diff --git a/.circleci/cimodel/data/caffe2_build_definitions.py b/.circleci/cimodel/data/caffe2_build_definitions.py index a419b5e47396b..87b1a85752d06 100644 --- a/.circleci/cimodel/data/caffe2_build_definitions.py +++ b/.circleci/cimodel/data/caffe2_build_definitions.py @@ -14,7 +14,7 @@ DOCKER_IMAGE_PATH_BASE = "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/" -DOCKER_IMAGE_VERSION = 301 +DOCKER_IMAGE_VERSION = 306 @dataclass diff --git a/.circleci/cimodel/data/pytorch_build_data.py b/.circleci/cimodel/data/pytorch_build_data.py index 261e9823b64df..ef328d1668014 100644 --- a/.circleci/cimodel/data/pytorch_build_data.py +++ b/.circleci/cimodel/data/pytorch_build_data.py @@ -13,7 +13,7 @@ ]), ("gcc", [ ("4.8", [X("3.6")]), - ("5.4", [ + ("5.4", [ # All this subtree rebases to master and then build XImportant("3.6"), ("3.6", [ ("namedtensor", [XImportant(True)]), diff --git a/.circleci/cimodel/data/pytorch_build_definitions.py b/.circleci/cimodel/data/pytorch_build_definitions.py index 22c4956d50b3e..76dbb8373ab07 100644 --- a/.circleci/cimodel/data/pytorch_build_definitions.py +++ b/.circleci/cimodel/data/pytorch_build_definitions.py @@ -13,7 +13,9 @@ DOCKER_IMAGE_PATH_BASE = "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/" -DOCKER_IMAGE_VERSION = 339 +# ARE YOU EDITING THIS NUMBER? MAKE SURE YOU READ THE GUIDANCE AT THE +# TOP OF .circleci/config.yml +DOCKER_IMAGE_VERSION = 347 @dataclass @@ -184,6 +186,7 @@ def instantiate_configs(): distro_name = fc.find_prop("distro_name") compiler_name = fc.find_prop("compiler_name") + compiler_version = fc.find_prop("compiler_version") is_xla = fc.find_prop("is_xla") or False parms_list_ignored_for_docker_image = [] @@ -244,6 +247,20 @@ def instantiate_configs(): if cuda_version == "9" and python_version == "3.6": c.dependent_tests = gen_dependent_configs(c) + if (compiler_name == "gcc" + and compiler_version == "5.4" + and not is_namedtensor): + bc_breaking_check = Conf( + "backward-compatibility-check", + [], + is_xla=False, + restrict_phases=["test"], + is_namedtensor=False, + is_important=True, + parent_build=c, + ) + c.dependent_tests.append(bc_breaking_check) + config_list.append(c) return config_list diff --git a/.circleci/cimodel/lib/miniyaml.py b/.circleci/cimodel/lib/miniyaml.py index e4de65bf871ac..ccd888ab2b0c3 100644 --- a/.circleci/cimodel/lib/miniyaml.py +++ b/.circleci/cimodel/lib/miniyaml.py @@ -9,23 +9,13 @@ def is_dict(data): - return type(data) is dict or type(data) is OrderedDict + return type(data) in [dict, OrderedDict] def is_collection(data): return is_dict(data) or type(data) is list -# TODO can eventually drop this custom sorting -def sortkey(x): - k = x[0] - return ( - k == "<<", - k != "environment", - k, - ) - - def render(fh, data, depth, is_list_member=False): """ PyYaml does not allow precise control over the quoting @@ -39,7 +29,7 @@ def render(fh, data, depth, is_list_member=False): tuples = list(data.items()) if type(data) is not OrderedDict: - tuples.sort(key=sortkey) + tuples.sort() for i, (k, v) in enumerate(tuples): @@ -51,10 +41,6 @@ def render(fh, data, depth, is_list_member=False): render(fh, v, depth + 1 + int(is_list_member)) - # TODO Could eventually drop this cosmetic convention - if depth == 2: - fh.write("\n") - elif type(data) is list: for v in data: render(fh, v, depth, True) diff --git a/.circleci/config.yml b/.circleci/config.yml index f313dc0cd23f7..b38aa8c513bb0 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -19,50 +19,96 @@ docker_config_defaults: &docker_config_defaults # This IAM user only allows read-write access to ECR aws_access_key_id: ${CIRCLECI_AWS_ACCESS_KEY_FOR_ECR_READ_WRITE_V4} aws_secret_access_key: ${CIRCLECI_AWS_SECRET_KEY_FOR_ECR_READ_WRITE_V4} +commands: + # NB: This command must be run as the first command in a job. It + # attaches the workspace at ~/workspace; this workspace is generated + # by the setup job. Note that ~/workspace is not the default working + # directory (that's ~/project). + should_run_job: + description: "Test if the job should run or not" + steps: + - attach_workspace: + name: Attaching workspace + at: ~/workspace + - run: + name: Should run job + no_output_timeout: "2m" + command: ~/workspace/.circleci/scripts/should_run_job.sh + + # This system setup script is meant to run before the CI-related scripts, e.g., + # installing Git client, checking out code, setting up CI env, and + # building/testing. + setup_linux_system_environment: + steps: + - run: + name: Set Up System Environment + no_output_timeout: "1h" + command: ~/workspace/.circleci/scripts/setup_linux_system_environment.sh -# This system setup script is meant to run before the CI-related scripts, e.g., -# installing Git client, checking out code, setting up CI env, and -# building/testing. -setup_linux_system_environment: &setup_linux_system_environment - name: Set Up System Environment - no_output_timeout: "1h" - command: ~/workspace/.circleci/scripts/setup_linux_system_environment.sh - -# NB: This (and the command below) must be run after attaching -# ~/workspace. This is NOT the default working directory (that's -# ~/project); this workspace is generated by the setup job. -should_run_job: &should_run_job - name: Should Run Job After attach_workspace - no_output_timeout: "2m" - command: ~/workspace/.circleci/scripts/should_run_job.sh - -setup_ci_environment: &setup_ci_environment - name: Set Up CI Environment After attach_workspace - no_output_timeout: "1h" - command: ~/workspace/.circleci/scripts/setup_ci_environment.sh - -# Installs expect and moreutils so that we can call `unbuffer` and `ts`. -# Also installs OpenMP -# !!!!NOTE!!!! this is copied into a binary_macos_brew_update job which is the -# same but does not install libomp. If you are changing this, consider if you -# need to change that step as well. -macos_brew_update: &macos_brew_update - name: Brew update and install moreutils, expect and libomp - no_output_timeout: "1h" - command: | - set -ex - # See https://discourse.brew.sh/t/fetching-homebrew-repos-is-slow/5374/3 - brew untap caskroom/homebrew-cask - # moreutils installs a `parallel` executable by default, which conflicts - # with the executable from the GNU `parallel`, so we must unlink GNU - # `parallel` first, and relink it afterwards - brew update - brew unlink parallel - brew install moreutils - brew link parallel --overwrite - brew install expect - brew install libomp + setup_ci_environment: + steps: + - run: + name: Set Up CI Environment After attach_workspace + no_output_timeout: "1h" + command: ~/workspace/.circleci/scripts/setup_ci_environment.sh + + brew_update: + description: "Update Homebrew and install base formulae" + steps: + - run: + name: Update Homebrew + no_output_timeout: "10m" + command: | + set -ex + + # Update repositories manually. + # Running `brew update` produces a comparison between the + # current checkout and the updated checkout, which takes a + # very long time because the existing checkout is 2y old. + for path in $(find /usr/local/Homebrew -type d -name .git) + do + cd $path/.. + git fetch --depth=1 origin + git reset --hard origin/master + done + + export HOMEBREW_NO_AUTO_UPDATE=1 + + # Install expect and moreutils so that we can call `unbuffer` and `ts`. + # moreutils installs a `parallel` executable by default, which conflicts + # with the executable from the GNU `parallel`, so we must unlink GNU + # `parallel` first, and relink it afterwards. + brew unlink parallel + brew install moreutils + brew link parallel --overwrite + brew install expect + + brew_install: + description: "Install Homebrew formulae" + parameters: + formulae: + type: string + default: "" + steps: + - run: + name: Install << parameters.formulae >> + no_output_timeout: "10m" + command: | + set -ex + export HOMEBREW_NO_AUTO_UPDATE=1 + brew install << parameters.formulae >> + + run_brew_for_macos_build: + steps: + - brew_update + - brew_install: + formulae: libomp + run_brew_for_ios_build: + steps: + - brew_update + - brew_install: + formulae: libtool ############################################################################## # Binary build (nightlies nightly build) defaults @@ -113,26 +159,6 @@ binary_run_in_docker: &binary_run_in_docker # This step only runs on circleci linux machine executors that themselves # need to start docker images command: ~/workspace/.circleci/scripts/binary_run_in_docker.sh - -# This is copied almost verbatim from the macos_brew_update job -# In version 2.1 and above we could make this a command and pass a parameter to -# it, but in this version there is no way to pass a parameter to a step -binary_macos_brew_update: &binary_macos_brew_update - name: Brew update and install moreutils and expect - no_output_timeout: "1h" - command: | - set -eux -o pipefail - # See https://discourse.brew.sh/t/fetching-homebrew-repos-is-slow/5374/3 - brew untap caskroom/homebrew-cask - # moreutils installs a `parallel` executable by default, which conflicts - # with the executable from the GNU `parallel`, so we must unlink GNU - # `parallel` first, and relink it afterwards - brew update - brew unlink parallel - brew install moreutils - brew link parallel --overwrite - brew install expect - ############################################################################## # Build parameters ############################################################################## @@ -156,6 +182,25 @@ pytorch_params: &pytorch_params USE_CUDA_DOCKER_RUNTIME: << parameters.use_cuda_docker_runtime >> resource_class: << parameters.resource_class >> +pytorch_ios_params: &pytorch_ios_params + parameters: + build_environment: + type: string + default: "" + ios_arch: + type: string + default: "" + ios_platform: + type: string + default: "" + environment: + BUILD_ENVIRONMENT: << parameters.build_environment >> + IOS_ARCH: << parameters.ios_arch >> + IOS_PLATFORM: << parameters.ios_platform >> + + + + caffe2_params: &caffe2_params parameters: build_environment: @@ -247,15 +292,10 @@ jobs: image: ubuntu-1604:201903-01 steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job - - run: - <<: *setup_linux_system_environment + - should_run_job + - setup_linux_system_environment - checkout - - run: - <<: *setup_ci_environment + - setup_ci_environment - run: name: Build no_output_timeout: "1h" @@ -266,6 +306,28 @@ jobs: docker pull ${DOCKER_IMAGE} >/dev/null export id=$(docker run -t -d -w /var/lib/jenkins ${DOCKER_IMAGE}) + # TODO We may want to move the rebase logic to a separate step after checkout + # Rebase to master only if in xenial_py3_6_gcc5_4 case + if [[ "${CIRCLE_BRANCH}" != "master" && "${BUILD_ENVIRONMENT}" == *"gcc5"* ]]; then + echo "Merge master branch into $CIRCLE_BRANCH before build in environment $BUILD_ENVIRONMENT" + set -x + git config --global user.email "circleci.ossci@gmail.com" + git config --global user.name "CircleCI" + git config remote.origin.url https://github.com/pytorch/pytorch.git + git config --add remote.origin.fetch +refs/heads/master:refs/remotes/origin/master + git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/master:refs/remotes/origin/master --depth=50 --quiet + export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/master` + echo "GIT_MERGE_TARGET: " ${GIT_MERGE_TARGET} + export GIT_COMMIT=${CIRCLE_SHA1} + echo "GIT_COMMIT: " ${GIT_COMMIT} + git checkout -f ${GIT_COMMIT} + git reset --hard ${GIT_COMMIT} + git merge --no-edit --no-ff ${GIT_MERGE_TARGET} + set +x + else + echo "Do NOT merge master branch into $CIRCLE_BRANCH in environment $BUILD_ENVIRONMENT" + fi + git submodule sync && git submodule update -q --init --recursive docker cp /home/circleci/project/. $id:/var/lib/jenkins/workspace @@ -274,11 +336,6 @@ jobs: NAMED_FLAG="export BUILD_NAMEDTENSOR=1" fi - # dispatch aten ops statically for mobile - if [[ ${BUILD_ENVIRONMENT} == *"android"* ]]; then - NAMED_FLAG="export USE_STATIC_DISPATCH=1" - fi - export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo '"$NAMED_FLAG"' && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && .jenkins/pytorch/build.sh") | docker exec -u jenkins -i "$id" bash) 2>&1' echo ${COMMAND} > ./command.sh && unbuffer bash ./command.sh | ts @@ -314,14 +371,9 @@ jobs: image: ubuntu-1604:201903-01 steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job - - run: - <<: *setup_linux_system_environment - - run: - <<: *setup_ci_environment + - should_run_job + - setup_linux_system_environment + - setup_ci_environment - run: name: Test no_output_timeout: "90m" @@ -331,7 +383,7 @@ jobs: output_image=${DOCKER_IMAGE}-${CIRCLE_SHA1} if [[ ${BUILD_ENVIRONMENT} == *"namedtensor"* ]]; then export COMMIT_DOCKER_IMAGE=$output_image-namedtensor - NAMED_FLAG="export BUILD_NAMEDTENSOR=1" + export NAMED_FLAG="export BUILD_NAMEDTENSOR=1 && export TEST_NAMEDTENSOR=1" elif [[ ${BUILD_ENVIRONMENT} == *"xla"* ]]; then export COMMIT_DOCKER_IMAGE=$output_image-xla else @@ -345,27 +397,21 @@ jobs: export id=$(docker run -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) fi if [[ ${BUILD_ENVIRONMENT} == *"multigpu"* ]]; then - export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo '"$NAMED_FLAG"' && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && .jenkins/pytorch/multigpu-test.sh") | docker exec -u jenkins -i "$id" bash) 2>&1' + export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "${NAMED_FLAG}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && .jenkins/pytorch/multigpu-test.sh") | docker exec -u jenkins -i "$id" bash) 2>&1' else - export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo '"$NAMED_FLAG"'&& echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && .jenkins/pytorch/test.sh") | docker exec -u jenkins -i "$id" bash) 2>&1' + export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "${NAMED_FLAG}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && .jenkins/pytorch/test.sh") | docker exec -u jenkins -i "$id" bash) 2>&1' fi echo ${COMMAND} > ./command.sh && unbuffer bash ./command.sh | ts - caffe2_linux_build: <<: *caffe2_params machine: image: ubuntu-1604:201903-01 steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job - - run: - <<: *setup_linux_system_environment + - should_run_job + - setup_linux_system_environment - checkout - - run: - <<: *setup_ci_environment + - setup_ci_environment - run: name: Build no_output_timeout: "1h" @@ -422,14 +468,9 @@ jobs: image: ubuntu-1604:201903-01 steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *setup_linux_system_environment - - run: - <<: *should_run_job - - run: - <<: *setup_ci_environment + - should_run_job + - setup_linux_system_environment + - setup_ci_environment - run: name: Test no_output_timeout: "1h" @@ -490,13 +531,9 @@ jobs: xcode: "9.0" steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job + - should_run_job - checkout - - run: - <<: *macos_brew_update + - run_brew_for_macos_build - run: name: Build no_output_timeout: "1h" @@ -573,10 +610,7 @@ jobs: <<: *binary_linux_build_params steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job + - should_run_job - run: <<: *binary_checkout - run: @@ -630,17 +664,12 @@ jobs: image: ubuntu-1604:201903-01 steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace + - should_run_job # TODO: We shouldn't attach the workspace multiple times - attach_workspace: at: /home/circleci/project - - run: - <<: *should_run_job - - run: - <<: *setup_linux_system_environment - - run: - <<: *setup_ci_environment + - setup_linux_system_environment + - setup_ci_environment - run: <<: *binary_checkout - run: @@ -658,14 +687,9 @@ jobs: image: ubuntu-1604:201903-01 steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job - - run: - <<: *setup_linux_system_environment - - run: - <<: *setup_ci_environment + - should_run_job + - setup_linux_system_environment + - setup_ci_environment - attach_workspace: at: /home/circleci/project - run: @@ -691,10 +715,8 @@ jobs: at: ~/workspace - attach_workspace: at: /home/circleci/project - - run: - <<: *setup_linux_system_environment - - run: - <<: *setup_ci_environment + - setup_linux_system_environment + - setup_ci_environment - run: <<: *binary_checkout - run: @@ -726,8 +748,7 @@ jobs: <<: *binary_checkout - run: <<: *binary_populate_env - - run: - <<: *binary_macos_brew_update + - brew_update - run: <<: *binary_install_miniconda - run: @@ -748,16 +769,12 @@ jobs: xcode: "9.0" steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job + - should_run_job - run: <<: *binary_checkout - run: <<: *binary_populate_env - - run: - <<: *binary_macos_brew_update + - brew_update - run: <<: *binary_install_miniconda @@ -789,16 +806,12 @@ jobs: xcode: "9.0" steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job + - should_run_job - run: <<: *binary_checkout - run: <<: *binary_populate_env - - run: - <<: *binary_macos_brew_update + - brew_update - run: <<: *binary_install_miniconda - attach_workspace: # TODO - we can `cp` from ~/workspace @@ -811,7 +824,45 @@ jobs: cat "$script" source "$script" + binary_ios_build: + <<: *pytorch_ios_params + macos: + xcode: "10.2.1" + steps: + - attach_workspace: + at: ~/workspace + - should_run_job + - checkout + - run_brew_for_ios_build + - run: + name: Build + contxt: org-member + no_output_timeout: "1h" + command: | + script="/Users/distiller/project/.circleci/scripts/binary_ios_build.sh" + cat "$script" + source "$script" + - persist_to_workspace: + root: /Users/distiller/workspace/ + paths: ios + binary_ios_upload: + <<: *pytorch_ios_params + macos: + xcode: "10.2.1" + steps: + - attach_workspace: + at: ~/workspace + - should_run_job + - checkout + - run_brew_for_ios_build + - run: + name: Upload + no_output_timeout: "1h" + command: | + script="/Users/distiller/project/.circleci/scripts/binary_ios_upload.sh" + cat "$script" + source "$script" setup: docker: - image: circleci/python:3.7.3 @@ -848,7 +899,7 @@ jobs: pytorch_short_perf_test_gpu: environment: BUILD_ENVIRONMENT: pytorch-short-perf-test-gpu - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:339" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:347" PYTHON_VERSION: "3.6" USE_CUDA_DOCKER_RUNTIME: "1" resource_class: gpu.medium @@ -856,14 +907,9 @@ jobs: image: ubuntu-1604:201903-01 steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job - - run: - <<: *setup_linux_system_environment - - run: - <<: *setup_ci_environment + - should_run_job + - setup_linux_system_environment + - setup_ci_environment - run: name: Perf Test no_output_timeout: "1h" @@ -889,20 +935,15 @@ jobs: environment: BUILD_ENVIRONMENT: pytorch-python-doc-push # TODO: stop hardcoding this - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:339" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:347" resource_class: large machine: image: ubuntu-1604:201903-01 steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job - - run: - <<: *setup_linux_system_environment - - run: - <<: *setup_ci_environment + - should_run_job + - setup_linux_system_environment + - setup_ci_environment - run: name: Doc Build and Push no_output_timeout: "1h" @@ -939,20 +980,15 @@ jobs: pytorch_cpp_doc_push: environment: BUILD_ENVIRONMENT: pytorch-cpp-doc-push - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:339" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:347" resource_class: large machine: image: ubuntu-1604:201903-01 steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job - - run: - <<: *setup_linux_system_environment - - run: - <<: *setup_ci_environment + - should_run_job + - setup_linux_system_environment + - setup_ci_environment - run: name: Doc Build and Push no_output_timeout: "1h" @@ -993,26 +1029,21 @@ jobs: xcode: "9.0" steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job + - should_run_job - checkout - - run: - <<: *macos_brew_update + - run_brew_for_macos_build - run: name: Build no_output_timeout: "1h" command: | set -e - export IN_CIRCLECI=1 # Install sccache sudo curl https://s3.amazonaws.com/ossci-macos/sccache --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache - export SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2 + # This IAM user allows write access to S3 bucket for sccache set +x export AWS_ACCESS_KEY_ID=${CIRCLECI_AWS_ACCESS_KEY_FOR_SCCACHE_S3_BUCKET_V4} @@ -1022,14 +1053,14 @@ jobs: chmod a+x .jenkins/pytorch/macos-build.sh unbuffer .jenkins/pytorch/macos-build.sh 2>&1 | ts - mkdir -p /Users/distiller/pytorch-ci-env/workspace - # copy with -a to preserve relative structure (e.g., symlinks), and be recursive - cp -a /Users/distiller/project/. /Users/distiller/pytorch-ci-env/workspace + cp -a ~/project ~/workspace + - persist_to_workspace: - root: /Users/distiller/pytorch-ci-env + root: ~/workspace paths: - - "*" + - miniconda3 + - project pytorch_macos_10_13_py3_test: environment: @@ -1039,12 +1070,8 @@ jobs: steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml # This workspace also carries binaries from the build job - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job - - run: - <<: *macos_brew_update + - should_run_job + - run_brew_for_macos_build - run: name: Test no_output_timeout: "1h" @@ -1053,10 +1080,7 @@ jobs: export IN_CIRCLECI=1 # copy with -a to preserve relative structure (e.g., symlinks), and be recursive - # TODO: I'm not sure why we can't just run our job in - # ~/workspace and call it a day - # NB: Yes, you need workspace twice - cp -a ~/workspace/workspace/. /Users/distiller/project + cp -a ~/workspace/project/. ~/project chmod a+x .jenkins/pytorch/macos-test.sh unbuffer .jenkins/pytorch/macos-test.sh 2>&1 | ts @@ -1068,13 +1092,9 @@ jobs: xcode: "9.0" steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job + - should_run_job - checkout - - run: - <<: *macos_brew_update + - run_brew_for_macos_build - run: name: Build no_output_timeout: "1h" @@ -1116,21 +1136,16 @@ jobs: pytorch_android_gradle_build: environment: BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:339" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:347" PYTHON_VERSION: "3.6" resource_class: large machine: image: ubuntu-1604:201903-01 steps: - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job - - run: - <<: *setup_linux_system_environment + - should_run_job + - setup_linux_system_environment - checkout - - run: - <<: *setup_ci_environment + - setup_ci_environment - run: name: pytorch android gradle build no_output_timeout: "1h" @@ -1204,19 +1219,52 @@ jobs: path: ~/workspace/build_android_artifacts/artifacts.tgz destination: artifacts.tgz + pytorch_android_publish_snapshot: + environment: + BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-publish-snapshot + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:347" + PYTHON_VERSION: "3.6" + resource_class: large + machine: + image: ubuntu-1604:201903-01 + steps: + - should_run_job + - setup_linux_system_environment + - checkout + - setup_ci_environment + - run: + name: pytorch android gradle build + no_output_timeout: "1h" + command: | + set -eux + docker_image_commit=${DOCKER_IMAGE}-${CIRCLE_SHA1} + + docker_image_libtorch_android_x86_32_gradle=${docker_image_commit}-android-x86_32-gradle + + echo "docker_image_commit: "${docker_image_commit} + echo "docker_image_libtorch_android_x86_32_gradle: "${docker_image_libtorch_android_x86_32_gradle} + + # x86_32 + docker pull ${docker_image_libtorch_android_x86_32_gradle} >/dev/null + export id_x86_32=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_x86_32_gradle}) + + export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace" && echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export SONATYPE_NEXUS_USERNAME=${SONATYPE_NEXUS_USERNAME}" && echo "export SONATYPE_NEXUS_PASSWORD=${SONATYPE_NEXUS_PASSWORD}" && echo "export ANDROID_SIGN_KEY=${ANDROID_SIGN_KEY}" && echo "export ANDROID_SIGN_PASS=${ANDROID_SIGN_PASS}" && echo "sudo chown -R jenkins workspace && cd workspace && ./.circleci/scripts/publish_android_snapshot.sh") | docker exec -u jenkins -i "$id_x86_32" bash) 2>&1' + echo ${COMMAND} > ./command.sh && unbuffer bash ./command.sh | ts + + output_image=${docker_image_libtorch_android_x86_32_gradle}-publish-snapshot + docker commit "$id_x86_32" ${output_image} + docker push ${output_image} + pytorch_android_gradle_build-x86_32: environment: BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build-only-x86_32 - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:339" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:347" PYTHON_VERSION: "3.6" resource_class: large machine: image: ubuntu-1604:201903-01 steps: - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job + - should_run_job - run: name: filter out not PR runs no_output_timeout: "5m" @@ -1225,11 +1273,9 @@ jobs: if [ -z "${CIRCLE_PULL_REQUEST:-}" ]; then circleci step halt fi - - run: - <<: *setup_linux_system_environment + - setup_linux_system_environment - checkout - - run: - <<: *setup_ci_environment + - setup_ci_environment - run: name: pytorch android gradle build only x86_32 (for PR) no_output_timeout: "1h" @@ -1248,13 +1294,55 @@ jobs: mkdir -p ~/workspace/build_android_x86_32_artifacts docker cp $id:/var/lib/jenkins/workspace/android/artifacts.tgz ~/workspace/build_android_x86_32_artifacts/ - output_image=${DOCKER_IMAGE}-${CIRCLE_SHA1}-android-gradle-x86_32 + output_image=${docker_image_libtorch_android_x86_32}-gradle docker commit "$id" ${output_image} docker push ${output_image} - store_artifacts: path: ~/workspace/build_android_x86_32_artifacts/artifacts.tgz destination: artifacts.tgz - + + pytorch_ios_build: + <<: *pytorch_ios_params + macos: + xcode: "10.2.1" + steps: + # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml + - should_run_job + - checkout + - run_brew_for_ios_build + - run: + name: Build + no_output_timeout: "1h" + command: | + set -e + export IN_CIRCLECI=1 + WORKSPACE=/Users/distiller/workspace + PROJ_ROOT=/Users/distiller/project + export TCLLIBPATH="/usr/local/lib" + + # Install conda + curl -o ~/Downloads/conda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + chmod +x ~/Downloads/conda.sh + /bin/bash ~/Downloads/conda.sh -b -p ~/anaconda + export PATH="~/anaconda/bin:${PATH}" + source ~/anaconda/bin/activate + # Install dependencies + conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests + # sync submodules + cd ${PROJ_ROOT} + git submodule sync + git submodule update --init --recursive + # export + export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} + # run build script + chmod a+x ${PROJ_ROOT}/scripts/build_ios.sh + echo "IOS_ARCH: ${IOS_ARCH}" + echo "IOS_PLATFORM: ${IOS_PLATFORM}" + export BUILD_PYTORCH_MOBILE=1 + export IOS_ARCH=${IOS_ARCH} + export IOS_PLATFORM=${IOS_PLATFORM} + unbuffer ${PROJ_ROOT}/scripts/build_ios.sh 2>&1 | ts + # update_s3_htmls job # These jobs create html files for every cpu/cu## folder in s3. The html # files just store the names of all the files in that folder (which are @@ -1268,8 +1356,7 @@ jobs: steps: - attach_workspace: at: ~/workspace - - run: - <<: *setup_linux_system_environment + - setup_linux_system_environment - run: <<: *binary_checkout # N.B. we do not run binary_populate_env. The only variable we need is @@ -1323,8 +1410,7 @@ jobs: steps: - attach_workspace: at: ~/workspace - - run: - <<: *setup_linux_system_environment + - setup_linux_system_environment - run: <<: *binary_checkout - run: @@ -1389,14 +1475,14 @@ workflows: requires: - setup build_environment: "pytorch-linux-xenial-py2.7.9-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py2.7.9:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py2.7.9:347" - pytorch_linux_test: name: pytorch_linux_xenial_py2_7_9_test requires: - setup - pytorch_linux_xenial_py2_7_9_build build_environment: "pytorch-linux-xenial-py2.7.9-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py2.7.9:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py2.7.9:347" resource_class: large - pytorch_linux_build: name: pytorch_linux_xenial_py2_7_build @@ -1408,7 +1494,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py2.7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py2.7:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py2.7:347" - pytorch_linux_test: name: pytorch_linux_xenial_py2_7_test requires: @@ -1420,7 +1506,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py2.7-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py2.7:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py2.7:347" resource_class: large - pytorch_linux_build: name: pytorch_linux_xenial_py3_5_build @@ -1432,7 +1518,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py3.5-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.5:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.5:347" - pytorch_linux_test: name: pytorch_linux_xenial_py3_5_test requires: @@ -1444,7 +1530,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py3.5-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.5:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.5:347" resource_class: large - pytorch_linux_build: name: pytorch_linux_xenial_pynightly_build @@ -1456,7 +1542,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-pynightly-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-pynightly:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-pynightly:347" - pytorch_linux_test: name: pytorch_linux_xenial_pynightly_test requires: @@ -1468,7 +1554,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-pynightly-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-pynightly:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-pynightly:347" resource_class: large - pytorch_linux_build: name: pytorch_linux_xenial_py3_6_gcc4_8_build @@ -1480,7 +1566,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py3.6-gcc4.8-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc4.8:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc4.8:347" - pytorch_linux_test: name: pytorch_linux_xenial_py3_6_gcc4_8_test requires: @@ -1492,35 +1578,43 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py3.6-gcc4.8-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc4.8:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc4.8:347" resource_class: large - pytorch_linux_build: name: pytorch_linux_xenial_py3_6_gcc5_4_build requires: - setup build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:347" - pytorch_linux_test: name: pytorch_linux_xenial_py3_6_gcc5_4_test requires: - setup - pytorch_linux_xenial_py3_6_gcc5_4_build build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:347" + resource_class: large + - pytorch_linux_test: + name: pytorch_linux_backward_compatibility_check_test + requires: + - setup + - pytorch_linux_xenial_py3_6_gcc5_4_build + build_environment: "pytorch-linux-backward-compatibility-check-test" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:347" resource_class: large - pytorch_linux_build: name: pytorch_namedtensor_linux_xenial_py3_6_gcc5_4_build requires: - setup build_environment: "pytorch-namedtensor-linux-xenial-py3.6-gcc5.4-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:347" - pytorch_linux_test: name: pytorch_namedtensor_linux_xenial_py3_6_gcc5_4_test requires: - setup - pytorch_namedtensor_linux_xenial_py3_6_gcc5_4_build build_environment: "pytorch-namedtensor-linux-xenial-py3.6-gcc5.4-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:347" resource_class: large - pytorch_linux_build: name: pytorch_linux_xenial_py3_6_gcc7_build @@ -1532,7 +1626,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py3.6-gcc7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7:347" - pytorch_linux_test: name: pytorch_linux_xenial_py3_6_gcc7_test requires: @@ -1544,49 +1638,49 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py3.6-gcc7-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7:347" resource_class: large - pytorch_linux_build: name: pytorch_linux_xenial_py3_clang5_asan_build requires: - setup build_environment: "pytorch-linux-xenial-py3-clang5-asan-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:347" - pytorch_linux_test: name: pytorch_linux_xenial_py3_clang5_asan_test requires: - setup - pytorch_linux_xenial_py3_clang5_asan_build build_environment: "pytorch-linux-xenial-py3-clang5-asan-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:347" resource_class: large - pytorch_linux_build: name: pytorch_namedtensor_linux_xenial_py3_clang5_asan_build requires: - setup build_environment: "pytorch-namedtensor-linux-xenial-py3-clang5-asan-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:347" - pytorch_linux_test: name: pytorch_namedtensor_linux_xenial_py3_clang5_asan_test requires: - setup - pytorch_namedtensor_linux_xenial_py3_clang5_asan_build build_environment: "pytorch-namedtensor-linux-xenial-py3-clang5-asan-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:347" resource_class: large - pytorch_linux_build: name: pytorch_xla_linux_xenial_py3_6_clang7_build requires: - setup build_environment: "pytorch-xla-linux-xenial-py3.6-clang7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-clang7:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-clang7:347" - pytorch_linux_test: name: pytorch_xla_linux_xenial_py3_6_clang7_test requires: - setup - pytorch_xla_linux_xenial_py3_6_clang7_build build_environment: "pytorch-xla-linux-xenial-py3.6-clang7-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-clang7:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-clang7:347" resource_class: large - pytorch_linux_build: name: pytorch_linux_xenial_cuda9_cudnn7_py2_build @@ -1598,7 +1692,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-cuda9-cudnn7-py2-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py2:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py2:347" - pytorch_linux_test: name: pytorch_linux_xenial_cuda9_cudnn7_py2_test requires: @@ -1610,7 +1704,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-cuda9-cudnn7-py2-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py2:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py2:347" use_cuda_docker_runtime: "1" resource_class: gpu.medium - pytorch_linux_build: @@ -1618,14 +1712,14 @@ workflows: requires: - setup build_environment: "pytorch-linux-xenial-cuda9-cudnn7-py3-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:347" - pytorch_linux_test: name: pytorch_linux_xenial_cuda9_cudnn7_py3_test requires: - setup - pytorch_linux_xenial_cuda9_cudnn7_py3_build build_environment: "pytorch-linux-xenial-cuda9-cudnn7-py3-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:347" use_cuda_docker_runtime: "1" resource_class: gpu.medium - pytorch_linux_test: @@ -1634,7 +1728,7 @@ workflows: - setup - pytorch_linux_xenial_cuda9_cudnn7_py3_build build_environment: "pytorch-linux-xenial-cuda9-cudnn7-py3-multigpu-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:347" use_cuda_docker_runtime: "1" resource_class: gpu.large - pytorch_linux_test: @@ -1643,7 +1737,7 @@ workflows: - setup - pytorch_linux_xenial_cuda9_cudnn7_py3_build build_environment: "pytorch-linux-xenial-cuda9-cudnn7-py3-NO_AVX2-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:347" use_cuda_docker_runtime: "1" resource_class: gpu.medium - pytorch_linux_test: @@ -1652,7 +1746,7 @@ workflows: - setup - pytorch_linux_xenial_cuda9_cudnn7_py3_build build_environment: "pytorch-linux-xenial-cuda9-cudnn7-py3-NO_AVX-NO_AVX2-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:347" use_cuda_docker_runtime: "1" resource_class: gpu.medium - pytorch_linux_test: @@ -1661,7 +1755,7 @@ workflows: - setup - pytorch_linux_xenial_cuda9_cudnn7_py3_build build_environment: "pytorch-linux-xenial-cuda9-cudnn7-py3-slow-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:347" use_cuda_docker_runtime: "1" resource_class: gpu.medium - pytorch_linux_test: @@ -1670,7 +1764,7 @@ workflows: - setup - pytorch_linux_xenial_cuda9_cudnn7_py3_build build_environment: "pytorch-linux-xenial-cuda9-cudnn7-py3-nogpu-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:347" resource_class: large - pytorch_short_perf_test_gpu: requires: @@ -1686,14 +1780,14 @@ workflows: requires: - setup build_environment: "pytorch-namedtensor-linux-xenial-cuda9-cudnn7-py2-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py2:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py2:347" - pytorch_linux_test: name: pytorch_namedtensor_linux_xenial_cuda9_cudnn7_py2_test requires: - setup - pytorch_namedtensor_linux_xenial_cuda9_cudnn7_py2_build build_environment: "pytorch-namedtensor-linux-xenial-cuda9-cudnn7-py2-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py2:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py2:347" use_cuda_docker_runtime: "1" resource_class: gpu.medium - pytorch_linux_build: @@ -1706,7 +1800,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7:347" - pytorch_linux_test: name: pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_test requires: @@ -1718,7 +1812,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7:347" use_cuda_docker_runtime: "1" resource_class: gpu.medium - pytorch_linux_build: @@ -1731,7 +1825,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-cuda10-cudnn7-py3-gcc7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10-cudnn7-py3-gcc7:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10-cudnn7-py3-gcc7:347" - pytorch_linux_build: name: pytorch_linux_xenial_cuda10_1_cudnn7_py3_gcc7_build requires: @@ -1742,7 +1836,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7:347" - pytorch_linux_test: name: pytorch_linux_xenial_cuda10_1_cudnn7_py3_gcc7_test requires: @@ -1754,7 +1848,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7:347" use_cuda_docker_runtime: "1" resource_class: gpu.medium - pytorch_linux_build: @@ -1762,7 +1856,7 @@ workflows: requires: - setup build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-x86_32-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:347" - pytorch_linux_build: name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_64_build requires: @@ -1773,7 +1867,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-x86_64-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:347" - pytorch_linux_build: name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v7a_build requires: @@ -1784,7 +1878,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-arm-v7a-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:347" - pytorch_linux_build: name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v8a_build requires: @@ -1795,36 +1889,45 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-arm-v8a-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:339" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:347" # Warning: indentation here matters! # Pytorch MacOS builds - pytorch_macos_10_13_py3_build: requires: - setup - filters: - branches: - only: master - pytorch_macos_10_13_py3_test: requires: - setup - pytorch_macos_10_13_py3_build - filters: - branches: - only: master - pytorch_macos_10_13_cuda9_2_cudnn7_py3_build: requires: - setup - pytorch_android_gradle_build-x86_32: + name: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build-x86_32 requires: - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build - pytorch_android_gradle_build: + name: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build requires: - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_64_build - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v7a_build - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v8a_build + # Pytorch iOS PR builds + - pytorch_ios_build: + name: pytorch_ios_10_2_1_x86_64_build + build_environment: "pytorch-ios-10.2.1-x86_64_build" + ios_platform: "SIMULATOR" + requires: + - setup + - pytorch_ios_build: + name: pytorch_ios_10_2_1_arm64_build + build_environment: "pytorch-ios-10.2.1-arm64_build" + ios_arch: "arm64" + requires: + - setup - caffe2_linux_build: name: caffe2_py2_gcc4_8_ubuntu14_04_build requires: @@ -1835,7 +1938,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "caffe2-py2-gcc4.8-ubuntu14.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:306" - caffe2_linux_test: name: caffe2_py2_gcc4_8_ubuntu14_04_test requires: @@ -1847,7 +1950,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "caffe2-py2-gcc4.8-ubuntu14.04-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc4.8-ubuntu14.04:306" resource_class: large - caffe2_linux_build: name: caffe2_py2_cuda9_0_cudnn7_ubuntu16_04_build @@ -1859,7 +1962,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "caffe2-py2-cuda9.0-cudnn7-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:306" - caffe2_linux_test: name: caffe2_py2_cuda9_0_cudnn7_ubuntu16_04_test requires: @@ -1872,14 +1975,14 @@ workflows: - /ci-all\/.*/ build_environment: "caffe2-py2-cuda9.0-cudnn7-ubuntu16.04-test" use_cuda_docker_runtime: "1" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:306" resource_class: gpu.medium - caffe2_linux_build: name: caffe2_cmake_cuda9_0_cudnn7_ubuntu16_04_build requires: - setup build_environment: "caffe2-cmake-cuda9.0-cudnn7-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:306" - caffe2_linux_test: name: caffe2_cmake_cuda9_0_cudnn7_ubuntu16_04_test requires: @@ -1887,14 +1990,14 @@ workflows: - caffe2_cmake_cuda9_0_cudnn7_ubuntu16_04_build build_environment: "caffe2-cmake-cuda9.0-cudnn7-ubuntu16.04-test" use_cuda_docker_runtime: "1" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-ubuntu16.04:306" resource_class: gpu.medium - caffe2_linux_build: name: caffe2_py2_cuda9_1_cudnn7_ubuntu16_04_build requires: - setup build_environment: "caffe2-py2-cuda9.1-cudnn7-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.1-cudnn7-ubuntu16.04:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.1-cudnn7-ubuntu16.04:306" - caffe2_linux_test: name: caffe2_py2_cuda9_1_cudnn7_ubuntu16_04_test requires: @@ -1902,35 +2005,35 @@ workflows: - caffe2_py2_cuda9_1_cudnn7_ubuntu16_04_build build_environment: "caffe2-py2-cuda9.1-cudnn7-ubuntu16.04-test" use_cuda_docker_runtime: "1" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.1-cudnn7-ubuntu16.04:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.1-cudnn7-ubuntu16.04:306" resource_class: gpu.medium - caffe2_linux_build: name: caffe2_py2_mkl_ubuntu16_04_build requires: - setup build_environment: "caffe2-py2-mkl-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:306" - caffe2_linux_test: name: caffe2_py2_mkl_ubuntu16_04_test requires: - setup - caffe2_py2_mkl_ubuntu16_04_build build_environment: "caffe2-py2-mkl-ubuntu16.04-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-mkl-ubuntu16.04:306" resource_class: large - caffe2_linux_build: name: caffe2_onnx_py2_gcc5_ubuntu16_04_build requires: - setup build_environment: "caffe2-onnx-py2-gcc5-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:306" - caffe2_linux_test: name: caffe2_onnx_py2_gcc5_ubuntu16_04_test requires: - setup - caffe2_onnx_py2_gcc5_ubuntu16_04_build build_environment: "caffe2-onnx-py2-gcc5-ubuntu16.04-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-gcc5-ubuntu16.04:306" resource_class: large - caffe2_linux_build: name: caffe2_py2_clang3_8_ubuntu16_04_build @@ -1942,7 +2045,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "caffe2-py2-clang3.8-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.8-ubuntu16.04:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.8-ubuntu16.04:306" build_only: "1" - caffe2_linux_build: name: caffe2_py2_clang3_9_ubuntu16_04_build @@ -1954,35 +2057,35 @@ workflows: - master - /ci-all\/.*/ build_environment: "caffe2-py2-clang3.9-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.9-ubuntu16.04:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang3.9-ubuntu16.04:306" build_only: "1" - caffe2_linux_build: name: caffe2_py2_clang7_ubuntu16_04_build requires: - setup build_environment: "caffe2-py2-clang7-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang7-ubuntu16.04:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-clang7-ubuntu16.04:306" build_only: "1" - caffe2_linux_build: name: caffe2_onnx_py3_6_clang7_ubuntu16_04_build requires: - setup build_environment: "caffe2-onnx-py3.6-clang7-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:306" - caffe2_linux_test: name: caffe2_onnx_py3_6_clang7_ubuntu16_04_test requires: - setup - caffe2_onnx_py3_6_clang7_ubuntu16_04_build build_environment: "caffe2-onnx-py3.6-clang7-ubuntu16.04-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py3.6-clang7-ubuntu16.04:306" resource_class: large - caffe2_linux_build: name: caffe2_py2_android_ubuntu16_04_build requires: - setup build_environment: "caffe2-py2-android-ubuntu16.04-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-android-ubuntu16.04:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-android-ubuntu16.04:306" build_only: "1" - caffe2_linux_build: name: caffe2_py2_cuda9_0_cudnn7_centos7_build @@ -1994,7 +2097,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "caffe2-py2-cuda9.0-cudnn7-centos7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:306" - caffe2_linux_test: name: caffe2_py2_cuda9_0_cudnn7_centos7_test requires: @@ -2007,7 +2110,7 @@ workflows: - /ci-all\/.*/ build_environment: "caffe2-py2-cuda9.0-cudnn7-centos7-test" use_cuda_docker_runtime: "1" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:301" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/caffe2/py2-cuda9.0-cudnn7-centos7:306" resource_class: gpu.medium - caffe2_macos_build: name: caffe2_py2_ios_macos10_13_build @@ -2582,7 +2685,6 @@ workflows: build_environment: "libtorch 2.7 cpu" requires: - setup - ############################################################################## # Daily binary build trigger ############################################################################## @@ -2971,7 +3073,66 @@ workflows: build_environment: "libtorch 2.7 cpu" requires: - setup + # Pytorch iOS binary builds + - binary_ios_build: + name: pytorch_ios_10_2_1_nightly_x86_64_build + build_environment: "libtorch-ios-10.2.1-nightly-x86_64-build" + ios_platform: "SIMULATOR" + ios_arch: "x86_64" + requires: + - setup + - binary_ios_build: + name: pytorch_ios_10_2_1_nigthly_arm64_build + build_environment: "libtorch-ios-10.2.1-nightly-arm64-build" + ios_arch: "arm64" + ios_platform: "OS" + requires: + - setup + - binary_ios_upload: + build_environment: "libtorch-ios-10.2.1-nightly-binary-build-upload" + context: org-member + requires: + - setup + - pytorch_ios_10_2_1_nightly_x86_64_build + - pytorch_ios_10_2_1_nigthly_arm64_build + - pytorch_linux_build: + name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build + build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-x86_32" + requires: + - setup + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:339" + - pytorch_linux_build: + name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_64_build + build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-x86_64" + requires: + - setup + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:339" + - pytorch_linux_build: + name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v7a_build + build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-arm-v7a" + requires: + - setup + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:339" + - pytorch_linux_build: + name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v8a_build + build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-arm-v8a" + requires: + - setup + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:339" + + - pytorch_android_gradle_build: + name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_android_gradle_build + requires: + - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build + - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_64_build + - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v7a_build + - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v8a_build + - pytorch_android_publish_snapshot: + name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_android_publish_snapshot + requires: + - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_android_gradle_build + context: org-member ############################################################################## # Nightly tests ############################################################################## diff --git a/.circleci/generate_config_yml.py b/.circleci/generate_config_yml.py index 5f94ad2575188..abefc0f585480 100755 --- a/.circleci/generate_config_yml.py +++ b/.circleci/generate_config_yml.py @@ -74,6 +74,7 @@ def write(self, output_filehandle): # Order of this list matters to the generated config.yml. YAML_SOURCES = [ File("header-section.yml"), + File("commands.yml"), File("nightly-binary-build-defaults.yml"), Header("Build parameters"), File("pytorch-build-params.yml"), @@ -91,12 +92,15 @@ def write(self, output_filehandle): Listgen(pytorch_build_definitions.get_workflow_jobs, 3), File("workflows-pytorch-macos-builds.yml"), File("workflows-pytorch-android-gradle-build.yml"), + File("workflows-pytorch-ios-builds.yml"), Listgen(caffe2_build_definitions.get_workflow_jobs, 3), File("workflows-binary-builds-smoke-subset.yml"), Header("Daily smoke test trigger"), Treegen(binary_build_definitions.add_binary_smoke_test_jobs, 1), Header("Daily binary build trigger"), Treegen(binary_build_definitions.add_binary_build_jobs, 1), + File("workflows-nightly-ios-binary-builds.yml"), + File("workflows-nightly-android-binary-builds.yml"), Header("Nightly tests"), Listgen(binary_build_definitions.get_nightly_tests, 3), File("workflows-nightly-uploads-header.yml"), diff --git a/.circleci/scripts/binary_ios_build.sh b/.circleci/scripts/binary_ios_build.sh new file mode 100644 index 0000000000000..c15813b5c5d71 --- /dev/null +++ b/.circleci/scripts/binary_ios_build.sh @@ -0,0 +1,38 @@ +#!/bin/bash +set -eux -o pipefail + +echo "" +echo "PWD: ${PWD}" +WORKSPACE=/Users/distiller/workspace +PROJ_ROOT=/Users/distiller/project +export TCLLIBPATH="/usr/local/lib" +# Install conda +curl -o ~/Downloads/conda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh +chmod +x ~/Downloads/conda.sh +/bin/bash ~/Downloads/conda.sh -b -p ~/anaconda +export PATH="~/anaconda/bin:${PATH}" +source ~/anaconda/bin/activate +# Install dependencies +conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests +export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} +# sync submodules +cd ${PROJ_ROOT} +git submodule sync +git submodule update --init --recursive +# run build script +chmod a+x ${PROJ_ROOT}/scripts/build_ios.sh +echo "########################################################" +cat ${PROJ_ROOT}/scripts/build_ios.sh +echo "########################################################" +echo "IOS_ARCH: ${IOS_ARCH}" +echo "IOS_PLATFORM: ${IOS_PLATFORM}" +export BUILD_PYTORCH_MOBILE=1 +export IOS_ARCH=${IOS_ARCH} +export IOS_PLATFORM=${IOS_PLATFORM} +unbuffer ${PROJ_ROOT}/scripts/build_ios.sh 2>&1 | ts +#store the binary +cd ${WORKSPACE} +DEST_DIR=${WORKSPACE}/ios +mkdir -p ${DEST_DIR} +cp -R ${PROJ_ROOT}/build_ios/install ${DEST_DIR} +mv ${DEST_DIR}/install ${DEST_DIR}/${IOS_ARCH} \ No newline at end of file diff --git a/.circleci/scripts/binary_ios_upload.sh b/.circleci/scripts/binary_ios_upload.sh new file mode 100644 index 0000000000000..32dde062e3243 --- /dev/null +++ b/.circleci/scripts/binary_ios_upload.sh @@ -0,0 +1,44 @@ +#!/bin/bash +set -eux -o pipefail + +echo "" +echo "PWD: $(pwd)" +WORKSPACE=/Users/distiller/workspace +PROJ_ROOT=/Users/distiller/project +ARTIFACTS_DIR=${WORKSPACE}/ios +ls ${ARTIFACTS_DIR} +ZIP_DIR=${WORKSPACE}/zip +mkdir -p ${ZIP_DIR}/install/lib +mkdir -p ${ZIP_DIR}/src +# copy header files +cp -R ${ARTIFACTS_DIR}/arm64/include ${ZIP_DIR}/install/ +# build a FAT bianry +cd ${ZIP_DIR}/install/lib +target_libs=(libc10.a libclog.a libcpuinfo.a libqnnpack.a libtorch.a) +for lib in ${target_libs[*]} +do + libs=(${ARTIFACTS_DIR}/x86_64/lib/${lib} ${ARTIFACTS_DIR}/arm64/lib/${lib}) + lipo -create "${libs[@]}" -o ${ZIP_DIR}/install/lib/${lib} +done +# for nnpack, we only support arm64 build +cp ${ARTIFACTS_DIR}/arm64/lib/libnnpack.a ./ +lipo -i ${ZIP_DIR}/install/lib/*.a +# copy the umbrella header and license +cp ${PROJ_ROOT}/ios/LibTorch.h ${ZIP_DIR}/src/ +cp ${PROJ_ROOT}/LICENSE ${ZIP_DIR}/ +# zip the library +ZIPFILE=libtorch_ios_nightly_build.zip +cd ${ZIP_DIR} +#for testing +touch version.txt +echo $(date +%s) > version.txt +zip -r ${ZIPFILE} install src version.txt LICENSE +# upload to aws +brew install awscli +set +x +export AWS_ACCESS_KEY_ID=${AWS_S3_ACCESS_KEY_FOR_PYTORCH_BINARY_UPLOAD} +export AWS_SECRET_ACCESS_KEY=${AWS_S3_ACCESS_SECRET_FOR_PYTORCH_BINARY_UPLOAD} +set +x +# echo "AWS KEY: ${AWS_ACCESS_KEY_ID}" +# echo "AWS SECRET: ${AWS_SECRET_ACCESS_KEY}" +aws s3 cp ${ZIPFILE} s3://ossci-ios-build/ --acl public-read diff --git a/.circleci/scripts/build_android_gradle.sh b/.circleci/scripts/build_android_gradle.sh index 109c2c7a221ab..4d986229854d7 100755 --- a/.circleci/scripts/build_android_gradle.sh +++ b/.circleci/scripts/build_android_gradle.sh @@ -4,68 +4,46 @@ set -eux -o pipefail export ANDROID_NDK_HOME=/opt/ndk export ANDROID_HOME=/opt/android/sdk -export GRADLE_VERSION=5.1.1 +export GRADLE_VERSION=4.10.3 export GRADLE_HOME=/opt/gradle/gradle-$GRADLE_VERSION export GRADLE_PATH=$GRADLE_HOME/bin/gradle -PYTORCH_ANDROID_SRC_MAIN_DIR=~/workspace/android/pytorch_android/src/main +BUILD_ANDROID_INCLUDE_DIR_x86=~/workspace/build_android/install/include +BUILD_ANDROID_LIB_DIR_x86=~/workspace/build_android/install/lib -JNI_LIBS_DIR=${PYTORCH_ANDROID_SRC_MAIN_DIR}/jniLibs -mkdir -p $JNI_LIBS_DIR -JNI_LIBS_DIR_x86=${JNI_LIBS_DIR}/x86 -mkdir -p $JNI_LIBS_DIR_x86 - -JNI_INCLUDE_DIR=${PYTORCH_ANDROID_SRC_MAIN_DIR}/cpp/libtorch_include -mkdir -p $JNI_INCLUDE_DIR -JNI_INCLUDE_DIR_x86=${JNI_INCLUDE_DIR}/x86 +BUILD_ANDROID_INCLUDE_DIR_x86_64=~/workspace/build_android_install_x86_64/install/include +BUILD_ANDROID_LIB_DIR_x86_64=~/workspace/build_android_install_x86_64/install/lib -env -echo "BUILD_ENVIRONMENT:$BUILD_ENVIRONMENT" +BUILD_ANDROID_INCLUDE_DIR_arm_v7a=~/workspace/build_android_install_arm_v7a/install/include +BUILD_ANDROID_LIB_DIR_arm_v7a=~/workspace/build_android_install_arm_v7a/install/lib -if [[ "${BUILD_ENVIRONMENT}" == *-gradle-build-only-x86_32* ]]; then - BUILD_ANDROID_INCLUDE_DIR_x86=~/workspace/build_android/install/include - BUILD_ANDROID_LIB_DIR_x86=~/workspace/build_android/install/lib -else - BUILD_ANDROID_INCLUDE_DIR_x86=~/workspace/build_android/install/include - BUILD_ANDROID_LIB_DIR_x86=~/workspace/build_android/install/lib +BUILD_ANDROID_INCLUDE_DIR_arm_v8a=~/workspace/build_android_install_arm_v8a/install/include +BUILD_ANDROID_LIB_DIR_arm_v8a=~/workspace/build_android_install_arm_v8a/install/lib - BUILD_ANDROID_INCLUDE_DIR_x86_64=~/workspace/build_android_install_x86_64/install/include - BUILD_ANDROID_LIB_DIR_x86_64=~/workspace/build_android_install_x86_64/install/lib - - BUILD_ANDROID_INCLUDE_DIR_arm_v7a=~/workspace/build_android_install_arm_v7a/install/include - BUILD_ANDROID_LIB_DIR_arm_v7a=~/workspace/build_android_install_arm_v7a/install/lib - - BUILD_ANDROID_INCLUDE_DIR_arm_v8a=~/workspace/build_android_install_arm_v8a/install/include - BUILD_ANDROID_LIB_DIR_arm_v8a=~/workspace/build_android_install_arm_v8a/install/lib +PYTORCH_ANDROID_SRC_MAIN_DIR=~/workspace/android/pytorch_android/src/main - JNI_LIBS_DIR_x86_64=${JNI_LIBS_DIR}/x86_64 - mkdir -p $JNI_LIBS_DIR_x86_64 - JNI_LIBS_DIR_arm_v7a=${JNI_LIBS_DIR}/armeabi-v7a - mkdir -p $JNI_LIBS_DIR_arm_v7a - JNI_LIBS_DIR_arm_v8a=${JNI_LIBS_DIR}/arm64-v8a - mkdir -p $JNI_LIBS_DIR_arm_v8a +JNI_INCLUDE_DIR=${PYTORCH_ANDROID_SRC_MAIN_DIR}/cpp/libtorch_include +mkdir -p $JNI_INCLUDE_DIR - JNI_INCLUDE_DIR_x86_64=${JNI_INCLUDE_DIR}/x86_64 - JNI_INCLUDE_DIR_arm_v7a=${JNI_INCLUDE_DIR}/armeabi-v7a - JNI_INCLUDE_DIR_arm_v8a=${JNI_INCLUDE_DIR}/arm64-v8a +JNI_LIBS_DIR=${PYTORCH_ANDROID_SRC_MAIN_DIR}/jniLibs +mkdir -p $JNI_LIBS_DIR - ln -s ${BUILD_ANDROID_INCLUDE_DIR_x86_64} ${JNI_INCLUDE_DIR_x86_64} - ln -s ${BUILD_ANDROID_INCLUDE_DIR_arm_v7a} ${JNI_INCLUDE_DIR_arm_v7a} - ln -s ${BUILD_ANDROID_INCLUDE_DIR_arm_v8a} ${JNI_INCLUDE_DIR_arm_v8a} +ln -s ${BUILD_ANDROID_INCLUDE_DIR_x86} ${JNI_INCLUDE_DIR}/x86 +ln -s ${BUILD_ANDROID_LIB_DIR_x86} ${JNI_LIBS_DIR}/x86 - ln -s ${BUILD_ANDROID_LIB_DIR_x86_64}/libc10.so ${JNI_LIBS_DIR_x86_64}/libc10.so - ln -s ${BUILD_ANDROID_LIB_DIR_x86_64}/libtorch.so ${JNI_LIBS_DIR_x86_64}/libtorch.so +if [[ "${BUILD_ENVIRONMENT}" != *-gradle-build-only-x86_32* ]]; then +ln -s ${BUILD_ANDROID_INCLUDE_DIR_x86_64} ${JNI_INCLUDE_DIR}/x86_64 +ln -s ${BUILD_ANDROID_LIB_DIR_x86_64} ${JNI_LIBS_DIR}/x86_64 - ln -s ${BUILD_ANDROID_LIB_DIR_arm_v7a}/libc10.so ${JNI_LIBS_DIR_arm_v7a}/libc10.so - ln -s ${BUILD_ANDROID_LIB_DIR_arm_v7a}/libtorch.so ${JNI_LIBS_DIR_arm_v7a}/libtorch.so +ln -s ${BUILD_ANDROID_INCLUDE_DIR_arm_v7a} ${JNI_INCLUDE_DIR}/armeabi-v7a +ln -s ${BUILD_ANDROID_LIB_DIR_arm_v7a} ${JNI_LIBS_DIR}/armeabi-v7a - ln -s ${BUILD_ANDROID_LIB_DIR_arm_v8a}/libc10.so ${JNI_LIBS_DIR_arm_v8a}/libc10.so - ln -s ${BUILD_ANDROID_LIB_DIR_arm_v8a}/libtorch.so ${JNI_LIBS_DIR_arm_v8a}/libtorch.so +ln -s ${BUILD_ANDROID_INCLUDE_DIR_arm_v8a} ${JNI_INCLUDE_DIR}/arm64-v8a +ln -s ${BUILD_ANDROID_LIB_DIR_arm_v8a} ${JNI_LIBS_DIR}/arm64-v8a fi -ln -s ${BUILD_ANDROID_INCLUDE_DIR_x86} ${JNI_INCLUDE_DIR_x86} -ln -s ${BUILD_ANDROID_LIB_DIR_x86}/libc10.so ${JNI_LIBS_DIR_x86}/libc10.so -ln -s ${BUILD_ANDROID_LIB_DIR_x86}/libtorch.so ${JNI_LIBS_DIR_x86}/libtorch.so +env +echo "BUILD_ENVIRONMENT:$BUILD_ENVIRONMENT" export GRADLE_LOCAL_PROPERTIES=~/workspace/android/local.properties rm -f $GRADLE_LOCAL_PROPERTIES diff --git a/.circleci/scripts/publish_android_snapshot.sh b/.circleci/scripts/publish_android_snapshot.sh new file mode 100755 index 0000000000000..b309f17021c1f --- /dev/null +++ b/.circleci/scripts/publish_android_snapshot.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +# DO NOT ADD 'set -x' not to reveal CircleCI secret context environment variables +set -eu -o pipefail + +export ANDROID_NDK_HOME=/opt/ndk +export ANDROID_HOME=/opt/android/sdk + +export GRADLE_VERSION=4.10.3 +export GRADLE_HOME=/opt/gradle/gradle-$GRADLE_VERSION +export GRADLE_PATH=$GRADLE_HOME/bin/gradle + +echo "BUILD_ENVIRONMENT:$BUILD_ENVIRONMENT" +ls -la ~/workspace + +GRADLE_PROPERTIES=~/workspace/android/gradle.properties + +IS_SNAPSHOT="$(grep 'VERSION_NAME=[0-9\.]\+-SNAPSHOT' "$GRADLE_PROPERTIES")" +echo "IS_SNAPSHOT:$IS_SNAPSHOT" + +if [ -z "$IS_SNAPSHOT" ]; then + echo "Error: version is not snapshot." +elif [ -z "$SONATYPE_NEXUS_USERNAME" ]; then + echo "Error: missing env variable SONATYPE_NEXUS_USERNAME." +elif [ -z "$SONATYPE_NEXUS_PASSWORD" ]; then + echo "Error: missing env variable SONATYPE_NEXUS_PASSWORD." +elif [ -z "$ANDROID_SIGN_KEY" ]; then + echo "Error: missing env variable ANDROID_SIGN_KEY." +elif [ -z "$ANDROID_SIGN_PASS" ]; then + echo "Error: missing env variable ANDROID_SIGN_PASS." +else + GRADLE_LOCAL_PROPERTIES=~/workspace/android/local.properties + rm -f $GRADLE_LOCAL_PROPERTIES + + echo "sdk.dir=/opt/android/sdk" >> $GRADLE_LOCAL_PROPERTIES + echo "ndk.dir=/opt/ndk" >> $GRADLE_LOCAL_PROPERTIES + + echo "SONATYPE_NEXUS_USERNAME=${SONATYPE_NEXUS_USERNAME}" >> $GRADLE_PROPERTIES + echo "SONATYPE_NEXUS_PASSWORD=${SONATYPE_NEXUS_PASSWORD}" >> $GRADLE_PROPERTIES + + echo "signing.keyId=${ANDROID_SIGN_KEY}" >> $GRADLE_PROPERTIES + echo "signing.password=${ANDROID_SIGN_PASS}" >> $GRADLE_PROPERTIES + + $GRADLE_PATH -p ~/workspace/android/ uploadArchives +fi diff --git a/.circleci/scripts/setup_ci_environment.sh b/.circleci/scripts/setup_ci_environment.sh index b81d0d34d6388..80782380c8fbb 100755 --- a/.circleci/scripts/setup_ci_environment.sh +++ b/.circleci/scripts/setup_ci_environment.sh @@ -45,7 +45,7 @@ retry () { retry sudo pip -q install awscli==1.16.35 if [ -n "${USE_CUDA_DOCKER_RUNTIME:-}" ]; then - DRIVER_FN="NVIDIA-Linux-x86_64-410.104.run" + DRIVER_FN="NVIDIA-Linux-x86_64-430.40.run" wget "https://s3.amazonaws.com/ossci-linux/nvidia_driver/$DRIVER_FN" sudo /bin/bash "$DRIVER_FN" -s --no-drm || (sudo cat /var/log/nvidia-installer.log && false) nvidia-smi diff --git a/.circleci/scripts/should_run_job.py b/.circleci/scripts/should_run_job.py index 3381fff0ead01..4e0603074b183 100644 --- a/.circleci/scripts/should_run_job.py +++ b/.circleci/scripts/should_run_job.py @@ -35,17 +35,27 @@ 'manywheel 2.7mu cpu devtoolset7', 'libtorch 2.7m cpu devtoolset7', 'libtorch 2.7m cpu gcc5.4_cxx11-abi', + 'libtorch-ios-10.2.1-nightly-x86_64-build', + 'libtorch-ios-10.2.1-nightly-arm64-build', + 'libtorch-ios-10.2.1-nightly-binary-build-upload', # Caffe2 Android 'caffe2-py2-android-ubuntu16.04', # Caffe2 OSX 'caffe2-py2-system-macos10.13', # PyTorch OSX + 'pytorch-macos-10.13-py3', 'pytorch-macos-10.13-cuda9.2-cudnn7-py3', # PyTorch Android 'pytorch-linux-xenial-py3-clang5-android-ndk-r19c-x86_32-build', # PyTorch Android gradle 'pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build-only-x86_32', + # Pytorch iOS builds + 'pytorch-ios-10.2.1-x86_64_build', + 'pytorch-ios-10.2.1-arm64_build', + + # Pytorch backward compatibility check + 'pytorch-linux-backward-compatibility-check-test', # XLA 'pytorch-xla-linux-xenial-py3.6-clang7', diff --git a/.circleci/verbatim-sources/binary-job-specs.yml b/.circleci/verbatim-sources/binary-job-specs.yml index 16034793cb30c..4e95223177595 100644 --- a/.circleci/verbatim-sources/binary-job-specs.yml +++ b/.circleci/verbatim-sources/binary-job-specs.yml @@ -2,10 +2,7 @@ <<: *binary_linux_build_params steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job + - should_run_job - run: <<: *binary_checkout - run: @@ -59,17 +56,12 @@ image: ubuntu-1604:201903-01 steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace + - should_run_job # TODO: We shouldn't attach the workspace multiple times - attach_workspace: at: /home/circleci/project - - run: - <<: *should_run_job - - run: - <<: *setup_linux_system_environment - - run: - <<: *setup_ci_environment + - setup_linux_system_environment + - setup_ci_environment - run: <<: *binary_checkout - run: @@ -87,14 +79,9 @@ image: ubuntu-1604:201903-01 steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job - - run: - <<: *setup_linux_system_environment - - run: - <<: *setup_ci_environment + - should_run_job + - setup_linux_system_environment + - setup_ci_environment - attach_workspace: at: /home/circleci/project - run: @@ -120,10 +107,8 @@ at: ~/workspace - attach_workspace: at: /home/circleci/project - - run: - <<: *setup_linux_system_environment - - run: - <<: *setup_ci_environment + - setup_linux_system_environment + - setup_ci_environment - run: <<: *binary_checkout - run: @@ -155,8 +140,7 @@ <<: *binary_checkout - run: <<: *binary_populate_env - - run: - <<: *binary_macos_brew_update + - brew_update - run: <<: *binary_install_miniconda - run: @@ -177,16 +161,12 @@ xcode: "9.0" steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job + - should_run_job - run: <<: *binary_checkout - run: <<: *binary_populate_env - - run: - <<: *binary_macos_brew_update + - brew_update - run: <<: *binary_install_miniconda @@ -218,16 +198,12 @@ xcode: "9.0" steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job + - should_run_job - run: <<: *binary_checkout - run: <<: *binary_populate_env - - run: - <<: *binary_macos_brew_update + - brew_update - run: <<: *binary_install_miniconda - attach_workspace: # TODO - we can `cp` from ~/workspace @@ -240,3 +216,42 @@ cat "$script" source "$script" + binary_ios_build: + <<: *pytorch_ios_params + macos: + xcode: "10.2.1" + steps: + - attach_workspace: + at: ~/workspace + - should_run_job + - checkout + - run_brew_for_ios_build + - run: + name: Build + contxt: org-member + no_output_timeout: "1h" + command: | + script="/Users/distiller/project/.circleci/scripts/binary_ios_build.sh" + cat "$script" + source "$script" + - persist_to_workspace: + root: /Users/distiller/workspace/ + paths: ios + + binary_ios_upload: + <<: *pytorch_ios_params + macos: + xcode: "10.2.1" + steps: + - attach_workspace: + at: ~/workspace + - should_run_job + - checkout + - run_brew_for_ios_build + - run: + name: Upload + no_output_timeout: "1h" + command: | + script="/Users/distiller/project/.circleci/scripts/binary_ios_upload.sh" + cat "$script" + source "$script" \ No newline at end of file diff --git a/.circleci/verbatim-sources/binary_update_htmls.yml b/.circleci/verbatim-sources/binary_update_htmls.yml index 0ac7d16d0e372..89ff83e6b73cd 100644 --- a/.circleci/verbatim-sources/binary_update_htmls.yml +++ b/.circleci/verbatim-sources/binary_update_htmls.yml @@ -1,4 +1,4 @@ - + # update_s3_htmls job # These jobs create html files for every cpu/cu## folder in s3. The html # files just store the names of all the files in that folder (which are @@ -12,8 +12,7 @@ steps: - attach_workspace: at: ~/workspace - - run: - <<: *setup_linux_system_environment + - setup_linux_system_environment - run: <<: *binary_checkout # N.B. we do not run binary_populate_env. The only variable we need is @@ -67,8 +66,7 @@ steps: - attach_workspace: at: ~/workspace - - run: - <<: *setup_linux_system_environment + - setup_linux_system_environment - run: <<: *binary_checkout - run: diff --git a/.circleci/verbatim-sources/caffe2-job-specs.yml b/.circleci/verbatim-sources/caffe2-job-specs.yml index 34dfed1b7d1f1..f465d76fad850 100644 --- a/.circleci/verbatim-sources/caffe2-job-specs.yml +++ b/.circleci/verbatim-sources/caffe2-job-specs.yml @@ -4,15 +4,10 @@ image: ubuntu-1604:201903-01 steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job - - run: - <<: *setup_linux_system_environment + - should_run_job + - setup_linux_system_environment - checkout - - run: - <<: *setup_ci_environment + - setup_ci_environment - run: name: Build no_output_timeout: "1h" @@ -69,14 +64,9 @@ image: ubuntu-1604:201903-01 steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *setup_linux_system_environment - - run: - <<: *should_run_job - - run: - <<: *setup_ci_environment + - should_run_job + - setup_linux_system_environment + - setup_ci_environment - run: name: Test no_output_timeout: "1h" @@ -137,13 +127,9 @@ xcode: "9.0" steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job + - should_run_job - checkout - - run: - <<: *macos_brew_update + - run_brew_for_macos_build - run: name: Build no_output_timeout: "1h" diff --git a/.circleci/verbatim-sources/commands.yml b/.circleci/verbatim-sources/commands.yml new file mode 100644 index 0000000000000..acf44d3d42561 --- /dev/null +++ b/.circleci/verbatim-sources/commands.yml @@ -0,0 +1,90 @@ +commands: + # NB: This command must be run as the first command in a job. It + # attaches the workspace at ~/workspace; this workspace is generated + # by the setup job. Note that ~/workspace is not the default working + # directory (that's ~/project). + should_run_job: + description: "Test if the job should run or not" + steps: + - attach_workspace: + name: Attaching workspace + at: ~/workspace + - run: + name: Should run job + no_output_timeout: "2m" + command: ~/workspace/.circleci/scripts/should_run_job.sh + + # This system setup script is meant to run before the CI-related scripts, e.g., + # installing Git client, checking out code, setting up CI env, and + # building/testing. + setup_linux_system_environment: + steps: + - run: + name: Set Up System Environment + no_output_timeout: "1h" + command: ~/workspace/.circleci/scripts/setup_linux_system_environment.sh + + setup_ci_environment: + steps: + - run: + name: Set Up CI Environment After attach_workspace + no_output_timeout: "1h" + command: ~/workspace/.circleci/scripts/setup_ci_environment.sh + + brew_update: + description: "Update Homebrew and install base formulae" + steps: + - run: + name: Update Homebrew + no_output_timeout: "10m" + command: | + set -ex + + # Update repositories manually. + # Running `brew update` produces a comparison between the + # current checkout and the updated checkout, which takes a + # very long time because the existing checkout is 2y old. + for path in $(find /usr/local/Homebrew -type d -name .git) + do + cd $path/.. + git fetch --depth=1 origin + git reset --hard origin/master + done + + export HOMEBREW_NO_AUTO_UPDATE=1 + + # Install expect and moreutils so that we can call `unbuffer` and `ts`. + # moreutils installs a `parallel` executable by default, which conflicts + # with the executable from the GNU `parallel`, so we must unlink GNU + # `parallel` first, and relink it afterwards. + brew unlink parallel + brew install moreutils + brew link parallel --overwrite + brew install expect + + brew_install: + description: "Install Homebrew formulae" + parameters: + formulae: + type: string + default: "" + steps: + - run: + name: Install << parameters.formulae >> + no_output_timeout: "10m" + command: | + set -ex + export HOMEBREW_NO_AUTO_UPDATE=1 + brew install << parameters.formulae >> + + run_brew_for_macos_build: + steps: + - brew_update + - brew_install: + formulae: libomp + + run_brew_for_ios_build: + steps: + - brew_update + - brew_install: + formulae: libtool diff --git a/.circleci/verbatim-sources/header-section.yml b/.circleci/verbatim-sources/header-section.yml index 804cdb8d9488e..462901c978672 100644 --- a/.circleci/verbatim-sources/header-section.yml +++ b/.circleci/verbatim-sources/header-section.yml @@ -19,47 +19,3 @@ docker_config_defaults: &docker_config_defaults # This IAM user only allows read-write access to ECR aws_access_key_id: ${CIRCLECI_AWS_ACCESS_KEY_FOR_ECR_READ_WRITE_V4} aws_secret_access_key: ${CIRCLECI_AWS_SECRET_KEY_FOR_ECR_READ_WRITE_V4} - -# This system setup script is meant to run before the CI-related scripts, e.g., -# installing Git client, checking out code, setting up CI env, and -# building/testing. -setup_linux_system_environment: &setup_linux_system_environment - name: Set Up System Environment - no_output_timeout: "1h" - command: ~/workspace/.circleci/scripts/setup_linux_system_environment.sh - -# NB: This (and the command below) must be run after attaching -# ~/workspace. This is NOT the default working directory (that's -# ~/project); this workspace is generated by the setup job. -should_run_job: &should_run_job - name: Should Run Job After attach_workspace - no_output_timeout: "2m" - command: ~/workspace/.circleci/scripts/should_run_job.sh - -setup_ci_environment: &setup_ci_environment - name: Set Up CI Environment After attach_workspace - no_output_timeout: "1h" - command: ~/workspace/.circleci/scripts/setup_ci_environment.sh - -# Installs expect and moreutils so that we can call `unbuffer` and `ts`. -# Also installs OpenMP -# !!!!NOTE!!!! this is copied into a binary_macos_brew_update job which is the -# same but does not install libomp. If you are changing this, consider if you -# need to change that step as well. -macos_brew_update: &macos_brew_update - name: Brew update and install moreutils, expect and libomp - no_output_timeout: "1h" - command: | - set -ex - # See https://discourse.brew.sh/t/fetching-homebrew-repos-is-slow/5374/3 - brew untap caskroom/homebrew-cask - # moreutils installs a `parallel` executable by default, which conflicts - # with the executable from the GNU `parallel`, so we must unlink GNU - # `parallel` first, and relink it afterwards - brew update - brew unlink parallel - brew install moreutils - brew link parallel --overwrite - brew install expect - brew install libomp - diff --git a/.circleci/verbatim-sources/job-specs-custom.yml b/.circleci/verbatim-sources/job-specs-custom.yml index 825f03f23b33c..1de89113f9bf4 100644 --- a/.circleci/verbatim-sources/job-specs-custom.yml +++ b/.circleci/verbatim-sources/job-specs-custom.yml @@ -1,7 +1,7 @@ pytorch_short_perf_test_gpu: environment: BUILD_ENVIRONMENT: pytorch-short-perf-test-gpu - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:339" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:347" PYTHON_VERSION: "3.6" USE_CUDA_DOCKER_RUNTIME: "1" resource_class: gpu.medium @@ -9,14 +9,9 @@ image: ubuntu-1604:201903-01 steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job - - run: - <<: *setup_linux_system_environment - - run: - <<: *setup_ci_environment + - should_run_job + - setup_linux_system_environment + - setup_ci_environment - run: name: Perf Test no_output_timeout: "1h" @@ -42,20 +37,15 @@ environment: BUILD_ENVIRONMENT: pytorch-python-doc-push # TODO: stop hardcoding this - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:339" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:347" resource_class: large machine: image: ubuntu-1604:201903-01 steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job - - run: - <<: *setup_linux_system_environment - - run: - <<: *setup_ci_environment + - should_run_job + - setup_linux_system_environment + - setup_ci_environment - run: name: Doc Build and Push no_output_timeout: "1h" @@ -92,20 +82,15 @@ pytorch_cpp_doc_push: environment: BUILD_ENVIRONMENT: pytorch-cpp-doc-push - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:339" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:347" resource_class: large machine: image: ubuntu-1604:201903-01 steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job - - run: - <<: *setup_linux_system_environment - - run: - <<: *setup_ci_environment + - should_run_job + - setup_linux_system_environment + - setup_ci_environment - run: name: Doc Build and Push no_output_timeout: "1h" @@ -146,26 +131,21 @@ xcode: "9.0" steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job + - should_run_job - checkout - - run: - <<: *macos_brew_update + - run_brew_for_macos_build - run: name: Build no_output_timeout: "1h" command: | set -e - export IN_CIRCLECI=1 # Install sccache sudo curl https://s3.amazonaws.com/ossci-macos/sccache --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache - export SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2 + # This IAM user allows write access to S3 bucket for sccache set +x export AWS_ACCESS_KEY_ID=${CIRCLECI_AWS_ACCESS_KEY_FOR_SCCACHE_S3_BUCKET_V4} @@ -175,14 +155,14 @@ chmod a+x .jenkins/pytorch/macos-build.sh unbuffer .jenkins/pytorch/macos-build.sh 2>&1 | ts - mkdir -p /Users/distiller/pytorch-ci-env/workspace - # copy with -a to preserve relative structure (e.g., symlinks), and be recursive - cp -a /Users/distiller/project/. /Users/distiller/pytorch-ci-env/workspace + cp -a ~/project ~/workspace + - persist_to_workspace: - root: /Users/distiller/pytorch-ci-env + root: ~/workspace paths: - - "*" + - miniconda3 + - project pytorch_macos_10_13_py3_test: environment: @@ -192,12 +172,8 @@ steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml # This workspace also carries binaries from the build job - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job - - run: - <<: *macos_brew_update + - should_run_job + - run_brew_for_macos_build - run: name: Test no_output_timeout: "1h" @@ -206,10 +182,7 @@ export IN_CIRCLECI=1 # copy with -a to preserve relative structure (e.g., symlinks), and be recursive - # TODO: I'm not sure why we can't just run our job in - # ~/workspace and call it a day - # NB: Yes, you need workspace twice - cp -a ~/workspace/workspace/. /Users/distiller/project + cp -a ~/workspace/project/. ~/project chmod a+x .jenkins/pytorch/macos-test.sh unbuffer .jenkins/pytorch/macos-test.sh 2>&1 | ts @@ -221,13 +194,9 @@ xcode: "9.0" steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job + - should_run_job - checkout - - run: - <<: *macos_brew_update + - run_brew_for_macos_build - run: name: Build no_output_timeout: "1h" @@ -269,21 +238,16 @@ pytorch_android_gradle_build: environment: BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:339" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:347" PYTHON_VERSION: "3.6" resource_class: large machine: image: ubuntu-1604:201903-01 steps: - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job - - run: - <<: *setup_linux_system_environment + - should_run_job + - setup_linux_system_environment - checkout - - run: - <<: *setup_ci_environment + - setup_ci_environment - run: name: pytorch android gradle build no_output_timeout: "1h" @@ -357,19 +321,52 @@ path: ~/workspace/build_android_artifacts/artifacts.tgz destination: artifacts.tgz + pytorch_android_publish_snapshot: + environment: + BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-publish-snapshot + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:347" + PYTHON_VERSION: "3.6" + resource_class: large + machine: + image: ubuntu-1604:201903-01 + steps: + - should_run_job + - setup_linux_system_environment + - checkout + - setup_ci_environment + - run: + name: pytorch android gradle build + no_output_timeout: "1h" + command: | + set -eux + docker_image_commit=${DOCKER_IMAGE}-${CIRCLE_SHA1} + + docker_image_libtorch_android_x86_32_gradle=${docker_image_commit}-android-x86_32-gradle + + echo "docker_image_commit: "${docker_image_commit} + echo "docker_image_libtorch_android_x86_32_gradle: "${docker_image_libtorch_android_x86_32_gradle} + + # x86_32 + docker pull ${docker_image_libtorch_android_x86_32_gradle} >/dev/null + export id_x86_32=$(docker run -t -d -w /var/lib/jenkins ${docker_image_libtorch_android_x86_32_gradle}) + + export COMMAND='((echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace" && echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "export SONATYPE_NEXUS_USERNAME=${SONATYPE_NEXUS_USERNAME}" && echo "export SONATYPE_NEXUS_PASSWORD=${SONATYPE_NEXUS_PASSWORD}" && echo "export ANDROID_SIGN_KEY=${ANDROID_SIGN_KEY}" && echo "export ANDROID_SIGN_PASS=${ANDROID_SIGN_PASS}" && echo "sudo chown -R jenkins workspace && cd workspace && ./.circleci/scripts/publish_android_snapshot.sh") | docker exec -u jenkins -i "$id_x86_32" bash) 2>&1' + echo ${COMMAND} > ./command.sh && unbuffer bash ./command.sh | ts + + output_image=${docker_image_libtorch_android_x86_32_gradle}-publish-snapshot + docker commit "$id_x86_32" ${output_image} + docker push ${output_image} + pytorch_android_gradle_build-x86_32: environment: BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build-only-x86_32 - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:339" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:347" PYTHON_VERSION: "3.6" resource_class: large machine: image: ubuntu-1604:201903-01 steps: - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job + - should_run_job - run: name: filter out not PR runs no_output_timeout: "5m" @@ -378,11 +375,9 @@ if [ -z "${CIRCLE_PULL_REQUEST:-}" ]; then circleci step halt fi - - run: - <<: *setup_linux_system_environment + - setup_linux_system_environment - checkout - - run: - <<: *setup_ci_environment + - setup_ci_environment - run: name: pytorch android gradle build only x86_32 (for PR) no_output_timeout: "1h" @@ -401,9 +396,51 @@ mkdir -p ~/workspace/build_android_x86_32_artifacts docker cp $id:/var/lib/jenkins/workspace/android/artifacts.tgz ~/workspace/build_android_x86_32_artifacts/ - output_image=${DOCKER_IMAGE}-${CIRCLE_SHA1}-android-gradle-x86_32 + output_image=${docker_image_libtorch_android_x86_32}-gradle docker commit "$id" ${output_image} docker push ${output_image} - store_artifacts: path: ~/workspace/build_android_x86_32_artifacts/artifacts.tgz destination: artifacts.tgz + + pytorch_ios_build: + <<: *pytorch_ios_params + macos: + xcode: "10.2.1" + steps: + # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml + - should_run_job + - checkout + - run_brew_for_ios_build + - run: + name: Build + no_output_timeout: "1h" + command: | + set -e + export IN_CIRCLECI=1 + WORKSPACE=/Users/distiller/workspace + PROJ_ROOT=/Users/distiller/project + export TCLLIBPATH="/usr/local/lib" + + # Install conda + curl -o ~/Downloads/conda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + chmod +x ~/Downloads/conda.sh + /bin/bash ~/Downloads/conda.sh -b -p ~/anaconda + export PATH="~/anaconda/bin:${PATH}" + source ~/anaconda/bin/activate + # Install dependencies + conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests + # sync submodules + cd ${PROJ_ROOT} + git submodule sync + git submodule update --init --recursive + # export + export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} + # run build script + chmod a+x ${PROJ_ROOT}/scripts/build_ios.sh + echo "IOS_ARCH: ${IOS_ARCH}" + echo "IOS_PLATFORM: ${IOS_PLATFORM}" + export BUILD_PYTORCH_MOBILE=1 + export IOS_ARCH=${IOS_ARCH} + export IOS_PLATFORM=${IOS_PLATFORM} + unbuffer ${PROJ_ROOT}/scripts/build_ios.sh 2>&1 | ts diff --git a/.circleci/verbatim-sources/job-specs-setup.yml b/.circleci/verbatim-sources/job-specs-setup.yml index 1fad450b438b7..500fb50e5ce67 100644 --- a/.circleci/verbatim-sources/job-specs-setup.yml +++ b/.circleci/verbatim-sources/job-specs-setup.yml @@ -1,4 +1,4 @@ - + setup: docker: - image: circleci/python:3.7.3 diff --git a/.circleci/verbatim-sources/nightly-binary-build-defaults.yml b/.circleci/verbatim-sources/nightly-binary-build-defaults.yml index 087e8691a5975..1aa8e76241e32 100644 --- a/.circleci/verbatim-sources/nightly-binary-build-defaults.yml +++ b/.circleci/verbatim-sources/nightly-binary-build-defaults.yml @@ -48,23 +48,3 @@ binary_run_in_docker: &binary_run_in_docker # This step only runs on circleci linux machine executors that themselves # need to start docker images command: ~/workspace/.circleci/scripts/binary_run_in_docker.sh - -# This is copied almost verbatim from the macos_brew_update job -# In version 2.1 and above we could make this a command and pass a parameter to -# it, but in this version there is no way to pass a parameter to a step -binary_macos_brew_update: &binary_macos_brew_update - name: Brew update and install moreutils and expect - no_output_timeout: "1h" - command: | - set -eux -o pipefail - # See https://discourse.brew.sh/t/fetching-homebrew-repos-is-slow/5374/3 - brew untap caskroom/homebrew-cask - # moreutils installs a `parallel` executable by default, which conflicts - # with the executable from the GNU `parallel`, so we must unlink GNU - # `parallel` first, and relink it afterwards - brew update - brew unlink parallel - brew install moreutils - brew link parallel --overwrite - brew install expect - diff --git a/.circleci/verbatim-sources/pytorch-build-params.yml b/.circleci/verbatim-sources/pytorch-build-params.yml index 86fff991cb03d..052bc7ff8428c 100644 --- a/.circleci/verbatim-sources/pytorch-build-params.yml +++ b/.circleci/verbatim-sources/pytorch-build-params.yml @@ -18,3 +18,22 @@ pytorch_params: &pytorch_params USE_CUDA_DOCKER_RUNTIME: << parameters.use_cuda_docker_runtime >> resource_class: << parameters.resource_class >> +pytorch_ios_params: &pytorch_ios_params + parameters: + build_environment: + type: string + default: "" + ios_arch: + type: string + default: "" + ios_platform: + type: string + default: "" + environment: + BUILD_ENVIRONMENT: << parameters.build_environment >> + IOS_ARCH: << parameters.ios_arch >> + IOS_PLATFORM: << parameters.ios_platform >> + + + + diff --git a/.circleci/verbatim-sources/pytorch-job-specs.yml b/.circleci/verbatim-sources/pytorch-job-specs.yml index 9da4f4a8c1fd6..44710d857c234 100644 --- a/.circleci/verbatim-sources/pytorch-job-specs.yml +++ b/.circleci/verbatim-sources/pytorch-job-specs.yml @@ -5,15 +5,10 @@ jobs: image: ubuntu-1604:201903-01 steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job - - run: - <<: *setup_linux_system_environment + - should_run_job + - setup_linux_system_environment - checkout - - run: - <<: *setup_ci_environment + - setup_ci_environment - run: name: Build no_output_timeout: "1h" @@ -24,6 +19,28 @@ jobs: docker pull ${DOCKER_IMAGE} >/dev/null export id=$(docker run -t -d -w /var/lib/jenkins ${DOCKER_IMAGE}) + # TODO We may want to move the rebase logic to a separate step after checkout + # Rebase to master only if in xenial_py3_6_gcc5_4 case + if [[ "${CIRCLE_BRANCH}" != "master" && "${BUILD_ENVIRONMENT}" == *"gcc5"* ]]; then + echo "Merge master branch into $CIRCLE_BRANCH before build in environment $BUILD_ENVIRONMENT" + set -x + git config --global user.email "circleci.ossci@gmail.com" + git config --global user.name "CircleCI" + git config remote.origin.url https://github.com/pytorch/pytorch.git + git config --add remote.origin.fetch +refs/heads/master:refs/remotes/origin/master + git fetch --tags --progress https://github.com/pytorch/pytorch.git +refs/heads/master:refs/remotes/origin/master --depth=50 --quiet + export GIT_MERGE_TARGET=`git log -n 1 --pretty=format:"%H" origin/master` + echo "GIT_MERGE_TARGET: " ${GIT_MERGE_TARGET} + export GIT_COMMIT=${CIRCLE_SHA1} + echo "GIT_COMMIT: " ${GIT_COMMIT} + git checkout -f ${GIT_COMMIT} + git reset --hard ${GIT_COMMIT} + git merge --no-edit --no-ff ${GIT_MERGE_TARGET} + set +x + else + echo "Do NOT merge master branch into $CIRCLE_BRANCH in environment $BUILD_ENVIRONMENT" + fi + git submodule sync && git submodule update -q --init --recursive docker cp /home/circleci/project/. $id:/var/lib/jenkins/workspace @@ -32,11 +49,6 @@ jobs: NAMED_FLAG="export BUILD_NAMEDTENSOR=1" fi - # dispatch aten ops statically for mobile - if [[ ${BUILD_ENVIRONMENT} == *"android"* ]]; then - NAMED_FLAG="export USE_STATIC_DISPATCH=1" - fi - export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo '"$NAMED_FLAG"' && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && .jenkins/pytorch/build.sh") | docker exec -u jenkins -i "$id" bash) 2>&1' echo ${COMMAND} > ./command.sh && unbuffer bash ./command.sh | ts @@ -72,14 +84,9 @@ jobs: image: ubuntu-1604:201903-01 steps: # See Note [Workspace for CircleCI scripts] in job-specs-setup.yml - - attach_workspace: - at: ~/workspace - - run: - <<: *should_run_job - - run: - <<: *setup_linux_system_environment - - run: - <<: *setup_ci_environment + - should_run_job + - setup_linux_system_environment + - setup_ci_environment - run: name: Test no_output_timeout: "90m" @@ -89,7 +96,7 @@ jobs: output_image=${DOCKER_IMAGE}-${CIRCLE_SHA1} if [[ ${BUILD_ENVIRONMENT} == *"namedtensor"* ]]; then export COMMIT_DOCKER_IMAGE=$output_image-namedtensor - NAMED_FLAG="export BUILD_NAMEDTENSOR=1" + export NAMED_FLAG="export BUILD_NAMEDTENSOR=1 && export TEST_NAMEDTENSOR=1" elif [[ ${BUILD_ENVIRONMENT} == *"xla"* ]]; then export COMMIT_DOCKER_IMAGE=$output_image-xla else @@ -103,9 +110,8 @@ jobs: export id=$(docker run -t -d -w /var/lib/jenkins ${COMMIT_DOCKER_IMAGE}) fi if [[ ${BUILD_ENVIRONMENT} == *"multigpu"* ]]; then - export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo '"$NAMED_FLAG"' && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && .jenkins/pytorch/multigpu-test.sh") | docker exec -u jenkins -i "$id" bash) 2>&1' + export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "${NAMED_FLAG}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && .jenkins/pytorch/multigpu-test.sh") | docker exec -u jenkins -i "$id" bash) 2>&1' else - export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo '"$NAMED_FLAG"'&& echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && .jenkins/pytorch/test.sh") | docker exec -u jenkins -i "$id" bash) 2>&1' + export COMMAND='((echo "export BUILD_ENVIRONMENT=${BUILD_ENVIRONMENT}" && echo "${NAMED_FLAG}" && echo "source ./workspace/env" && echo "sudo chown -R jenkins workspace && cd workspace && .jenkins/pytorch/test.sh") | docker exec -u jenkins -i "$id" bash) 2>&1' fi echo ${COMMAND} > ./command.sh && unbuffer bash ./command.sh | ts - diff --git a/.circleci/verbatim-sources/workflows-nightly-android-binary-builds.yml b/.circleci/verbatim-sources/workflows-nightly-android-binary-builds.yml new file mode 100644 index 0000000000000..c91db337a4e61 --- /dev/null +++ b/.circleci/verbatim-sources/workflows-nightly-android-binary-builds.yml @@ -0,0 +1,38 @@ + - pytorch_linux_build: + name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build + build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-x86_32" + requires: + - setup + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:339" + - pytorch_linux_build: + name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_64_build + build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-x86_64" + requires: + - setup + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:339" + - pytorch_linux_build: + name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v7a_build + build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-arm-v7a" + requires: + - setup + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:339" + - pytorch_linux_build: + name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v8a_build + build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-arm-v8a" + requires: + - setup + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:339" + + - pytorch_android_gradle_build: + name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_android_gradle_build + requires: + - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build + - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_64_build + - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v7a_build + - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v8a_build + + - pytorch_android_publish_snapshot: + name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_android_publish_snapshot + requires: + - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_android_gradle_build + context: org-member diff --git a/.circleci/verbatim-sources/workflows-nightly-ios-binary-builds.yml b/.circleci/verbatim-sources/workflows-nightly-ios-binary-builds.yml new file mode 100644 index 0000000000000..ab99d0a18a8e1 --- /dev/null +++ b/.circleci/verbatim-sources/workflows-nightly-ios-binary-builds.yml @@ -0,0 +1,22 @@ + # Pytorch iOS binary builds + - binary_ios_build: + name: pytorch_ios_10_2_1_nightly_x86_64_build + build_environment: "libtorch-ios-10.2.1-nightly-x86_64-build" + ios_platform: "SIMULATOR" + ios_arch: "x86_64" + requires: + - setup + - binary_ios_build: + name: pytorch_ios_10_2_1_nigthly_arm64_build + build_environment: "libtorch-ios-10.2.1-nightly-arm64-build" + ios_arch: "arm64" + ios_platform: "OS" + requires: + - setup + - binary_ios_upload: + build_environment: "libtorch-ios-10.2.1-nightly-binary-build-upload" + context: org-member + requires: + - setup + - pytorch_ios_10_2_1_nightly_x86_64_build + - pytorch_ios_10_2_1_nigthly_arm64_build diff --git a/.circleci/verbatim-sources/workflows-pytorch-android-gradle-build.yml b/.circleci/verbatim-sources/workflows-pytorch-android-gradle-build.yml index d80ed704fbb64..36c3ce0701d0f 100644 --- a/.circleci/verbatim-sources/workflows-pytorch-android-gradle-build.yml +++ b/.circleci/verbatim-sources/workflows-pytorch-android-gradle-build.yml @@ -1,8 +1,10 @@ - pytorch_android_gradle_build-x86_32: + name: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build-x86_32 requires: - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build - pytorch_android_gradle_build: + name: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build requires: - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_32_build - pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_64_build diff --git a/.circleci/verbatim-sources/workflows-pytorch-ios-builds.yml b/.circleci/verbatim-sources/workflows-pytorch-ios-builds.yml new file mode 100644 index 0000000000000..4d744fea2ee7a --- /dev/null +++ b/.circleci/verbatim-sources/workflows-pytorch-ios-builds.yml @@ -0,0 +1,13 @@ + # Pytorch iOS PR builds + - pytorch_ios_build: + name: pytorch_ios_10_2_1_x86_64_build + build_environment: "pytorch-ios-10.2.1-x86_64_build" + ios_platform: "SIMULATOR" + requires: + - setup + - pytorch_ios_build: + name: pytorch_ios_10_2_1_arm64_build + build_environment: "pytorch-ios-10.2.1-arm64_build" + ios_arch: "arm64" + requires: + - setup diff --git a/.circleci/verbatim-sources/workflows-pytorch-macos-builds.yml b/.circleci/verbatim-sources/workflows-pytorch-macos-builds.yml index 13e8b8bcfb56b..d5baba61b1568 100644 --- a/.circleci/verbatim-sources/workflows-pytorch-macos-builds.yml +++ b/.circleci/verbatim-sources/workflows-pytorch-macos-builds.yml @@ -4,16 +4,10 @@ - pytorch_macos_10_13_py3_build: requires: - setup - filters: - branches: - only: master - pytorch_macos_10_13_py3_test: requires: - setup - pytorch_macos_10_13_py3_build - filters: - branches: - only: master - pytorch_macos_10_13_cuda9_2_cudnn7_py3_build: requires: - setup diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000000000..d0a20b280509b --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,46 @@ +name: Lint + +on: + push: + branches: + - master + pull_request: + +jobs: + flake8-py3: + runs-on: ubuntu-latest + steps: + - name: Setup Python + uses: actions/setup-python@v1 + with: + python-version: 3.7.4 + architecture: x64 + - name: Fetch PyTorch + uses: actions/checkout@master + - name: Checkout PR tip + run: | + set -eux + if [ -z "${GITHUB_HEAD_REF}" ]; then + # We are on master, just set the SHA from our current location + echo ::set-output name=commit_sha::${GITHUB_SHA} + else + # We are on a PR, we need to check out PR branch + git checkout ${GITHUB_HEAD_REF} + echo ::set-output name=commit_sha::$(git rev-parse ${GITHUB_HEAD_REF}) + fi + id: get_pr_tip + - name: Run flake8 + run: | + set -eux + pip install flake8 + flake8 --exit-zero > ${GITHUB_WORKSPACE}/flake8-output.txt + cat ${GITHUB_WORKSPACE}/flake8-output.txt + - name: Add annotations + uses: pytorch/add-annotations-github-action@master + with: + check_name: 'flake8-py3' + linter_output_path: 'flake8-output.txt' + commit_sha: ${{ steps.get_pr_tip.outputs.commit_sha }} + regex: '^(?.*?):(?\d+):(?\d+): (?\w\d+) (?[\s|\w]*)' + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index c415f79d15e40..20d7c5993ca86 100644 --- a/.gitignore +++ b/.gitignore @@ -37,6 +37,7 @@ test/cpp/api/mnist test/custom_operator/model.pt test/data/legacy_modules.t7 test/data/*.pt +test/backward_compatibility/new_schemas.txt dropout_model.pt test/generated_type_hints_smoketest.py test/htmlcov @@ -225,11 +226,6 @@ caffe2.egg-info # Files generated by CLion cmake-build-debug -# Files generated by ctags -CTAGS -tags -TAGS - # BEGIN NOT-CLEAN-FILES (setup.py handles this marker. Do not change.) # # Below files are not deleted by "setup.py clean". @@ -244,3 +240,12 @@ TAGS # Files generated when a patch is rejected *.orig *.rej + +# Files generated by ctags +CTAGS +GTAGS +GRTAGS +GSYMS +GPATH +tags +TAGS diff --git a/.jenkins/caffe2/build.sh b/.jenkins/caffe2/build.sh index 91609edfd183d..80065ade1eac3 100755 --- a/.jenkins/caffe2/build.sh +++ b/.jenkins/caffe2/build.sh @@ -274,13 +274,8 @@ fi pip install --user -b /tmp/pip_install_onnx "file://${ROOT_DIR}/third_party/onnx#egg=onnx" if [[ $BUILD_ENVIRONMENT == *rocm* ]]; then - ORIG_COMP=/opt/rocm/hcc/bin/clang-*_original - if [ -e $ORIG_COMP ]; then - # runtime compilation of MIOpen kernels manages to crash sccache - hence undo the wrapping - # note that the wrapping always names the compiler "clang-7.0_original" - WRAPPED=/opt/rocm/hcc/bin/clang-[0-99] - sudo mv $ORIG_COMP $WRAPPED - fi + # runtime compilation of MIOpen kernels manages to crash sccache - hence undo the wrapping + bash tools/amd_build/unwrap_clang.sh fi report_compile_cache_stats diff --git a/.jenkins/caffe2/test.sh b/.jenkins/caffe2/test.sh index 0d35d64f7b1a5..c7ec6e6136622 100755 --- a/.jenkins/caffe2/test.sh +++ b/.jenkins/caffe2/test.sh @@ -104,10 +104,6 @@ if [[ "$BUILD_ENVIRONMENT" == *py3* ]]; then export LANG=C.UTF-8 fi -if [[ "$BUILD_ENVIRONMENT" == *py2* ]]; then - pip install --user requests -fi - pip install --user pytest-sugar "$PYTHON" \ -m pytest \ @@ -135,7 +131,7 @@ if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then # default pip version is too old(9.0.2), unable to support tag `manylinux2010`. # Fix the pip error: Couldn't find a version that satisfies the requirement sudo pip install --upgrade pip - pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==0.5.0.dev817 + pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==0.5.0.dev905 fi "$ROOT_DIR/scripts/onnx/test.sh" fi diff --git a/.jenkins/pytorch/build.sh b/.jenkins/pytorch/build.sh index e65a077f9dc3c..42012a16d40bf 100755 --- a/.jenkins/pytorch/build.sh +++ b/.jenkins/pytorch/build.sh @@ -65,9 +65,6 @@ fi if [[ "${BUILD_ENVIRONMENT}" == *-android* ]]; then export ANDROID_NDK=/opt/ndk build_args=() - build_args+=("-DBUILD_CAFFE2_MOBILE=OFF") - - build_args+=("-DBUILD_SHARED_LIBS=ON") if [[ "${BUILD_ENVIRONMENT}" == *-arm-v7a* ]]; then build_args+=("-DANDROID_ABI=armeabi-v7a") elif [[ "${BUILD_ENVIRONMENT}" == *-arm-v8a* ]]; then @@ -77,9 +74,7 @@ if [[ "${BUILD_ENVIRONMENT}" == *-android* ]]; then elif [[ "${BUILD_ENVIRONMENT}" == *-x86_64* ]]; then build_args+=("-DANDROID_ABI=x86_64") fi - - build_args+=("-DCMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')") - build_args+=("-DPYTHON_EXECUTABLE=$(python -c 'import sys; print(sys.executable)')") + export BUILD_PYTORCH_MOBILE=1 exec ./scripts/build_android.sh "${build_args[@]}" "$@" fi @@ -121,14 +116,9 @@ if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then # LMDB is needed to read datasets from https://download.caffe2.ai/databases/resnet_trainer.zip USE_ROCM=1 USE_LMDB=1 USE_OPENCV=1 python setup.py install --user - ORIG_COMP=/opt/rocm/hcc/bin/clang-*_original - if [ -e $ORIG_COMP ]; then - # runtime compilation of MIOpen kernels manages to crash sccache - hence undo the wrapping - # note that the wrapping always names the compiler "clang-7.0_original" - WRAPPED=/opt/rocm/hcc/bin/clang-[0-99] - sudo mv $ORIG_COMP $WRAPPED + # runtime compilation of MIOpen kernels manages to crash sccache - hence undo the wrapping + bash tools/amd_build/unwrap_clang.sh - fi exit 0 fi diff --git a/.jenkins/pytorch/common.sh b/.jenkins/pytorch/common.sh index f58eeb261bfa1..8d6764ec1ae09 100644 --- a/.jenkins/pytorch/common.sh +++ b/.jenkins/pytorch/common.sh @@ -161,6 +161,11 @@ function pip_install() { pip install --progress-bar off "$@" || pip install --progress-bar off "$@" || pip install --progress-bar off "$@" } +function pip_uninstall() { + # uninstall 2 times + pip uninstall -y "$@" || pip uninstall -y "$@" +} + function get_exit_code() { set +e "$@" diff --git a/.jenkins/pytorch/macos-build.sh b/.jenkins/pytorch/macos-build.sh index 882fcc6124940..a27278c51ee57 100755 --- a/.jenkins/pytorch/macos-build.sh +++ b/.jenkins/pytorch/macos-build.sh @@ -1,27 +1,11 @@ #!/bin/bash # shellcheck disable=SC2034 -COMPACT_JOB_NAME="${BUILD_ENVIRONMENT}" - -export PATH="/usr/local/bin:$PATH" -source "$(dirname "${BASH_SOURCE[0]}")/common.sh" - -# Set up conda environment -export PYTORCH_ENV_DIR="${HOME}/pytorch-ci-env" -# If a local installation of conda doesn't exist, we download and install conda -if [ ! -d "${PYTORCH_ENV_DIR}/miniconda3" ]; then - mkdir -p ${PYTORCH_ENV_DIR} - curl https://repo.continuum.io/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o ${PYTORCH_ENV_DIR}/miniconda3.sh - bash ${PYTORCH_ENV_DIR}/miniconda3.sh -b -p ${PYTORCH_ENV_DIR}/miniconda3 -fi -export PATH="${PYTORCH_ENV_DIR}/miniconda3/bin:$PATH" -source ${PYTORCH_ENV_DIR}/miniconda3/bin/activate -conda install -y mkl mkl-include numpy pyyaml setuptools cmake cffi ninja -rm -rf ${PYTORCH_ENV_DIR}/miniconda3/lib/python3.6/site-packages/torch* +source "$(dirname "${BASH_SOURCE[0]}")/macos-common.sh" git submodule sync --recursive git submodule update --init --recursive -export CMAKE_PREFIX_PATH=${PYTORCH_ENV_DIR}/miniconda3/ +export CMAKE_PREFIX_PATH=${WORKSPACE_DIR}/miniconda3/ # Build PyTorch if [[ "${BUILD_ENVIRONMENT}" == *cuda9.2* ]]; then @@ -43,35 +27,29 @@ else fi fi -export MACOSX_DEPLOYMENT_TARGET=10.9 -export CXX=clang++ -export CC=clang if which sccache > /dev/null; then - printf "#!/bin/sh\nexec sccache $(which clang++) \$*" > "${PYTORCH_ENV_DIR}/clang++" - chmod a+x "${PYTORCH_ENV_DIR}/clang++" + printf "#!/bin/sh\nexec sccache $(which clang++) \$*" > "${WORKSPACE_DIR}/clang++" + chmod a+x "${WORKSPACE_DIR}/clang++" - printf "#!/bin/sh\nexec sccache $(which clang) \$*" > "${PYTORCH_ENV_DIR}/clang" - chmod a+x "${PYTORCH_ENV_DIR}/clang" + printf "#!/bin/sh\nexec sccache $(which clang) \$*" > "${WORKSPACE_DIR}/clang" + chmod a+x "${WORKSPACE_DIR}/clang" if [[ "${BUILD_ENVIRONMENT}" == *cuda* ]]; then - printf "#!/bin/sh\nexec sccache $(which nvcc) \$*" > "${PYTORCH_ENV_DIR}/nvcc" - chmod a+x "${PYTORCH_ENV_DIR}/nvcc" - export CUDA_NVCC_EXECUTABLE="${PYTORCH_ENV_DIR}/nvcc" + printf "#!/bin/sh\nexec sccache $(which nvcc) \$*" > "${WORKSPACE_DIR}/nvcc" + chmod a+x "${WORKSPACE_DIR}/nvcc" + export CUDA_NVCC_EXECUTABLE="${WORKSPACE_DIR}/nvcc" fi - export PATH="${PYTORCH_ENV_DIR}:$PATH" + export PATH="${WORKSPACE_DIR}:$PATH" fi -# If we run too many parallel jobs, we will OOM -export MAX_JOBS=2 - -export IMAGE_COMMIT_TAG=${BUILD_ENVIRONMENT}-${IMAGE_COMMIT_ID} -python setup.py install +# If we run too many parallel jobs, we will OOM +MAX_JOBS=2 USE_DISTRIBUTED=1 python setup.py install assert_git_not_dirty # Upload torch binaries when the build job is finished if [ -z "${IN_CIRCLECI}" ]; then - 7z a ${IMAGE_COMMIT_TAG}.7z ${PYTORCH_ENV_DIR}/miniconda3/lib/python3.6/site-packages/torch* + 7z a ${IMAGE_COMMIT_TAG}.7z ${WORKSPACE_DIR}/miniconda3/lib/python3.6/site-packages/torch* aws s3 cp ${IMAGE_COMMIT_TAG}.7z s3://ossci-macos-build/pytorch/${IMAGE_COMMIT_TAG}.7z --acl public-read fi diff --git a/.jenkins/pytorch/macos-common.sh b/.jenkins/pytorch/macos-common.sh new file mode 100755 index 0000000000000..c7e71d832ed6a --- /dev/null +++ b/.jenkins/pytorch/macos-common.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +# Common prelude for macos-build.sh and macos-test.sh + +# shellcheck disable=SC2034 +COMPACT_JOB_NAME="${BUILD_ENVIRONMENT}" + +source "$(dirname "${BASH_SOURCE[0]}")/common.sh" +export PATH="/usr/local/bin:$PATH" +export WORKSPACE_DIR="${HOME}/workspace" +mkdir -p ${WORKSPACE_DIR} + +# If a local installation of conda doesn't exist, we download and install conda +if [ ! -d "${WORKSPACE_DIR}/miniconda3" ]; then + mkdir -p ${WORKSPACE_DIR} + curl https://repo.continuum.io/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o ${WORKSPACE_DIR}/miniconda3.sh + bash ${WORKSPACE_DIR}/miniconda3.sh -b -p ${WORKSPACE_DIR}/miniconda3 +fi +export PATH="${WORKSPACE_DIR}/miniconda3/bin:$PATH" +source ${WORKSPACE_DIR}/miniconda3/bin/activate +conda install -y mkl mkl-include numpy pyyaml setuptools cmake cffi ninja + +# The torch.hub tests make requests to GitHub. +# +# The certifi package from conda-forge is new enough to make the +# following error disappear (included for future reference): +# +# > ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] +# > certificate verify failed: unable to get local issuer certificate +# > (_ssl.c:1056) +# +conda install -y -c conda-forge certifi + +# Needed by torchvision, which is imported from TestHub in test_utils.py. +conda install -y pillow + +# Building with USE_DISTRIBUTED=1 requires libuv (for Gloo). +conda install -y libuv pkg-config + +# Image commit tag is used to persist the build from the build job +# and to retrieve the build from the test job. +export IMAGE_COMMIT_TAG=${BUILD_ENVIRONMENT}-${IMAGE_COMMIT_ID} + +# These are required for both the build job and the test job. +# In the latter to test cpp extensions. +export MACOSX_DEPLOYMENT_TARGET=10.9 +export CXX=clang++ +export CC=clang diff --git a/.jenkins/pytorch/macos-test.sh b/.jenkins/pytorch/macos-test.sh index 30482f5b607c0..518ac11d76521 100755 --- a/.jenkins/pytorch/macos-test.sh +++ b/.jenkins/pytorch/macos-test.sh @@ -1,23 +1,9 @@ #!/bin/bash # shellcheck disable=SC2034 -COMPACT_JOB_NAME="${BUILD_ENVIRONMENT}" +source "$(dirname "${BASH_SOURCE[0]}")/macos-common.sh" -source "$(dirname "${BASH_SOURCE[0]}")/common.sh" - -export PATH="/usr/local/bin:$PATH" - -# Set up conda environment -export PYTORCH_ENV_DIR="${HOME}/workspace" -# If a local installation of conda doesn't exist, we download and install conda -if [ ! -d "${PYTORCH_ENV_DIR}/miniconda3" ]; then - mkdir -p ${PYTORCH_ENV_DIR} - curl https://repo.continuum.io/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o ${PYTORCH_ENV_DIR}/miniconda3.sh - bash ${PYTORCH_ENV_DIR}/miniconda3.sh -b -p ${PYTORCH_ENV_DIR}/miniconda3 -fi -export PATH="${PYTORCH_ENV_DIR}/miniconda3/bin:$PATH" -source ${PYTORCH_ENV_DIR}/miniconda3/bin/activate -conda install -y mkl mkl-include numpy pyyaml setuptools cmake cffi ninja six +conda install -y six pip install -q hypothesis "librosa>=0.6.2" psutil # faulthandler become built-in since 3.3 @@ -26,12 +12,12 @@ if [[ ! $(python -c "import sys; print(int(sys.version_info >= (3, 3)))") == "1" fi if [ -z "${IN_CIRCLECI}" ]; then - rm -rf ${PYTORCH_ENV_DIR}/miniconda3/lib/python3.6/site-packages/torch* + rm -rf ${WORKSPACE_DIR}/miniconda3/lib/python3.6/site-packages/torch* fi git submodule sync --recursive git submodule update --init --recursive -export CMAKE_PREFIX_PATH=${PYTORCH_ENV_DIR}/miniconda3/ +export CMAKE_PREFIX_PATH=${WORKSPACE_DIR}/miniconda3/ # Test PyTorch if [ -z "${IN_CIRCLECI}" ]; then @@ -43,19 +29,12 @@ if [ -z "${IN_CIRCLECI}" ]; then export DEVELOPER_DIR=/Applications/Xcode9.app/Contents/Developer fi fi -export MACOSX_DEPLOYMENT_TARGET=10.9 -export CXX=clang++ -export CC=clang -# If we run too many parallel jobs, we will OOM -export MAX_JOBS=2 - -export IMAGE_COMMIT_TAG=${BUILD_ENVIRONMENT}-${IMAGE_COMMIT_ID} # Download torch binaries in the test jobs if [ -z "${IN_CIRCLECI}" ]; then - rm -rf ${PYTORCH_ENV_DIR}/miniconda3/lib/python3.6/site-packages/torch* + rm -rf ${WORKSPACE_DIR}/miniconda3/lib/python3.6/site-packages/torch* aws s3 cp s3://ossci-macos-build/pytorch/${IMAGE_COMMIT_TAG}.7z ${IMAGE_COMMIT_TAG}.7z - 7z x ${IMAGE_COMMIT_TAG}.7z -o"${PYTORCH_ENV_DIR}/miniconda3/lib/python3.6/site-packages" + 7z x ${IMAGE_COMMIT_TAG}.7z -o"${WORKSPACE_DIR}/miniconda3/lib/python3.6/site-packages" fi # Test that OpenMP is enabled @@ -67,6 +46,10 @@ fi popd test_python_all() { + # The CircleCI worker hostname doesn't resolve to an address. + # This environment variable makes ProcessGroupGloo default to + # using the address associated with the loopback interface. + export GLOO_SOCKET_IFNAME=lo0 echo "Ninja version: $(ninja --version)" python test/run_test.py --verbose assert_git_not_dirty diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index aa2c9d3a9d0c8..c369ad1d58137 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -56,8 +56,6 @@ if [[ "$BUILD_ENVIRONMENT" != *ppc64le* ]]; then pip_install --user mypy || true fi -pip_install --user requests - # faulthandler become built-in since 3.3 if [[ ! $(python -c "import sys; print(int(sys.version_info >= (3, 3)))") == "1" ]]; then pip_install --user faulthandler @@ -177,22 +175,47 @@ test_xla() { export XLA_USE_XRT=1 XRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0" export XRT_WORKERS="localservice:0;grpc://localhost:40934" pushd xla - python test/test_operations.py + echo "Running Python Tests" + ./test/run_tests.sh + + echo "Running MNIST Test" python test/test_train_mnist.py --tidy + + echo "Running C++ Tests" + pushd test/cpp + CC=clang-7 CXX=clang++-7 ./run_tests.sh + popd + assert_git_not_dirty +} + +# Do NOT run this test before any other tests, like test_python_nn, etc. +# Because this function uninstalls the torch built from branch, and install +# nightly version. +test_backward_compatibility() { + set -x + pushd test/backward_compatibility + python dump_all_function_schemas.py --filename new_schemas.txt + pip_uninstall torch + pip_install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + python check_backward_compatibility.py --new-schemas new_schemas.txt popd + set +x assert_git_not_dirty } (cd test && python -c "import torch; print(torch.__config__.show())") (cd test && python -c "import torch; print(torch.__config__.parallel_info())") -if [[ "${BUILD_ENVIRONMENT}" == *xla* ]]; then +if [[ "${BUILD_ENVIRONMENT}" == *backward* ]]; then + test_backward_compatibility + # Do NOT add tests after bc check tests, see its comment. +elif [[ "${BUILD_ENVIRONMENT}" == *xla* || "${JOB_BASE_NAME}" == *xla* ]]; then test_torchvision test_xla -elif [[ "${BUILD_ENVIRONMENT}" == *-test1 ]]; then +elif [[ "${BUILD_ENVIRONMENT}" == *-test1 || "${JOB_BASE_NAME}" == *-test1 ]]; then test_torchvision test_python_nn -elif [[ "${BUILD_ENVIRONMENT}" == *-test2 ]]; then +elif [[ "${BUILD_ENVIRONMENT}" == *-test2 || "${JOB_BASE_NAME}" == *-test2 ]]; then test_python_all_except_nn test_aten test_libtorch diff --git a/CMakeLists.txt b/CMakeLists.txt index 34ed1c6206d00..34e01d69ead07 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,6 +83,17 @@ else () set(CPU_INTEL OFF) endif () +# For non-supported platforms, turn USE_DISTRIBUTED off by default. +# It is not tested and likely won't work without additional changes. +if(NOT LINUX) + set(USE_DISTRIBUTED OFF CACHE STRING "Use distributed") + # On macOS, if USE_DISTRIBUTED is enabled (specified by the user), + # then make Gloo build with the libuv transport. + if(APPLE AND USE_DISTRIBUTED) + set(USE_LIBUV ON CACHE STRING "") + endif() +endif() + # ---[ Options. # Note to developers: if you add an option below, make sure you also add it to # cmake/Summary.cmake so that the summary prints out the option values. @@ -151,6 +162,7 @@ option(USE_OPENCV "Use OpenCV" OFF) option(USE_OPENMP "Use OpenMP for parallel code" ON) option(USE_PROF "Use profiling" OFF) option(USE_QNNPACK "Use QNNPACK (quantized 8-bit operators)" ON) +option(USE_PYTORCH_QNNPACK "Use ATen/QNNPACK (quantized 8-bit operators)" ON) option(USE_REDIS "Use Redis" OFF) option(USE_ROCKSDB "Use RocksDB" OFF) option(USE_SNPE "Use Qualcomm's SNPE library" OFF) @@ -173,9 +185,6 @@ cmake_dependent_option( cmake_dependent_option( USE_GLOO "Use Gloo. Only available if USE_DISTRIBUTED is on." ON "USE_DISTRIBUTED" OFF) -cmake_dependent_option( - USE_GLOO_IBVERBS "Use Gloo IB verbs for distributed. Only available if USE_GLOO is on." OFF - "USE_GLOO" OFF) option(USE_TBB "Use TBB" OFF) # Used when building Caffe2 through setup.py @@ -264,9 +273,6 @@ if (MSVC) # Try harder list(APPEND CUDA_NVCC_FLAGS "-Xcompiler /w -w") - - # Turning off USE_DISTRIBUTED on default - set(USE_DISTRIBUTED OFF) endif(MSVC) # Set INTERN_BUILD_MOBILE for all mobile builds. Components that are not @@ -275,6 +281,13 @@ if (ANDROID OR IOS) set(INTERN_BUILD_MOBILE ON) endif() +# Setting `PYTORCH_BUILD_MOBILE` environment variable can force it to do mobile +# build with host toolchain. +if (DEFINED ENV{PYTORCH_BUILD_MOBILE}) + set(INTERN_BUILD_MOBILE ON) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DC10_MOBILE") +endif() + # INTERN_BUILD_ATEN_OPS is used to control whether to build ATen/TH operators. # It's disabled for caffe2 mobile library. if (INTERN_BUILD_MOBILE AND BUILD_CAFFE2_MOBILE) @@ -288,12 +301,18 @@ endif() # When it's disabled it builds libtorch mobile library, which contains ATen/TH ops and native support for # TorchScript model, but doesn't contain not-yet-unified caffe2 ops; if (INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE) + if (NOT BUILD_SHARED_LIBS) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNO_EXPORT") + endif() set(BUILD_PYTHON OFF) set(BUILD_CAFFE2_OPS OFF) set(USE_DISTRIBUTED OFF) set(FEATURE_TORCH_MOBILE ON) set(NO_API ON) set(USE_FBGEMM OFF) + set(USE_STATIC_DISPATCH ON) + set(INTERN_DISABLE_ONNX ON) + set(INTERN_DISABLE_AUTOGRAD ON) endif() # ---[ Utils @@ -362,6 +381,10 @@ if(USE_QNNPACK) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_QNNPACK") endif() +if(USE_PYTORCH_QNNPACK) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_PYTORCH_QNNPACK") +endif() + if(USE_STATIC_DISPATCH) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_STATIC_DISPATCH") endif() diff --git a/README.md b/README.md index 418a7603f4ed3..cb8f8305d8e46 100644 --- a/README.md +++ b/README.md @@ -169,7 +169,7 @@ If you are building for NVIDIA's Jetson platforms (Jetson Nano, TX1, TX2, AGX Xa Common ``` -conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests +conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing ``` On Linux diff --git a/android/gradle.properties b/android/gradle.properties index e7742b33dd93b..ec9e4008fa562 100644 --- a/android/gradle.properties +++ b/android/gradle.properties @@ -1,8 +1,8 @@ ABI_FILTERS=armeabi-v7a,arm64-v8a,x86,x86_64 -VERSION_NAME=0.0.4 +VERSION_NAME=0.0.7-SNAPSHOT GROUP=org.pytorch -MAVEN_GROUP=com.facebook +MAVEN_GROUP=org.pytorch POM_URL=https://github.com/pytorch/pytorch/tree/master/android POM_SCM_URL=https://github.com/pytorch/pytorch.git POM_SCM_CONNECTION=scm:git:https://github.com/pytorch/pytorch @@ -11,9 +11,14 @@ POM_LICENSE_NAME=BSD 3-Clause POM_LICENSE_URL=https://github.com/pytorch/pytorch/blob/master/LICENSE POM_ISSUES_URL=https://github.com/pytorch/pytorch/issues POM_LICENSE_DIST=repo -POM_DEVELOPER_ID=facebook -POM_DEVELOPER_NAME=facebook +POM_DEVELOPER_ID=pytorch +POM_DEVELOPER_NAME=pytorch +syncWithMavenCentral=true GRADLE_BINTRAY_PLUGIN_VERSION=1.8.0 GRADLE_VERSIONS_PLUGIN_VERSION=0.15.0 ANDROID_MAVEN_GRADLE_PLUGIN_VERSION=2.1 + +# Gradle internals +org.gradle.internal.repository.max.retries=1 +org.gradle.jvmargs=-XX:MaxMetaspaceSize=1024m diff --git a/android/gradle/android_tasks.gradle b/android/gradle/android_tasks.gradle index a828b663a4779..ca188ac72d078 100644 --- a/android/gradle/android_tasks.gradle +++ b/android/gradle/android_tasks.gradle @@ -18,12 +18,12 @@ afterEvaluate { project -> } task androidJavadocJar(type: Jar, dependsOn: androidJavadoc) { - archiveClassifier.set('javadoc') + classifier = 'javadoc' from androidJavadoc.destinationDir } task androidSourcesJar(type: Jar) { - archiveClassifier.set('sources') + classifier = 'sources' from android.sourceSets.main.java.srcDirs } @@ -61,12 +61,12 @@ afterEvaluate { project -> if (POM_PACKAGING == 'jar') { task javadocJar(type: Jar, dependsOn: javadoc) { - archiveClassifier.set('javadoc') + classifier = 'javadoc' from javadoc.destinationDir } task sourcesJar(type: Jar, dependsOn: classes) { - archiveClassifier.set('sources') + classifier = 'sources' from sourceSets.main.allSource } diff --git a/android/gradle/release_bintray.gradle b/android/gradle/release_bintray.gradle index ed118fc41cc16..dc2bcc34003d5 100644 --- a/android/gradle/release_bintray.gradle +++ b/android/gradle/release_bintray.gradle @@ -1,6 +1,6 @@ ext { bintrayRepo = 'maven' - bintrayUserOrg = 'facebook' + bintrayUserOrg = 'pytorch' bintrayName = "${GROUP}:${POM_ARTIFACT_ID}" bintrayDescription = POM_DESCRIPTION projectUrl = POM_URL diff --git a/android/libs/fbjni_local/build.gradle b/android/libs/fbjni_local/build.gradle index 1b642c74c94c7..d4abff87fdc63 100644 --- a/android/libs/fbjni_local/build.gradle +++ b/android/libs/fbjni_local/build.gradle @@ -18,7 +18,14 @@ android { } } } - + buildTypes { + debug { + minifyEnabled false + } + release { + minifyEnabled false + } + } externalNativeBuild { cmake { path "../fbjni/CMakeLists.txt" @@ -31,3 +38,10 @@ dependencies { } apply from: rootProject.file('gradle/release.gradle') + +task sourcesJar(type: Jar) { + from android.sourceSets.main.java.srcDirs + classifier = 'sources' +} + +artifacts.add('archives', sourcesJar) diff --git a/android/pytorch_android/CMakeLists.txt b/android/pytorch_android/CMakeLists.txt index ea30768005d26..b1b4587e5db5b 100644 --- a/android/pytorch_android/CMakeLists.txt +++ b/android/pytorch_android/CMakeLists.txt @@ -6,16 +6,36 @@ set(CMAKE_VERBOSE_MAKEFILE ON) set(pytorch_android_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp) set(libtorch_include_DIR ${pytorch_android_DIR}/libtorch_include/${ANDROID_ABI}) -set(libtorch_SO ${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/libtorch.so) -set(libc10_SO ${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/libc10.so) +set(libtorch_FILE ${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/libtorch.a) +set(libc10_FILE ${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/libc10.a) +set(libnnpack_FILE ${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/libnnpack.a) +set(libqnnpack_FILE ${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/libqnnpack.a) +set(libpytorch_qnnpack_FILE ${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/libpytorch_qnnpack.a) +set(libcpuinfo_FILE ${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/libcpuinfo.a) +set(libclog_FILE ${CMAKE_CURRENT_LIST_DIR}/src/main/jniLibs/${ANDROID_ABI}/libclog.a) message(STATUS "libtorch dir:${libtorch_DIR}") -add_library(libtorch SHARED IMPORTED) -set_property(TARGET libtorch PROPERTY IMPORTED_LOCATION ${libtorch_SO}) +add_library(libtorch STATIC IMPORTED) +set_property(TARGET libtorch PROPERTY IMPORTED_LOCATION ${libtorch_FILE}) -add_library(libc10 SHARED IMPORTED ${libc10_SO}) -set_property(TARGET libc10 PROPERTY IMPORTED_LOCATION ${libc10_SO}) +add_library(libc10 STATIC IMPORTED) +set_property(TARGET libc10 PROPERTY IMPORTED_LOCATION ${libc10_FILE}) + +add_library(libnnpack STATIC IMPORTED) +set_property(TARGET libnnpack PROPERTY IMPORTED_LOCATION ${libnnpack_FILE}) + +add_library(libqnnpack STATIC IMPORTED) +set_property(TARGET libqnnpack PROPERTY IMPORTED_LOCATION ${libqnnpack_FILE}) + +add_library(libpytorch_qnnpack STATIC IMPORTED) +set_property(TARGET libpytorch_qnnpack PROPERTY IMPORTED_LOCATION ${libpytorch_qnnpack_FILE}) + +add_library(libcpuinfo STATIC IMPORTED) +set_property(TARGET libcpuinfo PROPERTY IMPORTED_LOCATION ${libcpuinfo_FILE}) + +add_library(libclog STATIC IMPORTED) +set_property(TARGET libclog PROPERTY IMPORTED_LOCATION ${libclog_FILE}) file(GLOB pytorch_android_SOURCES ${pytorch_android_DIR}/*.cpp @@ -29,6 +49,8 @@ target_compile_options(pytorch PRIVATE -fexceptions ) +target_compile_definitions(pytorch PRIVATE USE_STATIC_DISPATCH) + target_include_directories(pytorch PUBLIC ${libtorch_include_DIR} ) @@ -43,6 +65,14 @@ add_subdirectory(${fbjni_DIR} ${fbjni_BUILD_DIR}) target_link_libraries(pytorch fbjni + -Wl,--gc-sections + -Wl,--whole-archive libtorch + -Wl,--no-whole-archive libc10 + libnnpack + libqnnpack + libpytorch_qnnpack + libcpuinfo + libclog ) diff --git a/android/pytorch_android/build.gradle b/android/pytorch_android/build.gradle index 42b40ac448e7e..0b8df56f1e08e 100644 --- a/android/pytorch_android/build.gradle +++ b/android/pytorch_android/build.gradle @@ -35,6 +35,10 @@ android { } } + packagingOptions { + exclude '**/libfbjni.so' + } + useLibrary 'android.test.runner' useLibrary 'android.test.base' useLibrary 'android.test.mock' @@ -59,7 +63,7 @@ apply from: rootProject.file('gradle/release.gradle') task sourcesJar(type: Jar) { from android.sourceSets.main.java.srcDirs - archiveClassifier.set('sources') + classifier = 'sources' } artifacts.add('archives', sourcesJar) diff --git a/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchInstrumentedTests.java b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchInstrumentedTests.java index 5cdeb30b63b23..44ecdac405fa7 100644 --- a/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchInstrumentedTests.java +++ b/android/pytorch_android/src/androidTest/java/org/pytorch/PytorchInstrumentedTests.java @@ -1,13 +1,11 @@ package org.pytorch; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - import android.content.Context; -import androidx.test.ext.junit.runners.AndroidJUnit4; -import androidx.test.platform.app.InstrumentationRegistry; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + import java.io.File; import java.io.FileOutputStream; import java.io.IOException; @@ -15,9 +13,14 @@ import java.io.OutputStream; import java.util.HashMap; import java.util.Map; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; + +import androidx.test.ext.junit.runners.AndroidJUnit4; +import androidx.test.platform.app.InstrumentationRegistry; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; @RunWith(AndroidJUnit4.class) public class PytorchInstrumentedTests { @@ -33,7 +36,7 @@ public void setUp() { public void testForwardNull() throws IOException { final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); final IValue input = - IValue.tensor(Tensor.newByteTensor(new long[] {1}, Tensor.allocateByteBuffer(1))); + IValue.tensor(Tensor.newInt8Tensor(new long[] {1}, Tensor.allocateByteBuffer(1))); assertTrue(input.isTensor()); final IValue output = module.forward(input); assertTrue(output.isNull()); @@ -94,13 +97,13 @@ public void testEqFloat() throws IOException { @Test public void testEqTensor() throws IOException { - final long[] inputTensorDims = new long[] {1, 3, 224, 224}; - final long numElements = Tensor.numElements(inputTensorDims); + final long[] inputTensorShape = new long[] {1, 3, 224, 224}; + final long numElements = Tensor.numel(inputTensorShape); final float[] inputTensorData = new float[(int) numElements]; for (int i = 0; i < numElements; ++i) { inputTensorData[i] = i; } - final Tensor inputTensor = Tensor.newFloatTensor(inputTensorDims, inputTensorData); + final Tensor inputTensor = Tensor.newFloat32Tensor(inputTensorShape, inputTensorData); final Module module = Module.load(assetFilePath(TEST_MODULE_ASSET_NAME)); final IValue input = IValue.tensor(inputTensor); @@ -110,7 +113,7 @@ public void testEqTensor() throws IOException { assertTrue(output.isTensor()); final Tensor outputTensor = output.getTensor(); assertNotNull(outputTensor); - assertArrayEquals(inputTensorDims, outputTensor.dims); + assertArrayEquals(inputTensorShape, outputTensor.shape); float[] outputData = outputTensor.getDataAsFloatArray(); for (int i = 0; i < numElements; i++) { assertTrue(inputTensorData[i] == outputData[i]); @@ -216,8 +219,8 @@ public void testRunUndefinedMethod() throws IOException { @Test public void testTensorMethods() { - long[] dims = new long[] {1, 3, 224, 224}; - final int numel = (int) Tensor.numElements(dims); + long[] shape = new long[] {1, 3, 224, 224}; + final int numel = (int) Tensor.numel(shape); int[] ints = new int[numel]; float[] floats = new float[numel]; @@ -228,16 +231,16 @@ public void testTensorMethods() { floats[i] = i / 1000.f; } - Tensor tensorBytes = Tensor.newByteTensor(dims, bytes); - assertTrue(tensorBytes.isByteTensor()); + Tensor tensorBytes = Tensor.newInt8Tensor(shape, bytes); + assertTrue(tensorBytes.dtype() == Tensor.DTYPE_INT8); assertArrayEquals(bytes, tensorBytes.getDataAsByteArray()); - Tensor tensorInts = Tensor.newIntTensor(dims, ints); - assertTrue(tensorInts.isIntTensor()); + Tensor tensorInts = Tensor.newInt32Tensor(shape, ints); + assertTrue(tensorInts.dtype() == Tensor.DTYPE_INT32); assertArrayEquals(ints, tensorInts.getDataAsIntArray()); - Tensor tensorFloats = Tensor.newFloatTensor(dims, floats); - assertTrue(tensorFloats.isFloatTensor()); + Tensor tensorFloats = Tensor.newFloat32Tensor(shape, floats); + assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32); float[] floatsOut = tensorFloats.getDataAsFloatArray(); assertTrue(floatsOut.length == numel); for (int i = 0; i < numel; i++) { @@ -247,11 +250,11 @@ public void testTensorMethods() { @Test(expected = IllegalStateException.class) public void testTensorIllegalStateOnWrongType() { - long[] dims = new long[] {1, 3, 224, 224}; - final int numel = (int) Tensor.numElements(dims); + long[] shape = new long[] {1, 3, 224, 224}; + final int numel = (int) Tensor.numel(shape); float[] floats = new float[numel]; - Tensor tensorFloats = Tensor.newFloatTensor(dims, floats); - assertTrue(tensorFloats.isFloatTensor()); + Tensor tensorFloats = Tensor.newFloat32Tensor(shape, floats); + assertTrue(tensorFloats.dtype() == Tensor.DTYPE_FLOAT32); tensorFloats.getDataAsByteArray(); } diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni.cpp b/android/pytorch_android/src/main/cpp/pytorch_jni.cpp index 50ffcecd808f3..56988a764567f 100644 --- a/android/pytorch_android/src/main/cpp/pytorch_jni.cpp +++ b/android/pytorch_android/src/main/cpp/pytorch_jni.cpp @@ -10,9 +10,12 @@ namespace pytorch_jni { -constexpr static int kTensorTypeCodeByte = 1; -constexpr static int kTensorTypeCodeInt32 = 2; -constexpr static int kTensorTypeCodeFloat32 = 3; +constexpr static int kTensorDTypeUInt8 = 1; +constexpr static int kTensorDTypeInt8 = 2; +constexpr static int kTensorDTypeInt32 = 3; +constexpr static int kTensorDTypeFloat32 = 4; +constexpr static int kTensorDTypeInt64 = 5; +constexpr static int kTensorDTypeFloat64 = 6; template struct JHashMap @@ -41,53 +44,60 @@ struct JHashMap }; static at::Tensor newAtTensor( - facebook::jni::alias_ref inputData, - facebook::jni::alias_ref inputDims, - jint typeCode) { - const auto inputDimsRank = inputDims->size(); - const auto inputDimsArr = inputDims->getRegion(0, inputDimsRank); - std::vector inputDimsVec; - auto inputNumel = 1; - for (auto i = 0; i < inputDimsRank; ++i) { - inputDimsVec.push_back(inputDimsArr[i]); - inputNumel *= inputDimsArr[i]; + facebook::jni::alias_ref jbuffer, + facebook::jni::alias_ref jshape, + jint jdtype) { + const auto rank = jshape->size(); + const auto shapeArr = jshape->getRegion(0, rank); + std::vector shapeVec{}; + shapeVec.reserve(rank); + auto numel = 1; + for (auto i = 0; i < rank; ++i) { + shapeVec.push_back(shapeArr[i]); + numel *= shapeArr[i]; } JNIEnv* jni = facebook::jni::Environment::current(); - caffe2::TypeMeta inputTypeMeta{}; - int inputDataElementSizeBytes = 0; - if (kTensorTypeCodeFloat32 == typeCode) { - inputDataElementSizeBytes = 4; - inputTypeMeta = caffe2::TypeMeta::Make(); - } else if (kTensorTypeCodeInt32 == typeCode) { - inputDataElementSizeBytes = 4; - inputTypeMeta = caffe2::TypeMeta::Make(); - } else if (kTensorTypeCodeByte == typeCode) { - inputDataElementSizeBytes = 1; - inputTypeMeta = caffe2::TypeMeta::Make(); + caffe2::TypeMeta typeMeta{}; + int dataElementSizeBytes = 0; + if (kTensorDTypeFloat32 == jdtype) { + dataElementSizeBytes = 4; + typeMeta = caffe2::TypeMeta::Make(); + } else if (kTensorDTypeInt32 == jdtype) { + dataElementSizeBytes = 4; + typeMeta = caffe2::TypeMeta::Make(); + } else if (kTensorDTypeInt8 == jdtype) { + dataElementSizeBytes = 1; + typeMeta = caffe2::TypeMeta::Make(); + } else if (kTensorDTypeUInt8 == jdtype) { + dataElementSizeBytes = 1; + typeMeta = caffe2::TypeMeta::Make(); + } else if (kTensorDTypeFloat64 == jdtype) { + dataElementSizeBytes = 8; + typeMeta = caffe2::TypeMeta::Make(); + } else if (kTensorDTypeInt64 == jdtype) { + dataElementSizeBytes = 8; + typeMeta = caffe2::TypeMeta::Make(); } else { facebook::jni::throwNewJavaException( facebook::jni::gJavaLangIllegalArgumentException, - "Unknown Tensor typeCode %d", - typeCode); + "Unknown Tensor jdtype %d", + jdtype); } - const auto inputDataCapacity = jni->GetDirectBufferCapacity(inputData.get()); - if (inputDataCapacity != inputNumel) { + const auto dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get()); + if (dataCapacity != numel) { facebook::jni::throwNewJavaException( facebook::jni::gJavaLangIllegalArgumentException, "Tensor dimensions(elements number:%d, element byte size:%d, total " "bytes:%d) inconsistent with buffer capacity(%d)", - inputNumel, - inputDataElementSizeBytes, - inputNumel * inputDataElementSizeBytes, - inputDataCapacity); + numel, + dataElementSizeBytes, + numel * dataElementSizeBytes, + dataCapacity); } - - at::Tensor inputTensor = torch::empty(torch::IntArrayRef(inputDimsVec)); - inputTensor.unsafeGetTensorImpl()->ShareExternalPointer( - {jni->GetDirectBufferAddress(inputData.get()), at::DeviceType::CPU}, - inputTypeMeta, - inputDataCapacity); - return inputTensor; + return torch::from_blob( + jni->GetDirectBufferAddress(jbuffer.get()), + torch::IntArrayRef(shapeVec), + at::TensorOptions(typeMeta)); } class JTensor : public facebook::jni::JavaClass { @@ -96,8 +106,8 @@ class JTensor : public facebook::jni::JavaClass { static facebook::jni::local_ref newJTensor( facebook::jni::alias_ref jBuffer, - facebook::jni::alias_ref jDims, - jint typeCode) { + facebook::jni::alias_ref jShape, + jint jdtype) { static auto jMethodNewTensor = JTensor::javaClassStatic() ->getStaticMethod( @@ -105,35 +115,41 @@ class JTensor : public facebook::jni::JavaClass { facebook::jni::alias_ref, jint)>("nativeNewTensor"); return jMethodNewTensor( - JTensor::javaClassStatic(), jBuffer, jDims, typeCode); + JTensor::javaClassStatic(), jBuffer, jShape, jdtype); } static facebook::jni::local_ref newJTensorFromAtTensor( const at::Tensor& tensor) { const auto scalarType = tensor.scalar_type(); - int typeCode = 0; + int jdtype = 0; if (at::kFloat == scalarType) { - typeCode = kTensorTypeCodeFloat32; + jdtype = kTensorDTypeFloat32; } else if (at::kInt == scalarType) { - typeCode = kTensorTypeCodeInt32; + jdtype = kTensorDTypeInt32; } else if (at::kByte == scalarType) { - typeCode = kTensorTypeCodeByte; + jdtype = kTensorDTypeUInt8; + } else if (at::kChar == scalarType) { + jdtype = kTensorDTypeInt8; + } else if (at::kLong == scalarType) { + jdtype = kTensorDTypeInt64; + } else if (at::kDouble == scalarType) { + jdtype = kTensorDTypeFloat64; } else { facebook::jni::throwNewJavaException( facebook::jni::gJavaLangIllegalArgumentException, "at::Tensor scalar type is not supported on java side"); } - const auto& tensorDims = tensor.sizes(); - std::vector tensorDimsVec; - for (const auto& dim : tensorDims) { - tensorDimsVec.push_back(dim); + const auto& tensorShape = tensor.sizes(); + std::vector tensorShapeVec; + for (const auto& s : tensorShape) { + tensorShapeVec.push_back(s); } - facebook::jni::local_ref jTensorDims = - facebook::jni::make_long_array(tensorDimsVec.size()); + facebook::jni::local_ref jTensorShape = + facebook::jni::make_long_array(tensorShapeVec.size()); - jTensorDims->setRegion(0, tensorDimsVec.size(), tensorDimsVec.data()); + jTensorShape->setRegion(0, tensorShapeVec.size(), tensorShapeVec.data()); facebook::jni::local_ref jTensorBuffer = facebook::jni::JByteBuffer::allocateDirect(tensor.nbytes()); @@ -142,18 +158,18 @@ class JTensor : public facebook::jni::JavaClass { jTensorBuffer->getDirectBytes(), tensor.storage().data(), tensor.nbytes()); - return JTensor::newJTensor(jTensorBuffer, jTensorDims, typeCode); + return JTensor::newJTensor(jTensorBuffer, jTensorShape, jdtype); } static at::Tensor newAtTensorFromJTensor( facebook::jni::alias_ref jtensor) { - static const auto typeCodeMethod = - JTensor::javaClassStatic()->getMethod("getTypeCode"); - jint typeCode = typeCodeMethod(jtensor); + static const auto dtypeMethod = + JTensor::javaClassStatic()->getMethod("dtype"); + jint jdtype = dtypeMethod(jtensor); - static const auto dimsField = - JTensor::javaClassStatic()->getField("dims"); - auto jdims = jtensor->getFieldValue(dimsField); + static const auto shapeField = + JTensor::javaClassStatic()->getField("shape"); + auto jshape = jtensor->getFieldValue(shapeField); static auto dataBufferMethod = JTensor::javaClassStatic() @@ -162,7 +178,7 @@ class JTensor : public facebook::jni::JavaClass { "getRawDataBuffer"); facebook::jni::local_ref jbuffer = dataBufferMethod(jtensor); - return newAtTensor(jbuffer, jdims, typeCode); + return newAtTensor(jbuffer, jshape, jdtype); } }; @@ -308,7 +324,7 @@ class JIValue : public facebook::jni::JavaClass { return jMethodListArr(JIValue::javaClassStatic(), jArray); } else if (ivalue.isGenericDict()) { auto dict = ivalue.toGenericDict(); - const auto keyType = dict._keyType(); + const auto keyType = dict.keyType(); if (!keyType) { facebook::jni::throwNewJavaException( @@ -316,7 +332,7 @@ class JIValue : public facebook::jni::JavaClass { "Unknown IValue-Dict key type"); } - const auto keyTypeKind = keyType.value()->kind(); + const auto keyTypeKind = keyType->kind(); if (c10::TypeKind::StringType == keyTypeKind) { static auto jMethodDictStringKey = JIValue::javaClassStatic() @@ -405,17 +421,12 @@ class JIValue : public facebook::jni::JavaClass { std::vector elements; elements.reserve(n); - std::vector types; - types.reserve(n); for (auto i = 0; i < n; ++i) { auto jivalue_element = jarray->getElement(i); auto element = JIValue::JIValueToAtIValue(jivalue_element); - c10::TypePtr typePtr = c10::attemptToRecoverType(element); elements.push_back(std::move(element)); - types.push_back(std::move(typePtr)); } - return c10::ivalue::Tuple::create( - std::move(elements), c10::TupleType::create(std::move(types))); + return c10::ivalue::Tuple::create(std::move(elements)); } else if (JIValue::kTypeCodeBoolList == typeCode) { static const auto jMethodGetBoolList = JIValue::javaClassStatic()->getMethod("getBoolList"); @@ -579,7 +590,11 @@ class PytorchJni : public facebook::jni::HybridClass { at::IValue atIValue = JIValue::JIValueToAtIValue(jinputs->getElement(i)); inputs.push_back(std::move(atIValue)); } - auto output = module_.forward(std::move(inputs)); + auto output = [&]() { + torch::autograd::AutoGradMode guard(false); + at::AutoNonVariableTypeMode non_var_type_mode(true); + return module_.forward(std::move(inputs)); + }(); return JIValue::newJIValueFromAtIValue(output); } @@ -598,7 +613,11 @@ class PytorchJni : public facebook::jni::HybridClass { inputs.push_back(std::move(atIValue)); } if (auto method = module_.find_method(methodName)) { - auto output = (*method)(std::move(inputs)); + auto output = [&]() { + torch::autograd::AutoGradMode guard(false); + at::AutoNonVariableTypeMode non_var_type_mode(true); + return (*method)(std::move(inputs)); + }(); return JIValue::newJIValueFromAtIValue(output); } diff --git a/android/pytorch_android/src/main/java/org/pytorch/IValue.java b/android/pytorch_android/src/main/java/org/pytorch/IValue.java index f85e6ddb192ab..acb7512aa2d09 100644 --- a/android/pytorch_android/src/main/java/org/pytorch/IValue.java +++ b/android/pytorch_android/src/main/java/org/pytorch/IValue.java @@ -3,6 +3,12 @@ import java.util.Locale; import java.util.Map; +/** + * Java representation of a torchscript variable, which is implemented as tagged union that can be + * one of the supported types: https://pytorch.org/docs/stable/jit.html#types. + *

+ * Calling getters for inappropriate types will throw IllegalStateException. + */ public class IValue { private static final int TYPE_CODE_NULL = 1; @@ -84,54 +90,81 @@ public static IValue optionalNull() { return new IValue(TYPE_CODE_NULL); } + /** + * Creates a new IValue instance of torchscript Tensor type. + */ public static IValue tensor(Tensor tensor) { final IValue iv = new IValue(TYPE_CODE_TENSOR); iv.mData = tensor; return iv; } + /** + * Creates a new IValue instance of torchscript bool type. + */ public static IValue bool(boolean value) { final IValue iv = new IValue(TYPE_CODE_BOOL); iv.mData = value; return iv; } + /** + * Creates a new IValue instance of torchscript int type. + */ public static IValue long64(long value) { final IValue iv = new IValue(TYPE_CODE_LONG); iv.mData = value; return iv; } + /** + * Creates a new IValue instance of torchscript float type. + */ public static IValue double64(double value) { final IValue iv = new IValue(TYPE_CODE_DOUBLE); iv.mData = value; return iv; } + /** + * Creates a new IValue instance of torchscript List[bool] type. + */ public static IValue boolList(boolean... list) { final IValue iv = new IValue(TYPE_CODE_BOOL_LIST); iv.mData = list; return iv; } + /** + * Creates a new IValue instance of torchscript List[int] type. + */ public static IValue longList(long... list) { final IValue iv = new IValue(TYPE_CODE_LONG_LIST); iv.mData = list; return iv; } + /** + * Creates a new IValue instance of torchscript List[float] type. + */ public static IValue doubleList(double... list) { final IValue iv = new IValue(TYPE_CODE_DOUBLE_LIST); iv.mData = list; return iv; } + /** + * Creates a new IValue instance of torchscript List[Tensor] type. + */ public static IValue tensorList(Tensor... list) { final IValue iv = new IValue(TYPE_CODE_TENSOR_LIST); iv.mData = list; return iv; } + /** + * Creates a new IValue instance of torchscript List[T] type. All elements must have the same type. + */ public static IValue list(IValue... array) { final int size = array.length; if (size > 0) { @@ -148,18 +181,27 @@ public static IValue list(IValue... array) { return iv; } + /** + * Creates a new IValue instance of torchscript Tuple[T0, T1, ...] type. + */ public static IValue tuple(IValue... array) { final IValue iv = new IValue(TYPE_CODE_TUPLE); iv.mData = array; return iv; } + /** + * Creates a new IValue instance oftorchscript Dict[Str, V] type. + */ public static IValue dictStringKey(Map map) { final IValue iv = new IValue(TYPE_CODE_DICT_STRING_KEY); iv.mData = map; return iv; } + /** + * Creates a new IValue instance of torchscript Dict[int, V] type. + */ public static IValue dictLongKey(Map map) { final IValue iv = new IValue(TYPE_CODE_DICT_LONG_KEY); iv.mData = map; diff --git a/android/pytorch_android/src/main/java/org/pytorch/Module.java b/android/pytorch_android/src/main/java/org/pytorch/Module.java index 38dfc8ddf59d3..8a12a4ee9a31d 100644 --- a/android/pytorch_android/src/main/java/org/pytorch/Module.java +++ b/android/pytorch_android/src/main/java/org/pytorch/Module.java @@ -4,22 +4,45 @@ import com.facebook.jni.HybridData; +/** + * Java holder for torch::jit::script::Module which owns it on jni side. + */ public class Module { private NativePeer mNativePeer; + /** + * Loads serialized torchscript module from the specified absolute path on the disk. + * + * @param modelAbsolutePath absolute path to file that contains the serialized torchscript module. + * @return new {@link org.pytorch.Module} object which owns torch::jit::script::Module on jni + * side. + */ public static Module load(final String modelAbsolutePath) { return new Module(modelAbsolutePath); } - private Module(final String modelAbsolutePath) { - this.mNativePeer = new NativePeer(modelAbsolutePath); + private Module(final String moduleAbsolutePath) { + this.mNativePeer = new NativePeer(moduleAbsolutePath); } + /** + * Runs 'forward' method of loaded torchscript module with specified arguments. + * + * @param inputs arguments for torchscript module 'forward' method. + * @return result of torchscript module 'forward' method evaluation + */ public IValue forward(IValue... inputs) { return mNativePeer.forward(inputs); } + /** + * Runs specified method of loaded torchscript module with specified arguments. + * + * @param methodName torchscript module method to run + * @param inputs arguments that will be specified to torchscript module method call + * @return result of torchscript module specified method evaluation + */ public IValue runMethod(String methodName, IValue... inputs) { return mNativePeer.runMethod(methodName, inputs); } diff --git a/android/pytorch_android/src/main/java/org/pytorch/Tensor.java b/android/pytorch_android/src/main/java/org/pytorch/Tensor.java index ee595fe3ab411..4178a9adfb098 100644 --- a/android/pytorch_android/src/main/java/org/pytorch/Tensor.java +++ b/android/pytorch_android/src/main/java/org/pytorch/Tensor.java @@ -3,178 +3,518 @@ import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; +import java.nio.LongBuffer; import java.util.Arrays; import java.util.Locale; +/** + * Representation of Tensor. Tensor shape is stored in {@link Tensor#shape}, elements are stored as + * {@link java.nio.DirectByteBuffer} of one of the supported types. + */ public abstract class Tensor { - private static final int TYPE_CODE_BYTE = 1; - private static final int TYPE_CODE_INT32 = 2; - private static final int TYPE_CODE_FLOAT32 = 3; + + /** Code for dtype torch.uint8. {@link Tensor#dtype()} */ + public static final int DTYPE_UINT8 = 1; + /** Code for dtype torch.int8. {@link Tensor#dtype()} */ + public static final int DTYPE_INT8 = 2; + /** Code for dtype torch.int32. {@link Tensor#dtype()} */ + public static final int DTYPE_INT32 = 3; + /** Code for dtype torch.float32. {@link Tensor#dtype()} */ + public static final int DTYPE_FLOAT32 = 4; + /** Code for dtype torch.int64. {@link Tensor#dtype()} */ + public static final int DTYPE_INT64 = 5; + /** Code for dtype torch.float64. {@link Tensor#dtype()} */ + public static final int DTYPE_FLOAT64 = 6; private static final String ERROR_MSG_DATA_BUFFER_NOT_NULL = "Data buffer must be not null"; private static final String ERROR_MSG_DATA_ARRAY_NOT_NULL = "Data array must be not null"; - private static final String ERROR_MSG_DIMS_NOT_NULL = "Dims must be not null"; - private static final String ERROR_MSG_DIMS_NOT_EMPTY = "Dims must be not empty"; - private static final String ERROR_MSG_INDEX_NOT_NULL = "Index must be not null"; - private static final String ERROR_MSG_DIMS_NON_NEGATIVE = "Dims must be non negative"; + private static final String ERROR_MSG_SHAPE_NOT_NULL = "Shape must be not null"; + private static final String ERROR_MSG_SHAPE_NOT_EMPTY = "Shape must be not empty"; + private static final String ERROR_MSG_SHAPE_NON_NEGATIVE = "Shape elements must be non negative"; private static final String ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER = "Data buffer must have native byte order (java.nio.ByteOrder#nativeOrder)"; private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT = "Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)"; - public final long[] dims; + /** Shape of current tensor. */ + public final long[] shape; - private static final int FLOAT_SIZE_BYTES = 4; private static final int INT_SIZE_BYTES = 4; + private static final int FLOAT_SIZE_BYTES = 4; + private static final int LONG_SIZE_BYTES = 8; + private static final int DOUBLE_SIZE_BYTES = 8; + + /** + * Allocates a new direct {@link java.nio.ByteBuffer} with native byte order with specified + * capacity that can be used in {@link Tensor#newInt8Tensor(long[], ByteBuffer)}, {@link + * Tensor#newUInt8Tensor(long[], ByteBuffer)}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static ByteBuffer allocateByteBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements).order(ByteOrder.nativeOrder()); + } + + public static IntBuffer allocateIntBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * INT_SIZE_BYTES) + .order(ByteOrder.nativeOrder()) + .asIntBuffer(); + } + /** + * Allocates a new direct {@link java.nio.FloatBuffer} with native byte order with specified + * capacity that can be used in {@link Tensor#newFloat32Tensor(long[], FloatBuffer)}. + * + * @param numElements capacity (number of elements) of result buffer. + */ public static FloatBuffer allocateFloatBuffer(int numElements) { return ByteBuffer.allocateDirect(numElements * FLOAT_SIZE_BYTES) .order(ByteOrder.nativeOrder()) .asFloatBuffer(); } - public static IntBuffer allocateIntBuffer(int numElements) { - return ByteBuffer.allocateDirect(numElements * INT_SIZE_BYTES) + /** + * Allocates a new direct {@link java.nio.LongBuffer} with native byte order with specified + * capacity that can be used in {@link Tensor#newInt64Tensor(long[], LongBuffer)}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static LongBuffer allocateLongBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * LONG_SIZE_BYTES) .order(ByteOrder.nativeOrder()) - .asIntBuffer(); + .asLongBuffer(); } - public static ByteBuffer allocateByteBuffer(int numElements) { - return ByteBuffer.allocateDirect(numElements).order(ByteOrder.nativeOrder()); + /** + * Allocates a new direct {@link java.nio.DoubleBuffer} with native byte order with specified + * capacity that can be used in {@link Tensor#newFloat64Tensor(long[], DoubleBuffer)}. + * + * @param numElements capacity (number of elements) of result buffer. + */ + public static DoubleBuffer allocateDoubleBuffer(int numElements) { + return ByteBuffer.allocateDirect(numElements * DOUBLE_SIZE_BYTES) + .order(ByteOrder.nativeOrder()) + .asDoubleBuffer(); } - public static Tensor newFloatTensor(long[] dims, float[] data) { + /** + * Creates a new Tensor instance with dtype torch.uint8 with specified shape and data as array of + * bytes. + * + * @param shape Tensor shape + * @param data Tensor elements + */ + public static Tensor newUInt8Tensor(long[] shape, byte[] data) { checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); - checkArgument(dims != null, ERROR_MSG_DIMS_NOT_NULL); - checkDims(dims); - checkDimsAndDataCapacityConsistency(data.length, dims); - final int bufferCapacity = (int) numElements(dims); - final FloatBuffer floatBuffer = allocateFloatBuffer(bufferCapacity); - floatBuffer.put(data); - return new Tensor_float32(floatBuffer, dims); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape)); + byteBuffer.put(data); + return new Tensor_uint8(byteBuffer, shape); } - public static Tensor newIntTensor(long[] dims, int[] data) { + /** + * Creates a new Tensor instance with dtype torch.int8 with specified shape and data as array of + * bytes. + * + * @param shape Tensor shape + * @param data Tensor elements + */ + public static Tensor newInt8Tensor(long[] shape, byte[] data) { checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); - checkArgument(dims != null, ERROR_MSG_DIMS_NOT_NULL); - checkDims(dims); - checkDimsAndDataCapacityConsistency(data.length, dims); - final int bufferCapacity = (int) numElements(dims); - final IntBuffer intBuffer = allocateIntBuffer(bufferCapacity); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final ByteBuffer byteBuffer = allocateByteBuffer((int) numel(shape)); + byteBuffer.put(data); + return new Tensor_int8(byteBuffer, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.int32 with specified shape and data as array of + * ints. + * + * @param shape Tensor shape + * @param data Tensor elements + */ + public static Tensor newInt32Tensor(long[] shape, int[] data) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final IntBuffer intBuffer = allocateIntBuffer((int) numel(shape)); intBuffer.put(data); - return new Tensor_int32(intBuffer, dims); + return new Tensor_int32(intBuffer, shape); } - public static Tensor newByteTensor(long[] dims, byte[] data) { + /** + * Creates a new Tensor instance with dtype torch.float32 with specified shape and data as array + * of floats. + * + * @param shape Tensor shape + * @param data Tensor elements + */ + public static Tensor newFloat32Tensor(long[] shape, float[] data) { checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); - checkArgument(dims != null, ERROR_MSG_DIMS_NOT_NULL); - checkDims(dims); - checkDimsAndDataCapacityConsistency(data.length, dims); - final int bufferCapacity = (int) numElements(dims); - final ByteBuffer byteBuffer = allocateByteBuffer(bufferCapacity); - byteBuffer.put(data); - return new Tensor_byte(byteBuffer, dims); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final FloatBuffer floatBuffer = allocateFloatBuffer((int) numel(shape)); + floatBuffer.put(data); + return new Tensor_float32(floatBuffer, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.int64 with specified shape and data as array of + * longs. + * + * @param shape Tensor shape + * @param data Tensor elements + */ + public static Tensor newInt64Tensor(long[] shape, long[] data) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final LongBuffer longBuffer = allocateLongBuffer((int) numel(shape)); + longBuffer.put(data); + return new Tensor_int64(longBuffer, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.float64 with specified shape and data as array + * of doubles. + * + * @param shape Tensor shape + * @param data Tensor elements + */ + public static Tensor newFloat64Tensor(long[] shape, double[] data) { + checkArgument(data != null, ERROR_MSG_DATA_ARRAY_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.length, shape); + final DoubleBuffer doubleBuffer = allocateDoubleBuffer((int) numel(shape)); + doubleBuffer.put(data); + return new Tensor_float64(doubleBuffer, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.uint8 with specified shape and data. + * + * @param shape Tensor shape + * @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + */ + public static Tensor newUInt8Tensor(long[] shape, ByteBuffer data) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_uint8(data, shape); } - public static Tensor newFloatTensor(long[] dims, FloatBuffer data) { + /** + * Creates a new Tensor instance with dtype torch.int8 with specified shape and data. + * + * @param shape Tensor shape + * @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + */ + public static Tensor newInt8Tensor(long[] shape, ByteBuffer data) { checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); - checkArgument(dims != null, ERROR_MSG_DIMS_NOT_NULL); - checkDims(dims); - checkDimsAndDataCapacityConsistency(data.capacity(), dims); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); checkArgument( (data.order() == ByteOrder.nativeOrder()), ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); - return new Tensor_float32(data, dims); + return new Tensor_int8(data, shape); } - public static Tensor newIntTensor(long[] dims, IntBuffer data) { + /** + * Creates a new Tensor instance with dtype torch.int32 with specified shape and data. + * + * @param shape Tensor shape + * @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + */ + public static Tensor newInt32Tensor(long[] shape, IntBuffer data) { checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); - checkArgument(dims != null, ERROR_MSG_DIMS_NOT_NULL); - checkDims(dims); - checkDimsAndDataCapacityConsistency(data.capacity(), dims); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); checkArgument( (data.order() == ByteOrder.nativeOrder()), ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); - return new Tensor_int32(data, dims); + return new Tensor_int32(data, shape); } - public static Tensor newByteTensor(long[] dims, ByteBuffer data) { + /** + * Creates a new Tensor instance with dtype torch.float32 with specified shape and data. + * + * @param shape Tensor shape + * @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + */ + public static Tensor newFloat32Tensor(long[] shape, FloatBuffer data) { checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); - checkArgument(dims != null, ERROR_MSG_DIMS_NOT_NULL); - checkDims(dims); - checkDimsAndDataCapacityConsistency(data.capacity(), dims); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); checkArgument( (data.order() == ByteOrder.nativeOrder()), ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); - return new Tensor_byte(data, dims); + return new Tensor_float32(data, shape); } - private Tensor(long[] dims) { - checkDims(dims); - this.dims = Arrays.copyOf(dims, dims.length); + /** + * Creates a new Tensor instance with dtype torch.int64 with specified shape and data. + * + * @param shape Tensor shape + * @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + */ + public static Tensor newInt64Tensor(long[] shape, LongBuffer data) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_int64(data, shape); + } + + /** + * Creates a new Tensor instance with dtype torch.float64 with specified shape and data. + * + * @param shape Tensor shape + * @param data Direct buffer with native byte order that contains {@code Tensor#numel(shape)} + * elements. The buffer is used directly without copying, and changes to its content will + * change the tensor. + */ + public static Tensor newFloat64Tensor(long[] shape, DoubleBuffer data) { + checkArgument(data != null, ERROR_MSG_DATA_BUFFER_NOT_NULL); + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkShape(shape); + checkShapeAndDataCapacityConsistency(data.capacity(), shape); + checkArgument(data.isDirect(), ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT); + checkArgument( + (data.order() == ByteOrder.nativeOrder()), + ERROR_MSG_DATA_BUFFER_MUST_HAVE_NATIVE_BYTE_ORDER); + return new Tensor_float64(data, shape); + } + + private Tensor(long[] shape) { + checkShape(shape); + this.shape = Arrays.copyOf(shape, shape.length); + } + + /** Calculates number of elements in current tensor instance. */ + public long numel() { + return numel(this.shape); } - public static long numElements(long[] dims) { - checkDims(dims); + /** Calculates number of elements in tensor with specified shape. */ + public static long numel(long[] shape) { + checkShape(shape); int result = 1; - for (long dim : dims) { - result *= dim; + for (long s : shape) { + result *= s; } return result; } + /** + * Returns dtype of current tensor. Can be one of {@link Tensor#DTYPE_UINT8}, {@link + * Tensor#DTYPE_INT8}, {@link Tensor#DTYPE_INT32},{@link Tensor#DTYPE_FLOAT32}, {@link + * Tensor#DTYPE_INT64}, {@link Tensor#DTYPE_FLOAT64}. + */ + public abstract int dtype(); + + /** + * Returns newly allocated java byte array that contains a copy of tensor data. + * + * @throws IllegalStateException if it is called for a non-int8 tensor. + */ public byte[] getDataAsByteArray() { throw new IllegalStateException( "Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array."); } + /** + * Returns newly allocated java byte array that contains a copy of tensor data. + * + * @throws IllegalStateException if it is called for a non-uint8 tensor. + */ + public byte[] getDataAsUnsignedByteArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array."); + } + + /** + * Returns newly allocated java byte array that contains a copy of tensor data. + * + * @throws IllegalStateException if it is called for a non-int32 tensor. + */ public int[] getDataAsIntArray() { throw new IllegalStateException( "Tensor of type " + getClass().getSimpleName() + " cannot return data as int array."); } + /** + * Returns newly allocated java byte array that contains a copy of tensor data. + * + * @throws IllegalStateException if it is called for a non-float32 tensor. + */ public float[] getDataAsFloatArray() { throw new IllegalStateException( "Tensor of type " + getClass().getSimpleName() + " cannot return data as float array."); } - public boolean isByteTensor() { - return TYPE_CODE_BYTE == getTypeCode(); - } - - public boolean isIntTensor() { - return TYPE_CODE_INT32 == getTypeCode(); + /** + * Returns newly allocated java byte array that contains a copy of tensor data. + * + * @throws IllegalStateException if it is called for a non-int64 tensor. + */ + public long[] getDataAsLongArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as float array."); } - public boolean isFloatTensor() { - return TYPE_CODE_FLOAT32 == getTypeCode(); + /** + * Returns newly allocated java byte array that contains a copy of tensor data. + * + * @throws IllegalStateException if it is called for a non-float64 tensor. + */ + public double[] getDataAsDoubleArray() { + throw new IllegalStateException( + "Tensor of type " + getClass().getSimpleName() + " cannot return data as double array."); } - abstract int getTypeCode(); - Buffer getRawDataBuffer() { throw new IllegalStateException( "Tensor of type " + getClass().getSimpleName() + " cannot " + "return raw data buffer."); } - private static String invalidIndexErrorMessage(int[] index, long dims[]) { - return String.format( - Locale.US, - "Invalid index %s for tensor dimensions %s", - Arrays.toString(index), - Arrays.toString(dims)); + static class Tensor_uint8 extends Tensor { + private final ByteBuffer data; + + private Tensor_uint8(ByteBuffer data, long[] shape) { + super(shape); + this.data = data; + } + + @Override + public int dtype() { + return DTYPE_UINT8; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public byte[] getDataAsUnsignedByteArray() { + data.rewind(); + byte[] arr = new byte[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.uint8)", Arrays.toString(shape)); + } + } + + static class Tensor_int8 extends Tensor { + private final ByteBuffer data; + + private Tensor_int8(ByteBuffer data, long[] shape) { + super(shape); + this.data = data; + } + + @Override + public int dtype() { + return DTYPE_INT8; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public byte[] getDataAsByteArray() { + data.rewind(); + byte[] arr = new byte[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.int8)", Arrays.toString(shape)); + } + } + + static class Tensor_int32 extends Tensor { + private final IntBuffer data; + + private Tensor_int32(IntBuffer data, long[] shape) { + super(shape); + this.data = data; + } + + @Override + public int dtype() { + return DTYPE_INT32; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public int[] getDataAsIntArray() { + data.rewind(); + int[] arr = new int[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=torch.int32)", Arrays.toString(shape)); + } } static class Tensor_float32 extends Tensor { private final FloatBuffer data; - Tensor_float32(FloatBuffer data, long[] dims) { - super(dims); + Tensor_float32(FloatBuffer data, long[] shape) { + super(shape); this.data = data; } @@ -187,8 +527,8 @@ public float[] getDataAsFloatArray() { } @Override - int getTypeCode() { - return TYPE_CODE_FLOAT32; + public int dtype() { + return DTYPE_FLOAT32; } @Override @@ -198,23 +538,21 @@ Buffer getRawDataBuffer() { @Override public String toString() { - return String.format( - "Tensor_float32{dims:%s data:%s}", - Arrays.toString(dims), Arrays.toString(getDataAsFloatArray())); + return String.format("Tensor(%s, dtype=torch.float32)", Arrays.toString(shape)); } } - static class Tensor_int32 extends Tensor { - private final IntBuffer data; + static class Tensor_int64 extends Tensor { + private final LongBuffer data; - private Tensor_int32(IntBuffer data, long[] dims) { - super(dims); + private Tensor_int64(LongBuffer data, long[] shape) { + super(shape); this.data = data; } @Override - int getTypeCode() { - return TYPE_CODE_INT32; + public int dtype() { + return DTYPE_INT64; } @Override @@ -223,32 +561,30 @@ Buffer getRawDataBuffer() { } @Override - public int[] getDataAsIntArray() { + public long[] getDataAsLongArray() { data.rewind(); - int[] arr = new int[data.remaining()]; + long[] arr = new long[data.remaining()]; data.get(arr); return arr; } @Override public String toString() { - return String.format( - "Tensor_int32{dims:%s data:%s}", - Arrays.toString(dims), Arrays.toString(getDataAsIntArray())); + return String.format("Tensor(%s, dtype=torch.int64)", Arrays.toString(shape)); } } - static class Tensor_byte extends Tensor { - private final ByteBuffer data; + static class Tensor_float64 extends Tensor { + private final DoubleBuffer data; - private Tensor_byte(ByteBuffer data, long[] dims) { - super(dims); + private Tensor_float64(DoubleBuffer data, long[] shape) { + super(shape); this.data = data; } @Override - int getTypeCode() { - return TYPE_CODE_BYTE; + public int dtype() { + return DTYPE_FLOAT64; } @Override @@ -257,18 +593,16 @@ Buffer getRawDataBuffer() { } @Override - public byte[] getDataAsByteArray() { + public double[] getDataAsDoubleArray() { data.rewind(); - byte[] arr = new byte[data.remaining()]; + double[] arr = new double[data.remaining()]; data.get(arr); return arr; } @Override public String toString() { - return String.format( - "Tensor_byte{dims:%s data:%s}", - Arrays.toString(dims), Arrays.toString(getDataAsByteArray())); + return String.format("Tensor(%s, dtype=torch.float64)", Arrays.toString(shape)); } } @@ -279,48 +613,40 @@ private static void checkArgument(boolean expression, String errorMessage, Objec } } - private static void checkDims(long[] dims) { - checkArgument(dims != null, ERROR_MSG_DIMS_NOT_NULL); - checkArgument(dims.length > 0, ERROR_MSG_DIMS_NOT_EMPTY); - for (int i = 0; i < dims.length; i++) { - checkArgument(dims[i] >= 0, ERROR_MSG_DIMS_NON_NEGATIVE); + private static void checkShape(long[] shape) { + checkArgument(shape != null, ERROR_MSG_SHAPE_NOT_NULL); + checkArgument(shape.length > 0, ERROR_MSG_SHAPE_NOT_EMPTY); + for (int i = 0; i < shape.length; i++) { + checkArgument(shape[i] >= 0, ERROR_MSG_SHAPE_NON_NEGATIVE); } } - private static void checkIndex(int[] index, long dims[]) { - checkArgument(dims != null, ERROR_MSG_INDEX_NOT_NULL); - - if (index.length != dims.length) { - throw new IllegalArgumentException(invalidIndexErrorMessage(index, dims)); - } - - for (int i = 0; i < index.length; i++) { - if (index[i] >= dims[i]) { - throw new IllegalArgumentException(invalidIndexErrorMessage(index, dims)); - } - } - } - - private static void checkDimsAndDataCapacityConsistency(int dataCapacity, long[] dims) { - final long numElements = numElements(dims); + private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[] shape) { + final long numel = numel(shape); checkArgument( - numElements == dataCapacity, - "Inconsistent data capacity:%d and dims number elements:%d dims:%s", + numel == dataCapacity, + "Inconsistent data capacity:%d and shape number elements:%d shape:%s", dataCapacity, - numElements, - Arrays.toString(dims)); + numel, + Arrays.toString(shape)); } // endregion checks // Called from native - private static Tensor nativeNewTensor(ByteBuffer data, long[] dims, int typeCode) { - if (TYPE_CODE_FLOAT32 == typeCode) { - return new Tensor_float32(data.asFloatBuffer(), dims); - } else if (TYPE_CODE_INT32 == typeCode) { - return new Tensor_int32(data.asIntBuffer(), dims); - } else if (TYPE_CODE_BYTE == typeCode) { - return new Tensor_byte(data, dims); - } - throw new IllegalArgumentException("Unknown Tensor typeCode"); + private static Tensor nativeNewTensor(ByteBuffer data, long[] shape, int dtype) { + if (DTYPE_FLOAT32 == dtype) { + return new Tensor_float32(data.asFloatBuffer(), shape); + } else if (DTYPE_INT32 == dtype) { + return new Tensor_int32(data.asIntBuffer(), shape); + } else if (DTYPE_INT64 == dtype) { + return new Tensor_int64(data.asLongBuffer(), shape); + } else if (DTYPE_FLOAT64 == dtype) { + return new Tensor_float64(data.asDoubleBuffer(), shape); + } else if (DTYPE_UINT8 == dtype) { + return new Tensor_uint8(data, shape); + } else if (DTYPE_INT8 == dtype) { + return new Tensor_int8(data, shape); + } + throw new IllegalArgumentException("Unknown Tensor dtype"); } } diff --git a/android/pytorch_android_torchvision/build.gradle b/android/pytorch_android_torchvision/build.gradle index 382a0bc1ef68b..3f4c6b542e946 100644 --- a/android/pytorch_android_torchvision/build.gradle +++ b/android/pytorch_android_torchvision/build.gradle @@ -49,7 +49,7 @@ apply from: rootProject.file('gradle/release.gradle') task sourcesJar(type: Jar) { from android.sourceSets.main.java.srcDirs - archiveClassifier.set('sources') + classifier = 'sources' } artifacts.add('archives', sourcesJar) diff --git a/android/pytorch_android_torchvision/src/androidTest/java/org/pytorch/torchvision/TorchVisionInstrumentedTests.java b/android/pytorch_android_torchvision/src/androidTest/java/org/pytorch/torchvision/TorchVisionInstrumentedTests.java index ecb64840d2481..305bcc48fad63 100644 --- a/android/pytorch_android_torchvision/src/androidTest/java/org/pytorch/torchvision/TorchVisionInstrumentedTests.java +++ b/android/pytorch_android_torchvision/src/androidTest/java/org/pytorch/torchvision/TorchVisionInstrumentedTests.java @@ -23,6 +23,6 @@ public void setUp() { public void smokeTest() { Bitmap bitmap = Bitmap.createBitmap(320, 240, Bitmap.Config.ARGB_8888); Tensor tensor = TensorImageUtils.bitmapToFloatTensorTorchVisionForm(bitmap); - assertArrayEquals(new long[] {1l, 3l, 240l, 320l}, tensor.dims); + assertArrayEquals(new long[] {1l, 3l, 240l, 320l}, tensor.shape); } } diff --git a/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java b/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java index 70222d482cdf5..a194d68c22a49 100644 --- a/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java +++ b/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java @@ -39,8 +39,8 @@ public static Tensor bitmapToFloatTensorTorchVisionForm( floatArray[offset_g + i] = (g - NORM_MEAN_G) / NORM_STD_G; floatArray[offset_b + i] = (b - NORM_MEAN_B) / NORM_STD_B; } - final long dims[] = new long[] {1, 3, height, width}; - return Tensor.newFloatTensor(dims, floatArray); + final long shape[] = new long[] {1, 3, height, width}; + return Tensor.newFloat32Tensor(shape, floatArray); } public static Tensor imageYUV420CenterCropToFloatTensorTorchVisionForm( @@ -130,8 +130,8 @@ public static Tensor imageYUV420CenterCropToFloatTensorTorchVisionForm( floatArray[tensorInputOffsetB + offset] = ((b / 255.f) - NORM_MEAN_B) / NORM_STD_B; } } - final long dims[] = new long[] {1, 3, tensorHeight, tensorHeight}; - return Tensor.newFloatTensor(dims, floatArray); + final long shape[] = new long[] {1, 3, tensorHeight, tensorHeight}; + return Tensor.newFloat32Tensor(shape, floatArray); } private static final int clamp(int c, int min, int max) { diff --git a/aten/src/ATen/ATen.h b/aten/src/ATen/ATen.h index b741126ca1615..1a39455700900 100644 --- a/aten/src/ATen/ATen.h +++ b/aten/src/ATen/ATen.h @@ -10,9 +10,7 @@ #include #include #include -#ifdef BUILD_NAMEDTENSOR #include -#endif #include #include #include diff --git a/aten/src/ATen/AccumulateType.h b/aten/src/ATen/AccumulateType.h index ed5ecbb461748..9f91bcdcdcd65 100644 --- a/aten/src/ATen/AccumulateType.h +++ b/aten/src/ATen/AccumulateType.h @@ -1,6 +1,7 @@ #pragma once #include #include +#include // Defines the accumulation type for a scalar type. // Example: @@ -31,6 +32,7 @@ template <> struct AccumulateType { using type = int64_t; }; template <> struct AccumulateType { using type = int64_t; }; template <> struct AccumulateType { using type = int64_t; }; template <> struct AccumulateType { using type = int64_t; }; +template <> struct AccumulateType { using type = float; }; template <> struct AccumulateType { using type = double; }; template <> struct AccumulateType { using type = double; }; template <> struct AccumulateType { using type = int64_t; }; diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index ede456078cb5a..4d1aa4a5fc75d 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -210,7 +210,7 @@ endif() if(AT_NNPACK_ENABLED) include_directories(${NNPACK_INCLUDE_DIRS}) - list(APPEND ATen_CPU_DEPENDENCY_LIBS nnpack pthreadpool) # cpuinfo is added below + list(APPEND ATen_CPU_DEPENDENCY_LIBS nnpack) # cpuinfo is added below endif() if(MKLDNN_FOUND) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 6648c27837e53..f9ac82bc5b843 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -39,6 +39,14 @@ void Context::setUserEnabledCuDNN(bool e) { enabled_cudnn = e; } +bool Context::userEnabledMkldnn() const { + return enabled_mkldnn; +} + +void Context::setUserEnabledMkldnn(bool e) { + enabled_mkldnn = e; +} + bool Context::deterministicCuDNN() const { return deterministic_cudnn; } @@ -87,6 +95,32 @@ bool Context::hasLAPACK() const { #endif } +at::QEngine Context::qEngine() const { + return quantized_engine; +} + +void Context::setQEngine(at::QEngine e) { + const auto& qengines = supportedQEngines(); + if (std::find(qengines.begin(), qengines.end(), e) != qengines.end()) { + quantized_engine = e; + return; + } + TORCH_CHECK(false, "quantized engine ", toString(e), "is not supported"); +} + +std::vector Context::supportedQEngines() const { + static auto supported_qengines = { + at::kNoQEngine, + #ifdef USE_FBGEMM + at::kFBGEMM, + #endif + #ifdef USE_PYTORCH_QNNPACK + at::kQNNPACK, + #endif + }; + return supported_qengines; +} + bool Context::setFlushDenormal(bool on) { return at::cpu::set_flush_denormal(on); } diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index 2553b91c2b06a..0994178cf713d 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -99,11 +100,17 @@ class CAFFE2_API Context { // to test this instead bool userEnabledCuDNN() const; void setUserEnabledCuDNN(bool e); + bool userEnabledMkldnn() const; + void setUserEnabledMkldnn(bool e); bool benchmarkCuDNN() const; void setBenchmarkCuDNN(bool); bool deterministicCuDNN() const; void setDeterministicCuDNN(bool); -private: + at::QEngine qEngine() const; + void setQEngine(at::QEngine e); + std::vector supportedQEngines() const; + + private: void initCUDAIfNeeded(DeviceType p) { if (p == DeviceType::CUDA) { lazyInitCUDA(); @@ -119,6 +126,15 @@ class CAFFE2_API Context { bool enabled_cudnn = true; bool deterministic_cudnn = false; bool benchmark_cudnn = false; + bool enabled_mkldnn = true; + at::QEngine quantized_engine = +#ifdef USE_FBGEMM + at::kFBGEMM; +#elif defined(USE_PYTORCH_QNNPACK) + at::kQNNPACK; +#else + at::kNoQEngine; +#endif std::unique_ptr thc_state; std::unique_ptr thh_state; }; diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index 076936c493e61..968e49e9449d3 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -33,8 +33,7 @@ - THStorage* source - long storage_offset - IntArrayRefSize size - - arg: IntArrayRef stride - default: {} + - IntArrayRef stride ]] [[ name: _th_fill_ @@ -274,8 +273,7 @@ - THTensor* self - THIndexTensor* index - THTensor* source - - arg: bool accumulate - default: "false" + - bool accumulate ]] [[ name: _th_index_add_ @@ -1041,8 +1039,7 @@ - THTensor* self - arg: long dim wrap_dim: self - - arg: bool keepdim - default: "false" + - bool keepdim ]] [[ name: _th_max @@ -1081,8 +1078,7 @@ - THTensor* self - arg: long dim wrap_dim: self - - arg: bool keepdim - default: "false" + - bool keepdim ]] [[ name: _th_mode @@ -1098,9 +1094,7 @@ - THTensor* self - arg: long dim wrap_dim: self - default: __last_dim - - arg: bool keepdim - default: "false" + - bool keepdim ]] [[ name: _th_sort @@ -1115,10 +1109,8 @@ output: True - THTensor* self - arg: long dim - default: __last_dim wrap_dim: self - - arg: bool descending - default: "false" + - bool descending ]] [[ name: _th_topk @@ -1136,12 +1128,9 @@ - THTensor* self - long k - arg: long dim - default: __last_dim wrap_dim: self - - arg: bool largest - default: "true" - - arg: bool sorted - default: "true" + - bool largest + - bool sorted ]] [[ name: _th_abs @@ -1231,7 +1220,6 @@ types: - floating_point backends: - - CPU - CUDA variants: - function @@ -1246,7 +1234,6 @@ types: - floating_point backends: - - CPU - CUDA cname: lgamma variants: function @@ -1255,64 +1242,6 @@ - THTensor* self - THTensor* self ]] -[[ - name: _th_digamma - cname: digamma - types: - - floating_point - backends: - - CUDA - variants: - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] -[[ - name: _th_digamma_ - types: - - floating_point - backends: - - CUDA - cname: digamma - variants: function - return: self - arguments: - - THTensor* self - - THTensor* self -]] -[[ - name: _th_polygamma - cname: polygamma - types: - - floating_point - backends: - - CUDA - variants: - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - int64_t n - - THTensor* self -]] -[[ - name: _th_polygamma_ - types: - - floating_point - backends: - - CUDA - cname: polygamma - variants: function - return: self - arguments: - - THTensor* self - - int64_t n - - THTensor* self -]] [[ name: _th_exp cname: exp @@ -1509,20 +1438,6 @@ output: True - THTensor* self ]] -[[ - name: _th_rsqrt - cname: rsqrt - types: - - floating_point - backends: - - CUDA - variants: function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] [[ name: _th_floor cname: floor @@ -1537,20 +1452,6 @@ output: True - THTensor* self ]] -[[ - name: _th_round - cname: round - types: - - floating_point - backends: - - CUDA - variants: function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self -]] [[ name: _th_trunc cname: trunc @@ -1602,15 +1503,12 @@ - CUDA variants: function options: - - cname: varall + - cname: var_all return: accreal arguments: - THTensor* self - - arg: bool unbiased - if_true: 0 - if_false: 1 - default: 0 - - cname: var + - bool unbiased + - cname: var_single return: argument 0 scalar_check: self_->dim() == 0 || (keepdim == false && self_->dim() == 1) arguments: @@ -1619,12 +1517,8 @@ - THTensor* self - arg: long dim wrap_dim: self - - arg: bool unbiased - if_true: 0 - if_false: 1 - default: 0 - - arg: bool keepdim - default: "false" + - bool unbiased + - bool keepdim ]] [[ name: _th_std @@ -1635,15 +1529,12 @@ - CUDA variants: function options: - - cname: stdall + - cname: std_all return: accreal arguments: - THTensor* self - - arg: bool unbiased - if_true: 0 - if_false: 1 - default: 0 - - cname: std + - bool unbiased + - cname: std_single return: argument 0 scalar_check: self_->dim() == 0 || (keepdim == false && self_->dim() == 1) arguments: @@ -1652,12 +1543,8 @@ - THTensor* self - arg: long dim wrap_dim: self - - arg: bool unbiased - if_true: 0 - if_false: 1 - default: 0 - - arg: bool keepdim - default: "false" + - bool unbiased + - bool keepdim ]] [[ name: _th_renorm @@ -1713,8 +1600,7 @@ - arg: THTensor* self broadcast: other fallback - THTensor* other - - arg: real p - default: AS_REAL(2) + - real p ]] [[ name: _th_reciprocal @@ -1732,73 +1618,6 @@ output: True - THTensor* self ]] -[[ - name: _th_pow - backends: - - CUDA - cname: pow - variants: - - function - return: argument 0 - options: - - cname: pow - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - real exponent -]] -[[ - name: _th_pow - backends: - - CUDA - variants: - - function - return: argument 0 - options: - - cname: cpow - arguments: - - arg: THTensor* result - output: True - - arg: THTensor* self - broadcast: exponent fallback - - THTensor* exponent -]] -[[ - name: _th_pow - backends: - - CUDA - variants: - - function - return: argument 0 - options: - - cname: tpow - arguments: - - arg: THTensor* result - output: True - - real self - - THTensor* exponent -]] -[[ - name: _th_pow_ - backends: - - CUDA - return: argument 0 - cname: pow - variants: function - options: - - cname: pow - arguments: - - THTensor* self - - THTensor* self - - real exponent - - cname: cpow - arguments: - - THTensor* self - - arg: THTensor* self - broadcast: exponent inplace fallback - - THTensor* exponent -]] [[ name: _th_histc cname: histc @@ -1815,12 +1634,9 @@ - arg: THTensor* result output: True - THTensor* self - - arg: long bins - default: 100 - - arg: real min - default: 0 - - arg: real max - default: 0 + - long bins + - real min + - real max ]] [[ name: _th_zero_ @@ -2022,8 +1838,7 @@ - arg: THTensor* result output: True - THTensor* self - - arg: long diagonal - default: 0 + - long diagonal aten_custom_call: | if (self_->dim() == 0) { throw std::runtime_error("Input must be 1-d or 2-d"); @@ -2043,16 +1858,12 @@ - arguments: - arg: THTensor* result output: True - - arg: real beta - default: AS_REAL(1) - kwarg_only: True - arg: THTensor* self broadcast: mat1,mat2 dims:mat1.dim0,mat2.dim1 - - arg: real alpha - default: AS_REAL(1) - kwarg_only: True - THTensor* mat1 - THTensor* mat2 + - real beta + - real alpha ]] [[ name: _th_addmm_ @@ -2063,15 +1874,11 @@ - cname: addmm arguments: - THTensor* self - - arg: real beta - default: AS_REAL(1) - kwarg_only: True - THTensor* self - - arg: real alpha - default: AS_REAL(1) - kwarg_only: True - THTensor* mat1 - THTensor* mat2 + - real beta + - real alpha ]] [[ name: _th_addmv @@ -2082,16 +1889,12 @@ arguments: - arg: THTensor* result output: True - - arg: real beta - default: AS_REAL(1) - kwarg_only: True - arg: THTensor* self broadcast: mat,vec dims:mat.dim0 - - arg: real alpha - default: AS_REAL(1) - kwarg_only: True - THTensor* mat - THTensor* vec + - real beta + - real alpha ]] [[ name: _th_addmv_ @@ -2101,15 +1904,11 @@ return: self arguments: - THTensor* self - - arg: real beta - default: AS_REAL(1) - kwarg_only: True - THTensor* self - - arg: real alpha - default: AS_REAL(1) - kwarg_only: True - THTensor* mat - THTensor* vec + - real beta + - real alpha ]] [[ name: _th_addr @@ -2121,16 +1920,12 @@ arguments: - arg: THTensor* result output: True - - arg: real beta - default: AS_REAL(1) - kwarg_only: True - arg: THTensor* self broadcast: vec1,vec2 dims:vec1.dim0,vec2.dim0 - - arg: real alpha - default: AS_REAL(1) - kwarg_only: True - THTensor* vec1 - THTensor* vec2 + - real beta + - real alpha ]] [[ name: _th_addr_ @@ -2140,15 +1935,11 @@ variants: function arguments: - THTensor* self - - arg: real beta - default: AS_REAL(1) - kwarg_only: True - THTensor* self - - arg: real alpha - default: AS_REAL(1) - kwarg_only: True - THTensor* vec1 - THTensor* vec2 + - real beta + - real alpha ]] [[ name: _th_ger @@ -2162,11 +1953,11 @@ output: True resize: [ [self,0], [vec2,0] ] resize_scalar: True - - CONSTANT AS_REAL(0) - argument 0 - - CONSTANT AS_REAL(1) - THTensor* self - THTensor* vec2 + - CONSTANT AS_REAL(0) + - CONSTANT AS_REAL(1) ]] [[ name: _th_mv @@ -2180,11 +1971,11 @@ output: True resize: [ [self, 0] ] cpu_zero: True - - CONSTANT AS_REAL(0) - argument 0 - - CONSTANT AS_REAL(1) - THTensor* self - THTensor* vec + - CONSTANT AS_REAL(0) + - CONSTANT AS_REAL(1) ]] [[ name: _th_mm @@ -2199,11 +1990,11 @@ output: True resize: [ [self, 0], [mat2,1] ] cpu_zero: True - - CONSTANT AS_REAL(0) - argument 0 - - CONSTANT AS_REAL(1) - THTensor* self - THTensor* mat2 + - CONSTANT AS_REAL(0) + - CONSTANT AS_REAL(1) ]] [[ name: _th_bmm @@ -2219,11 +2010,11 @@ output: True resize: [ [self,0], [self,1], [mat2,2] ] cpu_zero: True - - CONSTANT AS_REAL(0) - argument 0 - - CONSTANT AS_REAL(1) - THTensor* self - THTensor* mat2 + - CONSTANT AS_REAL(0) + - CONSTANT AS_REAL(1) ]] [[ name: _th_addbmm @@ -2235,16 +2026,12 @@ arguments: - arg: THTensor* result output: True - - arg: real beta - default: AS_REAL(1) - kwarg_only: True - arg: THTensor* self broadcast: batch1,batch2 dims:batch1.dim1,batch2.dim2 - - arg: real alpha - default: AS_REAL(1) - kwarg_only: True - THTensor* batch1 - THTensor* batch2 + - real beta + - real alpha ]] [[ name: _th_addbmm_ @@ -2253,15 +2040,11 @@ return: self arguments: - THTensor* self - - arg: real beta - default: AS_REAL(1) - kwarg_only: True - THTensor* self - - arg: real alpha - default: AS_REAL(1) - kwarg_only: True - THTensor* batch1 - THTensor* batch2 + - real beta + - real alpha ]] [[ name: _th_baddbmm @@ -2275,16 +2058,12 @@ arguments: - arg: THTensor* result output: True - - arg: real beta - default: AS_REAL(1) - kwarg_only: True - arg: THTensor* self broadcast: batch1,batch2 dims:batch1.dim0,batch1.dim1,batch2.dim2 - - arg: real alpha - default: AS_REAL(1) - kwarg_only: True - THTensor* batch1 - THTensor* batch2 + - real beta + - real alpha ]] [[ name: _th_gels @@ -2324,10 +2103,7 @@ - arg: THTensor* res2 output: True - THTensor* self - - arg: bool eigenvectors - if_true: V - if_false: N - default: N + - bool eigenvectors ]] [[ name: _th_potri @@ -2346,10 +2122,7 @@ - arg: THTensor* output output: True - THTensor* self - - arg: bool upper - if_true: U - if_false: L - default: U + - bool upper ]] [[ name: _th_geqrf @@ -2406,14 +2179,8 @@ - THTensor* self - THTensor* input2 - THTensor* input3 - - arg: bool left - if_true: L - if_false: R - default: L - - arg: bool transpose - if_true: T - if_false: N - default: N + - bool left + - bool transpose ]] [[ name: _th_random_ @@ -2426,24 +2193,18 @@ - cname: random arguments: - THTensor* self - - arg: THGenerator* generator - default: nullptr - kwarg_only: True + - THGenerator* generator - cname: cappedRandom arguments: - THTensor* self - - arg: THGenerator* generator - default: nullptr - kwarg_only: True - int64_t to + - THGenerator* generator - cname: clampedRandom arguments: - THTensor* self - - arg: THGenerator* generator - default: nullptr - kwarg_only: True - int64_t from - int64_t to + - THGenerator* generator ]] [[ name: _th_multinomial_alias_setup @@ -2477,12 +2238,10 @@ arguments: - arg: THIndexTensor* result output: True - - arg: THGenerator* generator - default: nullptr - kwarg_only: True - THTensor* q - THIndexTensor* J - long num_samples + - THGenerator* generator ]] [[ name: _th_multinomial @@ -2497,13 +2256,10 @@ arguments: - arg: THIndexTensor* result output: True - - arg: THGenerator* generator - default: nullptr - kwarg_only: True - THTensor* self - long num_samples - - arg: bool replacement - default: "false" + - bool replacement + - THGenerator* generator ]] [[ name: _th_uniform_ @@ -2516,13 +2272,9 @@ return: self arguments: - THTensor* self - - arg: THGenerator* generator - default: nullptr - kwarg_only: True - - arg: double from - default: 0 - - arg: double to - default: 1 + - double from + - double to + - THGenerator* generator ]] [[ name: _th_normal @@ -2539,30 +2291,23 @@ arguments: - arg: THTensor* output output: True - - arg: THGenerator* generator - default: nullptr - kwarg_only: True - THTensor* mean - - arg: double std - default: 1 + - double std + - THGenerator* generator - cname: normal_stddevs arguments: - arg: THTensor* output output: True - - arg: THGenerator* generator - default: nullptr - kwarg_only: True - arg: double mean - THTensor* std + - THGenerator* generator - cname: normal_means_stddevs arguments: - arg: THTensor* output output: True - - arg: THGenerator* generator - default: nullptr - kwarg_only: True - THTensor* mean - THTensor* std + - THGenerator* generator ]] [[ name: _th_normal_ @@ -2575,13 +2320,9 @@ return: self arguments: - THTensor* self - - arg: THGenerator* generator - default: nullptr - kwarg_only: True - - arg: double mean - default: 0 - - arg: double std - default: 1 + - double mean + - double std + - THGenerator* generator ]] [[ name: _th_cauchy_ @@ -2594,13 +2335,9 @@ return: self arguments: - THTensor* self - - arg: THGenerator* generator - default: nullptr - kwarg_only: True - - arg: double median - default: 0 - - arg: double sigma - default: 1 + - double median + - double sigma + - THGenerator* generator ]] [[ name: _th_log_normal_ @@ -2613,13 +2350,9 @@ return: self arguments: - THTensor* self - - arg: THGenerator* generator - default: nullptr - kwarg_only: True - - arg: double mean - default: 1 - - arg: double std - default: 2 + - double mean + - double std + - THGenerator* generator ]] [[ name: _th_exponential_ @@ -2632,11 +2365,8 @@ return: self arguments: - THTensor* self - - arg: THGenerator* generator - default: nullptr - kwarg_only: True - - arg: double lambd - default: 1 + - double lambd + - THGenerator* generator ]] [[ name: _th_geometric_ @@ -2647,10 +2377,8 @@ return: self arguments: - THTensor* self - - arg: THGenerator* generator - default: nullptr - kwarg_only: True - double p + - THGenerator* generator ]] [[ @@ -2678,6 +2406,5 @@ - arg: THTensor* self output: True - TensorList tensors - - arg: int64_t dim - default: 0 + - int64_t dim ]] diff --git a/aten/src/ATen/NamedTensorUtils.cpp b/aten/src/ATen/NamedTensorUtils.cpp index 705beb5e1a60a..5441d4b4213f2 100644 --- a/aten/src/ATen/NamedTensorUtils.cpp +++ b/aten/src/ATen/NamedTensorUtils.cpp @@ -1,9 +1,9 @@ -#ifdef BUILD_NAMEDTENSOR - #include +#include #include #include +#ifdef BUILD_NAMEDTENSOR namespace at { // Returns "Tensor['N', 'C', 'H', 'W']" for a tensor with names ('N', 'C', 'H', 'W'). @@ -20,21 +20,10 @@ int64_t dimname_to_position(const Tensor& tensor, Dimname dim) { "Name ", dim, " not found in ", toDimnameRepr(tensor), "."); const auto names = tensor.names(); - const auto it = std::find_if( - names.begin(), names.end(), - [&dim](const Dimname& candidate) { return dim.can_refer_to(candidate); }); + const auto it = std::find(names.begin(), names.end(), dim); TORCH_CHECK(it != names.end(), "Name ", dim, " not found in ", toDimnameRepr(tensor), "."); - // Check that it can't refer to another dimension - const auto dup = std::find_if( - it + 1, names.end(), - [&dim](const Dimname& candidate) { return dim.can_refer_to(candidate); }); - TORCH_CHECK( - dup == names.end(), - "Name ", dim, " could refer to multiple dimensions in ", - toDimnameRepr(tensor), ". Please disambiguate by using a more ", - "specific name like ", *it, " or ", dup, "."); return std::distance(names.begin(), it); } @@ -65,11 +54,10 @@ static void check_for_misalignment( DimnameList names, DimnameList other_names, const char* action) { - if (name.is_wildcard()) { + if (name.isWildcard()) { return; } - auto it = std::find_if(other_names.begin(), other_names.end(), - [&](const Dimname& candidate) { return name.can_refer_to(candidate); }); + auto it = std::find(other_names.begin(), other_names.end(), name); // TODO(zou3519): Can improve message by checking if names are alignable and suggesting workarounds TORCH_CHECK(it == other_names.end(), "Misaligned dims when attempting to ", action, " dims ", names, " and dims ", @@ -94,20 +82,15 @@ std::vector unify_from_right( const auto& name = names_it == names.rend() ? wildcard : *names_it; const auto& other_name = other_it == other_names.rend() ? wildcard : *other_it; - // TODO(zou3519): Don't support tagged names for now. They're a little weird. - if (name.is_tagged() || other_name.is_tagged()) { - TORCH_INTERNAL_ASSERT("unify_from_right: NYI: tagged names."); - } - // Step 1: Check that the names match - const auto maybeName = unify(name, other_name); + const auto maybeName = name.unify(other_name); if (!maybeName) { report_positional_error(name, other_name, names, other_names, action); } *result_it = *maybeName; // Step 2: Check that the names are not misaligned - if (!name.is_normal() || !other_name.is_normal()) { + if (!name.isBasic() || !other_name.isBasic()) { // Let: N = max(len(names), len(other_names)) // K = # of special names among names and other_names. // This search (including the outer loop) is O(N*K) but typically # of dims is small. @@ -316,7 +299,7 @@ static DimnameList feature_dims(DimnameList names) { static bool are_distinct(DimnameList batch_dims, DimnameList feature_dims) { for (const auto& target : feature_dims) { - if (target.is_wildcard()) { + if (target.isWildcard()) { continue; } if (std::any_of(batch_dims.begin(), batch_dims.end(), diff --git a/aten/src/ATen/NamedTensorUtils.h b/aten/src/ATen/NamedTensorUtils.h index e4d47cf57f6d8..54fc9d83356e8 100644 --- a/aten/src/ATen/NamedTensorUtils.h +++ b/aten/src/ATen/NamedTensorUtils.h @@ -1,11 +1,12 @@ #pragma once -#ifdef BUILD_NAMEDTENSOR - +#include #include + #include #include #include +#ifdef BUILD_NAMEDTENSOR namespace at { using NameVector = SmallVector; @@ -38,7 +39,7 @@ namespace namedinference { // 2) If result has names, then `names` must be equal to result.names void propagate_names(Tensor& result, optional names); void propagate_names(Tensor& result, std::vector&& names, bool validate_names); -void propagate_names(Tensor& result, optional>&& maybe_names, bool validate_names); +CAFFE2_API void propagate_names(Tensor& result, optional>&& maybe_names, bool validate_names); void propagate_names(TensorImpl* result, optional names); void propagate_names(TensorImpl* result, std::vector&& names, bool validate_names); void propagate_names(TensorImpl* result, optional>&& maybe_names, bool validate_names); diff --git a/aten/src/ATen/NumericUtils.h b/aten/src/ATen/NumericUtils.h index 95274f79199bc..b809d5b550a90 100644 --- a/aten/src/ATen/NumericUtils.h +++ b/aten/src/ATen/NumericUtils.h @@ -1,5 +1,6 @@ #include #include +#include namespace at { @@ -19,4 +20,6 @@ inline bool _isnan(T val) { return std::isnan(val); } +inline bool _isnan(at::BFloat16 val) { return std::isnan(float(val)); } + } // namespace at diff --git a/aten/src/ATen/OpaqueTensorImpl.h b/aten/src/ATen/OpaqueTensorImpl.h index 62ab4c16cf12d..dd467d93982fe 100644 --- a/aten/src/ATen/OpaqueTensorImpl.h +++ b/aten/src/ATen/OpaqueTensorImpl.h @@ -19,9 +19,9 @@ namespace at { template struct CAFFE2_API OpaqueTensorImpl : public TensorImpl { // public constructor for now... - OpaqueTensorImpl(at::TensorTypeId type_id, const caffe2::TypeMeta& data_type, c10::Device device, + OpaqueTensorImpl(at::TensorTypeSet type_set, const caffe2::TypeMeta& data_type, c10::Device device, OpaqueHandle opaque_handle, c10::IntArrayRef sizes) - : TensorImpl(type_id, data_type, device), + : TensorImpl(type_set, data_type, device), opaque_handle_(std::move(opaque_handle)) { sizes_ = sizes.vec(); @@ -87,7 +87,7 @@ struct CAFFE2_API OpaqueTensorImpl : public TensorImpl { const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const override { auto impl = c10::make_intrusive>( - type_id(), dtype(), device(), opaque_handle_, sizes_); + type_set(), dtype(), device(), opaque_handle_, sizes_); copy_tensor_metadata( /*src_impl=*/this, /*dest_impl=*/impl.get(), @@ -104,7 +104,7 @@ struct CAFFE2_API OpaqueTensorImpl : public TensorImpl { * see NOTE [ TensorImpl Shallow-Copying ]. */ void shallow_copy_from(const c10::intrusive_ptr& impl) override { - AT_ASSERT(has_compatible_shallow_copy_type(impl->type_id())); + AT_ASSERT(has_compatible_shallow_copy_type(impl->type_set())); auto opaque_impl = static_cast*>(impl.get()); copy_tensor_metadata( /*src_impl=*/opaque_impl, diff --git a/aten/src/ATen/ParallelOpenMP.cpp b/aten/src/ATen/ParallelOpenMP.cpp index a47ebbfc42b66..ffff685583ad3 100644 --- a/aten/src/ATen/ParallelOpenMP.cpp +++ b/aten/src/ATen/ParallelOpenMP.cpp @@ -86,7 +86,7 @@ void intraop_launch(std::function func) { std::shared_ptr intraop_launch_future( std::function func) { func(); - auto future = std::make_shared(); + auto future = std::make_shared(NoneType::get()); future->markCompleted(); return future; } diff --git a/aten/src/ATen/ScalarOps.h b/aten/src/ATen/ScalarOps.h index f25557b6697d7..a6cea66c6f637 100644 --- a/aten/src/ATen/ScalarOps.h +++ b/aten/src/ATen/ScalarOps.h @@ -9,14 +9,16 @@ namespace c10 { // FIXME: this should be (and was) Scalar::toTensor, but there is currently no way // to implement this without going through Derived Types (which are not part of core). -inline at::Tensor scalar_to_tensor(Scalar s) { +inline at::Tensor scalar_to_tensor(Scalar s, const Device device = at::kCPU) { if (s.isFloatingPoint()) { - return at::scalar_tensor(s, at::device(at::kCPU).dtype(at::kDouble)); + return at::scalar_tensor(s, at::device(device).dtype(at::kDouble)); + } else if (s.isBoolean()) { + return at::scalar_tensor(s, at::device(device).dtype(at::kBool)); } else if (s.isComplex()) { - return at::scalar_tensor(s, at::device(at::kCPU).dtype(at::kComplexDouble)); + return at::scalar_tensor(s, at::device(device).dtype(at::kComplexDouble)); } else { - AT_ASSERT(s.isIntegral()); - return at::scalar_tensor(s, at::device(at::kCPU).dtype(at::kLong)); + AT_ASSERT(s.isIntegral(false)); + return at::scalar_tensor(s, at::device(device).dtype(at::kLong)); } } diff --git a/aten/src/ATen/SparseTensorImpl.cpp b/aten/src/ATen/SparseTensorImpl.cpp index b60501ee7b851..6af4b579210bc 100644 --- a/aten/src/ATen/SparseTensorImpl.cpp +++ b/aten/src/ATen/SparseTensorImpl.cpp @@ -6,13 +6,13 @@ namespace at { namespace { - DeviceType sparseTensorIdToDeviceType(TensorTypeId type_id) { - if (type_id == TensorTypeId::SparseCPUTensorId) { + DeviceType sparseTensorSetToDeviceType(TensorTypeSet type_set) { + if (type_set.has(TensorTypeId::SparseCPUTensorId)) { return kCPU; - } else if (type_id == TensorTypeId::SparseCUDATensorId) { + } else if (type_set.has(TensorTypeId::SparseCUDATensorId)) { return kCUDA; } else { - AT_ERROR("Cannot construct SparseTensor with non-sparse tensor type ID ", type_id); + AT_ERROR("Cannot construct SparseTensor with non-sparse tensor type ID ", type_set); } } } @@ -30,13 +30,13 @@ namespace { // // This means that we allocate a [1,0] size indices tensor and a [0] size // values tensor for such an empty tensor. -SparseTensorImpl::SparseTensorImpl(at::TensorTypeId type_id, const caffe2::TypeMeta& data_type) - : SparseTensorImpl(type_id, data_type - , at::empty({1, 0}, at::initialTensorOptions().device(sparseTensorIdToDeviceType(type_id)).dtype(ScalarType::Long)) - , at::empty({0}, at::initialTensorOptions().device(sparseTensorIdToDeviceType(type_id)).dtype(data_type))) {} +SparseTensorImpl::SparseTensorImpl(at::TensorTypeSet type_set, const caffe2::TypeMeta& data_type) + : SparseTensorImpl(type_set, data_type + , at::empty({1, 0}, at::initialTensorOptions().device(sparseTensorSetToDeviceType(type_set)).dtype(ScalarType::Long)) + , at::empty({0}, at::initialTensorOptions().device(sparseTensorSetToDeviceType(type_set)).dtype(data_type))) {} -SparseTensorImpl::SparseTensorImpl(at::TensorTypeId type_id, const caffe2::TypeMeta& data_type, at::Tensor indices, at::Tensor values) - : TensorImpl(type_id, data_type, values.device()) +SparseTensorImpl::SparseTensorImpl(at::TensorTypeSet type_set, const caffe2::TypeMeta& data_type, at::Tensor indices, at::Tensor values) + : TensorImpl(type_set, data_type, values.device()) , sparse_dim_(1) , dense_dim_(0) , indices_(std::move(indices)) diff --git a/aten/src/ATen/SparseTensorImpl.h b/aten/src/ATen/SparseTensorImpl.h index 5f9a9deec64c1..0a52bcab9c83c 100644 --- a/aten/src/ATen/SparseTensorImpl.h +++ b/aten/src/ATen/SparseTensorImpl.h @@ -31,7 +31,7 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { public: // Public for now... - explicit SparseTensorImpl(at::TensorTypeId, const caffe2::TypeMeta&); + explicit SparseTensorImpl(at::TensorTypeSet, const caffe2::TypeMeta&); int64_t nnz() const { return values_.size(0); } int64_t sparse_dim() const { return sparse_dim_; } @@ -192,7 +192,7 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { c10::intrusive_ptr shallow_copy_and_detach( const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const override { - auto impl = c10::make_intrusive(type_id(), dtype()); + auto impl = c10::make_intrusive(type_set(), dtype()); copy_tensor_metadata( /*src_impl=*/this, /*dest_impl=*/impl.get(), @@ -209,7 +209,7 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { * see NOTE [ TensorImpl Shallow-Copying ]. */ void shallow_copy_from(const c10::intrusive_ptr& impl) override { - AT_ASSERT(has_compatible_shallow_copy_type(impl->type_id())); + AT_ASSERT(has_compatible_shallow_copy_type(impl->type_set())); auto sparse_impl = static_cast(impl.get()); copy_tensor_metadata( /*src_impl=*/sparse_impl, @@ -219,7 +219,7 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { refresh_numel(); } private: - explicit SparseTensorImpl(at::TensorTypeId, const caffe2::TypeMeta&, at::Tensor indices, at::Tensor values); + explicit SparseTensorImpl(at::TensorTypeSet, const caffe2::TypeMeta&, at::Tensor indices, at::Tensor values); /** * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / storage_offset) diff --git a/aten/src/ATen/SparseTensorUtils.h b/aten/src/ATen/SparseTensorUtils.h index ecc52b2cb3734..45aa79eef91e4 100644 --- a/aten/src/ATen/SparseTensorUtils.h +++ b/aten/src/ATen/SparseTensorUtils.h @@ -1,3 +1,5 @@ +#pragma once + #include #include diff --git a/aten/src/ATen/TensorOperators.h b/aten/src/ATen/TensorOperators.h index 1828e7c9d2a06..6c3dd59a4e38f 100644 --- a/aten/src/ATen/TensorOperators.h +++ b/aten/src/ATen/TensorOperators.h @@ -45,7 +45,7 @@ inline Tensor& Tensor::operator/=(Scalar other) { return div_(other); } inline Tensor Tensor::operator[](Scalar index) const { - if (!index.isIntegral()) { + if (!index.isIntegral(false)) { AT_INDEX_ERROR("Can only index tensors with integral scalars"); } return select(0, index.toLong()); diff --git a/aten/src/ATen/Utils.h b/aten/src/ATen/Utils.h index fdd8fd2d80ce2..5358bf35f779b 100644 --- a/aten/src/ATen/Utils.h +++ b/aten/src/ATen/Utils.h @@ -71,8 +71,8 @@ static inline TensorImpl* checked_tensor_unwrap(const Tensor& expr, const char * if(allowNull && !expr.defined()) { return nullptr; } - if (tensorTypeIdToBackend(expr.type_id()) != backend) { - AT_ERROR("Expected object of backend ", backend, " but got backend ", tensorTypeIdToBackend(expr.type_id()), + if (tensorTypeIdToBackend(impl::dispatchTypeId(expr.type_set())) != backend) { + AT_ERROR("Expected object of backend ", backend, " but got backend ", tensorTypeIdToBackend(impl::dispatchTypeId(expr.type_set())), " for argument #", pos, " '", name, "' in call to ", api); } if (expr.scalar_type() != scalar_type) { @@ -91,8 +91,8 @@ static inline std::vector checked_tensor_list_unwrap(ArrayRef checked_tensor_list_unwrap(ArrayRef -std::array check_intlist(ArrayRef list, const char * name, int pos, ArrayRef def={}) { +std::array check_intlist(ArrayRef list, const char * name, int pos) { if (list.empty()) { - list = def; + // TODO: is this necessary? We used to treat nullptr-vs-not in IntList differently + // with strides as a way of faking optional. + list = {}; } auto res = std::array(); if (list.size() == 1 && N > 1) { diff --git a/aten/src/ATen/common_with_cwrap.py b/aten/src/ATen/common_with_cwrap.py index 2578241c0f2c8..f09803417718e 100644 --- a/aten/src/ATen/common_with_cwrap.py +++ b/aten/src/ATen/common_with_cwrap.py @@ -53,7 +53,7 @@ def set_declaration_defaults(declaration): def filter_unique_options(options, allow_kwarg, type_to_signature, remove_self): def exclude_arg(arg): - return arg.get('ignore_check') or arg['type'] == 'CONSTANT' + return arg['type'] == 'CONSTANT' def exclude_arg_with_self_check(arg): return exclude_arg(arg) or (remove_self and arg['name'] == 'self') @@ -91,10 +91,10 @@ def signature(option, kwarg_only_count): return unique -def sort_by_number_of_options(declaration, reverse=True): - def num_checked_args(option): - return sum(map(lambda a: not a.get('ignore_check', False), option['arguments'])) - declaration['options'].sort(key=num_checked_args, reverse=reverse) +def sort_by_number_of_args(declaration, reverse=True): + def num_args(option): + return len(option['arguments']) + declaration['options'].sort(key=num_args, reverse=reverse) class Function(object): diff --git a/aten/src/ATen/core/ATenDispatch.cpp b/aten/src/ATen/core/ATenDispatch.cpp index 26deaef09af8f..5f60b8e6dee2c 100644 --- a/aten/src/ATen/core/ATenDispatch.cpp +++ b/aten/src/ATen/core/ATenDispatch.cpp @@ -7,4 +7,37 @@ ATenDispatch & globalATenDispatch() { return singleton; } +void* ATenOpTable::getFallbackOp(TensorTypeId tid) const { + // TODO: an alternate strategy here would be to mask out the dead key + // and then redispatch gain (automatic delegation). I haven't done this + // for now to make it easier to smoke out error cases. + if (function_table_[static_cast(TensorTypeId::UndefinedTensorId)] == nullptr) { + std::ostringstream oss; + bool first = true; + for (int64_t i = 0; i < static_cast(TensorTypeId::NumTensorIds); i++) { + if (function_table_[i] != nullptr) { + if (!first) oss << ", "; + oss << toString(static_cast(i)); + first = false; + } + } + + // If there is no fallback dispatch, and dispatch failed because we didn't + // find any valid keys to dispatch on, this usually means the user gave + // us a non-empty list of tensors. So report a better error in this case. + // TODO: Maybe we should reword this error message + if (tid == TensorTypeId::UndefinedTensorId) { + TORCH_CHECK(false, + "There were no tensor arguments to this function (e.g., you passed an " + "empty list of Tensors), but no fallback function is registered for schema ", schema_, + ". This usually means that this function requires a non-empty list of Tensors. " + "Available functions are ", oss.str()) + } + TORCH_CHECK(false, + "No function is registered for schema ", schema_, " on tensor type ", toString(tid), + "; available functions are ", oss.str()); + } + return function_table_[static_cast(TensorTypeId::UndefinedTensorId)]; +} + } // namespace at diff --git a/aten/src/ATen/core/ATenDispatch.h b/aten/src/ATen/core/ATenDispatch.h index d830d93d19b77..c9307b8624750 100644 --- a/aten/src/ATen/core/ATenDispatch.h +++ b/aten/src/ATen/core/ATenDispatch.h @@ -1,11 +1,17 @@ #pragma once +#include #include +#include #include +#include +#include #include #include #include +// TODO: Rewrite this comment +// // This dispatch class serves as a replacement for our previous dispatch // mechanism, in which all functions were members of a Type class. A derived // class existed for each backend (and Variable), and the vtable was used to @@ -15,6 +21,25 @@ namespace at { +namespace impl { + +// Take a TensorTypeSet for a Tensor, and combine it with the current thread +// local valid (implemented) and enabled (not implemented) TensorTypeSets +// to determine what the actual dispatch TensorTypeId should be. Unlike +// Tensor::type_set(), the value of this on a tensor can change depending +// on TLS. +// +// NB: I didn't make this take a Tensor to avoid header include shenanigans. +// +// TODO: I'm not sure if this should live in this header or not; the operant +// question is whether or not we have access to all the relevant TLS at this +// point. +static inline TensorTypeId dispatchTypeId(TensorTypeSet ts) { + return (ts - c10::impl::tls_excluded_tensor_type_set()).highestPriorityTypeId(); +} + +} + // ATenOpTable stores the implementations for each backend, in addition to // an implementation for variables. class CAFFE2_API ATenOpTable { @@ -23,70 +48,53 @@ class CAFFE2_API ATenOpTable { : schema_(std::move(schema)) {} template - FuncType* getOp(Backend backend, bool is_variable) const { - if (is_variable) { - return reinterpret_cast(getVariableOp()); - } - return reinterpret_cast(getBaseOp(backend)); + FuncType* getOp(TensorTypeSet ts) const { + return reinterpret_cast(getOp(impl::dispatchTypeId(ts))); } private: - void registerOp(Backend backend, void* fn) { - TORCH_CHECK(function_table_[static_cast(backend)] == nullptr, - "Attempting to register variable function for schema ", schema_, - " and backend ", toString(backend), + void registerOp(TensorTypeId tid, void* fn) { + TORCH_CHECK(function_table_[static_cast(tid)] == nullptr, + "Attempting to register function for schema ", schema_, + " and tensor type ", toString(tid), " but there is already a function registered"); - function_table_[static_cast(backend)] = fn; + function_table_[static_cast(tid)] = fn; } - void registerVariableOp(void* fn) { - TORCH_CHECK(variable_function_ == nullptr, - "Attempting to register variable function for schema ", schema_, - " but there is already a function registered"); - variable_function_ = fn; - } + void* getFallbackOp(TensorTypeId tid) const; - void* getBaseOp(Backend backend) const { - if (function_table_[static_cast(backend)] == nullptr) { - TORCH_CHECK(function_table_[static_cast(Backend::Undefined)] != nullptr, - "No function is registered for schema ", schema_, " on backend ", toString(backend)); - return function_table_[static_cast(Backend::Undefined)]; + void* getOp(TensorTypeId tid) const { + // You might think we can minorly optimize this further by maintaining a + // bitmask of registered operator keys, so we don't select dispatch ids + // which don't have implementations here. But the net effect is that if you + // get a Variable CPUTensor, if there is no variable registration, you'll + // fall back to the CPU implementation. Is this what you want? Unlikely... + if (function_table_[static_cast(tid)] == nullptr) { + return getFallbackOp(tid); } - return function_table_[static_cast(backend)]; - } - - void* getVariableOp() const { - TORCH_CHECK(variable_function_ != nullptr, - "No variable function registered for ", schema_); - return variable_function_; + return function_table_[static_cast(tid)]; } friend class ATenDispatch; std::string schema_; - void* function_table_[static_cast(Backend::NumOptions)] = {nullptr}; - void* variable_function_ = nullptr; + void* function_table_[static_cast(TensorTypeId::NumTensorIds)] = {nullptr}; }; class CAFFE2_API ATenDispatch { public: template - ATenDispatch& registerOp(Backend backend, const char* schema, FuncType* fn) { + ATenDispatch& registerOp(TensorTypeId id, const char* schema, FuncType* fn) { std::lock_guard lock(mutex_); if (op_tables_.find(schema) == op_tables_.end()) { op_tables_.insert(std::make_pair(schema, ATenOpTable(schema))); } - op_tables_.at(schema).registerOp(backend, reinterpret_cast(fn)); + op_tables_.at(schema).registerOp(id, reinterpret_cast(fn)); return *this; } - template - ATenDispatch& registerVariableOp(const char* schema, FuncType* fn) { - std::lock_guard lock(mutex_); - if (op_tables_.find(schema) == op_tables_.end()) { - op_tables_.insert(std::make_pair(schema, ATenOpTable(schema))); - } - op_tables_.at(schema).registerVariableOp(reinterpret_cast(fn)); - return *this; + template + ATenDispatch& registerOp(Backend b, const char* schema, FuncType* fn) { + return registerOp(backendToTensorTypeId(b), schema, fn); } const ATenOpTable* getOpTable(const char* schema) const { diff --git a/aten/src/ATen/core/Dict.h b/aten/src/ATen/core/Dict.h index 5a73605fc3db1..304afbdfde7fe 100644 --- a/aten/src/ATen/core/Dict.h +++ b/aten/src/ATen/core/Dict.h @@ -3,8 +3,8 @@ #include #include #include -#include #include +#include #include #include @@ -39,22 +39,18 @@ struct DictKeyEqualTo { }; struct DictImpl final : public c10::intrusive_ptr_target { - using dict_map_type = ska::flat_hash_map; + using dict_map_type = ska_ordered::order_preserving_flat_hash_map; struct DictElementTypes final { TypePtr keyType; TypePtr valueType; }; - explicit DictImpl(dict_map_type dict_, optional elementTypes_) + explicit DictImpl(dict_map_type dict_, DictElementTypes elementTypes_) : dict(std::move(dict_)) - , elementTypes(std::move(elementTypes_)) { - TORCH_INTERNAL_ASSERT(!elementTypes.has_value() || (nullptr != elementTypes->keyType.get() && nullptr != elementTypes->valueType.get()), "Key and value type must not be nullptr"); - } - + , elementTypes(std::move(elementTypes_)) {} dict_map_type dict; - // TODO Right now, this is optional, but we want to make it mandatory for all dicts to know their types - optional elementTypes; + DictElementTypes elementTypes; intrusive_ptr copy() const; }; @@ -183,7 +179,6 @@ class DictIterator final : public std::iterator Dict toTypedDict(Dict dict); template Dict toGenericDict(Dict dict); -struct deprecatedUntypedDict final {}; } /** @@ -207,9 +202,9 @@ class Dict final { private: static_assert((std::is_same::value && std::is_same::value) || guts::typelist::contains::value, "Invalid Key type for Dict. We only support int64_t, double, bool, and string."); - // impl_ stores the underlying map as a ska::flat_hash_map. + // impl_ stores the underlying map as a ska_ordered::order_preserving_flat_hash_map. // We intentionally don't offer conversion from/to - // ska::flat_hash_map, return references to it or something like that, + // order_preserving_flat_hash_map, return references to it or something like that, // because such operations would get expensive if we switch out // the actual map implementation. // This is an intrusive_ptr because Dict is a pointer type. @@ -240,14 +235,6 @@ class Dict final { */ explicit Dict(TypePtr keyType, TypePtr valueType); - /** - * Creates an untyped dict, i.e. a Dict that doesn't know its types and - * doesn't do type checking. - * Please don't use this if you can avoid it. We want to get rid of untyped - * dicts. - */ - explicit Dict(impl::deprecatedUntypedDict); - ~Dict() = default; Dict(const Dict&) = default; @@ -354,8 +341,20 @@ class Dict final { // private API for now because the return type will change to TypePtr // instead of optional once types are mandatory. - optional _keyType() const; - optional _valueType() const; + TypePtr keyType() const; + TypePtr valueType() const; + + // [unsafe set type] + // These functions mutate the tagged type of this dictionary in place. + // There is no checking that the members of the dictionary are instances + // of the new types, nor is there a check that other IValues which + // hold references to this dictionary have the right static type. + // This functionality is used only in the unpickler, where at + // creation type the real type of the dictionary is unknown, but + // then later recovered from the static type information of the + // unpickled object. + void unsafeSetKeyType(TypePtr t); + void unsafeSetValueType(TypePtr t); }; namespace impl { diff --git a/aten/src/ATen/core/Dict_inl.h b/aten/src/ATen/core/Dict_inl.h index 7266699623e1d..a0e57fb1828fc 100644 --- a/aten/src/ATen/core/Dict_inl.h +++ b/aten/src/ATen/core/Dict_inl.h @@ -31,10 +31,8 @@ inline bool shallowEquals(const IValue& lhs, const IValue& rhs) { template Dict toTypedDict(GenericDict dict) { - if (dict.impl_->elementTypes.has_value()) { - TORCH_INTERNAL_ASSERT(*getTypePtr() == *dict.impl_->elementTypes->keyType, "Tried to cast a Dict<", toString(dict.impl_->elementTypes->keyType), ", ", toString(dict.impl_->elementTypes->valueType) ,"> to a Dict<", toString(getTypePtr()), ", ", toString(getTypePtr()), ">. Key types mismatch."); - TORCH_INTERNAL_ASSERT(*getTypePtr() == *dict.impl_->elementTypes->valueType, "Tried to cast a Dict<", toString(dict.impl_->elementTypes->keyType), ", ", toString(dict.impl_->elementTypes->valueType) ,"> to a Dict<", toString(getTypePtr()), ", ", toString(getTypePtr()), ">. Value types mismatch."); - } + TORCH_INTERNAL_ASSERT(*getTypePtr() == *dict.impl_->elementTypes.keyType, "Tried to cast a Dict<", toString(dict.impl_->elementTypes.keyType), ", ", toString(dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(getTypePtr()), ", ", toString(getTypePtr()), ">. Key types mismatch."); + TORCH_INTERNAL_ASSERT(*getTypePtr() == *dict.impl_->elementTypes.valueType, "Tried to cast a Dict<", toString(dict.impl_->elementTypes.keyType), ", ", toString(dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(getTypePtr()), ", ", toString(getTypePtr()), ">. Value types mismatch."); return Dict(std::move(dict.impl_)); } @@ -87,15 +85,6 @@ Dict::Dict(TypePtr keyType, TypePtr valueType) static_assert(std::is_same::value, "This constructor is only valid for c10::impl::GenericDict."); } -template -Dict::Dict(impl::deprecatedUntypedDict) -: Dict(make_intrusive( - detail::DictImpl::dict_map_type(), - c10::nullopt)) { - static_assert(std::is_same::value, "This constructor is only valid for c10::impl::GenericDict."); - static_assert(std::is_same::value, "This constructor is only valid for c10::impl::GenericDict."); -} - template Dict::Dict(Dict&& rhs) noexcept: impl_(std::move(rhs.impl_)) { rhs.impl_ = make_intrusive(detail::DictImpl::dict_map_type(), impl_->elementTypes); @@ -194,19 +183,22 @@ void Dict::reserve(size_type count) const { } template -optional Dict::_keyType() const { - if (!impl_->elementTypes.has_value()) { - return c10::nullopt; - } - return impl_->elementTypes->keyType; +TypePtr Dict::keyType() const { + return impl_->elementTypes.keyType; } template -optional Dict::_valueType() const { - if (!impl_->elementTypes.has_value()) { - return c10::nullopt; - } - return impl_->elementTypes->valueType; +TypePtr Dict::valueType() const { + return impl_->elementTypes.valueType; +} +template +void Dict::unsafeSetKeyType(TypePtr t) { + impl_->elementTypes.keyType = std::move(t); +} + +template +void Dict::unsafeSetValueType(TypePtr t) { + impl_->elementTypes.valueType = std::move(t); } } diff --git a/aten/src/ATen/core/Dimname.cpp b/aten/src/ATen/core/Dimname.cpp index d9075be595d6a..56ff9853bacf3 100644 --- a/aten/src/ATen/core/Dimname.cpp +++ b/aten/src/ATen/core/Dimname.cpp @@ -1,25 +1,26 @@ -#ifdef BUILD_NAMEDTENSOR #include #include +#include +#include +#ifdef BUILD_NAMEDTENSOR namespace at { std::ostream& operator<<(std::ostream& out, const Dimname& dimname) { if (dimname.type() == NameType::WILDCARD) { out << "None"; } else { - out << "'" << dimname.full_name().toUnqualString() << "'"; + out << "'" << dimname.symbol().toUnqualString() << "'"; } return out; } -bool is_valid_identifier(const std::string& name) { - std::locale loc; +bool Dimname::isValidName(const std::string& name) { if (name.length() == 0) { return false; } for (auto it = name.begin(); it != name.end(); ++it) { - if (std::isalpha(*it, loc) || *it == '_') { + if (std::isalpha(*it) || *it == '_') { continue; } return false; @@ -27,77 +28,42 @@ bool is_valid_identifier(const std::string& name) { return true; } -bool Dimname::can_refer_to(const Dimname& other) const { - switch (type()) { - case NameType::WILDCARD: - return false; - - // "C" can be used to refer to "C" or "C.in". - case NameType::NORMAL: - return untagged_name() == other.untagged_name(); - - default: - return full_name() == other.full_name(); - } -} - static void check_valid_identifier(const std::string& name) { TORCH_CHECK( - is_valid_identifier(name), + Dimname::isValidName(name), "Invalid name: a valid identifier must contain alphabetical characters and/or underscore, got: '", name, "'."); } -Dimname Dimname::fromSymbol(Symbol full_name) { - TORCH_INTERNAL_ASSERT(full_name.is_dimname()); - if (full_name == kWildcard) { +Dimname Dimname::fromSymbol(Symbol name) { + TORCH_INTERNAL_ASSERT(name.is_dimname()); + if (name == kWildcard) { return Dimname::wildcard(); } - const std::string delimiter = "."; - const std::string str(full_name.toUnqualString()); - auto it = str.find(delimiter); - - // Check for normal name - if (it == std::string::npos) { - check_valid_identifier(str); - return Dimname(full_name); - } - - // Check for tagged name - auto second_dot = str.find(delimiter, it + 1); - TORCH_CHECK( - second_dot == std::string::npos, - "Invalid name '", str, "': A tagged name can only contain one '.'"); - auto untagged_name = str.substr(0, it); - auto tag = str.substr(it + 1); - check_valid_identifier(untagged_name); - check_valid_identifier(tag); - return Dimname(NameType::TAGGED, full_name, Symbol::dimname(untagged_name)); + check_valid_identifier(name.toUnqualString()); + return Dimname(name); } Dimname Dimname::wildcard() { - static Dimname result(NameType::WILDCARD, kWildcard, kWildcard); + static Dimname result(kWildcard, NameType::WILDCARD); return result; } -optional unify(Dimname dimname, Dimname other) { +optional Dimname::unify(Dimname other) const { if (other.type() == NameType::WILDCARD) { - return dimname; + return *this; } - if (dimname.type() == NameType::WILDCARD) { + if (type_ == NameType::WILDCARD) { return other; } - if (dimname.full_name() == other.full_name()) { - return dimname; - } - if (dimname.untagged_name() == other.untagged_name()) { - return Dimname::fromSymbol(dimname.untagged_name()); + if (name_ == other.symbol()) { + return *this; } return c10::nullopt; } -bool match(Dimname dimname, Dimname other) { - return unify(dimname, other).has_value(); +bool Dimname::matches(Dimname other) const { + return unify(other).has_value(); } } // namespace at diff --git a/aten/src/ATen/core/Dimname.h b/aten/src/ATen/core/Dimname.h index bb2cc51a8b5bb..dfa6fcae4c2c8 100644 --- a/aten/src/ATen/core/Dimname.h +++ b/aten/src/ATen/core/Dimname.h @@ -1,6 +1,7 @@ #pragma once -#ifdef BUILD_NAMEDTENSOR +#include +#ifdef BUILD_NAMEDTENSOR #include #include #include @@ -8,52 +9,40 @@ namespace at { -enum class NameType: uint8_t { NORMAL, WILDCARD, TAGGED }; +enum class NameType: uint8_t { BASIC, WILDCARD }; struct CAFFE2_API Dimname { static Dimname fromSymbol(Symbol name); static Dimname wildcard(); + static bool isValidName(const std::string& name); NameType type() const { return type_; } - Symbol full_name() const { return full_name_; } - Symbol untagged_name() const { return untagged_name_; } + Symbol symbol() const { return name_; } - bool can_refer_to(const Dimname& other) const; + bool isBasic() const { return type_ == NameType::BASIC; } + bool isWildcard() const { return type_ == NameType::WILDCARD; } - bool is_normal() const { return type_ == NameType::NORMAL; } - bool is_wildcard() const { return type_ == NameType::WILDCARD; } - bool is_tagged() const { return type_ == NameType::TAGGED; } + bool matches(Dimname other) const; + optional unify(Dimname other) const; private: Dimname(Symbol name) - : untagged_name_(name), full_name_(name), type_(NameType::NORMAL) {} - Dimname(NameType type, Symbol full_name, Symbol untagged_name) - : untagged_name_(untagged_name), full_name_(full_name), type_(type) {} + : name_(name), type_(NameType::BASIC) {} + Dimname(Symbol name, NameType type) + : name_(name), type_(type) {} - // [Dimname Terminology] - // - // For "C.in": - // - "C.in" is the "full name" - // - "C" is the "untagged name" - // - "in" is the "tag" - Symbol untagged_name_; - Symbol full_name_; + Symbol name_; NameType type_; - // Will need more fields for other special name types. }; using DimnameList = c10::ArrayRef; static Symbol kWildcard = Symbol::dimname("*"); -bool CAFFE2_API is_valid_identifier(const std::string& name); - -CAFFE2_API c10::optional unify(Dimname dimname, Dimname other); -CAFFE2_API bool match(Dimname dimname, Dimname other); CAFFE2_API std::ostream& operator<<(std::ostream& out, const Dimname& dimname); inline bool operator==(const Dimname& lhs, const Dimname& rhs) { - return lhs.full_name() == rhs.full_name(); + return lhs.symbol() == rhs.symbol(); } inline bool operator!=(const Dimname& lhs, const Dimname& rhs) { diff --git a/aten/src/ATen/core/EnableNamedTensor.h b/aten/src/ATen/core/EnableNamedTensor.h new file mode 100644 index 0000000000000..7fd679149bb61 --- /dev/null +++ b/aten/src/ATen/core/EnableNamedTensor.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +// We are working on removing the BUILD_NAMEDTENSOR flag from the codebase. +// +// PyTorch's codegen also uses a similar flag. You can find it in +// - aten/src/ATen/env.py +#if !defined(CAFFE2_IS_XPLAT_BUILD) && (!defined(C10_MOBILE) || defined(FEATURE_TORCH_MOBILE)) +#ifndef BUILD_NAMEDTENSOR +#define BUILD_NAMEDTENSOR +#endif +#endif diff --git a/aten/src/ATen/core/LegacyTypeDispatch.h b/aten/src/ATen/core/LegacyTypeDispatch.h index dd0fb5d10bfa0..f3dc9e457c67c 100644 --- a/aten/src/ATen/core/LegacyTypeDispatch.h +++ b/aten/src/ATen/core/LegacyTypeDispatch.h @@ -13,12 +13,18 @@ #include #include #include +#include namespace at { class CAFFE2_API LegacyTypeDispatch { public: - void initForBackend(Backend b) { + void initForTensorTypeSet(TensorTypeSet ts) { + // TODO: Avoid use of legacyExtractTypeId here. The key + // problem is that you may get a TensorTypeSet with + // VariableTensorId set; should you initialize the "underlying" + // type in that case? Hard to say. + auto b = tensorTypeIdToBackend(legacyExtractTypeId(ts)); auto p = backendToDeviceType(b); static std::once_flag cpu_once; static std::once_flag cuda_once; diff --git a/aten/src/ATen/core/List.h b/aten/src/ATen/core/List.h index c8cde840f97cf..3126da8fb7fb4 100644 --- a/aten/src/ATen/core/List.h +++ b/aten/src/ATen/core/List.h @@ -23,16 +23,13 @@ template struct ListImpl final : public c10::intrusive_ptr_target { using list_type = std::vector; - explicit ListImpl(list_type list_, optional elementType_) + explicit ListImpl(list_type list_, TypePtr elementType_) : list(std::move(list_)) - , elementType(std::move(elementType_)) { - TORCH_INTERNAL_ASSERT(!elementType.has_value() || nullptr != elementType->get(), "Element type must not be nullptr"); - } + , elementType(std::move(elementType_)) {} list_type list; - // TODO Right now, this is optional, but we want to make it mandatory for all lists to know their types - optional elementType; + TypePtr elementType; intrusive_ptr copy() const { return make_intrusive(list, elementType); @@ -182,7 +179,6 @@ template List toGenericList(List list); const IValue* ptr_to_first_element(const List& list); template List toList(std::vector list); template const std::vector& toVector(const List& list); -struct deprecatedUntypedList final {}; } template bool list_is_equal(const List& lhs, const List& rhs); @@ -253,14 +249,6 @@ class List final { */ explicit List(TypePtr elementType); - /** - * Creates an untyped list, i.e. a List that doesn't know its type and - * doesn't do type checking. - * Please don't use this if you can avoid it. We want to get rid of untyped - * lists. - */ - explicit List(impl::deprecatedUntypedList); - List(const List&) = default; List& operator=(const List&) = default; List(List&&) noexcept; @@ -435,9 +423,10 @@ class List final { // TODO Test use_count size_t use_count() const; - // private API for now because the return type will change to TypePtr - // instead of optional once types are mandatory. - optional _elementType() const; + TypePtr elementType() const; + + // See [unsafe set type] for why this exists. + void unsafeSetElementType(TypePtr t); private: explicit List(c10::intrusive_ptr>&& elements); diff --git a/aten/src/ATen/core/List_inl.h b/aten/src/ATen/core/List_inl.h index 9b1a66cb35f77..d954d4d5cc048 100644 --- a/aten/src/ATen/core/List_inl.h +++ b/aten/src/ATen/core/List_inl.h @@ -17,14 +17,7 @@ List::List() : List(make_intrusive::StorageT>>( typename detail::ListImpl::StorageT>::list_type(), getTypePtr())) { - static_assert(!std::is_same::value, "This constructor is not valid for List. Please use c10::impl::GenericList(elementType) instead, or if you absolutely have to, use c10::impl::GenericList(c10::impl::deprecatedUntypedList())."); -} - -template -inline List::List(c10::impl::deprecatedUntypedList) -: List(make_intrusive>( - typename detail::ListImpl::list_type(), - c10::nullopt)) { + static_assert(!std::is_same::value, "This constructor is not valid for List. Please use c10::impl::GenericList(elementType) instead."); } template @@ -32,7 +25,7 @@ List::List(ArrayRef values) : List(make_intrusive::StorageT>>( typename detail::ListImpl::StorageT>::list_type(), getTypePtr())) { - static_assert(!std::is_same::value, "This constructor is not valid for List. Please use c10::impl::GenericList(elementType) instead, or if you absolutely have to, use c10::impl::GenericList(c10::impl::deprecatedUntypedList())."); + static_assert(!std::is_same::value, "This constructor is not valid for List. Please use c10::impl::GenericList(elementType)."); impl_->list.reserve(values.size()); for (const T& element : values) { impl_->list.push_back(element); @@ -42,7 +35,7 @@ List::List(ArrayRef values) template List::List(std::initializer_list initial_values) : List(ArrayRef(initial_values)) { - static_assert(!std::is_same::value, "This constructor is not valid for List. Please use c10::impl::GenericList(elementType) instead, or if you absolutely have to, use c10::impl::GenericList(c10::impl::deprecatedUntypedList())."); + static_assert(!std::is_same::value, "This constructor is not valid for List. Please use c10::impl::GenericList(elementType)."); } template @@ -57,9 +50,7 @@ namespace impl { template List toTypedList(impl::GenericList list) { static_assert(std::is_same::StorageT>::value, "Can only call toTypedList with lists that store their elements as IValues."); - if (list.impl_->elementType.has_value()) { - TORCH_INTERNAL_ASSERT(*getTypePtr() == **list.impl_->elementType, "Tried to cast a List<", toString(*list.impl_->elementType), "> to a List<", toString(getTypePtr()), ">. Types mismatch."); - } + TORCH_INTERNAL_ASSERT(*getTypePtr() == *list.impl_->elementType, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr()), ">. Types mismatch."); return List(std::move(list.impl_)); } @@ -289,9 +280,13 @@ size_t List::use_count() const { return impl_.use_count(); } -template -optional List::_elementType() const { +template +TypePtr List::elementType() const { return impl_->elementType; } +template +void List::unsafeSetElementType(TypePtr t) { + impl_->elementType = std::move(t); +} } diff --git a/aten/src/ATen/core/NamedTensor.cpp b/aten/src/ATen/core/NamedTensor.cpp index 6d2151795c65f..1d078c9e54497 100644 --- a/aten/src/ATen/core/NamedTensor.cpp +++ b/aten/src/ATen/core/NamedTensor.cpp @@ -1,5 +1,7 @@ -#ifdef BUILD_NAMEDTENSOR #include +#include + +#ifdef BUILD_NAMEDTENSOR #include #include @@ -40,38 +42,22 @@ DimnameList default_names(size_t len) { return DimnameList(&all_unnamed.front(), len); } -namespace impl { - -// Two Dimnames cannot be in the same Tensor if one of them can refer to the other. -// In practice, this constraint means that a Tensor cannot have duplicate names -// unless they are tagged and the tags are different. -static DimnameList::const_iterator find_incompatible_name( - DimnameList::const_iterator begin, - DimnameList::const_iterator end, - const Dimname& target) { - return std::find_if(begin, end, - [&target](const Dimname& candidate) { - return target.can_refer_to(candidate) || candidate.can_refer_to(target); - }); +void check_names_valid_for(const Tensor& tensor, DimnameList names) { + return impl::check_names_valid_for(tensor.unsafeGetTensorImpl(), names); } +namespace impl { + static void check_unique_names(DimnameList names) { // Strategy: Compare each element with the ones that come after it. // Although this is O(N^2), in practice N is small (no more than 25). for (auto it = names.begin(); it != names.end(); ++it) { - auto dup = find_incompatible_name(it + 1, names.end(), *it); + if (it->isWildcard()) continue; + auto dup = std::find(it + 1, names.end(), *it); while (dup != names.end()) { - // Simple error message if you're not using tags - TORCH_CHECK(it->type() == NameType::TAGGED || dup->type() == NameType::TAGGED, + TORCH_CHECK(false, "Cannot construct a tensor with duplicate names. Got names: ", names, "."); - - // Complicated error message if you're using tags - TORCH_CHECK(false, - "Cannot construct a tensor with duplicate names unless they are tagged ", - "and have different tags. Got names: ", names, ", offending names: (", - *it, " and ", *dup, ")."); - dup = find_incompatible_name(dup + 1, names.end(), *it); } } } @@ -90,7 +76,7 @@ static const NamedTensorMeta* get_named_tensor_meta(const TensorImpl* impl) { return static_cast(impl->named_tensor_meta()); } -void check_valid_names(TensorImpl* impl, DimnameList names) { +void check_names_valid_for(TensorImpl* impl, DimnameList names) { auto ndim = impl->dim(); TORCH_CHECK( ndim <= kMaxNamedTensorDim, @@ -108,7 +94,7 @@ void internal_set_names_inplace(TensorImpl* impl, optional names) { impl->set_named_tensor_meta(nullptr); return; } - check_valid_names(impl, *names); + check_names_valid_for(impl, *names); auto* meta = get_named_tensor_meta(impl); if (meta == nullptr) { impl->set_named_tensor_meta(c10::guts::make_unique(*names)); @@ -119,7 +105,7 @@ void internal_set_names_inplace(TensorImpl* impl, optional names) { void internal_set_names_inplace(TensorImpl* impl, std::vector&& names, bool validate_names) { if (validate_names) { - check_valid_names(impl, names); + check_names_valid_for(impl, names); } auto* meta = get_named_tensor_meta(impl); if (meta == nullptr) { diff --git a/aten/src/ATen/core/NamedTensor.h b/aten/src/ATen/core/NamedTensor.h index 28f40c0c4f374..e9d51cdf994c9 100644 --- a/aten/src/ATen/core/NamedTensor.h +++ b/aten/src/ATen/core/NamedTensor.h @@ -1,10 +1,11 @@ #pragma once -#ifdef BUILD_NAMEDTENSOR +#include #include #include #include +#ifdef BUILD_NAMEDTENSOR namespace at { // XXX: This file exists because TensorImpl is in c10, but Dimname is in ATen. @@ -33,6 +34,11 @@ struct CAFFE2_API NamedTensorMeta : public c10::NamedTensorMetaInterface { bool has_names() const; DimnameList names() const { return names_; } + // Used for an assertion in TensorImpl.h + int64_t slow_dim() const override { + return names_.size(); + } + void set_names(DimnameList new_names) { TORCH_INTERNAL_ASSERT(new_names.size() == names_.size()); std::copy(new_names.begin(), new_names.end(), names_.begin()); @@ -68,6 +74,7 @@ struct CAFFE2_API NoNamesGuard { bool prev_mode; }; +void check_names_valid_for(const Tensor& tensor, DimnameList names); // Sets the names of `tensor` to be `names`. CAFFE2_API Tensor& internal_set_names_inplace(Tensor& tensor, optional names); @@ -84,6 +91,8 @@ namespace impl { CAFFE2_API void internal_set_names_inplace(TensorImpl* impl, optional names); CAFFE2_API void internal_set_names_inplace(TensorImpl* impl, std::vector&& names, bool validate_names); +void check_names_valid_for(TensorImpl* impl, DimnameList names); + // Returns true if the tensor's names exist and are not all 'None'. // Returns false if the tensor's names don't exist (were not allocated), // or if all names are 'None'. diff --git a/aten/src/ATen/core/OpsAlreadyMovedToC10.cpp b/aten/src/ATen/core/OpsAlreadyMovedToC10.cpp new file mode 100644 index 0000000000000..1a33833eba2ba --- /dev/null +++ b/aten/src/ATen/core/OpsAlreadyMovedToC10.cpp @@ -0,0 +1,1390 @@ +#include +#include + +#include +#include +#include +#include +#include + +// @generated by aten/src/ATen/gen.py + +// TODO Once all ATen ops are moved to c10, this file should be removed + +namespace at { + +namespace { +struct OpNameEquals final { + bool operator()(const std::pair& lhs, const std::pair& rhs) const { + return 0 == strcmp(lhs.first, rhs.first) && 0 == strcmp(lhs.second, rhs.second); + } +}; + +struct OpNameHash final { + size_t operator()(const std::pair& p) const { + // use std::hash because std::hash would hash pointers and not pointed-to strings + return std::hash()(p.first) ^ (~ std::hash()(p.second)); + } +}; +} + +bool aten_op_is_already_moved_to_c10(const c10::OperatorName& opName) { + static std::unordered_set, OpNameHash, OpNameEquals> ops { + {"aten::_cast_Byte", ""}, + {"aten::_cast_Char", ""}, + {"aten::_cast_Double", ""}, + {"aten::_cast_Float", ""}, + {"aten::_cast_Int", ""}, + {"aten::_cast_Long", ""}, + {"aten::_cast_Short", ""}, + {"aten::_cast_Half", ""}, + {"aten::data", ""}, + #ifdef BUILD_NAMEDTENSOR + {"aten::align_as", ""}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::align_tensors", ""}, + #endif + {"aten::_cudnn_ctc_loss", ""}, + {"aten::_cudnn_rnn_flatten_weight", ""}, + {"aten::_debug_has_internal_overlap", ""}, + {"aten::_masked_scale", ""}, + {"aten::_sobol_engine_ff_", ""}, + {"aten::_sobol_engine_scramble_", ""}, + {"aten::_sobol_engine_initialize_state_", ""}, + {"aten::_reshape_from_tensor", ""}, + {"aten::_shape_as_tensor", ""}, + {"aten::dropout", ""}, + {"aten::dropout_", ""}, + {"aten::feature_dropout", ""}, + {"aten::feature_dropout_", ""}, + {"aten::alpha_dropout", ""}, + {"aten::alpha_dropout_", ""}, + {"aten::feature_alpha_dropout", ""}, + {"aten::feature_alpha_dropout_", ""}, + {"aten::abs", ""}, + {"aten::abs_", ""}, + {"aten::acos", ""}, + {"aten::acos_", ""}, + {"aten::avg_pool1d", ""}, + {"aten::adaptive_avg_pool1d", ""}, + {"aten::adaptive_max_pool1d", ""}, + {"aten::add", "Tensor"}, + {"aten::add_", "Tensor"}, + {"aten::add", "Scalar"}, + {"aten::add_", "Scalar"}, + {"aten::addmv", ""}, + {"aten::addmv_", ""}, + {"aten::addr", ""}, + {"aten::addr_", ""}, + {"aten::affine_grid_generator", ""}, + {"aten::affine_grid_generator_backward", ""}, + {"aten::all", "dim"}, + {"aten::allclose", ""}, + {"aten::any", "dim"}, + {"aten::_dim_arange", ""}, + {"aten::argmax", ""}, + {"aten::argmin", ""}, + {"aten::as_strided", ""}, + {"aten::as_strided_", ""}, + {"aten::asin", ""}, + {"aten::asin_", ""}, + {"aten::atan", ""}, + {"aten::atan_", ""}, + {"aten::baddbmm", ""}, + {"aten::baddbmm_", ""}, + {"aten::_baddbmm_mkl_", ""}, + {"aten::bitwise_not", ""}, + {"aten::bitwise_not_", ""}, + {"aten::logical_not", ""}, + {"aten::logical_not_", ""}, + {"aten::logical_xor", ""}, + {"aten::logical_xor_", ""}, + {"aten::bmm", ""}, + {"aten::broadcast_tensors", ""}, + {"aten::cat", ""}, + {"aten::ceil", ""}, + {"aten::ceil_", ""}, + {"aten::chain_matmul", ""}, + {"aten::chunk", ""}, + {"aten::clamp", ""}, + {"aten::clamp_", ""}, + {"aten::clamp_max", ""}, + {"aten::clamp_max_", ""}, + {"aten::clamp_min", ""}, + {"aten::clamp_min_", ""}, + {"aten::cudnn_is_acceptable", ""}, + {"aten::constant_pad_nd", ""}, + {"aten::conv_tbc", ""}, + {"aten::conv_tbc_backward", ""}, + {"aten::copy_", ""}, + {"aten::_copy_from", ""}, + {"aten::cos", ""}, + {"aten::cos_", ""}, + {"aten::cosh", ""}, + {"aten::cosh_", ""}, + {"aten::cosine_embedding_loss", ""}, + {"aten::cudnn_affine_grid_generator", ""}, + {"aten::cudnn_affine_grid_generator_backward", ""}, + {"aten::cudnn_convolution_backward_input", ""}, + {"aten::cudnn_convolution_backward", ""}, + {"aten::cudnn_convolution_backward_bias", ""}, + {"aten::cudnn_convolution_backward_weight", ""}, + {"aten::cudnn_convolution_transpose_backward", ""}, + {"aten::cudnn_convolution_transpose_backward_bias", ""}, + {"aten::cudnn_convolution_transpose_backward_input", ""}, + {"aten::cudnn_convolution_transpose_backward_weight", ""}, + {"aten::cudnn_grid_sampler", ""}, + {"aten::cudnn_grid_sampler_backward", ""}, + {"aten::ctc_loss", "IntList"}, + {"aten::ctc_loss", "Tensor"}, + {"aten::_ctc_loss", ""}, + {"aten::_ctc_loss_backward", ""}, + {"aten::det", ""}, + {"aten::diag_embed", ""}, + {"aten::diagflat", ""}, + {"aten::diagonal", ""}, + {"aten::fill_diagonal_", ""}, + {"aten::div", "Tensor"}, + {"aten::div_", "Tensor"}, + {"aten::div", "Scalar"}, + {"aten::div_", "Scalar"}, + {"aten::dot", ""}, + {"aten::einsum", ""}, + {"aten::embedding", ""}, + {"aten::embedding_backward", ""}, + {"aten::embedding_dense_backward", ""}, + {"aten::embedding_renorm_", ""}, + {"aten::embedding_sparse_backward", ""}, + {"aten::_embedding_bag_per_sample_weights_backward", ""}, + {"aten::resize_", ""}, + {"aten::empty_like", ""}, + {"aten::erf", ""}, + {"aten::erf_", ""}, + {"aten::erfc", ""}, + {"aten::erfc_", ""}, + {"aten::exp", ""}, + {"aten::exp_", ""}, + {"aten::expm1", ""}, + {"aten::expm1_", ""}, + {"aten::expand", ""}, + {"aten::expand_as", ""}, + {"aten::flatten", "using_ints"}, + {"aten::fill_", "Scalar"}, + {"aten::fill_", "Tensor"}, + {"aten::floor", ""}, + {"aten::floor_", ""}, + {"aten::frac", ""}, + {"aten::frac_", ""}, + {"aten::full_like", ""}, + {"aten::grid_sampler", ""}, + {"aten::grid_sampler_2d", ""}, + {"aten::grid_sampler_2d_backward", ""}, + {"aten::grid_sampler_3d", ""}, + {"aten::grid_sampler_3d_backward", ""}, + {"aten::hinge_embedding_loss", ""}, + {"aten::ger", ""}, + {"aten::fft", ""}, + {"aten::ifft", ""}, + {"aten::rfft", ""}, + {"aten::irfft", ""}, + {"aten::_fft_with_size", ""}, + {"aten::_cufft_get_plan_cache_size", ""}, + {"aten::_cufft_get_plan_cache_max_size", ""}, + {"aten::index_copy_", ""}, + {"aten::index_copy", ""}, + {"aten::inverse", ""}, + {"aten::_inverse_helper", ""}, + {"aten::isclose", ""}, + {"aten::isnan", ""}, + {"aten::is_distributed", ""}, + {"aten::is_floating_point", ""}, + {"aten::is_complex", ""}, + {"aten::is_nonzero", ""}, + {"aten::is_same_size", ""}, + {"aten::is_signed", ""}, + {"aten::kl_div", ""}, + {"aten::kl_div_backward", ""}, + {"aten::kthvalue", ""}, + {"aten::fbgemm_linear_int8_weight_fp32_activation", ""}, + {"aten::fbgemm_linear_int8_weight", ""}, + {"aten::fbgemm_linear_quantize_weight", ""}, + {"aten::fbgemm_pack_gemm_matrix_fp16", ""}, + {"aten::fbgemm_linear_fp16_weight_fp32_activation", ""}, + {"aten::fbgemm_linear_fp16_weight", ""}, + {"aten::fbgemm_pack_quantized_matrix", ""}, + {"aten::fbgemm_pack_quantized_matrix", "KN"}, + {"aten::fbgemm_is_cpu_supported", ""}, + {"aten::log", ""}, + {"aten::log_", ""}, + {"aten::log10", ""}, + {"aten::log10_", ""}, + {"aten::log1p", ""}, + {"aten::log1p_", ""}, + {"aten::log2", ""}, + {"aten::log2_", ""}, + {"aten::logdet", ""}, + {"aten::_log_softmax", ""}, + {"aten::_log_softmax_backward_data", ""}, + {"aten::logsumexp", ""}, + {"aten::margin_ranking_loss", ""}, + {"aten::matmul", ""}, + {"aten::matrix_rank", "tol"}, + {"aten::matrix_rank", ""}, + {"aten::matrix_power", ""}, + {"aten::max", "dim"}, + {"aten::max_values", ""}, + {"aten::max_pool1d_with_indices", ""}, + {"aten::max_pool1d", ""}, + {"aten::max_pool2d", ""}, + {"aten::mkldnn_max_pool2d", ""}, + {"aten::quantized_max_pool2d", ""}, + {"aten::max_pool3d", ""}, + {"aten::median", "dim"}, + {"aten::min", "dim"}, + {"aten::min_values", ""}, + {"aten::mkldnn_convolution_backward_input", ""}, + {"aten::mkldnn_convolution_backward_weights", ""}, + {"aten::mkldnn_convolution_backward", ""}, + {"aten::miopen_convolution_backward_input", ""}, + {"aten::miopen_convolution_backward", ""}, + {"aten::miopen_convolution_backward_bias", ""}, + {"aten::miopen_convolution_backward_weight", ""}, + {"aten::miopen_convolution_transpose_backward", ""}, + {"aten::miopen_convolution_transpose_backward_input", ""}, + {"aten::miopen_convolution_transpose_backward_weight", ""}, + {"aten::miopen_depthwise_convolution_backward_input", ""}, + {"aten::miopen_depthwise_convolution_backward", ""}, + {"aten::miopen_depthwise_convolution_backward_weight", ""}, + {"aten::mm", ""}, + {"aten::_sparse_mm", ""}, + {"aten::mode", ""}, + {"aten::mul", "Tensor"}, + {"aten::mul_", "Tensor"}, + {"aten::mul", "Scalar"}, + {"aten::mul_", "Scalar"}, + {"aten::mv", ""}, + {"aten::mvlgamma", ""}, + {"aten::mvlgamma_", ""}, + {"aten::narrow_copy", ""}, + {"aten::narrow", ""}, + {"aten::batch_norm_stats", ""}, + {"aten::_nnpack_available", ""}, + {"aten::_nnpack_spatial_convolution_backward", ""}, + {"aten::_nnpack_spatial_convolution_backward_input", ""}, + {"aten::_nnpack_spatial_convolution_backward_weight", ""}, + {"aten::ones_like", ""}, + {"aten::pairwise_distance", ""}, + {"aten::cdist", ""}, + {"aten::_cdist_backward", ""}, + {"aten::pdist", ""}, + {"aten::_pdist_forward", ""}, + {"aten::_pdist_backward", ""}, + {"aten::cosine_similarity", ""}, + {"aten::permute", ""}, + {"aten::numpy_T", ""}, + {"aten::pixel_shuffle", ""}, + {"aten::is_pinned", ""}, + {"aten::pin_memory", ""}, + {"aten::pinverse", ""}, + {"aten::poisson_nll_loss", ""}, + {"aten::rand_like", ""}, + {"aten::randint_like", ""}, + {"aten::randint_like", "low"}, + {"aten::randn_like", ""}, + {"aten::reciprocal", ""}, + {"aten::reciprocal_", ""}, + {"aten::neg", ""}, + {"aten::neg_", ""}, + {"aten::repeat", ""}, + {"aten::repeat_interleave", "Tensor"}, + {"aten::repeat_interleave", "self_Tensor"}, + {"aten::repeat_interleave", "self_int"}, + {"aten::reshape", ""}, + {"aten::_mkldnn_reshape", ""}, + {"aten::reshape_as", ""}, + {"aten::round", ""}, + {"aten::round_", ""}, + {"aten::relu", ""}, + {"aten::relu_", ""}, + {"aten::prelu", ""}, + {"aten::prelu_backward", ""}, + {"aten::gelu", ""}, + {"aten::gelu_backward", ""}, + {"aten::hardshrink", ""}, + {"aten::hardshrink_backward", ""}, + {"aten::rsqrt", ""}, + {"aten::rsqrt_", ""}, + {"aten::select", "int"}, + {"aten::selu", ""}, + {"aten::selu_", ""}, + {"aten::celu", ""}, + {"aten::celu_", ""}, + {"aten::sigmoid", ""}, + {"aten::sigmoid_", ""}, + {"aten::sin", ""}, + {"aten::sin_", ""}, + {"aten::sinh", ""}, + {"aten::sinh_", ""}, + {"aten::detach", ""}, + {"aten::detach_", ""}, + {"aten::size", "int"}, + {"aten::slice", "Tensor"}, + {"aten::slogdet", ""}, + {"aten::smm", ""}, + {"aten::_softmax", ""}, + {"aten::_softmax_backward_data", ""}, + {"aten::split", "Tensor"}, + {"aten::split_with_sizes", ""}, + {"aten::squeeze", ""}, + {"aten::squeeze", "dim"}, + {"aten::squeeze_", ""}, + {"aten::squeeze_", "dim"}, + {"aten::sspaddmm", ""}, + {"aten::stack", ""}, + {"aten::stride", "int"}, + {"aten::sum_to_size", ""}, + {"aten::sqrt", ""}, + {"aten::sqrt_", ""}, + {"aten::std", ""}, + {"aten::std", "dim"}, + {"aten::std_mean", ""}, + {"aten::std_mean", "dim"}, + {"aten::t", ""}, + {"aten::t_", ""}, + {"aten::tan", ""}, + {"aten::tan_", ""}, + {"aten::tanh", ""}, + {"aten::tanh_", ""}, + {"aten::tensordot", ""}, + {"aten::threshold", ""}, + {"aten::threshold_", ""}, + {"aten::threshold_backward", ""}, + {"aten::transpose", "int"}, + {"aten::_mkldnn_transpose", ""}, + {"aten::transpose_", ""}, + {"aten::_mkldnn_transpose_", ""}, + {"aten::one_hot", ""}, + {"aten::flip", ""}, + {"aten::roll", ""}, + {"aten::rot90", ""}, + {"aten::trapz", "x"}, + {"aten::trapz", "dx"}, + {"aten::_trilinear", ""}, + {"aten::triplet_margin_loss", ""}, + {"aten::trunc", ""}, + {"aten::trunc_", ""}, + {"aten::type_as", ""}, + {"aten::_has_compatible_shallow_copy_type", ""}, + {"aten::_unique", ""}, + {"aten::unique_dim", ""}, + {"aten::unique_consecutive", ""}, + {"aten::unique_dim_consecutive", ""}, + {"aten::_unique2", ""}, + {"aten::_unsafe_view", ""}, + {"aten::unsqueeze", ""}, + {"aten::unsqueeze_", ""}, + {"aten::var", ""}, + {"aten::var", "dim"}, + {"aten::var_mean", ""}, + {"aten::var_mean", "dim"}, + {"aten::view_as", ""}, + {"aten::where", "self"}, + {"aten::where", ""}, + {"aten::_s_where", ""}, + {"aten::norm_except_dim", ""}, + {"aten::_weight_norm", ""}, + {"aten::_weight_norm_cuda_interface", ""}, + {"aten::_weight_norm_cuda_interface_backward", ""}, + {"aten::_weight_norm_differentiable_backward", ""}, + {"aten::zeros_like", ""}, + {"aten::_standard_gamma_grad", ""}, + {"aten::_dirichlet_grad", ""}, + {"aten::native_norm", ""}, + {"aten::_sparse_sum", ""}, + {"aten::_sparse_sum", "dim"}, + {"aten::_sparse_sum_backward", ""}, + {"aten::norm", "Scalar"}, + {"aten::norm", "ScalarOpt_dim"}, + {"aten::frobenius_norm", ""}, + {"aten::frobenius_norm", "dim"}, + {"aten::nuclear_norm", ""}, + {"aten::nuclear_norm", "dim"}, + {"aten::clone", ""}, + {"aten::resize_as_", ""}, + {"aten::pow", "Tensor_Scalar"}, + {"aten::zero_", ""}, + {"aten::sub", "Tensor"}, + {"aten::sub_", "Tensor"}, + {"aten::sub", "Scalar"}, + {"aten::sub_", "Scalar"}, + {"aten::rsub", "Tensor"}, + {"aten::rsub", "Scalar"}, + {"aten::_sparse_addmm", ""}, + {"aten::addmm", ""}, + {"aten::addmm_", ""}, + {"aten::sparse_resize_", ""}, + {"aten::sparse_resize_and_clear_", ""}, + {"aten::sparse_mask", ""}, + {"aten::to_dense", ""}, + {"aten::to_dense_backward", ""}, + {"aten::sparse_dim", ""}, + {"aten::_dimI", ""}, + {"aten::dense_dim", ""}, + {"aten::_dimV", ""}, + {"aten::_nnz", ""}, + {"aten::coalesce", ""}, + {"aten::is_coalesced", ""}, + {"aten::_indices", ""}, + {"aten::_values", ""}, + {"aten::_coalesced_", ""}, + {"aten::indices", ""}, + {"aten::values", ""}, + {"aten::hspmm", ""}, + {"aten::copy_sparse_to_sparse_", ""}, + {"aten::numel", ""}, + {"aten::unbind", "int"}, + {"aten::to_sparse", "sparse_dim"}, + {"aten::to_sparse", ""}, + {"aten::to_mkldnn", ""}, + {"aten::mkldnn_reorder_conv2d_weight", ""}, + {"aten::to_mkldnn_backward", ""}, + {"aten::dequantize", ""}, + {"aten::q_scale", ""}, + {"aten::q_zero_point", ""}, + {"aten::q_per_channel_scales", ""}, + {"aten::q_per_channel_zero_points", ""}, + {"aten::int_repr", ""}, + {"aten::_per_tensor_affine_qtensor", ""}, + {"aten::_per_channel_affine_qtensor", ""}, + {"aten::fake_quantize_per_tensor_affine", ""}, + {"aten::fake_quantize_per_tensor_affine_backward", ""}, + {"aten::to", "other"}, + {"aten::meshgrid", ""}, + {"aten::cartesian_prod", ""}, + {"aten::combinations", ""}, + {"aten::item", ""}, + {"aten::_local_scalar_dense", ""}, + {"aten::_thnn_fused_gru_cell_backward", ""}, + {"aten::lstm", "input"}, + {"aten::lstm", "data"}, + {"aten::gru", "input"}, + {"aten::gru", "data"}, + {"aten::rnn_tanh", "input"}, + {"aten::rnn_tanh", "data"}, + {"aten::rnn_relu", "input"}, + {"aten::rnn_relu", "data"}, + {"aten::quantized_gru", "input"}, + {"aten::quantized_gru", "data"}, + {"aten::quantized_lstm_cell", ""}, + {"aten::quantized_gru_cell", ""}, + {"aten::quantized_rnn_relu_cell", ""}, + {"aten::quantized_rnn_tanh_cell", ""}, + {"aten::_pack_padded_sequence", ""}, + {"aten::_pack_padded_sequence_backward", ""}, + {"aten::_pad_packed_sequence", ""}, + {"aten::set_", "source_Tensor"}, + {"aten::set_", ""}, + {"aten::is_set_to", ""}, + {"aten::masked_fill_", "Scalar"}, + {"aten::masked_fill", "Scalar"}, + {"aten::masked_fill_", "Tensor"}, + {"aten::masked_fill", "Tensor"}, + {"aten::masked_scatter_", ""}, + {"aten::masked_scatter", ""}, + {"aten::view", ""}, + {"aten::put_", ""}, + {"aten::index_add_", ""}, + {"aten::index_add", ""}, + {"aten::index_fill_", "Scalar"}, + {"aten::index_fill", "Scalar"}, + {"aten::index_fill_", "Tensor"}, + {"aten::index_fill", "Tensor"}, + {"aten::scatter_", "src"}, + {"aten::scatter", "src"}, + {"aten::scatter_", "value"}, + {"aten::scatter", "value"}, + {"aten::scatter_add_", ""}, + {"aten::scatter_add", ""}, + {"aten::lt_", "Scalar"}, + {"aten::lt_", "Tensor"}, + {"aten::gt_", "Scalar"}, + {"aten::gt_", "Tensor"}, + {"aten::le_", "Scalar"}, + {"aten::le_", "Tensor"}, + {"aten::ge_", "Scalar"}, + {"aten::ge_", "Tensor"}, + {"aten::eq_", "Scalar"}, + {"aten::eq_", "Tensor"}, + {"aten::ne_", "Scalar"}, + {"aten::ne_", "Tensor"}, + {"aten::__and__", "Scalar"}, + {"aten::__and__", "Tensor"}, + {"aten::__iand__", "Scalar"}, + {"aten::__iand__", "Tensor"}, + {"aten::__or__", "Scalar"}, + {"aten::__or__", "Tensor"}, + {"aten::__ior__", "Scalar"}, + {"aten::__ior__", "Tensor"}, + {"aten::__xor__", "Scalar"}, + {"aten::__xor__", "Tensor"}, + {"aten::__ixor__", "Scalar"}, + {"aten::__ixor__", "Tensor"}, + {"aten::__lshift__", "Scalar"}, + {"aten::__lshift__", "Tensor"}, + {"aten::__ilshift__", "Scalar"}, + {"aten::__ilshift__", "Tensor"}, + {"aten::__rshift__", "Scalar"}, + {"aten::__rshift__", "Tensor"}, + {"aten::__irshift__", "Scalar"}, + {"aten::__irshift__", "Tensor"}, + {"aten::lgamma_", ""}, + {"aten::atan2_", ""}, + {"aten::tril_", ""}, + {"aten::triu_", ""}, + {"aten::digamma_", ""}, + {"aten::polygamma_", ""}, + {"aten::renorm_", ""}, + {"aten::pow_", "Scalar"}, + {"aten::pow_", "Tensor"}, + {"aten::lerp_", "Scalar"}, + {"aten::lerp_", "Tensor"}, + {"aten::fmod_", "Scalar"}, + {"aten::fmod_", "Tensor"}, + {"aten::remainder_", "Scalar"}, + {"aten::remainder_", "Tensor"}, + {"aten::addbmm_", ""}, + {"aten::addbmm", ""}, + {"aten::addcdiv_", ""}, + {"aten::diag", ""}, + {"aten::cross", ""}, + {"aten::triu", ""}, + {"aten::tril", ""}, + {"aten::trace", ""}, + {"aten::ne", "Scalar"}, + {"aten::ne", "Tensor"}, + {"aten::eq", "Scalar"}, + {"aten::eq", "Tensor"}, + {"aten::ge", "Scalar"}, + {"aten::ge", "Tensor"}, + {"aten::le", "Scalar"}, + {"aten::le", "Tensor"}, + {"aten::gt", "Scalar"}, + {"aten::gt", "Tensor"}, + {"aten::lt", "Scalar"}, + {"aten::lt", "Tensor"}, + {"aten::take", ""}, + {"aten::index_select", ""}, + {"aten::masked_select", ""}, + {"aten::nonzero", ""}, + {"aten::nonzero_numpy", ""}, + {"aten::gather", ""}, + {"aten::_gather_sparse_backward", ""}, + {"aten::addcmul", ""}, + {"aten::addcmul_", ""}, + {"aten::addcdiv", ""}, + {"aten::lstsq", ""}, + {"aten::triangular_solve", ""}, + {"aten::_triangular_solve_helper", ""}, + {"aten::symeig", ""}, + {"aten::_symeig_helper", ""}, + {"aten::eig", ""}, + {"aten::svd", ""}, + {"aten::_svd_helper", ""}, + {"aten::cholesky", ""}, + {"aten::_cholesky_helper", ""}, + {"aten::cholesky_solve", ""}, + {"aten::_cholesky_solve_helper", ""}, + {"aten::solve", ""}, + {"aten::_solve_helper", ""}, + {"aten::cholesky_inverse", ""}, + {"aten::qr", ""}, + {"aten::_qr_helper", ""}, + {"aten::geqrf", ""}, + {"aten::orgqr", ""}, + {"aten::ormqr", ""}, + {"aten::_lu_with_info", ""}, + {"aten::lu_solve", ""}, + {"aten::_lu_solve_helper", ""}, + {"aten::_multinomial_alias_setup", ""}, + {"aten::lgamma", ""}, + {"aten::digamma", ""}, + {"aten::polygamma", ""}, + {"aten::erfinv", ""}, + {"aten::erfinv_", ""}, + {"aten::sign", ""}, + {"aten::sign_", ""}, + {"aten::dist", ""}, + {"aten::atan2", ""}, + {"aten::lerp", "Scalar"}, + {"aten::lerp", "Tensor"}, + {"aten::histc", ""}, + {"aten::fmod", "Scalar"}, + {"aten::fmod", "Tensor"}, + {"aten::remainder", "Scalar"}, + {"aten::remainder", "Tensor"}, + {"aten::min", "other"}, + {"aten::min", ""}, + {"aten::max", "other"}, + {"aten::max", ""}, + {"aten::median", ""}, + {"aten::sort", ""}, + {"aten::argsort", ""}, + {"aten::topk", ""}, + {"aten::all", ""}, + {"aten::any", ""}, + {"aten::renorm", ""}, + {"aten::unfold", ""}, + {"aten::equal", ""}, + {"aten::pow", "Tensor_Tensor"}, + {"aten::pow", "Scalar"}, + {"aten::alias", ""}, + {"aten::_addr", ""}, + {"aten::_addr_", ""}, + {"aten::_index_copy_", ""}, + {"aten::_cumsum", ""}, + {"aten::_cumprod", ""}, + {"aten::_var", ""}, + {"aten::_std", ""}, + {"aten::_cat", ""}, + {"aten::_mode", ""}, + {"aten::_max", ""}, + {"aten::_min", ""}, + {"aten::mse_loss", ""}, + {"aten::mse_loss_backward", ""}, + {"aten::l1_loss", ""}, + {"aten::l1_loss_backward", ""}, + {"aten::multilabel_margin_loss", ""}, + {"aten::multilabel_margin_loss_forward", ""}, + {"aten::multilabel_margin_loss_backward", ""}, + {"aten::smooth_l1_loss", ""}, + {"aten::smooth_l1_loss_backward", ""}, + {"aten::soft_margin_loss", ""}, + {"aten::soft_margin_loss_backward", ""}, + {"aten::elu", ""}, + {"aten::elu_backward", ""}, + {"aten::elu_", ""}, + {"aten::glu", ""}, + {"aten::glu_backward", ""}, + {"aten::hardtanh", ""}, + {"aten::hardtanh_backward", ""}, + {"aten::hardtanh_", ""}, + {"aten::leaky_relu", ""}, + {"aten::leaky_relu_backward", ""}, + {"aten::leaky_relu_", ""}, + {"aten::log_sigmoid", ""}, + {"aten::log_sigmoid_forward", ""}, + {"aten::log_sigmoid_backward", ""}, + {"aten::rrelu_with_noise_backward", ""}, + {"aten::softplus", ""}, + {"aten::softplus_backward", ""}, + {"aten::softshrink", ""}, + {"aten::softshrink_backward", ""}, + {"aten::adaptive_avg_pool2d", ""}, + {"aten::mkldnn_adaptive_avg_pool2d", ""}, + {"aten::_adaptive_avg_pool2d", ""}, + {"aten::_adaptive_avg_pool2d_backward", ""}, + {"aten::adaptive_avg_pool3d", ""}, + {"aten::adaptive_avg_pool3d_backward", ""}, + {"aten::adaptive_max_pool2d", ""}, + {"aten::adaptive_max_pool2d_backward", ""}, + {"aten::adaptive_max_pool3d", ""}, + {"aten::adaptive_max_pool3d_backward", ""}, + {"aten::avg_pool2d", ""}, + {"aten::avg_pool2d_backward", ""}, + {"aten::avg_pool3d", ""}, + {"aten::avg_pool3d_backward", ""}, + {"aten::fractional_max_pool2d", ""}, + {"aten::fractional_max_pool2d_backward", ""}, + {"aten::fractional_max_pool3d", ""}, + {"aten::fractional_max_pool3d_backward", ""}, + {"aten::max_pool2d_with_indices", ""}, + {"aten::max_pool2d_with_indices_backward", ""}, + {"aten::max_pool3d_with_indices", ""}, + {"aten::max_pool3d_with_indices_backward", ""}, + {"aten::max_unpool2d", ""}, + {"aten::max_unpool2d_backward", ""}, + {"aten::max_unpool3d", ""}, + {"aten::max_unpool3d_backward", ""}, + {"aten::reflection_pad1d", ""}, + {"aten::reflection_pad1d_backward", ""}, + {"aten::reflection_pad2d", ""}, + {"aten::reflection_pad2d_backward", ""}, + {"aten::replication_pad1d", ""}, + {"aten::replication_pad1d_backward", ""}, + {"aten::replication_pad2d", ""}, + {"aten::replication_pad2d_backward", ""}, + {"aten::replication_pad3d", ""}, + {"aten::replication_pad3d_backward", ""}, + {"aten::upsample_linear1d", ""}, + {"aten::upsample_linear1d_backward", ""}, + {"aten::upsample_bilinear2d", ""}, + {"aten::upsample_bilinear2d_backward", ""}, + {"aten::upsample_bicubic2d", ""}, + {"aten::upsample_bicubic2d_backward", ""}, + {"aten::upsample_trilinear3d", ""}, + {"aten::upsample_trilinear3d_backward", ""}, + {"aten::upsample_nearest1d", ""}, + {"aten::upsample_nearest1d_backward", ""}, + {"aten::upsample_nearest2d", ""}, + {"aten::upsample_nearest2d_backward", ""}, + {"aten::upsample_nearest3d", ""}, + {"aten::upsample_nearest3d_backward", ""}, + {"aten::sigmoid_backward", ""}, + {"aten::tanh_backward", ""}, + {"aten::slow_conv_transpose2d_backward", "output_mask"}, + {"aten::slow_conv_transpose3d_backward", "output_mask"}, + {"aten::thnn_conv2d_backward", "output_mask"}, + {"aten::thnn_conv_depthwise2d_backward", "output_mask"}, + {"aten::thnn_conv3d_backward", "output_mask"}, + {"aten::slow_conv_dilated2d_backward", ""}, + {"aten::slow_conv_dilated3d_backward", ""}, + {"aten::col2im", ""}, + {"aten::col2im_backward", ""}, + {"aten::im2col", ""}, + {"aten::im2col_backward", ""}, + {"", ""} + }; + return ops.count(std::make_pair(opName.name.c_str(), opName.overload_name.c_str())) != 0; +} + +bool aten_op_is_not_moved_to_c10_yet(const c10::OperatorName& opName) { + static std::unordered_set, OpNameHash, OpNameEquals> ops { + {"aten::backward", ""}, + {"aten::set_data", ""}, + {"aten::is_leaf", ""}, + {"aten::output_nr", ""}, + #ifdef BUILD_NAMEDTENSOR + {"aten::names_", ""}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::renamed", ""}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::align_to", ""}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::refine_names", ""}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::unflatten", ""}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::unflatten", ""}, + #endif + {"aten::_cudnn_rnn", ""}, + {"aten::_cudnn_rnn_backward", ""}, + {"aten::_cudnn_init_dropout_state", ""}, + {"aten::_fused_dropout", ""}, + {"aten::_sobol_engine_draw", ""}, + {"aten::abs", "out"}, + {"aten::acos", "out"}, + {"aten::add", "out"}, + {"aten::addmv", "out"}, + {"aten::addr", "out"}, + {"aten::all", "out"}, + {"aten::any", "out"}, + {"aten::arange", ""}, + {"aten::arange", "start"}, + {"aten::arange", "start_step"}, + {"aten::arange", "out"}, + {"aten::arange", "start_out"}, + {"aten::asin", "out"}, + {"aten::atan", "out"}, + {"aten::baddbmm", "out"}, + {"aten::bartlett_window", ""}, + {"aten::bartlett_window", "periodic"}, + {"aten::batch_norm", ""}, + {"aten::_batch_norm_impl_index", ""}, + {"aten::_batch_norm_impl_index_backward", ""}, + {"aten::bernoulli", ""}, + {"aten::bernoulli", "out"}, + {"aten::bernoulli_", "Tensor"}, + {"aten::bernoulli_", "float"}, + {"aten::bernoulli", "p"}, + {"aten::bilinear", ""}, + {"aten::binary_cross_entropy_with_logits", ""}, + {"aten::binary_cross_entropy_with_logits_backward", ""}, + {"aten::bincount", ""}, + {"aten::bitwise_not", "out"}, + {"aten::logical_not", "out"}, + {"aten::logical_xor", "out"}, + {"aten::blackman_window", ""}, + {"aten::blackman_window", "periodic"}, + {"aten::bmm", "out"}, + {"aten::cat", "out"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::cat", "names"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::cat", "names_out"}, + #endif + {"aten::ceil", "out"}, + {"aten::clamp", "out"}, + {"aten::clamp_max", "out"}, + {"aten::clamp_min", "out"}, + {"aten::contiguous", ""}, + {"aten::convolution", ""}, + {"aten::convolution_overrideable", ""}, + {"aten::convolution_backward_overrideable", ""}, + {"aten::_convolution", ""}, + {"aten::_convolution_nogroup", ""}, + {"aten::_convolution_double_backward", ""}, + {"aten::conv1d", ""}, + {"aten::conv2d", ""}, + {"aten::conv3d", ""}, + {"aten::conv_transpose1d", ""}, + {"aten::conv_transpose2d", "input"}, + {"aten::conv_transpose3d", "input"}, + {"aten::cos", "out"}, + {"aten::cosh", "out"}, + {"aten::cudnn_batch_norm", ""}, + {"aten::cudnn_batch_norm_backward", ""}, + {"aten::cudnn_convolution", ""}, + {"aten::cudnn_convolution_transpose", ""}, + {"aten::cumsum", ""}, + {"aten::cumsum", "out"}, + {"aten::cumprod", ""}, + {"aten::cumprod", "out"}, + {"aten::div", "out"}, + {"aten::dot", "out"}, + {"aten::embedding_bag", ""}, + {"aten::_embedding_bag", ""}, + {"aten::_embedding_bag_backward", ""}, + {"aten::_embedding_bag_sparse_backward", ""}, + {"aten::_embedding_bag_dense_backward", ""}, + #ifdef BUILD_NAMEDTENSOR + {"aten::empty", "names"}, + #endif + {"aten::empty", "memory_format"}, + {"aten::new_empty", ""}, + {"aten::new_full", ""}, + {"aten::_empty_affine_quantized", ""}, + {"aten::_empty_per_channel_affine_quantized_like", ""}, + {"aten::empty", "out"}, + {"aten::empty_like", "dtype"}, + {"aten::empty_strided", ""}, + {"aten::erf", "out"}, + {"aten::erfc", "out"}, + {"aten::exp", "out"}, + {"aten::expm1", "out"}, + {"aten::eye", ""}, + {"aten::eye", "m"}, + {"aten::eye", "out"}, + {"aten::eye", "m_out"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::flatten", "named_out_dim"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::flatten", "using_names"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::flatten", "DimnameList"}, + #endif + {"aten::floor", "out"}, + {"aten::frac", "out"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::full", "names"}, + #endif + {"aten::full", ""}, + {"aten::full", "out"}, + {"aten::full_like", "dtype"}, + {"aten::from_file", ""}, + {"aten::hann_window", ""}, + {"aten::hann_window", "periodic"}, + {"aten::hamming_window", ""}, + {"aten::hamming_window", "periodic"}, + {"aten::hamming_window", "periodic_alpha"}, + {"aten::hamming_window", "periodic_alpha_beta"}, + {"aten::ger", "out"}, + {"aten::group_norm", ""}, + {"aten::_cufft_set_plan_cache_max_size", ""}, + {"aten::_cufft_clear_plan_cache", ""}, + {"aten::index", "Tensor"}, + {"aten::index_put_", ""}, + {"aten::index_put", ""}, + {"aten::_index_put_impl_", ""}, + {"aten::instance_norm", ""}, + {"aten::inverse", "out"}, + {"aten::kthvalue", "values"}, + {"aten::layer_norm", ""}, + {"aten::native_layer_norm", ""}, + {"aten::native_layer_norm_backward", ""}, + {"aten::native_layer_norm_double_backward", ""}, + {"aten::linear", ""}, + {"aten::mkldnn_linear", ""}, + {"aten::linspace", ""}, + {"aten::linspace", "out"}, + {"aten::log", "out"}, + {"aten::log10", "out"}, + {"aten::log1p", "out"}, + {"aten::log2", "out"}, + {"aten::logspace", ""}, + {"aten::logspace", "out"}, + {"aten::log_softmax", ""}, + #ifdef BUILD_NAMEDTENSOR + {"aten::log_softmax", ""}, + #endif + {"aten::logsumexp", "out"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::logsumexp", "names"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::logsumexp", "names_out"}, + #endif + {"aten::matmul", "out"}, + {"aten::max", "dim_max"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::max", "names_dim"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::max", "names_dim_max"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::max_values", "names"}, + #endif + {"aten::mean", ""}, + {"aten::mean", "dim"}, + {"aten::mean", "out"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::mean", "names_dim"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::mean", "names_out"}, + #endif + {"aten::median", "dim_values"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::median", "names_dim"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::median", "names_dim_values"}, + #endif + {"aten::min", "dim_min"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::min", "names_dim"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::min", "names_dim_min"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::min_values", "names"}, + #endif + {"aten::mkldnn_convolution", ""}, + {"aten::miopen_batch_norm", ""}, + {"aten::miopen_batch_norm_backward", ""}, + {"aten::miopen_convolution", ""}, + {"aten::miopen_convolution_transpose", ""}, + {"aten::miopen_depthwise_convolution", ""}, + {"aten::miopen_rnn", ""}, + {"aten::miopen_rnn_backward", ""}, + {"aten::mm", "out"}, + {"aten::mode", "values"}, + {"aten::mul", "out"}, + {"aten::mv", "out"}, + {"aten::native_batch_norm", ""}, + {"aten::batch_norm_elemt", ""}, + {"aten::batch_norm_gather_stats", ""}, + {"aten::batch_norm_gather_stats_with_counts", ""}, + {"aten::native_batch_norm_backward", ""}, + {"aten::batch_norm_backward_reduce", ""}, + {"aten::batch_norm_backward_elemt", ""}, + {"aten::batch_norm_update_stats", ""}, + {"aten::_nnpack_spatial_convolution", ""}, + #ifdef BUILD_NAMEDTENSOR + {"aten::ones", "names"}, + #endif + {"aten::ones", ""}, + {"aten::ones", "out"}, + {"aten::ones_like", "dtype"}, + {"aten::scalar_tensor", ""}, + #ifdef BUILD_NAMEDTENSOR + {"aten::rand", "names"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::rand", "generator_with_names"}, + #endif + {"aten::rand", ""}, + {"aten::rand", "generator"}, + {"aten::rand", "out"}, + {"aten::rand", "generator_out"}, + {"aten::rand_like", "dtype"}, + {"aten::randint", ""}, + {"aten::randint", "generator"}, + {"aten::randint", "low"}, + {"aten::randint", "low_generator"}, + {"aten::randint", "out"}, + {"aten::randint", "generator_out"}, + {"aten::randint", "low_out"}, + {"aten::randint", "low_generator_out"}, + {"aten::randint_like", "dtype"}, + {"aten::randint_like", "low_dtype"}, + {"aten::randn", ""}, + {"aten::randn", "generator"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::randn", "names"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::randn", "generator_with_names"}, + #endif + {"aten::randn", "out"}, + {"aten::randn", "generator_out"}, + {"aten::randn_like", "dtype"}, + {"aten::randperm", ""}, + {"aten::randperm", "generator"}, + {"aten::randperm", "out"}, + {"aten::randperm", "generator_out"}, + {"aten::range", "step"}, + {"aten::range", ""}, + {"aten::range", "out"}, + {"aten::reciprocal", "out"}, + {"aten::neg", "out"}, + {"aten::round", "out"}, + {"aten::rrelu", ""}, + {"aten::rrelu_", ""}, + {"aten::rsqrt", "out"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::select", "Dimname"}, + #endif + {"aten::sigmoid", "out"}, + {"aten::sin", "out"}, + {"aten::sinh", "out"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::size", "Dimname"}, + #endif + {"aten::softmax", ""}, + #ifdef BUILD_NAMEDTENSOR + {"aten::softmax", ""}, + #endif + {"aten::sspaddmm", "out"}, + {"aten::stack", "out"}, + {"aten::stft", ""}, + #ifdef BUILD_NAMEDTENSOR + {"aten::stride", "Dimname"}, + #endif + {"aten::sum", ""}, + {"aten::sum", "dim_IntList"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::sum", "dim_DimnameList"}, + #endif + {"aten::sum", "IntList_out"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::sum", "DimnameList_out"}, + #endif + {"aten::sqrt", "out"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::std_mean", "names_dim"}, + #endif + {"aten::std", "out"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::std", "names_dim"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::std", "names_out"}, + #endif + {"aten::prod", ""}, + {"aten::prod", "dim_int"}, + {"aten::prod", "int_out"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::prod", "dim_Dimname"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::prod", "Dimname_out"}, + #endif + {"aten::tan", "out"}, + {"aten::tanh", "out"}, + {"aten::threshold", "out"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::transpose", "Dimname"}, + #endif + {"aten::trunc", "out"}, + {"aten::var", "out"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::var", "names_dim"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::var", "names_out"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::var_mean", "names_dim"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::zeros", "names"}, + #endif + {"aten::zeros", ""}, + {"aten::zeros", "out"}, + {"aten::zeros_like", "dtype"}, + {"aten::_standard_gamma", ""}, + {"aten::_sample_dirichlet", ""}, + {"aten::poisson", ""}, + {"aten::_sparse_sum", "dtype"}, + {"aten::_sparse_sum", "dim_dtype"}, + {"aten::norm", "ScalarOpt_dtype"}, + {"aten::norm", "ScalarOpt_dim_dtype"}, + {"aten::norm", "dtype_out"}, + {"aten::norm", "out"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::norm", "names_ScalarOpt_dim_dtype"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::norm", "names_ScalarOpt_dim"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::norm", "names_dtype_out"}, + #endif + #ifdef BUILD_NAMEDTENSOR + {"aten::norm", "names_out"}, + #endif + {"aten::frobenius_norm", "out"}, + {"aten::nuclear_norm", "out"}, + {"aten::nuclear_norm", "dim_out"}, + {"aten::pow", "Tensor_Scalar_out"}, + {"aten::sub", "out"}, + {"aten::addmm", "out"}, + {"aten::sparse_coo_tensor", "size"}, + {"aten::sparse_coo_tensor", "indices"}, + {"aten::sparse_coo_tensor", "indices_size"}, + {"aten::_sparse_coo_tensor_unsafe", ""}, + {"aten::_sparse_coo_tensor_with_dims", ""}, + {"aten::_sparse_coo_tensor_with_dims_and_tensors", ""}, + {"aten::hspmm", "out"}, + #ifdef BUILD_NAMEDTENSOR + {"aten::unbind", "Dimname"}, + #endif + {"aten::quantize_linear", ""}, + {"aten::quantize_linear_per_channel", ""}, + {"aten::_dequantize_linear", ""}, + {"aten::q_per_channel_axis", ""}, + {"aten::qscheme", ""}, + {"aten::to", "dtype_layout"}, + {"aten::to", "device"}, + {"aten::to", "dtype"}, + {"aten::_thnn_fused_lstm_cell", ""}, + {"aten::_thnn_fused_lstm_cell_backward", ""}, + {"aten::_thnn_fused_gru_cell", ""}, + {"aten::lstm_cell", ""}, + {"aten::gru_cell", ""}, + {"aten::rnn_tanh_cell", ""}, + {"aten::rnn_relu_cell", ""}, + {"aten::quantized_lstm", ""}, + {"aten::set_", "source_Storage"}, + {"aten::set_", "source_Storage_storage_offset"}, + {"aten::set_quantizer_", ""}, + {"aten::addbmm", "out"}, + {"aten::random_", "from"}, + {"aten::random_", "to"}, + {"aten::random_", ""}, + {"aten::uniform_", ""}, + {"aten::normal_", ""}, + {"aten::cauchy_", ""}, + {"aten::log_normal_", ""}, + {"aten::exponential_", ""}, + {"aten::geometric_", ""}, + {"aten::diag", "out"}, + {"aten::cross", "out"}, + {"aten::triu", "out"}, + {"aten::tril", "out"}, + {"aten::tril_indices", ""}, + {"aten::triu_indices", ""}, + {"aten::ne", "Scalar_out"}, + {"aten::ne", "Tensor_out"}, + {"aten::eq", "Scalar_out"}, + {"aten::eq", "Tensor_out"}, + {"aten::ge", "Scalar_out"}, + {"aten::ge", "Tensor_out"}, + {"aten::le", "Scalar_out"}, + {"aten::le", "Tensor_out"}, + {"aten::gt", "Scalar_out"}, + {"aten::gt", "Tensor_out"}, + {"aten::lt", "Scalar_out"}, + {"aten::lt", "Tensor_out"}, + {"aten::take", "out"}, + {"aten::index_select", "out"}, + {"aten::masked_select", "out"}, + {"aten::nonzero", "out"}, + {"aten::gather", "out"}, + {"aten::addcmul", "out"}, + {"aten::addcdiv", "out"}, + {"aten::lstsq", "X"}, + {"aten::triangular_solve", "X"}, + {"aten::symeig", "e"}, + {"aten::eig", "e"}, + {"aten::svd", "U"}, + {"aten::cholesky", "out"}, + {"aten::cholesky_solve", "out"}, + {"aten::solve", "solution"}, + {"aten::cholesky_inverse", "out"}, + {"aten::qr", "Q"}, + {"aten::geqrf", "a"}, + {"aten::orgqr", "out"}, + {"aten::ormqr", "out"}, + {"aten::lu_solve", "out"}, + {"aten::multinomial", "out"}, + {"aten::multinomial", ""}, + {"aten::_multinomial_alias_draw", ""}, + {"aten::lgamma", "out"}, + {"aten::digamma", "out"}, + {"aten::polygamma", "out"}, + {"aten::erfinv", "out"}, + {"aten::sign", "out"}, + {"aten::atan2", "out"}, + {"aten::lerp", "Scalar_out"}, + {"aten::lerp", "Tensor_out"}, + {"aten::histc", "out"}, + {"aten::fmod", "Scalar_out"}, + {"aten::fmod", "Tensor_out"}, + {"aten::remainder", "Scalar_out"}, + {"aten::remainder", "Tensor_out"}, + {"aten::min", "out"}, + {"aten::max", "out"}, + {"aten::sort", "values"}, + {"aten::topk", "values"}, + {"aten::renorm", "out"}, + {"aten::pow", "Tensor_Tensor_out"}, + {"aten::pow", "Scalar_out"}, + {"aten::normal", "Tensor_float_out"}, + {"aten::normal", "Tensor_float"}, + {"aten::normal", "float_Tensor_out"}, + {"aten::normal", "float_Tensor"}, + {"aten::normal", "Tensor_Tensor_out"}, + {"aten::normal", "Tensor_Tensor"}, + {"aten::normal", "float_float"}, + {"aten::normal", "float_float_out"}, + {"aten::_addr", "out"}, + {"aten::_cumsum", "out"}, + {"aten::_cumprod", "out"}, + {"aten::_cat", "out"}, + {"aten::_mode", "values"}, + {"aten::_max", "max"}, + {"aten::_min", "min"}, + {"aten::binary_cross_entropy", "out"}, + {"aten::binary_cross_entropy", ""}, + {"aten::binary_cross_entropy_backward", "grad_input"}, + {"aten::binary_cross_entropy_backward", ""}, + {"aten::mse_loss", "out"}, + {"aten::mse_loss_backward", "grad_input"}, + {"aten::l1_loss", "out"}, + {"aten::l1_loss_backward", "grad_input"}, + {"aten::multi_margin_loss", "out"}, + {"aten::multi_margin_loss", ""}, + {"aten::multi_margin_loss_backward", "grad_input"}, + {"aten::multi_margin_loss_backward", ""}, + {"aten::multilabel_margin_loss", "out"}, + {"aten::multilabel_margin_loss_forward", "output"}, + {"aten::multilabel_margin_loss_backward", "grad_input"}, + {"aten::nll_loss", "out"}, + {"aten::nll_loss", ""}, + {"aten::nll_loss_forward", "output"}, + {"aten::nll_loss_forward", ""}, + {"aten::nll_loss_backward", "grad_input"}, + {"aten::nll_loss_backward", ""}, + {"aten::nll_loss2d", "out"}, + {"aten::nll_loss2d", ""}, + {"aten::nll_loss2d_forward", "output"}, + {"aten::nll_loss2d_forward", ""}, + {"aten::nll_loss2d_backward", "grad_input"}, + {"aten::nll_loss2d_backward", ""}, + {"aten::smooth_l1_loss", "out"}, + {"aten::smooth_l1_loss_backward", "grad_input"}, + {"aten::soft_margin_loss", "out"}, + {"aten::soft_margin_loss_backward", "grad_input"}, + {"aten::elu", "out"}, + {"aten::elu_backward", "grad_input"}, + {"aten::glu", "out"}, + {"aten::glu_backward", "grad_input"}, + {"aten::hardtanh", "out"}, + {"aten::hardtanh_backward", "grad_input"}, + {"aten::leaky_relu", "out"}, + {"aten::leaky_relu_backward", "grad_input"}, + {"aten::log_sigmoid", "out"}, + {"aten::log_sigmoid_forward", "output"}, + {"aten::log_sigmoid_backward", "grad_input"}, + {"aten::rrelu_with_noise", "out"}, + {"aten::rrelu_with_noise", ""}, + {"aten::rrelu_with_noise_backward", "grad_input"}, + {"aten::rrelu_with_noise_", ""}, + {"aten::softplus", "out"}, + {"aten::softplus_backward", "grad_input"}, + {"aten::softshrink", "out"}, + {"aten::softshrink_backward", "grad_input"}, + {"aten::adaptive_avg_pool2d", "out"}, + {"aten::adaptive_avg_pool3d", "out"}, + {"aten::adaptive_avg_pool3d_backward", "grad_input"}, + {"aten::adaptive_max_pool2d", "out"}, + {"aten::adaptive_max_pool2d_backward", "grad_input"}, + {"aten::adaptive_max_pool3d", "out"}, + {"aten::adaptive_max_pool3d_backward", "grad_input"}, + {"aten::avg_pool2d", "out"}, + {"aten::avg_pool2d_backward", "grad_input"}, + {"aten::avg_pool3d", "out"}, + {"aten::avg_pool3d_backward", "grad_input"}, + {"aten::fractional_max_pool2d", "output"}, + {"aten::fractional_max_pool2d_backward", "grad_input"}, + {"aten::fractional_max_pool3d", "output"}, + {"aten::fractional_max_pool3d_backward", "grad_input"}, + {"aten::max_pool2d_with_indices", "out"}, + {"aten::max_pool2d_with_indices_backward", "grad_input"}, + {"aten::max_pool3d_with_indices", "out"}, + {"aten::max_pool3d_with_indices_backward", "grad_input"}, + {"aten::max_unpool2d", "out"}, + {"aten::max_unpool2d_backward", "grad_input"}, + {"aten::max_unpool3d", "out"}, + {"aten::max_unpool3d_backward", "grad_input"}, + {"aten::reflection_pad1d", "out"}, + {"aten::reflection_pad1d_backward", "grad_input"}, + {"aten::reflection_pad2d", "out"}, + {"aten::reflection_pad2d_backward", "grad_input"}, + {"aten::replication_pad1d", "out"}, + {"aten::replication_pad1d_backward", "grad_input"}, + {"aten::replication_pad2d", "out"}, + {"aten::replication_pad2d_backward", "grad_input"}, + {"aten::replication_pad3d", "out"}, + {"aten::replication_pad3d_backward", "grad_input"}, + {"aten::upsample_linear1d", "out"}, + {"aten::upsample_linear1d_backward", "grad_input"}, + {"aten::upsample_bilinear2d", "out"}, + {"aten::upsample_bilinear2d_backward", "grad_input"}, + {"aten::upsample_bicubic2d", "out"}, + {"aten::upsample_bicubic2d_backward", "grad_input"}, + {"aten::upsample_trilinear3d", "out"}, + {"aten::upsample_trilinear3d_backward", "grad_input"}, + {"aten::upsample_nearest1d", "out"}, + {"aten::upsample_nearest1d_backward", "grad_input"}, + {"aten::upsample_nearest2d", "out"}, + {"aten::upsample_nearest2d_backward", "grad_input"}, + {"aten::upsample_nearest3d", "out"}, + {"aten::upsample_nearest3d_backward", "grad_input"}, + {"aten::sigmoid_backward", "grad_input"}, + {"aten::tanh_backward", "grad_input"}, + {"aten::slow_conv_transpose2d", "out"}, + {"aten::slow_conv_transpose2d", ""}, + {"aten::slow_conv_transpose2d_backward", "grad_output"}, + {"aten::slow_conv_transpose3d", "out"}, + {"aten::slow_conv_transpose3d", ""}, + {"aten::slow_conv_transpose3d_backward", "grad_output"}, + {"aten::thnn_conv2d", "out"}, + {"aten::thnn_conv2d", ""}, + {"aten::thnn_conv2d_forward", "output"}, + {"aten::thnn_conv2d_forward", ""}, + {"aten::thnn_conv2d_backward", "grad_input"}, + {"aten::thnn_conv_depthwise2d", "out"}, + {"aten::thnn_conv_depthwise2d", ""}, + {"aten::thnn_conv_depthwise2d_forward", "out"}, + {"aten::thnn_conv_depthwise2d_forward", ""}, + {"aten::thnn_conv_depthwise2d_backward", "grad_input"}, + {"aten::thnn_conv3d", "out"}, + {"aten::thnn_conv3d", ""}, + {"aten::thnn_conv3d_forward", "output"}, + {"aten::thnn_conv3d_forward", ""}, + {"aten::thnn_conv3d_backward", "grad_input"}, + {"aten::slow_conv_dilated2d", ""}, + {"aten::slow_conv_dilated3d", ""}, + {"aten::col2im", "out"}, + {"aten::col2im_backward", "grad_input"}, + {"aten::im2col", "out"}, + {"aten::im2col_backward", "grad_input"}, + {"", ""} + }; + return ops.count(std::make_pair(opName.name.c_str(), opName.overload_name.c_str())) != 0; +} + +} diff --git a/aten/src/ATen/core/OpsAlreadyMovedToC10.h b/aten/src/ATen/core/OpsAlreadyMovedToC10.h new file mode 100644 index 0000000000000..73cea49c50953 --- /dev/null +++ b/aten/src/ATen/core/OpsAlreadyMovedToC10.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +namespace c10 { +struct OperatorName; +} + +namespace at { + +/* +There's semantically three sets of operators: + +- aten_ops_already_moved_to_c10 +- aten_ops_not_moved_to_c10_yet +- non_aten_ops (e.g. custom ops) + +register_c10_ops.cpp needs to decide between aten_ops_already_moved_to_c10 +and union(aten_ops_not_moved_to_c10_yet, non_aten_ops). +The c10 operator registry needs to decide between aten_ops_not_moved_to_c10_yet +and union(aten_ops_already_moved_to_c10, non_aten_ops), which is different to what +register_c10_ops.cpp needs. We need to store two sets to be able to make both decisions. +*/ + +// list of ATen ops that got already moved to the c10 dispatcher +CAFFE2_API bool aten_op_is_already_moved_to_c10(const c10::OperatorName& opName); + +// list of ATen ops that are still on the globalATenDispatch dispatcher. +CAFFE2_API bool aten_op_is_not_moved_to_c10_yet(const c10::OperatorName& opName); + +} diff --git a/aten/src/ATen/core/TensorAccessor.h b/aten/src/ATen/core/TensorAccessor.h index 0116964e03622..95f37fcb09510 100644 --- a/aten/src/ATen/core/TensorAccessor.h +++ b/aten/src/ATen/core/TensorAccessor.h @@ -1,12 +1,13 @@ #pragma once #include +#include #include #include namespace at { -// The PtrTraits argument to the TensorAccessor/PackedTensorAccessor +// The PtrTraits argument to the TensorAccessor/GenericPackedTensorAccessor // is used to enable the __restrict__ keyword/modifier for the data // passed to cuda. template @@ -62,7 +63,7 @@ class TensorAccessorBase { // The `TensorAccessor` is typically instantiated for CPU `Tensor`s using // `Tensor.accessor()`. -// For CUDA `Tensor`s, `PackedTensorAccessor` is used on the host and only +// For CUDA `Tensor`s, `GenericPackedTensorAccessor` is used on the host and only // indexing on the device uses `TensorAccessor`s. template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> class TensorAccessor : public TensorAccessorBase { @@ -103,7 +104,7 @@ class TensorAccessor : public TensorAccessorBase : public TensorAccessorBase class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> -class PackedTensorAccessorBase { +class GenericPackedTensorAccessorBase { public: typedef typename PtrTraits::PtrType PtrType; - C10_HOST PackedTensorAccessorBase( + C10_HOST GenericPackedTensorAccessorBase( PtrType data_, const index_t* sizes_, const index_t* strides_) @@ -126,7 +127,7 @@ class PackedTensorAccessorBase { // if index_t is not int64_t, we want to have an int64_t constructor template ::value>::type> - C10_HOST PackedTensorAccessorBase( + C10_HOST GenericPackedTensorAccessorBase( PtrType data_, const source_index_t* sizes_, const source_index_t* strides_) @@ -156,23 +157,23 @@ class PackedTensorAccessorBase { }; template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> -class PackedTensorAccessor : public PackedTensorAccessorBase { +class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase { public: typedef typename PtrTraits::PtrType PtrType; - C10_HOST PackedTensorAccessor( + C10_HOST GenericPackedTensorAccessor( PtrType data_, const index_t* sizes_, const index_t* strides_) - : PackedTensorAccessorBase(data_, sizes_, strides_) {} + : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} // if index_t is not int64_t, we want to have an int64_t constructor template ::value>::type> - C10_HOST PackedTensorAccessor( + C10_HOST GenericPackedTensorAccessor( PtrType data_, const source_index_t* sizes_, const source_index_t* strides_) - : PackedTensorAccessorBase(data_, sizes_, strides_) {} + : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} C10_DEVICE TensorAccessor operator[](index_t i) { index_t* new_sizes = this->sizes_ + 1; @@ -188,22 +189,22 @@ class PackedTensorAccessor : public PackedTensorAccessorBase class PtrTraits, typename index_t> -class PackedTensorAccessor : public PackedTensorAccessorBase { +class GenericPackedTensorAccessor : public GenericPackedTensorAccessorBase { public: typedef typename PtrTraits::PtrType PtrType; - C10_HOST PackedTensorAccessor( + C10_HOST GenericPackedTensorAccessor( PtrType data_, const index_t* sizes_, const index_t* strides_) - : PackedTensorAccessorBase(data_, sizes_, strides_) {} + : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} // if index_t is not int64_t, we want to have an int64_t constructor template ::value>::type> - C10_HOST PackedTensorAccessor( + C10_HOST GenericPackedTensorAccessor( PtrType data_, const source_index_t* sizes_, const source_index_t* strides_) - : PackedTensorAccessorBase(data_, sizes_, strides_) {} + : GenericPackedTensorAccessorBase(data_, sizes_, strides_) {} C10_DEVICE T & operator[](index_t i) { return this->data_[this->strides_[0] * i]; @@ -213,4 +214,19 @@ class PackedTensorAccessor : public PackedTensorAccessorB } }; -} + +// Can't put this directly into the macro function args because of commas +#define AT_X GenericPackedTensorAccessor + +// Old name for `GenericPackedTensorAccessor` +template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> +C10_DEFINE_DEPRECATED_USING(PackedTensorAccessor, AT_X) + +#undef AT_X + +template class PtrTraits = DefaultPtrTraits> +using PackedTensorAccessor32 = GenericPackedTensorAccessor; + +template class PtrTraits = DefaultPtrTraits> +using PackedTensorAccessor64 = GenericPackedTensorAccessor; +} // namespace at diff --git a/aten/src/ATen/core/TensorBody.h b/aten/src/ATen/core/TensorBody.h index 30b36cd4509de..1549e470478dd 100644 --- a/aten/src/ATen/core/TensorBody.h +++ b/aten/src/ATen/core/TensorBody.h @@ -11,14 +11,14 @@ #include #include #include +#include #include #include #include #include #include -#ifdef BUILD_NAMEDTENSOR +#include #include -#endif namespace caffe2 { class Tensor; @@ -42,6 +42,7 @@ struct Quantizer; // This is temporary typedef to enable Quantizer in aten native function API // we'll remove them when we are actually exposing Quantizer class // to frontend +using QuantizerPtr = c10::intrusive_ptr; using ConstQuantizerPtr = const c10::intrusive_ptr&; // Tensor is a "generic" object holding a pointer to the underlying TensorImpl object, which @@ -219,12 +220,12 @@ class CAFFE2_API Tensor { DeprecatedTypeProperties & type() const { return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( - tensorTypeIdToBackend(type_id()), + tensorTypeIdToBackend(legacyExtractTypeId(type_set())), scalar_type(), is_variable()); } - TensorTypeId type_id() const { - return impl_->type_id(); + TensorTypeSet type_set() const { + return impl_->type_set(); } ScalarType scalar_type() const { return typeMetaToScalarType(impl_->dtype()); @@ -274,6 +275,10 @@ class CAFFE2_API Tensor { /// Returns if a `Tensor` has quantized backend. bool is_quantized() const; + /// If a tensor is a quantized tensor, returns its quantizer + /// TODO: it's not in native_functions.yaml yet as it's not exposed to python + QuantizerPtr quantizer() const; + #ifdef BUILD_NAMEDTENSOR /// Returns if a `Tensor` has any dimension names bool has_names() const; @@ -317,19 +322,42 @@ class CAFFE2_API Tensor { template TensorAccessor accessor() && = delete; - // Return a `PackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and + // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and // dimension. You can optionally specify RestrictPtrTraits as a template parameter to // cast the data pointer to a __restrict__ pointer. - // In order to use this, your CUDA kernel has to take a corresponding PackedTensorAccessor + // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor // as an argument. template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> - PackedTensorAccessor packed_accessor() const& { + GenericPackedTensorAccessor generic_packed_accessor() const& { static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr()"); TORCH_CHECK(dim() == N, "expected ", N, " dims but tensor has ", dim()); - return PackedTensorAccessor(static_cast::PtrType>(data_ptr()),sizes().data(),strides().data()); + return GenericPackedTensorAccessor(static_cast::PtrType>(data_ptr()),sizes().data(),strides().data()); } - template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> - PackedTensorAccessor packed_accessor() && = delete; + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + GenericPackedTensorAccessor generic_packed_accessor() && = delete; + + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor32 packed_accessor32() const& { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor32 packed_accessor32() && = delete; + + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor64 packed_accessor64() const& { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor64 packed_accessor64() && = delete; + + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") + GenericPackedTensorAccessor packed_accessor() const & { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") + GenericPackedTensorAccessor packed_accessor() && = delete; Tensor operator-() const; Tensor& operator+=(const Tensor & other); @@ -372,15 +400,30 @@ class CAFFE2_API Tensor { //Tensor * add(Tensor & b); void backward(const Tensor & gradient={}, bool keep_graph=false, bool create_graph=false) const; void set_data(const Tensor & new_data) const; + Tensor data() const; + bool is_leaf() const; + int64_t output_nr() const; #ifdef BUILD_NAMEDTENSOR Tensor & names_(c10::optional names) const; #endif #ifdef BUILD_NAMEDTENSOR - Tensor view_names(c10::optional names) const; + Tensor renamed(c10::optional names) const; #endif #ifdef BUILD_NAMEDTENSOR Tensor align_to(DimnameList names) const; #endif + #ifdef BUILD_NAMEDTENSOR + Tensor align_as(const Tensor & other) const; + #endif + #ifdef BUILD_NAMEDTENSOR + Tensor refine_names(DimnameList names) const; + #endif + #ifdef BUILD_NAMEDTENSOR + Tensor unflatten(Dimname dim, IntArrayRef sizes, DimnameList names) const; + #endif + #ifdef BUILD_NAMEDTENSOR + Tensor unflatten(int64_t dim, IntArrayRef sizes, DimnameList names) const; + #endif Tensor abs() const; Tensor & abs_() const; Tensor acos() const; @@ -459,6 +502,15 @@ class CAFFE2_API Tensor { Tensor expand(IntArrayRef size, bool implicit=false) const; Tensor expand_as(const Tensor & other) const; Tensor flatten(int64_t start_dim=0, int64_t end_dim=-1) const; + #ifdef BUILD_NAMEDTENSOR + Tensor flatten(int64_t start_dim, int64_t end_dim, Dimname out_dim) const; + #endif + #ifdef BUILD_NAMEDTENSOR + Tensor flatten(Dimname start_dim, Dimname end_dim, Dimname out_dim) const; + #endif + #ifdef BUILD_NAMEDTENSOR + Tensor flatten(DimnameList dims, Dimname out_dim) const; + #endif Tensor & fill_(Scalar value) const; Tensor & fill_(const Tensor & value) const; Tensor floor() const; @@ -680,6 +732,9 @@ class CAFFE2_API Tensor { Tensor values() const; int64_t numel() const; std::vector unbind(int64_t dim=0) const; + #ifdef BUILD_NAMEDTENSOR + std::vector unbind(Dimname dim) const; + #endif Tensor to_sparse(int64_t sparse_dim) const; Tensor to_sparse() const; Tensor to_mkldnn() const; @@ -688,6 +743,7 @@ class CAFFE2_API Tensor { int64_t q_zero_point() const; Tensor q_per_channel_scales() const; Tensor q_per_channel_zero_points() const; + IntArrayRef q_per_channel_axis() const; Tensor int_repr() const; QScheme qscheme() const; Tensor to(const TensorOptions & options, bool non_blocking=false, bool copy=false) const; @@ -878,7 +934,7 @@ class CAFFE2_API Tensor { }; namespace detail { -// Helper creator for Tensor clas which doesn't requires the users to pass +// Helper creator for Tensor class which doesn't requires the users to pass // in an intrusive_ptr instead it just converts the argument passed to // requested intrusive_ptr type. template @@ -886,23 +942,10 @@ Tensor make_tensor(Args&&... args) { return Tensor(c10::make_intrusive(std::forward(args)...)); } -inline Backend infer_backend(const Tensor & t) { - TORCH_CHECK(t.defined(), "undefined Tensor"); - return tensorTypeIdToBackend(t.type_id()); -} -inline Backend infer_backend(const TensorList & tl) { - TORCH_CHECK(tl.size() > 0, "expected a non-empty list of Tensors"); - return tensorTypeIdToBackend(tl[0].type_id()); -} +} // namespace detail -inline bool infer_is_variable(const Tensor & t) { - TORCH_CHECK(t.defined(), "undefined Tensor"); - return t.is_variable(); +static inline TensorTypeId legacyExtractTypeId(const Tensor& t) { + return legacyExtractTypeId(t.type_set()); } -inline bool infer_is_variable(const TensorList & tl) { - TORCH_CHECK(tl.size() > 0, "expected a non-empty list of Tensors"); - return tl[0].is_variable(); -} -} // namespace detail } // namespace at diff --git a/aten/src/ATen/core/TensorMethods.h b/aten/src/ATen/core/TensorMethods.h index 2b4a094f70f47..0db5a861f4b3d 100644 --- a/aten/src/ATen/core/TensorMethods.h +++ b/aten/src/ATen/core/TensorMethods.h @@ -8,12 +8,11 @@ #include #include #include -#if !defined(CAFFE2_IS_XPLAT_BUILD) #include -#endif -#ifdef BUILD_NAMEDTENSOR +#include #include -#endif +#include + #ifdef USE_STATIC_DISPATCH #include #include @@ -23,6 +22,30 @@ namespace at { +namespace detail { + +struct MultiDispatchTensorTypeSet : IterArgs { + TensorTypeSet ts; + void operator()(const at::Tensor& x) { + ts = ts | x.type_set(); + } + void operator()(TensorOptions x) { + ts = ts | x.type_set(); + } + void operator()(at::ArrayRef xs) { + for (const auto& x : xs) { + ts = ts | x.type_set(); + } + } +}; + +template +TensorTypeSet multi_dispatch_tensor_type_set(Args&&... args) { + return MultiDispatchTensorTypeSet().apply(std::forward(args)...).ts; +} + +} + struct Quantizer; // This is temporary typedef to enable Quantizer in aten native function API // we'll remove them when we are actually exposing Quantizer class @@ -64,7 +87,7 @@ inline void Tensor::backward(const Tensor & gradient, bool keep_graph, bool crea TypeDefault::backward(const_cast(*this), gradient, keep_graph, create_graph); #else static auto table = globalATenDispatch().getOpTable("aten::backward(Tensor self, Tensor? gradient=None, bool keep_graph=False, bool create_graph=False) -> void"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), gradient, keep_graph, create_graph); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, gradient))(const_cast(*this), gradient, keep_graph, create_graph); #endif } inline void Tensor::set_data(const Tensor & new_data) const { @@ -72,7 +95,32 @@ inline void Tensor::set_data(const Tensor & new_data) const { TypeDefault::set_data(const_cast(*this), new_data); #else static auto table = globalATenDispatch().getOpTable("aten::set_data(Tensor(a!) self, Tensor new_data) -> void"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), new_data); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, new_data))(const_cast(*this), new_data); +#endif +} +inline Tensor Tensor::data() const { +#ifdef USE_STATIC_DISPATCH + return TypeDefault::data(const_cast(*this)); +#else + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::data", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); +#endif +} +inline bool Tensor::is_leaf() const { +#ifdef USE_STATIC_DISPATCH + return TypeDefault::is_leaf(const_cast(*this)); +#else + static auto table = globalATenDispatch().getOpTable("aten::is_leaf(Tensor self) -> bool"); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); +#endif +} +inline int64_t Tensor::output_nr() const { +#ifdef USE_STATIC_DISPATCH + return TypeDefault::output_nr(const_cast(*this)); +#else + static auto table = globalATenDispatch().getOpTable("aten::output_nr(Tensor self) -> int"); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); #endif } #ifdef BUILD_NAMEDTENSOR @@ -81,17 +129,17 @@ inline Tensor & Tensor::names_(c10::optional names) const { return TypeDefault::names_(const_cast(*this), names); #else static auto table = globalATenDispatch().getOpTable("aten::names_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), names); + return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), names); #endif } #endif #ifdef BUILD_NAMEDTENSOR -inline Tensor Tensor::view_names(c10::optional names) const { +inline Tensor Tensor::renamed(c10::optional names) const { #ifdef USE_STATIC_DISPATCH - return TypeDefault::view_names(const_cast(*this), names); + return TypeDefault::renamed(const_cast(*this), names); #else - static auto table = globalATenDispatch().getOpTable("aten::view_names(Tensor(a) self, Dimname[]? names) -> Tensor(a)"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), names); + static auto table = globalATenDispatch().getOpTable("aten::renamed(Tensor(a) self, Dimname[]? names) -> Tensor(a)"); + return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), names); #endif } #endif @@ -100,8 +148,49 @@ inline Tensor Tensor::align_to(DimnameList names) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::align_to(const_cast(*this), names); #else - static auto table = globalATenDispatch().getOpTable("aten::align_to(Tensor self, DimnameList names) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), names); + static auto table = globalATenDispatch().getOpTable("aten::align_to(Tensor(a) self, DimnameList names) -> Tensor(a)"); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), names); +#endif +} +#endif +#ifdef BUILD_NAMEDTENSOR +inline Tensor Tensor::align_as(const Tensor & other) const { +#ifdef USE_STATIC_DISPATCH + return TypeDefault::align_as(const_cast(*this), other); +#else + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::align_as", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); +#endif +} +#endif +#ifdef BUILD_NAMEDTENSOR +inline Tensor Tensor::refine_names(DimnameList names) const { +#ifdef USE_STATIC_DISPATCH + return TypeDefault::refine_names(const_cast(*this), names); +#else + static auto table = globalATenDispatch().getOpTable("aten::refine_names(Tensor(a) self, DimnameList names) -> Tensor(a)"); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), names); +#endif +} +#endif +#ifdef BUILD_NAMEDTENSOR +inline Tensor Tensor::unflatten(Dimname dim, IntArrayRef sizes, DimnameList names) const { +#ifdef USE_STATIC_DISPATCH + return TypeDefault::unflatten(const_cast(*this), dim, sizes, names); +#else + static auto table = globalATenDispatch().getOpTable("aten::unflatten(Tensor self, Dimname dim, int[] sizes, DimnameList names) -> Tensor"); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, sizes, names); +#endif +} +#endif +#ifdef BUILD_NAMEDTENSOR +inline Tensor Tensor::unflatten(int64_t dim, IntArrayRef sizes, DimnameList names) const { +#ifdef USE_STATIC_DISPATCH + return TypeDefault::unflatten(const_cast(*this), dim, sizes, names); +#else + static auto table = globalATenDispatch().getOpTable("aten::unflatten(Tensor self, int dim, int[] sizes, DimnameList names) -> Tensor"); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, sizes, names); #endif } #endif @@ -109,49 +198,53 @@ inline Tensor Tensor::abs() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::abs(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::abs(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::abs", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::abs_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::abs_(const_cast(*this)); break; default: - AT_ERROR("abs_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("abs_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::abs_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::abs_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::acos() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::acos(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::acos(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::acos", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::acos_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::acos_(const_cast(*this)); break; default: - AT_ERROR("acos_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("acos_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::acos_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::acos_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::add(const Tensor & other, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::add(const_cast(*this), other, alpha); break; @@ -159,16 +252,17 @@ inline Tensor Tensor::add(const Tensor & other, Scalar alpha) const { return SparseCPUType::add(const_cast(*this), other, alpha); break; default: - AT_ERROR("add not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("add not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::add", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other, alpha); #endif } inline Tensor & Tensor::add_(const Tensor & other, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::add_(const_cast(*this), other, alpha); break; @@ -176,116 +270,128 @@ inline Tensor & Tensor::add_(const Tensor & other, Scalar alpha) const { return SparseCPUType::add_(const_cast(*this), other, alpha); break; default: - AT_ERROR("add_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("add_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::add_", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other, alpha); #endif } inline Tensor Tensor::add(Scalar other, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::add(const_cast(*this), other, alpha); #else - static auto table = globalATenDispatch().getOpTable("aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::add", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other, alpha); #endif } inline Tensor & Tensor::add_(Scalar other, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::add_(const_cast(*this), other, alpha); #else - static auto table = globalATenDispatch().getOpTable("aten::add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::add_", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other, alpha); #endif } inline Tensor Tensor::addmv(const Tensor & mat, const Tensor & vec, Scalar beta, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::addmv(const_cast(*this), mat, vec, beta, alpha); break; default: - AT_ERROR("addmv not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("addmv not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), mat, vec, beta, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addmv", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mat, vec))) + .callUnboxed(const_cast(*this), mat, vec, beta, alpha); #endif } inline Tensor & Tensor::addmv_(const Tensor & mat, const Tensor & vec, Scalar beta, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::addmv_(const_cast(*this), mat, vec, beta, alpha); break; default: - AT_ERROR("addmv_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("addmv_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), mat, vec, beta, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addmv_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mat, vec))) + .callUnboxed(const_cast(*this), mat, vec, beta, alpha); #endif } inline Tensor Tensor::addr(const Tensor & vec1, const Tensor & vec2, Scalar beta, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::addr(const_cast(*this), vec1, vec2, beta, alpha); #else - static auto table = globalATenDispatch().getOpTable("aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), vec1, vec2, beta, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addr", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, vec1, vec2))) + .callUnboxed(const_cast(*this), vec1, vec2, beta, alpha); #endif } inline Tensor & Tensor::addr_(const Tensor & vec1, const Tensor & vec2, Scalar beta, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::addr_(const_cast(*this), vec1, vec2, beta, alpha); #else - static auto table = globalATenDispatch().getOpTable("aten::addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), vec1, vec2, beta, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addr_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, vec1, vec2))) + .callUnboxed(const_cast(*this), vec1, vec2, beta, alpha); #endif } inline Tensor Tensor::all(int64_t dim, bool keepdim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::all(const_cast(*this), dim, keepdim); #else - static auto table = globalATenDispatch().getOpTable("aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::all", "dim"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim, keepdim); #endif } inline bool Tensor::allclose(const Tensor & other, double rtol, double atol, bool equal_nan) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::allclose(const_cast(*this), other, rtol, atol, equal_nan); #else - static auto table = globalATenDispatch().getOpTable("aten::allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other, rtol, atol, equal_nan); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::allclose", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other, rtol, atol, equal_nan); #endif } inline Tensor Tensor::any(int64_t dim, bool keepdim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::any(const_cast(*this), dim, keepdim); #else - static auto table = globalATenDispatch().getOpTable("aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::any", "dim"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim, keepdim); #endif } inline Tensor Tensor::argmax(c10::optional dim, bool keepdim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::argmax(const_cast(*this), dim, keepdim); #else - static auto table = globalATenDispatch().getOpTable("aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"); - return table->getOp, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::argmax", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, bool>(const_cast(*this), dim, keepdim); #endif } inline Tensor Tensor::argmin(c10::optional dim, bool keepdim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::argmin(const_cast(*this), dim, keepdim); #else - static auto table = globalATenDispatch().getOpTable("aten::argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor"); - return table->getOp, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::argmin", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, bool>(const_cast(*this), dim, keepdim); #endif } inline Tensor Tensor::as_strided(IntArrayRef size, IntArrayRef stride, c10::optional storage_offset) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::as_strided(const_cast(*this), size, stride, storage_offset); break; @@ -293,91 +399,99 @@ inline Tensor Tensor::as_strided(IntArrayRef size, IntArrayRef stride, c10::opti return QuantizedCPUType::as_strided(const_cast(*this), size, stride, storage_offset); break; default: - AT_ERROR("as_strided not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("as_strided not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a)"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), size, stride, storage_offset); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::as_strided", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed>(const_cast(*this), size, stride, storage_offset); #endif } inline Tensor & Tensor::as_strided_(IntArrayRef size, IntArrayRef stride, c10::optional storage_offset) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::as_strided_(const_cast(*this), size, stride, storage_offset); #else - static auto table = globalATenDispatch().getOpTable("aten::as_strided_(Tensor(a!) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a!)"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), size, stride, storage_offset); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::as_strided_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed>(const_cast(*this), size, stride, storage_offset); #endif } inline Tensor Tensor::asin() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::asin(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::asin(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::asin", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::asin_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::asin_(const_cast(*this)); break; default: - AT_ERROR("asin_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("asin_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::asin_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::asin_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::atan() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::atan(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::atan(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::atan", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::atan_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::atan_(const_cast(*this)); break; default: - AT_ERROR("atan_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("atan_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::atan_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::atan_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::baddbmm(const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::baddbmm(const_cast(*this), batch1, batch2, beta, alpha); break; default: - AT_ERROR("baddbmm not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("baddbmm not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), batch1, batch2, beta, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::baddbmm", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, batch1, batch2))) + .callUnboxed(const_cast(*this), batch1, batch2, beta, alpha); #endif } inline Tensor & Tensor::baddbmm_(const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::baddbmm_(const_cast(*this), batch1, batch2, beta, alpha); break; default: - AT_ERROR("baddbmm_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("baddbmm_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::baddbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), batch1, batch2, beta, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::baddbmm_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, batch1, batch2))) + .callUnboxed(const_cast(*this), batch1, batch2, beta, alpha); #endif } inline Tensor Tensor::bernoulli(Generator * generator) const { @@ -385,35 +499,35 @@ inline Tensor Tensor::bernoulli(Generator * generator) const { return TypeDefault::bernoulli(const_cast(*this), generator); #else static auto table = globalATenDispatch().getOpTable("aten::bernoulli(Tensor self, *, Generator? generator=None) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), generator); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), generator); #endif } inline Tensor & Tensor::bernoulli_(const Tensor & p, Generator * generator) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::bernoulli_(const_cast(*this), p, generator); break; default: - AT_ERROR("bernoulli_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("bernoulli_ not implemented for ", at::toString(type_set())); } #else static auto table = globalATenDispatch().getOpTable("aten::bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), p, generator); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, p))(const_cast(*this), p, generator); #endif } inline Tensor & Tensor::bernoulli_(double p, Generator * generator) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::bernoulli_(const_cast(*this), p, generator); break; default: - AT_ERROR("bernoulli_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("bernoulli_ not implemented for ", at::toString(type_set())); } #else static auto table = globalATenDispatch().getOpTable("aten::bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), p, generator); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, generator); #endif } inline Tensor Tensor::bernoulli(double p, Generator * generator) const { @@ -421,173 +535,189 @@ inline Tensor Tensor::bernoulli(double p, Generator * generator) const { return TypeDefault::bernoulli(const_cast(*this), p, generator); #else static auto table = globalATenDispatch().getOpTable("aten::bernoulli.p(Tensor self, float p, *, Generator? generator=None) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), p, generator); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, generator); #endif } inline Tensor Tensor::bincount(const Tensor & weights, int64_t minlength) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::bincount(const_cast(*this), weights, minlength); break; default: - AT_ERROR("bincount not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("bincount not implemented for ", at::toString(type_set())); } #else static auto table = globalATenDispatch().getOpTable("aten::bincount(Tensor self, Tensor? weights=None, int minlength=0) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), weights, minlength); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, weights))(const_cast(*this), weights, minlength); #endif } inline Tensor Tensor::bitwise_not() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::bitwise_not(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::bitwise_not(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::bitwise_not", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::bitwise_not_() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::bitwise_not_(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::bitwise_not_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::bitwise_not_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::logical_not() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::logical_not(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::logical_not(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::logical_not", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::logical_not_() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::logical_not_(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::logical_not_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::logical_not_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::logical_xor(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::logical_xor(const_cast(*this), other); #else - static auto table = globalATenDispatch().getOpTable("aten::logical_xor(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::logical_xor", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::logical_xor_(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::logical_xor_(const_cast(*this), other); #else - static auto table = globalATenDispatch().getOpTable("aten::logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::logical_xor_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::bmm(const Tensor & mat2) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::bmm(const_cast(*this), mat2); break; default: - AT_ERROR("bmm not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("bmm not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::bmm(Tensor self, Tensor mat2) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), mat2); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::bmm", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mat2))) + .callUnboxed(const_cast(*this), mat2); #endif } inline Tensor Tensor::ceil() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::ceil(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::ceil(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ceil", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::ceil_() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::ceil_(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::ceil_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ceil_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline std::vector Tensor::chunk(int64_t chunks, int64_t dim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::chunk(const_cast(*this), chunks, dim); #else - static auto table = globalATenDispatch().getOpTable("aten::chunk(Tensor(a) self, int chunks, int dim=0) -> Tensor(a)[]"); - return table->getOp (const Tensor &, int64_t, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), chunks, dim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::chunk", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, const Tensor &, int64_t, int64_t>(const_cast(*this), chunks, dim); #endif } inline Tensor Tensor::clamp(c10::optional min, c10::optional max) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::clamp(const_cast(*this), min, max); #else - static auto table = globalATenDispatch().getOpTable("aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor"); - return table->getOp, c10::optional)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), min, max); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::clamp", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, c10::optional>(const_cast(*this), min, max); #endif } inline Tensor & Tensor::clamp_(c10::optional min, c10::optional max) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::clamp_(const_cast(*this), min, max); break; default: - AT_ERROR("clamp_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("clamp_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)"); - return table->getOp, c10::optional)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), min, max); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::clamp_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, c10::optional>(const_cast(*this), min, max); #endif } inline Tensor Tensor::clamp_max(Scalar max) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::clamp_max(const_cast(*this), max); #else - static auto table = globalATenDispatch().getOpTable("aten::clamp_max(Tensor self, Scalar max) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), max); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::clamp_max", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), max); #endif } inline Tensor & Tensor::clamp_max_(Scalar max) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::clamp_max_(const_cast(*this), max); break; default: - AT_ERROR("clamp_max_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("clamp_max_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::clamp_max_(Tensor(a!) self, Scalar max) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), max); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::clamp_max_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), max); #endif } inline Tensor Tensor::clamp_min(Scalar min) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::clamp_min(const_cast(*this), min); #else - static auto table = globalATenDispatch().getOpTable("aten::clamp_min(Tensor self, Scalar min) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), min); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::clamp_min", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), min); #endif } inline Tensor & Tensor::clamp_min_(Scalar min) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::clamp_min_(const_cast(*this), min); break; default: - AT_ERROR("clamp_min_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("clamp_min_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), min); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::clamp_min_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), min); #endif } inline Tensor Tensor::contiguous(MemoryFormat memory_format) const { @@ -595,59 +725,64 @@ inline Tensor Tensor::contiguous(MemoryFormat memory_format) const { return TypeDefault::contiguous(const_cast(*this), memory_format); #else static auto table = globalATenDispatch().getOpTable("aten::contiguous(Tensor self, *, MemoryFormat memory_format=contiguous_format) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), memory_format); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), memory_format); #endif } inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::copy_(const_cast(*this), src, non_blocking); #else - static auto table = globalATenDispatch().getOpTable("aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), src, non_blocking); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::copy_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, src))) + .callUnboxed(const_cast(*this), src, non_blocking); #endif } inline Tensor Tensor::cos() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::cos(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::cos(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cos", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::cos_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::cos_(const_cast(*this)); break; default: - AT_ERROR("cos_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("cos_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::cos_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cos_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::cosh() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::cosh(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::cosh(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cosh", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::cosh_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::cosh_(const_cast(*this)); break; default: - AT_ERROR("cosh_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("cosh_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::cosh_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cosh_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::cumsum(int64_t dim, c10::optional dtype) const { @@ -655,7 +790,7 @@ inline Tensor Tensor::cumsum(int64_t dim, c10::optional dtype) const return TypeDefault::cumsum(const_cast(*this), dim, dtype); #else static auto table = globalATenDispatch().getOpTable("aten::cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, dtype); + return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, dtype); #endif } inline Tensor Tensor::cumprod(int64_t dim, c10::optional dtype) const { @@ -663,93 +798,121 @@ inline Tensor Tensor::cumprod(int64_t dim, c10::optional dtype) cons return TypeDefault::cumprod(const_cast(*this), dim, dtype); #else static auto table = globalATenDispatch().getOpTable("aten::cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, dtype); + return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, dtype); #endif } inline Tensor Tensor::det() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::det(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::det(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::det", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::diag_embed(int64_t offset, int64_t dim1, int64_t dim2) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::diag_embed(const_cast(*this), offset, dim1, dim2); #else - static auto table = globalATenDispatch().getOpTable("aten::diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), offset, dim1, dim2); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::diag_embed", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), offset, dim1, dim2); #endif } inline Tensor Tensor::diagflat(int64_t offset) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::diagflat(const_cast(*this), offset); #else - static auto table = globalATenDispatch().getOpTable("aten::diagflat(Tensor self, int offset=0) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), offset); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::diagflat", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), offset); #endif } inline Tensor Tensor::diagonal(int64_t offset, int64_t dim1, int64_t dim2) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::diagonal(const_cast(*this), offset, dim1, dim2); #else - static auto table = globalATenDispatch().getOpTable("aten::diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), offset, dim1, dim2); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::diagonal", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), offset, dim1, dim2); #endif } inline Tensor & Tensor::fill_diagonal_(Scalar fill_value, bool wrap) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::fill_diagonal_(const_cast(*this), fill_value, wrap); #else - static auto table = globalATenDispatch().getOpTable("aten::fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), fill_value, wrap); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::fill_diagonal_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), fill_value, wrap); #endif } inline Tensor Tensor::div(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - return TypeDefault::div(const_cast(*this), other); + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { + case Backend::CPU: + return CPUType::div(const_cast(*this), other); + break; + case Backend::SparseCPU: + return SparseCPUType::div(const_cast(*this), other); + break; + default: + AT_ERROR("div not implemented for ", at::toString(type_set())); + } #else - static auto table = globalATenDispatch().getOpTable("aten::div.Tensor(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::div", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::div_(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - return TypeDefault::div_(const_cast(*this), other); + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { + case Backend::CPU: + return CPUType::div_(const_cast(*this), other); + break; + case Backend::SparseCPU: + return SparseCPUType::div_(const_cast(*this), other); + break; + default: + AT_ERROR("div_ not implemented for ", at::toString(type_set())); + } #else - static auto table = globalATenDispatch().getOpTable("aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::div_", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::div(Scalar other) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::div(const_cast(*this), other); #else - static auto table = globalATenDispatch().getOpTable("aten::div.Scalar(Tensor self, Scalar other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::div", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::div_(Scalar other) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::div_(const_cast(*this), other); #else - static auto table = globalATenDispatch().getOpTable("aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::div_", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::dot(const Tensor & tensor) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::dot(const_cast(*this), tensor); break; default: - AT_ERROR("dot not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("dot not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::dot(Tensor self, Tensor tensor) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), tensor); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::dot", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, tensor))) + .callUnboxed(const_cast(*this), tensor); #endif } inline Tensor Tensor::new_empty(IntArrayRef size, const TensorOptions & options) const { @@ -757,7 +920,7 @@ inline Tensor Tensor::new_empty(IntArrayRef size, const TensorOptions & options) return TypeDefault::new_empty(const_cast(*this), size, options); #else static auto table = globalATenDispatch().getOpTable("aten::new_empty(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), size, options); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, options))(const_cast(*this), size, options); #endif } inline Tensor Tensor::new_full(IntArrayRef size, Scalar fill_value, const TensorOptions & options) const { @@ -765,263 +928,318 @@ inline Tensor Tensor::new_full(IntArrayRef size, Scalar fill_value, const Tensor return TypeDefault::new_full(const_cast(*this), size, fill_value, options); #else static auto table = globalATenDispatch().getOpTable("aten::new_full(Tensor self, int[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), size, fill_value, options); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, options))(const_cast(*this), size, fill_value, options); #endif } inline Tensor & Tensor::resize_(IntArrayRef size) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::resize_(const_cast(*this), size); break; default: - AT_ERROR("resize_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("resize_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::resize_(Tensor(a!) self, int[] size) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), size); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::resize_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), size); #endif } inline Tensor Tensor::erf() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::erf(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::erf(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::erf", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::erf_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::erf_(const_cast(*this)); break; default: - AT_ERROR("erf_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("erf_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::erf_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::erf_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::erfc() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::erfc(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::erfc(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::erfc", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::erfc_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::erfc_(const_cast(*this)); break; default: - AT_ERROR("erfc_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("erfc_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::erfc_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::erfc_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::exp() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::exp(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::exp(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::exp", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::exp_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::exp_(const_cast(*this)); break; default: - AT_ERROR("exp_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("exp_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::exp_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::exp_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::expm1() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::expm1(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::expm1(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::expm1", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::expm1_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::expm1_(const_cast(*this)); break; default: - AT_ERROR("expm1_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("expm1_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::expm1_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::expm1_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::expand(IntArrayRef size, bool implicit) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::expand(const_cast(*this), size, implicit); #else - static auto table = globalATenDispatch().getOpTable("aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), size, implicit); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::expand", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), size, implicit); #endif } inline Tensor Tensor::expand_as(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::expand_as(const_cast(*this), other); #else - static auto table = globalATenDispatch().getOpTable("aten::expand_as(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::expand_as", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::flatten(int64_t start_dim, int64_t end_dim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::flatten(const_cast(*this), start_dim, end_dim); #else - static auto table = globalATenDispatch().getOpTable("aten::flatten(Tensor self, int start_dim=0, int end_dim=-1) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), start_dim, end_dim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::flatten", "using_ints"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), start_dim, end_dim); +#endif +} +#ifdef BUILD_NAMEDTENSOR +inline Tensor Tensor::flatten(int64_t start_dim, int64_t end_dim, Dimname out_dim) const { +#ifdef USE_STATIC_DISPATCH + return TypeDefault::flatten(const_cast(*this), start_dim, end_dim, out_dim); +#else + static auto table = globalATenDispatch().getOpTable("aten::flatten.named_out_dim(Tensor self, int start_dim, int end_dim, Dimname out_dim) -> Tensor"); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), start_dim, end_dim, out_dim); +#endif +} +#endif +#ifdef BUILD_NAMEDTENSOR +inline Tensor Tensor::flatten(Dimname start_dim, Dimname end_dim, Dimname out_dim) const { +#ifdef USE_STATIC_DISPATCH + return TypeDefault::flatten(const_cast(*this), start_dim, end_dim, out_dim); +#else + static auto table = globalATenDispatch().getOpTable("aten::flatten.using_names(Tensor self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor"); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), start_dim, end_dim, out_dim); #endif } +#endif +#ifdef BUILD_NAMEDTENSOR +inline Tensor Tensor::flatten(DimnameList dims, Dimname out_dim) const { +#ifdef USE_STATIC_DISPATCH + return TypeDefault::flatten(const_cast(*this), dims, out_dim); +#else + static auto table = globalATenDispatch().getOpTable("aten::flatten.DimnameList(Tensor self, DimnameList dims, Dimname out_dim) -> Tensor"); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dims, out_dim); +#endif +} +#endif inline Tensor & Tensor::fill_(Scalar value) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::fill_(const_cast(*this), value); #else - static auto table = globalATenDispatch().getOpTable("aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), value); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::fill_", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), value); #endif } inline Tensor & Tensor::fill_(const Tensor & value) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::fill_(const_cast(*this), value); #else - static auto table = globalATenDispatch().getOpTable("aten::fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), value); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::fill_", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, value))) + .callUnboxed(const_cast(*this), value); #endif } inline Tensor Tensor::floor() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::floor(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::floor(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::floor", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::floor_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::floor_(const_cast(*this)); break; default: - AT_ERROR("floor_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("floor_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::floor_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::floor_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::frac() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::frac(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::frac(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::frac", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::frac_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::frac_(const_cast(*this)); break; default: - AT_ERROR("frac_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("frac_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::frac_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::frac_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::ger(const Tensor & vec2) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::ger(const_cast(*this), vec2); break; default: - AT_ERROR("ger not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("ger not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::ger(Tensor self, Tensor vec2) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), vec2); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ger", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, vec2))) + .callUnboxed(const_cast(*this), vec2); #endif } inline Tensor Tensor::fft(int64_t signal_ndim, bool normalized) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::fft(const_cast(*this), signal_ndim, normalized); #else - static auto table = globalATenDispatch().getOpTable("aten::fft(Tensor self, int signal_ndim, bool normalized=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), signal_ndim, normalized); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::fft", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), signal_ndim, normalized); #endif } inline Tensor Tensor::ifft(int64_t signal_ndim, bool normalized) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::ifft(const_cast(*this), signal_ndim, normalized); #else - static auto table = globalATenDispatch().getOpTable("aten::ifft(Tensor self, int signal_ndim, bool normalized=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), signal_ndim, normalized); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ifft", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), signal_ndim, normalized); #endif } inline Tensor Tensor::rfft(int64_t signal_ndim, bool normalized, bool onesided) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::rfft(const_cast(*this), signal_ndim, normalized, onesided); #else - static auto table = globalATenDispatch().getOpTable("aten::rfft(Tensor self, int signal_ndim, bool normalized=False, bool onesided=True) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), signal_ndim, normalized, onesided); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::rfft", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), signal_ndim, normalized, onesided); #endif } inline Tensor Tensor::irfft(int64_t signal_ndim, bool normalized, bool onesided, IntArrayRef signal_sizes) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::irfft(const_cast(*this), signal_ndim, normalized, onesided, signal_sizes); #else - static auto table = globalATenDispatch().getOpTable("aten::irfft(Tensor self, int signal_ndim, bool normalized=False, bool onesided=True, int[] signal_sizes=[]) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), signal_ndim, normalized, onesided, signal_sizes); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::irfft", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), signal_ndim, normalized, onesided, signal_sizes); #endif } inline Tensor Tensor::index(TensorList indices) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::index(const_cast(*this), indices); #else - static auto table = globalATenDispatch().getOpTable("aten::index(Tensor self, Tensor?[] indices) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), indices); + static auto table = globalATenDispatch().getOpTable("aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, indices))(const_cast(*this), indices); #endif } inline Tensor & Tensor::index_copy_(int64_t dim, const Tensor & index, const Tensor & source) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::index_copy_(const_cast(*this), dim, index, source); #else - static auto table = globalATenDispatch().getOpTable("aten::index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, index, source); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::index_copy_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, source))) + .callUnboxed(const_cast(*this), dim, index, source); #endif } inline Tensor Tensor::index_copy(int64_t dim, const Tensor & index, const Tensor & source) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::index_copy(const_cast(*this), dim, index, source); #else - static auto table = globalATenDispatch().getOpTable("aten::index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, index, source); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::index_copy", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, source))) + .callUnboxed(const_cast(*this), dim, index, source); #endif } inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values, bool accumulate) const { @@ -1029,7 +1247,7 @@ inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values, bo return TypeDefault::index_put_(const_cast(*this), indices, values, accumulate); #else static auto table = globalATenDispatch().getOpTable("aten::index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), indices, values, accumulate); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, indices, values))(const_cast(*this), indices, values, accumulate); #endif } inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool accumulate) const { @@ -1037,136 +1255,150 @@ inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool return TypeDefault::index_put(const_cast(*this), indices, values, accumulate); #else static auto table = globalATenDispatch().getOpTable("aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), indices, values, accumulate); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, indices, values))(const_cast(*this), indices, values, accumulate); #endif } inline Tensor Tensor::inverse() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::inverse(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::inverse(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::inverse", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::isclose(const Tensor & other, double rtol, double atol, bool equal_nan) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::isclose(const_cast(*this), other, rtol, atol, equal_nan); #else - static auto table = globalATenDispatch().getOpTable("aten::isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other, rtol, atol, equal_nan); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::isclose", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other, rtol, atol, equal_nan); #endif } inline bool Tensor::is_distributed() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::is_distributed(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::is_distributed(Tensor self) -> bool"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::is_distributed", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline bool Tensor::is_floating_point() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::is_floating_point(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::is_floating_point(Tensor self) -> bool"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::is_floating_point", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline bool Tensor::is_complex() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::is_complex(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::is_complex(Tensor self) -> bool"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::is_complex", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline bool Tensor::is_nonzero() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::is_nonzero(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::is_nonzero(Tensor self) -> bool"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::is_nonzero", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline bool Tensor::is_same_size(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::is_same_size(const_cast(*this), other); #else - static auto table = globalATenDispatch().getOpTable("aten::is_same_size(Tensor self, Tensor other) -> bool"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::is_same_size", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline bool Tensor::is_signed() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::is_signed(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::is_signed(Tensor self) -> bool"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::is_signed", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline std::tuple Tensor::kthvalue(int64_t k, int64_t dim, bool keepdim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::kthvalue(const_cast(*this), k, dim, keepdim); #else - static auto table = globalATenDispatch().getOpTable("aten::kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)"); - return table->getOp (const Tensor &, int64_t, int64_t, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), k, dim, keepdim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::kthvalue", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, const Tensor &, int64_t, int64_t, bool>(const_cast(*this), k, dim, keepdim); #endif } inline Tensor Tensor::log() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::log(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::log(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::log", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::log_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::log_(const_cast(*this)); break; default: - AT_ERROR("log_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("log_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::log_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::log_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::log10() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::log10(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::log10(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::log10", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::log10_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::log10_(const_cast(*this)); break; default: - AT_ERROR("log10_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("log10_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::log10_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::log10_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::log1p() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::log1p(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::log1p(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::log1p", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::log1p_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::log1p_(const_cast(*this)); break; @@ -1174,41 +1406,45 @@ inline Tensor & Tensor::log1p_() const { return SparseCPUType::log1p_(const_cast(*this)); break; default: - AT_ERROR("log1p_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("log1p_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::log1p_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::log1p_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::log2() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::log2(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::log2(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::log2", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::log2_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::log2_(const_cast(*this)); break; default: - AT_ERROR("log2_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("log2_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::log2_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::log2_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::logdet() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::logdet(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::logdet(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::logdet", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::log_softmax(int64_t dim, c10::optional dtype) const { @@ -1216,7 +1452,7 @@ inline Tensor Tensor::log_softmax(int64_t dim, c10::optional dtype) return TypeDefault::log_softmax(const_cast(*this), dim, dtype); #else static auto table = globalATenDispatch().getOpTable("aten::log_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, dtype); + return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, dtype); #endif } #ifdef BUILD_NAMEDTENSOR @@ -1225,7 +1461,7 @@ inline Tensor Tensor::log_softmax(Dimname dim, c10::optional dtype) return TypeDefault::log_softmax(const_cast(*this), dim, dtype); #else static auto table = globalATenDispatch().getOpTable("aten::log_softmax(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, dtype); + return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, dtype); #endif } #endif @@ -1233,8 +1469,9 @@ inline Tensor Tensor::logsumexp(IntArrayRef dim, bool keepdim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::logsumexp(const_cast(*this), dim, keepdim); #else - static auto table = globalATenDispatch().getOpTable("aten::logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::logsumexp", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim, keepdim); #endif } #ifdef BUILD_NAMEDTENSOR @@ -1243,7 +1480,7 @@ inline Tensor Tensor::logsumexp(DimnameList dim, bool keepdim) const { return TypeDefault::logsumexp(const_cast(*this), dim, keepdim); #else static auto table = globalATenDispatch().getOpTable("aten::logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); #endif } #endif @@ -1251,32 +1488,36 @@ inline Tensor Tensor::matmul(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::matmul(const_cast(*this), other); #else - static auto table = globalATenDispatch().getOpTable("aten::matmul(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::matmul", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::matrix_power(int64_t n) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::matrix_power(const_cast(*this), n); #else - static auto table = globalATenDispatch().getOpTable("aten::matrix_power(Tensor self, int n) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), n); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::matrix_power", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), n); #endif } inline std::tuple Tensor::max(int64_t dim, bool keepdim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::max(const_cast(*this), dim, keepdim); #else - static auto table = globalATenDispatch().getOpTable("aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); - return table->getOp (const Tensor &, int64_t, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::max", "dim"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, const Tensor &, int64_t, bool>(const_cast(*this), dim, keepdim); #endif } inline Tensor Tensor::max_values(IntArrayRef dim, bool keepdim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::max_values(const_cast(*this), dim, keepdim); #else - static auto table = globalATenDispatch().getOpTable("aten::max_values(Tensor self, int[1] dim, bool keepdim=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::max_values", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim, keepdim); #endif } #ifdef BUILD_NAMEDTENSOR @@ -1285,7 +1526,7 @@ inline std::tuple Tensor::max(Dimname dim, bool keepdim) const { return TypeDefault::max(const_cast(*this), dim, keepdim); #else static auto table = globalATenDispatch().getOpTable("aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); - return table->getOp (const Tensor &, Dimname, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim); + return table->getOp (const Tensor &, Dimname, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); #endif } #endif @@ -1295,7 +1536,7 @@ inline Tensor Tensor::max_values(DimnameList dim, bool keepdim) const { return TypeDefault::max_values(const_cast(*this), dim, keepdim); #else static auto table = globalATenDispatch().getOpTable("aten::max_values.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); #endif } #endif @@ -1304,7 +1545,7 @@ inline Tensor Tensor::mean(c10::optional dtype) const { return TypeDefault::mean(const_cast(*this), dtype); #else static auto table = globalATenDispatch().getOpTable("aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dtype); + return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dtype); #endif } inline Tensor Tensor::mean(IntArrayRef dim, bool keepdim, c10::optional dtype) const { @@ -1312,7 +1553,7 @@ inline Tensor Tensor::mean(IntArrayRef dim, bool keepdim, c10::optional(*this), dim, keepdim, dtype); #else static auto table = globalATenDispatch().getOpTable("aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim, dtype); + return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim, dtype); #endif } #ifdef BUILD_NAMEDTENSOR @@ -1321,7 +1562,7 @@ inline Tensor Tensor::mean(DimnameList dim, bool keepdim, c10::optional(*this), dim, keepdim, dtype); #else static auto table = globalATenDispatch().getOpTable("aten::mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim, dtype); + return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim, dtype); #endif } #endif @@ -1329,8 +1570,9 @@ inline std::tuple Tensor::median(int64_t dim, bool keepdim) const #ifdef USE_STATIC_DISPATCH return TypeDefault::median(const_cast(*this), dim, keepdim); #else - static auto table = globalATenDispatch().getOpTable("aten::median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); - return table->getOp (const Tensor &, int64_t, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::median", "dim"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, const Tensor &, int64_t, bool>(const_cast(*this), dim, keepdim); #endif } #ifdef BUILD_NAMEDTENSOR @@ -1339,7 +1581,7 @@ inline std::tuple Tensor::median(Dimname dim, bool keepdim) const return TypeDefault::median(const_cast(*this), dim, keepdim); #else static auto table = globalATenDispatch().getOpTable("aten::median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); - return table->getOp (const Tensor &, Dimname, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim); + return table->getOp (const Tensor &, Dimname, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); #endif } #endif @@ -1347,16 +1589,18 @@ inline std::tuple Tensor::min(int64_t dim, bool keepdim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::min(const_cast(*this), dim, keepdim); #else - static auto table = globalATenDispatch().getOpTable("aten::min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); - return table->getOp (const Tensor &, int64_t, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::min", "dim"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, const Tensor &, int64_t, bool>(const_cast(*this), dim, keepdim); #endif } inline Tensor Tensor::min_values(IntArrayRef dim, bool keepdim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::min_values(const_cast(*this), dim, keepdim); #else - static auto table = globalATenDispatch().getOpTable("aten::min_values(Tensor self, int[1] dim, bool keepdim=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::min_values", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim, keepdim); #endif } #ifdef BUILD_NAMEDTENSOR @@ -1365,7 +1609,7 @@ inline std::tuple Tensor::min(Dimname dim, bool keepdim) const { return TypeDefault::min(const_cast(*this), dim, keepdim); #else static auto table = globalATenDispatch().getOpTable("aten::min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)"); - return table->getOp (const Tensor &, Dimname, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim); + return table->getOp (const Tensor &, Dimname, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); #endif } #endif @@ -1375,13 +1619,13 @@ inline Tensor Tensor::min_values(DimnameList dim, bool keepdim) const { return TypeDefault::min_values(const_cast(*this), dim, keepdim); #else static auto table = globalATenDispatch().getOpTable("aten::min_values.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim); #endif } #endif inline Tensor Tensor::mm(const Tensor & mat2) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::mm(const_cast(*this), mat2); break; @@ -1389,24 +1633,26 @@ inline Tensor Tensor::mm(const Tensor & mat2) const { return SparseCPUType::mm(const_cast(*this), mat2); break; default: - AT_ERROR("mm not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("mm not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::mm(Tensor self, Tensor mat2) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), mat2); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::mm", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mat2))) + .callUnboxed(const_cast(*this), mat2); #endif } inline std::tuple Tensor::mode(int64_t dim, bool keepdim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::mode(const_cast(*this), dim, keepdim); #else - static auto table = globalATenDispatch().getOpTable("aten::mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)"); - return table->getOp (const Tensor &, int64_t, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::mode", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, const Tensor &, int64_t, bool>(const_cast(*this), dim, keepdim); #endif } inline Tensor Tensor::mul(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::mul(const_cast(*this), other); break; @@ -1414,16 +1660,17 @@ inline Tensor Tensor::mul(const Tensor & other) const { return SparseCPUType::mul(const_cast(*this), other); break; default: - AT_ERROR("mul not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("mul not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::mul.Tensor(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::mul", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::mul_(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::mul_(const_cast(*this), other); break; @@ -1431,62 +1678,68 @@ inline Tensor & Tensor::mul_(const Tensor & other) const { return SparseCPUType::mul_(const_cast(*this), other); break; default: - AT_ERROR("mul_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("mul_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::mul_", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::mul(Scalar other) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::mul(const_cast(*this), other); #else - static auto table = globalATenDispatch().getOpTable("aten::mul.Scalar(Tensor self, Scalar other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::mul", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::mul_(Scalar other) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::mul_(const_cast(*this), other); #else - static auto table = globalATenDispatch().getOpTable("aten::mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::mul_", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::mv(const Tensor & vec) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::mv(const_cast(*this), vec); break; default: - AT_ERROR("mv not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("mv not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::mv(Tensor self, Tensor vec) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), vec); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::mv", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, vec))) + .callUnboxed(const_cast(*this), vec); #endif } inline Tensor Tensor::mvlgamma(int64_t p) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::mvlgamma(const_cast(*this), p); #else - static auto table = globalATenDispatch().getOpTable("aten::mvlgamma(Tensor self, int p) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), p); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::mvlgamma", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), p); #endif } inline Tensor & Tensor::mvlgamma_(int64_t p) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::mvlgamma_(const_cast(*this), p); #else - static auto table = globalATenDispatch().getOpTable("aten::mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), p); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::mvlgamma_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), p); #endif } inline Tensor Tensor::narrow_copy(int64_t dim, int64_t start, int64_t length) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::narrow_copy(const_cast(*this), dim, start, length); break; @@ -1494,164 +1747,176 @@ inline Tensor Tensor::narrow_copy(int64_t dim, int64_t start, int64_t length) co return SparseCPUType::narrow_copy(const_cast(*this), dim, start, length); break; default: - AT_ERROR("narrow_copy not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("narrow_copy not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::narrow_copy(Tensor self, int dim, int start, int length) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, start, length); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::narrow_copy", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim, start, length); #endif } inline Tensor Tensor::narrow(int64_t dim, int64_t start, int64_t length) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::narrow(const_cast(*this), dim, start, length); #else - static auto table = globalATenDispatch().getOpTable("aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, start, length); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::narrow", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim, start, length); #endif } inline Tensor Tensor::permute(IntArrayRef dims) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::permute(const_cast(*this), dims); #else - static auto table = globalATenDispatch().getOpTable("aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dims); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::permute", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dims); #endif } inline Tensor Tensor::numpy_T() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::numpy_T(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::numpy_T(Tensor(a) self) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::numpy_T", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline bool Tensor::is_pinned() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::is_pinned(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::is_pinned(Tensor self) -> bool"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::is_pinned", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::pin_memory() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::pin_memory(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::pin_memory(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::pin_memory", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::pinverse(double rcond) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::pinverse(const_cast(*this), rcond); #else - static auto table = globalATenDispatch().getOpTable("aten::pinverse(Tensor self, float rcond=1e-15) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), rcond); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::pinverse", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), rcond); #endif } inline Tensor Tensor::reciprocal() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::reciprocal(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::reciprocal(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::reciprocal", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::reciprocal_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::reciprocal_(const_cast(*this)); break; default: - AT_ERROR("reciprocal_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("reciprocal_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::reciprocal_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::reciprocal_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::neg() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::neg(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::neg(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::neg", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::neg_() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::neg_(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::neg_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::neg_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::repeat(IntArrayRef repeats) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::repeat(const_cast(*this), repeats); #else - static auto table = globalATenDispatch().getOpTable("aten::repeat(Tensor self, int[] repeats) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), repeats); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::repeat", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), repeats); #endif } inline Tensor Tensor::repeat_interleave(const Tensor & repeats, c10::optional dim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::repeat_interleave(const_cast(*this), repeats, dim); #else - static auto table = globalATenDispatch().getOpTable("aten::repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None) -> Tensor"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), repeats, dim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::repeat_interleave", "self_Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, repeats))) + .callUnboxed>(const_cast(*this), repeats, dim); #endif } inline Tensor Tensor::repeat_interleave(int64_t repeats, c10::optional dim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::repeat_interleave(const_cast(*this), repeats, dim); #else - static auto table = globalATenDispatch().getOpTable("aten::repeat_interleave.self_int(Tensor self, int repeats, int? dim=None) -> Tensor"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), repeats, dim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::repeat_interleave", "self_int"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed>(const_cast(*this), repeats, dim); #endif } inline Tensor Tensor::reshape(IntArrayRef shape) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::reshape(const_cast(*this), shape); #else - static auto table = globalATenDispatch().getOpTable("aten::reshape(Tensor self, int[] shape) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), shape); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::reshape", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), shape); #endif } inline Tensor Tensor::reshape_as(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::reshape_as(const_cast(*this), other); #else - static auto table = globalATenDispatch().getOpTable("aten::reshape_as(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::reshape_as", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::round() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::round(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::round(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::round", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::round_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { - case Backend::CPU: - return CPUType::round_(const_cast(*this)); - break; - default: - AT_ERROR("round_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); - } + return TypeDefault::round_(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::round_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::round_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::relu() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::relu(const_cast(*this)); break; @@ -1659,16 +1924,17 @@ inline Tensor Tensor::relu() const { return QuantizedCPUType::relu(const_cast(*this)); break; default: - AT_ERROR("relu not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("relu not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::relu(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::relu", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::relu_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::relu_(const_cast(*this)); break; @@ -1676,89 +1942,90 @@ inline Tensor & Tensor::relu_() const { return QuantizedCPUType::relu_(const_cast(*this)); break; default: - AT_ERROR("relu_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("relu_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::relu_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::relu_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::prelu(const Tensor & weight) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::prelu(const_cast(*this), weight); break; default: - AT_ERROR("prelu not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("prelu not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::prelu(Tensor self, Tensor weight) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), weight); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::prelu", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, weight))) + .callUnboxed(const_cast(*this), weight); #endif } inline std::tuple Tensor::prelu_backward(const Tensor & grad_output, const Tensor & weight) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::prelu_backward(grad_output, const_cast(*this), weight); break; default: - AT_ERROR("prelu_backward not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("prelu_backward not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::prelu_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor)"); - return table->getOp (const Tensor &, const Tensor &, const Tensor &)>(tensorTypeIdToBackend(type_id()), is_variable())(grad_output, const_cast(*this), weight); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::prelu_backward", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(grad_output, *this, weight))) + .callUnboxed, const Tensor &, const Tensor &, const Tensor &>(grad_output, const_cast(*this), weight); #endif } inline Tensor Tensor::hardshrink(Scalar lambd) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::hardshrink(const_cast(*this), lambd); break; default: - AT_ERROR("hardshrink not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("hardshrink not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), lambd); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::hardshrink", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), lambd); #endif } inline Tensor Tensor::hardshrink_backward(const Tensor & grad_out, Scalar lambd) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::hardshrink_backward(grad_out, const_cast(*this), lambd); break; default: - AT_ERROR("hardshrink_backward not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("hardshrink_backward not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(grad_out, const_cast(*this), lambd); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::hardshrink_backward", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(grad_out, *this))) + .callUnboxed(grad_out, const_cast(*this), lambd); #endif } inline Tensor Tensor::rsqrt() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::rsqrt(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::rsqrt(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::rsqrt", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::rsqrt_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { - case Backend::CPU: - return CPUType::rsqrt_(const_cast(*this)); - break; - default: - AT_ERROR("rsqrt_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); - } + return TypeDefault::rsqrt_(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::rsqrt_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::rsqrt_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } #ifdef BUILD_NAMEDTENSOR @@ -1767,7 +2034,7 @@ inline Tensor Tensor::select(Dimname dim, int64_t index) const { return TypeDefault::select(const_cast(*this), dim, index); #else static auto table = globalATenDispatch().getOpTable("aten::select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, index); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, index); #endif } #endif @@ -1775,104 +2042,114 @@ inline Tensor Tensor::select(int64_t dim, int64_t index) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::select(const_cast(*this), dim, index); #else - static auto table = globalATenDispatch().getOpTable("aten::select.int(Tensor(a) self, int dim, int index) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, index); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::select", "int"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim, index); #endif } inline Tensor Tensor::sigmoid() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::sigmoid(const_cast(*this)); break; default: - AT_ERROR("sigmoid not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("sigmoid not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::sigmoid(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sigmoid", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::sigmoid_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::sigmoid_(const_cast(*this)); break; default: - AT_ERROR("sigmoid_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("sigmoid_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::sigmoid_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sigmoid_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::sin() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::sin(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::sin(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sin", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::sin_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::sin_(const_cast(*this)); break; default: - AT_ERROR("sin_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("sin_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::sin_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sin_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::sinh() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::sinh(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::sinh(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sinh", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::sinh_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::sinh_(const_cast(*this)); break; default: - AT_ERROR("sinh_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("sinh_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::sinh_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sinh_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::detach() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::detach(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::detach(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::detach", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::detach_() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::detach_(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::detach_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::detach_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline int64_t Tensor::size(int64_t dim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::size(const_cast(*this), dim); #else - static auto table = globalATenDispatch().getOpTable("aten::size.int(Tensor self, int dim) -> int"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::size", "int"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim); #endif } #ifdef BUILD_NAMEDTENSOR @@ -1881,7 +2158,7 @@ inline int64_t Tensor::size(Dimname dim) const { return TypeDefault::size(const_cast(*this), dim); #else static auto table = globalATenDispatch().getOpTable("aten::size.Dimname(Tensor self, Dimname dim) -> int"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); #endif } #endif @@ -1889,24 +2166,27 @@ inline Tensor Tensor::slice(int64_t dim, int64_t start, int64_t end, int64_t ste #ifdef USE_STATIC_DISPATCH return TypeDefault::slice(const_cast(*this), dim, start, end, step); #else - static auto table = globalATenDispatch().getOpTable("aten::slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, start, end, step); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::slice", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim, start, end, step); #endif } inline std::tuple Tensor::slogdet() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::slogdet(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)"); - return table->getOp (const Tensor &)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::slogdet", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, const Tensor &>(const_cast(*this)); #endif } inline Tensor Tensor::smm(const Tensor & mat2) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::smm(const_cast(*this), mat2); #else - static auto table = globalATenDispatch().getOpTable("aten::smm(Tensor self, Tensor mat2) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), mat2); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::smm", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mat2))) + .callUnboxed(const_cast(*this), mat2); #endif } inline Tensor Tensor::softmax(int64_t dim, c10::optional dtype) const { @@ -1914,7 +2194,7 @@ inline Tensor Tensor::softmax(int64_t dim, c10::optional dtype) cons return TypeDefault::softmax(const_cast(*this), dim, dtype); #else static auto table = globalATenDispatch().getOpTable("aten::softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, dtype); + return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, dtype); #endif } #ifdef BUILD_NAMEDTENSOR @@ -1923,7 +2203,7 @@ inline Tensor Tensor::softmax(Dimname dim, c10::optional dtype) cons return TypeDefault::softmax(const_cast(*this), dim, dtype); #else static auto table = globalATenDispatch().getOpTable("aten::softmax(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, dtype); + return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, dtype); #endif } #endif @@ -1931,56 +2211,63 @@ inline std::vector Tensor::split(int64_t split_size, int64_t dim) const #ifdef USE_STATIC_DISPATCH return TypeDefault::split(const_cast(*this), split_size, dim); #else - static auto table = globalATenDispatch().getOpTable("aten::split.Tensor(Tensor(a) self, int split_size, int dim=0) -> Tensor(a)[]"); - return table->getOp (const Tensor &, int64_t, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), split_size, dim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::split", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, const Tensor &, int64_t, int64_t>(const_cast(*this), split_size, dim); #endif } inline std::vector Tensor::split_with_sizes(IntArrayRef split_sizes, int64_t dim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::split_with_sizes(const_cast(*this), split_sizes, dim); #else - static auto table = globalATenDispatch().getOpTable("aten::split_with_sizes(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]"); - return table->getOp (const Tensor &, IntArrayRef, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), split_sizes, dim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::split_with_sizes", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, const Tensor &, IntArrayRef, int64_t>(const_cast(*this), split_sizes, dim); #endif } inline Tensor Tensor::squeeze() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::squeeze(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::squeeze(Tensor(a) self) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::squeeze", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::squeeze(int64_t dim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::squeeze(const_cast(*this), dim); #else - static auto table = globalATenDispatch().getOpTable("aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::squeeze", "dim"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim); #endif } inline Tensor & Tensor::squeeze_() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::squeeze_(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::squeeze_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::squeeze_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::squeeze_(int64_t dim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::squeeze_(const_cast(*this), dim); #else - static auto table = globalATenDispatch().getOpTable("aten::squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::squeeze_", "dim"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim); #endif } inline Tensor Tensor::sspaddmm(const Tensor & mat1, const Tensor & mat2, Scalar beta, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::sspaddmm(const_cast(*this), mat1, mat2, beta, alpha); #else - static auto table = globalATenDispatch().getOpTable("aten::sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), mat1, mat2, beta, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sspaddmm", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mat1, mat2))) + .callUnboxed(const_cast(*this), mat1, mat2, beta, alpha); #endif } inline Tensor Tensor::stft(int64_t n_fft, c10::optional hop_length, c10::optional win_length, const Tensor & window, bool normalized, bool onesided) const { @@ -1988,15 +2275,16 @@ inline Tensor Tensor::stft(int64_t n_fft, c10::optional hop_length, c10 return TypeDefault::stft(const_cast(*this), n_fft, hop_length, win_length, window, normalized, onesided); #else static auto table = globalATenDispatch().getOpTable("aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool onesided=True) -> Tensor"); - return table->getOp, c10::optional, const Tensor &, bool, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), n_fft, hop_length, win_length, window, normalized, onesided); + return table->getOp, c10::optional, const Tensor &, bool, bool)>(at::detail::multi_dispatch_tensor_type_set(*this, window))(const_cast(*this), n_fft, hop_length, win_length, window, normalized, onesided); #endif } inline int64_t Tensor::stride(int64_t dim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::stride(const_cast(*this), dim); #else - static auto table = globalATenDispatch().getOpTable("aten::stride.int(Tensor self, int dim) -> int"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::stride", "int"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim); #endif } #ifdef BUILD_NAMEDTENSOR @@ -2005,7 +2293,7 @@ inline int64_t Tensor::stride(Dimname dim) const { return TypeDefault::stride(const_cast(*this), dim); #else static auto table = globalATenDispatch().getOpTable("aten::stride.Dimname(Tensor self, Dimname dim) -> int"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); #endif } #endif @@ -2014,7 +2302,7 @@ inline Tensor Tensor::sum(c10::optional dtype) const { return TypeDefault::sum(const_cast(*this), dtype); #else static auto table = globalATenDispatch().getOpTable("aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dtype); + return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dtype); #endif } inline Tensor Tensor::sum(IntArrayRef dim, bool keepdim, c10::optional dtype) const { @@ -2022,7 +2310,7 @@ inline Tensor Tensor::sum(IntArrayRef dim, bool keepdim, c10::optional(*this), dim, keepdim, dtype); #else static auto table = globalATenDispatch().getOpTable("aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim, dtype); + return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim, dtype); #endif } #ifdef BUILD_NAMEDTENSOR @@ -2031,7 +2319,7 @@ inline Tensor Tensor::sum(DimnameList dim, bool keepdim, c10::optional(*this), dim, keepdim, dtype); #else static auto table = globalATenDispatch().getOpTable("aten::sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim, dtype); + return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim, dtype); #endif } #endif @@ -2039,46 +2327,51 @@ inline Tensor Tensor::sum_to_size(IntArrayRef size) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::sum_to_size(const_cast(*this), size); #else - static auto table = globalATenDispatch().getOpTable("aten::sum_to_size(Tensor self, int[] size) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), size); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sum_to_size", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), size); #endif } inline Tensor Tensor::sqrt() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::sqrt(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::sqrt(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sqrt", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::sqrt_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::sqrt_(const_cast(*this)); break; default: - AT_ERROR("sqrt_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("sqrt_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::sqrt_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sqrt_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::std(bool unbiased) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::std(const_cast(*this), unbiased); #else - static auto table = globalATenDispatch().getOpTable("aten::std(Tensor self, bool unbiased=True) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), unbiased); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::std", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), unbiased); #endif } inline Tensor Tensor::std(IntArrayRef dim, bool unbiased, bool keepdim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::std(const_cast(*this), dim, unbiased, keepdim); #else - static auto table = globalATenDispatch().getOpTable("aten::std.dim(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, unbiased, keepdim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::std", "dim"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim, unbiased, keepdim); #endif } #ifdef BUILD_NAMEDTENSOR @@ -2087,7 +2380,7 @@ inline Tensor Tensor::std(DimnameList dim, bool unbiased, bool keepdim) const { return TypeDefault::std(const_cast(*this), dim, unbiased, keepdim); #else static auto table = globalATenDispatch().getOpTable("aten::std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, unbiased, keepdim); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, unbiased, keepdim); #endif } #endif @@ -2096,7 +2389,7 @@ inline Tensor Tensor::prod(c10::optional dtype) const { return TypeDefault::prod(const_cast(*this), dtype); #else static auto table = globalATenDispatch().getOpTable("aten::prod(Tensor self, *, ScalarType? dtype=None) -> Tensor"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dtype); + return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dtype); #endif } inline Tensor Tensor::prod(int64_t dim, bool keepdim, c10::optional dtype) const { @@ -2104,7 +2397,7 @@ inline Tensor Tensor::prod(int64_t dim, bool keepdim, c10::optional return TypeDefault::prod(const_cast(*this), dim, keepdim, dtype); #else static auto table = globalATenDispatch().getOpTable("aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim, dtype); + return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim, dtype); #endif } #ifdef BUILD_NAMEDTENSOR @@ -2113,7 +2406,7 @@ inline Tensor Tensor::prod(Dimname dim, bool keepdim, c10::optional return TypeDefault::prod(const_cast(*this), dim, keepdim, dtype); #else static auto table = globalATenDispatch().getOpTable("aten::prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, keepdim, dtype); + return table->getOp)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, keepdim, dtype); #endif } #endif @@ -2121,68 +2414,75 @@ inline Tensor Tensor::t() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::t(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::t(Tensor(a) self) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::t", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::t_() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::t_(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::t_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::t_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::tan() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::tan(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::tan(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::tan", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::tan_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::tan_(const_cast(*this)); break; default: - AT_ERROR("tan_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("tan_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::tan_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::tan_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::tanh() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::tanh(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::tanh(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::tanh", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::tanh_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::tanh_(const_cast(*this)); break; default: - AT_ERROR("tanh_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("tanh_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::tanh_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::tanh_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::transpose(int64_t dim0, int64_t dim1) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::transpose(const_cast(*this), dim0, dim1); #else - static auto table = globalATenDispatch().getOpTable("aten::transpose(Tensor(a) self, int dim0, int dim1) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim0, dim1); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::transpose", "int"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim0, dim1); #endif } #ifdef BUILD_NAMEDTENSOR @@ -2190,8 +2490,8 @@ inline Tensor Tensor::transpose(Dimname dim0, Dimname dim1) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::transpose(const_cast(*this), dim0, dim1); #else - static auto table = globalATenDispatch().getOpTable("aten::transpose(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim0, dim1); + static auto table = globalATenDispatch().getOpTable("aten::transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a)"); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim0, dim1); #endif } #endif @@ -2199,106 +2499,117 @@ inline Tensor & Tensor::transpose_(int64_t dim0, int64_t dim1) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::transpose_(const_cast(*this), dim0, dim1); #else - static auto table = globalATenDispatch().getOpTable("aten::transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim0, dim1); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::transpose_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim0, dim1); #endif } inline Tensor Tensor::flip(IntArrayRef dims) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::flip(const_cast(*this), dims); break; default: - AT_ERROR("flip not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("flip not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::flip(Tensor self, int[] dims) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dims); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::flip", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dims); #endif } inline Tensor Tensor::roll(IntArrayRef shifts, IntArrayRef dims) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::roll(const_cast(*this), shifts, dims); break; default: - AT_ERROR("roll not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("roll not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), shifts, dims); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::roll", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), shifts, dims); #endif } inline Tensor Tensor::rot90(int64_t k, IntArrayRef dims) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::rot90(const_cast(*this), k, dims); #else - static auto table = globalATenDispatch().getOpTable("aten::rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), k, dims); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::rot90", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), k, dims); #endif } inline Tensor Tensor::trunc() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::trunc(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::trunc(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::trunc", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::trunc_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::trunc_(const_cast(*this)); break; default: - AT_ERROR("trunc_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("trunc_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::trunc_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::trunc_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::type_as(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::type_as(const_cast(*this), other); #else - static auto table = globalATenDispatch().getOpTable("aten::type_as(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::type_as", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::unsqueeze(int64_t dim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::unsqueeze(const_cast(*this), dim); #else - static auto table = globalATenDispatch().getOpTable("aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::unsqueeze", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim); #endif } inline Tensor & Tensor::unsqueeze_(int64_t dim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::unsqueeze_(const_cast(*this), dim); #else - static auto table = globalATenDispatch().getOpTable("aten::unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::unsqueeze_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim); #endif } inline Tensor Tensor::var(bool unbiased) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::var(const_cast(*this), unbiased); #else - static auto table = globalATenDispatch().getOpTable("aten::var(Tensor self, bool unbiased=True) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), unbiased); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::var", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), unbiased); #endif } inline Tensor Tensor::var(IntArrayRef dim, bool unbiased, bool keepdim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::var(const_cast(*this), dim, unbiased, keepdim); #else - static auto table = globalATenDispatch().getOpTable("aten::var.dim(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, unbiased, keepdim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::var", "dim"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim, unbiased, keepdim); #endif } #ifdef BUILD_NAMEDTENSOR @@ -2307,7 +2618,7 @@ inline Tensor Tensor::var(DimnameList dim, bool unbiased, bool keepdim) const { return TypeDefault::var(const_cast(*this), dim, unbiased, keepdim); #else static auto table = globalATenDispatch().getOpTable("aten::var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, unbiased, keepdim); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim, unbiased, keepdim); #endif } #endif @@ -2315,16 +2626,18 @@ inline Tensor Tensor::view_as(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::view_as(const_cast(*this), other); #else - static auto table = globalATenDispatch().getOpTable("aten::view_as(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::view_as", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::where(const Tensor & condition, const Tensor & other) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::where(condition, const_cast(*this), other); #else - static auto table = globalATenDispatch().getOpTable("aten::where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(condition, const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::where", "self"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(condition, *this, other))) + .callUnboxed(condition, const_cast(*this), other); #endif } inline Tensor Tensor::norm(c10::optional p, ScalarType dtype) const { @@ -2332,15 +2645,16 @@ inline Tensor Tensor::norm(c10::optional p, ScalarType dtype) const { return TypeDefault::norm(const_cast(*this), p, dtype); #else static auto table = globalATenDispatch().getOpTable("aten::norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor"); - return table->getOp, ScalarType)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), p, dtype); + return table->getOp, ScalarType)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, dtype); #endif } inline Tensor Tensor::norm(Scalar p) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::norm(const_cast(*this), p); #else - static auto table = globalATenDispatch().getOpTable("aten::norm.Scalar(Tensor self, Scalar p=2) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), p); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::norm", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), p); #endif } inline Tensor Tensor::norm(c10::optional p, IntArrayRef dim, bool keepdim, ScalarType dtype) const { @@ -2348,15 +2662,16 @@ inline Tensor Tensor::norm(c10::optional p, IntArrayRef dim, bool keepdi return TypeDefault::norm(const_cast(*this), p, dim, keepdim, dtype); #else static auto table = globalATenDispatch().getOpTable("aten::norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor"); - return table->getOp, IntArrayRef, bool, ScalarType)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), p, dim, keepdim, dtype); + return table->getOp, IntArrayRef, bool, ScalarType)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, dim, keepdim, dtype); #endif } inline Tensor Tensor::norm(c10::optional p, IntArrayRef dim, bool keepdim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::norm(const_cast(*this), p, dim, keepdim); #else - static auto table = globalATenDispatch().getOpTable("aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor"); - return table->getOp, IntArrayRef, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), p, dim, keepdim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::norm", "ScalarOpt_dim"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, IntArrayRef, bool>(const_cast(*this), p, dim, keepdim); #endif } #ifdef BUILD_NAMEDTENSOR @@ -2365,7 +2680,7 @@ inline Tensor Tensor::norm(c10::optional p, DimnameList dim, bool keepdi return TypeDefault::norm(const_cast(*this), p, dim, keepdim, dtype); #else static auto table = globalATenDispatch().getOpTable("aten::norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor"); - return table->getOp, DimnameList, bool, ScalarType)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), p, dim, keepdim, dtype); + return table->getOp, DimnameList, bool, ScalarType)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, dim, keepdim, dtype); #endif } #endif @@ -2375,13 +2690,13 @@ inline Tensor Tensor::norm(c10::optional p, DimnameList dim, bool keepdi return TypeDefault::norm(const_cast(*this), p, dim, keepdim); #else static auto table = globalATenDispatch().getOpTable("aten::norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor"); - return table->getOp, DimnameList, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), p, dim, keepdim); + return table->getOp, DimnameList, bool)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, dim, keepdim); #endif } #endif inline Tensor Tensor::clone() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::clone(const_cast(*this)); break; @@ -2392,16 +2707,17 @@ inline Tensor Tensor::clone() const { return SparseCPUType::clone(const_cast(*this)); break; default: - AT_ERROR("clone not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("clone not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::clone(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::clone", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::resize_as_(const Tensor & the_template) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::resize_as_(const_cast(*this), the_template); break; @@ -2409,16 +2725,17 @@ inline Tensor & Tensor::resize_as_(const Tensor & the_template) const { return SparseCPUType::resize_as_(const_cast(*this), the_template); break; default: - AT_ERROR("resize_as_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("resize_as_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::resize_as_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), the_template); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::resize_as_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, the_template))) + .callUnboxed(const_cast(*this), the_template); #endif } inline Tensor Tensor::pow(Scalar exponent) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::pow(const_cast(*this), exponent); break; @@ -2426,16 +2743,17 @@ inline Tensor Tensor::pow(Scalar exponent) const { return SparseCPUType::pow(const_cast(*this), exponent); break; default: - AT_ERROR("pow not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("pow not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), exponent); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::pow", "Tensor_Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), exponent); #endif } inline Tensor & Tensor::zero_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::zero_(const_cast(*this)); break; @@ -2443,439 +2761,533 @@ inline Tensor & Tensor::zero_() const { return SparseCPUType::zero_(const_cast(*this)); break; default: - AT_ERROR("zero_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("zero_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::zero_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::zero_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::sub(const Tensor & other, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH - return TypeDefault::sub(const_cast(*this), other, alpha); + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { + case Backend::CPU: + return CPUType::sub(const_cast(*this), other, alpha); + break; + case Backend::SparseCPU: + return SparseCPUType::sub(const_cast(*this), other, alpha); + break; + default: + AT_ERROR("sub not implemented for ", at::toString(type_set())); + } #else - static auto table = globalATenDispatch().getOpTable("aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sub", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other, alpha); #endif } inline Tensor & Tensor::sub_(const Tensor & other, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH - return TypeDefault::sub_(const_cast(*this), other, alpha); + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { + case Backend::CPU: + return CPUType::sub_(const_cast(*this), other, alpha); + break; + case Backend::SparseCPU: + return SparseCPUType::sub_(const_cast(*this), other, alpha); + break; + default: + AT_ERROR("sub_ not implemented for ", at::toString(type_set())); + } #else - static auto table = globalATenDispatch().getOpTable("aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sub_", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other, alpha); #endif } inline Tensor Tensor::sub(Scalar other, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::sub(const_cast(*this), other, alpha); #else - static auto table = globalATenDispatch().getOpTable("aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sub", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other, alpha); #endif } inline Tensor & Tensor::sub_(Scalar other, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::sub_(const_cast(*this), other, alpha); #else - static auto table = globalATenDispatch().getOpTable("aten::sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sub_", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other, alpha); #endif } inline Tensor Tensor::addmm(const Tensor & mat1, const Tensor & mat2, Scalar beta, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH - return TypeDefault::addmm(const_cast(*this), mat1, mat2, beta, alpha); + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { + case Backend::CPU: + return CPUType::addmm(const_cast(*this), mat1, mat2, beta, alpha); + break; + case Backend::SparseCPU: + return SparseCPUType::addmm(const_cast(*this), mat1, mat2, beta, alpha); + break; + default: + AT_ERROR("addmm not implemented for ", at::toString(type_set())); + } #else - static auto table = globalATenDispatch().getOpTable("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), mat1, mat2, beta, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addmm", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mat1, mat2))) + .callUnboxed(const_cast(*this), mat1, mat2, beta, alpha); #endif } inline Tensor & Tensor::addmm_(const Tensor & mat1, const Tensor & mat2, Scalar beta, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH - return TypeDefault::addmm_(const_cast(*this), mat1, mat2, beta, alpha); + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { + case Backend::CPU: + return CPUType::addmm_(const_cast(*this), mat1, mat2, beta, alpha); + break; + case Backend::SparseCPU: + return SparseCPUType::addmm_(const_cast(*this), mat1, mat2, beta, alpha); + break; + default: + AT_ERROR("addmm_ not implemented for ", at::toString(type_set())); + } #else - static auto table = globalATenDispatch().getOpTable("aten::addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), mat1, mat2, beta, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addmm_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mat1, mat2))) + .callUnboxed(const_cast(*this), mat1, mat2, beta, alpha); #endif } inline Tensor & Tensor::sparse_resize_(IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::SparseCPU: return SparseCPUType::sparse_resize_(const_cast(*this), size, sparse_dim, dense_dim); break; default: - AT_ERROR("sparse_resize_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("sparse_resize_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), size, sparse_dim, dense_dim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sparse_resize_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), size, sparse_dim, dense_dim); #endif } inline Tensor & Tensor::sparse_resize_and_clear_(IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::SparseCPU: return SparseCPUType::sparse_resize_and_clear_(const_cast(*this), size, sparse_dim, dense_dim); break; default: - AT_ERROR("sparse_resize_and_clear_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("sparse_resize_and_clear_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), size, sparse_dim, dense_dim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sparse_resize_and_clear_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), size, sparse_dim, dense_dim); #endif } inline Tensor Tensor::sparse_mask(const Tensor & mask) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { - case Backend::CPU: - return CPUType::sparse_mask(const_cast(*this), mask); + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { + case Backend::SparseCPU: + return SparseCPUType::sparse_mask(const_cast(*this), mask); break; default: - AT_ERROR("sparse_mask not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("sparse_mask not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::sparse_mask(Tensor self, Tensor mask) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), mask); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sparse_mask", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mask))) + .callUnboxed(const_cast(*this), mask); #endif } inline Tensor Tensor::to_dense() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::SparseCPU: return SparseCPUType::to_dense(const_cast(*this)); break; default: - AT_ERROR("to_dense not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("to_dense not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::to_dense(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::to_dense", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline int64_t Tensor::sparse_dim() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::SparseCPU: return SparseCPUType::sparse_dim(const_cast(*this)); break; default: - AT_ERROR("sparse_dim not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("sparse_dim not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::sparse_dim(Tensor self) -> int"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sparse_dim", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline int64_t Tensor::_dimI() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::SparseCPU: return SparseCPUType::_dimI(const_cast(*this)); break; default: - AT_ERROR("_dimI not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("_dimI not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::_dimI(Tensor self) -> int"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::_dimI", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline int64_t Tensor::dense_dim() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::SparseCPU: return SparseCPUType::dense_dim(const_cast(*this)); break; default: - AT_ERROR("dense_dim not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("dense_dim not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::dense_dim(Tensor self) -> int"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::dense_dim", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline int64_t Tensor::_dimV() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::SparseCPU: return SparseCPUType::_dimV(const_cast(*this)); break; default: - AT_ERROR("_dimV not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("_dimV not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::_dimV(Tensor self) -> int"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::_dimV", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline int64_t Tensor::_nnz() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::SparseCPU: return SparseCPUType::_nnz(const_cast(*this)); break; default: - AT_ERROR("_nnz not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("_nnz not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::_nnz(Tensor self) -> int"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::_nnz", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::coalesce() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::SparseCPU: return SparseCPUType::coalesce(const_cast(*this)); break; default: - AT_ERROR("coalesce not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("coalesce not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::coalesce(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::coalesce", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline bool Tensor::is_coalesced() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::SparseCPU: return SparseCPUType::is_coalesced(const_cast(*this)); break; default: - AT_ERROR("is_coalesced not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("is_coalesced not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::is_coalesced(Tensor self) -> bool"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::is_coalesced", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::_indices() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::SparseCPU: return SparseCPUType::_indices(const_cast(*this)); break; default: - AT_ERROR("_indices not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("_indices not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::_indices(Tensor(a) self) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::_indices", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::_values() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::SparseCPU: return SparseCPUType::_values(const_cast(*this)); break; default: - AT_ERROR("_values not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("_values not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::_values(Tensor(a) self) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::_values", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::_coalesced_(bool coalesced) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::SparseCPU: return SparseCPUType::_coalesced_(const_cast(*this), coalesced); break; default: - AT_ERROR("_coalesced_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("_coalesced_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::_coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), coalesced); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::_coalesced_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), coalesced); #endif } inline Tensor Tensor::indices() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::SparseCPU: return SparseCPUType::indices(const_cast(*this)); break; default: - AT_ERROR("indices not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("indices not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::indices(Tensor(a) self) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::indices", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::values() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::SparseCPU: return SparseCPUType::values(const_cast(*this)); break; default: - AT_ERROR("values not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("values not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::values(Tensor(a) self) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::values", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline int64_t Tensor::numel() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::numel(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::numel(Tensor self) -> int"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::numel", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline std::vector Tensor::unbind(int64_t dim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::unbind(const_cast(*this), dim); #else - static auto table = globalATenDispatch().getOpTable("aten::unbind(Tensor(a) self, int dim=0) -> Tensor(a)[]"); - return table->getOp (const Tensor &, int64_t)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::unbind", "int"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, const Tensor &, int64_t>(const_cast(*this), dim); #endif } +#ifdef BUILD_NAMEDTENSOR +inline std::vector Tensor::unbind(Dimname dim) const { +#ifdef USE_STATIC_DISPATCH + return TypeDefault::unbind(const_cast(*this), dim); +#else + static auto table = globalATenDispatch().getOpTable("aten::unbind.Dimname(Tensor(a) self, Dimname dim) -> Tensor(a)[]"); + return table->getOp (const Tensor &, Dimname)>(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dim); +#endif +} +#endif inline Tensor Tensor::to_sparse(int64_t sparse_dim) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::to_sparse(const_cast(*this), sparse_dim); break; default: - AT_ERROR("to_sparse not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("to_sparse not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), sparse_dim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::to_sparse", "sparse_dim"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), sparse_dim); #endif } inline Tensor Tensor::to_sparse() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::to_sparse(const_cast(*this)); break; default: - AT_ERROR("to_sparse not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("to_sparse not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::to_sparse(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::to_sparse", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::to_mkldnn() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::to_mkldnn(const_cast(*this)); break; default: - AT_ERROR("to_mkldnn not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("to_mkldnn not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::to_mkldnn(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::to_mkldnn", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::dequantize() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::QuantizedCPU: return QuantizedCPUType::dequantize(const_cast(*this)); break; default: - AT_ERROR("dequantize not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("dequantize not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::dequantize(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::dequantize", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline double Tensor::q_scale() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::QuantizedCPU: return QuantizedCPUType::q_scale(const_cast(*this)); break; default: - AT_ERROR("q_scale not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("q_scale not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::q_scale(Tensor self) -> float"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::q_scale", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline int64_t Tensor::q_zero_point() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::QuantizedCPU: return QuantizedCPUType::q_zero_point(const_cast(*this)); break; default: - AT_ERROR("q_zero_point not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("q_zero_point not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::q_zero_point(Tensor self) -> int"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::q_zero_point", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::q_per_channel_scales() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::QuantizedCPU: return QuantizedCPUType::q_per_channel_scales(const_cast(*this)); break; default: - AT_ERROR("q_per_channel_scales not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("q_per_channel_scales not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::q_per_channel_scales(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::q_per_channel_scales", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::q_per_channel_zero_points() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::QuantizedCPU: return QuantizedCPUType::q_per_channel_zero_points(const_cast(*this)); break; default: - AT_ERROR("q_per_channel_zero_points not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("q_per_channel_zero_points not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::q_per_channel_zero_points(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::q_per_channel_zero_points", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); +#endif +} +inline IntArrayRef Tensor::q_per_channel_axis() const { +#ifdef USE_STATIC_DISPATCH + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { + case Backend::QuantizedCPU: + return QuantizedCPUType::q_per_channel_axis(const_cast(*this)); + break; + default: + AT_ERROR("q_per_channel_axis not implemented for ", at::toString(type_set())); + } +#else + static auto table = globalATenDispatch().getOpTable("aten::q_per_channel_axis(Tensor self) -> int[]"); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); #endif } inline Tensor Tensor::int_repr() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::QuantizedCPU: return QuantizedCPUType::int_repr(const_cast(*this)); break; default: - AT_ERROR("int_repr not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("int_repr not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::int_repr(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::int_repr", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline QScheme Tensor::qscheme() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::QuantizedCPU: return QuantizedCPUType::qscheme(const_cast(*this)); break; default: - AT_ERROR("qscheme not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("qscheme not implemented for ", at::toString(type_set())); } #else static auto table = globalATenDispatch().getOpTable("aten::qscheme(Tensor self) -> QScheme"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this)); #endif } inline Tensor Tensor::to(const TensorOptions & options, bool non_blocking, bool copy) const { @@ -2883,7 +3295,7 @@ inline Tensor Tensor::to(const TensorOptions & options, bool non_blocking, bool return TypeDefault::to(const_cast(*this), options, non_blocking, copy); #else static auto table = globalATenDispatch().getOpTable("aten::to.dtype_layout(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, bool non_blocking=False, bool copy=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), options, non_blocking, copy); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this, options))(const_cast(*this), options, non_blocking, copy); #endif } inline Tensor Tensor::to(Device device, ScalarType dtype, bool non_blocking, bool copy) const { @@ -2891,7 +3303,7 @@ inline Tensor Tensor::to(Device device, ScalarType dtype, bool non_blocking, boo return TypeDefault::to(const_cast(*this), device, dtype, non_blocking, copy); #else static auto table = globalATenDispatch().getOpTable("aten::to.device(Tensor self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), device, dtype, non_blocking, copy); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), device, dtype, non_blocking, copy); #endif } inline Tensor Tensor::to(ScalarType dtype, bool non_blocking, bool copy) const { @@ -2899,42 +3311,44 @@ inline Tensor Tensor::to(ScalarType dtype, bool non_blocking, bool copy) const { return TypeDefault::to(const_cast(*this), dtype, non_blocking, copy); #else static auto table = globalATenDispatch().getOpTable("aten::to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dtype, non_blocking, copy); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), dtype, non_blocking, copy); #endif } inline Tensor Tensor::to(const Tensor & other, bool non_blocking, bool copy) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::to(const_cast(*this), other, non_blocking, copy); #else - static auto table = globalATenDispatch().getOpTable("aten::to.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other, non_blocking, copy); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::to", "other"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other, non_blocking, copy); #endif } inline Scalar Tensor::item() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::item(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::item(Tensor self) -> Scalar"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::item", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::set_(Storage source) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::set_(const_cast(*this), source); break; default: - AT_ERROR("set_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("set_ not implemented for ", at::toString(type_set())); } #else static auto table = globalATenDispatch().getOpTable("aten::set_.source_Storage(Tensor(a!) self, Storage source) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), source); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), source); #endif } inline Tensor & Tensor::set_(Storage source, int64_t storage_offset, IntArrayRef size, IntArrayRef stride) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::set_(const_cast(*this), source, storage_offset, size, stride); break; @@ -2942,138 +3356,147 @@ inline Tensor & Tensor::set_(Storage source, int64_t storage_offset, IntArrayRef return QuantizedCPUType::set_(const_cast(*this), source, storage_offset, size, stride); break; default: - AT_ERROR("set_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("set_ not implemented for ", at::toString(type_set())); } #else static auto table = globalATenDispatch().getOpTable("aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, int storage_offset, int[] size, int[] stride=[]) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), source, storage_offset, size, stride); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), source, storage_offset, size, stride); #endif } inline Tensor & Tensor::set_(const Tensor & source) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::set_(const_cast(*this), source); break; default: - AT_ERROR("set_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("set_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), source); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::set_", "source_Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, source))) + .callUnboxed(const_cast(*this), source); #endif } inline Tensor & Tensor::set_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::set_(const_cast(*this)); break; default: - AT_ERROR("set_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("set_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::set_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::set_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::set_quantizer_(ConstQuantizerPtr quantizer) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::QuantizedCPU: return QuantizedCPUType::set_quantizer_(const_cast(*this), quantizer); break; default: - AT_ERROR("set_quantizer_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("set_quantizer_ not implemented for ", at::toString(type_set())); } #else static auto table = globalATenDispatch().getOpTable("aten::set_quantizer_(Tensor(a!) self, ConstQuantizerPtr quantizer) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), quantizer); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), quantizer); #endif } inline bool Tensor::is_set_to(const Tensor & tensor) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::is_set_to(const_cast(*this), tensor); break; default: - AT_ERROR("is_set_to not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("is_set_to not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::is_set_to(Tensor self, Tensor tensor) -> bool"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), tensor); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::is_set_to", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, tensor))) + .callUnboxed(const_cast(*this), tensor); #endif } inline Tensor & Tensor::masked_fill_(const Tensor & mask, Scalar value) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::masked_fill_(const_cast(*this), mask, value); break; default: - AT_ERROR("masked_fill_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("masked_fill_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), mask, value); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::masked_fill_", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mask))) + .callUnboxed(const_cast(*this), mask, value); #endif } inline Tensor Tensor::masked_fill(const Tensor & mask, Scalar value) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::masked_fill(const_cast(*this), mask, value); #else - static auto table = globalATenDispatch().getOpTable("aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), mask, value); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::masked_fill", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mask))) + .callUnboxed(const_cast(*this), mask, value); #endif } inline Tensor & Tensor::masked_fill_(const Tensor & mask, const Tensor & value) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::masked_fill_(const_cast(*this), mask, value); break; default: - AT_ERROR("masked_fill_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("masked_fill_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), mask, value); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::masked_fill_", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mask, value))) + .callUnboxed(const_cast(*this), mask, value); #endif } inline Tensor Tensor::masked_fill(const Tensor & mask, const Tensor & value) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::masked_fill(const_cast(*this), mask, value); #else - static auto table = globalATenDispatch().getOpTable("aten::masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), mask, value); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::masked_fill", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mask, value))) + .callUnboxed(const_cast(*this), mask, value); #endif } inline Tensor & Tensor::masked_scatter_(const Tensor & mask, const Tensor & source) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::masked_scatter_(const_cast(*this), mask, source); break; default: - AT_ERROR("masked_scatter_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("masked_scatter_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), mask, source); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::masked_scatter_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mask, source))) + .callUnboxed(const_cast(*this), mask, source); #endif } inline Tensor Tensor::masked_scatter(const Tensor & mask, const Tensor & source) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::masked_scatter(const_cast(*this), mask, source); #else - static auto table = globalATenDispatch().getOpTable("aten::masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), mask, source); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::masked_scatter", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mask, source))) + .callUnboxed(const_cast(*this), mask, source); #endif } inline Tensor Tensor::view(IntArrayRef size) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::view(const_cast(*this), size); break; @@ -3081,1028 +3504,1085 @@ inline Tensor Tensor::view(IntArrayRef size) const { return QuantizedCPUType::view(const_cast(*this), size); break; default: - AT_ERROR("view not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("view not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::view(Tensor(a) self, int[] size) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), size); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::view", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), size); #endif } inline Tensor & Tensor::put_(const Tensor & index, const Tensor & source, bool accumulate) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::put_(const_cast(*this), index, source, accumulate); break; default: - AT_ERROR("put_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("put_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), index, source, accumulate); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::put_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, source))) + .callUnboxed(const_cast(*this), index, source, accumulate); #endif } inline Tensor & Tensor::index_add_(int64_t dim, const Tensor & index, const Tensor & source) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::index_add_(const_cast(*this), dim, index, source); break; default: - AT_ERROR("index_add_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("index_add_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, index, source); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::index_add_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, source))) + .callUnboxed(const_cast(*this), dim, index, source); #endif } inline Tensor Tensor::index_add(int64_t dim, const Tensor & index, const Tensor & source) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::index_add(const_cast(*this), dim, index, source); #else - static auto table = globalATenDispatch().getOpTable("aten::index_add(Tensor self, int dim, Tensor index, Tensor source) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, index, source); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::index_add", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, source))) + .callUnboxed(const_cast(*this), dim, index, source); #endif } inline Tensor & Tensor::index_fill_(int64_t dim, const Tensor & index, Scalar value) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::index_fill_(const_cast(*this), dim, index, value); break; default: - AT_ERROR("index_fill_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("index_fill_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::index_fill_.Scalar(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, index, value); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::index_fill_", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index))) + .callUnboxed(const_cast(*this), dim, index, value); #endif } inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, Scalar value) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::index_fill(const_cast(*this), dim, index, value); #else - static auto table = globalATenDispatch().getOpTable("aten::index_fill.Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, index, value); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::index_fill", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index))) + .callUnboxed(const_cast(*this), dim, index, value); #endif } inline Tensor & Tensor::index_fill_(int64_t dim, const Tensor & index, const Tensor & value) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::index_fill_(const_cast(*this), dim, index, value); break; default: - AT_ERROR("index_fill_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("index_fill_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::index_fill_.Tensor(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, index, value); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::index_fill_", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, value))) + .callUnboxed(const_cast(*this), dim, index, value); #endif } inline Tensor Tensor::index_fill(int64_t dim, const Tensor & index, const Tensor & value) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::index_fill(const_cast(*this), dim, index, value); #else - static auto table = globalATenDispatch().getOpTable("aten::index_fill.Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, index, value); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::index_fill", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, value))) + .callUnboxed(const_cast(*this), dim, index, value); #endif } inline Tensor & Tensor::scatter_(int64_t dim, const Tensor & index, const Tensor & src) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::scatter_(const_cast(*this), dim, index, src); break; default: - AT_ERROR("scatter_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("scatter_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, index, src); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::scatter_", "src"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, src))) + .callUnboxed(const_cast(*this), dim, index, src); #endif } inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, const Tensor & src) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::scatter(const_cast(*this), dim, index, src); #else - static auto table = globalATenDispatch().getOpTable("aten::scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, index, src); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::scatter", "src"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, src))) + .callUnboxed(const_cast(*this), dim, index, src); #endif } inline Tensor & Tensor::scatter_(int64_t dim, const Tensor & index, Scalar value) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::scatter_(const_cast(*this), dim, index, value); break; default: - AT_ERROR("scatter_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("scatter_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, index, value); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::scatter_", "value"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index))) + .callUnboxed(const_cast(*this), dim, index, value); #endif } inline Tensor Tensor::scatter(int64_t dim, const Tensor & index, Scalar value) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::scatter(const_cast(*this), dim, index, value); #else - static auto table = globalATenDispatch().getOpTable("aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, index, value); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::scatter", "value"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index))) + .callUnboxed(const_cast(*this), dim, index, value); #endif } inline Tensor & Tensor::scatter_add_(int64_t dim, const Tensor & index, const Tensor & src) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::scatter_add_(const_cast(*this), dim, index, src); break; default: - AT_ERROR("scatter_add_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("scatter_add_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, index, src); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::scatter_add_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, src))) + .callUnboxed(const_cast(*this), dim, index, src); #endif } inline Tensor Tensor::scatter_add(int64_t dim, const Tensor & index, const Tensor & src) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::scatter_add(const_cast(*this), dim, index, src); #else - static auto table = globalATenDispatch().getOpTable("aten::scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, index, src); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::scatter_add", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index, src))) + .callUnboxed(const_cast(*this), dim, index, src); #endif } inline Tensor & Tensor::lt_(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::lt_(const_cast(*this), other); break; default: - AT_ERROR("lt_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("lt_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lt_", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::lt_(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::lt_(const_cast(*this), other); break; default: - AT_ERROR("lt_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("lt_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lt_", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::gt_(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::gt_(const_cast(*this), other); break; default: - AT_ERROR("gt_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("gt_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::gt_", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::gt_(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::gt_(const_cast(*this), other); break; default: - AT_ERROR("gt_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("gt_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::gt_", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::le_(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::le_(const_cast(*this), other); break; default: - AT_ERROR("le_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("le_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::le_", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::le_(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::le_(const_cast(*this), other); break; default: - AT_ERROR("le_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("le_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::le_", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::ge_(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::ge_(const_cast(*this), other); break; default: - AT_ERROR("ge_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("ge_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ge_", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::ge_(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::ge_(const_cast(*this), other); break; default: - AT_ERROR("ge_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("ge_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ge_", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::eq_(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::eq_(const_cast(*this), other); break; default: - AT_ERROR("eq_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("eq_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::eq_", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::eq_(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::eq_(const_cast(*this), other); break; default: - AT_ERROR("eq_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("eq_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::eq_", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::ne_(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::ne_(const_cast(*this), other); break; default: - AT_ERROR("ne_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("ne_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ne_", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::ne_(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::ne_(const_cast(*this), other); break; default: - AT_ERROR("ne_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("ne_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ne_", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::__and__(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__and__(const_cast(*this), other); break; default: - AT_ERROR("__and__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__and__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__and__", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::__and__(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__and__(const_cast(*this), other); break; default: - AT_ERROR("__and__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__and__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__and__", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::__iand__(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__iand__(const_cast(*this), other); break; default: - AT_ERROR("__iand__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__iand__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__iand__", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::__iand__(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__iand__(const_cast(*this), other); break; default: - AT_ERROR("__iand__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__iand__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__iand__", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::__or__(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__or__(const_cast(*this), other); break; default: - AT_ERROR("__or__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__or__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__or__.Scalar(Tensor self, Scalar other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__or__", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::__or__(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__or__(const_cast(*this), other); break; default: - AT_ERROR("__or__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__or__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__or__.Tensor(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__or__", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::__ior__(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__ior__(const_cast(*this), other); break; default: - AT_ERROR("__ior__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__ior__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__ior__", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::__ior__(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__ior__(const_cast(*this), other); break; default: - AT_ERROR("__ior__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__ior__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__ior__", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::__xor__(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__xor__(const_cast(*this), other); break; default: - AT_ERROR("__xor__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__xor__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__xor__.Scalar(Tensor self, Scalar other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__xor__", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::__xor__(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__xor__(const_cast(*this), other); break; default: - AT_ERROR("__xor__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__xor__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__xor__.Tensor(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__xor__", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::__ixor__(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__ixor__(const_cast(*this), other); break; default: - AT_ERROR("__ixor__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__ixor__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__ixor__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__ixor__", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::__ixor__(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__ixor__(const_cast(*this), other); break; default: - AT_ERROR("__ixor__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__ixor__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__ixor__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__ixor__", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::__lshift__(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__lshift__(const_cast(*this), other); break; default: - AT_ERROR("__lshift__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__lshift__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__lshift__", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::__lshift__(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__lshift__(const_cast(*this), other); break; default: - AT_ERROR("__lshift__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__lshift__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__lshift__", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::__ilshift__(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__ilshift__(const_cast(*this), other); break; default: - AT_ERROR("__ilshift__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__ilshift__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__ilshift__", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::__ilshift__(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__ilshift__(const_cast(*this), other); break; default: - AT_ERROR("__ilshift__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__ilshift__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__ilshift__", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::__rshift__(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__rshift__(const_cast(*this), other); break; default: - AT_ERROR("__rshift__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__rshift__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__rshift__", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::__rshift__(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__rshift__(const_cast(*this), other); break; default: - AT_ERROR("__rshift__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__rshift__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__rshift__", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::__irshift__(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__irshift__(const_cast(*this), other); break; default: - AT_ERROR("__irshift__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__irshift__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__irshift__", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::__irshift__(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::__irshift__(const_cast(*this), other); break; default: - AT_ERROR("__irshift__ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("__irshift__ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::__irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::__irshift__", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::lgamma_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::lgamma_(const_cast(*this)); break; default: - AT_ERROR("lgamma_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("lgamma_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::lgamma_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lgamma_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::atan2_(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::atan2_(const_cast(*this), other); #else - static auto table = globalATenDispatch().getOpTable("aten::atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::atan2_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::tril_(int64_t diagonal) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::tril_(const_cast(*this), diagonal); break; default: - AT_ERROR("tril_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("tril_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), diagonal); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::tril_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), diagonal); #endif } inline Tensor & Tensor::triu_(int64_t diagonal) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::triu_(const_cast(*this), diagonal); break; default: - AT_ERROR("triu_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("triu_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), diagonal); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::triu_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), diagonal); #endif } inline Tensor & Tensor::digamma_() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { - case Backend::CPU: - return CPUType::digamma_(const_cast(*this)); - break; - default: - AT_ERROR("digamma_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); - } + return TypeDefault::digamma_(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::digamma_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::digamma_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::polygamma_(int64_t n) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { - case Backend::CPU: - return CPUType::polygamma_(const_cast(*this), n); - break; - default: - AT_ERROR("polygamma_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); - } + return TypeDefault::polygamma_(const_cast(*this), n); #else - static auto table = globalATenDispatch().getOpTable("aten::polygamma_(Tensor(a!) self, int n) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), n); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::polygamma_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), n); #endif } inline Tensor & Tensor::renorm_(Scalar p, int64_t dim, Scalar maxnorm) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::renorm_(const_cast(*this), p, dim, maxnorm); break; default: - AT_ERROR("renorm_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("renorm_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), p, dim, maxnorm); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::renorm_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), p, dim, maxnorm); #endif } inline Tensor & Tensor::pow_(Scalar exponent) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::pow_(const_cast(*this), exponent); break; default: - AT_ERROR("pow_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("pow_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), exponent); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::pow_", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), exponent); #endif } inline Tensor & Tensor::pow_(const Tensor & exponent) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::pow_(const_cast(*this), exponent); break; default: - AT_ERROR("pow_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("pow_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), exponent); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::pow_", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, exponent))) + .callUnboxed(const_cast(*this), exponent); #endif } inline Tensor & Tensor::lerp_(const Tensor & end, Scalar weight) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::lerp_(const_cast(*this), end, weight); break; default: - AT_ERROR("lerp_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("lerp_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::lerp_.Scalar(Tensor(a!) self, Tensor end, Scalar weight) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), end, weight); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lerp_", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, end))) + .callUnboxed(const_cast(*this), end, weight); #endif } inline Tensor & Tensor::lerp_(const Tensor & end, const Tensor & weight) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::lerp_(const_cast(*this), end, weight); break; default: - AT_ERROR("lerp_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("lerp_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::lerp_.Tensor(Tensor(a!) self, Tensor end, Tensor weight) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), end, weight); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lerp_", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, end, weight))) + .callUnboxed(const_cast(*this), end, weight); #endif } inline Tensor & Tensor::fmod_(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::fmod_(const_cast(*this), other); break; default: - AT_ERROR("fmod_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("fmod_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::fmod_", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::fmod_(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::fmod_(const_cast(*this), other); break; default: - AT_ERROR("fmod_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("fmod_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::fmod_", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::remainder_(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::remainder_(const_cast(*this), other); break; default: - AT_ERROR("remainder_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("remainder_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::remainder_", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::remainder_(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::remainder_(const_cast(*this), other); break; default: - AT_ERROR("remainder_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("remainder_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::remainder_", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor & Tensor::addbmm_(const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::addbmm_(const_cast(*this), batch1, batch2, beta, alpha); break; default: - AT_ERROR("addbmm_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("addbmm_ not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), batch1, batch2, beta, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addbmm_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, batch1, batch2))) + .callUnboxed(const_cast(*this), batch1, batch2, beta, alpha); #endif } inline Tensor Tensor::addbmm(const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::addbmm(const_cast(*this), batch1, batch2, beta, alpha); break; default: - AT_ERROR("addbmm not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("addbmm not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), batch1, batch2, beta, alpha); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addbmm", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, batch1, batch2))) + .callUnboxed(const_cast(*this), batch1, batch2, beta, alpha); #endif } inline Tensor & Tensor::addcdiv_(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::addcdiv_(const_cast(*this), tensor1, tensor2, value); #else - static auto table = globalATenDispatch().getOpTable("aten::addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), tensor1, tensor2, value); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addcdiv_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, tensor1, tensor2))) + .callUnboxed(const_cast(*this), tensor1, tensor2, value); #endif } inline Tensor & Tensor::random_(int64_t from, int64_t to, Generator * generator) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::random_(const_cast(*this), from, to, generator); break; default: - AT_ERROR("random_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("random_ not implemented for ", at::toString(type_set())); } #else static auto table = globalATenDispatch().getOpTable("aten::random_.from(Tensor(a!) self, int from, int to, *, Generator? generator=None) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), from, to, generator); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), from, to, generator); #endif } inline Tensor & Tensor::random_(int64_t to, Generator * generator) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::random_(const_cast(*this), to, generator); break; default: - AT_ERROR("random_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("random_ not implemented for ", at::toString(type_set())); } #else static auto table = globalATenDispatch().getOpTable("aten::random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), to, generator); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), to, generator); #endif } inline Tensor & Tensor::random_(Generator * generator) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::random_(const_cast(*this), generator); break; default: - AT_ERROR("random_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("random_ not implemented for ", at::toString(type_set())); } #else static auto table = globalATenDispatch().getOpTable("aten::random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), generator); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), generator); #endif } inline Tensor & Tensor::uniform_(double from, double to, Generator * generator) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::uniform_(const_cast(*this), from, to, generator); break; default: - AT_ERROR("uniform_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("uniform_ not implemented for ", at::toString(type_set())); } #else static auto table = globalATenDispatch().getOpTable("aten::uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), from, to, generator); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), from, to, generator); #endif } inline Tensor & Tensor::normal_(double mean, double std, Generator * generator) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::normal_(const_cast(*this), mean, std, generator); break; default: - AT_ERROR("normal_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("normal_ not implemented for ", at::toString(type_set())); } #else static auto table = globalATenDispatch().getOpTable("aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), mean, std, generator); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), mean, std, generator); #endif } inline Tensor & Tensor::cauchy_(double median, double sigma, Generator * generator) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::cauchy_(const_cast(*this), median, sigma, generator); break; default: - AT_ERROR("cauchy_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("cauchy_ not implemented for ", at::toString(type_set())); } #else static auto table = globalATenDispatch().getOpTable("aten::cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), median, sigma, generator); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), median, sigma, generator); #endif } inline Tensor & Tensor::log_normal_(double mean, double std, Generator * generator) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::log_normal_(const_cast(*this), mean, std, generator); break; default: - AT_ERROR("log_normal_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("log_normal_ not implemented for ", at::toString(type_set())); } #else static auto table = globalATenDispatch().getOpTable("aten::log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), mean, std, generator); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), mean, std, generator); #endif } inline Tensor & Tensor::exponential_(double lambd, Generator * generator) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::exponential_(const_cast(*this), lambd, generator); break; default: - AT_ERROR("exponential_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("exponential_ not implemented for ", at::toString(type_set())); } #else static auto table = globalATenDispatch().getOpTable("aten::exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), lambd, generator); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), lambd, generator); #endif } inline Tensor & Tensor::geometric_(double p, Generator * generator) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::geometric_(const_cast(*this), p, generator); break; default: - AT_ERROR("geometric_ not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("geometric_ not implemented for ", at::toString(type_set())); } #else static auto table = globalATenDispatch().getOpTable("aten::geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), p, generator); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), p, generator); #endif } inline Tensor Tensor::diag(int64_t diagonal) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::diag(const_cast(*this), diagonal); break; default: - AT_ERROR("diag not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("diag not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::diag(Tensor self, int diagonal=0) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), diagonal); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::diag", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), diagonal); #endif } inline Tensor Tensor::cross(const Tensor & other, c10::optional dim) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::cross(const_cast(*this), other, dim); #else - static auto table = globalATenDispatch().getOpTable("aten::cross(Tensor self, Tensor other, int? dim=None) -> Tensor"); - return table->getOp)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other, dim); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cross", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed>(const_cast(*this), other, dim); #endif } inline Tensor Tensor::triu(int64_t diagonal) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::triu(const_cast(*this), diagonal); #else - static auto table = globalATenDispatch().getOpTable("aten::triu(Tensor self, int diagonal=0) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), diagonal); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::triu", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), diagonal); #endif } inline Tensor Tensor::tril(int64_t diagonal) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::tril(const_cast(*this), diagonal); #else - static auto table = globalATenDispatch().getOpTable("aten::tril(Tensor self, int diagonal=0) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), diagonal); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::tril", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), diagonal); #endif } inline Tensor Tensor::trace() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::trace(const_cast(*this)); break; default: - AT_ERROR("trace not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("trace not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::trace(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::trace", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::ne(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::ne(const_cast(*this), other); break; @@ -4110,16 +4590,17 @@ inline Tensor Tensor::ne(Scalar other) const { return QuantizedCPUType::ne(const_cast(*this), other); break; default: - AT_ERROR("ne not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("ne not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::ne.Scalar(Tensor self, Scalar other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ne", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::ne(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::ne(const_cast(*this), other); break; @@ -4127,16 +4608,17 @@ inline Tensor Tensor::ne(const Tensor & other) const { return QuantizedCPUType::ne(const_cast(*this), other); break; default: - AT_ERROR("ne not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("ne not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::ne.Tensor(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ne", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::eq(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::eq(const_cast(*this), other); break; @@ -4144,16 +4626,17 @@ inline Tensor Tensor::eq(Scalar other) const { return QuantizedCPUType::eq(const_cast(*this), other); break; default: - AT_ERROR("eq not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("eq not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::eq.Scalar(Tensor self, Scalar other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::eq", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::eq(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::eq(const_cast(*this), other); break; @@ -4161,16 +4644,17 @@ inline Tensor Tensor::eq(const Tensor & other) const { return QuantizedCPUType::eq(const_cast(*this), other); break; default: - AT_ERROR("eq not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("eq not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::eq.Tensor(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::eq", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::ge(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::ge(const_cast(*this), other); break; @@ -4178,16 +4662,17 @@ inline Tensor Tensor::ge(Scalar other) const { return QuantizedCPUType::ge(const_cast(*this), other); break; default: - AT_ERROR("ge not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("ge not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::ge.Scalar(Tensor self, Scalar other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ge", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::ge(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::ge(const_cast(*this), other); break; @@ -4195,16 +4680,17 @@ inline Tensor Tensor::ge(const Tensor & other) const { return QuantizedCPUType::ge(const_cast(*this), other); break; default: - AT_ERROR("ge not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("ge not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::ge.Tensor(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ge", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::le(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::le(const_cast(*this), other); break; @@ -4212,16 +4698,17 @@ inline Tensor Tensor::le(Scalar other) const { return QuantizedCPUType::le(const_cast(*this), other); break; default: - AT_ERROR("le not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("le not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::le.Scalar(Tensor self, Scalar other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::le", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::le(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::le(const_cast(*this), other); break; @@ -4229,16 +4716,17 @@ inline Tensor Tensor::le(const Tensor & other) const { return QuantizedCPUType::le(const_cast(*this), other); break; default: - AT_ERROR("le not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("le not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::le.Tensor(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::le", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::gt(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::gt(const_cast(*this), other); break; @@ -4246,16 +4734,17 @@ inline Tensor Tensor::gt(Scalar other) const { return QuantizedCPUType::gt(const_cast(*this), other); break; default: - AT_ERROR("gt not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("gt not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::gt.Scalar(Tensor self, Scalar other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::gt", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::gt(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::gt(const_cast(*this), other); break; @@ -4263,16 +4752,17 @@ inline Tensor Tensor::gt(const Tensor & other) const { return QuantizedCPUType::gt(const_cast(*this), other); break; default: - AT_ERROR("gt not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("gt not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::gt.Tensor(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::gt", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::lt(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::lt(const_cast(*this), other); break; @@ -4280,16 +4770,17 @@ inline Tensor Tensor::lt(Scalar other) const { return QuantizedCPUType::lt(const_cast(*this), other); break; default: - AT_ERROR("lt not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("lt not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::lt.Scalar(Tensor self, Scalar other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lt", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::lt(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::lt(const_cast(*this), other); break; @@ -4297,30 +4788,32 @@ inline Tensor Tensor::lt(const Tensor & other) const { return QuantizedCPUType::lt(const_cast(*this), other); break; default: - AT_ERROR("lt not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("lt not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::lt.Tensor(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lt", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::take(const Tensor & index) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::take(const_cast(*this), index); break; default: - AT_ERROR("take not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("take not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::take(Tensor self, Tensor index) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), index); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::take", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index))) + .callUnboxed(const_cast(*this), index); #endif } inline Tensor Tensor::index_select(int64_t dim, const Tensor & index) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::index_select(const_cast(*this), dim, index); break; @@ -4328,460 +4821,487 @@ inline Tensor Tensor::index_select(int64_t dim, const Tensor & index) const { return SparseCPUType::index_select(const_cast(*this), dim, index); break; default: - AT_ERROR("index_select not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("index_select not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::index_select(Tensor self, int dim, Tensor index) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, index); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::index_select", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index))) + .callUnboxed(const_cast(*this), dim, index); #endif } inline Tensor Tensor::masked_select(const Tensor & mask) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::masked_select(const_cast(*this), mask); break; default: - AT_ERROR("masked_select not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("masked_select not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::masked_select(Tensor self, Tensor mask) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), mask); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::masked_select", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, mask))) + .callUnboxed(const_cast(*this), mask); #endif } inline Tensor Tensor::nonzero() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::nonzero(const_cast(*this)); break; default: - AT_ERROR("nonzero not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("nonzero not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::nonzero(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::nonzero", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline std::vector Tensor::nonzero_numpy() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::nonzero_numpy(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::nonzero_numpy(Tensor self) -> Tensor[]"); - return table->getOp (const Tensor &)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::nonzero_numpy", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, const Tensor &>(const_cast(*this)); #endif } inline Tensor Tensor::gather(int64_t dim, const Tensor & index, bool sparse_grad) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::gather(const_cast(*this), dim, index, sparse_grad); break; default: - AT_ERROR("gather not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("gather not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, index, sparse_grad); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::gather", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, index))) + .callUnboxed(const_cast(*this), dim, index, sparse_grad); #endif } inline Tensor Tensor::addcmul(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::addcmul(const_cast(*this), tensor1, tensor2, value); #else - static auto table = globalATenDispatch().getOpTable("aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), tensor1, tensor2, value); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addcmul", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, tensor1, tensor2))) + .callUnboxed(const_cast(*this), tensor1, tensor2, value); #endif } inline Tensor & Tensor::addcmul_(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::addcmul_(const_cast(*this), tensor1, tensor2, value); #else - static auto table = globalATenDispatch().getOpTable("aten::addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), tensor1, tensor2, value); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addcmul_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, tensor1, tensor2))) + .callUnboxed(const_cast(*this), tensor1, tensor2, value); #endif } inline Tensor Tensor::addcdiv(const Tensor & tensor1, const Tensor & tensor2, Scalar value) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::addcdiv(const_cast(*this), tensor1, tensor2, value); #else - static auto table = globalATenDispatch().getOpTable("aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), tensor1, tensor2, value); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::addcdiv", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, tensor1, tensor2))) + .callUnboxed(const_cast(*this), tensor1, tensor2, value); #endif } inline std::tuple Tensor::lstsq(const Tensor & A) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::lstsq(const_cast(*this), A); break; default: - AT_ERROR("lstsq not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("lstsq not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::lstsq(Tensor self, Tensor A) -> (Tensor solution, Tensor QR)"); - return table->getOp (const Tensor &, const Tensor &)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), A); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lstsq", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, A))) + .callUnboxed, const Tensor &, const Tensor &>(const_cast(*this), A); #endif } inline std::tuple Tensor::triangular_solve(const Tensor & A, bool upper, bool transpose, bool unitriangular) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::triangular_solve(const_cast(*this), A, upper, transpose, unitriangular); #else - static auto table = globalATenDispatch().getOpTable("aten::triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient)"); - return table->getOp (const Tensor &, const Tensor &, bool, bool, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), A, upper, transpose, unitriangular); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::triangular_solve", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, A))) + .callUnboxed, const Tensor &, const Tensor &, bool, bool, bool>(const_cast(*this), A, upper, transpose, unitriangular); #endif } inline std::tuple Tensor::symeig(bool eigenvectors, bool upper) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::symeig(const_cast(*this), eigenvectors, upper); #else - static auto table = globalATenDispatch().getOpTable("aten::symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors)"); - return table->getOp (const Tensor &, bool, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), eigenvectors, upper); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::symeig", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, const Tensor &, bool, bool>(const_cast(*this), eigenvectors, upper); #endif } inline std::tuple Tensor::eig(bool eigenvectors) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::eig(const_cast(*this), eigenvectors); break; default: - AT_ERROR("eig not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("eig not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::eig(Tensor self, bool eigenvectors=False) -> (Tensor eigenvalues, Tensor eigenvectors)"); - return table->getOp (const Tensor &, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), eigenvectors); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::eig", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, const Tensor &, bool>(const_cast(*this), eigenvectors); #endif } inline std::tuple Tensor::svd(bool some, bool compute_uv) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::svd(const_cast(*this), some, compute_uv); #else - static auto table = globalATenDispatch().getOpTable("aten::svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V)"); - return table->getOp (const Tensor &, bool, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), some, compute_uv); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::svd", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, const Tensor &, bool, bool>(const_cast(*this), some, compute_uv); #endif } inline Tensor Tensor::cholesky(bool upper) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::cholesky(const_cast(*this), upper); #else - static auto table = globalATenDispatch().getOpTable("aten::cholesky(Tensor self, bool upper=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), upper); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cholesky", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), upper); #endif } inline Tensor Tensor::cholesky_solve(const Tensor & input2, bool upper) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::cholesky_solve(const_cast(*this), input2, upper); #else - static auto table = globalATenDispatch().getOpTable("aten::cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), input2, upper); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cholesky_solve", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, input2))) + .callUnboxed(const_cast(*this), input2, upper); #endif } inline std::tuple Tensor::solve(const Tensor & A) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::solve(const_cast(*this), A); #else - static auto table = globalATenDispatch().getOpTable("aten::solve(Tensor self, Tensor A) -> (Tensor solution, Tensor LU)"); - return table->getOp (const Tensor &, const Tensor &)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), A); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::solve", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, A))) + .callUnboxed, const Tensor &, const Tensor &>(const_cast(*this), A); #endif } inline Tensor Tensor::cholesky_inverse(bool upper) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::cholesky_inverse(const_cast(*this), upper); break; default: - AT_ERROR("cholesky_inverse not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("cholesky_inverse not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::cholesky_inverse(Tensor self, bool upper=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), upper); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cholesky_inverse", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), upper); #endif } inline std::tuple Tensor::qr(bool some) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::qr(const_cast(*this), some); #else - static auto table = globalATenDispatch().getOpTable("aten::qr(Tensor self, bool some=True) -> (Tensor Q, Tensor R)"); - return table->getOp (const Tensor &, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), some); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::qr", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, const Tensor &, bool>(const_cast(*this), some); #endif } inline std::tuple Tensor::geqrf() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::geqrf(const_cast(*this)); break; default: - AT_ERROR("geqrf not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("geqrf not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::geqrf(Tensor self) -> (Tensor a, Tensor tau)"); - return table->getOp (const Tensor &)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::geqrf", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, const Tensor &>(const_cast(*this)); #endif } inline Tensor Tensor::orgqr(const Tensor & input2) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::orgqr(const_cast(*this), input2); break; default: - AT_ERROR("orgqr not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("orgqr not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::orgqr(Tensor self, Tensor input2) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), input2); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::orgqr", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, input2))) + .callUnboxed(const_cast(*this), input2); #endif } inline Tensor Tensor::ormqr(const Tensor & input2, const Tensor & input3, bool left, bool transpose) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::ormqr(const_cast(*this), input2, input3, left, transpose); break; default: - AT_ERROR("ormqr not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("ormqr not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), input2, input3, left, transpose); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::ormqr", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, input2, input3))) + .callUnboxed(const_cast(*this), input2, input3, left, transpose); #endif } inline Tensor Tensor::lu_solve(const Tensor & LU_data, const Tensor & LU_pivots) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::lu_solve(const_cast(*this), LU_data, LU_pivots); #else - static auto table = globalATenDispatch().getOpTable("aten::lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), LU_data, LU_pivots); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lu_solve", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, LU_data, LU_pivots))) + .callUnboxed(const_cast(*this), LU_data, LU_pivots); #endif } inline Tensor Tensor::multinomial(int64_t num_samples, bool replacement, Generator * generator) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::multinomial(const_cast(*this), num_samples, replacement, generator); break; default: - AT_ERROR("multinomial not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("multinomial not implemented for ", at::toString(type_set())); } #else static auto table = globalATenDispatch().getOpTable("aten::multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), num_samples, replacement, generator); + return table->getOp(at::detail::multi_dispatch_tensor_type_set(*this))(const_cast(*this), num_samples, replacement, generator); #endif } inline Tensor Tensor::lgamma() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::lgamma(const_cast(*this)); break; default: - AT_ERROR("lgamma not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("lgamma not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::lgamma(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lgamma", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::digamma() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { - case Backend::CPU: - return CPUType::digamma(const_cast(*this)); - break; - default: - AT_ERROR("digamma not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); - } + return TypeDefault::digamma(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::digamma(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::digamma", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::polygamma(int64_t n) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { - case Backend::CPU: - return CPUType::polygamma(n, const_cast(*this)); - break; - default: - AT_ERROR("polygamma not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); - } + return TypeDefault::polygamma(n, const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::polygamma(int n, Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(n, const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::polygamma", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(n, const_cast(*this)); #endif } inline Tensor Tensor::erfinv() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::erfinv(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::erfinv(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::erfinv", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::erfinv_() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::erfinv_(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::erfinv_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::erfinv_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::sign() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::sign(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::sign(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sign", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor & Tensor::sign_() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::sign_(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::sign_(Tensor(a!) self) -> Tensor(a!)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sign_", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::dist(const Tensor & other, Scalar p) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::dist(const_cast(*this), other, p); break; default: - AT_ERROR("dist not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("dist not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::dist(Tensor self, Tensor other, Scalar p=2) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other, p); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::dist", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other, p); #endif } inline Tensor Tensor::atan2(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::atan2(const_cast(*this), other); #else - static auto table = globalATenDispatch().getOpTable("aten::atan2(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::atan2", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::lerp(const Tensor & end, Scalar weight) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::lerp(const_cast(*this), end, weight); break; default: - AT_ERROR("lerp not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("lerp not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), end, weight); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lerp", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, end))) + .callUnboxed(const_cast(*this), end, weight); #endif } inline Tensor Tensor::lerp(const Tensor & end, const Tensor & weight) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::lerp(const_cast(*this), end, weight); break; default: - AT_ERROR("lerp not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("lerp not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), end, weight); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::lerp", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, end, weight))) + .callUnboxed(const_cast(*this), end, weight); #endif } inline Tensor Tensor::histc(int64_t bins, Scalar min, Scalar max) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::histc(const_cast(*this), bins, min, max); break; default: - AT_ERROR("histc not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("histc not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), bins, min, max); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::histc", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), bins, min, max); #endif } inline Tensor Tensor::fmod(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::fmod(const_cast(*this), other); break; default: - AT_ERROR("fmod not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("fmod not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::fmod", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::fmod(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::fmod(const_cast(*this), other); break; default: - AT_ERROR("fmod not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("fmod not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::fmod", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::remainder(Scalar other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::remainder(const_cast(*this), other); break; default: - AT_ERROR("remainder not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("remainder not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::remainder", "Scalar"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::remainder(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::remainder(const_cast(*this), other); break; default: - AT_ERROR("remainder not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("remainder not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::remainder", "Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::min(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::min(const_cast(*this), other); break; default: - AT_ERROR("min not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("min not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::min.other(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::min", "other"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::min() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::min(const_cast(*this)); break; @@ -4789,30 +5309,32 @@ inline Tensor Tensor::min() const { return QuantizedCPUType::min(const_cast(*this)); break; default: - AT_ERROR("min not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("min not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::min(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::min", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::max(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::max(const_cast(*this), other); break; default: - AT_ERROR("max not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("max not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::max.other(Tensor self, Tensor other) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::max", "other"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::max() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::max(const_cast(*this)); break; @@ -4820,30 +5342,32 @@ inline Tensor Tensor::max() const { return QuantizedCPUType::max(const_cast(*this)); break; default: - AT_ERROR("max not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("max not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::max(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::max", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::median() const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::median(const_cast(*this)); break; default: - AT_ERROR("median not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("median not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::median(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::median", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline std::tuple Tensor::sort(int64_t dim, bool descending) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::sort(const_cast(*this), dim, descending); break; @@ -4851,76 +5375,83 @@ inline std::tuple Tensor::sort(int64_t dim, bool descending) cons return QuantizedCPUType::sort(const_cast(*this), dim, descending); break; default: - AT_ERROR("sort not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("sort not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)"); - return table->getOp (const Tensor &, int64_t, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, descending); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::sort", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, const Tensor &, int64_t, bool>(const_cast(*this), dim, descending); #endif } inline Tensor Tensor::argsort(int64_t dim, bool descending) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::argsort(const_cast(*this), dim, descending); #else - static auto table = globalATenDispatch().getOpTable("aten::argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dim, descending); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::argsort", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dim, descending); #endif } inline std::tuple Tensor::topk(int64_t k, int64_t dim, bool largest, bool sorted) const { #ifdef USE_STATIC_DISPATCH return TypeDefault::topk(const_cast(*this), k, dim, largest, sorted); #else - static auto table = globalATenDispatch().getOpTable("aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)"); - return table->getOp (const Tensor &, int64_t, int64_t, bool, bool)>(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), k, dim, largest, sorted); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::topk", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed, const Tensor &, int64_t, int64_t, bool, bool>(const_cast(*this), k, dim, largest, sorted); #endif } inline Tensor Tensor::all() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::all(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::all(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::all", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::any() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::any(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::any(Tensor self) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::any", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } inline Tensor Tensor::renorm(Scalar p, int64_t dim, Scalar maxnorm) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::renorm(const_cast(*this), p, dim, maxnorm); break; default: - AT_ERROR("renorm not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("renorm not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), p, dim, maxnorm); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::renorm", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), p, dim, maxnorm); #endif } inline Tensor Tensor::unfold(int64_t dimension, int64_t size, int64_t step) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::unfold(const_cast(*this), dimension, size, step); break; default: - AT_ERROR("unfold not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("unfold not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), dimension, size, step); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::unfold", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this), dimension, size, step); #endif } inline bool Tensor::equal(const Tensor & other) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::equal(const_cast(*this), other); break; @@ -4928,33 +5459,36 @@ inline bool Tensor::equal(const Tensor & other) const { return QuantizedCPUType::equal(const_cast(*this), other); break; default: - AT_ERROR("equal not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("equal not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::equal(Tensor self, Tensor other) -> bool"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), other); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::equal", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, other))) + .callUnboxed(const_cast(*this), other); #endif } inline Tensor Tensor::pow(const Tensor & exponent) const { #ifdef USE_STATIC_DISPATCH - switch(tensorTypeIdToBackend(type_id())) { + switch(tensorTypeIdToBackend(impl::dispatchTypeId(type_set()))) { case Backend::CPU: return CPUType::pow(const_cast(*this), exponent); break; default: - AT_ERROR("pow not implemented for ", at::toString(tensorTypeIdToBackend(type_id()))); + AT_ERROR("pow not implemented for ", at::toString(type_set())); } #else - static auto table = globalATenDispatch().getOpTable("aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this), exponent); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::pow", "Tensor_Tensor"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this, exponent))) + .callUnboxed(const_cast(*this), exponent); #endif } inline Tensor Tensor::alias() const { #ifdef USE_STATIC_DISPATCH return TypeDefault::alias(const_cast(*this)); #else - static auto table = globalATenDispatch().getOpTable("aten::alias(Tensor(a) self) -> Tensor(a)"); - return table->getOp(tensorTypeIdToBackend(type_id()), is_variable())(const_cast(*this)); + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::alias", ""}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this))) + .callUnboxed(const_cast(*this)); #endif } diff --git a/aten/src/ATen/core/Variadic.h b/aten/src/ATen/core/Variadic.h new file mode 100644 index 0000000000000..495e9050461f2 --- /dev/null +++ b/aten/src/ATen/core/Variadic.h @@ -0,0 +1,74 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { + +// This class allows you to write variadic functions which +// call a (possibly overloaded) function on each argument, +// in order. This is most commonly used in autogenerated code, +// where it is convenient to have a function that can uniformly +// take arguments of different types. If your arguments +// are homogenous consider using a std::initializer_list instead. +// +// For examples of this in use, see torch/csrc/utils/variadic.h +template +struct IterArgs { + template + inline F& apply() { + return self(); + } + + // NB: Use perfect forwarding here, otherwise we'll make value + // copies of all arguments! + template + inline F& apply(T&& arg, Args&&... args) { + self()(std::forward(arg)); + if (self().short_circuit()) { + return self(); + } else { + return apply(std::forward(args)...); + } + } + + // Here are some handy overloads which provide sensible + // defaults for container-like structures that one might + // be interested in recursing into. You can enable them + // by adding: + // + // using IterArgs::operator() + // + // to your struct. These are not enabled by default because + // you may be able to process these structures more efficiently + // than handling them one-by-one. + + template + void operator()(at::ArrayRef args) { + for (const auto& arg : args) { + self()(arg); + if (self().short_circuit()) + return; + } + } + + // NB: we need to specify std::vector manually as C++ won't + // do an implicit conversion to make a template deduction go through. + template + void operator()(const std::vector& args) { + self()(at::ArrayRef{args}); + } + + constexpr bool short_circuit() const { + return false; + } + + private: + inline F& self() { + return *static_cast(this); + } +}; + +} // namespace torch diff --git a/aten/src/ATen/core/alias_info.h b/aten/src/ATen/core/alias_info.h index c9cb3d71f403b..96c4d0a4e3de5 100644 --- a/aten/src/ATen/core/alias_info.h +++ b/aten/src/ATen/core/alias_info.h @@ -86,7 +86,7 @@ inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) { && lhs.containedTypes() == rhs.containedTypes(); } -// DEBUG ONLY; this does not match the way things are represented in the schema +// this does match the way things are represented in the schema inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) { out << "("; bool first = true; @@ -98,11 +98,22 @@ inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) { } out << set.toUnqualString(); } - out << ")"; - - if (!aliasInfo.containedTypes().empty()) { - out << " CONTAINS " << aliasInfo.containedTypes()[0]; + if (aliasInfo.isWrite()) { + out << "!"; } + if (aliasInfo.beforeSets() != aliasInfo.afterSets()) { + out << " -> "; + first = true; + for (const auto& set : aliasInfo.afterSets()) { + if (first) { + first = false; + } else { + out << "|"; + } + out << set.toUnqualString(); + } + } + out << ")"; return out; } } // namespace c10 diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index e1feb58b8001e..dd07afe0e6ccd 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -579,7 +579,6 @@ _(aten, rrelu_with_noise) \ _(aten, rrelu_with_noise_backward) \ _(aten, rrelu_with_noise_forward) \ _(aten, rsqrt) \ -_(aten, s_native_addmm) \ _(aten, scatter) \ _(aten, scatter_add) \ _(aten, select) \ diff --git a/aten/src/ATen/core/dispatch/DispatchTable.h b/aten/src/ATen/core/dispatch/DispatchTable.h index 146f245f2c761..8bf81970370c3 100644 --- a/aten/src/ATen/core/dispatch/DispatchTable.h +++ b/aten/src/ATen/core/dispatch/DispatchTable.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -102,7 +103,8 @@ class KernelTable_ final { class DispatchTable final { public: DispatchTable(const FunctionSchema& schema) - : kernels_(make_left()) + : kernels_() + , catchall_kernel_(c10::nullopt) , dispatch_strategy_(get_dispatch_strategy_(schema)) , operator_name_(schema.name()) {} @@ -115,9 +117,13 @@ class DispatchTable final { TensorTypeId dispatch_key, const DispatchTableEntry& kernel) { TORCH_INTERNAL_ASSERT(dispatch_key != TensorTypeId::UndefinedTensorId); - TORCH_CHECK(dispatch_strategy_.is_valid_, "Tried to register a kernel with dispatch key ", toString(dispatch_key), " for operator ", operator_name_, " that doesn't have tensor arguments."); - TORCH_CHECK(kernels_.is_left(), "Tried to register a kernel with dispatch key ", toString(dispatch_key)," for operator ", operator_name_, ", which already has a catch-all kernel registered. An operator can only have either a catch-all kernel or kernels with dispatch keys."); - kernels_.left().set(dispatch_key, kernel, operator_name_); + // The following assertion is disabled because we're codegenerating + // autograd kernels for operators without tensor arguments even though + // they are never called. These, however, register kernels for + // VariableTensorId. + // TODO Stop generating these kernels and re-enable this assertion here. + //TORCH_CHECK(dispatch_strategy_.is_valid_, "Tried to register a kernel with dispatch key ", toString(dispatch_key), " for operator ", operator_name_, " that doesn't have tensor arguments."); + kernels_.set(dispatch_key, kernel, operator_name_); } /** @@ -126,8 +132,7 @@ class DispatchTable final { * @param dispatch_key Dispatch key to unregister. */ void removeKernelIfExists(TensorTypeId dispatch_key) { - TORCH_INTERNAL_ASSERT(kernels_.is_left(), "Tried to remove the kernel for dispatch key ", toString(dispatch_key), " for operator ", operator_name_, ", which only has a catch-all kernel."); - kernels_.left().removeIfExists(dispatch_key, operator_name_); + kernels_.removeIfExists(dispatch_key, operator_name_); } /** @@ -137,20 +142,18 @@ class DispatchTable final { * dispatch keys, not both. */ void setCatchallKernel(const DispatchTableEntry& kernel) { - if (kernels_.is_right()) { + if (catchall_kernel_.has_value()) { TORCH_WARN("Registered a catch-all kernel for operator ", operator_name_," that overwrote a previously registered catch-all kernel for the same operator."); - } else { - TORCH_CHECK(0 == kernels_.left().size(), "Tried to register a catch-all kernel for operator ", operator_name_, " which already has kernels with dispatch keys. An operator can only have either a catch-all kernel or kernels with dispatch keys."); } - kernels_ = make_right(kernel); + catchall_kernel_ = kernel; } /** * Remove the catch-all kernel. */ void removeCatchallKernel() { - TORCH_INTERNAL_ASSERT(kernels_.is_right(), "Tried to remove the catch-all kernel for operator ", operator_name_," but there is no catch-all kernel registered."); - kernels_ = make_left(); + TORCH_INTERNAL_ASSERT(catchall_kernel_.has_value(), "Tried to remove the catch-all kernel for operator ", operator_name_," but there is no catch-all kernel registered."); + catchall_kernel_ = c10::nullopt; } /** @@ -161,28 +164,28 @@ class DispatchTable final { * @return Kernel function pointing to the right kernel for the given arguments. */ const DispatchTableEntry& lookup(const Stack* stack) const { - return lookup_([=] { - TORCH_INTERNAL_ASSERT(dispatch_strategy_.is_valid_, "Operator ", operator_name_, " has an invalid dispatch key but kernels registered."); + return lookup_([=] () -> c10::optional { + if (!dispatch_strategy_.is_valid_) { + return c10::nullopt; + } return dispatch_strategy_.get_dispatch_key(stack, operator_name_); }); } const DispatchTableEntry& lookup(TensorTypeId dispatchKey) const { - return lookup_([=] {return dispatchKey;}); + return lookup_([=] () -> c10::optional { return dispatchKey;}); } bool isEmpty() const { - return kernels_.map( - [] (const detail::KernelTable_& table) {return 0 == table.size();}, - [] (const DispatchTableEntry&) {return false;} - ); + return !catchall_kernel_.has_value() && kernels_.size() == 0; } std::string listAllDispatchKeys() const { - return kernels_.map( - [] (const detail::KernelTable_& table) {return table.list_all_dispatch_keys();}, - [] (const DispatchTableEntry&) {return "CATCH-ALL";} - ); + std::string result = kernels_.list_all_dispatch_keys(); + if (catchall_kernel_.has_value()) { + result += ", CATCH-ALL"; + } + return result; } private: @@ -209,14 +212,17 @@ class DispatchTable final { 0, reverse_index_of_first_tensor_arg_ ); + // TODO: This will need to get adjusted for multiple dispatch if (C10_UNLIKELY(first_tensor_arg_is_tensor_list_)) { auto tensor_list = first_tensor_arg.toTensorListRef(); if (tensor_list.size() == 0) { throw std::runtime_error("Tried to dispatch operator " + operator_name + " based on an empty tensor list. When the first tensor argument of an operator is a tensor list, then it must not be empty."); } - return tensor_list[0].type_id(); + // TODO: Don't use legacy extractor; blocked on c10 understanding + // variable + return c10::legacyExtractTypeId(tensor_list[0].type_set()); } else { - return first_tensor_arg.unsafeToTensorImpl()->type_id(); + return c10::legacyExtractTypeId(first_tensor_arg.unsafeToTensorImpl()->type_set()); } } }; @@ -239,30 +245,35 @@ class DispatchTable final { template const DispatchTableEntry& lookup_(const GetDispatchKeyFunc& getDispatchKey) const { - return kernels_.map( - [&] (const detail::KernelTable_& table) -> const DispatchTableEntry& { - // We have a dispatch table. Find the correct kernel for the inputs and return it. - TensorTypeId dispatch_key = getDispatchKey(); - auto found = table.lookup(dispatch_key); - - TORCH_CHECK(nullptr != found, "Didn't find kernel to dispatch to for operator '", operator_name_, - "'. Tried to look up kernel for dispatch key '", toString(dispatch_key), - "'. Registered dispatch keys are: ", listAllDispatchKeys()); - - return *found; - }, - [] (const DispatchTableEntry& entry) -> const DispatchTableEntry& { - // We have a catch-all kernel. Just return it. - return entry; + c10::optional dispatch_key = getDispatchKey(); + if (dispatch_key.has_value()) { + const auto* found = kernels_.lookup(*dispatch_key); + + if (nullptr != found) { + return *found; + } + } + + if (catchall_kernel_.has_value()) { + return *catchall_kernel_; + } + + if (!dispatch_key.has_value() || *dispatch_key == TensorTypeId::UndefinedTensorId) { + TORCH_CHECK(false, + "There were no tensor arguments to this function (e.g., you passed an " + "empty list of Tensors), but no fallback function is registered for schema ", operator_name_, + ". This usually means that this function requires a non-empty list of Tensors. " + "Available functions are ", listAllDispatchKeys()) } - ); + + const std::string dispatch_key_str = dispatch_key.has_value() ? toString(*dispatch_key) : "None"; + TORCH_CHECK(false, "Didn't find kernel to dispatch to for operator '", operator_name_, + "'. Tried to look up kernel for dispatch key '", dispatch_key_str, + "'. Registered dispatch keys are: ", listAllDispatchKeys()); } - // kernels_ either contains a dispatch table or - // a single catch-all kernel that is called for every backend - // The empty state (i.e. no kernels registered) is represented - // as an empty table. - either kernels_; + detail::KernelTable_ kernels_; + c10::optional catchall_kernel_; DispatchStrategy dispatch_strategy_; std::string operator_name_; }; diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp index 79e174679af14..10d10931e2e27 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.cpp +++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp @@ -126,11 +126,6 @@ RegistrationHandleRAII Dispatcher::registerCatchallKernel(const OperatorHandle& return op.operatorIterator_->op.registerCatchallKernel(DispatchTableEntry{kernel_func, std::move(cache_creator_func), unboxed_kernel_func}); } -RegistrationHandleRAII Dispatcher::registerUnboxedAutogradKernel(const OperatorHandle& op, void* unboxed_autograd_kernel) { - // note: this doesn't need the mutex to protect the iterator because write operations on the list keep iterators intact. - return op.operatorIterator_->op.registerUnboxedAutogradKernel(unboxed_autograd_kernel); -} - void Dispatcher::addRegistrationListener(std::unique_ptr listener) { std::lock_guard lock(mutex_); diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 26ee88f52a49a..8b459f25e2b13 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -90,8 +90,6 @@ class CAFFE2_API Dispatcher final { */ RegistrationHandleRAII registerCatchallKernel(const OperatorHandle& op, KernelFunction* kernel_func, KernelCacheCreatorFunction cache_creator_func, void* unboxed_kernel_func); - RegistrationHandleRAII registerUnboxedAutogradKernel(const OperatorHandle& op, void* unboxed_autograd_kernel); - /** * Perform a dynamic dispatch and get the kernel for an operator. */ @@ -104,11 +102,6 @@ class CAFFE2_API Dispatcher final { // the (unboxed?) arguments the operator is to be called with. OpKernel lookup(const OperatorHandle& op, TensorTypeId dispatchKey) const; - // TODO Remove callUnboxedAutogradKernel() and instead figure out in a generic - // callKernel() wrapper if the autograd or the regular kernel need to be called. - template - Result callUnboxedAutogradKernel(const OperatorHandle& op, Args... args) const; - /** * Add a listener that gets called whenever a new op is registered or an existing * op is deregistered. Immediately after registering, this listener gets called @@ -183,14 +176,4 @@ inline OpKernel Dispatcher::lookup(const OperatorHandle& op, TensorTypeId dispat return op.operatorIterator_->op.lookupKernel(dispatchKey); } -template -inline Result Dispatcher::callUnboxedAutogradKernel(const OperatorHandle& op, Args... args) const { - void* unboxed_autograd_kernel = op.operatorIterator_->op.lookupUnboxedAutogradKernel(); - TORCH_CHECK(nullptr != unboxed_autograd_kernel, "Tried to call Dispatcher::callUnboxedAutogradKernel() for operator ", toString(op.schema()), " that doesn't have an autograd kernel."); - - using OpSignature = Result (Args...); - OpSignature* kernel = reinterpret_cast(unboxed_autograd_kernel); - return (*kernel)(std::forward(args)...); -} - } // namespace c10 diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index 99b613cd087c9..30ccb534d4947 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -20,7 +20,8 @@ namespace { OperatorEntry::OperatorEntry(FunctionSchema&& schema, OperatorOptions&& options) : schema_(std::move(schema)) , dispatchTable_(schema_) -, kernels_(make_left>, std::list>()) +, kernels_() +, catchAllKernels_() , options_(std::move(options)) { } @@ -30,18 +31,16 @@ void OperatorEntry::prepareForDeregistration() { TORCH_INTERNAL_ASSERT(false, "Tried to deregister op schema for an operator that still has kernels registered. The operator schema is ", toString(schema_), ". Registered kernels for dispatch keys: ", dispatchTable.listAllDispatchKeys()); } }); - TORCH_INTERNAL_ASSERT(kernels_.is_left(), "If the dispatch table is empty, then the invariant says there can't be any kernels but we still have a catch-all kernel. The operator schema is ", toString(schema_)); - TORCH_INTERNAL_ASSERT(kernels_.left().size() == 0, "If the dispatch table is empty, then the invariant says there can't be any kernels but we still have kernels for dispatch keys ", listAllDispatchKeys(kernels_.left()), ". The operator schema is ", toString(schema_)); + TORCH_INTERNAL_ASSERT(kernels_.size() == 0, "If the dispatch table is empty, then the invariant says there can't be any kernels but we still have kernels for dispatch keys ", listAllDispatchKeys(kernels_), ". The operator schema is ", toString(schema_)); + TORCH_INTERNAL_ASSERT(catchAllKernels_.size() == 0, "If the dispatch table is empty, then the invariant says there can't be any kernels but we still have catch-all kernel. The operator schema is ", toString(schema_)); } RegistrationHandleRAII OperatorEntry::registerKernel(TensorTypeId dispatch_key, DispatchTableEntry kernel) { std::unique_lock lock(kernelsMutex_); - TORCH_CHECK(kernels_.is_left(), "Tried to register a kernel with dispatch key ", toString(dispatch_key)," for an operator which already has a catch-all kernel registered. An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is ", toString(schema_)); - // Add the kernel to the kernels list, // possibly creating the list if this is the first kernel. - auto& k = kernels_.left()[dispatch_key]; + auto& k = kernels_[dispatch_key]; k.push_front(kernel); std::list::iterator inserted = k.begin(); // update the dispatch table, i.e. re-establish the invariant @@ -58,16 +57,10 @@ RegistrationHandleRAII OperatorEntry::registerKernel(TensorTypeId dispatch_key, RegistrationHandleRAII OperatorEntry::registerCatchallKernel(DispatchTableEntry kernel) { std::unique_lock lock(kernelsMutex_); - if (kernels_.is_left()) { - TORCH_CHECK(0 == kernels_.left().size(), "Tried to register a catch-all kernel for an operator which already has kernels for dispatch keys ", listAllDispatchKeys(kernels_.left()), ". An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is ", toString(schema_)); - kernels_ = make_right>, std::list>(); - } - // Add the kernel to the kernels list, // possibly creating the list if this is the first kernel. - auto& k = kernels_.right(); - k.push_front(kernel); - std::list::iterator inserted = k.begin(); + catchAllKernels_.push_front(kernel); + std::list::iterator inserted = catchAllKernels_.begin(); // update the dispatch table, i.e. re-establish the invariant // that the dispatch table points to the newest kernel updateCatchallDispatchTable_(); @@ -82,16 +75,13 @@ RegistrationHandleRAII OperatorEntry::registerCatchallKernel(DispatchTableEntry void OperatorEntry::deregisterKernel_(TensorTypeId dispatch_key, std::list::iterator kernel) { std::unique_lock lock(kernelsMutex_); - TORCH_CHECK(kernels_.is_left(), "Tried deregister a kernel for dispatch key ", toString(dispatch_key), " for an operator that only has a catch-all kernel. The operator schema is ", toString(schema_)); - - auto& kernels = kernels_.left(); - auto found = kernels.find(dispatch_key); - TORCH_INTERNAL_ASSERT(found != kernels.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator schema is ", toString(schema_)); + auto found = kernels_.find(dispatch_key); + TORCH_INTERNAL_ASSERT(found != kernels_.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator schema is ", toString(schema_)); auto& k = found->second; k.erase(kernel); if (k.empty()) { // the invariant says we don't want empty lists but instead remove the list from the map - kernels.erase(found); + kernels_.erase(found); } updateDispatchTable_(dispatch_key); @@ -100,62 +90,17 @@ void OperatorEntry::deregisterKernel_(TensorTypeId dispatch_key, std::list::iterator kernel) { std::unique_lock lock(kernelsMutex_); - TORCH_CHECK(kernels_.is_right(), "Tried to deregister a catch-all kernel for an operator that doesn't have a catch-all kernel registered. The operator schema is ", toString(schema_)); - - auto& k = kernels_.right(); - k.erase(kernel); - if (k.empty()) { - // the invariant says that the empty state is represented with is_left() - kernels_ = make_left>, std::list>(); - } + catchAllKernels_.erase(kernel); updateCatchallDispatchTable_(); } -RegistrationHandleRAII OperatorEntry::registerUnboxedAutogradKernel(void* kernel_func) { - std::unique_lock lock(unboxedAutogradKernelsMutex_); - - TORCH_INTERNAL_ASSERT(kernel_func != nullptr); - - unboxedAutogradKernels_.push_front(kernel_func); - std::list::iterator inserted = unboxedAutogradKernels_.begin(); - - updateCurrentUnboxedAutogradKernel_(); - - return RegistrationHandleRAII([this, inserted] { - // list iterators stay valid even if the list changes, - // so we can use the iterator to deregister the kernel from the list - deregisterUnboxedAutogradKernel_(inserted); - }); -} - -void OperatorEntry::deregisterUnboxedAutogradKernel_(std::list::iterator kernel) { - std::unique_lock lock(unboxedAutogradKernelsMutex_); - - unboxedAutogradKernels_.erase(kernel); - - updateCurrentUnboxedAutogradKernel_(); -} - -void OperatorEntry::updateCurrentUnboxedAutogradKernel_() { - // precondition: unboxedAutogradKernelsMutex_ is locked - - if (unboxedAutogradKernels_.empty()) { - currentUnboxedAutogradKernel_ = nullptr; - } else { - currentUnboxedAutogradKernel_ = unboxedAutogradKernels_.front(); - } -} - void OperatorEntry::updateDispatchTable_(TensorTypeId dispatch_key) { // precondition: kernelsMutex_ is locked - TORCH_INTERNAL_ASSERT(kernels_.is_left(), "Can't update the dispatch table a dispatch key ", toString(dispatch_key), " because the operator only has catch-all kernels. The operator schema is ", toString(schema_)); - - auto& kernels = kernels_.left(); - auto k = kernels.find(dispatch_key); + auto k = kernels_.find(dispatch_key); - if (k == kernels.end()) { + if (k == kernels_.end()) { dispatchTable_.write([&] (DispatchTable& dispatchTable) { dispatchTable.removeKernelIfExists(dispatch_key); }); @@ -169,13 +114,13 @@ void OperatorEntry::updateDispatchTable_(TensorTypeId dispatch_key) { void OperatorEntry::updateCatchallDispatchTable_() { // precondition: kernelsMutex_ is locked - if (kernels_.is_left()) { + if (catchAllKernels_.size() == 0) { dispatchTable_.write([&] (DispatchTable& dispatchTable) { dispatchTable.removeCatchallKernel(); }); } else { dispatchTable_.write([&] (DispatchTable& dispatchTable) { - dispatchTable.setCatchallKernel(kernels_.right().front()); + dispatchTable.setCatchallKernel(catchAllKernels_.front()); }); } } diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h index 6542dd9e26974..37eaf0253ab9c 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.h +++ b/aten/src/ATen/core/dispatch/OperatorEntry.h @@ -54,12 +54,14 @@ class CAFFE2_API OpKernel final { private: explicit OpKernel(KernelFunction* kernel, const KernelCacheCreatorFunction& cache_creator, void* unboxed_kernel) - : kernel_(kernel), cache_(cache_creator ? cache_creator() : nullptr), unboxed_kernel_(unboxed_kernel) {} + : kernel_(kernel), cache_(cache_creator ? cache_creator() : c10::guts::make_unique()), unboxed_kernel_(unboxed_kernel) {} friend class impl::OperatorEntry; - KernelFunction* kernel_; // can be nullptr, not all kernels have this + // All of these fields may be nullptr, but at least one of + // kernel_ or unboxed_kernel_ should be non-NULL + KernelFunction* kernel_; std::unique_ptr cache_; - void* unboxed_kernel_; // can be nullptr, not all kernels have this + void* unboxed_kernel_; }; namespace impl { @@ -93,17 +95,11 @@ class OperatorEntry final { }); } - void* lookupUnboxedAutogradKernel() const { - return currentUnboxedAutogradKernel_; - } - void prepareForDeregistration(); RegistrationHandleRAII registerKernel(TensorTypeId dispatch_key, DispatchTableEntry kernel); RegistrationHandleRAII registerCatchallKernel(DispatchTableEntry kernel); - RegistrationHandleRAII registerUnboxedAutogradKernel(void* kernel_func); - const OperatorOptions& options() { return options_; } @@ -111,21 +107,14 @@ class OperatorEntry final { private: void deregisterKernel_(TensorTypeId dispatch_key, std::list::iterator kernel); void deregisterCatchallKernel_(std::list::iterator kernel); - void deregisterUnboxedAutogradKernel_(std::list::iterator kernel); FunctionSchema schema_; // The dispatchTable stores the current kernel for each dispatch key LeftRight dispatchTable_; - // kernels_ is either: - // left: a kernel map listing mapping from a dispatch key to a list of all - // kernels for that operator, or it is - // right: a list of all catch-all kernels registered for this operator. - // An operator can only have either dispatched kernels or catch-all kernels, - // not both. - // In both cases, the list of kernels stores all registered kernels for the - // corresponding dispatch key (or for catch-all). + // kernels_ stores all registered kernels for the corresponding dispatch key + // and catchAllKernels_ stores the catch-all kernels. // If an operator library gets loaded that overwrites an already existing kernel, // both kernels will be in that list but only the newer one will be in // dispatchTable. If any of the kernels go away (say the library gets @@ -137,15 +126,13 @@ class OperatorEntry final { // kernels is a larger data structure and accessed quite infrequently // while dispatchTable is accessed often and should be kept small to fit // into CPU caches. - // Invariants (assuming kernels_.is_left()): - // - dispatchTable[dispatch_key] == kernels_.left()[dispatch_key].front() + // Invariants: + // - dispatchTable[dispatch_key] == kernels_[dispatch_key].front() // - dispatchTable[dispatch_key] does not exist if and only if - // kernels_.left()[dispatch_key] does not exist - // - If kernels_.left()[dispatch_key] exists, then it has elements. + // kernels_[dispatch_key] does not exist + // - If kernels_[dispatch_key] exists, then it has elements. // It is never an empty list. - // Analogous invariants for kernels_.is_right(). - // The empty state (i.e. no kernels registered) is represented as an empty - // map with kernels_.is_left(). + // Analogous invariants for catchAllKernels_. // // Why do we do that? // ----- @@ -155,41 +142,21 @@ class OperatorEntry final { // function schema changed between the executions, but it works as long // as the function schema didn't change. A better solution would be to // unload the old extension library from the Jupyter cell when the cell is - // re-ececuted and then only allow one kernel here, i.e. error if a kernel + // re-executed and then only allow one kernel here, i.e. error if a kernel // is already registered, but that's a lot of effort to implement and // currently not high-pri. - c10::either< - ska::flat_hash_map>, // dispatched kernels - std::list // catch-all kernels - > kernels_; - - // unboxedAutogradKernels_ stores all autograd kernels registered for this op. - // An autograd kernel has the same signature as the main op kernel and - // internally re-dispatches to call the actual kernel. - // Autograd kernels are unboxed currently. We are planning to move this - // towards a system where ops register autograd wrappers (i.e. functions that - // do some wrapping code and get a pointer to the actual kernel) instead of - // autograd functions. - // This is a list because, similar to kernels_, multiple libraries could - // be loaded that register autograd kernels for the same op. The list is - // ordered by registration time descendingly, i.e. newer registrations are - // before older registrations and the list head is the autograd kernel - // which is currently used. - // See the comment for kernels_ above for an explanation for why we do this. - std::list unboxedAutogradKernels_; - std::atomic currentUnboxedAutogradKernel_; + ska::flat_hash_map> kernels_; + std::list catchAllKernels_; // Some metadata about the operator OperatorOptions options_; std::mutex kernelsMutex_; // protects kernels_ - std::mutex unboxedAutogradKernelsMutex_; // protects unboxedAutogradKernels_ // This function re-establishes the invariant that dispatchTable // contains the front element from the kernels list for a given dispatch key. void updateDispatchTable_(TensorTypeId dispatch_key); void updateCatchallDispatchTable_(); - void updateCurrentUnboxedAutogradKernel_(); }; } diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index 6837fef4ec085..e1706c3b2362e 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -4,6 +4,7 @@ #include #include #include +#include #include namespace c10 { @@ -12,6 +13,23 @@ namespace c10 { // errors. These objects should be constructed from C10 schema once those // are available. +struct Argument; +struct FunctionSchema; + +namespace detail { +inline bool defaultValueEquals_( + const c10::optional& lhs, + const c10::optional& rhs) { + if (lhs.has_value()) { + return rhs.has_value() && impl::shallowEquals(*lhs, *rhs); + } else { + return !rhs.has_value(); + } +} +} // namespace detail + +bool operator==(const Argument& lhs, const Argument& rhs); + struct Argument { Argument( std::string name = "", @@ -78,6 +96,15 @@ struct Argument { return Argument(name_, new_type, N_, default_value_, kwarg_only_, alias_info_); } + // this function check whether this Argument is backward compatible with + // the old one. we consider the following cases are backward compatible: + // 1) two arguments are equal + // 2) this arg's type should be subtype of old + // 3) this arg must provide the same default value if old arg has one, + bool isBackwardCompatibleWith( + const Argument& old, + std::ostream* why_not=nullptr) const; + private: std::string name_; TypePtr type_; @@ -94,16 +121,6 @@ struct Argument { bool is_inferred_type_; }; -namespace detail { -inline bool defaultValueEquals_(const c10::optional& lhs, const c10::optional& rhs) { - if (lhs.has_value()) { - return rhs.has_value() && impl::shallowEquals(*lhs, *rhs); - } else { - return !rhs.has_value(); - } -} -} - inline bool operator==(const Argument& lhs, const Argument& rhs) { return lhs.name() == rhs.name() && *lhs.type() == *rhs.type() @@ -113,10 +130,7 @@ inline bool operator==(const Argument& lhs, const Argument& rhs) { && lhs.alias_info() == rhs.alias_info(); } -struct OperatorName final { - std::string name; - std::string overload_name; -}; +bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs); struct FunctionSchema { FunctionSchema( @@ -147,6 +161,22 @@ struct FunctionSchema { is_vararg, is_varret) {} + // check whether this schema is backward compatible with the old one. + // the following conditions are considered as this schema is backward + // compatible with old: + // 1) two schemas are equal + // 2) this schema has the same or more positional args than old, + // and any positional arg in this schema is backward compatible + // with the corresponding one in old schema, which could be an arg + // or a kwarg, if it has, or it must provide a default value + // 3) this schema has the same or more kwargs than old, and all the kwargs + // in old schema can find the corresponding kwarg in this schema which + // is backward compatible with the old kwarg, and the extra kwargs in + // this schema must provide default values. + bool isBackwardCompatibleWith( + const FunctionSchema& old, + std::ostream* why_not=nullptr) const; + private: OperatorName name_; std::vector arguments_; @@ -237,9 +267,9 @@ struct FunctionSchema { return false; } - // can a function with this schema be substituted for a function of rhs's + // can a function with this schema be substituted for a function of rhs's // schema and have the program typecheck? - // as_method - if true, treat this schema as a method and ignore + // as_method - if true, treat this schema as a method and ignore // the first argument, which will be the object in both cases bool isSubtypeOf(const FunctionSchema& rhs, bool as_method, std::ostream* why_not=nullptr) const; }; diff --git a/aten/src/ATen/core/function_schema_inl.h b/aten/src/ATen/core/function_schema_inl.h index 04eb3068d9ffe..48a0c9579fa34 100644 --- a/aten/src/ATen/core/function_schema_inl.h +++ b/aten/src/ATen/core/function_schema_inl.h @@ -51,6 +51,30 @@ inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) return out; } +inline bool Argument::isBackwardCompatibleWith( + const Argument& old, + std::ostream* why_not) const { + const Argument* lhs = this; + const Argument* rhs = &old; + if (!(lhs->name() == rhs->name() + && lhs->N() == rhs->N() + && lhs->alias_info() == rhs->alias_info())) { + return false; + } + if (lhs->kwarg_only() && !rhs->kwarg_only()) { + return false; + } + if (!rhs->type()->isSubtypeOfExt(lhs->type(), why_not)) { + return false; + } + if (rhs->default_value().has_value() && + !detail::defaultValueEquals_(lhs->default_value(), + rhs->default_value())) { + return false; + } + return true; +} + inline std::string FunctionSchema::formatTypeMismatchMsg( const Argument& expected, const std::string& actual_type, @@ -74,6 +98,90 @@ inline std::string FunctionSchema::formatTypeMismatchMsg( *this); } +inline bool FunctionSchema::isBackwardCompatibleWith( + const FunctionSchema& old, + std::ostream* why_not) const { + if (!(name() == old.name() + && overload_name() == old.overload_name() + // we are conservative on is_vararg and is_varret, + // since they are only used by internal operators + && is_vararg() == old.is_vararg() + && is_varret() == old.is_varret() + && returns().size() == old.returns().size() + && arguments().size() >= old.arguments().size())) { + return false; + } + for (size_t i = 0; i < returns().size(); ++i) { + // functions are covariant in arguments but contravariant in returns + if (!old.returns().at(i).isBackwardCompatibleWith( + returns().at(i), + why_not)) { + return false; + } + } + std::vector args, old_args; + std::map kwargs, old_kwargs; + auto split_func = [](const std::vector& arguments, + std::vector* positionals, + std::map* nameds) { + for (const Argument& arg : arguments) { + if (!arg.kwarg_only()) { + positionals->emplace_back(&arg); + } + nameds->emplace(arg.name(), &arg); + } + }; + // we split args into positional and keyward parts, + split_func(arguments(), &args, &kwargs); + split_func(old.arguments(), &old_args, &old_kwargs); + if (old_args.size() > args.size()) { + return false; + } + // make sure that all the old positional args have their corresponding + // backward compatible positional args in this schema + for (size_t i = 0; i < old_args.size(); ++i) { + if (!args.at(i)->isBackwardCompatibleWith( + *old_args.at(i), + why_not)) { + return false; + } + } + // check the extra positional args in this schema either has corresponding + // backward compatible keyward args since positional args also can be used as + // a keyward arg, or provided default values + for (size_t i = old_args.size(); i < args.size(); ++i) { + if (!args.at(i)->default_value()) { + auto it = old_kwargs.find(args.at(i)->name()); + if (it == old_kwargs.end() || + !args.at(i)->isBackwardCompatibleWith( + *it->second, + why_not)) { + return false; + } + } + } + // make sure that all the keyword args in the old schema have their + // corresponding backward compatible keyward args in this schema + for (auto& kv : old_kwargs) { + auto it = kwargs.find(kv.first); + if (it == kwargs.end() || + !it->second->isBackwardCompatibleWith( + *kv.second, + why_not)) { + return false; + } + kwargs.erase(it); + } + // check all the extra keyword args in this schema provide default values + for (auto& kv : kwargs) { + if (!kv.second->default_value()) { + return false; + } + } + + return true; +} + inline void FunctionSchema::checkArg( const IValue& value, const Argument& argument, @@ -189,14 +297,6 @@ inline FunctionSchema FunctionSchema::cloneWithRemappedTypes( is_varret()); } -inline bool operator==(const OperatorName& lhs, const OperatorName& rhs) { - return lhs.name == rhs.name && lhs.overload_name == rhs.overload_name; -} - -inline bool operator!=(const OperatorName& lhs, const OperatorName& rhs) { - return !operator==(lhs, rhs); -} - // covariant subtyping of list of Arguments inline bool isSubtypeOfList( ArrayRef child, @@ -232,12 +332,3 @@ inline bool FunctionSchema::isSubtypeOf( } } // namespace c10 - -namespace std { - template <> - struct hash<::c10::OperatorName> { - size_t operator()(const ::c10::OperatorName& x) const { - return std::hash()(x.name) ^ (~ std::hash()(x.overload_name)); - } - }; -} diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 98750de3b00af..eed1cbcd774d4 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -191,6 +191,7 @@ namespace c10 { _(onnx, Split) \ _(onnx, ConstantOfShape) \ _(onnx, Cast) \ + _(onnx, Mod) \ FORALL_ATTR_BASE_SYMBOLS(_) \ _(attr, Subgraph) \ _(attr, ReverseSubgraph) \ @@ -297,9 +298,7 @@ struct CAFFE2_API Symbol { static Symbol prim(const std::string & s); static Symbol user(const std::string & s); static Symbol caffe2(const std::string & s); -#ifdef BUILD_NAMEDTENSOR static Symbol dimname(const std::string & s); -#endif // TODO: eliminate me static Symbol scope(const std::string & s); @@ -309,9 +308,7 @@ struct CAFFE2_API Symbol { bool is_onnx() const; bool is_user() const; bool is_caffe2() const; -#ifdef BUILD_NAMEDTENSOR bool is_dimname() const; -#endif // So we can switch on this constexpr operator unique_t() const { @@ -372,18 +369,14 @@ inline Symbol Symbol::prim(const std::string & s) { return Symbol::fromQualStri inline Symbol Symbol::scope(const std::string & s) { return Symbol::fromQualString("scope::" + s); } inline Symbol Symbol::user(const std::string & s) { return Symbol::fromQualString("user::" + s); } inline Symbol Symbol::caffe2(const std::string & s) { return Symbol::fromQualString("_caffe2::" + s); } -#ifdef BUILD_NAMEDTENSOR inline Symbol Symbol::dimname(const std::string & s) { return Symbol::fromQualString("dimname::" + s); } -#endif inline bool Symbol::is_attr() const { return ns() == namespaces::attr; } inline bool Symbol::is_aten() const { return ns() == namespaces::aten; } inline bool Symbol::is_prim() const { return ns() == namespaces::prim; } inline bool Symbol::is_onnx() const { return ns() == namespaces::onnx; } inline bool Symbol::is_user() const { return ns() == namespaces::user; } inline bool Symbol::is_caffe2() const { return ns() == namespaces::_caffe2; } -#ifdef BUILD_NAMEDTENSOR inline bool Symbol::is_dimname() const { return ns() == namespaces::dimname; } -#endif } // namespace c10 diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 20f647174ccab..6946e11886508 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -13,8 +13,63 @@ CAFFE2_API c10::intrusive_ptr ConstantString::create( return c10::make_intrusive(std::move(str_)); } +TupleTypePtr Tuple::type() const { + if (!type_) { + type_ = TupleType::create( + fmap(elements_, [&](const IValue& v) { return v.type(); })); + } + return type_; +} + } // namespace ivalue + +TypePtr IValue::type() const { + switch(tag) { + case Tag::None: + return NoneType::get(); + case Tag::Tensor: + return TensorType::create(toTensor()); + case Tag::Double: + return FloatType::get(); + case Tag::Int: + return IntType::get(); + case Tag::Bool: + return BoolType::get(); + case Tag::IntList: + return ListType::ofInts(); + case Tag::DoubleList: + return ListType::ofFloats(); + case Tag::BoolList: + return ListType::ofBools(); + case Tag::TensorList: + return ListType::ofTensors(); + case Tag::String: + return StringType::get(); + case Tag::Blob: + return AnyType::get(); + case Tag::GenericDict: { + auto d = toGenericDict(); + return DictType::create(d.keyType(), d.valueType()); + } + case Tag::GenericList: + return ListType::create(toGenericList().elementType()); + case Tag::Future: + return toFuture()->type(); + case Tag::Device: + return DeviceObjType::get(); + case Tag::Object: + return toObjectRef().type(); + case Tag::Uninitialized: + return AnyType::get(); + case Tag::Capsule: + return CapsuleType::get(); + case Tag::Tuple: + return toTuple()->type(); + } + // switch above is complete but this silences compiler warnings + TORCH_INTERNAL_ASSERT(false, "unhandled case in IValue::type()"); +} namespace { template @@ -149,8 +204,8 @@ void ivalue::Object::resizeObject(size_t slot) { slots_.resize(type()->numAttributes()); } -static bool CompareIValue(const std::pair& aWrap, - const std::pair& bWrap) { +static bool CompareKeys(const std::pair& aWrap, + const std::pair& bWrap) { const auto a = aWrap.first; const auto b = bWrap.first; if (a.isString() && b.isString()) { @@ -159,6 +214,8 @@ static bool CompareIValue(const std::pair& aWrap, return a.toInt() < b.toInt(); } else if (a.isDouble() && b.isDouble()) { return a.toDouble() < b.toDouble(); + } else if (a.isTensor() && b.isTensor()) { + return a.toTensor().unsafeGetTensorImpl() < b.toTensor().unsafeGetTensorImpl(); } AT_ERROR("Illegal dict key"); } @@ -168,7 +225,7 @@ std::vector> iterationOrder(const c10::Dict class Dict; template class List; struct IValue; struct ClassType; +struct Type; +using TypePtr = std::shared_ptr; namespace ivalue { struct Tuple; struct Future; @@ -431,6 +433,8 @@ struct CAFFE2_API IValue final { return payload.as_intrusive_ptr; } + TypePtr type() const; + private: // NOTE: IValue tags are intentionally private. In the future we may encode // this value different (e.g. using NaN boxing), and this would make it more @@ -454,7 +458,7 @@ struct CAFFE2_API IValue final { tag = Tag::None; is_intrusive_ptr = false; } -private: + union Payload { int64_t as_int; double as_double; diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index c476f0d36bcbd..0892a96ed6f79 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -135,16 +135,19 @@ struct Future; struct CAFFE2_API Tuple : c10::intrusive_ptr_target { private: - std::vector elements_; + std::vector elements_; + mutable std::shared_ptr type_; // lazily computed for unnamed tuples public: - static c10::intrusive_ptr create(std::vector elements_, std::shared_ptr type_) { - TORCH_INTERNAL_ASSERT(nullptr != type_.get(), "Type cannot be nullptr"); + // named tuples have additional type information, so we + // directly create them tagged + static c10::intrusive_ptr createNamed( + std::vector elements_, + std::shared_ptr type_) { return c10::make_intrusive(std::move(elements_), type_); } - C10_DEPRECATED_MESSAGE("Creating tuples without type information is deprecated. Please use Tuple::create(elements, type) instead.") static c10::intrusive_ptr create(std::vector elements_) { - return c10::make_intrusive(std::move(elements_), nullptr); + return c10::make_intrusive(std::move(elements_)); } const std::vector& elements() const & { @@ -164,11 +167,11 @@ struct CAFFE2_API Tuple : c10::intrusive_ptr_target { std::vector&& elements() && { return std::move(elements_); } + std::shared_ptr type() const; - std::shared_ptr type; private: - Tuple(std::vector elements, std::shared_ptr type) - : elements_(std::move(elements)), type(std::move(type)) {} + Tuple(std::vector elements, std::shared_ptr type = nullptr) + : elements_(std::move(elements)), type_(std::move(type)) {} friend class c10::intrusive_ptr; }; @@ -188,6 +191,7 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { } public: + Future(TypePtr type) : type_(type) {} struct CAFFE2_API FutureError final : public std::exception { FutureError(std::string&& error_msg_) : error_msg(std::move(error_msg_)) {} @@ -266,7 +270,7 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { } // Check if the current future has completed - bool completed() { + bool completed() const{ return completed_; } @@ -274,6 +278,10 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { std::ostream& out, const Future& v); + TypePtr type() const { + return type_; + } + private: void fireCallbacks() { AT_ASSERT(completed()); @@ -290,6 +298,7 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { std::condition_variable finished_cv_; IValue value_; // when finished the value + TypePtr type_; std::vector> callbacks; bool has_error = false; FutureError error; diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 6dcb822e70028..feeff96245ac8 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -12,6 +12,7 @@ #include #include #include +#include struct ClassType; namespace torch { @@ -29,6 +30,7 @@ struct FunctionSchema; using OptNameList = c10::optional>; #define C10_FORALL_TYPES(_) \ + _(AnyType) \ _(TensorType) \ _(TupleType) \ _(ListType) \ @@ -163,6 +165,29 @@ struct CAFFE2_API Type : std::enable_shared_from_this { } }; +struct AnyType; +using AnyTypePtr = std::shared_ptr; +// Any is the top of the type hierarchy, all other types are subtypes +// T <: Any, forall T +struct CAFFE2_API AnyType : public Type { + static AnyTypePtr create() { + return AnyTypePtr( + new AnyType()); // NOLINT(modernize-make-shared) + } + bool operator==(const Type& rhs) const override { + return rhs.kind() == kind(); + } + std::string str() const override { + return "Any"; + } + static const TypeKind Kind = TypeKind::AnyType; + // global singleton + static AnyTypePtr get(); + + private: + AnyType() : Type(TypeKind::AnyType) {} +}; + inline std::string toString(TypePtr typePtr) { return typePtr->str(); } @@ -547,7 +572,7 @@ struct CAFFE2_API TensorType : public Type { sizes_(tensor.sizes().size()), strides_(tensor.sizes().size()), requires_grad_(tensor.requires_grad()) { - if (!tensor.is_mkldnn()) { + if (!tensor.is_mkldnn() && !tensor.is_sparse()) { sizes_ = tensor.sizes().vec(); strides_ = tensor.strides().vec(); } @@ -641,6 +666,7 @@ struct CAFFE2_API DictType : public Type { static DictTypePtr create(TypePtr key, TypePtr value) { switch (key->kind()) { + case TypeKind::AnyType: case TypeKind::IntType: case TypeKind::FloatType: case TypeKind::StringType: @@ -1132,13 +1158,13 @@ inline TypePtr TensorType::fromNumberType(TypePtr typ) { } else if (typ->isSubtypeOf(BoolType::get())) { return TensorType::createContiguous(at::kLong, at::kCPU, {}); } - AT_ERROR("unknown number type", typ->str()); + TORCH_CHECK(false, "Unknown number type: ", typ->str()); } inline TypePtr TensorType::fromBoolType() { return TensorType::createContiguous(at::kLong, at::kCPU, {}); } -inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) { +inline c10::optional tryScalarTypeFromJitType(const c10::TypePtr & type) { if (type == FloatType::get()) { return at::ScalarType::Double; } else if (type == IntType::get()) { @@ -1146,10 +1172,16 @@ inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) { } else if (type == BoolType::get()) { return at::ScalarType::Bool; } + return c10::nullopt; +} + +inline at::ScalarType scalarTypeFromJitType(const c10::TypePtr& type) { + auto result = tryScalarTypeFromJitType(type); AT_ASSERTM( - 0, + result, "Add new condition, expected Float, Int, or Bool but got", type->str()); + return *result; } // Attempt to find the correct supertype of t1 and t2. If none is found then @@ -1231,6 +1263,13 @@ struct getTypePtr_> final { return type; } }; +template +struct getTypePtr_> final { + static TypePtr call() { + static auto type = ListType::create(getTypePtr_::call()); + return type; + } +}; template struct getTypePtr_> final { static TypePtr call() { @@ -1280,7 +1319,7 @@ struct MatchTypeReturn { } private: - MatchTypeReturn() + MatchTypeReturn() : reason_(c10::nullopt) {} c10::optional reason_; // is there is no match, this contains the reason }; @@ -1290,13 +1329,13 @@ struct MatchTypeReturn { // and a r.reason() that describes why it could not match. // note: It is possible to successfully match a formal, but for type variables // in the formal to still not be defined. In particular, None matches Optional[T] -// but does not define the value of T. +// but does not define the value of T. CAFFE2_API MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type_env); -// replace type variables appearing in `type` with the values in -// `type_env`. Returns nullptr if a variable used in `type` -// does not appear in `type_env` +// replace type variables appearing in `type` with the values in +// `type_env`. Returns nullptr if a variable used in `type` +// does not appear in `type_env` CAFFE2_API TypePtr tryEvalTypeVariables(TypePtr type, TypeEnv& type_env); /** diff --git a/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp b/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp index 2604c03de6711..fd996e091aad8 100644 --- a/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_function_legacy_test.cpp @@ -157,11 +157,11 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorOu auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); } std::vector kernelWithTensorListOutput(const Tensor& input1, const Tensor& input2, const Tensor& input3) { @@ -178,9 +178,9 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorLi auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); EXPECT_EQ(3, result[0].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensorListRef()[0].type_id()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensorListRef()[1].type_id()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensorListRef()[2].type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[0])); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensorListRef()[1])); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[2])); } std::vector kernelWithIntListOutput(const Tensor&, int64_t input1, int64_t input2, int64_t input3) { @@ -224,16 +224,16 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithMultiple auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(5, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); EXPECT_EQ(5, result[1].toInt()); EXPECT_EQ(2, result[2].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[2].toTensorListRef()[0].type_id()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[2].toTensorListRef()[1].type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[2].toTensorListRef()[0])); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[2].toTensorListRef()[1])); EXPECT_EQ(0, result[3].toInt()); auto result_dict = c10::impl::toTypedDict(result[4].toGenericDict()); EXPECT_EQ(2, result_dict.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result_dict.at("first").type_id()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result_dict.at("second").type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result_dict.at("first"))); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result_dict.at("second"))); } Tensor kernelWithTensorInputByReferenceWithOutput(const Tensor& input1) { @@ -252,11 +252,11 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorIn auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); } TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) { @@ -268,11 +268,11 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorIn auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); } Tensor captured_input; @@ -294,11 +294,11 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorIn auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); } TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) { @@ -310,11 +310,11 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorIn auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); } int64_t captured_int_input = 0; @@ -803,9 +803,9 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenFallbackKernelWith EXPECT_EQ(4, outputs[0].toInt()); } -c10::optional called_arg2; -c10::optional called_arg3; -c10::optional called_arg4; +c10::optional called_arg2 = c10::nullopt; +c10::optional called_arg3 = c10::nullopt; +c10::optional called_arg4 = c10::nullopt; void kernelWithOptInputWithoutOutput(Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { called = true; @@ -825,7 +825,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithOptional EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(called_arg2->type_id(), TensorTypeId::CUDATensorId); + EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CUDATensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); @@ -857,11 +857,11 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithOptional called = false; auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), c10::IValue(), std::string("text")); EXPECT_EQ(1, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, outputs[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(outputs[0].toTensor())); EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(called_arg2->type_id(), TensorTypeId::CUDATensorId); + EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CUDATensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); @@ -890,7 +890,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithOptional auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), c10::IValue(), std::string("text")); EXPECT_EQ(3, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, outputs[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(outputs[0].toTensor())); EXPECT_TRUE(outputs[1].isNone()); EXPECT_EQ("text", outputs[2].toString()->string()); diff --git a/aten/src/ATen/core/op_registration/kernel_function_test.cpp b/aten/src/ATen/core/op_registration/kernel_function_test.cpp index 6023b519006f2..4a216a45caf5c 100644 --- a/aten/src/ATen/core/op_registration/kernel_function_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_function_test.cpp @@ -154,11 +154,11 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorOutput_w auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); } c10::List kernelWithTensorListOutput(const Tensor& input1, const Tensor& input2, const Tensor& input3) { @@ -175,9 +175,9 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListOutp auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); EXPECT_EQ(3, result[0].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensorListRef()[0].type_id()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensorListRef()[1].type_id()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensorListRef()[2].type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[0])); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensorListRef()[1])); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[2])); } c10::List kernelWithIntListOutput(const Tensor&, int64_t input1, int64_t input2, int64_t input3) { @@ -221,16 +221,16 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithMultipleOutput auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(5, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); EXPECT_EQ(5, result[1].toInt()); EXPECT_EQ(2, result[2].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[2].toTensorListRef()[0].type_id()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[2].toTensorListRef()[1].type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[2].toTensorListRef()[0])); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[2].toTensorListRef()[1])); EXPECT_EQ(0, result[3].toInt()); auto result_dict = c10::impl::toTypedDict(result[4].toGenericDict()); EXPECT_EQ(2, result_dict.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result_dict.at("first").type_id()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result_dict.at("second").type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result_dict.at("first"))); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result_dict.at("second"))); } Tensor kernelWithTensorInputByReferenceWithOutput(const Tensor& input1) { @@ -251,11 +251,11 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByR auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); } TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) { @@ -268,11 +268,11 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByV auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); } Tensor captured_input; @@ -295,11 +295,11 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByR auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); } TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) { @@ -312,11 +312,11 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByV auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); } int64_t captured_int_input = 0; @@ -526,9 +526,9 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenFallbackKernelWithoutTen EXPECT_EQ(4, outputs[0].toInt()); } -c10::optional called_arg2; -c10::optional called_arg3; -c10::optional called_arg4; +c10::optional called_arg2 = c10::nullopt; +c10::optional called_arg3 = c10::nullopt; +c10::optional called_arg4 = c10::nullopt; void kernelWithOptInputWithoutOutput(Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { called = true; @@ -548,7 +548,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithOptionalInputs EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(called_arg2->type_id(), TensorTypeId::CUDATensorId); + EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CUDATensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); @@ -580,11 +580,11 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithOptionalInputs called = false; auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), c10::IValue(), std::string("text")); EXPECT_EQ(1, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, outputs[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(outputs[0].toTensor())); EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(called_arg2->type_id(), TensorTypeId::CUDATensorId); + EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CUDATensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); @@ -613,7 +613,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithOptionalInputs auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), c10::IValue(), std::string("text")); EXPECT_EQ(3, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, outputs[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(outputs[0].toTensor())); EXPECT_TRUE(outputs[1].isNone()); EXPECT_EQ("text", outputs[2].toString()->string()); diff --git a/aten/src/ATen/core/op_registration/kernel_functor_test.cpp b/aten/src/ATen/core/op_registration/kernel_functor_test.cpp index 43ee47bc41964..ce7e5d4eea4e4 100644 --- a/aten/src/ATen/core/op_registration/kernel_functor_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_functor_test.cpp @@ -169,11 +169,11 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorOutput_wh auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); } struct KernelWithTensorListOutput final : OperatorKernel { @@ -192,9 +192,9 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListOutpu auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); EXPECT_EQ(3, result[0].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensorListRef()[0].type_id()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensorListRef()[1].type_id()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensorListRef()[2].type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[0])); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensorListRef()[1])); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[2])); } struct KernelWithIntListOutput final : OperatorKernel { @@ -242,16 +242,16 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithMultipleOutputs auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(5, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); EXPECT_EQ(5, result[1].toInt()); EXPECT_EQ(2, result[2].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[2].toTensorListRef()[0].type_id()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[2].toTensorListRef()[1].type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[2].toTensorListRef()[0])); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[2].toTensorListRef()[1])); EXPECT_EQ(0, result[3].toInt()); auto result_dict = c10::impl::toTypedDict(result[4].toGenericDict()); EXPECT_EQ(2, result_dict.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result_dict.at("first").type_id()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result_dict.at("second").type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result_dict.at("first"))); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result_dict.at("second"))); } struct KernelWithTensorInputByReferenceWithOutput final : OperatorKernel { @@ -276,11 +276,11 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByRe auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); } TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) { @@ -293,11 +293,11 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByVa auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); } Tensor captured_input; @@ -324,11 +324,11 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByRe auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); } TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) { @@ -341,11 +341,11 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByVa auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); } int64_t captured_int_input = 0; @@ -675,9 +675,9 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenFallbackKernelWithoutTens EXPECT_EQ(4, outputs[0].toInt()); } -c10::optional called_arg2; -c10::optional called_arg3; -c10::optional called_arg4; +c10::optional called_arg2 = c10::nullopt; +c10::optional called_arg3 = c10::nullopt; +c10::optional called_arg4 = c10::nullopt; struct KernelWithOptInputWithoutOutput final : OperatorKernel { void operator()(Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { @@ -699,7 +699,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithOptionalInputs_ EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(called_arg2->type_id(), TensorTypeId::CUDATensorId); + EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CUDATensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); @@ -733,11 +733,11 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithOptionalInputs_ called = false; auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), c10::IValue(), std::string("text")); EXPECT_EQ(1, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, outputs[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(outputs[0].toTensor())); EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(called_arg2->type_id(), TensorTypeId::CUDATensorId); + EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CUDATensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); @@ -768,7 +768,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithOptionalInputs_ auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), c10::IValue(), std::string("text")); EXPECT_EQ(3, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, outputs[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(outputs[0].toTensor())); EXPECT_TRUE(outputs[1].isNone()); EXPECT_EQ("text", outputs[2].toString()->string()); diff --git a/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp b/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp index 735773bbe1f67..d0dc2c4fd814e 100644 --- a/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_lambda_legacy_test.cpp @@ -142,11 +142,11 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorOutp auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); } TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) { @@ -161,9 +161,9 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorList auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); EXPECT_EQ(3, result[0].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensorListRef()[0].type_id()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensorListRef()[1].type_id()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensorListRef()[2].type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[0])); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensorListRef()[1])); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[2])); } TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntListOutput_whenRegistered_thenCanBeCalled) { @@ -203,16 +203,16 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithMultipleOu auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(5, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); EXPECT_EQ(5, result[1].toInt()); EXPECT_EQ(2, result[2].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[2].toTensorListRef()[0].type_id()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[2].toTensorListRef()[1].type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[2].toTensorListRef()[0])); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[2].toTensorListRef()[1])); EXPECT_EQ(0, result[3].toInt()); auto result_dict = c10::impl::toTypedDict(result[4].toGenericDict()); EXPECT_EQ(2, result_dict.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result_dict.at("first").type_id()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result_dict.at("second").type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result_dict.at("first"))); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result_dict.at("second"))); } TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) { @@ -226,11 +226,11 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInpu auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); } TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) { @@ -244,11 +244,11 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInpu auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); } Tensor captured_input; @@ -264,11 +264,11 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInpu auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); } TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) { @@ -282,11 +282,11 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInpu auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); } int64_t captured_int_input = 0; @@ -729,9 +729,9 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenFallbackKernelWithou TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) { bool called; - c10::optional called_arg2; - c10::optional called_arg3; - c10::optional called_arg4; + c10::optional called_arg2 = c10::nullopt; + c10::optional called_arg3 = c10::nullopt; + c10::optional called_arg4 = c10::nullopt; auto registrar = RegisterOperators().op( "_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> ()", @@ -750,7 +750,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalIn EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(called_arg2->type_id(), TensorTypeId::CUDATensorId); + EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CUDATensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); @@ -768,9 +768,9 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalIn TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalInputs_withOutput_whenRegistered_thenCanBeCalled) { bool called; - c10::optional called_arg2; - c10::optional called_arg3; - c10::optional called_arg4; + c10::optional called_arg2 = c10::nullopt; + c10::optional called_arg3 = c10::nullopt; + c10::optional called_arg4 = c10::nullopt; auto registrar = RegisterOperators().op( "_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> Tensor?", @@ -787,11 +787,11 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalIn called = false; auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), c10::IValue(), std::string("text")); EXPECT_EQ(1, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, outputs[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(outputs[0].toTensor())); EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(called_arg2->type_id(), TensorTypeId::CUDATensorId); + EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CUDATensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); @@ -810,9 +810,9 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalIn TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalInputs_withMultipleOutputs_whenRegistered_thenCanBeCalled) { bool called; - c10::optional called_arg2; - c10::optional called_arg3; - c10::optional called_arg4; + c10::optional called_arg2 = c10::nullopt; + c10::optional called_arg3 = c10::nullopt; + c10::optional called_arg4 = c10::nullopt; auto registrar = RegisterOperators().op( "_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> (Tensor?, int?, str?)", @@ -824,7 +824,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalIn auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), c10::IValue(), std::string("text")); EXPECT_EQ(3, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, outputs[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(outputs[0].toTensor())); EXPECT_TRUE(outputs[1].isNone()); EXPECT_EQ("text", outputs[2].toString()->string()); diff --git a/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp b/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp index 57c455cb19acb..f6e4c06867cd8 100644 --- a/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp +++ b/aten/src/ATen/core/op_registration/kernel_lambda_test.cpp @@ -136,11 +136,11 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorOutput_whe auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) { @@ -154,9 +154,9 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListOutput auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); EXPECT_EQ(3, result[0].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensorListRef()[0].type_id()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensorListRef()[1].type_id()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensorListRef()[2].type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[0])); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensorListRef()[1])); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[2])); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListOutput_whenRegistered_thenCanBeCalled) { @@ -196,16 +196,16 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithMultipleOutputs_ auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(5, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); EXPECT_EQ(5, result[1].toInt()); EXPECT_EQ(2, result[2].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[2].toTensorListRef()[0].type_id()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[2].toTensorListRef()[1].type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[2].toTensorListRef()[0])); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[2].toTensorListRef()[1])); EXPECT_EQ(0, result[3].toInt()); auto result_dict = c10::impl::toTypedDict(result[4].toGenericDict()); EXPECT_EQ(2, result_dict.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result_dict.at("first").type_id()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result_dict.at("second").type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result_dict.at("first"))); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result_dict.at("second"))); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) { @@ -220,11 +220,11 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByRef auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) { @@ -239,11 +239,11 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByVal auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, result[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); } Tensor captured_input; @@ -260,11 +260,11 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByRef auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) { @@ -279,11 +279,11 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByVal auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, captured_input.type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); } int64_t captured_int_input = 0; @@ -461,9 +461,9 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenFallbackKernelWithoutTenso EXPECT_EQ(4, outputs[0].toInt()); } -c10::optional called_arg2; -c10::optional called_arg3; -c10::optional called_arg4; +c10::optional called_arg2 = c10::nullopt; +c10::optional called_arg3 = c10::nullopt; +c10::optional called_arg4 = c10::nullopt; TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators().op( @@ -483,7 +483,7 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithOptionalInputs_w EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(called_arg2->type_id(), TensorTypeId::CUDATensorId); + EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CUDATensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); @@ -515,11 +515,11 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithOptionalInputs_w called = false; auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), c10::IValue(), std::string("text")); EXPECT_EQ(1, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, outputs[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(outputs[0].toTensor())); EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(called_arg2->type_id(), TensorTypeId::CUDATensorId); + EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CUDATensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); @@ -547,7 +547,7 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithOptionalInputs_w auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), c10::IValue(), std::string("text")); EXPECT_EQ(3, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, outputs[0].toTensor().type_id()); + EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(outputs[0].toTensor())); EXPECT_TRUE(outputs[1].isNone()); EXPECT_EQ("text", outputs[2].toString()->string()); diff --git a/aten/src/ATen/core/op_registration/op_registration.cpp b/aten/src/ATen/core/op_registration/op_registration.cpp index 1cb3ea521df9d..59e45c096414b 100644 --- a/aten/src/ATen/core/op_registration/op_registration.cpp +++ b/aten/src/ATen/core/op_registration/op_registration.cpp @@ -12,7 +12,7 @@ static_assert(std::is_nothrow_move_assignable dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction&& cache_creator, void* unboxed_kernel, void* unboxed_autograd_kernel) + explicit OperatorRegistrar(FunctionSchema&& schema, OperatorOptions&& operatorOptions, c10::optional dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction&& cache_creator, void* unboxed_kernel) : op_(Dispatcher::singleton().registerSchema(std::move(schema), std::move(operatorOptions))), kernel_registration_handle_(c10::nullopt) { // cache creator can only be set if the kernel is also set TORCH_INTERNAL_ASSERT((kernel != nullptr || unboxed_kernel != nullptr) || !static_cast(cache_creator)); @@ -24,10 +24,6 @@ class RegisterOperators::OperatorRegistrar final { kernel_registration_handle_ = Dispatcher::singleton().registerCatchallKernel(op_.opHandle(), kernel, std::move(cache_creator), unboxed_kernel); } } - - if (unboxed_autograd_kernel != nullptr) { - unboxed_autograd_kernel_registration_handle_ = Dispatcher::singleton().registerUnboxedAutogradKernel(op_.opHandle(), unboxed_autograd_kernel); - } } OperatorRegistrar(OperatorRegistrar&& rhs) noexcept = default; @@ -40,58 +36,72 @@ class RegisterOperators::OperatorRegistrar final { private: c10::SchemaRegistrationHandleRAII op_; c10::optional kernel_registration_handle_; - c10::optional unboxed_autograd_kernel_registration_handle_; }; -void RegisterOperators::checkSchemaAndRegisterOp_(const std::string& schemaOrNameStr, Options&& options) { - #if defined(CAFFE2_IS_XPLAT_BUILD) - throw std::logic_error("Tried to register operator " + schemaOrNameStr + ". We don't support registering c10 ops on mobile yet because the function schema parser isn't present in the mobile build."); - #else - either schemaOrName = torch::jit::parseSchemaOrName(schemaOrNameStr); - if (schemaOrName.is_right()) { - // schema was explicitly specified. Check it matches the inferred one and register the op. - - auto schema = std::move(schemaOrName).right(); - TORCH_CHECK( - options.aliasAnalysisKind_ == AliasAnalysisKind::FROM_SCHEMA || - !schema.hasAnyAliasInfo(), - "In operator registration: Tried to register operator ", - schemaOrNameStr, - " with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA."); - - checkSchemaAndRegisterOp_(std::move(schema), std::move(options)); - } else { - // schema wasn't explicitly specified. Take the inferred schema for registering the op. - - FunctionSchema inferred_schema = inferSchemaFromKernels_(schemaOrNameStr, options); - OperatorName name = std::move(schemaOrName).left(); - FunctionSchema inferred_schema_with_name( - std::move(name.name), - std::move(name.overload_name), - inferred_schema.arguments(), - inferred_schema.returns(), - inferred_schema.is_vararg(), - inferred_schema.is_varret() - ); - - checkNoDuplicateKernels_(inferred_schema_with_name, options); - - // This would have unexpected behavior since an inferred schema will not - // have aliasing annotations. - TORCH_CHECK( - options.aliasAnalysisKind_ != AliasAnalysisKind::FROM_SCHEMA, - "In operator registration: Tried to register operator ", - schemaOrNameStr, - " with AliasAnalysisKind::FROM_SCHEMA, but the schema is inferred."); - - // Register all kernels with the schema we inferred - registerOp_(std::move(inferred_schema_with_name), std::move(options)); +void RegisterOperators::checkSchemaAndRegisterOp_(Options&& options) { + if (options.legacyATenSchema_.has_value()) { + // Ignore legacy aten operators, don't add them to c10 + return; + } + + TORCH_CHECK(options.schemaOrName_.has_value(), "In operator registration: Tried to register an operator without specifying a schema or operator name."); + if (options.schemaOrName_->is_right()) { + // schema was explicitly specified. Check it matches the inferred one and register the op. + + const FunctionSchema& schema = options.schemaOrName_->right(); + TORCH_CHECK( + options.aliasAnalysisKind_ == AliasAnalysisKind::FROM_SCHEMA || + !schema.hasAnyAliasInfo(), + "In operator registration: Tried to register operator ", + options.schemaOrName_->right(), + " with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA."); + + for (auto& kernel : options.kernels) { + if (nullptr != kernel.inferred_function_schema.get()) { + c10::optional schema_difference = findSchemaDifferences(schema, *kernel.inferred_function_schema); + if (schema_difference.has_value()) { + TORCH_CHECK(false, "In operator registration: Specified function schema [", toString(schema), "] ", + "doesn't match inferred function schema [", toString(*kernel.inferred_function_schema), "]. ", + *schema_difference); + } + } } - #endif + + checkNoDuplicateKernels_(options); + + registerOp_(std::move(options)); + } else { + // schema wasn't explicitly specified. Take the inferred schema for registering the op. + + OperatorName name = std::move(*options.schemaOrName_).left(); + FunctionSchema inferred_schema = inferSchemaFromKernels_(name, options); + + options.schemaOrName_ = c10::make_right( + std::move(name.name), + std::move(name.overload_name), + inferred_schema.arguments(), + inferred_schema.returns(), + inferred_schema.is_vararg(), + inferred_schema.is_varret() + ); + + checkNoDuplicateKernels_(options); + + // This would have unexpected behavior since an inferred schema will not + // have aliasing annotations. + TORCH_CHECK( + options.aliasAnalysisKind_ != AliasAnalysisKind::FROM_SCHEMA, + "In operator registration: Tried to register operator ", + options.schemaOrName_->right(), + " with AliasAnalysisKind::FROM_SCHEMA, but the schema is inferred."); + + // Register all kernels with the schema we inferred + registerOp_(std::move(options)); + } } -c10::FunctionSchema RegisterOperators::inferSchemaFromKernels_(const std::string& opNameStr, const RegisterOperators::Options& options) { - TORCH_CHECK(options.kernels.size() > 0, "Cannot infer operator schema in registration of operator ", opNameStr, " because there is no kernel specified."); +c10::FunctionSchema RegisterOperators::inferSchemaFromKernels_(const OperatorName& opName, const RegisterOperators::Options& options) { + TORCH_CHECK(options.kernels.size() > 0, "Cannot infer operator schema in registration of operator ", toString(opName), " because there is no kernel specified."); c10::optional inferred_schema = c10::nullopt; for (const auto& kernel : options.kernels) { @@ -108,53 +118,37 @@ c10::FunctionSchema RegisterOperators::inferSchemaFromKernels_(const std::string } } } - TORCH_CHECK(inferred_schema.has_value(), "Cannot infer operator schema for this kind of kernel in registration of operator ", opNameStr,". Please explicitly specify the operator schema or specify at least one kernel for which we can infer the schema."); + TORCH_CHECK(inferred_schema.has_value(), "Cannot infer operator schema for this kind of kernel in registration of operator ", toString(opName), ". Please explicitly specify the operator schema or specify at least one kernel for which we can infer the schema."); return *inferred_schema; } -void RegisterOperators::checkSchemaAndRegisterOp_(FunctionSchema schema, Options&& options) { - for (auto& kernel : options.kernels) { - if (nullptr != kernel.inferred_function_schema.get()) { - c10::optional schema_difference = findSchemaDifferences(schema, *kernel.inferred_function_schema); - if (schema_difference.has_value()) { - TORCH_CHECK(false, "In operator registration: Specified function schema [", toString(schema), "] ", - "doesn't match inferred function schema [", toString(*kernel.inferred_function_schema), "]. ", - *schema_difference); - } - } - } - - checkNoDuplicateKernels_(schema, options); - - registerOp_(std::move(schema), std::move(options)); -} - -void RegisterOperators::checkNoDuplicateKernels_(const FunctionSchema& schema, const Options& options) { +void RegisterOperators::checkNoDuplicateKernels_(const Options& options) { std::unordered_set dispatch_keys; bool has_catchall_kernel = false; for (const auto& kernel : options.kernels) { if (kernel.dispatch_key.has_value()) { - TORCH_CHECK(0 == dispatch_keys.count(*kernel.dispatch_key), "In operator registration: Tried to register multiple kernels with same dispatch key ", toString(*kernel.dispatch_key), " for operator schema ", toString(schema)); + TORCH_CHECK(0 == dispatch_keys.count(*kernel.dispatch_key), "In operator registration: Tried to register multiple kernels with same dispatch key ", toString(*kernel.dispatch_key), " for operator schema ", toString(options.schemaOrName_->right())); dispatch_keys.insert(*kernel.dispatch_key); } else { - TORCH_CHECK(!has_catchall_kernel, "In operator registration: Tried to register multiple catch-all kernels for operator schema " + toString(schema)); + TORCH_CHECK(!has_catchall_kernel, "In operator registration: Tried to register multiple catch-all kernels for operator schema " + toString(options.schemaOrName_->right())); has_catchall_kernel = true; } } } -void RegisterOperators::registerOp_(FunctionSchema&& schema, Options&& options) { +void RegisterOperators::registerOp_(Options&& options) { + FunctionSchema schema = std::move(*options.schemaOrName_).right(); OperatorName op_name = schema.operator_name(); auto operatorOptions = makeOperatorOptions_(options); if (0 == options.kernels.size()) { - registerSchemaOnly_(std::move(schema), std::move(operatorOptions), options.unboxedAutogradKernel_); + registerSchemaOnly_(std::move(schema), std::move(operatorOptions)); } else { for (auto& kernel : options.kernels) { - registerSchemaAndKernel_(schema, std::move(kernel), std::move(operatorOptions), options.unboxedAutogradKernel_); + registerSchemaAndKernel_(schema, std::move(kernel), std::move(operatorOptions)); } } @@ -169,14 +163,14 @@ OperatorOptions RegisterOperators::makeOperatorOptions_(const RegisterOperators: return result; } -void RegisterOperators::registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& kernel, OperatorOptions&& operatorOptions, void* unboxedAutogradKernel) { +void RegisterOperators::registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& kernel, OperatorOptions&& operatorOptions) { TORCH_INTERNAL_ASSERT((kernel.kernel_func != nullptr || kernel.unboxed_kernel_func != nullptr), "Kernel must be set"); - registrars_.emplace_back(std::move(schema), std::move(operatorOptions), kernel.dispatch_key, kernel.kernel_func, std::move(kernel.cache_creator_func), kernel.unboxed_kernel_func, unboxedAutogradKernel); + registrars_.emplace_back(std::move(schema), std::move(operatorOptions), kernel.dispatch_key, kernel.kernel_func, std::move(kernel.cache_creator_func), kernel.unboxed_kernel_func); } -void RegisterOperators::registerSchemaOnly_(FunctionSchema&& schema, OperatorOptions&& operatorOptions, void* unboxedAutogradKernel) { - registrars_.emplace_back(std::move(schema), std::move(operatorOptions), c10::nullopt, nullptr, nullptr, nullptr, unboxedAutogradKernel); +void RegisterOperators::registerSchemaOnly_(FunctionSchema&& schema, OperatorOptions&& operatorOptions) { + registrars_.emplace_back(std::move(schema), std::move(operatorOptions), c10::nullopt, nullptr, nullptr, nullptr); } RegisterOperators::RegisterOperators() = default; diff --git a/aten/src/ATen/core/op_registration/op_registration.h b/aten/src/ATen/core/op_registration/op_registration.h index b210030627394..099a33ee14a1a 100644 --- a/aten/src/ATen/core/op_registration/op_registration.h +++ b/aten/src/ATen/core/op_registration/op_registration.h @@ -10,6 +10,11 @@ #include #include #include +#if !defined(CAFFE2_IS_XPLAT_BUILD) +#include +#endif +#include +#include namespace c10 { @@ -28,7 +33,8 @@ namespace c10 { * > } * > * > static auto registry = c10::RegisterOperators() - * > .op("my_op", c10::RegisterOperators::options() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") * > .kernel(TensorTypeId::CPUTensorId)); */ class CAFFE2_API RegisterOperators final { @@ -58,6 +64,50 @@ class CAFFE2_API RegisterOperators final { return std::move(*this).kernel(c10::nullopt, kernel_func, std::move(cache_creator), nullptr, nullptr); } + // internal only for registering caffe2 ops + Options&& schema(FunctionSchema&& schema) { + TORCH_CHECK(!schemaOrName_.has_value(), "You can only specify the schema once per operator registration."); + schemaOrName_ = c10::make_right(std::move(schema)); + return std::move(*this); + } + + /** + * Use this to specify the schema for an operator. You can also specify + * the operator name only to have the function signature part of the + * schema be inferred from the kernel function. + * + * Example: + * + * > // Infer function signature from my_kernel_cpu + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") + * > .kernel(TensorTypeId::CPUTensorId)); + * > + * > + * > // Explicitly specify full schema + * > static auto registry = c10::RegisterOperators() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op(Tensor a) -> Tensor") + * > .kernel(TensorTypeId::CPUTensorId)); + */ + Options&& schema(const std::string& schemaOrName) { + TORCH_CHECK(!schemaOrName_.has_value(), "Tried to register operator ", schemaOrName," but specified schema multiple times. You can only specify the schema once per operator registration."); + TORCH_CHECK(!legacyATenSchema_.has_value(), "Tried to register operator ", schemaOrName," but specified schema multiple times. You can only specify the schema once per operator registration."); + + if (Options::op_is_still_on_aten_dispatcher_(schemaOrName.c_str())) { + TORCH_CHECK(kernels.size() == 0, "For legacy aten ops, the schema() call must happen before any kernel() calls. Operator was ", schemaOrName); + legacyATenSchema_ = schemaOrName; + } else { + #if defined(CAFFE2_IS_XPLAT_BUILD) + throw std::logic_error("Tried to register operator " + schemaOrName + ". We don't support registering c10 ops on mobile yet because the function schema parser isn't present in the mobile build."); + #else + schemaOrName_ = torch::jit::parseSchemaOrName(schemaOrName); + #endif + } + return std::move(*this); + } + /** * Use this to register an operator whose kernel is implemented as a functor. * The kernel is only called for inputs matching the given dispatch key. @@ -73,7 +123,8 @@ class CAFFE2_API RegisterOperators final { * > } * > * > static auto registry = c10::RegisterOperators() - * > .op("my_op", c10::RegisterOperators::options() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") * > .kernel(TensorTypeId::CPUTensorId)); * * The functor constructor can take arguments to configure the kernel. @@ -91,7 +142,8 @@ class CAFFE2_API RegisterOperators final { * > } * > * > static auto registry = c10::RegisterOperators() - * > .op("my_op", c10::RegisterOperators::options() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") * > .kernel(TensorTypeId::CPUTensorId, "some_configuration", 3, true)); */ template @@ -118,7 +170,8 @@ class CAFFE2_API RegisterOperators final { * > } * > * > static auto registry = c10::RegisterOperators() - * > .op("my_op", c10::RegisterOperators::options() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") * > .catchAllKernel()); * * The functor constructor can take arguments to configure the kernel. @@ -136,7 +189,8 @@ class CAFFE2_API RegisterOperators final { * > } * > * > static auto registry = c10::RegisterOperators() - * > .op("my_op", c10::RegisterOperators::options() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") * > .catchAllKernel("some_configuration", 3, true)); */ template @@ -158,7 +212,8 @@ class CAFFE2_API RegisterOperators final { * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} } * > * > static auto registry = c10::RegisterOperators() - * > .op("my_op", c10::RegisterOperators::options() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") * > .kernel(TensorTypeId::CPUTensorId)); */ template @@ -180,7 +235,8 @@ class CAFFE2_API RegisterOperators final { * > namespace { Tensor my_kernel_cpu(Tensor a, Tensor b) {...} } * > * > static auto registry = c10::RegisterOperators() - * > .op("my_op", c10::RegisterOperators::options() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") * > .catchAllKernel()); */ template @@ -199,7 +255,14 @@ class CAFFE2_API RegisterOperators final { static_assert(!std::is_same::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API."); static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr"); - return std::move(*this).kernelFunctorUnboxedOnly::type>(dispatch_key); + if (legacyATenSchema_.has_value()) { + // TODO Remove this once all ops are moved to c10. + TORCH_INTERNAL_ASSERT(!schemaOrName_.has_value()); + at::globalATenDispatch().registerOp(dispatch_key, legacyATenSchema_->c_str(), kernel_func); + return std::move(*this); + } else { + return std::move(*this).kernelFunctorUnboxedOnly::type>(dispatch_key); + } } // TODO Remove impl_unboxedOnlyCatchAllKernel once all of aten can generate boxed kernels @@ -209,7 +272,14 @@ class CAFFE2_API RegisterOperators final { static_assert(!std::is_same::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API."); static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr"); - return std::move(*this).kernelFunctorUnboxedOnly::type>(c10::nullopt); + if (legacyATenSchema_.has_value()) { + // TODO Remove this once all ops are moved to c10. + TORCH_INTERNAL_ASSERT(!schemaOrName_.has_value()); + at::globalATenDispatch().registerOp(TensorTypeId::UndefinedTensorId, legacyATenSchema_->c_str(), kernel_func); + return std::move(*this); + } else { + return std::move(*this).kernelFunctorUnboxedOnly::type>(c10::nullopt); + } } /** @@ -224,7 +294,8 @@ class CAFFE2_API RegisterOperators final { * Example: * * > static auto registry = c10::RegisterOperators() - * > .op("my_op", c10::RegisterOperators::options() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") * > .kernel(TensorTypeId::CPUTensorId, [] (Tensor a) -> Tensor {...})); */ template @@ -255,7 +326,8 @@ class CAFFE2_API RegisterOperators final { * Example: * * > static auto registry = c10::RegisterOperators() - * > .op("my_op", c10::RegisterOperators::options() + * > .op(c10::RegisterOperators::options() + * > .schema("my_op") * > .catchAllKernel([] (Tensor a) -> Tensor {...})); */ template @@ -280,16 +352,35 @@ class CAFFE2_API RegisterOperators final { return std::move(*this); } - template - Options&& impl_unboxedAutogradKernel(Result (*kernel)(Args...)) && { - // TODO Infer and check schema - TORCH_CHECK(kernel != nullptr, "Kernel function pointer cannot be nullptr"); - TORCH_CHECK(unboxedAutogradKernel_ == nullptr, "You can only call impl_unboxedAutogradKernel() once per operator registration."); - unboxedAutogradKernel_ = reinterpret_cast(kernel); - return std::move(*this); + private: + static c10::OperatorName parse_operator_name_(const char* schema) { + // TODO Remove this function once all aten ops are on c10 + // We can't depend on the jit function schema parser here, but parsing + // the op name is trivial. Let's just do it by hand. + std::string schema_str(schema); + size_t name_end_pos = schema_str.find_first_of(".("); + if (name_end_pos == std::string::npos) { + name_end_pos = schema_str.size(); + } + size_t overload_name_end_pos = name_end_pos + 1; + if (schema_str[name_end_pos] == '.') { + overload_name_end_pos = schema_str.find_first_of('(', name_end_pos); + if (overload_name_end_pos == std::string::npos) { + overload_name_end_pos = name_end_pos + 1; + } + } + return c10::OperatorName{ + schema_str.substr(0, name_end_pos), + (overload_name_end_pos > name_end_pos + 1) ? schema_str.substr(name_end_pos + 1, overload_name_end_pos - name_end_pos - 1) : "" + }; + } + + static bool op_is_still_on_aten_dispatcher_(const char* schema_string) { + // TODO Remove this function once all aten ops are on c10 + const auto op_name = parse_operator_name_(schema_string); + return at::aten_op_is_not_moved_to_c10_yet(op_name); } - private: Options&& kernel(c10::optional&& dispatch_key, KernelFunction* kernel_func, KernelCacheCreatorFunction&& cache_creator, void* unboxed_kernel_func, std::unique_ptr&& inferred_function_schema) && { KernelRegistrationConfig config; config.dispatch_key = dispatch_key; @@ -314,16 +405,32 @@ class CAFFE2_API RegisterOperators final { template Options&& kernelFunctorUnboxedOnly(c10::optional&& dispatch_key, ConstructorParameters&&... constructorParameters) && { + // Setting cache_creator to nullptr so calling the kernel doesn't need to call it, which would be expensive. + // Since the dispatcher static_cast's cache objects into our functor type to call their operator(), this nullptr + // will cause it to create and static_cast an invalid cache object, which is technically illegal in the C++ standard, + // but it works as long as operator() does not access any functor members. + // Exception: Backend extensions use runtime function pointers and store these in the functor as members, + // so we need a cache if sizeof...(ConstructorParameters) != 0 + auto cache_creator = + (sizeof...(ConstructorParameters) == 0) + ? KernelCacheCreatorFunction(nullptr) + : detail::KernelFactory...>(std::forward(constructorParameters)...); + return std::move(*this).kernel( std::move(dispatch_key), nullptr, - nullptr, // setting cache creator to nullptr so calling the kernel doesn't need to call it, which would be expensive + std::move(cache_creator), reinterpret_cast(&detail::wrap_kernel_functor_unboxed::call), detail::FunctionSchemaInferer()() ); } - Options() = default; + Options() + : schemaOrName_(c10::nullopt) + , legacyATenSchema_(c10::nullopt) + , kernels() + , aliasAnalysisKind_(c10::nullopt) + {} // KernelRegistrationConfig accumulates all information from the config // parameters passed to a RegisterOperators::op() call into one object. @@ -343,9 +450,16 @@ class CAFFE2_API RegisterOperators final { std::unique_ptr inferred_function_schema; }; + // For all modern ops, schemaOrName_ is set. + // For legacy ATen ops (i.e. ops on globalATenDispatch()), legacyATenSchema_ + // is set. We never set both. + // TODO This is just a hack to forward some registrations to globalATenDispatch(). + // We should remove legacyATenSchema_ once all ops are on the c10 dispatcher. + c10::optional> schemaOrName_; + c10::optional legacyATenSchema_; + std::vector kernels; optional aliasAnalysisKind_; - void* unboxedAutogradKernel_; // can be nullptr, not all kernels have this friend class RegisterOperators; }; @@ -362,15 +476,23 @@ class CAFFE2_API RegisterOperators final { /** * Call this to register an operator. See class doc comment for examples. */ - RegisterOperators&& op(const std::string& schemaOrName, Options&& options = RegisterOperators::options()) && { - checkSchemaAndRegisterOp_(schemaOrName, std::move(options)); + RegisterOperators&& op(Options&& options) && { + checkSchemaAndRegisterOp_(std::move(options)); return std::move(*this); } + /** + * This is a shorthand for RegisterOperators::op(Options) where you can + * specify the operator schema outside of the options parameter. + * See class doc comment for examples. + */ + RegisterOperators&& op(const std::string& schemaOrName, Options&& options = RegisterOperators::options()) && { + return std::move(*this).op(std::move(options).schema(schemaOrName)); + } + // internal only for registering caffe2 ops RegisterOperators&& op(FunctionSchema schema, Options&& options) && { - checkSchemaAndRegisterOp_(std::move(schema), std::move(options)); - return std::move(*this); + return std::move(*this).op(std::move(options).schema(std::move(schema))); } template @@ -417,7 +539,7 @@ class CAFFE2_API RegisterOperators final { guts::enable_if_t::value && !std::is_same::value, RegisterOperators&&> op(const std::string& schemaOrName, FuncType* func, Options&& options = RegisterOperators::options()) && { constexpr bool AllowLegacyTypes = true; - return std::move(*this).op(schemaOrName, std::move(options).kernelFunctor>, AllowLegacyTypes>(c10::nullopt, func)); + return std::move(*this).op(std::move(options).schema(schemaOrName).kernelFunctor>, AllowLegacyTypes>(c10::nullopt, func)); } /** @@ -442,7 +564,7 @@ class CAFFE2_API RegisterOperators final { static_assert(!std::is_base_of::value, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new RegisterOperators::options().kernel() based API instead."); constexpr bool AllowLegacyTypes = true; - return std::move(*this).op(schemaOrName, std::move(options).kernelFunctor>, AllowLegacyTypes>(c10::nullopt, std::forward(func))); + return std::move(*this).op(std::move(options).schema(schemaOrName).kernelFunctor>, AllowLegacyTypes>(c10::nullopt, std::forward(func))); } template @@ -453,18 +575,17 @@ class CAFFE2_API RegisterOperators final { static_assert(!std::is_base_of::value, "c10::OperatorKernel is part of the new kernel registration API and shouldn't be used together with the deprecated registration API. Please use the new RegisterOperators::options().kernel() based API instead."); constexpr bool AllowLegacyTypes = true; - return std::move(*this).op(schemaOrName, std::move(options).kernelFunctor>, AllowLegacyTypes>(c10::nullopt, std::forward(func))); + return std::move(*this).op(std::move(options).schema(schemaOrName).kernelFunctor>, AllowLegacyTypes>(c10::nullopt, std::forward(func))); } private: - void checkSchemaAndRegisterOp_(FunctionSchema schema, Options&& config); - void checkSchemaAndRegisterOp_(const std::string& schemaOrName, Options&& config); - - static c10::FunctionSchema inferSchemaFromKernels_(const std::string& opNameStr, const Options& options); - void checkNoDuplicateKernels_(const FunctionSchema& schema, const Options& options); - void registerOp_(FunctionSchema&& schema, Options&& options); - void registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& config, OperatorOptions&& options, void* unboxedAutogradKernel); - void registerSchemaOnly_(FunctionSchema&& schema, OperatorOptions&& options, void* unboxedAutogradKernel); + void checkSchemaAndRegisterOp_(Options&& config); + + static c10::FunctionSchema inferSchemaFromKernels_(const OperatorName& opNameStr, const Options& options); + void checkNoDuplicateKernels_(const Options& options); + void registerOp_(Options&& options); + void registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& config, OperatorOptions&& options); + void registerSchemaOnly_(FunctionSchema&& schema, OperatorOptions&& options); static OperatorOptions makeOperatorOptions_(const Options& options); class OperatorRegistrar; diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index 6a639bb0685d2..2c00c3f5aef22 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -19,6 +19,7 @@ using c10::RegisterOperators; using c10::OperatorKernel; using c10::Dispatcher; using c10::IValue; +using c10::TensorTypeId; using at::Tensor; namespace { @@ -36,6 +37,57 @@ struct MockKernel final : OperatorKernel { private: bool* called_; }; + +TEST(OperatorRegistrationTest, whenRegisteringWithSchemaBeforeKernelInOptionsObject_thenCanBeCalled) { + bool called = false; + auto registrar = c10::RegisterOperators().op(c10::RegisterOperators::options().schema("_test::dummy(Tensor dummy) -> ()").catchAllKernel(&called)); + + auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); + ASSERT_TRUE(op.has_value()); + EXPECT_FALSE(called); + callOp(*op, dummyTensor(c10::TensorTypeId::CUDATensorId)); + EXPECT_TRUE(called); +} + +TEST(OperatorRegistrationTest, whenRegisteringWithSchemaAfterKernelInOptionsObject_thenCanBeCalled) { + bool called = false; + auto registrar = c10::RegisterOperators().op(c10::RegisterOperators::options().catchAllKernel(&called).schema("_test::dummy(Tensor dummy) -> ()")); + + auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); + ASSERT_TRUE(op.has_value()); + EXPECT_FALSE(called); + callOp(*op, dummyTensor(c10::TensorTypeId::CUDATensorId)); + EXPECT_TRUE(called); +} + +TEST(OperatorRegistrationTest, whenRegisteringWithNameBeforeKernelInOptionsObject_thenCanBeCalled) { + bool called = false; + auto registrar = c10::RegisterOperators().op(c10::RegisterOperators::options().schema("_test::dummy").catchAllKernel(&called)); + + auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); + ASSERT_TRUE(op.has_value()); + EXPECT_FALSE(called); + callOp(*op, dummyTensor(c10::TensorTypeId::CUDATensorId)); + EXPECT_TRUE(called); +} + +TEST(OperatorRegistrationTest, whenRegisteringWithNameAfterKernelInOptionsObject_thenCanBeCalled) { + bool called = false; + auto registrar = c10::RegisterOperators().op(c10::RegisterOperators::options().catchAllKernel(&called).schema("_test::dummy")); + + auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); + ASSERT_TRUE(op.has_value()); + EXPECT_FALSE(called); + callOp(*op, dummyTensor(c10::TensorTypeId::CUDATensorId)); + EXPECT_TRUE(called); +} + +TEST(OperatorRegistrationTest, whenRegisteringWithoutSchema_thenFails) { + expectThrows([] { + c10::RegisterOperators().op(c10::RegisterOperators::options().catchAllKernel()); + }, "In operator registration: Tried to register an operator without specifying a schema or operator name."); +} + TEST(OperatorRegistrationTest, whenCallingOpWithWrongDispatchKey_thenFails) { auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId)); @@ -57,22 +109,23 @@ TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenCallingOp_thenCalls EXPECT_TRUE(called); } -TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenRegisteringDispatchedKernel_thenFails) { - bool called = false; - auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel(&called)); - expectThrows([&] { - c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called)); - }, "for an operator which already has a catch-all kernel registered"); -} - -TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenRegisteringDispatchedKernelInSameOpCall_thenFails) { - bool called = false; - expectThrows([&] { - auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() - .catchAllKernel(&called) - .kernel(c10::TensorTypeId::CPUTensorId, &called)); - }, "for an operator which already has a catch-all kernel registered"); -} +// TODO Rewrite (since this is now allowed) and reenable +// TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenRegisteringDispatchedKernel_thenFails) { +// bool called = false; +// auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel(&called)); +// expectThrows([&] { +// c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called)); +// }, "for an operator which already has a catch-all kernel registered"); +// } + +// TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenRegisteringDispatchedKernelInSameOpCall_thenFails) { +// bool called = false; +// expectThrows([&] { +// auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() +// .catchAllKernel(&called) +// .kernel(c10::TensorTypeId::CPUTensorId, &called)); +// }, "for an operator which already has a catch-all kernel registered"); +// } TEST(OperatorRegistrationTest, givenOpWithDispatchedKernelOutOfScope_whenRegisteringCatchallKernelAndCallingOp_thenCallsCatchallKernel) { bool called = false; @@ -89,22 +142,23 @@ TEST(OperatorRegistrationTest, givenOpWithDispatchedKernelOutOfScope_whenRegiste EXPECT_TRUE(called); } -TEST(OperatorRegistrationTest, givenOpWithDispatchedKernel_whenRegisteringCatchallKernel_thenFails) { - bool called = false; - auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called)); - expectThrows([&] { - c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel(&called)); - }, "Tried to register a catch-all kernel for an operator which already has kernels for dispatch keys CPUTensorId. An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is _test::dummy"); -} - -TEST(OperatorRegistrationTest, givenOpWithDispatchedKernel_whenRegisteringCatchallKernelInSameOpCall_thenFails) { - bool called = false; - expectThrows([&] { - auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() - .kernel(c10::TensorTypeId::CPUTensorId, &called) - .catchAllKernel(&called)); - }, "Tried to register a catch-all kernel for an operator which already has kernels for dispatch keys CPUTensorId. An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is _test::dummy"); -} +// TODO Rewrite (since this is now allowed) and reenable +// TEST(OperatorRegistrationTest, givenOpWithDispatchedKernel_whenRegisteringCatchallKernel_thenFails) { +// bool called = false; +// auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called)); +// expectThrows([&] { +// c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel(&called)); +// }, "Tried to register a catch-all kernel for an operator which already has kernels for dispatch keys CPUTensorId. An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is _test::dummy"); +// } +// +// TEST(OperatorRegistrationTest, givenOpWithDispatchedKernel_whenRegisteringCatchallKernelInSameOpCall_thenFails) { +// bool called = false; +// expectThrows([&] { +// auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() +// .kernel(c10::TensorTypeId::CPUTensorId, &called) +// .catchAllKernel(&called)); +// }, "Tried to register a catch-all kernel for an operator which already has kernels for dispatch keys CPUTensorId. An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is _test::dummy"); +// } TEST(OperatorRegistrationTest, givenOpWithCatchallKernelOutOfScope_whenRegisteringDispatchedKernelAndCallingOp_thenCallsCatchallKernel) { bool called = false; @@ -566,37 +620,42 @@ TEST(OperatorRegistrationTest, whenRegisteringMismatchingKernelsInSameOpCall_the }, "Tried to register kernels for same operator that infer a different function schema"); } -int64_t increment_kernel(int64_t a) { - return a + 1; +bool called_autograd = false; +bool called_catchall = false; + +void catchall_kernel(Tensor a) { + called_catchall = true; } -int64_t decrement_kernel(int64_t a) { - return a - 1; +void autograd_kernel(Tensor a) { + called_autograd = true; } TEST(OperatorRegistrationTest, whenRegisteringAutogradKernel_thenCanCallAutogradKernel) { - auto registrar = c10::RegisterOperators().op("_test::dummy(int dummy) -> int", c10::RegisterOperators::options() - .impl_unboxedAutogradKernel(&increment_kernel)); + auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() + .impl_unboxedOnlyKernel(TensorTypeId::VariableTensorId)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); - int64_t result = c10::Dispatcher::singleton().callUnboxedAutogradKernel(*op, 4); - EXPECT_EQ(5, result); + + called_autograd = false; + c10::Dispatcher::singleton().lookup(*op, TensorTypeId::VariableTensorId).callUnboxed(dummyTensor(TensorTypeId::VariableTensorId)); + EXPECT_TRUE(called_autograd); } TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallAutogradKernel) { - auto registrar = c10::RegisterOperators().op("_test::dummy(int dummy) -> int", c10::RegisterOperators::options() - .catchAllKernel() - .impl_unboxedAutogradKernel(&increment_kernel)); + auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() + .impl_unboxedOnlyCatchAllKernel() + .impl_unboxedOnlyKernel(TensorTypeId::VariableTensorId)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); - int64_t result = c10::Dispatcher::singleton().callUnboxedAutogradKernel(*op, 4); - EXPECT_EQ(5, result); -} -// TODO Test cases that adding multiple autograd kernels, removing some, and so on works -// (similar to test cases above for regular kernels "_whenNewerAndThenOlderKernelDeletedAndOpCalled") + called_catchall = called_autograd = false; + c10::Dispatcher::singleton().lookup(*op, TensorTypeId::VariableTensorId).callUnboxed(dummyTensor(TensorTypeId::VariableTensorId)); + EXPECT_FALSE(called_catchall); + EXPECT_TRUE(called_autograd); +} /** * This is used to check that a given type works correctly when passed as input @@ -725,8 +784,8 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { "string2", [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());}, "(str a) -> str"); testArgTypes::test( - dummyTensor(c10::TensorTypeId::CPUTensorId), [] (const Tensor& v) {EXPECT_EQ(c10::TensorTypeId::CPUTensorId, v.type_id());}, - dummyTensor(c10::TensorTypeId::CUDATensorId), [] (const IValue& v) {EXPECT_EQ(c10::TensorTypeId::CUDATensorId, v.toTensor().type_id());}, + dummyTensor(c10::TensorTypeId::CPUTensorId), [] (const Tensor& v) {EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(v));}, + dummyTensor(c10::TensorTypeId::CUDATensorId), [] (const IValue& v) {EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(v.toTensor()));}, "(Tensor a) -> Tensor"); @@ -752,35 +811,35 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { c10::optional("string2"), [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());}, "(str? a) -> str?"); testArgTypes>::test( - c10::optional(dummyTensor(c10::TensorTypeId::CPUTensorId)), [] (const c10::optional& v) {EXPECT_EQ(c10::TensorTypeId::CPUTensorId, v.value().type_id());}, - c10::optional(dummyTensor(c10::TensorTypeId::CUDATensorId)), [] (const IValue& v) {EXPECT_EQ(c10::TensorTypeId::CUDATensorId, v.toTensor().type_id());}, + c10::optional(dummyTensor(c10::TensorTypeId::CPUTensorId)), [] (const c10::optional& v) {EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(v.value()));}, + c10::optional(dummyTensor(c10::TensorTypeId::CUDATensorId)), [] (const IValue& v) {EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(v.toTensor()));}, "(Tensor? a) -> Tensor?"); // optional types (with has_value() == false) testArgTypes>::test( - c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, - c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, + c10::optional(c10::nullopt), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, + c10::optional(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(float? a) -> float?"); testArgTypes>::test( - c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, - c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, + c10::optional(c10::nullopt), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, + c10::optional(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(int? a) -> int?"); testArgTypes>::test( - c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, - c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, + c10::optional(c10::nullopt), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, + c10::optional(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(bool? a) -> bool?"); testArgTypes>::test( - c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, - c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, + c10::optional(c10::nullopt), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, + c10::optional(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(bool? a) -> bool?"); testArgTypes>::test( - c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, - c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, + c10::optional(c10::nullopt), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, + c10::optional(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(str? a) -> str?"); testArgTypes>::test( - c10::optional(), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, - c10::optional(), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, + c10::optional(c10::nullopt), [] (const c10::optional& v) {EXPECT_FALSE(v.has_value());}, + c10::optional(c10::nullopt), [] (const IValue& v) {EXPECT_TRUE(v.isNone());}, "(Tensor? a) -> Tensor?"); @@ -801,10 +860,6 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { c10::List(), [] (const c10::List& v) {EXPECT_EQ(0, v.size());}, c10::List(), [] (const IValue& v) {EXPECT_EQ(0, v.toGenericListRef().size());}, "(str[] a) -> str[]"); - testArgTypes>::test( - c10::List({}), [] (const c10::List& v) {EXPECT_EQ(0, v.size());}, - c10::List({}), [] (const IValue& v) {EXPECT_EQ(0, v.to>().size());}, - "(Tensor[] a) -> Tensor[]"); // list types (with non-empty list) @@ -831,13 +886,13 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { testArgTypes>::test( c10::List({dummyTensor(c10::TensorTypeId::CPUTensorId), dummyTensor(c10::TensorTypeId::CUDATensorId)}), [] (const c10::List& v) { EXPECT_EQ(2, v.size()); - EXPECT_EQ(c10::TensorTypeId::CPUTensorId, v.get(0).type_id()); - EXPECT_EQ(c10::TensorTypeId::CUDATensorId, v.get(1).type_id()); + EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(v.get(0))); + EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(v.get(1))); }, c10::List({dummyTensor(c10::TensorTypeId::CUDATensorId), dummyTensor(c10::TensorTypeId::CPUTensorId)}), [] (const IValue& v) { EXPECT_EQ(2, v.to>().size()); - EXPECT_EQ(c10::TensorTypeId::CUDATensorId, v.to>().get(0).type_id()); - EXPECT_EQ(c10::TensorTypeId::CPUTensorId, v.to>().get(1).type_id()); + EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(v.to>().get(0))); + EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(v.to>().get(1))); }, "(Tensor[] a) -> Tensor[]"); @@ -855,10 +910,6 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { std::vector(), [] (const std::vector& v) {EXPECT_EQ(0, v.size());}, std::vector(), [] (const IValue& v) {EXPECT_EQ(0, v.toGenericListRef().size());}, "(str[] a) -> str[]"); - testArgTypes>::test( - std::vector({}), [] (const std::vector& v) {EXPECT_EQ(0, v.size());}, - std::vector({}), [] (const IValue& v) {EXPECT_EQ(0, v.to>().size());}, - "(Tensor[] a) -> Tensor[]"); // deprecated list types (with non-empty list) @@ -882,13 +933,13 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { testArgTypes>::test( std::vector({dummyTensor(c10::TensorTypeId::CPUTensorId), dummyTensor(c10::TensorTypeId::CUDATensorId)}), [] (const std::vector& v) { EXPECT_EQ(2, v.size()); - EXPECT_EQ(c10::TensorTypeId::CPUTensorId, v.at(0).type_id()); - EXPECT_EQ(c10::TensorTypeId::CUDATensorId, v.at(1).type_id()); + EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(v.at(0))); + EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(v.at(1))); }, std::vector({dummyTensor(c10::TensorTypeId::CUDATensorId), dummyTensor(c10::TensorTypeId::CPUTensorId)}), [] (const IValue& v) { EXPECT_EQ(2, v.to>().size()); - EXPECT_EQ(c10::TensorTypeId::CUDATensorId, v.to>().get(0).type_id()); - EXPECT_EQ(c10::TensorTypeId::CPUTensorId, v.to>().get(1).type_id()); + EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(v.to>().get(0))); + EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(v.to>().get(1))); }, "(Tensor[] a) -> Tensor[]"); @@ -945,14 +996,14 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { testArgTypes>::test( tensor_dict, [] (c10::Dict v) { EXPECT_EQ(2, v.size()); - EXPECT_EQ(c10::TensorTypeId::CPUTensorId, v.at(1).type_id()); - EXPECT_EQ(c10::TensorTypeId::CUDATensorId, v.at(2).type_id()); + EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(v.at(1))); + EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(v.at(2))); }, tensor_dict, [] (const IValue& v) { c10::Dict dict = c10::impl::toTypedDict(v.toGenericDict()); EXPECT_EQ(2, dict.size()); - EXPECT_EQ(c10::TensorTypeId::CPUTensorId, dict.at(1).type_id()); - EXPECT_EQ(c10::TensorTypeId::CUDATensorId, dict.at(2).type_id()); + EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(dict.at(1))); + EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(dict.at(2))); }, "(Dict(int, Tensor) a) -> Dict(int, Tensor)"); @@ -979,14 +1030,14 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { testArgTypes>::test( tensor_map, [] (std::unordered_map v) { EXPECT_EQ(2, v.size()); - EXPECT_EQ(c10::TensorTypeId::CPUTensorId, v.at(1).type_id()); - EXPECT_EQ(c10::TensorTypeId::CUDATensorId, v.at(2).type_id()); + EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(v.at(1))); + EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(v.at(2))); }, tensor_map, [] (const IValue& v) { c10::Dict dict = c10::impl::toTypedDict(v.toGenericDict()); EXPECT_EQ(2, dict.size()); - EXPECT_EQ(c10::TensorTypeId::CPUTensorId, dict.at(1).type_id()); - EXPECT_EQ(c10::TensorTypeId::CUDATensorId, dict.at(2).type_id()); + EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(dict.at(1))); + EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(dict.at(2))); }, "(Dict(int, Tensor) a) -> Dict(int, Tensor)"); diff --git a/aten/src/ATen/core/op_registration/test_helpers.h b/aten/src/ATen/core/op_registration/test_helpers.h index e03f95db676a5..d5dceee332436 100644 --- a/aten/src/ATen/core/op_registration/test_helpers.h +++ b/aten/src/ATen/core/op_registration/test_helpers.h @@ -79,3 +79,9 @@ void expectListEquals(c10::ArrayRef expected, std::vector actual) { EXPECT_EQ(expected[i], actual[i]); } } + +// NB: This is not really sound, but all of the type sets constructed here +// are singletons so it's fine +static inline c10::TensorTypeId extractTypeId(const at::Tensor& t) { + return legacyExtractTypeId(t.type_set()); +} diff --git a/aten/src/ATen/core/operator_name.h b/aten/src/ATen/core/operator_name.h new file mode 100644 index 0000000000000..b95098e38fc1a --- /dev/null +++ b/aten/src/ATen/core/operator_name.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + +namespace c10 { + +struct OperatorName final { + std::string name; + std::string overload_name; +}; + +inline bool operator==(const OperatorName& lhs, const OperatorName& rhs) { + return lhs.name == rhs.name && lhs.overload_name == rhs.overload_name; +} + +inline bool operator!=(const OperatorName& lhs, const OperatorName& rhs) { + return !operator==(lhs, rhs); +} + +inline std::string toString(const OperatorName& opName) { + std::string result = opName.name; + if (opName.overload_name.size() != 0) { + result += "." + opName.overload_name; + } + return result; +} + +} + +namespace std { + template <> + struct hash<::c10::OperatorName> { + size_t operator()(const ::c10::OperatorName& x) const { + return std::hash()(x.name) ^ (~ std::hash()(x.overload_name)); + } + }; +} diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index a084fae151415..b0e96f15c161c 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -65,6 +65,11 @@ std::ostream& operator<<(std::ostream & out, const Type & t) { return out; } +AnyTypePtr AnyType::get() { + static auto value = AnyType::create(); + return value; +} + TensorTypePtr TensorType::get() { static auto value = TensorType::create( {}, @@ -319,51 +324,70 @@ c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2) { return c10::nullopt; } -MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type_env) { - if(!formal->hasFreeVariables()) { +c10::optional unifyTypeList(at::ArrayRef elements) { + if (elements.size() == 0) { + return c10::nullopt; + } + + c10::optional ret_type = elements[0]; + for (size_t i = 1; i < elements.size() && ret_type; ++i) { + ret_type = unifyTypes(*ret_type, elements[i]); + } + + return ret_type; +} + +MatchTypeReturn matchTypeVariables( + TypePtr formal, + TypePtr actual, + TypeEnv& type_env) { + if (!formal->hasFreeVariables()) { return MatchTypeReturn::Success(); } - if(auto vt = formal->cast()) { + if (auto vt = formal->cast()) { auto it = type_env.find(vt->name()); - if(it == type_env.end()) { + if (it == type_env.end()) { type_env[vt->name()] = actual; return MatchTypeReturn::Success(); - } else if(auto unified = unifyTypes(it->second, actual)) { + } else if (auto unified = unifyTypes(it->second, actual)) { type_env[vt->name()] = *unified; return MatchTypeReturn::Success(); } std::stringstream ss; - ss << "Type variable '" << vt->name() << "' previously matched to type " << - it->second->python_str() << " is matched to type " << actual->python_str(); + ss << "Type variable '" << vt->name() << "' previously matched to type " + << it->second->python_str() << " is matched to type " + << actual->python_str(); return ss.str(); - } else if(auto lt_formal = formal->cast()) { - if(auto lt_actual = actual->cast()) { + } else if (auto lt_formal = formal->cast()) { + if (auto lt_actual = actual->cast()) { const auto innerMatch = matchTypeVariables( - lt_formal->getElementType(), - lt_actual->getElementType(), - type_env); + lt_formal->getElementType(), lt_actual->getElementType(), type_env); if (!innerMatch.success()) { // propagate the errMsg onward return innerMatch; } return MatchTypeReturn::Success(); - } else { - std::stringstream ss; - ss << "Cannot match " << lt_formal->python_str() << " to " - << actual->python_str(); - return ss.str(); + } else if (auto tup_type = actual->cast()) { + auto maybe_tuple_unified = unifyTypeList(tup_type->elements()); + if (maybe_tuple_unified) { + return matchTypeVariables( + lt_formal->getElementType(), *maybe_tuple_unified, type_env); + } } - } else if(auto tp_formal = formal->cast()) { - if(auto tp_actual = actual->cast()) { - if(tp_formal->elements().size() != tp_actual->elements().size()) { + + std::stringstream ss; + ss << "Cannot match " << lt_formal->python_str() << " to " + << actual->python_str(); + return ss.str(); + } else if (auto tp_formal = formal->cast()) { + if (auto tp_actual = actual->cast()) { + if (tp_formal->elements().size() != tp_actual->elements().size()) { return MatchTypeReturn("Cannot match tuples of mismatched size"); } - for(size_t i = 0; i < tp_formal->elements().size(); ++i) { + for (size_t i = 0; i < tp_formal->elements().size(); ++i) { const auto result = matchTypeVariables( - tp_formal->elements()[i], - tp_actual->elements()[i], - type_env); + tp_formal->elements()[i], tp_actual->elements()[i], type_env); if (!result.success()) { return result; } @@ -401,26 +425,20 @@ MatchTypeReturn matchTypeVariables(TypePtr formal, TypePtr actual, TypeEnv& type // unknown type). return matchTypeVariables(opt_formal->getElementType(), actual, type_env); } - // note: if actual was non here we potentially did not fill in the type variables - // contained in the formal. It is still a valid match because None matches Optional[T] - // later error checking on tryEvalTypeVariables will report the problem if we never match - // variables in type T + // note: if actual was non here we potentially did not fill in the type + // variables contained in the formal. It is still a valid match because None + // matches Optional[T] later error checking on tryEvalTypeVariables will + // report the problem if we never match variables in type T return MatchTypeReturn::Success(); } else if (auto dict_formal = formal->cast()) { if (auto dict_actual = actual->cast()) { auto key_match = matchTypeVariables( - dict_formal->getKeyType(), - dict_actual->getKeyType(), - type_env - ); + dict_formal->getKeyType(), dict_actual->getKeyType(), type_env); if (!key_match.success()) { return key_match; } auto value_match = matchTypeVariables( - dict_formal->getValueType(), - dict_actual->getValueType(), - type_env - ); + dict_formal->getValueType(), dict_actual->getValueType(), type_env); if (!value_match.success()) { return value_match; } @@ -471,7 +489,7 @@ const char * typeKindToString(TypeKind kind) { } bool Type::isSubtypeOfExt(const TypePtr rhs, std::ostream* why_not) const { - if (*this == *rhs) { + if (rhs->kind() == TypeKind::AnyType || *this == *rhs) { return true; } if(auto rhs_ = rhs->cast()) { diff --git a/aten/src/ATen/cpu/vec256/vec256_base.h b/aten/src/ATen/cpu/vec256/vec256_base.h index ce0f3d65360a1..89db59cd03ff4 100644 --- a/aten/src/ATen/cpu/vec256/vec256_base.h +++ b/aten/src/ATen/cpu/vec256/vec256_base.h @@ -10,6 +10,7 @@ #include #include #include +#include #if defined(__GNUC__) #define __at_align32__ __attribute__((aligned(32))) @@ -236,6 +237,7 @@ struct Vec256 { return map([](T x) -> T { return -x; }); } Vec256 round() const { + // We do not use std::round because we would like to round midway numbers to the nearest even integer. return map(std::nearbyint); } Vec256 sin() const { @@ -253,6 +255,9 @@ struct Vec256 { Vec256 trunc() const { return map(std::trunc); } + Vec256 lgamma() const { + return map(std::lgamma); + } Vec256 sqrt() const { return map(std::sqrt); } diff --git a/aten/src/ATen/cpu/vec256/vec256_double.h b/aten/src/ATen/cpu/vec256/vec256_double.h index 04ef7fb0f1d5f..62f5525fc77a5 100644 --- a/aten/src/ATen/cpu/vec256/vec256_double.h +++ b/aten/src/ATen/cpu/vec256/vec256_double.h @@ -161,6 +161,9 @@ template <> class Vec256 { Vec256 trunc() const { return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); } + Vec256 lgamma() const { + return Vec256(Sleef_lgammad4_u10(values)); + } Vec256 sqrt() const { return _mm256_sqrt_pd(values); } diff --git a/aten/src/ATen/cpu/vec256/vec256_float.h b/aten/src/ATen/cpu/vec256/vec256_float.h index 40ff536e47e27..f3c75b6a90989 100644 --- a/aten/src/ATen/cpu/vec256/vec256_float.h +++ b/aten/src/ATen/cpu/vec256/vec256_float.h @@ -169,6 +169,9 @@ template <> class Vec256 { Vec256 trunc() const { return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); } + Vec256 lgamma() const { + return Vec256(Sleef_lgammaf8_u10(values)); + } Vec256 sqrt() const { return _mm256_sqrt_ps(values); } diff --git a/aten/src/ATen/cpu/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec256/vec256_qint.h index 6a2147999174a..6adccbc8503af 100644 --- a/aten/src/ATen/cpu/vec256/vec256_qint.h +++ b/aten/src/ATen/cpu/vec256/vec256_qint.h @@ -39,6 +39,109 @@ namespace { #if defined(__AVX__) && !defined(_MSC_VER) +#if defined(__AVX2__) && defined(__FMA__) +template +__m256i pack_saturate_and_clamp( + __m256i first, + __m256i second, + T min_val, + T max_val); + +template <> +__m256i pack_saturate_and_clamp( + __m256i first, + __m256i second, + int8_t min_val, + int8_t max_val) { + __m256i packed_and_sat = _mm256_packs_epi16(first, second); + return _mm256_max_epi8( + _mm256_set1_epi8(min_val), + _mm256_min_epi8(packed_and_sat, _mm256_set1_epi8(max_val))); +} + +template <> +__m256i pack_saturate_and_clamp( + __m256i first, + __m256i second, + uint8_t min_val, + uint8_t max_val) { + __m256i packed_and_sat = _mm256_packus_epi16(first, second); + return _mm256_max_epu8( + _mm256_set1_epi8(min_val), + _mm256_min_epu8(packed_and_sat, _mm256_set1_epi8(max_val))); +} +#endif + +template +inline void __attribute__((always_inline)) QuantizeAvx2( + const float* src, + typename T::underlying* dst, + int len, + float inverse_scale, + int64_t zero_point) { +#if defined(__AVX2__) && defined(__FMA__) + constexpr int VLEN = 8; + constexpr auto min_val = std::numeric_limits::min(); + constexpr auto max_val = std::numeric_limits::max(); + int i = 0; + __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale); + __m256i permute_mask_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + int len_aligned = len / (VLEN * 4) * (VLEN * 4); + for (; i < len_aligned; i += 4 * VLEN) { + // x + __m256 x_vals = _mm256_load_ps(src + i); + __m256 x_transformed_v = + _mm256_fmadd_ps(x_vals, inverse_scale_v, _mm256_set1_ps(zero_point)); + // y + __m256 y_vals = _mm256_load_ps(src + i + VLEN); + __m256 y_transformed_v = + _mm256_fmadd_ps(y_vals, inverse_scale_v, _mm256_set1_ps(zero_point)); + // z + __m256 z_vals = _mm256_load_ps(src + i + 2 * VLEN); + __m256 z_transformed_v = + _mm256_fmadd_ps(z_vals, inverse_scale_v, _mm256_set1_ps(zero_point)); + // w + __m256 w_vals = _mm256_load_ps(src + i + 3 * VLEN); + __m256 w_transformed_v = + _mm256_fmadd_ps(w_vals, inverse_scale_v, _mm256_set1_ps(zero_point)); + + __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v); + __m256i y_rounded_v = _mm256_cvtps_epi32(y_transformed_v); + __m256i z_rounded_v = _mm256_cvtps_epi32(z_transformed_v); + __m256i w_rounded_v = _mm256_cvtps_epi32(w_transformed_v); + + __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v); + __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v); + __m256i xyzw_clamped_v = pack_saturate_and_clamp( + xy_packed_v, zw_packed_v, min_val, max_val); + + xyzw_clamped_v = + _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst + i), xyzw_clamped_v); + } + + for (; i < len; ++i) { + float transformed = zero_point + src[i] * inverse_scale; + float clipped = + std::min(std::max(transformed, float(min_val)), float(max_val)); + // Not exactly the same behavior as the vectorized code. + // The vectorized code above always rounds to even in halfway cases + // (https://software.intel.com/en-us/node/523819), but std::nearbyint + // does the same only when the current rounding mode is FE_TONEAREST. + // However, in practice, this should not be a problem because most cases + // use the default rounding mode FE_TONEAREST. + // Note that we cannot implement the same behavior as the vectorized code + // using std::round because it does rounding away from zero in halfway + // cases. + dst[i] = nearbyint(clipped); + } +#else + at::quantize_vec( + 1.0f / inverse_scale, zero_point, src, reinterpret_cast(dst), len); +#endif +} + template<> struct Vec256 { static constexpr int size() { @@ -99,53 +202,73 @@ struct Vec256 { #endif } - // This needs to be a separate template function because _mm256_extract_epi64 - // requires an immediate operand for the index - template - Vec256 extract_and_dequantize(Vec256 scale, Vec256 zero_point) const { - __m128i int_val; - int_val[0] = _mm256_extract_epi64(vals, idx); - __m256 float_val = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val)); - // TODO this could probably be an FMA - return scale * (Vec256(float_val) - zero_point); - } - public: - float_vec_return_type dequantize(Vec256 scale, Vec256 zero_point) const { - return { - extract_and_dequantize<0>(scale, zero_point), - extract_and_dequantize<1>(scale, zero_point), - extract_and_dequantize<2>(scale, zero_point), - extract_and_dequantize<3>(scale, zero_point) - }; - } - + float_vec_return_type dequantize( + Vec256 scale, + Vec256 zero_point, + Vec256 scale_zp_premul) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256 float_val0 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m256 float_val1 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3)); + +#if defined(__AVX2__) && defined(__FMA__) + auto val0 = + vec256::fmadd(scale, Vec256(float_val0), scale_zp_premul); + auto val1 = + vec256::fmadd(scale, Vec256(float_val1), scale_zp_premul); + auto val2 = + vec256::fmadd(scale, Vec256(float_val2), scale_zp_premul); + auto val3 = + vec256::fmadd(scale, Vec256(float_val3), scale_zp_premul); +#else + auto val0 = scale * (Vec256(float_val0) - zero_point); + auto val1 = scale * (Vec256(float_val1) - zero_point); + auto val2 = scale * (Vec256(float_val2) - zero_point); + auto val3 = scale * (Vec256(float_val3) - zero_point); +#endif + return {val0, val1, val2, val3}; + } - static Vec256 quantize(const float_vec_return_type& rhs, float scale, int32_t zero_point) { - Vec256 retval; - auto *rhs_data = (float*)rhs.data(); - at::quantize_vec(scale, zero_point,rhs_data, (c10::qint8*)&retval.vals, 32); - return retval; - } + static Vec256 quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + auto* rhs_data = (float*)rhs.data(); + int8_t quantized_values[32]; + QuantizeAvx2( + rhs_data, quantized_values, 32, inverse_scale, zero_point); + return Vec256::loadu(quantized_values); + } - Vec256 relu(Vec256 zero_point) { + Vec256 maximum(Vec256 b) const { #ifdef __AVX2__ - return _mm256_max_epi8(vals, zero_point.vals); + return _mm256_max_epi8(vals, b.vals); #else // Pray the compiler can autovectorize this int8_t int_vals[size()]; _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - int8_t zero_point_vals[size()]; + int8_t b_vals[size()]; _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals); + reinterpret_cast<__m256i*>(&b_vals), b.vals); int8_t result_vals[size()]; for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::max(int_vals[i], zero_point_vals[i]); + result_vals[i] = std::max(int_vals[i], b_vals[i]); } return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); #endif } + Vec256 relu(Vec256 zero_point) const { + return maximum(zero_point); + } + Vec256 relu6( Vec256 zero_point, Vec256 q_six) { @@ -187,6 +310,11 @@ struct Vec256 { } }; +template <> +Vec256 inline maximum(const Vec256& a, const Vec256& b) { + return a.maximum(b); +} + template<> struct Vec256 { static constexpr int size() { @@ -244,52 +372,73 @@ struct Vec256 { #endif } - // This needs to be a separate template function because _mm256_extract_epi64 - // requires an immediate operand for the index - template - Vec256 extract_and_dequantize(Vec256 scale, Vec256 zero_point) const { - __m128i int_val; - int_val[0] = _mm256_extract_epi64(vals, idx); - __m256 float_val = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val)); - // TODO this could probably be an FMA - return scale * (Vec256(float_val) - zero_point); - } - public: - float_vec_return_type dequantize(Vec256 scale, Vec256 zero_point) const { - return { - extract_and_dequantize<0>(scale, zero_point), - extract_and_dequantize<1>(scale, zero_point), - extract_and_dequantize<2>(scale, zero_point), - extract_and_dequantize<3>(scale, zero_point) - }; - } + float_vec_return_type dequantize( + Vec256 scale, + Vec256 zero_point, + Vec256 scale_zp_premul) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256 float_val0 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m256 float_val1 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3)); + +#if defined(__AVX2__) && defined(__FMA__) + auto val0 = + vec256::fmadd(scale, Vec256(float_val0), scale_zp_premul); + auto val1 = + vec256::fmadd(scale, Vec256(float_val1), scale_zp_premul); + auto val2 = + vec256::fmadd(scale, Vec256(float_val2), scale_zp_premul); + auto val3 = + vec256::fmadd(scale, Vec256(float_val3), scale_zp_premul); +#else + auto val0 = scale * (Vec256(float_val0) - zero_point); + auto val1 = scale * (Vec256(float_val1) - zero_point); + auto val2 = scale * (Vec256(float_val2) - zero_point); + auto val3 = scale * (Vec256(float_val3) - zero_point); +#endif + return {val0, val1, val2, val3}; + } - static Vec256 quantize(const float_vec_return_type& rhs, float scale, int32_t zero_point) { - Vec256 retval; - auto *rhs_data = (float*)rhs.data(); - at::quantize_vec(scale, zero_point,rhs_data, (c10::quint8*)&retval.vals, 32); - return retval; - } + static Vec256 quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + auto* rhs_data = (float*)rhs.data(); + uint8_t quantized_values[32]; + QuantizeAvx2( + rhs_data, quantized_values, 32, inverse_scale, zero_point); + return Vec256::loadu(quantized_values); + } - Vec256 relu(Vec256 zero_point) { + Vec256 maximum(Vec256 b) const { #ifdef __AVX2__ - return _mm256_max_epu8(vals, zero_point.vals); + return _mm256_max_epu8(vals, b.vals); #else // Pray the compiler can autovectorize this uint8_t int_vals[size()]; _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - uint8_t zero_point_vals[size()]; + uint8_t b_vals[size()]; _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals); + reinterpret_cast<__m256i*>(&b_vals), b.vals); uint8_t result_vals[size()]; for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::max(int_vals[i], zero_point_vals[i]); + result_vals[i] = std::max(int_vals[i], b_vals[i]); } return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); #endif } + Vec256 relu(Vec256 zero_point) const { + return maximum(zero_point); + } + Vec256 relu6( Vec256 zero_point, Vec256 q_six) { @@ -331,6 +480,11 @@ struct Vec256 { } }; +template <> +Vec256 inline maximum(const Vec256& a, const Vec256& b) { + return a.maximum(b); +} + template<> struct Vec256 { static constexpr int size() { @@ -366,41 +520,52 @@ struct Vec256 { return Vec256(ptr); } - float_vec_return_type dequantize(Vec256 scale, Vec256 zero_point) const { + float_vec_return_type dequantize( + Vec256 scale, + Vec256 zero_point, + Vec256 scale_zp_premul) const { __m256 float_vals = _mm256_cvtepi32_ps(vals); +#if defined(__AVX2__) && defined(__FMA__) + return {vec256::fmadd(scale, Vec256(float_vals), scale_zp_premul)}; +#else return {scale * (Vec256(float_vals) - zero_point)}; +#endif } - static Vec256 quantize(const float_vec_return_type& rhs, float scale, int32_t zero_point) { - Vec256 retval; - auto rhs_data = (__m256)rhs[0]; - at::quantize_vec( - scale, - zero_point, - (float*)&rhs_data, - (c10::qint32*)&retval.vals, - 8); - return retval; + static Vec256 quantize( + const float_vec_return_type& rhs, + float scale, + int32_t zero_point, + float inverse_scale) { + Vec256 retval; + auto rhs_data = (__m256)rhs[0]; + at::quantize_vec( + scale, zero_point, (float*)&rhs_data, (c10::qint32*)&retval.vals, 8); + return retval; } - Vec256 relu(Vec256 zero_point) { + Vec256 maximum(Vec256 b) const { #ifdef __AVX2__ - return _mm256_max_epi32(vals, zero_point.vals); + return _mm256_max_epi32(vals, b.vals); #else // Pray the compiler can autovectorize this int32_t int_vals[size()]; _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals); - int32_t zero_point_vals[size()]; + int32_t b_vals[size()]; _mm256_storeu_si256( - reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals); + reinterpret_cast<__m256i*>(&b_vals), b.vals); int32_t result_vals[size()]; for (size_t i = 0; i < size(); ++i) { - result_vals[i] = std::max(int_vals[i], zero_point_vals[i]); + result_vals[i] = std::max(int_vals[i], b_vals[i]); } return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals)); #endif } + Vec256 relu(Vec256 zero_point) const { + return maximum(zero_point); + } + Vec256 relu6( Vec256 zero_point, Vec256 q_six) { @@ -442,6 +607,11 @@ struct Vec256 { } }; +template <> +Vec256 inline maximum(const Vec256& a, const Vec256& b) { + return a.maximum(b); +} + #else // NOTE: These are low-performance implementations that we fall back on @@ -483,7 +653,8 @@ struct Vec256QuantizedConverter { float_vec_return_type dequantize( Vec256 scale, - Vec256 zero_point) const { + Vec256 zero_point, + Vec256 scale_zp_premul) const { float_vec_return_type rv; for (int i = 0; i < float_num_vecs(); ++i) { for (int j = 0; j < 8; ++j) { @@ -524,7 +695,8 @@ struct Vec256 : public Vec256QuantizedConverter< static Vec256 quantize( const float_vec_return_type& rhs, float scale, - int32_t zero_point) { + int32_t zero_point, + float inverse_scale) { value_type qvals[size()]; float float_vals[float_num_vecs() * 8]; @@ -542,14 +714,18 @@ struct Vec256 : public Vec256QuantizedConverter< return Vec256::loadu(qvals); } - Vec256 relu(Vec256 zero_point) { + Vec256 maximum(Vec256 b) const { Vec256 retval; for (size_t i = 0; i < size(); ++i) { - retval.vals[i] = std::max(vals[i], zero_point.vals[i]); + retval.vals[i] = std::max(vals[i], b.vals[i]); } return retval; } + Vec256 relu(Vec256 zero_point) const { + return maximum(zero_point); + } + Vec256 relu6( Vec256 zero_point, Vec256 q_six) { @@ -565,6 +741,11 @@ struct Vec256 : public Vec256QuantizedConverter< Vec256() {} }; +template <> +Vec256 inline maximum(const Vec256& a, const Vec256& b) { + return a.maximum(b); +} + template <> struct Vec256 : public Vec256QuantizedConverter< c10::quint8, @@ -584,7 +765,8 @@ struct Vec256 : public Vec256QuantizedConverter< static Vec256 quantize( const float_vec_return_type& rhs, float scale, - int32_t zero_point) { + int32_t zero_point, + float inverse_scale) { value_type qvals[size()]; float float_vals[float_num_vecs() * 8]; @@ -602,14 +784,19 @@ struct Vec256 : public Vec256QuantizedConverter< return Vec256::loadu(qvals); } - Vec256 relu(Vec256 zero_point) { + Vec256 maximum(Vec256 b) const { Vec256 retval; for (size_t i = 0; i < size(); ++i) { - retval.vals[i] = std::max(vals[i], zero_point.vals[i]); + retval.vals[i] = std::max(vals[i], b.vals[i]); } return retval; } + Vec256 relu(Vec256 zero_point) const { + return maximum(zero_point); + } + + Vec256 relu6( Vec256 zero_point, Vec256 q_six) { @@ -625,6 +812,11 @@ struct Vec256 : public Vec256QuantizedConverter< Vec256() {} }; +template <> +Vec256 inline maximum(const Vec256& a, const Vec256& b) { + return a.maximum(b); +} + template <> struct Vec256 : public Vec256QuantizedConverter< c10::qint32, @@ -644,7 +836,8 @@ struct Vec256 : public Vec256QuantizedConverter< static Vec256 quantize( const float_vec_return_type& rhs, float scale, - int32_t zero_point) { + int32_t zero_point, + float inverse_scale) { value_type qvals[size()]; float float_vals[float_num_vecs() * 8]; @@ -662,14 +855,19 @@ struct Vec256 : public Vec256QuantizedConverter< return Vec256::loadu(qvals); } - Vec256 relu(Vec256 zero_point) { + Vec256 maximum(Vec256 b) const { Vec256 retval; for (size_t i = 0; i < size(); ++i) { - retval.vals[i] = std::max(vals[i], zero_point.vals[i]); + retval.vals[i] = std::max(vals[i], b.vals[i]); } return retval; } + Vec256 relu(Vec256 zero_point) const { + return maximum(zero_point); + } + + Vec256 relu6( Vec256 zero_point, Vec256 q_six) { @@ -685,6 +883,11 @@ struct Vec256 : public Vec256QuantizedConverter< Vec256() {} }; +template <> +Vec256 inline maximum(const Vec256& a, const Vec256& b) { + return a.maximum(b); +} + #endif // defined(__AVX__) && !defined(_MSC_VER) -}}} \ No newline at end of file +}}} diff --git a/aten/src/ATen/cpu/vml.h b/aten/src/ATen/cpu/vml.h index f104f3eb85cd7..c976e8a5e5346 100644 --- a/aten/src/ATen/cpu/vml.h +++ b/aten/src/ATen/cpu/vml.h @@ -123,6 +123,8 @@ IMPLEMENT_VML(rsqrt) IMPLEMENT_VML_BUG(tan) IMPLEMENT_VML_BUG(tanh) IMPLEMENT_VML_BUG(trunc) +IMPLEMENT_VML_BUG(lgamma) + #if AT_MKL_ENABLED() && !defined(__APPLE__) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index a8c64760c8da0..8b73489964114 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -217,9 +217,7 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)) { rocblas_datatype_f32_r, rocblas_gemm_algo_standard, 0, - 0, - NULL, - NULL)); + 0)); #else # if CUDA_VERSION >= 9000 diff --git a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h index 56a4b6893f15d..36b7bd99acffc 100644 --- a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h +++ b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h @@ -2,10 +2,7 @@ #include #include - -#ifndef __HIP_PLATFORM_HCC__ #include -#endif namespace at { namespace cuda { @@ -64,14 +61,22 @@ namespace at { namespace cuda { // list above. // // HIP doesn't have -// nvrtc* // cuOccupancyMaxActiveBlocksPerMultiprocessor // cuGetErrorString (maps to non-functional hipGetErrorString___) #define AT_FORALL_NVRTC(_) \ + _(nvrtcVersion) \ + _(nvrtcCreateProgram) \ + _(nvrtcDestroyProgram) \ + _(nvrtcGetPTXSize) \ + _(nvrtcGetPTX) \ _(cuModuleLoadData) \ _(cuModuleGetFunction) \ + _(nvrtcGetErrorString) \ + _(nvrtcGetProgramLogSize) \ + _(nvrtcGetProgramLog) \ _(cuLaunchKernel) \ + _(nvrtcCompileProgram) \ _(cuCtxGetCurrent) \ _(cuModuleUnload) \ _(cuDevicePrimaryCtxGetState) diff --git a/aten/src/ATen/detail/ScalarTypeConversions.h b/aten/src/ATen/detail/ScalarTypeConversions.h index ef04271397163..10f53bd1736b3 100644 --- a/aten/src/ATen/detail/ScalarTypeConversions.h +++ b/aten/src/ATen/detail/ScalarTypeConversions.h @@ -9,14 +9,14 @@ namespace at { namespace detail { template inline T load(const void* data, ScalarType src_type) { - return AT_DISPATCH_ALL_TYPES(src_type, "load", [&]() { + return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src_type, "load", [&]() { return at::convert(*(scalar_t*)data); }); } template inline void store(T value, void* dst, ScalarType dst_type) { - AT_DISPATCH_ALL_TYPES(dst_type, "store", [&]() { + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, dst_type, "store", [&]() { *(scalar_t*)dst = at::convert(value); }); } diff --git a/aten/src/ATen/env.py b/aten/src/ATen/env.py index 00eb22eb4416d..4f3d4e8dee1e1 100644 --- a/aten/src/ATen/env.py +++ b/aten/src/ATen/env.py @@ -9,4 +9,4 @@ def check_env_flag(name, default=''): def check_negative_env_flag(name, default=''): return os.getenv(name, default).upper() in ['OFF', '0', 'NO', 'FALSE', 'N'] -BUILD_NAMEDTENSOR = check_env_flag('BUILD_NAMEDTENSOR') +BUILD_NAMEDTENSOR = True diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index 8dea9fddec613..be1cff9eed5dd 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -113,10 +113,16 @@ def TypedDict(name, attrs, total=True): # type: ignore """) DEFAULT_FUNCTION_REGISTRATION = CodeTemplate("""\ -.registerOp<${return_type} (${formals_types})>(Backend::Undefined, "${schema_string}", &TypeDefault::${api_name}) +.op(torch::RegisterOperators::options() + .schema("${schema_string}") + .impl_unboxedOnlyCatchAllKernel<${return_type} (${formals_types}), &TypeDefault::${api_name}>() + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) """) BACKEND_FUNCTION_REGISTRATION = CodeTemplate("""\ -.registerOp<${return_type} (${formals_types})>(Backend::${Backend}, "${schema_string}", &${Type}::${api_name}) +.op(torch::RegisterOperators::options() + .schema("${schema_string}") + .impl_unboxedOnlyKernel<${return_type} (${formals_types}), &${Type}::${api_name}>(TensorTypeId::${Backend}TensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) """) # Generate a file that lists all functions and their schema string. Used for XLA @@ -132,10 +138,21 @@ def TypedDict(name, attrs, total=True): # type: ignore TENSOR_METHOD_DEFINITION = CodeTemplate("""\ inline ${return_type} Tensor::${api_name}(${method_formals}) const { #ifdef USE_STATIC_DISPATCH - ${mobile_method_body} + ${static_dispatch_method_body} #else static auto table = globalATenDispatch().getOpTable("${schema_string}"); - return table->getOp<${return_type} (${formals_types})>(tensorTypeIdToBackend(type_id()), is_variable())(${method_actuals}); + return table->getOp<${return_type} (${formals_types})>(${inferred_type_set})(${method_actuals}); +#endif +} +""") +C10_TENSOR_METHOD_DEFINITION = CodeTemplate("""\ +inline ${return_type} Tensor::${api_name}(${method_formals}) const { +#ifdef USE_STATIC_DISPATCH + ${static_dispatch_method_body} +#else + static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::${name}", "${overload_name}"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(${inferred_type_set})) + .callUnboxed<${formals_types_with_return}>(${method_actuals}); #endif } """) @@ -151,27 +168,40 @@ def TypedDict(name, attrs, total=True): # type: ignore FUNCTION_DEFINITION = CodeTemplate("""\ static inline ${return_type} ${api_name}(${formals}) { #ifdef USE_STATIC_DISPATCH - ${mobile_function_body} + ${static_dispatch_function_body} #else static auto table = globalATenDispatch().getOpTable("${schema_string}"); - return table->getOp<${return_type} (${formals_types})>(${inferred_backend}, ${inferred_is_variable})(${native_actuals}); + return table->getOp<${return_type} (${formals_types})>(${inferred_type_set})(${native_actuals}); +#endif +} +""") + +C10_FUNCTION_DEFINITION = CodeTemplate("""\ +static inline ${return_type} ${api_name}(${formals}) { +#ifdef USE_STATIC_DISPATCH + ${static_dispatch_function_body} +#else + static c10::OperatorHandle op = c10::Dispatcher::singleton() + .findSchema({"aten::${name}", "${overload_name}"}).value(); + return c10::Dispatcher::singleton().lookup(op, impl::dispatchTypeId(${inferred_type_set})) + .callUnboxed<${formals_types_with_return}>(${native_actuals}); #endif } """) -# for mobile builds, we rely on the linker to strip unused ops. -# this requires us to dispatch statically in Functions.h and TensorMethods.h -MOBILE_FUNCTION_DEFAULT_BODY = CodeTemplate("""\ +# In order to rely on the linker to strip unused ops, it requires us to dispatch statically +# in Functions.h and TensorMethods.h. +STATIC_DISPATCH_FUNCTION_DEFAULT_BODY = CodeTemplate("""\ ${return_call} TypeDefault::${native_type_method_dispatch}(${native_arguments}); """) -MOBILE_FUNCTION_SWITCH_BODY = CodeTemplate("""\ -switch(${backend}) { - ${mobile_function_switches} +STATIC_DISPATCH_FUNCTION_SWITCH_BODY = CodeTemplate("""\ +switch(tensorTypeIdToBackend(impl::dispatchTypeId(${type_set}))) { + ${static_dispatch_function_switches} default: - AT_ERROR("${api_name} not implemented for ", at::toString(${backend})); + AT_ERROR("${api_name} not implemented for ", at::toString(${type_set})); } """) -MOBILE_FUNCTION_SWITCH_STATEMENT = CodeTemplate("""\ +STATIC_DISPATCH_FUNCTION_SWITCH_STATEMENT = CodeTemplate("""\ case Backend::${backend}: ${return_call} ${backend}Type::${api_name}(${native_arguments}); break; @@ -186,11 +216,11 @@ def TypedDict(name, attrs, total=True): # type: ignore FACTORY_DEFINITION = CodeTemplate("""\ static inline ${return_type} ${api_name}(${formals}) { #ifdef USE_STATIC_DISPATCH - ${mobile_function_body} + ${static_dispatch_function_body} #else - globalLegacyTypeDispatch().initForBackend(${inferred_backend}); + globalLegacyTypeDispatch().initForTensorTypeSet(${inferred_type_set}); static auto table = globalATenDispatch().getOpTable("${schema_string}"); - return table->getOp<${return_type} (${formals_types})>(${inferred_backend}, ${inferred_is_variable})(${native_actuals}); + return table->getOp<${return_type} (${formals_types})>(${inferred_type_set})(${native_actuals}); #endif } """) @@ -217,6 +247,10 @@ def TypedDict(name, attrs, total=True): # type: ignore CALL_TEMPLATE = CodeTemplate("${cname}(${actuals})") +OPERATOR_NAME = CodeTemplate("""\ + {"aten::${operator_name}", "${overload_name}"}, +""") + NAMEDTENSOR_CHECK = CodeTemplate("""\ #ifdef BUILD_NAMEDTENSOR ${code} @@ -236,7 +270,7 @@ def TypedDict(name, attrs, total=True): # type: ignore ('BFloat16', 'BFloat16', 'BFloat16AccrealNotDefined', True), ] -mobile_backends = ['CPU', 'QuantizedCPU', 'SparseCPU'] +static_dispatch_backends = ['CPU', 'QuantizedCPU', 'SparseCPU'] class NYIError(Exception): @@ -348,7 +382,7 @@ def __init__(self, reason): 'TensorList': CodeTemplate( 'checked_tensor_list_unwrap(${arg_name},"${arg_name}",${arg_pos}, ' 'Backend::${Backend}, ScalarType::${ScalarName})'), - 'IntArrayRef': CodeTemplate('check_intlist<${size}>(${arg_name}, "${arg_name}", ${arg_pos}${,default_init})') + 'IntArrayRef': CodeTemplate('check_intlist<${size}>(${arg_name}, "${arg_name}", ${arg_pos})') } CHECKED_USE = { @@ -396,13 +430,11 @@ def __init__(self, reason): # Replacements for constants when calling into TH CONSTANT_REPLACEMENTS = [ ('AS_REAL', '${ScalarType}'), - ('__last_dim', 'self.ndimension()-1'), ] # Replacements for constants in header file function definitions HEADER_CONSTANT_REPLACEMENTS = [ (r'AS_REAL\((.*)\)', r'\1'), - ('__last_dim', '-1'), ] @@ -432,6 +464,8 @@ def __getitem__(self, x): 'type_registrations': List[str], 'type_headers': List[str], 'function_registrations': List[str], + 'c10_ops_already_moved_from_aten_to_c10': List[str], + 'c10_ops_not_moved_from_aten_to_c10_yet': List[str], 'type_method_declarations': List[str], 'type_method_definitions': List[str], 'tensor_method_declarations': List[str], @@ -452,15 +486,10 @@ def __getitem__(self, x): 'kwarg_only': bool, 'is_nullable': bool, 'default': str, - 'default_init': str, 'output': bool, 'size': int, - 'declared_type': str, - 'ignore_check': bool, 'allocate': bool, 'mask': bool, - 'if_true': bool, - 'if_false': bool, 'wrap_dim': str, # Broadcast is originally a str but gets unwrapped to a List or Dict in-place 'broadcast': Any, @@ -478,7 +507,6 @@ def __getitem__(self, x): 'kwarg_only': bool, 'is_nullable': bool, 'default': str, - 'default_init': str, 'output': bool, 'size': int, }, total=False) @@ -538,6 +566,7 @@ def __getitem__(self, x): 'device_guard': bool, 'device_guard_declaration': str, 'dispatch_scalar_type_declaration': str, + 'use_c10_dispatcher': bool, 'with_gil': bool, 'cpu_half': bool, 'cpu_bfloat16': bool, @@ -551,8 +580,8 @@ def __getitem__(self, x): 'formals_with_defaults': List[str], 'formals': List[str], 'formals_types': List[str], - 'inferred_backend': str, - 'inferred_is_variable': str, + 'formals_types_with_return': List[str], + 'inferred_type_set': str, 'inplace': bool, 'matches_jit_signature': bool, # This controls whether or not we generate the interface in Type or @@ -566,8 +595,10 @@ def __getitem__(self, x): 'mode': str, 'python_module': str, 'name': str, + 'operator_name': str, 'overload_name': str, 'native_actuals': List[str], + 'native_actuals_with_comma_prefix': str, 'native_type_method_dispatch': str, # options should be List[FunctionOption] 'options': Any, @@ -593,7 +624,9 @@ def __getitem__(self, x): OutputDeclaration = NamedTuple('OutputDeclaration', [ ('name', str), + ('operator_name', str), ('overload_name', str), + ('use_c10_dispatcher', bool), ('matches_jit_signature', bool), ('schema_string', str), ('method_prefix_derived', str), @@ -629,9 +662,9 @@ def device_guard(option, dispatch_options, dispatch_tensor): def named_guard(option, tensors, tensorlists): - if not option.get('named_guard', True) or (len(tensors) + len(tensorlists) == 0): + if option.get('supports_named_tensor', False) or (len(tensors) + len(tensorlists) == 0): return '' - # Override: named_guard = True for _th_ functions. This is because: + # Override: supports_named_tensor = False for _th_ functions. This is because: # There is always some at:: function that calls the _th_ function. if option['name'].startswith('_th_'): return '' @@ -642,7 +675,10 @@ def named_guard(option, tensors, tensorlists): named_conditions.append('at::has_names({})'.format(tensorlist)) return ("""\ if ({named_conditions}) {{ - AT_ERROR("{op}: no named inference rule implemented."); + AT_ERROR( + "{op} is not yet supported with named tensors. Please drop names via " + "`tensor = tensor.renamed(None)`, call the op with an unnamed tensor, " + "and set names on the result of the operation."); }}""".format(named_conditions=' || '.join(named_conditions), op=option['name'])) @@ -700,8 +736,6 @@ def translate_default(argument, type_str, default): if default is None: # cause the default constructor for the object to run return '{}' - if 'if_true' in argument: - return argument['default'] == argument['if_true'] for pattern, replacement in HEADER_CONSTANT_REPLACEMENTS: default = re.sub(pattern, replacement, str(default)) if type_str in {'Scalar', 'int64_t', 'double'}: @@ -735,7 +769,6 @@ def translate_formal(argument, option): if 'default' in argument: default = translate_default(argument, type_str, argument['default']) translated['default'] = default - translated['default_init'] = argument.get('default_init', default) if argument.get('output'): translated['output'] = True if argument.get('size'): @@ -823,7 +856,13 @@ def find_tensorlists(formals): def find_dispatch_tensor(formals): # type: (List[AtFormal]) -> Optional[str] - # dispatch to self if it's a parameter + # Determine legacy TH-style single dispatch tensor. + # + # Also used to determine what tensor should be used to provide a default + # DeviceGuard. Unlike dispatch, we don't guard on ALL tensor arguments + # (because this is not actually a thing you can do.) Guarding on the + # first argument is best effort to help people avoid doing this + # themselves. for formal in formals: if formal['name'] == 'self' and is_any_tensor_type(formal) and not formal.get('is_nullable', False): @@ -836,6 +875,28 @@ def find_dispatch_tensor(formals): return None + def find_multidispatch_tensors(formals): + # type: (List[AtFormal]) -> List[str] + # Compute the list of all tensor arguments which should be considered + # for multiple dispatch. Note that this doesn't completely replace + # find_dispatch_tensor because we use the "dispatch tensor" to determine + # device guards. TensorOptions is included as part of this calculation. + # + # The interaction of multiple dispatch with TensorOptions + # is quite interesting. In particular, suppose I have: + # + # cuda_tensor.new_like(1, device='cpu') + # + # Multiple dispatch will attempt a dispatch to CUDA, even though + # the end tensor that should be produced here is a CPU one. The + # upshot is that if you have an operator with mixed TensorOptions + # and Tensor arguments, you MUST only ever register it generically. + r = [] + for formal in formals: + if formal['dynamic_type'] in ['TensorOptions', 'TensorList'] or is_any_tensor_type(formal): + r.append(formal['name']) + return r + def format_formal(f): # type: (AtFormal) -> str return '{} {}'.format(f['type'], f['name']) @@ -878,32 +939,10 @@ def get_broadcast_actuals(broadcast_arg, broadcast_inplace, broadcast_dims): return broadcast_actuals - def emit_nn_body(option): - # type: (FunctionOption) -> Union[str, List[str]] - # Concrete definition on Type.cpp for NN functions. Delegates to the - # xxx_forward variant variant after creating any necessary buffers. - actuals = option['actuals'] - base_name = option['name'][:-1] if option['inplace'] else option['name'] - fwd_name = option['api_name'].replace(base_name, base_name + '_forward') - - if len(option['buffers']) == 0: - return 'return {}({});'.format(fwd_name, ', '.join(actuals)) - - body = [] # type: List[str] - if option['api_name'].endswith('_out'): - # _out variants must create buffers and insert them in the - # arguments list between output and input arguments - for buffer in option['buffers']: - body.append('Tensor {} = at::empty({{0}}, this->options());'.format(buffer['name'])) - actuals = [arg['name'] for arg in option['arguments'] if arg.get('output')] - actuals += [buffer['name'] for buffer in option['buffers']] - actuals += [arg['name'] for arg in option['arguments'] if not arg.get('output')] - - body.append('return std::get<0>({}({}));'.format(fwd_name, ', '.join(actuals))) - return body - - def process_option(option): + def process_legacy_th_option(option): # type: (FunctionOption) -> None + # Mutably populate option with derived values computed from values + # passed in to option. option['inplace'] = re.search( '(^__i|[^_]_$)', option['api_name']) is not None @@ -932,6 +971,7 @@ def process_option(option): assert 'method' not in option['variants'], 'TH functions cannot be methods' is_function = 'function' in option['variants'] + # NB: TH functions don't support multiple dispatch dispatch_tensor = find_dispatch_tensor(formals) is_namespace_function = is_function and dispatch_tensor is not None @@ -1072,6 +1112,14 @@ def process_native(option): option['formals_types'] = [f['type'] for f in option['formals_list']] option['native_actuals'] = [f['name'] for f in option['formals_list']] + if len(option['native_actuals']) == 0: + option['native_actuals_with_comma_prefix'] = '' + else: + option['native_actuals_with_comma_prefix'] = ', ' + ', '.join(option['native_actuals']) + + option['formals_types_with_return'] = [option['return_type']] + if len(option['formals_types']) > 0: + option['formals_types_with_return'].extend(option['formals_types']) option['method_formals'] = [format_formal(f) for f in formals if f['name'] != 'self'] @@ -1091,65 +1139,87 @@ def find_formal(formal_name, formals): def has_named_tensor_formals(formals): return any(['Dimname' in formal['dynamic_type'] for formal in formals]) - def gen_tensor_method(option): - # type: (Any) -> FunctionCode + def gen_tensor_method(option, multidispatch_tensors): + # type: (Any, List[str]) -> FunctionCode + def swizzle_self(t): # blegh + if t == 'self': + return '*this' + else: + return t + option['inferred_type_set'] = 'at::detail::multi_dispatch_tensor_type_set({})'.format( + ', '.join(swizzle_self(t) for t in multidispatch_tensors) + ) + if isinstance(type_method_dispatch, dict): - mobile_function_switches = [] - for backend in mobile_backends: + static_dispatch_function_switches = [] + # NB: As this code is currently written, there will NEVER be + # a backend generated for variable dispatch. There is nothing + # stopping us from actually implementing this, however, if you + # really wanted variable on mobile, there's nothing stopping + # you from implementing this (however, you would have an + # annoying phase problem, since code generation for variable + # happens in tools/ which happens later than here.) + # + # If you pass in a variable to the dispatch, and variable is + # enabled, this switch will fail. This is intentional: you + # probably need to disable variable globally in the mobile + # calling code. + for backend in static_dispatch_backends: if backend in type_method_dispatch: - mobile_function_switches.append(MOBILE_FUNCTION_SWITCH_STATEMENT.substitute( + static_dispatch_function_switches.append(STATIC_DISPATCH_FUNCTION_SWITCH_STATEMENT.substitute( option, backend=backend, backend_function=type_method_dispatch[backend], native_arguments=option['method_actuals'])) - mobile_method_body = MOBILE_FUNCTION_SWITCH_BODY.substitute( + static_dispatch_method_body = STATIC_DISPATCH_FUNCTION_SWITCH_BODY.substitute( option, - backend='tensorTypeIdToBackend(type_id())', - mobile_function_switches=mobile_function_switches) + type_set='type_set()', + static_dispatch_function_switches=static_dispatch_function_switches) else: - mobile_method_body = MOBILE_FUNCTION_DEFAULT_BODY.substitute( + static_dispatch_method_body = STATIC_DISPATCH_FUNCTION_DEFAULT_BODY.substitute( option, native_arguments=option['method_actuals']) + method_definition = (C10_TENSOR_METHOD_DEFINITION if option['use_c10_dispatcher'] else TENSOR_METHOD_DEFINITION) return FunctionCode( - declaration=TENSOR_METHOD_DECLARATION.substitute(option, mobile_method_body=mobile_method_body), - definition=TENSOR_METHOD_DEFINITION.substitute(option, mobile_method_body=mobile_method_body)) - - def gen_namespace_function(option, dispatch_tensor, dispatch_options): - # type: (Any, Optional[str], Any) -> FunctionCode - if dispatch_tensor: - option['inferred_backend'] = 'at::detail::infer_backend({})'.format(dispatch_tensor) - option['inferred_is_variable'] = 'at::detail::infer_is_variable({})'.format(dispatch_tensor) - elif dispatch_options: - option['inferred_backend'] = '{}.backend()'.format(dispatch_options['name']) - option['inferred_is_variable'] = '{}.is_variable()'.format(dispatch_options['name']) - else: - # doesn't depend on a specific backend, use CPU - option['inferred_backend'] = 'Backend::CPU' - option['inferred_is_variable'] = 'false' + declaration=TENSOR_METHOD_DECLARATION.substitute( + option, static_dispatch_method_body=static_dispatch_method_body), + definition=method_definition.substitute( + option, static_dispatch_method_body=static_dispatch_method_body)) + + def gen_namespace_function(option, multidispatch_tensors): + # type: (Any, List[str]) -> FunctionCode + option['inferred_type_set'] = ( + 'at::detail::multi_dispatch_tensor_type_set({})'.format(', '.join(multidispatch_tensors))) declaration = DEPRECATED_FUNCTION_DECLARATION if option['deprecated'] else FUNCTION_DECLARATION fn_declaration = declaration.substitute(option) if isinstance(type_method_dispatch, dict): - mobile_function_switches = [] - for backend in mobile_backends: + static_dispatch_function_switches = [] + for backend in static_dispatch_backends: if backend in type_method_dispatch: - mobile_function_switches.append(MOBILE_FUNCTION_SWITCH_STATEMENT.substitute( + static_dispatch_function_switches.append(STATIC_DISPATCH_FUNCTION_SWITCH_STATEMENT.substitute( option, backend=backend, backend_function=type_method_dispatch[backend], native_arguments=option['native_actuals'])) - mobile_function_body = MOBILE_FUNCTION_SWITCH_BODY.substitute( + static_dispatch_function_body = STATIC_DISPATCH_FUNCTION_SWITCH_BODY.substitute( option, - backend=option['inferred_backend'], - mobile_function_switches=mobile_function_switches) + type_set=option['inferred_type_set'], + static_dispatch_function_switches=static_dispatch_function_switches) else: - mobile_function_body = MOBILE_FUNCTION_DEFAULT_BODY.substitute( + static_dispatch_function_body = STATIC_DISPATCH_FUNCTION_DEFAULT_BODY.substitute( option, native_arguments=option['native_actuals']) if is_factory_method: - fn_definition = FACTORY_DEFINITION.substitute(option, mobile_function_body=mobile_function_body) + fn_definition = FACTORY_DEFINITION.substitute( + option, static_dispatch_function_body=static_dispatch_function_body) else: - fn_definition = FUNCTION_DEFINITION.substitute(option, mobile_function_body=mobile_function_body) + if not option['use_c10_dispatcher']: + fn_definition = FUNCTION_DEFINITION.substitute( + option, static_dispatch_function_body=static_dispatch_function_body) + else: + fn_definition = C10_FUNCTION_DEFINITION.substitute( + option, static_dispatch_function_body=static_dispatch_function_body) return FunctionCode(definition=fn_definition, declaration=fn_declaration) # Emit #ifdef BUILD_NAMEDTENSOR macros for any code generated here @@ -1157,7 +1227,8 @@ def gen_namespace_function(option, dispatch_tensor, dispatch_options): # TensorBody.h, TensorMethods.h) is checked into the repo and must be # the same regardless of BUILD_NAMEDTENSOR status. is_named_tensor_only = (has_named_tensor_formals(formals) or - option['api_name'] == 'align_tensors') + option['api_name'] == 'align_tensors' or + option['api_name'] == 'align_as') def check_namedtensor_enabled(code): if is_named_tensor_only: @@ -1176,13 +1247,15 @@ def add_namedtensor_enabled_macro(code): type_method_dispatch = option['type_method_definition_dispatch'] - dispatch_options = find_formal('TensorOptions', formals) - # Only dispatch via tensor if there is no Options argument - dispatch_tensor = None if dispatch_options else find_dispatch_tensor(formals) + multidispatch_tensors = find_multidispatch_tensors(formals) option['type_method_formals'] = [format_formal(f) for f in formals] option['type_method_actuals'] = [f['name'] for f in formals] option['native_actuals'] = [f['name'] for f in formals] + if len(option['native_actuals']) == 0: + option['native_actuals_with_comma_prefix'] = '' + else: + option['native_actuals_with_comma_prefix'] = ', ' + ', '.join(option['native_actuals']) is_method = 'method' in option['variants'] is_namespace_function = 'function' in option['variants'] @@ -1194,10 +1267,16 @@ def add_namedtensor_enabled_macro(code): check_methods_do_not_start_with_underscore(option['name'], is_method) option['method_prefix_derived'] = '' - option['device_guard_declaration'] = device_guard(option, dispatch_options, dispatch_tensor) + # NB: Device guard and scalar type generated code is still based on the + # first argument. Scalar type test will be removed once TH is removed. + # If you need more complex device guard behavior, you should disable + # device guard and then manually add the guards you need. + dispatch_options = find_formal('TensorOptions', formals) + guard_tensor = None if dispatch_options else find_dispatch_tensor(formals) + option['device_guard_declaration'] = device_guard(option, dispatch_options, guard_tensor) option['named_guard_declaration'] = named_guard(option, find_tensors(formals), find_tensorlists(formals)) - option['dispatch_scalar_type_declaration'] = dispatch_scalar_type(option, dispatch_options, dispatch_tensor) + option['dispatch_scalar_type_declaration'] = dispatch_scalar_type(option, dispatch_options, guard_tensor) broadcast_arg = get_broadcast_argument(option) if broadcast_arg is not None: @@ -1207,6 +1286,14 @@ def add_namedtensor_enabled_macro(code): if BUILD_NAMEDTENSOR or not is_named_tensor_only: top_env['registration_declarations'].append( REGISTRATION_DECLARATION.substitute(option)) + if option['use_c10_dispatcher']: + top_env['c10_ops_already_moved_from_aten_to_c10'].append( + check_namedtensor_enabled(OPERATOR_NAME.substitute(option)) + ) + else: + top_env['c10_ops_not_moved_from_aten_to_c10_yet'].append( + check_namedtensor_enabled(OPERATOR_NAME.substitute(option)) + ) option['native_type_method_dispatch'] = type_method_dispatch # Note [Abstract ATen methods] @@ -1247,7 +1334,7 @@ def add_namedtensor_enabled_macro(code): method_of = ['Type'] if is_method: - code = gen_tensor_method(option) + code = gen_tensor_method(option, multidispatch_tensors) if is_named_tensor_only: code = add_namedtensor_enabled_macro(code) top_env['tensor_method_declarations'].append(code.declaration) @@ -1255,7 +1342,7 @@ def add_namedtensor_enabled_macro(code): method_of.append('Tensor') if is_namespace_function: - code = gen_namespace_function(option, dispatch_tensor, dispatch_options) + code = gen_namespace_function(option, multidispatch_tensors) if is_named_tensor_only: code = add_namedtensor_enabled_macro(code) top_env['function_definitions'].append(code.definition) @@ -1266,7 +1353,9 @@ def add_namedtensor_enabled_macro(code): return None return OutputDeclaration( name=option['api_name'], + operator_name=option['operator_name'], overload_name=option['overload_name'], + use_c10_dispatcher=option['use_c10_dispatcher'], matches_jit_signature=option["matches_jit_signature"], schema_string=option["schema_string"], method_prefix_derived=option['method_prefix_derived'], @@ -1294,8 +1383,8 @@ def add_namedtensor_enabled_macro(code): option["schema_string"] = declaration["schema_string"] try: if option['mode'] != 'native': - # XXX: Does the following line do anything meaningful? - process_option(option) + # Mutably populate option with values + process_legacy_th_option(option) else: output_option = process_native(option) if output_option: @@ -1326,10 +1415,6 @@ def nullable_argument(argument): # type: (THFormal) -> bool return argument.get('is_nullable', False) - def bool_option_is_string(argument): - # type: (THFormal) -> bool - return 'if_true' in argument and isinstance(argument['if_true'], string_type) - def get_argument(env, argument, option): # type: (Environment, THFormal, FunctionOption) -> str if requires_checked_cast(argument): @@ -1339,17 +1424,7 @@ def get_argument(env, argument, option): checked_use = CHECKED_USE_NULLABLE.substitute( env={}, arg_name=argument['name'], usage=checked_use) return checked_use - elif argument['type'] == 'bool' and 'if_true' in argument: - if bool_option_is_string(argument): - tpl = '({}) ? "{}" : "{}"' - else: - tpl = '({}) ? {} : {}' - return tpl.format(argument['name'], - argument['if_true'], argument['if_false']) elif argument['type'] == 'CONSTANT': - # this is a bool that is actually a string... - if bool_option_is_string(argument): - return '"{}"'.format(argument['name']) v = str(argument.get('default', argument['name'])) for pattern, replacement in CONSTANT_REPLACEMENTS: v = re.sub(pattern, replacement, v) @@ -1543,14 +1618,11 @@ def emit_body(env, option, scalar_type_cases): # the checked cast succeeds even if the Tensor is not # defined null_okay = 'true' if nullable_argument(arg) else 'false' - default_init = [] - if 'default_init' in arg: - default_init.append(arg['default_init']) check_cast = CHECKED_CAST[arg['type']].substitute( case_env, arg_name=arg['name'], arg_pos=count, api_name=option['api_name'], null_okay=null_okay, - default_init=default_init, size=arg.get('size')) + size=arg.get('size')) case_body.append("auto {}_ = {};".format( arg['name'], check_cast)) if drop_argument(arg, option): @@ -1672,7 +1744,7 @@ def emit_body(env, option, scalar_type_cases): body.append(LEGACY_TH_DEFINITION_SWITCH_STATEMENT.substitute(env, cases=cases)) return body - def process_option(option): + def process_legacy_th_option(option): # type: (FunctionOption) -> None backend = backend_type_env['Backend'] if backend in option['backend_types']: @@ -1714,10 +1786,10 @@ def process_native(option): if option['mode'] == 'NN' and option.get('cimpls') is None: continue if option['mode'] != 'native': - process_option(option) + process_legacy_th_option(option) else: process_native(option) except NYIError: pass - return (type_object_declarations, type_object_definitions, function_registrations, legacy_th_declarations, - legacy_th_definitions) + return (type_object_declarations, type_object_definitions, function_registrations, + legacy_th_declarations, legacy_th_definitions) diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py index e23dedb6a2f4a..cff113ebb0211 100644 --- a/aten/src/ATen/gen.py +++ b/aten/src/ATen/gen.py @@ -1,3 +1,4 @@ + import argparse import os @@ -119,6 +120,7 @@ def check_all_files_written(self): TYPE_DEFAULT_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.h") TYPE_DEFAULT_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.cpp") REGISTRATION_DECLARATIONS_H = CodeTemplate.from_file(TEMPLATE_PATH + "/RegistrationDeclarations.h") +OPS_ALREADY_MOVED_TO_C10_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/OpsAlreadyMovedToC10.cpp") TENSOR_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TensorBody.h") TENSOR_METHODS_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TensorMethods.h") @@ -157,6 +159,8 @@ def backend_to_devicetype(backend): 'cpu_type_headers': [], 'cuda_type_headers': [], 'function_registrations': [], + 'c10_ops_already_moved_from_aten_to_c10': [], + 'c10_ops_not_moved_from_aten_to_c10_yet': [], 'type_method_declarations': [], 'type_method_definitions': [], 'tensor_method_declarations': [], @@ -388,7 +392,7 @@ def cmpfiles_with_eol_normalization(a, b, names): def is_namedtensor_only_decl(decl): if 'Dimname' in decl['schema_string']: return True - if decl['name'] == 'align_tensors': + if decl['name'] == 'align_tensors' or decl['name'] == 'align_as': return True return False @@ -426,7 +430,8 @@ def generate_outputs(): core_files = { 'TensorBody.h': TENSOR_H, - 'TensorMethods.h': TENSOR_METHODS_H + 'TensorMethods.h': TENSOR_METHODS_H, + 'OpsAlreadyMovedToC10.cpp': OPS_ALREADY_MOVED_TO_C10_CPP, } for core_file, core_template_file in core_files.items(): diff --git a/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h index 1033f0bb1cc9c..7fa2bcbca01e5 100644 --- a/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h +++ b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h @@ -82,7 +82,10 @@ struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplI C10_HIP_CHECK_WARN(hipSetDevice(d.index())); } Stream getStream(Device d) const noexcept override { - return getCurrentHIPStreamMasqueradingAsCUDA().unwrap(); + return getCurrentHIPStreamMasqueradingAsCUDA(d.index()).unwrap(); + } + Stream getDefaultStream(Device d) const override { + return getDefaultHIPStreamMasqueradingAsCUDA(d.index()); } Stream exchangeStream(Stream s) const noexcept override { HIPStreamMasqueradingAsCUDA cs(s); diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index 46c5d9368b413..0c75df68b76f5 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -16,30 +16,24 @@ DEFINE_DISPATCH(div_stub); DEFINE_DISPATCH(atan2_stub); DEFINE_DISPATCH(logical_xor_stub); +static constexpr char alpha_mismatch_err[] = + "For integral input tensors, argument alpha must not be a floating point number."; + Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { - if (other.is_sparse()) { - if (self.is_sparse()) { - at::_sparse_add_out(result, self, other, alpha); - } else { - at::_sparse_dense_add_out(result, self, other, alpha); - } - return result; - } else if (self.is_sparse()) { - AT_ERROR("add(sparse, dense) is not supported. Use add(dense, sparse) instead."); - } auto iter = TensorIterator::binary_op(result, self, other, /*check_mem_overlap=*/true); + TORCH_CHECK(! alpha.isBoolean() || iter.dtype() == ScalarType::Bool, "Boolean alpha only supported for boolean results"); + TORCH_CHECK(isFloatingType(iter.dtype()) || alpha.isIntegral(true), alpha_mismatch_err); add_stub(iter.device_type(), iter, alpha); + TORCH_INTERNAL_ASSERT(result.scalar_type() == iter.output().dtype()); return result; } Tensor add(const Tensor& self, const Tensor& other, Scalar alpha) { Tensor result; - if (other.is_sparse()) { - result = at::empty({0}, self.options()); - return native::add_out(result, self, other, alpha); - } auto iter = TensorIterator::binary_op(result, self, other); + TORCH_CHECK(! alpha.isBoolean() || iter.dtype() == ScalarType::Bool, "Boolean alpha only supported for boolean results"); + TORCH_CHECK(isFloatingType(iter.dtype()) || alpha.isIntegral(true), alpha_mismatch_err); add_stub(iter.device_type(), iter, alpha); return iter.output(); } @@ -49,13 +43,6 @@ Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) { } Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) { - if (self.is_sparse()) { - if (other.dim() != 0) { - AT_ERROR("div(): sparse division only supports division by a scalar ", - "(got shape ", other.sizes(), " for argument 'other')"); - } - return at::_sparse_div_zerodim_out(result, self, other); - } auto iter = TensorIterator::binary_op(result, self, other, /*check_mem_overlap=*/true); div_stub(iter.device_type(), iter); @@ -64,10 +51,6 @@ Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) { Tensor div(const Tensor& self, const Tensor& other) { Tensor result; - if (self.is_sparse()) { - result = at::empty({0}, self.options()); - return native::div_out(result, self, other); - } auto iter = TensorIterator::binary_op(result, self, other); div_stub(iter.device_type(), iter); return iter.output(); @@ -78,9 +61,6 @@ Tensor& div_(Tensor& self, const Tensor& other) { } Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) { - if (self.is_sparse() || other.is_sparse()) { - return at::_sparse_mul_out(result, self, other); - } auto iter = TensorIterator::binary_op(result, self, other, /*check_mem_overlap=*/true); mul_stub(iter.device_type(), iter); @@ -89,10 +69,6 @@ Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) { Tensor mul(const Tensor& self, const Tensor& other) { Tensor result; - if (self.is_sparse() || other.is_sparse()) { - result = at::empty({0}, self.options()); - return native::mul_out(result, self, other); - } auto iter = TensorIterator::binary_op(result, self, other); mul_stub(iter.device_type(), iter); return iter.output(); @@ -114,33 +90,19 @@ static inline void sub_check(const Tensor& self, const Tensor& other) { Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { sub_check(self, other); - if (other.is_sparse()) { - if (!self.sizes().equals(other.sizes())) { - AT_ERROR("sizes do not match"); - } - if (self.is_sparse()) { - at::_sparse_add_out(result, self, other, -alpha); - } else { - at::_sparse_dense_add_out(result, self, other, -alpha); - } - return result; - } else if (self.is_sparse()) { - AT_ERROR("sub(sparse, dense) is not supported. Use sub(dense, sparse) instead."); - } auto iter = TensorIterator::binary_op(result, self, other, /*check_mem_overlap=*/true); + TORCH_CHECK(isFloatingType(iter.dtype()) || alpha.isIntegral(false), alpha_mismatch_err); sub_stub(iter.device_type(), iter, alpha); + TORCH_INTERNAL_ASSERT(result.scalar_type() == iter.output().dtype()); return result; } Tensor sub(const Tensor& self, const Tensor& other, Scalar alpha) { sub_check(self, other); Tensor result; - if (other.is_sparse()) { - result = at::empty({0}, self.options()); - return native::sub_out(result, self, other, alpha); - } auto iter = TensorIterator::binary_op(result, self, other); + TORCH_CHECK(isFloatingType(iter.dtype()) || alpha.isIntegral(false), alpha_mismatch_err); sub_stub(iter.device_type(), iter, alpha); return iter.output(); } @@ -160,7 +122,7 @@ Tensor& atan2_out(Tensor& result, const Tensor& self, const Tensor& other) { } Tensor atan2(const Tensor& self, const Tensor& other) { - Tensor result = at::empty_like(self); + Tensor result = at::empty({0}, self.options()); return native::atan2_out(result, self, other); } @@ -186,12 +148,18 @@ Tensor& add_(Tensor& self, Scalar other, Scalar alpha) { return native::add_(self, wrapped_scalar_tensor(other), alpha); } +// WARNING: There doesn't appear to be any testing for this function +// with sparse self input. Tensor div(const Tensor& self, Scalar other) { - return native::div(self, wrapped_scalar_tensor(other)); + return self.div(wrapped_scalar_tensor(other)); // redispatch! } +// WARNING: This function, with a sparse self, is currently only +// exercised by DistributedDataParallelTest.test_sparse_gradients +// (you need to exercise it from C++, because this overload is never +// used for Python) Tensor& div_(Tensor& self, Scalar other) { - return native::div_(self, wrapped_scalar_tensor(other)); + return self.div_(wrapped_scalar_tensor(other)); // redispatch! } Tensor mul(const Tensor& self, Scalar other) { diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 22d6fc386b26d..ea939b696b86f 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -1,5 +1,3 @@ -#include "ATen/native/Convolution.h" - #include #include #include @@ -13,8 +11,6 @@ static const int MIOPEN_DIM_MAX = 4; namespace at { namespace native { -std::atomic disable_mkldnn_conv{false}; - struct ConvParams { std::vector stride; std::vector padding; @@ -147,13 +143,12 @@ auto ConvParams::use_miopen(const at::Tensor& input) const -> bool { && input.is_cuda() && input.dim() <= MIOPEN_DIM_MAX && !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1 - && !transposed ; } auto ConvParams::use_mkldnn(const at::Tensor& input) const -> bool { #if AT_MKLDNN_ENABLED() - if (disable_mkldnn_conv.load()) { + if (!at::globalContext().userEnabledMkldnn()) { return false; } return (input.is_mkldnn()) || // input is mkldnn Tensor @@ -337,7 +332,7 @@ static void check_shape_forward(const at::Tensor& input, const ConvParams& params, bool input_is_mkldnn) { int64_t k = input.ndimension(); int64_t weight_dim = weight.ndimension(); - std::vector weight_sizes(k); + std::vector weight_sizes(weight_dim); // mkldnn conv2d weights could have been re-ordered to 5d by // mkldnn_reorder_conv2d_weight if ((weight_dim == k + 1) && input_is_mkldnn) { @@ -347,7 +342,7 @@ static void check_shape_forward(const at::Tensor& input, weight_dim = k; } else { std::copy_n( - weight.sizes().cbegin(), k, weight_sizes.begin()); + weight.sizes().cbegin(), weight_dim, weight_sizes.begin()); } int64_t groups = params.groups; auto padding = params.padding; diff --git a/aten/src/ATen/native/Convolution.h b/aten/src/ATen/native/Convolution.h deleted file mode 100644 index 84ac861174411..0000000000000 --- a/aten/src/ATen/native/Convolution.h +++ /dev/null @@ -1,18 +0,0 @@ -#pragma once - -#include - -#include - -namespace at { -namespace native { - -// *** Warning: this code is here to workaround an issue: -// https://github.com/pytorch/pytorch/issues/23825 -// -// This flag allows us to temporarily disable MKLDNN to work around cases -// where there are bugs. -extern CAFFE2_API std::atomic disable_mkldnn_conv; - -} // namespace at -} // namespace native diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index 6430278a4b560..cf5b9cf596915 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -7,9 +7,8 @@ #include #include #include -#ifdef BUILD_NAMEDTENSOR #include -#endif +#include namespace { @@ -123,7 +122,7 @@ Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) { TORCH_CHECK(self.qscheme() == src.qscheme(), "Quantized Copy only works with same qscheme"); TORCH_CHECK(self.scalar_type() == src.scalar_type()); - self.set_quantizer_(at::make_per_tensor_affine_quantizer(src.q_scale(), src.q_zero_point(), src.scalar_type())); + self.set_quantizer_(src.quantizer()); } auto iter = TensorIterator(); diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp index f6c787f322a5d..f3e6740eb4d4e 100644 --- a/aten/src/ATen/native/Distance.cpp +++ b/aten/src/ATen/native/Distance.cpp @@ -79,6 +79,7 @@ Tensor _cdist_backward(const Tensor& grad, const Tensor& x1, const Tensor& x2, c TORCH_CHECK(x1.is_contiguous(), "_cdist_backward requires X1 to be contiguous"); TORCH_CHECK(x2.is_contiguous(), "_cdist_backward requires X2 to be contiguous"); TORCH_CHECK(cdist.is_contiguous(), "_cdist_backward requires dist to be contiguous"); + TORCH_CHECK(grad.is_contiguous(), "_cdist_backward requires grad to be contiguous"); int64_t n = x1.size(-2); int64_t m = x1.size(-1); auto device1 = x1.type().device_type(); diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index 9f1debd41c024..a3b32da0b22c8 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -12,14 +13,13 @@ #include #include #include -#ifdef BUILD_NAMEDTENSOR #include -#endif #include #include #include #include +#include #include @@ -296,25 +296,36 @@ Tensor _s_dirichlet_cpu(const Tensor& alpha, Generator *gen) { return ret; } -Tensor multinomial_cpu(const Tensor& self, int64_t n_sample, bool with_replacement, Generator *gen) { - Tensor result = at::empty({0}, self.options().dtype(kLong)); - multinomial_out_cpu(result, self, n_sample, with_replacement, gen); - return result; -} +/* The largest consecutive integer representable in float32 (2^24) */ +constexpr int64_t FLOAT32_MAX_CONSECUTIVE_INT = 1 << (FLT_MANT_DIG); -Tensor& multinomial_out_cpu(Tensor& result, const Tensor& self, int64_t n_sample, bool with_replacement, Generator *gen) { - TORCH_CHECK(at::isFloatingType(self.scalar_type()), "multinomial only supports floating-point dtypes for input, got: ", self.scalar_type()); - TORCH_CHECK(result.scalar_type() == ScalarType::Long, "multinomial expects Long tensor out, got: ", result.scalar_type()); +Tensor& multinomial_out(Tensor& result, const Tensor& self, int64_t n_sample, bool with_replacement, Generator *gen) { + TORCH_CHECK(result.device() == self.device(), "multinomial arguments must have the same device"); + TORCH_CHECK(self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim"); + TORCH_CHECK(at::isFloatingType(self.scalar_type()), + "multinomial only supports floating-point dtypes for input, got: ", self.scalar_type()); + TORCH_CHECK(result.scalar_type() == ScalarType::Long, + "multinomial expects Long tensor out, got: ", result.scalar_type()); TORCH_CHECK(n_sample > 0, "cannot sample n_sample <= 0 samples"); int64_t n_categories = self.size(-1); - TORCH_CHECK(with_replacement || (n_sample <= n_categories), "cannot sample n_sample > prob_dist.size(-1) samples without replacement"); + TORCH_CHECK(with_replacement || (n_sample <= n_categories), + "cannot sample n_sample > prob_dist.size(-1) samples without replacement"); + // Since the index tensor is float, numCategories cannot exceed max + // float integer precision + TORCH_CHECK(n_categories <= FLOAT32_MAX_CONSECUTIVE_INT, "number of categories cannot exceed 2^24"); if (self.dim() > 1) { int64_t n_dist = self.size(-2); result.resize_({n_dist, n_sample}); } else { result.resize_({n_sample}); } - multinomial_stub(kCPU, result, self, n_sample, with_replacement, gen); + multinomial_stub(result.type().device_type(), result, self, n_sample, with_replacement, gen); + return result; +} + +Tensor multinomial(const Tensor& self, int64_t n_sample, bool with_replacement, Generator *gen) { + Tensor result = at::empty({0}, self.options().dtype(kLong)); + native::multinomial_out(result, self, n_sample, with_replacement, gen); return result; } diff --git a/aten/src/ATen/native/Indexing.cpp b/aten/src/ATen/native/Indexing.cpp index f44bccc79bfb9..2e137087ee9cf 100644 --- a/aten/src/ATen/native/Indexing.cpp +++ b/aten/src/ATen/native/Indexing.cpp @@ -55,6 +55,7 @@ #include #include #include +#include #include #include diff --git a/aten/src/ATen/native/LegacyBridge.cpp b/aten/src/ATen/native/LegacyBridge.cpp index 3e544cb14d2f2..e69de29bb2d1d 100644 --- a/aten/src/ATen/native/LegacyBridge.cpp +++ b/aten/src/ATen/native/LegacyBridge.cpp @@ -1,79 +0,0 @@ -#include -#include -#include - -namespace at { namespace native { - -// Note [Multiple dispatch to sparse] -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// In an ideal world, we would use direct support for multiple dispatch to -// say that add(Dense, Dense) should dispatch to one function, while -// add(Dense, Sparse) should dispatch to another function. -// -// In a world where we only have single dispatch, we can single dispatch on -// the first function, and then do an is_sparse() test on the second argument -// to direct ourselves to the correct argument. -// -// We are in neither of those worlds. Instead, we have a _th_addmm function -// which has legacy implementations in the single dispatch world, BUT our -// actual addmm function needs to call s_native_addmm if the function *would have* -// utilized a sparse kernel that is natively implemented. -// -// _th_addmm is "good old single dispatch" which internally handles the is_sparse() -// test and also handles broadcasting. s_native_addmm works asymmetrically: -// it doesn't handle broadcasting at all, and it ASSUMES that the relevant -// argument is a sparse tensor. Why the asymmetry? It turns out it is not -// so easy to figure out if a kernel is implemented in THS; it's not as simple -// as testing if the first argument is sparse, because, e.g., -// in addmm(Dense, Sparse), the sparse kernel is in the second argument. So, -// the trampoline function is going to know about the overloads *anyway*; it -// might as well also handle is_sparse() and broadcasting while it's at it. -// -// Why not change TH to follow this new scheme? We could... but since it's -// all going away when we finish porting the TH functions to ATen, we haven't -// done it. - -// NB: You may be tempted to implement addmm and addmm_ just as calls to addmm_out, but -// calling the actual implementing function matters, because broadcast -// will be handled differently depending on if you call addmm_ or (a seemingly -// equivalent) add_out. Arguably this mismatch in treatment is a bug, -// c.f., https://github.com/pytorch/pytorch/issues/8308 but fixing this -// bug would involve changing a lot of other places, so we leave it -// alone for now. - -Tensor& addmm_out(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) { - // See Note [Multiple dispatch to sparse] - auto mat1_sparse = mat1.is_sparse(); - if (mat1_sparse) { - Tensor b_self; - std::tie(b_self) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); - return s_native_addmm_out(result, b_self, mat1, mat2, beta, alpha); - } else { - return at::_addmm_out(result, self, mat1, mat2, beta, alpha); - } -} - -Tensor addmm(const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) { - // See Note [Multiple dispatch to sparse] - auto mat1_sparse = mat1.is_sparse(); - if (mat1_sparse) { - Tensor b_self; - std::tie(b_self) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm"); - return s_native_addmm(b_self, mat1, mat2, beta, alpha); - } else { - return at::_addmm(self, mat1, mat2, beta, alpha); - } -} - -Tensor& addmm_(Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) { - // See Note [Multiple dispatch to sparse] - auto mat1_sparse = mat1.is_sparse(); - if (mat1_sparse) { - // inplace is not broadcasting - return s_native_addmm_(self, mat1, mat2, beta, alpha); - } else { - return at::_addmm_(self, mat1, mat2, beta, alpha); - } -} - -}} // namespace at::native diff --git a/aten/src/ATen/native/LegacyDefinitions.cpp b/aten/src/ATen/native/LegacyDefinitions.cpp index a8afd251787ef..5f20195b47a3b 100644 --- a/aten/src/ATen/native/LegacyDefinitions.cpp +++ b/aten/src/ATen/native/LegacyDefinitions.cpp @@ -2,6 +2,7 @@ #include #include #include +#include namespace at { namespace native { diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index b36ecd32a7b24..06a88f5668657 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -10,9 +10,8 @@ #include #include #include -#ifdef BUILD_NAMEDTENSOR #include -#endif +#include namespace at { namespace native { @@ -28,8 +27,13 @@ static inline std::tuple _lu_det_P_diag_U(const Tensor& self) { TORCH_CHECK(infos.ge(0).all().item(), "Invalid argument passed to lu"); auto n = self.size(-1); auto num_exchanges = (at::arange(1, n + 1, pivs.options()) != pivs).sum(-1, /*keepdim=*/false, /*dtype=*/self.scalar_type()).fmod_(2); - return std::tuple(num_exchanges.mul_(-2).add_(1), - lu.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1)); + auto u_diagonal = lu.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1); + + // We have to manually set the diagonal to 0 due to an issue with MAGMA's getrf_batched routine + if (self.dim() > 2 && self.is_cuda()) { + u_diagonal.index_put_(infos.nonzero_numpy(), at::zeros({}, self.options())); + } + return std::tuple(num_exchanges.mul_(-2).add_(1), u_diagonal); } Tensor det(const Tensor& self) { @@ -79,18 +83,20 @@ std::tuple slogdet(const Tensor& self) { } Tensor pinverse(const Tensor& self, double rcond) { - TORCH_CHECK(at::isFloatingType(self.scalar_type()) && self.dim() == 2, - "pinverse(", self.type(), "{", self.sizes(), "}): expected a 2D tensor " + TORCH_CHECK(at::isFloatingType(self.scalar_type()) && self.dim() >= 2, + "pinverse(", self.type(), "{", self.sizes(), "}): expected a tensor with 2 or more dimensions " "of floating types"); if (self.numel() == 0) { // Match NumPy - return at::empty({self.size(1), self.size(0)}, self.options()); + auto self_sizes = self.sizes().vec(); + std::swap(self_sizes[self.dim() - 1], self_sizes[self.dim() - 2]); + return at::empty(self_sizes, self.options()); } Tensor U, S, V; std::tie(U, S, V) = self.svd(); - Tensor max_val = S[0]; + Tensor max_val = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); Tensor S_pseudoinv = at::where(S > rcond * max_val, S.reciprocal(), at::zeros({}, self.options())); - return V.mm(S_pseudoinv.diag().mm(U.t())); + return at::matmul(V, at::matmul(S_pseudoinv.diag_embed(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1), U.transpose(-2, -1))); } static inline Tensor _matrix_rank_helper(const Tensor& self, bool symmetric) { @@ -301,7 +307,7 @@ Tensor& bmm_out_cpu(Tensor &result, const Tensor& batch1, const Tensor& batch2) } namedinference::propagate_names( result, - std::move(namedinference::compute_bmm_outnames(result, batch1, batch2)), + namedinference::compute_bmm_outnames(result, batch1, batch2), /*validate_names=*/false); #endif return result; diff --git a/aten/src/ATen/native/LossCTC.cpp b/aten/src/ATen/native/LossCTC.cpp index 55a5f924e2b3c..2c55dff0ea9b2 100644 --- a/aten/src/ATen/native/LossCTC.cpp +++ b/aten/src/ATen/native/LossCTC.cpp @@ -84,7 +84,7 @@ std::tuple ctc_loss_cpu_template(const Tensor& log_probs, const int64_t max_input_length = log_probs.size(0); for (int64_t b = 0; b < batch_size; b++) { TORCH_CHECK(input_lengths[b] <= max_input_length, - "Expected tensor to have size at least ", max_input_length, " at dimension 1, but got size ", input_lengths[b], " for ", log_probs_arg, + "Expected input_lengths to have value at most ", max_input_length, ", but got value ", input_lengths[b], " (while checking arguments for ", c, ")"); } diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index d843a86b5d67b..bb19f411e8d55 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -1,5 +1,7 @@ -#include -#include +#include +#include +#include +#include /* The next function is taken from https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c. Below is the copyright. @@ -52,7 +54,9 @@ Output was modified to be inf or -inf when input is 1 or -1. */ #define CENTRAL_RANGE 0.7 -static inline double calc_erfinv(double y) { +template +static inline typename std::enable_if::value, T>::type +calc_erfinv(T y) { /* Function to calculate inverse error function. Rational approximation is used to generate an initial approximation, which is then improved to full accuracy by two steps of Newton's method. Code is a direct @@ -60,29 +64,30 @@ translation of the erfinv m file in matlab version 2.0. Author: Gary L. Pavlis, Indiana University Date: February 1996 */ - double x,z,num,dem; /*working variables */ + T x, z, num, dem; /*working variables */ /* coefficients in rational expansion */ - double a[4]={ 0.886226899, -1.645349621, 0.914624893, -0.140543331}; - double b[4]={-2.118377725, 1.442710462, -0.329097515, 0.012229801}; - double c[4]={-1.970840454, -1.624906493, 3.429567803, 1.641345311}; - double d[2]={ 3.543889200, 1.637067800}; - if(fabs(y) > 1.0) return (atof("NaN")); /* This needs IEEE constant*/ - if(fabs(y) == 1.0) return((copysign(1.0,y))*atof("INFINITY")); - if(fabs(y) <= CENTRAL_RANGE){ - z = y*y; + T a[4]={ 0.886226899, -1.645349621, 0.914624893, -0.140543331}; + T b[4]={-2.118377725, 1.442710462, -0.329097515, 0.012229801}; + T c[4]={-1.970840454, -1.624906493, 3.429567803, 1.641345311}; + T d[2]={ 3.543889200, 1.637067800}; + T y_abs = std::abs(y); + if(y_abs > 1.0) return std::numeric_limits::quiet_NaN(); + if(y_abs == 1.0) return std::copysign(std::numeric_limits::infinity(), y); + if(y_abs <= static_cast(CENTRAL_RANGE)) { + z = y * y; num = (((a[3]*z + a[2])*z + a[1])*z + a[0]); - dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0])*z + 1.0); - x = y*num/dem; + dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0]) * z + static_cast(1.0)); + x = y * num / dem; } else{ - z = sqrt(-log((1.0-fabs(y))/2.0)); - num = ((c[3]*z + c[2])*z + c[1])*z + c[0]; - dem = (d[1]*z + d[0])*z + 1.0; - x = (copysign(1.0,y))*num/dem; + z = std::sqrt(-std::log((static_cast(1.0)-y_abs)/static_cast(2.0))); + num = ((c[3]*z + c[2])*z + c[1]) * z + c[0]; + dem = (d[1]*z + d[0])*z + static_cast(1.0); + x = std::copysign(num, y) / dem; } /* Two steps of Newton-Raphson correction */ - x = x - (erf(x) - y)/( (2.0/sqrt(M_PI))*exp(-x*x)); - x = x - (erf(x) - y)/( (2.0/sqrt(M_PI))*exp(-x*x)); + x = x - (std::erf(x) - y) / ((static_cast(2.0)/static_cast(std::sqrt(M_PI)))*std::exp(-x*x)); + x = x - (std::erf(x) - y) / ((static_cast(2.0)/static_cast(std::sqrt(M_PI)))*std::exp(-x*x)); return(x); } @@ -197,7 +202,7 @@ static inline double calc_digamma(double x) { * Cephes Math Library Release 2.8: June, 2000 * Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier */ -static inline double calc_digamma(float x) { +static inline float calc_digamma(float x) { static float PSI_10 = 2.25175258906672110764f; if (x == 0) { return INFINITY; @@ -236,9 +241,9 @@ static inline double calc_digamma(float x) { }; float y = 0; - if (x < 1.0e17) { + if (x < 1.0e17f) { float z = 1 / (x * x); y = z * polevlf(z, A, 6); } - return result + logf(x) - (0.5 / x) - y; + return result + logf(x) - (0.5f / x) - y; } diff --git a/aten/src/ATen/native/NamedTensor.cpp b/aten/src/ATen/native/NamedTensor.cpp index b2607ff26030c..d911895934ea3 100644 --- a/aten/src/ATen/native/NamedTensor.cpp +++ b/aten/src/ATen/native/NamedTensor.cpp @@ -1,16 +1,17 @@ -#ifdef BUILD_NAMEDTENSOR #include #include #include +#include +#ifdef BUILD_NAMEDTENSOR namespace at { namespace native { Tensor& names_(Tensor& self, optional names) { return at::internal_set_names_inplace(self, names); } -Tensor view_names(const Tensor& self, optional names) { +Tensor renamed(const Tensor& self, optional names) { auto result = self.alias(); at::internal_set_names_inplace(result, names); return result; @@ -59,7 +60,6 @@ static std::vector aligned_size( ptrdiff_t dim = (ptrdiff_t)tensor_sizes.size() - 1; ptrdiff_t idx = (ptrdiff_t)aligned_names.size() - 1; for (; idx >= 0 && dim >= 0; --idx) { - TORCH_INTERNAL_ASSERT(!tensor_names[dim].is_tagged() && !aligned_names[idx].is_tagged(), "Tagged names NYI"); if (tensor_names[dim] != aligned_names[idx]) { continue; } @@ -73,7 +73,7 @@ static std::vector aligned_size( // *, a // [*, a] is a subsequence of [*, c, a, b], but in order to align them, // we'd have to move the * to create [*, c: 1, a, b: 1] - if (tensor_names[dim].is_wildcard() && + if (tensor_names[dim].isWildcard() && tensor_sizes.size() - dim != aligned_names.size() - idx) { report_moving_unnamed_dim_error( tensor_names, aligned_names, /*is_aligning_two_tensors=*/false); @@ -89,6 +89,38 @@ static std::vector aligned_size( return expanded_sizes; } +Tensor refine_names(const Tensor& self, DimnameList names) { + const auto self_names = self.names(); + TORCH_CHECK(self_names.size() == names.size(), + "refine_names: cannot coerce Tensor", self_names, " to Tensor", names, + " because they have a different number of dims (", + self_names.size(), " and ", names.size(), " respectively)."); + check_names_valid_for(self, names); + + for (size_t idx = 0; idx < self_names.size(); idx++) { + const auto& self_name = self_names[idx]; + const auto& out_name = names[idx]; + if (self_name == out_name || self_name.isWildcard()) { + continue; + } + if (out_name.isWildcard()) { + TORCH_CHECK(false, + "refine_names: cannot coerse Tensor", self_names, " to Tensor", names, + " because ", self_name, " is more specific than ", out_name, " at index ", + idx); + } + TORCH_CHECK(false, + "refine_names: cannot coerse Tensor", self_names, " to Tensor", names, + " because ", self_name, " is different from ", out_name, " at index ", + idx); + TORCH_INTERNAL_ASSERT(false); // done handling errors + } + + auto result = self.alias(); + internal_set_names_inplace(result, names); + return result; +} + // [Alignment rules] // Aligns `tensor` to names with the following rules: // 1) Check that tensor.names is a subsequence (not necessarily contiguous) of `names`. @@ -104,18 +136,42 @@ static Tensor align(const Tensor& tensor, DimnameList names, bool is_aligning_tw tensor.names(), names, is_aligning_two_tensors); - auto result = tensor.view_names(nullopt).view(expanded_sizes); + auto result = tensor.renamed(nullopt).view(expanded_sizes); at::internal_set_names_inplace(result, names); return result; } Tensor align_to(const Tensor& tensor, DimnameList names) { - TORCH_CHECK( - names.size() >= tensor.dim(), - "Cannot align tensor with dims ", - tensor.names(), - " to a shorter list of dims ", names, "."); - return align(tensor, names, /*aligning_two_tensors=*/false); + auto tensor_names = tensor.names(); + auto tensor_sizes = tensor.sizes(); + auto tensor_strides = tensor.strides(); + std::vector new_sizes(names.size(), 1); + std::vector new_strides(names.size(), 0); + + for (auto idx = 0; idx < tensor_names.size(); ++idx) { + const auto& dim = tensor_names[idx]; + TORCH_CHECK(dim.isBasic(), + "align_to: All input dims must be named. Found unnamed dim at index ", + dim, " of Tensor", tensor_names); + auto it = std::find(names.begin(), names.end(), dim); + TORCH_CHECK(it != names.end(), + "align_to: Cannot find dim ", dim, " from Tensor", names, + " in desired alignment ", names, "."); + int64_t new_idx = std::distance(names.begin(), it); + new_sizes[new_idx] = tensor_sizes[idx]; + new_strides[new_idx] = tensor_strides[idx]; + } + Tensor result; + { + NoNamesGuard guard; + result = tensor.as_strided(new_sizes, new_strides); + } + internal_set_names_inplace(result, names); + return result; +} + +Tensor align_as(const Tensor& tensor, const Tensor& other) { + return native::align_to(tensor, other.names()); } static std::vector align_tensors_to(TensorList tensors, DimnameList names) { @@ -136,5 +192,45 @@ std::vector align_tensors(TensorList tensors) { return align_tensors_to(tensors, longest_dim->names()); } +static int64_t cumprod(IntArrayRef sizes) { + int64_t result = 1; + for (auto size : sizes) { + result *= size; + } + return result; +} + +Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, DimnameList names) { + // unflatten is implemented only as a python method on tensor right now. + // The following asserts should be checked by the python method. + TORCH_INTERNAL_ASSERT(names.size() == sizes.size()); + TORCH_INTERNAL_ASSERT(sizes.size() > 0); + TORCH_CHECK( + cumprod(sizes) == self.size(dim), + "unflatten: Provided names ", names, " and sizes ", sizes, " but sizes don't multiply " + "up to the size of dim ", dim, " (", self.names()[dim], ": ", self.size(dim), + ") in Tensor", self.names()); + + auto outnames = self.names().vec(); + outnames.erase(outnames.begin() + dim); + outnames.insert(outnames.begin() + dim, names.begin(), names.end()); + + auto new_sizes = self.sizes().vec(); + new_sizes.erase(new_sizes.begin() + dim); + new_sizes.insert(new_sizes.begin() + dim, sizes.begin(), sizes.end()); + + Tensor result; + { + NoNamesGuard guard; + result = self.view(new_sizes); + } + at::internal_set_names_inplace(result, outnames); + return result; +} + +Tensor unflatten(const Tensor& self, Dimname dim, IntArrayRef sizes, DimnameList names) { + return native::unflatten(self, dimname_to_position(self, dim), sizes, names); +} + }} // namespace at::native #endif diff --git a/aten/src/ATen/native/PointwiseOps.cpp b/aten/src/ATen/native/PointwiseOps.cpp index 0a6c76b567ca4..17b0a231c766e 100644 --- a/aten/src/ATen/native/PointwiseOps.cpp +++ b/aten/src/ATen/native/PointwiseOps.cpp @@ -5,10 +5,9 @@ #include #include #include +#include -#ifdef BUILD_NAMEDTENSOR #include -#endif namespace at { namespace native { diff --git a/aten/src/ATen/native/Pow.cpp b/aten/src/ATen/native/Pow.cpp index 8649804b33ea5..7edcdbd1b8524 100644 --- a/aten/src/ATen/native/Pow.cpp +++ b/aten/src/ATen/native/Pow.cpp @@ -19,13 +19,13 @@ Tensor& pow_out(Tensor& result, const Tensor& base, const Tensor& exp) { Tensor& pow_out(Tensor& result, const Tensor& base, Scalar exp) { // Numpy compatibility check: - TORCH_CHECK(!(isIntegralType(base.scalar_type()) && - exp.isIntegral() && exp.toLong() < 0), + TORCH_CHECK(!(isIntegralType(base.scalar_type(), true) && + exp.isIntegral(true) && exp.toLong() < 0), "Integers to negative integer powers are not allowed."); if (exp.toDouble() == 0.0) { result.resize_as_(base).fill_(1); } else if (exp.toDouble() == 1.0) { - result.copy_(base); + result.resize_as_(base).copy_(base); } else { auto iter = TensorIterator::unary_op(result, base, /*check_mem_overlap=*/true); @@ -38,7 +38,7 @@ Tensor& pow_out(Tensor& result, Scalar base, const Tensor& exp) { if (base.toDouble() == 1.0) { result.resize_as_(exp).fill_(1); } else { - native::pow_out(result, c10::scalar_to_tensor(base), exp); + native::pow_out(result, c10::scalar_to_tensor(base, exp.device()), exp); } return result; } @@ -52,7 +52,7 @@ Tensor& pow_(Tensor& base, Scalar alpha) { } Tensor pow(const Tensor& base, const Tensor& exp) { - Tensor result = at::empty_like(base); + Tensor result = at::empty({0}, base.options()); return native::pow_out(result, base, exp); } diff --git a/aten/src/ATen/native/QuantizedLinear.cpp b/aten/src/ATen/native/QuantizedLinear.cpp index fe57d52eb7072..7ae43978cd338 100644 --- a/aten/src/ATen/native/QuantizedLinear.cpp +++ b/aten/src/ATen/native/QuantizedLinear.cpp @@ -57,7 +57,7 @@ Tensor fbgemm_linear_int8_weight_fp32_activation( TORCH_CHECK(bias.dim() == 1); TORCH_CHECK(bias.size(0) == N); TORCH_CHECK(weight_scale.isFloatingPoint()); - TORCH_CHECK(weight_zero_point.isIntegral()); + TORCH_CHECK(weight_zero_point.isIntegral(false)); // Calculate statistics for quantization of the input Tensor float x_min, x_max; diff --git a/aten/src/ATen/native/README.md b/aten/src/ATen/native/README.md index 4e672ceda6f59..a137145529df8 100644 --- a/aten/src/ATen/native/README.md +++ b/aten/src/ATen/native/README.md @@ -273,17 +273,17 @@ that case, code generation of the device guard can be disabled by adding in which case this field would go away. If you have an opinion on the matter, please write in at https://github.com/pytorch/pytorch/issues/14234 -### `named_guard` +### `supports_named_tensor` ``` -named_guard: False +supports_named_tensor: True ``` Experimental: this option is ignored unless compiling with BUILD_NAMEDTENSOR=1. -By default, (`named_guard: True`) ATen code generation will generate a check +By default, (`supports_named_tensor: True`) ATen code generation will generate a check that all tensor inputs to the function are unnamed. This is used to incrementally implement named tensors; if a function supports named tensors, then it'll have -`named_guard: False`; otherwise, passing it a named tensor will error out. +`supports_named_tensor: True`; otherwise, passing it a named tensor will error out. ### `matches_jit_signature` @@ -299,6 +299,18 @@ with other components of PyTorch in order to reduce overall complexity. If you find yourself having to set this field to False add @gchanan to your PR's set of reviewers. +### `use_c10_dispatcher` + +``` +use_c10_dispatcher: True +``` + +This will indicate that the func signature only uses features supported by +the c10 dispatcher. With this flag, the operator will be added to the +c10 operator library and be available there. If enabling this works for your +operator, please do. For a few corner cases, enabling this might not compile +successfully, so setting this to false is a workaround. Also, False is the default. + ## Writing an implementation in C++ Implementations of native functions go in an appropriate C++ file in the diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp index 507c0052a65b9..1b01f644370d4 100644 --- a/aten/src/ATen/native/RNN.cpp +++ b/aten/src/ATen/native/RNN.cpp @@ -2,6 +2,8 @@ #include #include +#include +#include #include @@ -11,7 +13,7 @@ namespace { // Check if pytorch is compiled with MIOpen. bool use_miopen(const at::Tensor& input, const double dropout_state) { - bool is_miopen_acceptable = (input.scalar_type() == at::kFloat) && + bool is_miopen_acceptable = (input.scalar_type() == at::kFloat) && (detail::getCUDAHooks().compiledWithMIOpen()) && (input.is_cuda()) && (dropout_state == 0.0); @@ -170,10 +172,10 @@ struct QuantizedCellParamsDynamic { } Tensor linear_ih(const Tensor& input_ih) const { - const auto kFuncName = "quantized::fbgemm_linear_dynamic"; + const auto kFuncName = "quantized::linear_dynamic"; const auto kOvrldName = ""; const std::vector output_ih_list = - callOp(kFuncName, kOvrldName, input_ih, w_ih, b_ih); + callOp(kFuncName, kOvrldName, input_ih, w_ih); TORCH_INTERNAL_ASSERT( output_ih_list.size() == 1, "The output vector should have exact one element"); @@ -181,10 +183,10 @@ struct QuantizedCellParamsDynamic { return output_ih; } Tensor linear_hh(const Tensor& input_hh) const { - const auto kFuncName = "quantized::fbgemm_linear_dynamic"; + const auto kFuncName = "quantized::linear_dynamic"; const auto kOvrldName = ""; const std::vector output_hh_list = - callOp(kFuncName, kOvrldName, input_hh, w_hh, b_hh); + callOp(kFuncName, kOvrldName, input_hh, w_hh); TORCH_INTERNAL_ASSERT( output_hh_list.size() == 1, "The output vector should have exact one element"); @@ -278,12 +280,23 @@ static std::vector gather_quantized_params_dynamic( static at::Tensor undefined; std::vector result; TORCH_CHECK( - params.size() % 4 == 0, + params.size() % 2 == 0, "got an incorrect number of quantized RNN parameters"); - for (size_t i = 0; i < params.size(); i += 4) { - result.emplace_back(params[i], params[i + 1], params[i + 2], params[i + 3]); + // PackedLinearWeight is only defined when USE_FBGEMM is defined +#ifdef USE_FBGEMM + for (size_t i = 0; i < params.size(); i += 2) { + auto& packed_struct_ih = + cpp_custom_type_hack::cast(params[i]); + auto& packed_struct_hh = + cpp_custom_type_hack::cast(params[i + 1]); + auto bias_ih = packed_struct_ih.bias.value_or(undefined); + auto bias_hh = packed_struct_hh.bias.value_or(undefined); + result.emplace_back(params[i], params[i + 1], bias_ih, bias_hh); } return result; +#else // USE_FBGEMM + TORCH_INTERNAL_ASSERT(false, "Tried to use quantized RNN wihtout FBGEMM!") +#endif // USE_FBGEMM } static std::vector gather_quantized_params_fp16( @@ -966,7 +979,7 @@ std::tuple lstm( lstm_cudnn_stub(_input.type().device_type(), output, hy, cy, _input, hx, _params, has_biases, num_layers, dropout_p, train, bidirectional, batch_first); return std::make_tuple(output, hy, cy); - } + } if (use_miopen(_input, dropout_p)) { Tensor output, hy, cy; @@ -995,7 +1008,7 @@ std::tuple lstm( lstm_packed_cudnn_stub(data.type().device_type(), output, hy, cy, data, batch_sizes, hx, _params, has_biases, num_layers, dropout_p, train, bidirectional); return std::make_tuple(output, hy, cy); - } + } if (use_miopen(data, dropout_p)) { Tensor output, hy, cy; @@ -1003,7 +1016,7 @@ std::tuple lstm( _params, has_biases, num_layers, dropout_p, train, bidirectional); return std::make_tuple(output, hy, cy); } - + PackedSequence input { data, batch_sizes }; auto params = gather_params(_params, has_biases); auto result = _lstm_impl( @@ -1059,11 +1072,13 @@ std::tuple quantized_lstm( check_device(_input, _params, hx); auto input = batch_first ? _input.transpose(0, 1) : _input; TORCH_CHECK(has_biases, "quantized LSTM requires biases"); - TORCH_CHECK(result_dtype == at::kChar || result_dtype == at::kHalf, - "dtype is not supported"); + TORCH_CHECK( + result_dtype == at::kChar || result_dtype == at::kQInt8 || + result_dtype == at::kHalf, + "dtype is not supported"); std::tuple results; - if (result_dtype == at::kChar) { + if (result_dtype == at::kChar || result_dtype == at::kQInt8) { if (use_dynamic) { auto params = gather_quantized_params_dynamic(_params); results = _lstm_impl( diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 309599b17c925..2843584e21516 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -8,9 +8,8 @@ #include #include #include -#ifdef BUILD_NAMEDTENSOR #include -#endif +#include #include #include diff --git a/aten/src/ATen/native/SoftMax.cpp b/aten/src/ATen/native/SoftMax.cpp index 6fb6f1241ad2a..91a2f0137f1b8 100644 --- a/aten/src/ATen/native/SoftMax.cpp +++ b/aten/src/ATen/native/SoftMax.cpp @@ -5,9 +5,8 @@ #include #include #include -#ifdef BUILD_NAMEDTENSOR #include -#endif +#include namespace at { namespace native { @@ -41,7 +40,7 @@ void host_softmax(Tensor output, const Tensor& input, const int64_t dim) { for (int64_t d = 1; d < dim_size; d++) max_input = std::max(max_input, input_data[d * dim_stride]); - scalar_t tmpsum = 0; + acc_type tmpsum = 0; for (int64_t d = 0; d < dim_size; d++) { scalar_t z = std::exp(input_data[d * dim_stride] - max_input); if (!LogSoftMax) { @@ -97,7 +96,7 @@ void host_softmax_backward( const scalar_t* gradOutput_data = gradOutput_data_base + outer_idx * outer_stride + inner_idx; - scalar_t sum = 0; // TODO was accreal here + acc_type sum = 0; for (int64_t d = 0; d < dim_size; d++) if (LogSoftMax) sum += gradOutput_data[d * dim_stride]; @@ -160,9 +159,9 @@ Tensor log_softmax_cpu(const Tensor& input_, const int64_t dim_, const bool half if (input.ndimension() > 0 && dim == input.ndimension() - 1) { log_softmax_lastdim_kernel(kCPU, output, input); } else { - AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "log_softmax", [&] { - host_softmax(output, input, dim); - }); + AT_DISPATCH_FLOATING_TYPES_AND( + at::ScalarType::BFloat16, input.scalar_type(), "log_softmax", + [&] { host_softmax(output, input, dim); }); } return output; } @@ -224,9 +223,11 @@ Tensor log_softmax_backward_cpu( if (grad.ndimension() > 0 && dim == grad.ndimension() - 1) { log_softmax_backward_lastdim_kernel(kCPU, grad_input, grad, output); } else { - AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "log_softmax_backward", [&] { - host_softmax_backward(grad_input, grad, output, dim); - }); + AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::BFloat16, grad.scalar_type(), + "log_softmax_backward", [&] { + host_softmax_backward( + grad_input, grad, output, dim); + }); } return grad_input; } diff --git a/aten/src/ATen/native/Sorting.cpp b/aten/src/ATen/native/Sorting.cpp index 405899a8a7e1b..a87804b7d93fc 100644 --- a/aten/src/ATen/native/Sorting.cpp +++ b/aten/src/ATen/native/Sorting.cpp @@ -6,6 +6,7 @@ #include #include #include +#include namespace at { namespace native { diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index abb7e108534ef..c9a48305f96c8 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -8,6 +8,7 @@ #include #include #include +#include namespace { template @@ -51,8 +52,22 @@ bool allclose(const Tensor& self, const Tensor& other, double rtol, double atol, Tensor isclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) { // TODO: use bitwise operator overloads once we add them + + TORCH_CHECK(self.scalar_type() == other.scalar_type(), self.scalar_type(), " did not match ", other.scalar_type()) + auto actual_error = (self - other).abs(); auto max_error = atol + rtol * other.abs(); + + // `max_error` could be a float or double depending on the type of the input + // tensors. + // Specifically, if other is an int tensor, multiplying by rtol results in + // float tensor. + // It is also possible for parameters to be 'wrapped_number's, in which case + // max_error could be promoted to double when actual error is still a float. + if (actual_error.scalar_type() != max_error.scalar_type()) { + actual_error = actual_error.to(max_error.scalar_type()); + } + auto close = actual_error <= max_error; if (isFloatingType(self.scalar_type()) && isFloatingType(other.scalar_type())) { @@ -84,8 +99,10 @@ bool is_nonzero(const Tensor& self) { Scalar localScalar = self.item(); if (localScalar.isFloatingPoint()) { return localScalar.to() != 0; - } else if (localScalar.isIntegral()){ + } else if (localScalar.isIntegral(false)){ return localScalar.to() != 0; + } else if (localScalar.isBoolean()) { + return localScalar.to(); } AT_ERROR("expected non-Tensor backed scalar"); } diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 4b74a8c3ab62f..551be074df9cf 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -17,9 +17,8 @@ #include #include #include -#ifdef BUILD_NAMEDTENSOR #include -#endif +#include #include #include @@ -219,18 +218,32 @@ Tensor empty_like( "It is currently not supported to specify a dtype that doesn't match " "the input tensor's dtype via empty_like. Specified: ", options.dtype(), " Input tensor's dtype: ", self.dtype()); - // TODO: uncomment when qscheme diff is landed - // TORCH_INTERNAL_ASSERT(self.qscheme(), at::kPerTensorAffine, - // "empty_like for quantized Tensor only works for - // PerTensorAffine scheme right now"); - return at::_empty_affine_quantized(self.sizes(), options, - self.q_scale(), - self.q_zero_point(), - use_memory_format); + auto qscheme = self.qscheme(); + if (qscheme == kPerTensorAffine) { + return at::_empty_affine_quantized(self.sizes(), options, + self.q_scale(), + self.q_zero_point(), + use_memory_format); + } else if (qscheme == kPerChannelAffine) { + // Copy the tensors with channels to avoid accidental overrides + return at::_empty_per_channel_affine_quantized_like( + self.q_per_channel_scales().clone(), + self.q_per_channel_zero_points().clone(), + self.sizes(), + self.q_per_channel_axis(), + options, + use_memory_format); + } else { + TORCH_CHECK(false, "Unsupported qscheme: ", toString(qscheme)); + } } #ifdef BUILD_NAMEDTENSOR - return at::empty(self.sizes(), self.opt_names(), options, use_memory_format); + if (self.opt_names()) { + return at::empty(self.sizes(), self.opt_names(), options, use_memory_format); + } else { + return at::empty(self.sizes(), options, use_memory_format); + } #else return at::empty(self.sizes(), options, use_memory_format); #endif diff --git a/aten/src/ATen/native/TensorIterator.cpp b/aten/src/ATen/native/TensorIterator.cpp index d7a911aca4a64..340a43009589a 100644 --- a/aten/src/ATen/native/TensorIterator.cpp +++ b/aten/src/ATen/native/TensorIterator.cpp @@ -3,6 +3,7 @@ #include #include #include +#include namespace at { @@ -77,8 +78,13 @@ compute_result_type(at::ArrayRef operands, const F& predicate) { ScalarType dtype = ScalarType::Undefined; for (auto& op : operands) { if (!op.tensor.defined()) continue; - if (!predicate(op.tensor)) continue; - auto tensor_dtype = op.tensor.scalar_type(); + if (!predicate(op)) continue; + ScalarType tensor_dtype; + if (op.tensor.unsafeGetTensorImpl()->is_wrapped_number() && isFloatingType(op.tensor.scalar_type())) { + tensor_dtype = typeMetaToScalarType(caffe2::get_default_dtype()); + } else { + tensor_dtype = op.tensor.scalar_type(); + } if (dtype == ScalarType::Undefined) { dtype = tensor_dtype; device = op.tensor.device(); @@ -101,36 +107,120 @@ compute_result_type(at::ArrayRef operands, return compute_result_type(operands, predicates...); } -std::tuple TensorIterator::compute_common_type() { +static std::tuple compute_common_type_(at::ArrayRef operands) { // See [Result type computation] in TensorIterator.h + auto result_type = - compute_result_type(operands_, - [](const Tensor& t) { return t.dim() > 0; }, - [](const Tensor& t) { return !t.unsafeGetTensorImpl()->is_wrapped_number(); }, - [](const Tensor& t) { return true; }); + compute_result_type(operands, + [](const OperandInfo& op) { return op.tensor.dim() > 0; }, + [](const OperandInfo& op) { return !op.tensor.unsafeGetTensorImpl()->is_wrapped_number(); }, + [](const OperandInfo& op) { return true; }); + + if (ScalarType::Bool == std::get<1>(result_type)) { + auto alternate = compute_result_type(operands, + [](const OperandInfo& op) { + return op.tensor.dim() == 0; + } + ); + if (std::get<1>(alternate) != ScalarType::Undefined) { + // preserve device from original result + return std::make_tuple(std::get<0>(result_type), std::get<1>(alternate)); + } + } + + // if non-zero-dim tensor result is an integral type and there's a zero-dim + // floating point operand, we'll promote the floating point type. + if (isIntegralType(std::get<1>(result_type), false)) { + auto alternate = compute_result_type(operands, + [](const OperandInfo& op) { + return isFloatingType(op.tensor.scalar_type()) && op.tensor.dim() == 0; + } + ); + if (std::get<1>(alternate) != ScalarType::Undefined) { + // preserve device from original result + return std::make_tuple(std::get<0>(result_type), std::get<1>(alternate)); + } + } TORCH_INTERNAL_ASSERT(std::get<1>(result_type) != ScalarType::Undefined); return result_type; } +std::tuple TensorIterator::compute_common_type() { + return compute_common_type_(operands_); +} + +static void validate_dtype(OperandInfo& op, ScalarType common_dtype, int ninputs) { + if (op.tensor.defined()) { + // For binary_ops, we follow casting rules. For unary/nullary types + // we require the type to match. + if (op.is_output) { + if (!canCast(common_dtype, op.tensor.scalar_type())) + { + AT_ERROR("result type ", common_dtype, + " can't be cast to the desired output type ", + op.tensor.scalar_type()); + } + } + if (ninputs < 2 && op.dtype != op.tensor.scalar_type()) { + AT_ERROR("expected dtype ", op.dtype, " but got dtype ", op.tensor.scalar_type()); + } + } +} + +static void maybe_promote_common_dtype(OperandInfo& op, ScalarType common_dtype) { + if (op.tensor.defined() && op.tensor.scalar_type() != common_dtype) + { + op.dtype = common_dtype; + op.original_tensor = op.tensor; + op.tensor = op.tensor.to(common_dtype); + auto original_element_size = op.original_tensor.element_size(); + auto new_element_size = op.tensor.element_size(); + + // stride size (in bytes) can change if we change the dtype. + for( size_t i=0; i < op.stride_bytes.size(); i++ ) { + auto stride = op.stride_bytes[i] / original_element_size; + op.stride_bytes[i] = stride * new_element_size; + } + } +} + void TensorIterator::compute_types() { bool missing_dtypes = false; + bool missing_output_dtypes = false; + bool has_read_write_op = false; + ScalarType common_dtype = dtype(); for (auto& op : operands_) { if (!op.tensor.defined() && !op.is_type_defined()) { missing_dtypes = true; + if (op.is_output) { + missing_output_dtypes = true; + } } + if (op.is_read_write) { + has_read_write_op = true; + } + } + + if (compute_common_dtype_strategy_ == CommonDTypeStrategy::COMPUTE_INPUTS) { + TORCH_CHECK(!missing_output_dtypes, "unable to compute and promote common dtype based only on inputs if there are missing dtypes for outputs"); + TORCH_CHECK(!has_read_write_op, "unable to compute and promote common dtype based only on inputs if input is same as output"); } - if (missing_dtypes || compute_common_dtype_) { - auto common_type = compute_common_type(); + bool compute_common_dtype = (compute_common_dtype_strategy_ != CommonDTypeStrategy::COMPUTE_NONE); + bool compute_common_dtype_only_for_inputs = (compute_common_dtype_strategy_ == CommonDTypeStrategy::COMPUTE_INPUTS); + + if (missing_dtypes || compute_common_dtype) { + auto operands = compute_common_dtype_only_for_inputs ? at::ArrayRef(operands_).slice(noutputs()) : operands_; + auto common_type = compute_common_type_(operands); auto common_device = std::get<0>(common_type); - auto common_dtype = std::get<1>(common_type); + common_dtype = std::get<1>(common_type); bool has_cpu_scalar = false; for (auto& op : operands_) { if (!op.is_type_defined()) { op.device = common_device; op.dtype = common_dtype; - } else if (compute_common_dtype_ && + } else if (compute_common_dtype && (op.device != common_device || op.dtype != common_dtype)) { if (allow_cpu_scalars_ && op.tensor.defined() && op.tensor.dim() == 0 && common_device.is_cuda() && op.tensor.device().is_cpu() && @@ -149,26 +239,31 @@ void TensorIterator::compute_types() { op.dtype = op.tensor.scalar_type(); } else { op.device = common_device; - op.dtype = common_dtype; + if (compute_common_dtype_only_for_inputs && op.is_output) { + op.dtype = op.tensor.scalar_type(); + } else { + op.dtype = common_dtype; + } } } - } - } - for (auto& op : operands_) { - auto& tensor = op.tensor; - if (!tensor.defined()) { - continue; - } - if (op.device != tensor.device() || op.dtype != tensor.scalar_type()) { - if (op.is_output) { - AT_ERROR("output with device ", tensor.device(), " and dtype ", tensor.scalar_type(), - " doesn't match the desired device ", op.device, " and dtype ", op.dtype); - } else if (tensor.dim() == 0) { - tensor = tensor.to(op.options()); - } else { - AT_ERROR("expected device ", op.device, " and dtype ", op.dtype, - " but got device ", tensor.device(), " and dtype ", tensor.scalar_type()); + if (!compute_common_dtype_only_for_inputs) { + validate_dtype(op, common_dtype, ninputs()); + } + if (!compute_common_dtype_only_for_inputs || !op.is_output) { + maybe_promote_common_dtype(op, common_dtype); + } + + if (op.tensor.defined() && op.device != op.tensor.device()) { + if (op.is_output) { + AT_ERROR("output with device ", op.tensor.device(), + " doesn't match the desired device ", op.device); + } else if (op.tensor.dim() == 0) { + op.tensor = op.tensor.to(op.options()); + } else { + AT_ERROR("expected device ", op.device, + " but got device ", op.tensor.device()); + } } } } diff --git a/aten/src/ATen/native/TensorIterator.h b/aten/src/ATen/native/TensorIterator.h index 2e8debba8796b..5be53910871ab 100644 --- a/aten/src/ATen/native/TensorIterator.h +++ b/aten/src/ATen/native/TensorIterator.h @@ -7,9 +7,8 @@ #include #include #include -#ifdef BUILD_NAMEDTENSOR #include -#endif +#include // TensorIterator is a helper class for element-wise operations, such as // arithmetic, comparisions, and trigonometric functions. It handles @@ -86,11 +85,15 @@ struct CAFFE2_API OperandInfo { /// Stride after broadcasting. The stride is in bytes, not number of elements. DimVector stride_bytes; - /// The original tensor operand. Note that the strides, data pointer, and + /// The tensor operand. Note that the strides, data pointer, and /// other attributes may differ due to dimension reordering and /// coalescing. Tensor tensor; + // Save the original tensor operand in cases when an output is modified + // (e.g. if dtype is changed) + Tensor original_tensor; + /// The desired device and type for the operand. For inputs, this specifies that /// the input should be converted to this type if necessary. For outputs, this /// specifies which type to allocate. Note that there is very limited support @@ -120,6 +123,12 @@ struct CAFFE2_API OperandInfo { struct SplitUntil32Bit; +enum class CommonDTypeStrategy : uint8_t { + COMPUTE_ALL = 0, // Compute common dtype based on inputs and outputs. Try to promote common dtype to inputs and outputs + COMPUTE_INPUTS = 1, // Compute common dtype based only on inputs. Try to promote common dtype only to inputs + COMPUTE_NONE = 2, // Do not compute and promote common dtype +}; + struct CAFFE2_API TensorIterator { using DimMask = std::bitset<64>; using PtrVector = SmallVector; @@ -174,7 +183,8 @@ struct CAFFE2_API TensorIterator { /// Accessors for each operand IntArrayRef strides(int arg) const { return operands_[arg].stride_bytes; } void* data_ptr(int arg) const; - ScalarType dtype(int arg=0) const { return operands_[arg].dtype; } + ScalarType dtype(int arg=0) const { return operands_[arg].tensor.scalar_type(); } + ScalarType input_dtype(int arg=0) const { return operands_[num_outputs_ + arg].dtype; } Device device(int arg=0) const { return operands_[arg].device; } DeviceType device_type(int arg=0) const { return device(arg).type(); } int64_t element_size(int arg) const { return elementSize(dtype(arg)); } @@ -189,6 +199,17 @@ struct CAFFE2_API TensorIterator { return operands_[arg].tensor; } + void cast_outputs() { + if (compute_common_dtype_strategy_ == CommonDTypeStrategy::COMPUTE_ALL) { + for(int i=0; i < noutputs(); i++) { + if (operands_[i].original_tensor.defined() && dtype(i) != operands_[i].original_tensor.scalar_type()) { + operands_[i].original_tensor.copy_(operands_[i].tensor); + operands_[i].tensor = operands_[i].original_tensor; + } + } + } + } + Tensor input(int arg=0) const { AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_); return operands_[num_outputs_ + arg].tensor; @@ -282,7 +303,11 @@ struct CAFFE2_API TensorIterator { } void dont_compute_common_dtype() { - compute_common_dtype_ = false; + compute_common_dtype_strategy_ = CommonDTypeStrategy::COMPUTE_NONE; + } + + void compute_common_dtype_only_for_inputs() { + compute_common_dtype_strategy_ = CommonDTypeStrategy::COMPUTE_INPUTS; } void dont_resize_outputs() { @@ -315,11 +340,11 @@ struct CAFFE2_API TensorIterator { #endif SmallVector operands_; int num_outputs_ = 0; + CommonDTypeStrategy compute_common_dtype_strategy_ = CommonDTypeStrategy::COMPUTE_ALL; bool has_coalesced_dimensions_ = false; bool accumulate_ = false; bool resize_outputs_ = true; bool is_reduction_ = false; - bool compute_common_dtype_ = true; bool allow_cpu_scalars_ = false; bool promote_gpu_output_dtypes_ = false; bool final_output_ = true; diff --git a/aten/src/ATen/native/TensorProperties.cpp b/aten/src/ATen/native/TensorProperties.cpp index de1e8c6ac5aa4..fa10c65660f54 100644 --- a/aten/src/ATen/native/TensorProperties.cpp +++ b/aten/src/ATen/native/TensorProperties.cpp @@ -2,9 +2,8 @@ #include #include #include -#ifdef BUILD_NAMEDTENSOR #include -#endif +#include #include namespace at { diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 6ee2ffb493617..7afdc8a52e03f 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -13,9 +13,8 @@ #include #include #include -#ifdef BUILD_NAMEDTENSOR #include -#endif +#include namespace at { namespace native { @@ -342,16 +341,19 @@ Tensor sum_to_size(const Tensor& self, IntArrayRef size) { Tensor as_strided_tensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef stride, optional storage_offset_) { auto storage_offset = storage_offset_.value_or(self.storage_offset()); - auto tid = self.type_id(); - auto result = detail::make_tensor(Storage(self.storage()), tid); + auto result = detail::make_tensor(Storage(self.storage()), self.type_set()); setStrided(result, size, stride, storage_offset); return result; } Tensor as_strided_qtensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef stride, optional storage_offset_) { auto storage_offset = storage_offset_.value_or(self.storage_offset()); - auto tid = self.type_id(); - auto result = detail::make_tensor(Storage(self.storage()), tid, get_qtensorimpl(self)->quantizer()); + auto quantizer = get_qtensorimpl(self)->quantizer(); + TORCH_CHECK( + quantizer->qscheme() == QScheme::PER_TENSOR_AFFINE, + "Setting strides is possible only on uniformly quantized tensor"); + auto result = detail::make_tensor( + Storage(self.storage()), self.type_set(), quantizer); setStrided(result, size, stride, storage_offset); return result; } @@ -1037,6 +1039,39 @@ Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim) { return self.reshape(shape); } +#ifdef BUILD_NAMEDTENSOR +Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim, Dimname out_dim) { + auto outnames = self.names().vec(); + outnames.erase(outnames.begin() + start_dim, outnames.begin() + end_dim + 1); + outnames.insert(outnames.begin() + start_dim, out_dim); + + Tensor result; + { + NoNamesGuard guard; + result = native::flatten(self, start_dim, end_dim); + } + internal_set_names_inplace(result, outnames); + return result; +} + +Tensor flatten(const Tensor& self, Dimname start_dim, Dimname end_dim, Dimname out_dim) { + auto start_pos = dimname_to_position(self, start_dim); + auto end_pos = dimname_to_position(self, end_dim); + return native::flatten(self, start_pos, end_pos, out_dim); +} + +Tensor flatten(const Tensor& self, DimnameList dims, Dimname out_dim) { + auto positions = dimnames_to_positions(self, dims); + for (size_t i = 0; i < positions.size() - 1; i++) { + if (positions[i] + 1 == positions[i + 1]) continue; + TORCH_CHECK(positions[i] + 1 == positions[i + 1], + "flatten(tensor, dims, out_dim): dims ", dims, " must be consecutive ", + "in Tensor", self.names()); + } + return native::flatten(self, *dims.begin(), *(dims.end() - 1), out_dim); +} +#endif + Tensor view_as(const Tensor& self, const Tensor& other) { return self.view(other.sizes()); } @@ -1055,6 +1090,12 @@ std::vector unbind(const Tensor &self, int64_t dim) { return tensors; } +#ifdef BUILD_NAMEDTENSOR +std::vector unbind(const Tensor& self, Dimname dim) { + return at::unbind(self, dimname_to_position(self, dim)); +} +#endif + std::vector meshgrid(TensorList tensors) { int64_t size = tensors.size(); TORCH_CHECK(size > 0, "meshgrid expects a non-empty TensorList"); @@ -1115,14 +1156,14 @@ Tensor alias(const Tensor& self) { if (self.is_quantized()) { auto impl = c10::make_intrusive( Storage(self.storage()), - self.type_id(), + self.type_set(), get_qtensorimpl(self)->quantizer()); impl->set_storage_offset(self.storage_offset()); impl->set_sizes_and_strides(self.sizes(), self.strides()); self_ = Tensor(std::move(impl)); } else { auto impl = c10::make_intrusive(Storage(self.storage()), - self.type_id()); + self.type_set()); impl->set_storage_offset(self.storage_offset()); impl->set_sizes_and_strides(self.sizes(), self.strides()); self_ = Tensor(std::move(impl)); diff --git a/aten/src/ATen/native/TypeProperties.cpp b/aten/src/ATen/native/TypeProperties.cpp index cb69da72475a3..59092a6ca088c 100644 --- a/aten/src/ATen/native/TypeProperties.cpp +++ b/aten/src/ATen/native/TypeProperties.cpp @@ -22,15 +22,7 @@ bool is_floating_point(const Tensor& self) { } bool is_signed(const Tensor &self) { - if (self.scalar_type() == ScalarType::Half) { - return true; - } - if (self.scalar_type() == ScalarType::BFloat16) { - return true; - } - return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "is_signed", [&]() -> bool { - return std::is_signed(); - }); + return at::isSignedType(self.scalar_type()); } bool is_sparse(const Tensor& self) { @@ -45,7 +37,7 @@ bool is_quantized(const Tensor& self) { // TensorImpl can be copied to `self`. bool _has_compatible_shallow_copy_type(const Tensor& self, const Tensor& from) { return self.unsafeGetTensorImpl()->has_compatible_shallow_copy_type( - from.type_id()); + from.type_set()); } Tensor type_as(const Tensor& self, const Tensor& other) { diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index afc3c3d5bc5c5..6bf1eb92d15ce 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -16,9 +16,8 @@ #include #include #include -#ifdef BUILD_NAMEDTENSOR #include -#endif +#include #include #include @@ -44,6 +43,9 @@ static inline Tensor& unary_op_impl_out(Tensor& result, const Tensor& self, Stub return result; } +// out_impl passed into unary_op_impl and unary_op_impl_ must go through at:: device dispatch +// otherwise it won't dispatch to out-of-source devices like XLA. +// For example it must be at::bitwise_not_out instead of bitwise_not_out(which is at::native!). template static inline Tensor unary_op_impl(const Tensor& self, OutImpl& out_impl) { Tensor result = at::empty({0}, self.options()); @@ -56,16 +58,28 @@ static inline Tensor& unary_op_impl_(Tensor& self, OutImpl& out_impl) { } Tensor& bitwise_not_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, bitwise_not_stub); } -Tensor bitwise_not(const Tensor& self) { return unary_op_impl(self, bitwise_not_out); } -Tensor& bitwise_not_(Tensor& self) { return unary_op_impl_(self, bitwise_not_out); } +Tensor bitwise_not(const Tensor& self) { return unary_op_impl(self, at::bitwise_not_out); } +Tensor& bitwise_not_(Tensor& self) { return unary_op_impl_(self, at::bitwise_not_out); } Tensor& ceil_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, ceil_stub); } -Tensor ceil(const Tensor& self) { return unary_op_impl(self, ceil_out); } -Tensor& ceil_(Tensor& self) { return unary_op_impl_(self, ceil_out); } +Tensor ceil(const Tensor& self) { return unary_op_impl(self, at::ceil_out); } +Tensor& ceil_(Tensor& self) { return unary_op_impl_(self, at::ceil_out); } Tensor& erfinv_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, erfinv_stub); } -Tensor erfinv(const Tensor& self) { return unary_op_impl(self, erfinv_out); } -Tensor& erfinv_(Tensor& self) { return unary_op_impl_(self, erfinv_out); } +Tensor erfinv(const Tensor& self) { return unary_op_impl(self, at::erfinv_out); } +Tensor& erfinv_(Tensor& self) { return unary_op_impl_(self, at::erfinv_out); } + +Tensor& round_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, round_stub); } +Tensor round(const Tensor& self) { return unary_op_impl(self, at::round_out); } +Tensor& round_(Tensor& self) { return unary_op_impl_(self, at::round_out); } + +Tensor& digamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, digamma_stub); } +Tensor digamma(const Tensor& self) { return unary_op_impl(self, digamma_out); } +Tensor& digamma_(Tensor& self) { return unary_op_impl_(self, digamma_out); } + +Tensor& rsqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, rsqrt_stub); } +Tensor rsqrt(const Tensor& self) { return unary_op_impl(self, at::rsqrt_out); } +Tensor& rsqrt_(Tensor& self) { return unary_op_impl_(self, at::rsqrt_out); } Tensor& neg_out(Tensor& result, const Tensor& self) { TORCH_CHECK(self.scalar_type() != kBool, @@ -73,8 +87,8 @@ Tensor& neg_out(Tensor& result, const Tensor& self) { "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead."); return unary_op_impl_out(result, self, neg_stub); } -Tensor neg(const Tensor& self) { return unary_op_impl(self, neg_out); } -Tensor& neg_(Tensor& self) { return unary_op_impl_(self, neg_out); } +Tensor neg(const Tensor& self) { return unary_op_impl(self, at::neg_out); } +Tensor& neg_(Tensor& self) { return unary_op_impl_(self, at::neg_out); } Tensor logical_not(const Tensor& self) { Tensor result = at::empty({0}, self.options().dtype(kBool)); @@ -115,15 +129,6 @@ Tensor& _clamp__cpu(Tensor& self, optional min, optional max) { return clamp_out(self, self, min, max); } -//used internally and not exposed by API -Tensor& trigamma_out(Tensor& result, const Tensor& self) { - checkBackend("trigamma", result, Backend::CPU); - auto iter = TensorIterator::unary_op(result, self, - /*check_mem_overlap=*/true); - trigamma_stub(iter.device_type(), iter); - return result; -} - Tensor polygamma(int64_t n, const Tensor& self) { Tensor result = at::empty({0}, self.options()); at::polygamma_out(result, n, self); @@ -133,7 +138,7 @@ Tensor& polygamma_(Tensor& self, int64_t n) { return at::polygamma_out(self, n, self); } Tensor& polygamma_out(Tensor& result, int64_t n, const Tensor& self) { - checkBackend("polygamma", result, Backend::CPU); + TORCH_CHECK(n >= 0, "polygamma(n, x) does not support negative n."); auto iter = TensorIterator::unary_op(result, self, /*check_mem_overlap=*/true); polygamma_stub(iter.device_type(), iter, n); @@ -265,7 +270,6 @@ IMPLEMENT_UNARY_OP_VEC(asin) IMPLEMENT_UNARY_OP_VEC(atan) IMPLEMENT_UNARY_OP_VEC(cos) IMPLEMENT_UNARY_OP_VEC(cosh) -IMPLEMENT_UNARY_OP_VEC(digamma) IMPLEMENT_UNARY_OP_VEC(erf) IMPLEMENT_UNARY_OP_VEC(erfc) IMPLEMENT_UNARY_OP_VEC(exp) @@ -277,8 +281,6 @@ IMPLEMENT_UNARY_OP_VEC(log10) IMPLEMENT_UNARY_OP_VEC(log1p) IMPLEMENT_UNARY_OP_VEC(log2) IMPLEMENT_UNARY_OP_VEC(reciprocal) -IMPLEMENT_UNARY_OP_VEC(round) -IMPLEMENT_UNARY_OP_VEC(rsqrt) IMPLEMENT_UNARY_OP_VEC(sigmoid) IMPLEMENT_UNARY_OP_VEC(sin) IMPLEMENT_UNARY_OP_VEC(sinh) @@ -286,6 +288,7 @@ IMPLEMENT_UNARY_OP_VEC(sqrt) IMPLEMENT_UNARY_OP_VEC(tan) IMPLEMENT_UNARY_OP_VEC(tanh) IMPLEMENT_UNARY_OP_VEC(trunc) +IMPLEMENT_UNARY_OP_VEC(lgamma) DEFINE_DISPATCH(abs_stub); DEFINE_DISPATCH(acos_stub); @@ -323,7 +326,7 @@ DEFINE_DISPATCH(sinh_stub); DEFINE_DISPATCH(sqrt_stub); DEFINE_DISPATCH(tan_stub); DEFINE_DISPATCH(tanh_stub); -DEFINE_DISPATCH(trigamma_stub); DEFINE_DISPATCH(trunc_stub); +DEFINE_DISPATCH(lgamma_stub); } } // namespace at diff --git a/aten/src/ATen/native/UnaryOps.h b/aten/src/ATen/native/UnaryOps.h index 6a5a50828e621..0640bca5fd2a8 100644 --- a/aten/src/ATen/native/UnaryOps.h +++ b/aten/src/ATen/native/UnaryOps.h @@ -48,6 +48,7 @@ DECLARE_DISPATCH(unary_fn, tan_stub); DECLARE_DISPATCH(unary_fn, tanh_stub); DECLARE_DISPATCH(unary_fn, trigamma_stub); DECLARE_DISPATCH(unary_fn, trunc_stub); +DECLARE_DISPATCH(unary_fn, lgamma_stub); DECLARE_DISPATCH(void(*)(Tensor&, const double, Generator *), bernoulli_mkl_stub); DECLARE_DISPATCH(void(*)(TensorIterator&, const int64_t), polygamma_stub); diff --git a/aten/src/ATen/native/VariableMethodStubs.cpp b/aten/src/ATen/native/VariableMethodStubs.cpp index f5f77352d6d76..f3d01b9bce021 100644 --- a/aten/src/ATen/native/VariableMethodStubs.cpp +++ b/aten/src/ATen/native/VariableMethodStubs.cpp @@ -12,5 +12,17 @@ void set_data(const Tensor& self, const Tensor& new_data) { AT_ERROR("set_data is not implemented for Tensor"); } +Tensor data(const Tensor& self) { + AT_ERROR("data is not implemented for Tensor"); +} + +bool is_leaf(const Tensor& self) { + AT_ERROR("is_leaf is not implemented for Tensor"); +} + +int64_t output_nr(const Tensor& self) { + AT_ERROR("output_nr is not implemented for Tensor"); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp b/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp index 40b6d7c52e3e6..88a28318ca760 100644 --- a/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp @@ -354,7 +354,10 @@ struct Dist { const int64_t d = result.size(0); const int64_t l1_size = r1 * m; const int64_t l2_size = r2 * m; - const int64_t gs = grad.stride(-1); + //current implementation supports only tensor that can be collapsed to 1D. However, to avoid checking if grad satisfies this assumption, + //we call .contiguous() on grad before backward, thus stride is guaranteed to be 1 + //don't use grad.stride(-1), because if last dimension is 1, stride can be bogus. + const int64_t gs = 1; const scalar_t * const grad_start = grad.data_ptr(); const scalar_t * const dist_start = dist.data_ptr(); diff --git a/aten/src/ATen/native/cpu/LerpKernel.cpp b/aten/src/ATen/native/cpu/LerpKernel.cpp index 98b3ad34d8257..5e6d044030626 100644 --- a/aten/src/ATen/native/cpu/LerpKernel.cpp +++ b/aten/src/ATen/native/cpu/LerpKernel.cpp @@ -14,6 +14,10 @@ static void lerp_kernel_scalar( const Tensor& self, const Tensor& end, Scalar weight) { + // lerp() only uses TensorIterator for CPU. Since TensorIterator would + // would attempt to promote types inconsistent with the CUDA implementation, + // restrict types explicitly here + TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(), " for `end` but got dtype ", end.dtype()); auto iter = TensorIterator::binary_op(ret, self, end, /*check_mem_overlap=*/true); AT_DISPATCH_FLOATING_TYPES(ret.scalar_type(), "lerp_kernel_scalar", [&] { @@ -33,6 +37,11 @@ static void lerp_kernel_tensor( const Tensor& self, const Tensor& end, const Tensor& weights) { + // lerp() only uses TensorIterator for CPU. Since TensorIterator would + // would attempt to promote types inconsistent with the CUDA implementation, + // restrict types explicitly here + TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(), " for `end` but got dtype ", end.dtype()); + TORCH_CHECK(self.dtype() == weights.dtype(), "expected dtype ", self.dtype(), " for `weights` but got dtype ", end.dtype()); auto iter = TensorIterator(); iter.set_check_mem_overlap(true); iter.add_output(ret); diff --git a/aten/src/ATen/native/cpu/Loops.h b/aten/src/ATen/native/cpu/Loops.h index 1a7b4e6291e5d..e4bc9bc492b49 100644 --- a/aten/src/ATen/native/cpu/Loops.h +++ b/aten/src/ATen/native/cpu/Loops.h @@ -177,6 +177,7 @@ void cpu_kernel(TensorIterator& iter, func_t op) { }); } }); + iter.cast_outputs(); } template @@ -198,6 +199,7 @@ void cpu_kernel_vec(TensorIterator& iter, func_t op, vec_func_t vop) { }); } }); + iter.cast_outputs(); } template diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index faa98ba3c93f0..e806f0ccfcef7 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -15,12 +15,12 @@ namespace at { namespace native { namespace { using namespace vec256; static void sum_kernel_impl(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, iter.dtype(), "sum_cpu", [&] { - binary_kernel_reduce_vec( - iter, - [=](scalar_t a, scalar_t b) -> scalar_t { return a + b; }, - [=](Vec256 a, Vec256 b) { return a + b; }); - }); + AT_DISPATCH_ALL_TYPES_AND2( + ScalarType::BFloat16, ScalarType::Bool, iter.dtype(), "sum_cpu", [&] { + binary_kernel_reduce_vec( + iter, [=](scalar_t a, scalar_t b) -> scalar_t { return a + b; }, + [=](Vec256 a, Vec256 b) { return a + b; }); + }); } static void mean_kernel_impl(TensorIterator& iter) { @@ -58,7 +58,7 @@ static void norm_kernel_tensor_iterator_impl( TensorIterator& iter, Scalar p) { float val; - if (p.isIntegral()) { + if (p.isIntegral(false)) { val = p.to(); } else if (p.isFloatingPoint()) { val = p.to(); diff --git a/aten/src/ATen/native/cpu/SoftMaxKernel.cpp b/aten/src/ATen/native/cpu/SoftMaxKernel.cpp index 0bf1eaad1ed76..7ff81943e79a0 100644 --- a/aten/src/ATen/native/cpu/SoftMaxKernel.cpp +++ b/aten/src/ATen/native/cpu/SoftMaxKernel.cpp @@ -232,10 +232,10 @@ static void softmax_lastdim_kernel_impl(Tensor& result, const Tensor& self) { static void log_softmax_lastdim_kernel_impl( Tensor& result, const Tensor& self) { - AT_DISPATCH_FLOATING_TYPES( - self.scalar_type(), "log_softmax_lastdim_kernel_impl", [&] { - vec_host_softmax_lastdim::apply(result, self); - }); + AT_DISPATCH_FLOATING_TYPES_AND( + at::ScalarType::BFloat16, self.scalar_type(), + "log_softmax_lastdim_kernel_impl", + [&] { vec_host_softmax_lastdim::apply(result, self); }); } static void softmax_backward_lastdim_kernel_impl( @@ -253,8 +253,9 @@ static void log_softmax_backward_lastdim_kernel_impl( Tensor& grad_input, const Tensor& grad, const Tensor& output) { - AT_DISPATCH_FLOATING_TYPES( - grad.scalar_type(), "log_softmax_backward_lastdim_kernel_impl", [&] { + AT_DISPATCH_FLOATING_TYPES_AND( + at::ScalarType::BFloat16, grad.scalar_type(), + "log_softmax_backward_lastdim_kernel_impl", [&] { vec_host_softmax_backward_lastdim::apply( grad_input, grad, output); }); diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index 23e163d33bd79..cc7e47509b277 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -184,7 +184,7 @@ static void polygamma_kernel(TensorIterator& iter, int64_t n) { switch (n) { case 0: digamma_kernel(iter); break; case 1: trigamma_kernel(iter); break; - default: AT_ERROR("polygamma(n,x) is not implemented for n>=2"); + default: TORCH_CHECK("polygamma(n,x) is not implemented for n>=2, but was ", n); } } @@ -369,5 +369,6 @@ IMPLEMENT_FLOAT_KERNEL(FLOATING, sqrt) IMPLEMENT_FLOAT_KERNEL(FLOATING, tan) IMPLEMENT_FLOAT_KERNEL(FLOATING, tanh) IMPLEMENT_FLOAT_KERNEL(FLOATING, trunc) +IMPLEMENT_FLOAT_KERNEL(FLOATING, lgamma) }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/AveragePool3d.cu b/aten/src/ATen/native/cuda/AveragePool3d.cu index 17fd342878871..214e08d92bbf0 100644 --- a/aten/src/ATen/native/cuda/AveragePool3d.cu +++ b/aten/src/ATen/native/cuda/AveragePool3d.cu @@ -23,8 +23,8 @@ __device__ inline int max(int a, int b) { template __global__ void avg_pool3d_cuda_update_output( - PackedTensorAccessor input, - PackedTensorAccessor output, + PackedTensorAccessor64 input, + PackedTensorAccessor64 output, int kT, int kH, int kW, int dT, int dH, int dW, int padT, int padH, int padW, @@ -87,8 +87,8 @@ __global__ void avg_pool3d_cuda_update_output( // template __global__ void avg_pool3d_cuda_update_output( - PackedTensorAccessor input, - PackedTensorAccessor output, + PackedTensorAccessor64 input, + PackedTensorAccessor64 output, int kT, int kH, int dT, int dH, int dW, int padT, int padH, int padW, @@ -148,8 +148,8 @@ __global__ void avg_pool3d_cuda_update_output( template __global__ void avg_pool3d_single_backward_out_frame_stride1( - PackedTensorAccessor gradOutput, - PackedTensorAccessor gradInput, + PackedTensorAccessor64 gradOutput, + PackedTensorAccessor64 gradInput, int kT, int kH, int kW, accscalar_t normFactor, int offsetZ) @@ -193,8 +193,8 @@ __global__ void avg_pool3d_single_backward_out_frame_stride1( template __global__ void avg_pool3d_cuda_update_grad_input_atomic( - PackedTensorAccessor gradOutput, - PackedTensorAccessor gradInput, + PackedTensorAccessor64 gradOutput, + PackedTensorAccessor64 gradInput, int kT, int kH, int kW, int dT, int dH, int dW, int padT, int padH, int padW, @@ -251,8 +251,8 @@ __global__ void avg_pool3d_cuda_update_grad_input_atomic( template __global__ void avg_pool3d_cuda_update_grad_input( - PackedTensorAccessor gradOutput, - PackedTensorAccessor gradInput, + PackedTensorAccessor64 gradOutput, + PackedTensorAccessor64 gradInput, int kT, int kH, int kW, int dT, int dH, int dW, int padT, int padH, int padW, @@ -309,8 +309,8 @@ __global__ void avg_pool3d_cuda_update_grad_input( #define LAUNCH_UPDATE_OUTPUT_KERNEL_WIDTH(KW) case KW: \ avg_pool3d_cuda_update_output \ <<>>( \ - work_input.packed_accessor(), \ - work_output.packed_accessor(), \ + work_input.packed_accessor64(), \ + work_output.packed_accessor64(), \ kT, kH, \ dT, dH, dW, \ padT, padH, padW, \ @@ -425,8 +425,8 @@ void avg_pool3d_out_cuda_template( default: avg_pool3d_cuda_update_output <<>>( - work_input.packed_accessor(), - work_output.packed_accessor(), + work_input.packed_accessor64(), + work_output.packed_accessor64(), kT, kH, kW, dT, dH, dW, padT, padH, padW, @@ -567,8 +567,8 @@ void avg_pool3d_backward_out_cuda_template( avg_pool3d_single_backward_out_frame_stride1 <<>>( - work_grad_output.packed_accessor(), - work_grad_input.packed_accessor(), + work_grad_output.packed_accessor64(), + work_grad_input.packed_accessor64(), kT, kH, kW, 1.0f/divide_factor, offsetZ); @@ -600,8 +600,8 @@ void avg_pool3d_backward_out_cuda_template( if (kernelsOverlap) { avg_pool3d_cuda_update_grad_input_atomic <<>>( - work_grad_output.packed_accessor(), - work_grad_input.packed_accessor(), + work_grad_output.packed_accessor64(), + work_grad_input.packed_accessor64(), kT, kH, kW, dT, dH, dW, padT, padH, padW, @@ -611,8 +611,8 @@ void avg_pool3d_backward_out_cuda_template( else { avg_pool3d_cuda_update_grad_input <<>>( - work_grad_output.packed_accessor(), - work_grad_input.packed_accessor(), + work_grad_output.packed_accessor64(), + work_grad_input.packed_accessor64(), kT, kH, kW, dT, dH, dW, padT, padH, padW, diff --git a/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp b/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp index 157b4fa8d2523..8f1c7ec8c4398 100644 --- a/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp +++ b/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp @@ -1,8 +1,7 @@ #include #include -#ifdef BUILD_NAMEDTENSOR #include -#endif +#include namespace at { namespace native { @@ -82,8 +81,6 @@ IMPLEMENT_UNARY_OP_PREQUEL(log10) IMPLEMENT_UNARY_OP_PREQUEL(log1p) IMPLEMENT_UNARY_OP_PREQUEL(log2) IMPLEMENT_UNARY_OP_PREQUEL(reciprocal) -IMPLEMENT_UNARY_OP_PREQUEL(round) -IMPLEMENT_UNARY_OP_PREQUEL(rsqrt) IMPLEMENT_UNARY_OP_PREQUEL(sigmoid) IMPLEMENT_UNARY_OP_PREQUEL(sin) IMPLEMENT_UNARY_OP_PREQUEL(sinh) diff --git a/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu b/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu index f33cced00f699..0e9dee088897d 100644 --- a/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu +++ b/aten/src/ATen/native/cuda/DilatedMaxPool3d.cu @@ -20,8 +20,8 @@ __device__ inline int min(int a, int b) { template __global__ static void max_pool3d_with_indices_single_out_frame( scalar_t* inputData, - PackedTensorAccessor output, - PackedTensorAccessor indices, + PackedTensorAccessor64 output, + PackedTensorAccessor64 indices, int itime, int iheight, int iwidth, int kT, int kH, int kW, int dT, int dH, int dW, @@ -81,8 +81,8 @@ __global__ static void max_pool3d_with_indices_single_out_frame( template __global__ static void max_pool3d_with_indices_single_out_frame( scalar_t* inputData, - PackedTensorAccessor output, - PackedTensorAccessor indices, + PackedTensorAccessor64 output, + PackedTensorAccessor64 indices, int itime, int iheight, int iwidth, int kT, int kH, int dT, int dH, int dW, @@ -143,8 +143,8 @@ __global__ static void max_pool3d_with_indices_single_out_frame( max_pool3d_with_indices_single_out_frame \ <<>>( \ input_data, \ - output.packed_accessor(), \ - indices.packed_accessor(), \ + output.packed_accessor64(), \ + indices.packed_accessor64(), \ itime, iheight, iwidth, \ kT, kH, \ dT, dH, dW, \ @@ -185,8 +185,8 @@ void max_pool3d_with_indices_out_frame( max_pool3d_with_indices_single_out_frame <<>>( input_data, - output.packed_accessor(), - indices.packed_accessor(), + output.packed_accessor64(), + indices.packed_accessor64(), itime, iheight, iwidth, kT, kH, kW, dT, dH, dW, @@ -209,8 +209,8 @@ void max_pool3d_with_indices_out_frame( template __global__ static void max_pool3d_with_indices_backward_single_out_frame( scalar_t *gradInputData, - PackedTensorAccessor gradOutput, - PackedTensorAccessor indices, + PackedTensorAccessor64 gradOutput, + PackedTensorAccessor64 indices, int itime, int iheight, int iwidth, int dT, int dH, int dW, int pT, int pH, int pW, @@ -255,8 +255,8 @@ void max_pool3d_with_indices_backward_out_frame( max_pool3d_with_indices_backward_single_out_frame <<>>( gradInputData, - gradOutput.packed_accessor(), - indices.packed_accessor(), + gradOutput.packed_accessor64(), + indices.packed_accessor64(), itime, iheight, iwidth, dT, dH, dW, pT, pH, pW, diff --git a/aten/src/ATen/native/cuda/DistanceKernel.cu b/aten/src/ATen/native/cuda/DistanceKernel.cu index 5242d1b9e4402..099adac66ecda 100644 --- a/aten/src/ATen/native/cuda/DistanceKernel.cu +++ b/aten/src/ATen/native/cuda/DistanceKernel.cu @@ -5,6 +5,7 @@ #include +#include namespace at { namespace native { @@ -12,12 +13,6 @@ namespace { static const int forward_threads = 256; -#ifdef __HIP_PLATFORM_HCC__ -static const int WARP_SIZE = 64; -#else -static const int WARP_SIZE = 32; -#endif - template static __forceinline__ __device__ scalar_t device_sqrt(scalar_t val); @@ -233,15 +228,15 @@ void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, doubl AT_DISPATCH_FLOATING_TYPES(x1.scalar_type(), "cdist_cuda", [&] { if (p == 0.0) { - cdist_kernel_cuda_impl::zero><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); + cdist_kernel_cuda_impl::zero><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); } else if (p == 1.0) { - cdist_kernel_cuda_impl::one><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); + cdist_kernel_cuda_impl::one><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); } else if (p == 2.0) { - cdist_kernel_cuda_impl::two><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); + cdist_kernel_cuda_impl::two><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); } else if (std::isinf(p)) { - cdist_kernel_cuda_impl::inf><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); + cdist_kernel_cuda_impl::inf><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); } else { - cdist_kernel_cuda_impl::p><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); + cdist_kernel_cuda_impl::p><<>>(result.data_ptr(), x1.data_ptr(), x2.data_ptr(), p, r1, r2, m, r_size, l1_size, l2_size); } }); AT_CUDA_CHECK(cudaGetLastError()); @@ -259,15 +254,15 @@ void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, double p) { AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_cuda", [&] { if (p == 0.0) { - pdist_kernel_cuda_impl::zero><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); + pdist_kernel_cuda_impl::zero><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); } else if (p == 1.0) { - pdist_kernel_cuda_impl::one><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); + pdist_kernel_cuda_impl::one><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); } else if (p == 2.0) { - pdist_kernel_cuda_impl::two><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); + pdist_kernel_cuda_impl::two><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); } else if (std::isinf(p)) { - pdist_kernel_cuda_impl::inf><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); + pdist_kernel_cuda_impl::inf><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); } else { - pdist_kernel_cuda_impl::p><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); + pdist_kernel_cuda_impl::p><<>>(result.data_ptr(), self.data_ptr(), n, m, p, n2, n2_squared_minus_1); } }); AT_CUDA_CHECK(cudaGetLastError()); @@ -298,15 +293,15 @@ void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor Tensor buffer = at::empty({n - 1, result.size(0), result.size(1)}, result.options()); AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_cuda_backward", [&] { if (p == 1.0) { - pdist_backward_kernel_cuda_impl::one><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); + pdist_backward_kernel_cuda_impl::one><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); } else if (p < 2.0) { - pdist_backward_kernel_cuda_impl::lt_two><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); + pdist_backward_kernel_cuda_impl::lt_two><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); } else if (p == 2.0) { - pdist_backward_kernel_cuda_impl::two><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); + pdist_backward_kernel_cuda_impl::two><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); } else if (std::isinf(p)) { - pdist_backward_kernel_cuda_impl::inf><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); + pdist_backward_kernel_cuda_impl::inf><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); } else { - pdist_backward_kernel_cuda_impl::p><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); + pdist_backward_kernel_cuda_impl::p><<>>(buffer.data_ptr(), grad.data_ptr(), self.data_ptr(), dist.data_ptr(), grad.stride(0), n, m, dist.numel(), p, n2, n2_squared_minus_1); } }); AT_CUDA_CHECK(cudaGetLastError()); @@ -336,19 +331,32 @@ void cdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor const int64_t r_size = r1 * r2; const int64_t l1_size = r1 * m; const int64_t l2_size = r2 * m; + //current implementation supports only gradient that can be collapsed to 1D. However, to avoid checking this assumption, + //we call grad.contiguous() before backward, so stride is guaranteed to be 1 + const int64_t gs = 1; Tensor buffer = (x1.dim() > 2) ? at::empty({batch, r2, r1, m}, result.options()) : at::empty({r2, r1, m}, result.options()); AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist_cuda_backward", [&] { if (p == 1.0) { - cdist_backward_kernel_cuda_impl::one><<>>(buffer.data_ptr(), grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), grad.stride(-1), p, r1, r2, m, count, r_size, l1_size, l2_size); + cdist_backward_kernel_cuda_impl::one><<>>(buffer.data_ptr(), + grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), + gs, p, r1, r2, m, count, r_size, l1_size, l2_size); } else if (p < 2.0) { - cdist_backward_kernel_cuda_impl::lt_two><<>>(buffer.data_ptr(), grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), grad.stride(-1), p, r1, r2, m, count, r_size, l1_size, l2_size); + cdist_backward_kernel_cuda_impl::lt_two><<>>(buffer.data_ptr(), + grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), + gs, p, r1, r2, m, count, r_size, l1_size, l2_size); } else if (p == 2.0) { - cdist_backward_kernel_cuda_impl::two><<>>(buffer.data_ptr(), grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), grad.stride(-1), p, r1, r2, m, count, r_size, l1_size, l2_size); + cdist_backward_kernel_cuda_impl::two><<>>(buffer.data_ptr(), + grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), + gs, p, r1, r2, m, count, r_size, l1_size, l2_size); } else if (std::isinf(p)) { - cdist_backward_kernel_cuda_impl::inf><<>>(buffer.data_ptr(), grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), grad.stride(-1), p, r1, r2, m, count, r_size, l1_size, l2_size); + cdist_backward_kernel_cuda_impl::inf><<>>(buffer.data_ptr(), + grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), + gs, p, r1, r2, m, count, r_size, l1_size, l2_size); } else { - cdist_backward_kernel_cuda_impl::p><<>>(buffer.data_ptr(), grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), grad.stride(-1), p, r1, r2, m, count, r_size, l1_size, l2_size); + cdist_backward_kernel_cuda_impl::p><<>>(buffer.data_ptr(), + grad.data_ptr(), x1.data_ptr(), x2.data_ptr(), dist.data_ptr(), + gs, p, r1, r2, m, count, r_size, l1_size, l2_size); } }); AT_CUDA_CHECK(cudaGetLastError()); diff --git a/aten/src/ATen/native/cuda/Embedding.cu b/aten/src/ATen/native/cuda/Embedding.cu index 359512ca6d41a..2ad9af2313559 100644 --- a/aten/src/ATen/native/cuda/Embedding.cu +++ b/aten/src/ATen/native/cuda/Embedding.cu @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -21,10 +22,8 @@ namespace at { namespace native { namespace { #ifdef __HIP_PLATFORM_HCC__ -static const int WARP_SIZE = 64; static const int BLOCKDIMY = 16; #else -static const int WARP_SIZE = 32; static const int BLOCKDIMY = 32; #endif @@ -41,8 +40,8 @@ __global__ void embedding_backward_feature_kernel { extern __shared__ char buf[]; accscalar_t* smem = (accscalar_t*)buf; - accscalar_t* my_s = smem + WARP_SIZE*threadIdx.y; - int* indices_batch = (int*)(buf + sizeof(accscalar_t)*WARP_SIZE*blockDim.y); + accscalar_t* my_s = smem + C10_WARP_SIZE*threadIdx.y; + int* indices_batch = (int*)(buf + sizeof(accscalar_t)*C10_WARP_SIZE*blockDim.y); const int s = (int)stride; // OK to make int, we don't expect 2 billion+ embedding row size @@ -106,7 +105,7 @@ __global__ void embedding_backward_feature_kernel #else first_remaining_peer = __ffs(matchmask) - 1; #endif - my_s[threadIdx.x] += smem[threadIdx.x + WARP_SIZE*first_remaining_peer]; + my_s[threadIdx.x] += smem[threadIdx.x + C10_WARP_SIZE*first_remaining_peer]; matchmask ^= (1 << first_remaining_peer); } if(f < s) @@ -154,7 +153,7 @@ __global__ void embedding_backward_kernel( #pragma unroll for (int ii = 0; ii < SZ; ii++) { - int feature_dim = start_feature + ii * WARP_SIZE; + int feature_dim = start_feature + ii * C10_WARP_SIZE; if (feature_dim < stride) { gradient[ii] = static_cast(grad_output[grad_row + feature_dim]); weight[ii] = static_cast(grad_weight[weight_row + feature_dim]); @@ -168,7 +167,7 @@ __global__ void embedding_backward_kernel( #pragma unroll for (int ii = 0; ii < SZ; ii++) { - int feature_dim = start_feature + ii * WARP_SIZE; + int feature_dim = start_feature + ii * C10_WARP_SIZE; if (feature_dim < stride) { grad_weight[weight_row + feature_dim] = static_cast(weight[ii]); } @@ -240,8 +239,8 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice auto indices_contig = indices.contiguous(); auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options()); int64_t stride = grad_weight.stride(0); - dim3 grid(THCCeilDiv(stride, (int64_t)WARP_SIZE)); - dim3 block(WARP_SIZE, BLOCKDIMY); + dim3 grid(THCCeilDiv(stride, (int64_t)C10_WARP_SIZE)); + dim3 block(C10_WARP_SIZE, BLOCKDIMY); AT_DISPATCH_FLOATING_TYPES_AND_HALF (grad.scalar_type(), @@ -252,7 +251,7 @@ Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indice embedding_backward_feature_kernel <<>> (indices_contig.data_ptr(), grad.data_ptr(), diff --git a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu index fb5d97f3d16e4..be0d963adfdf7 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu @@ -14,6 +14,7 @@ #include #include +#include namespace at { namespace native { @@ -33,12 +34,6 @@ constexpr int MAX_BLOCK_SIZE = 1024; */ constexpr int NROWS_PER_THREAD = 10; -#ifdef __HIP_PLATFORM_HCC__ - constexpr int WARP_SIZE = 64; -#else - constexpr int WARP_SIZE = 32; -#endif - // Fast ceil division (no overflow checking) __host__ __device__ __forceinline__ int64_t ceil_div(int64_t x, int64_t y) { @@ -266,7 +261,7 @@ Tensor embedding_backward_cuda_kernel( num_of_segments); } - const int stride_warped = ceil_div(stride, WARP_SIZE)*WARP_SIZE; + const int stride_warped = ceil_div(stride, C10_WARP_SIZE)*C10_WARP_SIZE; const int block = std::min(stride_warped, MAX_BLOCK_SIZE); const int grid = ceil_div(num_of_partial_segments*stride_warped, block); diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu index 88b8bd325276a..6268823dc2a1d 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBag.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu @@ -18,6 +18,8 @@ #include +#include + namespace at { namespace native { @@ -27,13 +29,6 @@ constexpr int MODE_SUM = 0; constexpr int MODE_MEAN = 1; constexpr int MODE_MAX = 2; -#ifdef __HIP_PLATFORM_HCC__ -constexpr int WARP_SIZE = 64; -#else -constexpr int WARP_SIZE = 32; -#endif - - // This kernel assumes that all input tensors except `weight` and // per_sample_weights are contiguous. template @@ -352,7 +347,7 @@ Tensor _embedding_bag_dense_backward_cuda(const Tensor &grad_, const Tensor &ind template __inline__ __device__ static scalar_t warpReduceSum(scalar_t val) { - for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) + for (int offset = C10_WARP_SIZE/2; offset > 0; offset /= 2) val += WARP_SHFL_DOWN(val, offset); return val; } @@ -368,9 +363,9 @@ __global__ static void _embedding_bag_per_sample_weights_backward_kernel( scalar_t* output) { using accscalar_t = acc_type; const int idx = threadIdx.x + blockIdx.x * blockDim.x; - const int warp = idx / WARP_SIZE; - const int thread_in_warp = idx % WARP_SIZE; - const int num_warps = blockDim.x * gridDim.x / WARP_SIZE; + const int warp = idx / C10_WARP_SIZE; + const int thread_in_warp = idx % C10_WARP_SIZE; + const int num_warps = blockDim.x * gridDim.x / C10_WARP_SIZE; // Each warp is responsible for the accumulation of one sample. // This involves doing one dot product between grad[bag_idx] and weight[embedding_idx]. @@ -379,7 +374,7 @@ __global__ static void _embedding_bag_per_sample_weights_backward_kernel( const int bag_idx = (int)offset2bag[sample_idx]; const int embedding_idx = (int)indices[sample_idx]; for (int feature_idx = thread_in_warp; feature_idx < embedding_features; - feature_idx += WARP_SIZE) { + feature_idx += C10_WARP_SIZE) { result += grad[grad_stride0 * bag_idx + grad_stride1 * feature_idx] * weight[weight_stride0 * embedding_idx + weight_stride1 * feature_idx]; @@ -412,7 +407,7 @@ Tensor _embedding_bag_per_sample_weights_backward_cuda( AT_ASSERT(weight.size(1) == embedding_features); const int threads_per_block = 1024; - const int warps_per_block = threads_per_block / WARP_SIZE; + const int warps_per_block = threads_per_block / C10_WARP_SIZE; dim3 block(threads_per_block); dim3 grid((num_samples + warps_per_block - 1) / warps_per_block); diff --git a/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu b/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu index c44b49c004d4e..ecd7188b273fd 100644 --- a/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu +++ b/aten/src/ATen/native/cuda/FractionalMaxPool3d.cu @@ -40,10 +40,10 @@ __device__ inline int64_t get_intervals( template __global__ void fractional_max_pool3d_out_frame( - PackedTensorAccessor input, - PackedTensorAccessor output, - PackedTensorAccessor indices, - PackedTensorAccessor samples, + PackedTensorAccessor64 input, + PackedTensorAccessor64 output, + PackedTensorAccessor64 indices, + PackedTensorAccessor64 samples, int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) { using accscalar_t = at::acc_type; // Output (t, h, w) point that this thread is responsible for @@ -109,9 +109,9 @@ __global__ void fractional_max_pool3d_out_frame( template __global__ void fractional_max_pool3d_backward_out_frame( - PackedTensorAccessor gradInput, - PackedTensorAccessor gradOutput, - PackedTensorAccessor indices) { + PackedTensorAccessor64 gradInput, + PackedTensorAccessor64 gradOutput, + PackedTensorAccessor64 indices) { // Output (h, w) point that this thread is responsible for int64_t ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x; int64_t plane = blockIdx.y; @@ -236,10 +236,10 @@ void fractional_max_pool3d_out_cuda_template( [&]{ fractional_max_pool3d_out_frame <<>>( - input_.packed_accessor(), - output_.packed_accessor(), - indices_.packed_accessor(), - randomSamples.packed_accessor(), + input_.packed_accessor64(), + output_.packed_accessor64(), + indices_.packed_accessor64(), + randomSamples.packed_accessor64(), poolSizeT, poolSizeH, poolSizeW ); } @@ -326,9 +326,9 @@ void fractional_max_pool3d_backward_out_cuda_template( [&] { fractional_max_pool3d_backward_out_frame <<>>( - gradInput_.packed_accessor(), - gradOutput_.packed_accessor(), - indices_.packed_accessor() + gradInput_.packed_accessor64(), + gradOutput_.packed_accessor64(), + indices_.packed_accessor64() ); } ); diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 1588765d4e17c..05086f2300c15 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -15,18 +15,10 @@ #include #include +#include namespace { -#ifdef __HIP_PLATFORM_HCC__ -static const int WARP_SIZE = 64; -#else -static const int WARP_SIZE = 32; -#endif - - - - template __global__ void indexing_backward_kernel( int64_t* sorted_indices, int64_t* indices, scalar_t* grad_output, scalar_t* grad_weight, @@ -66,7 +58,7 @@ __global__ void indexing_backward_kernel( while (start_feature < stride) { #pragma unroll for (int ii = 0; ii < SZ; ii++) { - int feature_dim = start_feature + ii * WARP_SIZE; + int feature_dim = start_feature + ii * C10_WARP_SIZE; if (feature_dim < stride) { gradient[ii] = static_cast(grad_output[grad_row + feature_dim]); weight[ii] = static_cast(grad_weight[weight_row + feature_dim]); @@ -80,7 +72,7 @@ __global__ void indexing_backward_kernel( #pragma unroll for (int ii = 0; ii < SZ; ii++) { - int feature_dim = start_feature + ii * WARP_SIZE; + int feature_dim = start_feature + ii * C10_WARP_SIZE; if (feature_dim < stride) { grad_weight[weight_row + feature_dim] = static_cast(weight[ii]); } @@ -229,9 +221,9 @@ void index_put_accum_kernel(Tensor & self, TensorList indices, const Tensor & va const int UNROLL = 4; const int indices_per_block = 4; dim3 grid(THCCeilDiv(num_indices, (int64_t) indices_per_block), - std::min(at::cuda::getCurrentDeviceProperties()->maxGridSize[1], THCCeilDiv(sliceSize, (int64_t) (WARP_SIZE*UNROLL))), + std::min(at::cuda::getCurrentDeviceProperties()->maxGridSize[1], THCCeilDiv(sliceSize, (int64_t) (C10_WARP_SIZE*UNROLL))), std::min(std::max(1,nElemBefore), at::cuda::getCurrentDeviceProperties()->maxGridSize[2])); - dim3 block(WARP_SIZE, indices_per_block); + dim3 block(C10_WARP_SIZE, indices_per_block); AT_DISPATCH_FLOATING_TYPES_AND_HALF(value_.scalar_type(), "embedding_backward", [&] { indexing_backward_kernel<<>>( diff --git a/aten/src/ATen/native/cuda/LegacyDefinitions.cpp b/aten/src/ATen/native/cuda/LegacyDefinitions.cpp index 06a4274d51ab5..892f0d8c88a15 100644 --- a/aten/src/ATen/native/cuda/LegacyDefinitions.cpp +++ b/aten/src/ATen/native/cuda/LegacyDefinitions.cpp @@ -2,6 +2,7 @@ #include #include #include +#include namespace at { namespace native { @@ -35,13 +36,14 @@ Tensor & masked_fill__cuda(Tensor& self, const Tensor & mask, const Tensor & val if (mask.dtype() == at::ScalarType::Byte) { AT_WARN("masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated," \ "please use a mask with dtype torch.bool instead."); - return legacy::cuda::_th_masked_fill_(self, mask, value); + legacy::cuda::_th_masked_fill_(self, mask, value); } else { - return legacy::cuda::_th_masked_fill_bool_(self, mask, value); + legacy::cuda::_th_masked_fill_bool_(self, mask, value); } #ifdef BUILD_NAMEDTENSOR namedinference::propagate_names(self, std::move(outnames), /*validate_names=*/false); #endif + return self; } Tensor & masked_scatter__cuda(Tensor& self, const Tensor & mask, const Tensor & source) { diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index c4f06d15879d4..7d8dfb15a1f71 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -179,6 +179,7 @@ void gpu_kernel(TensorIterator& iter, const func_t& f) { } gpu_kernel_impl(iter, f); + iter.cast_outputs(); } template diff --git a/aten/src/ATen/native/cuda/LossCTC.cu b/aten/src/ATen/native/cuda/LossCTC.cu index a52bbe0c7064b..97dd081ddd3c2 100644 --- a/aten/src/ATen/native/cuda/LossCTC.cu +++ b/aten/src/ATen/native/cuda/LossCTC.cu @@ -248,7 +248,7 @@ std::tuple ctc_loss_gpu_template(const Tensor& log_probs, const int64_t max_input_length = log_probs.size(0); for (int64_t b = 0; b < batch_size; b++) { TORCH_CHECK(input_lengths[b] <= max_input_length, - "Expected tensor to have size at least ", max_input_length, " at dimension 1, but got size ", targets.size(0), " for ", targets_arg, + "Expected input_lengths to have value at most ", max_input_length, ", but got value ", input_lengths[b], " (while checking arguments for ", c, ")"); } diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh new file mode 100644 index 0000000000000..41b4c1d7959fa --- /dev/null +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -0,0 +1,92 @@ +#pragma once +#include + +namespace at { +namespace native { + +/* +* The following function was converted to CUDA form from code that comes +* with the following copyright notice. It has been released under the BSD license. +* +* Cephes Math Library Release 2.8: June, 2000 +* Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier +*/ +template +static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) { + using accscalar_t = at::acc_type; + static const double PI_f64 = 3.14159265358979323846; + const accscalar_t PSI_10 = 2.25175258906672110764; + const accscalar_t A[] = { + 8.33333333333333333333E-2, + -2.10927960927960927961E-2, + 7.57575757575757575758E-3, + -4.16666666666666666667E-3, + 3.96825396825396825397E-3, + -8.33333333333333333333E-3, + 8.33333333333333333333E-2, + }; + + accscalar_t x = static_cast(in); + if (x == 0) { + return static_cast(INFINITY); + } + + bool x_is_integer = x == ::floor(x); + accscalar_t result = 0; + if (x < 0) { + if (x_is_integer) { + return static_cast(INFINITY); + } + // Rounding errors in tan's input can really affect the output + // for extreme values, so we always perform this computation in double. + result = static_cast(- PI_f64 / ::tan(PI_f64 * static_cast(x))); + x = 1 - x; + } + + while (x < 10) { + result -= 1 / x; + x += 1; + } + if (x == 10) { + return static_cast(result + PSI_10); + } + + accscalar_t y = 0; + if (x < 1.0e17) { + accscalar_t z = 1.0 / (x * x); + + accscalar_t polevl_result = 0; + for (int i = 0; i <= 6; i++) { + polevl_result = polevl_result * z + A[i]; + } + y = z * polevl_result; + } + + return static_cast(::log(x) - (0.5 / x) - y + result); +} + +template +static inline __host__ __device__ scalar_t calc_trigamma(scalar_t in) { + using accscalar_t = at::acc_type; + const accscalar_t PI = 3.14159265358979323846; + accscalar_t x = static_cast(in); + accscalar_t sign = +1; + accscalar_t result = 0; + if (x < 0.5f) { + sign = -1; + accscalar_t sin_pi_x = ::sin(PI * x); + result -= (PI * PI) / (sin_pi_x * sin_pi_x); + x = 1 - x; + } + for (int i = 0; i < 6; ++i) { + result += 1 / (x * x); + x += 1; + } + const accscalar_t one = static_cast(1); + const accscalar_t ixx = 1 / (x*x); + result += (1 + 1 / (2*x) + ixx * (one/6 - ixx * (one/30 - ixx * (one/42)))) / x; + return static_cast(sign * result); +} + +} +} diff --git a/aten/src/ATen/native/cuda/MaxUnpooling.cu b/aten/src/ATen/native/cuda/MaxUnpooling.cu index 1db0afd8b3afe..e4131c701bbcd 100644 --- a/aten/src/ATen/native/cuda/MaxUnpooling.cu +++ b/aten/src/ATen/native/cuda/MaxUnpooling.cu @@ -38,8 +38,8 @@ __global__ void max_unpooling2d_forward_kernel( template __global__ void max_unpooling3d_forward_kernel( - PackedTensorAccessor input, - PackedTensorAccessor indices, + PackedTensorAccessor64 input, + PackedTensorAccessor64 indices, T* output, const int64_t oT, const int64_t oH, @@ -82,8 +82,8 @@ __global__ void max_unpooling3d_backward_kernel( int64_t oT, int64_t oH, int64_t oW, - PackedTensorAccessor indices, - PackedTensorAccessor gradInput, + PackedTensorAccessor64 indices, + PackedTensorAccessor64 gradInput, int offsetZ) { int iColumn = blockIdx.x * blockDim.x + threadIdx.x; int iRow = blockIdx.y * blockDim.y + threadIdx.y; @@ -339,8 +339,8 @@ Tensor& max_unpooling3d_forward_out_cuda( block, 0, at::cuda::getCurrentCUDAStream()>>>( - self.packed_accessor(), - indices.packed_accessor(), + self.packed_accessor64(), + indices.packed_accessor64(), output.data_ptr(), oT, oH, @@ -558,8 +558,8 @@ at::Tensor& max_unpooling3d_backward_out_cuda( oT, oH, oW, - indices.packed_accessor(), - grad_input_reshaped.packed_accessor(), + indices.packed_accessor64(), + grad_input_reshaped.packed_accessor64(), offsetZ); TORCH_CHECK( cudaGetLastError() == cudaSuccess, diff --git a/aten/src/ATen/native/cuda/MultinomialKernel.cu b/aten/src/ATen/native/cuda/MultinomialKernel.cu new file mode 100644 index 0000000000000..e52f9f48cdf56 --- /dev/null +++ b/aten/src/ATen/native/cuda/MultinomialKernel.cu @@ -0,0 +1,14 @@ +#include +#include +#include +#include + +namespace at { namespace native { + +void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t num_samples, const bool replacement, Generator* generator) { + legacy::cuda::_th_multinomial_out(result, self, num_samples, replacement, generator); +} + +REGISTER_DISPATCH(multinomial_stub, &multinomial_kernel_impl); + +}} \ No newline at end of file diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index 966520698c93e..414db1ac9e37e 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -8,15 +8,10 @@ #include #include #include +#include namespace at { namespace native { -#if defined(__HIP_PLATFORM_HCC__) -constexpr int WARP_SIZE = 64; -#else -constexpr int WARP_SIZE = 32; -#endif - // The maximum number of threads in a block #if defined(__HIP_PLATFORM_HCC__) constexpr int MAX_BLOCK_SIZE = 256; @@ -94,8 +89,8 @@ struct GradOp { // Sum across all threads within a warp template static __device__ __forceinline__ T warpSum(T val) { - for (int i = 0; i < getMSB(WARP_SIZE); ++i) { - val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); + for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) { + val += WARP_SHFL_XOR(val, 1 << i, C10_WARP_SIZE); } return val; } @@ -110,12 +105,12 @@ static __device__ __forceinline__ Float2 warpSum(Float2 __device__ scalar_t reduce(Op op, PTA tensor, int plane) { // first the reductions each thread does separately @@ -131,15 +126,15 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) { sum = warpSum(sum); // this writes each warps item into shared memory - // there are at most WARP_SIZE items left because - // there are at most WARP_SIZE**2 threads at the beginning - __shared__ scalar_t shared[WARP_SIZE]; + // there are at most C10_WARP_SIZE items left because + // there are at most C10_WARP_SIZE**2 threads at the beginning + __shared__ scalar_t shared[C10_WARP_SIZE]; __syncthreads(); int tid = threadIdx.x + threadIdx.y * blockDim.x; - if (tid % WARP_SIZE == 0) { - shared[tid / WARP_SIZE] = sum; + if (tid % C10_WARP_SIZE == 0) { + shared[tid / C10_WARP_SIZE] = sum; } - if (tid >= blockDim.x * blockDim.y / WARP_SIZE && tid < WARP_SIZE) { + if (tid >= blockDim.x * blockDim.y / C10_WARP_SIZE && tid < C10_WARP_SIZE) { // zero out the other entries in shared shared[tid] = (scalar_t)0; } @@ -148,7 +143,7 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) { // from shared memory to a single number. The very first // thread writes it to shared memory. - if (tid / WARP_SIZE == 0) { + if (tid / C10_WARP_SIZE == 0) { sum = warpSum(shared[tid]); if (tid == 0) { shared[0] = sum; @@ -162,12 +157,12 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) { template __global__ void batch_norm_transform_input_kernel( - const PackedTensorAccessor input, - PackedTensorAccessor output, - const PackedTensorAccessor::type, 1, RestrictPtrTraits, index_t> mean_, - const PackedTensorAccessor::type, 1, RestrictPtrTraits, index_t> var_or_invstd, - const PackedTensorAccessor weight, - const PackedTensorAccessor bias, + const GenericPackedTensorAccessor input, + GenericPackedTensorAccessor output, + const GenericPackedTensorAccessor::type, 1, RestrictPtrTraits, index_t> mean_, + const GenericPackedTensorAccessor::type, 1, RestrictPtrTraits, index_t> var_or_invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor bias, stat_accscalar_t epsilon) { index_t plane = blockIdx.x; @@ -219,15 +214,15 @@ struct Var { template class VarTransform, typename input_scalar_t, typename stat_scalar_t, typename stat_accscalar_t, typename index_t> __global__ void batch_norm_collect_statistics_kernel( - const PackedTensorAccessor input, + const GenericPackedTensorAccessor input, const stat_accscalar_t epsilon, const stat_accscalar_t momentum, - PackedTensorAccessor running_mean, - PackedTensorAccessor running_var, - PackedTensorAccessor save_mean, - PackedTensorAccessor save_transformed_var) { + GenericPackedTensorAccessor running_mean, + GenericPackedTensorAccessor running_var, + GenericPackedTensorAccessor save_mean, + GenericPackedTensorAccessor save_transformed_var) { - __shared__ int shared_n[2 * 2 * WARP_SIZE + WARP_SIZE]; + __shared__ int shared_n[2 * 2 * C10_WARP_SIZE + C10_WARP_SIZE]; int plane = blockIdx.x; int N = input.size(0) * input.size(2); @@ -239,7 +234,7 @@ __global__ void batch_norm_collect_statistics_kernel( // and the parallel algorithm on the same page. // We use two shuffles to reduce across the entire block. // https://devblogs.nvidia.com/faster-parallel-reductions-kepler/ has a description. - stat_accscalar_t* shared_avg_var = (stat_accscalar_t*) &shared_n[WARP_SIZE]; + stat_accscalar_t* shared_avg_var = (stat_accscalar_t*) &shared_n[C10_WARP_SIZE]; // first the reductions each thread does separately stat_accscalar_t avg = 0; @@ -257,39 +252,39 @@ __global__ void batch_norm_collect_statistics_kernel( // first warpSum to get one value per thread to // one value per warp - for (int i = 0; i < getMSB(WARP_SIZE); ++i) { - stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, WARP_SIZE); - int o_n = WARP_SHFL_XOR(n, 1 << i, WARP_SIZE); + for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) { + stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE); + int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE); stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n); - var_n += WARP_SHFL_XOR(var_n, 1 << i, WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; + var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; avg = (n * avg + o_n * o_avg) * factor; n += o_n; } // this writes each warps item into shared memory - // there are at most WARP_SIZE items left because - // there are at most WARP_SIZE**2 threads at the beginning + // there are at most C10_WARP_SIZE items left because + // there are at most C10_WARP_SIZE**2 threads at the beginning __syncthreads(); - if (tid % WARP_SIZE == 0) { - shared_n[tid / WARP_SIZE] = n; - shared_avg_var[tid / WARP_SIZE * 2] = avg; - shared_avg_var[tid / WARP_SIZE * 2 + 1] = var_n; + if (tid % C10_WARP_SIZE == 0) { + shared_n[tid / C10_WARP_SIZE] = n; + shared_avg_var[tid / C10_WARP_SIZE * 2] = avg; + shared_avg_var[tid / C10_WARP_SIZE * 2 + 1] = var_n; } __syncthreads(); // now have a second warpSum to reduce the intermediate values // from shared memory to a single number. The very first // thread writes it to shared memory. - if (tid < WARP_SIZE) { - n = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_n[tid] : 0); - avg = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_avg_var[2 * tid] : stat_accscalar_t(0)); - var_n = (tid < blockDim.x * blockDim.y / WARP_SIZE ? shared_avg_var[2 * tid + 1] : stat_accscalar_t(0)); + if (tid < C10_WARP_SIZE) { + n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_n[tid] : 0); + avg = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid] : stat_accscalar_t(0)); + var_n = (tid < blockDim.x * blockDim.y / C10_WARP_SIZE ? shared_avg_var[2 * tid + 1] : stat_accscalar_t(0)); } - for (int i = 0; i < getMSB(WARP_SIZE); ++i) { - stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, WARP_SIZE); - int o_n = WARP_SHFL_XOR(n, 1 << i, WARP_SIZE); + for (int i = 0; i < getMSB(C10_WARP_SIZE); ++i) { + stat_accscalar_t o_avg = WARP_SHFL_XOR(avg, 1 << i, C10_WARP_SIZE); + int o_n = WARP_SHFL_XOR(n, 1 << i, C10_WARP_SIZE); stat_accscalar_t factor = 1.0 / fmaxf(1.0, n+o_n); - var_n += WARP_SHFL_XOR(var_n, 1 << i, WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; + var_n += WARP_SHFL_XOR(var_n, 1 << i, C10_WARP_SIZE) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; avg = (n * avg + o_n * o_avg) * factor; n += o_n; } @@ -315,16 +310,16 @@ __global__ void batch_norm_collect_statistics_kernel( template __global__ void batch_norm_backward_kernel( - const PackedTensorAccessor input, - const PackedTensorAccessor grad_output, - PackedTensorAccessor grad_input, - PackedTensorAccessor grad_weight, - PackedTensorAccessor grad_bias, - const PackedTensorAccessor weight, - const PackedTensorAccessor running_mean, - const PackedTensorAccessor running_var, - const PackedTensorAccessor save_mean, - const PackedTensorAccessor save_invstd, + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + GenericPackedTensorAccessor grad_input, + GenericPackedTensorAccessor grad_weight, + GenericPackedTensorAccessor grad_bias, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor running_mean, + const GenericPackedTensorAccessor running_var, + const GenericPackedTensorAccessor save_mean, + const GenericPackedTensorAccessor save_invstd, bool train, stat_accscalar_t epsilon) { @@ -346,9 +341,9 @@ __global__ void batch_norm_backward_kernel( // Compute two values across (batch, x/y/z) in one pass: // 1. Sum(grad_output) // 2. DotProduct(input - mean, grad_output) - GradOp> g(mean, input, grad_output); + GradOp> g(mean, input, grad_output); Float2 res = reduce, GradOp>>(g, grad_output, plane); + GenericPackedTensorAccessor>>(g, grad_output, plane); stat_accscalar_t grad_output_sum = res.v1; stat_accscalar_t dot_p = res.v2; @@ -386,15 +381,15 @@ __global__ void batch_norm_backward_kernel( template __global__ void batch_norm_reduce_statistics_kernel( - const PackedTensorAccessor vec_mean, - const PackedTensorAccessor vec_invstd, - PackedTensorAccessor mean, - PackedTensorAccessor invstd, - PackedTensorAccessor running_mean, - PackedTensorAccessor running_var, + const GenericPackedTensorAccessor vec_mean, + const GenericPackedTensorAccessor vec_invstd, + GenericPackedTensorAccessor mean, + GenericPackedTensorAccessor invstd, + GenericPackedTensorAccessor running_mean, + GenericPackedTensorAccessor running_var, const accscalar_t epsilon, const accscalar_t momentum, - const PackedTensorAccessor counts) { + const GenericPackedTensorAccessor counts) { int feature_size = vec_mean.size(1); int world_size = vec_mean.size(0); @@ -432,14 +427,14 @@ __global__ void batch_norm_reduce_statistics_kernel( template __global__ void batch_norm_backward_reduce_kernel( - const PackedTensorAccessor input, - const PackedTensorAccessor grad_output, - PackedTensorAccessor mean, - PackedTensorAccessor invstd, - PackedTensorAccessor mean_dy, - PackedTensorAccessor mean_dy_xmu, - PackedTensorAccessor grad_weight, - PackedTensorAccessor grad_bias) { + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + GenericPackedTensorAccessor mean, + GenericPackedTensorAccessor invstd, + GenericPackedTensorAccessor mean_dy, + GenericPackedTensorAccessor mean_dy_xmu, + GenericPackedTensorAccessor grad_weight, + GenericPackedTensorAccessor grad_bias) { index_t plane = blockIdx.x; index_t N = input.size(0) * input.size(2); @@ -447,9 +442,9 @@ __global__ void batch_norm_backward_reduce_kernel( stat_accscalar_t r_mean = mean[plane]; stat_accscalar_t factor = invstd[plane]; - GradOp> g(r_mean, input, grad_output); + GradOp> g(r_mean, input, grad_output); Float2 res = reduce, GradOp>>(g, grad_output, plane); + GenericPackedTensorAccessor>>(g, grad_output, plane); stat_accscalar_t norm = stat_accscalar_t(1) / N; if (threadIdx.x == 0) { @@ -470,14 +465,14 @@ __global__ void batch_norm_backward_reduce_kernel( template __global__ void batch_norm_backward_elemt_kernel( - const PackedTensorAccessor input, - const PackedTensorAccessor grad_output, - const PackedTensorAccessor mean, - const PackedTensorAccessor invstd, - const PackedTensorAccessor weight, - const PackedTensorAccessor mean_dy, - const PackedTensorAccessor mean_dy_xmu, - PackedTensorAccessor grad_input) { + const GenericPackedTensorAccessor input, + const GenericPackedTensorAccessor grad_output, + const GenericPackedTensorAccessor mean, + const GenericPackedTensorAccessor invstd, + const GenericPackedTensorAccessor weight, + const GenericPackedTensorAccessor mean_dy, + const GenericPackedTensorAccessor mean_dy_xmu, + GenericPackedTensorAccessor grad_input) { index_t plane = blockIdx.x; @@ -507,12 +502,12 @@ __global__ void batch_norm_backward_elemt_kernel( } template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> -static PackedTensorAccessor packed_accessor_or_dummy(const Tensor& t) { +static GenericPackedTensorAccessor packed_accessor_or_dummy(const Tensor& t) { if (! t.defined()) { const std::vector zeros(dim); - return PackedTensorAccessor(nullptr, zeros.data(), zeros.data()); + return GenericPackedTensorAccessor(nullptr, zeros.data(), zeros.data()); } - return t.packed_accessor(); + return t.generic_packed_accessor(); } template @@ -537,7 +532,7 @@ std::tuple batch_norm_cuda_template(const Tensor& input_ auto bs = input_reshaped.size(0); auto features = input_reshaped.size(2); - auto input = input_reshaped.packed_accessor(); + auto input = input_reshaped.generic_packed_accessor(); auto input_options = input_.options(); if (input_.scalar_type() == at::ScalarType::Half) { input_options = input_options.dtype(ScalarType::Float); @@ -549,13 +544,13 @@ std::tuple batch_norm_cuda_template(const Tensor& input_ save_mean_ = at::empty({0}, input_options); save_invstd_ = at::empty({0}, input_options); } - auto output = output_reshaped.packed_accessor(); + auto output = output_reshaped.generic_packed_accessor(); auto weight = packed_accessor_or_dummy(weight_); auto bias = packed_accessor_or_dummy(bias_); auto running_mean = packed_accessor_or_dummy(running_mean_); auto running_var = packed_accessor_or_dummy(running_var_); - auto save_mean = save_mean_.packed_accessor(); - auto save_invstd = save_invstd_.packed_accessor(); + auto save_mean = save_mean_.generic_packed_accessor(); + auto save_invstd = save_invstd_.generic_packed_accessor(); auto stream = at::cuda::getCurrentCUDAStream(); // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean, @@ -611,8 +606,8 @@ std::tuple batch_norm_backward_cuda_template(const Tenso grad_bias_ = at::empty_like(weight_); } - auto input = input_reshaped.packed_accessor(); - auto grad_output = grad_output_reshaped.packed_accessor(); + auto input = input_reshaped.generic_packed_accessor(); + auto grad_output = grad_output_reshaped.generic_packed_accessor(); auto grad_input = packed_accessor_or_dummy(grad_input_reshaped); auto weight = packed_accessor_or_dummy(weight_); auto grad_weight = packed_accessor_or_dummy(grad_weight_); @@ -648,7 +643,7 @@ std::tuple batch_norm_stats_cuda_template(const Tensor& input_, auto bs = input_reshaped.size(0); auto features = input_reshaped.size(2); - auto input = input_reshaped.packed_accessor(); + auto input = input_reshaped.generic_packed_accessor(); auto input_options = input_.options(); dummy_mean_ = at::empty({0}, input_options); dummy_var_ = at::empty({0}, input_options); @@ -660,8 +655,8 @@ std::tuple batch_norm_stats_cuda_template(const Tensor& input_, invstd_ = at::empty({n_input}, input_options); auto mean = packed_accessor_or_dummy(mean_); auto invstd = packed_accessor_or_dummy(invstd_); - auto dummy_mean = dummy_mean_.packed_accessor(); - auto dummy_invstd = dummy_var_.packed_accessor(); + auto dummy_mean = dummy_mean_.generic_packed_accessor(); + auto dummy_invstd = dummy_var_.generic_packed_accessor(); auto stream = at::cuda::getCurrentCUDAStream(); dim3 blocks(input.size(1)); @@ -685,12 +680,12 @@ Tensor batch_norm_elemt_cuda_template(const Tensor& input_, const Tensor& weight auto bs = input_reshaped.size(0); auto features = input_reshaped.size(2); - auto input = input_reshaped.packed_accessor(); + auto input = input_reshaped.generic_packed_accessor(); auto input_options = input_.options(); if (input_.scalar_type() == at::ScalarType::Half) { input_options = input_options.dtype(ScalarType::Float); } - auto output = output_reshaped.packed_accessor(); + auto output = output_reshaped.generic_packed_accessor(); auto weight = packed_accessor_or_dummy(weight_); auto bias = packed_accessor_or_dummy(bias_); auto mean = packed_accessor_or_dummy(mean_); @@ -735,8 +730,8 @@ std::tuple batch_norm_gather_stats_cuda_template(const Tensor& m auto running_var = packed_accessor_or_dummy(running_var_); auto counts = packed_accessor_or_dummy(counts_); - auto save_mean = save_mean_.packed_accessor(); - auto save_invstd = save_invstd_.packed_accessor(); + auto save_mean = save_mean_.generic_packed_accessor(); + auto save_invstd = save_invstd_.generic_packed_accessor(); auto stream = at::cuda::getCurrentCUDAStream(); int block = getNumThreads(features); @@ -772,8 +767,8 @@ std::tuple batch_norm_backward_reduce_cuda_templ grad_bias_ = at::empty({n_input}, weight_.options()); } - auto input = input_reshaped.packed_accessor(); - auto grad_output = grad_output_reshaped.packed_accessor(); + auto input = input_reshaped.generic_packed_accessor(); + auto grad_output = grad_output_reshaped.generic_packed_accessor(); auto grad_weight = packed_accessor_or_dummy(grad_weight_); auto grad_bias = packed_accessor_or_dummy(grad_bias_); auto mean = packed_accessor_or_dummy(mean_); @@ -811,9 +806,9 @@ Tensor batch_norm_backward_elemt_cuda_template(const Tensor& grad_out_, const Te auto bs = input_reshaped.size(0); auto features = input_reshaped.size(2); - auto input = input_reshaped.packed_accessor(); - auto grad_input = grad_input_reshaped.packed_accessor(); - auto grad_output = grad_output_reshaped.packed_accessor(); + auto input = input_reshaped.generic_packed_accessor(); + auto grad_input = grad_input_reshaped.generic_packed_accessor(); + auto grad_output = grad_output_reshaped.generic_packed_accessor(); auto mean = packed_accessor_or_dummy(mean_); auto invstd = packed_accessor_or_dummy(invstd_); auto weight = packed_accessor_or_dummy(weight_); @@ -853,11 +848,11 @@ std::tuple batch_norm_update_stats_cuda_template( Tensor save_mean_ = at::empty({n_channels}, input_options); Tensor save_var_ = at::empty({n_channels}, input_options); - auto input = input_reshaped.packed_accessor(); + auto input = input_reshaped.generic_packed_accessor(); auto running_mean = packed_accessor_or_dummy(running_mean_); auto running_var = packed_accessor_or_dummy(running_var_); - auto save_mean = save_mean_.packed_accessor(); - auto save_var = save_var_.packed_accessor(); + auto save_mean = save_mean_.generic_packed_accessor(); + auto save_var = save_var_.generic_packed_accessor(); auto stream = at::cuda::getCurrentCUDAStream(); // for the reduction, we cannot use blocks for the batch dim, but if we have few threads in diff --git a/aten/src/ATen/native/cuda/PersistentSoftmax.cuh b/aten/src/ATen/native/cuda/PersistentSoftmax.cuh index 5d283d1d404e1..9be51dc18e4af 100644 --- a/aten/src/ATen/native/cuda/PersistentSoftmax.cuh +++ b/aten/src/ATen/native/cuda/PersistentSoftmax.cuh @@ -6,6 +6,7 @@ #include #include #include +#include namespace { @@ -47,11 +48,11 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) { // The template arguments have the following meaning: // One "WARP" works on one "BATCH". One "BATCH" contains "WARP_BATCH" samples. // WARP_BATCH is equal to 1 when element_count is large, and > 1 when element_count is small. -// A "WARP" contains "WARPS_SIZE" threads, these treads are guaranteed to belong to the same CUDA warp. +// A "WARP" contains "C10_WARPS_SIZE" threads, these treads are guaranteed to belong to the same warp. // This is important because it means only __shfl_ instructions are required for reductions. -// Note that this means WARP_SIZE must be a power of two and <= CUDA warp size. +// Note that this means WARP_SIZE must be a power of two and <= architecture warp size. // CUDA warp size is 32 for all existing GPU architecures, but there is no guarantee this will not change for future arch. -// This code will not work properly if warp size is < 32. +// ROCm warp size is 64 for all currently ROCm-supported GPU architectures, but this may change for future archs. // is_log_softmax is a flag indicating whether SoftMax or LogSoftMax should be computed. // The template can be instantiated with any floating point type for the type arguments input_t, output_t and acc_t. // This allows SoftMax to be fused with a cast immediately following the SoftMax. @@ -65,7 +66,7 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc { // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_forward_kernel. constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < 32) ? next_power_of_two : 32; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; @@ -156,7 +157,7 @@ __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad, { // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel. constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < 32) ? next_power_of_two : 32; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; @@ -243,7 +244,7 @@ void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_ele const int next_power_of_two = 1 << log2_elements; // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; @@ -318,7 +319,7 @@ void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const const int next_power_of_two = 1 << log2_elements; // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; diff --git a/aten/src/ATen/native/cuda/PowKernel.cu b/aten/src/ATen/native/cuda/PowKernel.cu new file mode 100644 index 0000000000000..91a96ef0a5f97 --- /dev/null +++ b/aten/src/ATen/native/cuda/PowKernel.cu @@ -0,0 +1,100 @@ +#include +#include +#include +#include +#include +#include + +namespace at { namespace native { + +namespace { + +template +static inline __host__ __device__ T powi(T a, T b) { + T result = 1; + while (b) { + if (b & 1) { + result *= a; + } + b /= 2; + a *= a; + } + return result; +} + +template +static inline __host__ __device__ T sqrt(T x) { + return std::sqrt(x); +} + +void pow_tensor_tensor_kernel(TensorIterator& iter) { + if (isFloatingType(iter.dtype())) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "pow_cuda", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t { + return std::pow(base, exp); + }); + }); + } else { + AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "pow_cuda", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t base, scalar_t exp) -> scalar_t { + return powi(base, exp); + }); + }); + } +} + +template +void pow_tensor_scalar_kernel_impl(TensorIterator& iter, + Exp_type exp) { + const auto d_exp = static_cast(exp); + if (d_exp == 0.5) { + gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type { + return ::sqrt(base); + }); + } else if (d_exp == 2) { + gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type { + return base * base; + }); + } else if (d_exp == 3) { + gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type { + return base * base * base; + }); + } else if (d_exp == -0.5) { + gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type { + return 1.0 / ::sqrt(base); + }); + } else if (d_exp == -1) { + gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type { + return 1.0 / base; + }); + } else if (d_exp == -2) { + gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type { + return 1.0 / (base * base); + }); + } else { + gpu_kernel(iter, [=]GPU_LAMBDA(Base_type base) -> Base_type { + return std::pow(base, exp); + }); + } +} + +void pow_tensor_scalar_kernel(TensorIterator& iter, Scalar exp_scalar) { + if (isFloatingType(iter.dtype()) || exp_scalar.isIntegral(false)) { + AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "pow_cuda", [&]() { + const auto exp = exp_scalar.to(); + pow_tensor_scalar_kernel_impl(iter, exp); + }); + } else { + const auto exp = exp_scalar.to(); + AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "pow_cuda", [&]() { + pow_tensor_scalar_kernel_impl(iter, exp); + }); + } +} + +} // anonymous namespace + +REGISTER_DISPATCH(pow_tensor_tensor_stub, &pow_tensor_tensor_kernel); +REGISTER_DISPATCH(pow_tensor_scalar_stub, &pow_tensor_scalar_kernel); + +}} // namespace at::native diff --git a/aten/src/ATen/native/cuda/ReduceOpsKernel.cu b/aten/src/ATen/native/cuda/ReduceOpsKernel.cu index 40f813f6881df..f8b6e9bc8e11a 100644 --- a/aten/src/ATen/native/cuda/ReduceOpsKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceOpsKernel.cu @@ -60,7 +60,7 @@ void mean_kernel_impl(TensorIterator& iter) { template void norm_kernel_cuda_impl(TensorIterator& iter, Scalar val) { float p; - if (val.isIntegral()) { + if (val.isIntegral(false)) { p = val.to(); } else if (val.isFloatingPoint()) { p = val.to(); diff --git a/aten/src/ATen/native/cuda/ReplicationPadding.cu b/aten/src/ATen/native/cuda/ReplicationPadding.cu index c9da8f440b729..ba51fc2105350 100644 --- a/aten/src/ATen/native/cuda/ReplicationPadding.cu +++ b/aten/src/ATen/native/cuda/ReplicationPadding.cu @@ -27,8 +27,8 @@ __host__ __device__ __forceinline__ int imax(int a, int b) { namespace { template __global__ void replication_pad_forward_kernel1d( - PackedTensorAccessor input, - PackedTensorAccessor output, + PackedTensorAccessor64 input, + PackedTensorAccessor64 output, int padL, int padR) { int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; @@ -50,8 +50,8 @@ __global__ void replication_pad_forward_kernel1d( template __global__ void replication_pad_backward_kernel( - PackedTensorAccessor gradInput, - PackedTensorAccessor gradOutput, + PackedTensorAccessor64 gradInput, + PackedTensorAccessor64 gradOutput, int padL, int padR) { int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; @@ -73,8 +73,8 @@ __global__ void replication_pad_backward_kernel( template __global__ void replication_pad_forward_kernel2d( - PackedTensorAccessor input, - PackedTensorAccessor output, + PackedTensorAccessor64 input, + PackedTensorAccessor64 output, int padT, int padB, int padL, int padR) { int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; @@ -100,8 +100,8 @@ __global__ void replication_pad_forward_kernel2d( template __global__ void replication_pad_backward_kernel( - PackedTensorAccessor gradInput, - PackedTensorAccessor gradOutput, + PackedTensorAccessor64 gradInput, + PackedTensorAccessor64 gradOutput, int padT, int padB, int padL, int padR) { int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; @@ -127,8 +127,8 @@ __global__ void replication_pad_backward_kernel( template __global__ void replication_pad_forward_kernel3d( - PackedTensorAccessor input, - PackedTensorAccessor output, + PackedTensorAccessor64 input, + PackedTensorAccessor64 output, int pfront, int pback, int ptop, int pbottom, int pleft, int pright) { int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; @@ -163,8 +163,8 @@ __global__ void replication_pad_forward_kernel3d( template __global__ void replication_pad_backward_kernel( - PackedTensorAccessor gradInput, - PackedTensorAccessor gradOutput, + PackedTensorAccessor64 gradInput, + PackedTensorAccessor64 gradOutput, int pfront, int pback, int ptop, int pbottom, int pleft, int pright) { int outputPointId = threadIdx.x + blockIdx.x * blockDim.x; int plane = blockIdx.y; @@ -242,8 +242,8 @@ void replication_pad1d_out_cuda_template( output.resize_({numPlanes, outputW}); auto input_ = input.unsqueeze(0); auto output_ = output.unsqueeze(0); - auto devInput = input_.packed_accessor(); - auto devOutput = output_.packed_accessor(); + auto devInput = input_.packed_accessor64(); + auto devOutput = output_.packed_accessor64(); int outputPlaneSize = devOutput.size(2); dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), @@ -255,8 +255,8 @@ void replication_pad1d_out_cuda_template( at::cuda::getCurrentCUDAStream()>>>(devInput, devOutput, padL, padR); } else { output.resize_({numBatch, numPlanes, outputW}); - auto devInput = input.packed_accessor(); - auto devOutput = output.packed_accessor(); + auto devInput = input.packed_accessor64(); + auto devOutput = output.packed_accessor64(); int outputPlaneSize = devOutput.size(2); dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), @@ -314,8 +314,8 @@ void replication_pad1d_backward_out_cuda_template( gradInput_ = gradInput.unsqueeze(0); gradOutput_ = gradOutput.unsqueeze(0); } - auto devGradInput = gradInput_.packed_accessor(); - auto devGradOutput = gradOutput_.packed_accessor(); + auto devGradInput = gradInput_.packed_accessor64(); + auto devGradOutput = gradOutput_.packed_accessor64(); int outputPlaneSize = devGradOutput.size(2); dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), @@ -379,8 +379,8 @@ void replication_pad2d_out_cuda_template( output.resize_({numPlanes, outputH, outputW}); auto input_ = input.unsqueeze(0); auto output_ = output.unsqueeze(0); - auto devInput = input_.packed_accessor(); - auto devOutput = output_.packed_accessor(); + auto devInput = input_.packed_accessor64(); + auto devOutput = output_.packed_accessor64(); int outputPlaneSize = devOutput.size(2) * devOutput.size(3); dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), @@ -393,8 +393,8 @@ void replication_pad2d_out_cuda_template( devInput, devOutput, padT, padB, padL, padR); } else { output.resize_({numBatch, numPlanes, outputH, outputW}); - auto devInput = input.packed_accessor(); - auto devOutput = output.packed_accessor(); + auto devInput = input.packed_accessor64(); + auto devOutput = output.packed_accessor64(); int outputPlaneSize = devOutput.size(2) * devOutput.size(3); dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), @@ -462,8 +462,8 @@ void replication_pad2d_backward_out_cuda_template( gradInput_ = gradInput.unsqueeze(0); gradOutput_ = gradOutput.unsqueeze(0); } - auto devGradInput = gradInput_.packed_accessor(); - auto devGradOutput = gradOutput_.packed_accessor(); + auto devGradInput = gradInput_.packed_accessor64(); + auto devGradOutput = gradOutput_.packed_accessor64(); int outputPlaneSize = devGradOutput.size(2) * devGradOutput.size(3); dim3 gridSize(THCCeilDiv(outputPlaneSize, 256), @@ -614,8 +614,8 @@ void replication_pad3d_out_cuda_template( output.resize_({numPlanes, outputD, outputH, outputW}); auto input_ = input.unsqueeze(0); auto output_ = output.unsqueeze(0); - auto devInput = input_.packed_accessor(); - auto devOutput = output_.packed_accessor(); + auto devInput = input_.packed_accessor64(); + auto devOutput = output_.packed_accessor64(); int outputPlaneSize = devOutput.size(2) * devOutput.size(3) * devOutput.size(4); @@ -629,8 +629,8 @@ void replication_pad3d_out_cuda_template( devInput, devOutput, pfront, pback, ptop, pbottom, pleft, pright); } else { output.resize_({numBatch, numPlanes, outputD, outputH, outputW}); - auto devInput = input.packed_accessor(); - auto devOutput = output.packed_accessor(); + auto devInput = input.packed_accessor64(); + auto devOutput = output.packed_accessor64(); int outputPlaneSize = devOutput.size(2) * devOutput.size(3) * devOutput.size(4); @@ -689,8 +689,8 @@ void replication_pad3d_backward_out_cuda_template( gradInput_ = gradInput.unsqueeze(0); gradOutput_ = gradOutput.unsqueeze(0); } - auto devGradInput = gradInput_.packed_accessor(); - auto devGradOutput = gradOutput_.packed_accessor(); + auto devGradInput = gradInput_.packed_accessor64(); + auto devGradOutput = gradOutput_.packed_accessor64(); int outputPlaneSize = devGradOutput.size(2) * devGradOutput.size(3) * devGradOutput.size(4); diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu index ea233ad8f7dd3..ad77d95a86b4d 100644 --- a/aten/src/ATen/native/cuda/SoftMax.cu +++ b/aten/src/ATen/native/cuda/SoftMax.cu @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -136,7 +137,7 @@ inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { uint64_t max_block_size = std::min(dim_size / ILP, static_cast(max_threads)); while (block_size < max_block_size) block_size *= 2; // Launch at least a single warp - the kernel assumes that. - block_size = std::max(block_size, static_cast(32)); + block_size = std::max(block_size, static_cast(C10_WARP_SIZE)); return dim3(block_size); } @@ -332,13 +333,13 @@ blockReduce(AccumT* smem, AccumT val, AccumT warpVal = defaultVal; // First warp will perform per-warp reductions for the remaining warps - uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; - if (threadIdx.x < 32) { - int lane = threadIdx.x % 32; - if (lane < blockDim.x / 32) { + uint32_t mask = (((uint64_t)1) << (blockDim.x / C10_WARP_SIZE)) - 1; + if (threadIdx.x < C10_WARP_SIZE) { + int lane = threadIdx.x % C10_WARP_SIZE; + if (lane < blockDim.x / C10_WARP_SIZE) { #pragma unroll - for (int i = 0; i < 32; ++i) { - warpVal = r(warpVal, smem[lane * 32 + i]); + for (int i = 0; i < C10_WARP_SIZE; ++i) { + warpVal = r(warpVal, smem[lane * C10_WARP_SIZE + i]); } #if CUDA_VERSION >= 9000 __syncwarp(mask); @@ -353,7 +354,7 @@ blockReduce(AccumT* smem, AccumT val, AccumT blockVal = defaultVal; if (threadIdx.x == 0) { - for (int i = 0; i < blockDim.x / 32; ++i) { + for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) { blockVal = r(blockVal, smem[i]); } smem[0] = blockVal; diff --git a/aten/src/ATen/native/cuda/SortingCommon.cuh b/aten/src/ATen/native/cuda/SortingCommon.cuh index 93069eadba491..54513955e9127 100644 --- a/aten/src/ATen/native/cuda/SortingCommon.cuh +++ b/aten/src/ATen/native/cuda/SortingCommon.cuh @@ -14,11 +14,9 @@ namespace at { namespace native { #if defined(__HIP_PLATFORM_HCC__) -constexpr int WARP_SIZE = 64; constexpr int MAX_BLOCK_SIZE = 256; #else -constexpr int WARP_SIZE = 32; constexpr int MAX_BLOCK_SIZE = 1024; #endif diff --git a/aten/src/ATen/native/cuda/SortingKthValue.cu b/aten/src/ATen/native/cuda/SortingKthValue.cu index 2c2c63cc06cc9..02350b0063b17 100644 --- a/aten/src/ATen/native/cuda/SortingKthValue.cu +++ b/aten/src/ATen/native/cuda/SortingKthValue.cu @@ -40,7 +40,7 @@ __global__ void gatherKthValue( cuda::detail::TensorInfo indices) { // Indices are limited to integer fp precision, so counts can fit in // int32, regardless of index_t - __shared__ int smem[WARP_SIZE]; // one per each warp, up to warp limit + __shared__ int smem[C10_WARP_SIZE]; // one per each warp, up to warp limit index_t slice = getLinearBlockId(); if (slice >= numInputSlices) { @@ -117,7 +117,7 @@ struct KthValueLauncher { } dim3 block( - std::min(THCRoundUp(slice_size, (int64_t)WARP_SIZE), (int64_t)1024)); + std::min(THCRoundUp(slice_size, (int64_t)C10_WARP_SIZE), (int64_t)1024)); auto stream = at::cuda::getCurrentCUDAStream(); gatherKthValue<<>>( self_info, diff --git a/aten/src/ATen/native/cuda/SortingRadixSelect.cuh b/aten/src/ATen/native/cuda/SortingRadixSelect.cuh index ea340cdd9b615..af26a4e30b29d 100644 --- a/aten/src/ATen/native/cuda/SortingRadixSelect.cuh +++ b/aten/src/ATen/native/cuda/SortingRadixSelect.cuh @@ -229,7 +229,7 @@ __device__ scalar_t findPattern( index_t withinSliceStride, bitwise_t desired, bitwise_t desiredMask) { - if (threadIdx.x < WARP_SIZE) { + if (threadIdx.x < C10_WARP_SIZE) { smem[threadIdx.x] = static_cast(0); } __syncthreads(); diff --git a/aten/src/ATen/native/cuda/SummaryOps.cu b/aten/src/ATen/native/cuda/SummaryOps.cu index ada0dd356269b..11a195365ff1e 100644 --- a/aten/src/ATen/native/cuda/SummaryOps.cu +++ b/aten/src/ATen/native/cuda/SummaryOps.cu @@ -17,7 +17,7 @@ namespace cuda { enum class CUDAHistogramMemoryType { SHARED, MULTI_BLOCK, GLOBAL }; namespace { template - __device__ static IndexType getBin(input_t bVal, input_t minvalue, input_t maxvalue, int nbins) { + __device__ static IndexType getBin(input_t bVal, input_t minvalue, input_t maxvalue, int64_t nbins) { IndexType bin = (int)((bVal - minvalue) * nbins / (maxvalue - minvalue)); // (only applicable for histc) // while each bin is inclusive at the lower end and exclusive at the higher, i.e. [start, end) @@ -47,7 +47,7 @@ __global__ void kernelHistogram1D( detail::TensorInfo a, /* output */ detail::TensorInfo p, /* partial output */ detail::TensorInfo b, /* input */ - int nbins, + int64_t nbins, input_t minvalue, input_t maxvalue, IndexType totalElements, diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu index 8cd8e17579428..e1c3e73e0c2ab 100644 --- a/aten/src/ATen/native/cuda/TensorCompare.cu +++ b/aten/src/ATen/native/cuda/TensorCompare.cu @@ -48,7 +48,7 @@ Tensor _s_where_cuda( const Tensor& self, const Tensor& other) { Tensor ret = at::empty(self.sizes(), self.options()); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ret.scalar_type(), "where_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, ret.scalar_type(), "where_cuda", [&] { where_cuda(ret, condition, self, other); }); return ret; diff --git a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu index e832856c5b607..5415236a91ca7 100644 --- a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu @@ -3,9 +3,9 @@ #include #include #include -#include #include #include +#include namespace at { namespace native { @@ -48,6 +48,34 @@ void neg_kernel_cuda(TensorIterator& iter) { }); } +// We manually overload nearbyint because std::nearbyint does not work with ROCm. +template +__host__ __device__ static inline scalar_t nearbyint_wrapper(scalar_t a) { + return static_cast(::nearbyintf(static_cast(a))); +} + +__host__ __device__ static inline double nearbyint_wrapper(double a) { + return ::nearbyint(a); +} + +void round_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "round_cuda", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + // We do not use std::round because we would like to round midway numbers to the nearest even integer. + return nearbyint_wrapper(a); + }); + }); +} + +void rsqrt_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "rsqrt_cuda", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + // In CUDA, ::rsqrt is overloaded for float and at::Half here is implicitly cast to float. + return ::rsqrt(a); + }); + }); +} + void sign_kernel_cuda(TensorIterator& iter){ if (iter.dtype() == ScalarType::Bool) { gpu_kernel(iter, []GPU_LAMBDA(bool a){ @@ -66,15 +94,43 @@ void sign_kernel_cuda(TensorIterator& iter){ void erfinv_kernel_cuda(TensorIterator& iter) { AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "erfinv_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { - return erfinvf(a); + return ::erfinv(a); + }); + }); +} + +void digamma_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "digamma_cuda", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return calc_digamma(a); + }); + }); +} + +void trigamma_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "trigamma_cuda", [&]() { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return calc_trigamma(a); }); }); } +void polygamma_kernel_cuda(TensorIterator& iter, int64_t n) { + switch (n) { + case 0: digamma_kernel_cuda(iter); break; + case 1: trigamma_kernel_cuda(iter); break; + default: TORCH_CHECK(false, "polygamma(n,x) is not implemented for n>=2, but was ", n); + } +} + REGISTER_DISPATCH(bitwise_not_stub, &bitwise_not_kernel_cuda); REGISTER_DISPATCH(logical_not_stub, &logical_not_kernel_cuda); REGISTER_DISPATCH(ceil_stub, &ceil_kernel_cuda); REGISTER_DISPATCH(neg_stub, &neg_kernel_cuda); +REGISTER_DISPATCH(round_stub, &round_kernel_cuda); +REGISTER_DISPATCH(rsqrt_stub, &rsqrt_kernel_cuda); REGISTER_DISPATCH(sign_stub, &sign_kernel_cuda); REGISTER_DISPATCH(erfinv_stub, &erfinv_kernel_cuda); +REGISTER_DISPATCH(digamma_stub, &digamma_kernel_cuda); +REGISTER_DISPATCH(polygamma_stub, &polygamma_kernel_cuda); }} diff --git a/aten/src/ATen/native/cuda/UpSample.cuh b/aten/src/ATen/native/cuda/UpSample.cuh index 3b398e27cb6e5..0bde9149136a3 100644 --- a/aten/src/ATen/native/cuda/UpSample.cuh +++ b/aten/src/ATen/native/cuda/UpSample.cuh @@ -166,7 +166,7 @@ __device__ __forceinline__ static int nearest_neighbor_compute_source_index( /* Used by UpSampleBicubic2d.cu */ template __device__ __forceinline__ static scalar_t upsample_get_value_bounded( - const PackedTensorAccessor& data, + const PackedTensorAccessor64& data, int batch, int channel, int height, @@ -181,7 +181,7 @@ __device__ __forceinline__ static scalar_t upsample_get_value_bounded( /* Used by UpSampleBicubic2d.cu */ template __device__ __forceinline__ static void upsample_increment_value_bounded( - PackedTensorAccessor& data, + PackedTensorAccessor64& data, int batch, int channel, int height, diff --git a/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu b/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu index 443e88ec078b8..cd03005172890 100644 --- a/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu +++ b/aten/src/ATen/native/cuda/UpSampleBicubic2d.cu @@ -18,8 +18,8 @@ __global__ void upsample_bicubic2d_out_frame( const accscalar_t height_scale, const accscalar_t width_scale, const bool align_corners, - const PackedTensorAccessor idata, - PackedTensorAccessor odata) { + const PackedTensorAccessor64 idata, + PackedTensorAccessor64 odata) { int index = threadIdx.x + blockIdx.x * blockDim.x; const int batchsize = idata.size(0); @@ -93,8 +93,8 @@ __global__ void upsample_bicubic2d_backward_out_frame( const accscalar_t height_scale, const accscalar_t width_scale, const bool align_corners, - PackedTensorAccessor idata, - const PackedTensorAccessor odata) { + PackedTensorAccessor64 idata, + const PackedTensorAccessor64 odata) { int index = threadIdx.x + blockIdx.x * blockDim.x; const int batchsize = idata.size(0); @@ -206,8 +206,8 @@ static void upsample_bicubic2d_out_cuda_template( input.scalar_type(), "upsample_bicubic2d_out_frame", [&] { using accscalar_t = at::acc_type; - auto idata = input.packed_accessor(); - auto odata = output.packed_accessor(); + auto idata = input.packed_accessor64(); + auto odata = output.packed_accessor64(); // Get scaling factors const accscalar_t rheight = area_pixel_compute_scale( @@ -285,8 +285,8 @@ static void upsample_bicubic2d_backward_out_cuda_template( grad_output.scalar_type(), "upsample_bicubic2d_backward_out_frame", [&] { using accscalar_t = at::acc_type; - auto idata = grad_input.packed_accessor(); - auto odata = grad_output.packed_accessor(); + auto idata = grad_input.packed_accessor64(); + auto odata = grad_output.packed_accessor64(); const accscalar_t rheight = area_pixel_compute_scale( input_height, output_height, align_corners); diff --git a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu index 1f3f566893cc6..d8a8ed8904fa3 100644 --- a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu +++ b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu @@ -197,8 +197,8 @@ static void upsample_bilinear2d_out_cuda_template( input.scalar_type(), "upsample_bilinear2d_out_frame", [&] { using accscalar_t = at::acc_type; - auto idata = input.packed_accessor(); - auto odata = output.packed_accessor(); + auto idata = input.packed_accessor64(); + auto odata = output.packed_accessor64(); const accscalar_t rheight = area_pixel_compute_scale( input_height, output_height, align_corners); diff --git a/aten/src/ATen/native/cuda/UpSampleLinear1d.cu b/aten/src/ATen/native/cuda/UpSampleLinear1d.cu index 0f70b57344cb6..b4fc8d5a5afd9 100644 --- a/aten/src/ATen/native/cuda/UpSampleLinear1d.cu +++ b/aten/src/ATen/native/cuda/UpSampleLinear1d.cu @@ -21,8 +21,8 @@ __global__ void upsample_linear1d_out_frame( const int n, const accscalar_t rwidth, const bool align_corners, - const PackedTensorAccessor idata, - PackedTensorAccessor odata) { + const PackedTensorAccessor64 idata, + PackedTensorAccessor64 odata) { int index = threadIdx.x + blockIdx.x * blockDim.x; const int batchsize = idata.size(0); @@ -70,8 +70,8 @@ __global__ void upsample_linear1d_out_frame_backward( const int n, const accscalar_t rwidth, const bool align_corners, - PackedTensorAccessor idata, - const PackedTensorAccessor odata) { + PackedTensorAccessor64 idata, + const PackedTensorAccessor64 odata) { int index = threadIdx.x + blockIdx.x * blockDim.x; const int batchsize = idata.size(0); @@ -147,8 +147,8 @@ static void upsample_linear1d_out_cuda_template( input.scalar_type(), "upsample_linear1d_out_frame", [&] { using accscalar_t = at::acc_type; - auto idata = input.packed_accessor(); - auto odata = output.packed_accessor(); + auto idata = input.packed_accessor64(); + auto odata = output.packed_accessor64(); const accscalar_t rwidth = area_pixel_compute_scale( input_width, output_width, align_corners); @@ -207,8 +207,8 @@ static void upsample_linear1d_backward_out_cuda_template( grad_output.scalar_type(), "upsample_linear1d_out_frame_backward", [&] { using accscalar_t = at::acc_type; - auto idata = grad_input.packed_accessor(); - auto odata = grad_output.packed_accessor(); + auto idata = grad_input.packed_accessor64(); + auto odata = grad_output.packed_accessor64(); const accscalar_t rwidth = area_pixel_compute_scale( input_width, output_width, align_corners); diff --git a/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu b/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu index 683860e8a466b..73799b088a64e 100644 --- a/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu +++ b/aten/src/ATen/native/cuda/UpSampleTrilinear3d.cu @@ -21,8 +21,8 @@ __global__ void upsample_trilinear3d_out_frame( const accscalar_t rheight, const accscalar_t rwidth, const bool align_corners, - const PackedTensorAccessor idata, - PackedTensorAccessor odata) { + const PackedTensorAccessor64 idata, + PackedTensorAccessor64 odata) { int index = threadIdx.x + blockIdx.x * blockDim.x; const int batchsize = idata.size(0); @@ -105,8 +105,8 @@ __global__ void upsample_trilinear3d_backward_out_frame( const accscalar_t rheight, const accscalar_t rwidth, const bool align_corners, - PackedTensorAccessor idata, - const PackedTensorAccessor odata) { + PackedTensorAccessor64 idata, + const PackedTensorAccessor64 odata) { int index = threadIdx.x + blockIdx.x * blockDim.x; const int batchsize = idata.size(0); @@ -245,8 +245,8 @@ static void upsample_trilinear3d_out_cuda_template( input.scalar_type(), "upsample_trilinear3d_out_frame", [&] { using accscalar_t = at::acc_type; - auto idata = input.packed_accessor(); - auto odata = output.packed_accessor(); + auto idata = input.packed_accessor64(); + auto odata = output.packed_accessor64(); const accscalar_t rdepth = area_pixel_compute_scale( input_depth, output_depth, align_corners); @@ -332,8 +332,8 @@ static void upsample_trilinear3d_backward_out_cuda_template( [&] { using accscalar_t = at::acc_type; - auto idata = grad_input.packed_accessor(); - auto odata = grad_output.packed_accessor(); + auto idata = grad_input.packed_accessor64(); + auto odata = grad_output.packed_accessor64(); const accscalar_t rdepth = area_pixel_compute_scale( input_depth, output_depth, align_corners); diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp index 914fd70613cf3..e21a8834fca4c 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp +++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp @@ -46,7 +46,8 @@ Tensor new_with_itensor_mkldnn(ideep::tensor&& it, const TensorOptions& options) auto dims = it.get_dims(); IDeepTensorWrapperPtr handle = c10::make_intrusive(std::move(it)); return detail::make_tensor( - TensorTypeId::MkldnnCPUTensorId, options.dtype(), options.device(), handle, + TensorTypeSet(TensorTypeId::MkldnnCPUTensorId), + options.dtype(), options.device(), handle, std::vector(dims.begin(), dims.end())); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 63942567fda55..3cb52336aae22 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6,27 +6,35 @@ # specialized operators for each datatype. # TODO: remove when we have Type support in the IR - func: _cast_Byte(Tensor self, bool non_blocking=False) -> Tensor + use_c10_dispatcher: True variants: function - func: _cast_Char(Tensor self, bool non_blocking=False) -> Tensor + use_c10_dispatcher: True variants: function - func: _cast_Double(Tensor self, bool non_blocking=False) -> Tensor + use_c10_dispatcher: True variants: function - func: _cast_Float(Tensor self, bool non_blocking=False) -> Tensor + use_c10_dispatcher: True variants: function - func: _cast_Int(Tensor self, bool non_blocking=False) -> Tensor + use_c10_dispatcher: True variants: function - func: _cast_Long(Tensor self, bool non_blocking=False) -> Tensor + use_c10_dispatcher: True variants: function - func: _cast_Short(Tensor self, bool non_blocking=False) -> Tensor + use_c10_dispatcher: True variants: function - func: _cast_Half(Tensor self, bool non_blocking=False) -> Tensor + use_c10_dispatcher: True variants: function - func: backward(Tensor self, Tensor? gradient=None, bool keep_graph=False, bool create_graph=False) -> void @@ -35,26 +43,56 @@ - func: set_data(Tensor(a!) self, Tensor new_data) -> void variants: method +- func: data(Tensor self) -> Tensor + use_c10_dispatcher: True + variants: method + +- func: is_leaf(Tensor self) -> bool + variants: method + +- func: output_nr(Tensor self) -> int + variants: method + - func: names_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!) variants: method - named_guard: False + supports_named_tensor: True -- func: view_names(Tensor(a) self, Dimname[]? names) -> Tensor(a) +- func: renamed(Tensor(a) self, Dimname[]? names) -> Tensor(a) variants: method - named_guard: False + supports_named_tensor: True -- func: align_to(Tensor self, DimnameList names) -> Tensor +- func: align_to(Tensor(a) self, DimnameList names) -> Tensor(a) variants: function, method - named_guard: False + supports_named_tensor: True + +- func: align_as(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True + variants: method + supports_named_tensor: True - func: align_tensors(Tensor[] tensors) -> Tensor[] - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True + +- func: refine_names(Tensor(a) self, DimnameList names) -> Tensor(a) + variants: method + supports_named_tensor: True + +- func: unflatten(Tensor self, Dimname dim, int[] sizes, DimnameList names) -> Tensor + variants: method + supports_named_tensor: True + +- func: unflatten(Tensor self, int dim, int[] sizes, DimnameList names) -> Tensor + variants: method + supports_named_tensor: True - func: _cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) + use_c10_dispatcher: True dispatch: CUDA: _cudnn_ctc_loss - func: _cudnn_rnn_flatten_weight(Tensor[] weight_arr, int weight_stride0, int input_size, int mode, int hidden_size, int num_layers, bool batch_first, bool bidirectional) -> Tensor + use_c10_dispatcher: True dispatch: CUDA: _cudnn_rnn_flatten_weight @@ -71,15 +109,17 @@ CUDA: _cudnn_init_dropout_state - func: _debug_has_internal_overlap(Tensor self) -> int + use_c10_dispatcher: True variants: function - func: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor) variants: function dispatch: CUDA: fused_dropout_cuda - named_guard: False + supports_named_tensor: True - func: _masked_scale(Tensor self, Tensor mask, float scale) -> Tensor + use_c10_dispatcher: True variants: function dispatch: CUDA: masked_scale_cuda @@ -87,155 +127,190 @@ - func: _sobol_engine_draw(Tensor quasi, int n, Tensor sobolstate, int dimension, int num_generated, ScalarType? dtype) -> (Tensor, Tensor) - func: _sobol_engine_ff_(Tensor(a!) self, int n, Tensor sobolstate, int dimension, int num_generated) -> Tensor(a!) + use_c10_dispatcher: True - func: _sobol_engine_scramble_(Tensor(a!) self, Tensor ltm, int dimension) -> Tensor(a!) + use_c10_dispatcher: True - func: _sobol_engine_initialize_state_(Tensor(a!) self, int dimension) -> Tensor(a!) + use_c10_dispatcher: True - func: _reshape_from_tensor(Tensor self, Tensor shape) -> Tensor + use_c10_dispatcher: True - func: _shape_as_tensor(Tensor self) -> Tensor + use_c10_dispatcher: True - func: dropout(Tensor input, float p, bool train) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True + use_c10_dispatcher: True - func: dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True + use_c10_dispatcher: True - func: feature_dropout(Tensor input, float p, bool train) -> Tensor + use_c10_dispatcher: True - func: feature_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) + use_c10_dispatcher: True - func: alpha_dropout(Tensor input, float p, bool train) -> Tensor + use_c10_dispatcher: True - func: alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) + use_c10_dispatcher: True - func: feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor + use_c10_dispatcher: True - func: feature_alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) + use_c10_dispatcher: True - func: abs(Tensor self) -> Tensor + use_c10_dispatcher: True variants: function, method - named_guard: False + supports_named_tensor: True - func: abs_(Tensor(a!) self) -> Tensor(a!) + use_c10_dispatcher: True variants: function, method - named_guard: False + supports_named_tensor: True dispatch: CPU: _abs__cpu CUDA: _abs__cuda - func: abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _abs_out_cpu CUDA: _abs_out_cuda - func: acos(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: acos_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _acos__cpu CUDA: _acos__cuda - func: acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _acos_out_cpu CUDA: _acos_out_cuda - func: avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor + use_c10_dispatcher: True - func: adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor + use_c10_dispatcher: True # Return: (Tensor output, Tensor indices) - func: adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor) + use_c10_dispatcher: True - func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: CPU: add CUDA: add - SparseCPU: add - SparseCUDA: add + SparseCPU: add_sparse + SparseCUDA: add_sparse MkldnnCPU: mkldnn_add - named_guard: False + supports_named_tensor: True - func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: add_ CUDA: add_ - SparseCPU: add_ - SparseCUDA: add_ + SparseCPU: add_sparse_ + SparseCUDA: add_sparse_ MkldnnCPU: mkldnn_add_ - named_guard: False + supports_named_tensor: True - func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: add_out CUDA: add_out - SparseCPU: add_out - SparseCUDA: add_out + SparseCPU: add_out_sparse_cpu + SparseCUDA: add_out_sparse_cuda MkldnnCPU: mkldnn_add_out - named_guard: False + supports_named_tensor: True # For C++ only, until we have conversion from C++ numbers to Tensor - func: add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + use_c10_dispatcher: True variants: function, method - named_guard: False + supports_named_tensor: True - func: add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) + use_c10_dispatcher: True variants: method - named_guard: False + supports_named_tensor: True - func: addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: CPU: legacy::cpu::_th_addmv CUDA: legacy::cuda::_th_addmv - named_guard: False + supports_named_tensor: True - func: addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + use_c10_dispatcher: True variants: function, method dispatch: CPU: legacy::cpu::_th_addmv_ CUDA: legacy::cuda::_th_addmv_ - named_guard: False + supports_named_tensor: True - func: addmv.out(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: legacy::cpu::_th_addmv_out CUDA: legacy::cuda::_th_addmv_out - named_guard: False + supports_named_tensor: True - func: addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + use_c10_dispatcher: True variants: function, method - func: addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + use_c10_dispatcher: True variants: method - func: addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) - func: affine_grid_generator(Tensor theta, int[] size, bool align_corners) -> Tensor + use_c10_dispatcher: True variants: function - func: affine_grid_generator_backward(Tensor grad, int[] size, bool align_corners) -> Tensor + use_c10_dispatcher: True variants: function - func: all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor + use_c10_dispatcher: True variants: function, method - func: all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) - func: allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool + use_c10_dispatcher: True variants: function, method - func: any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor + use_c10_dispatcher: True variants: function, method - func: any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) @@ -259,73 +334,85 @@ # preserve tracing. Get rid of this when arange can directly take tensors for bounds # (so that it can be traced directly). - func: _dim_arange(Tensor like, int dim) -> Tensor + use_c10_dispatcher: True - func: argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor + use_c10_dispatcher: True variants: function, method - func: argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor + use_c10_dispatcher: True variants: function, method - func: as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a) + use_c10_dispatcher: True variants: function, method dispatch: CPU: as_strided_tensorimpl CUDA: as_strided_tensorimpl QuantizedCPU: as_strided_qtensorimpl device_guard: False - named_guard: False + supports_named_tensor: True - func: as_strided_(Tensor(a!) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a!) + use_c10_dispatcher: True variants: function, method device_guard: False - func: asin(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: asin_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _asin__cpu CUDA: _asin__cuda - func: asin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _asin_out_cpu CUDA: _asin_out_cuda - func: atan(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: atan_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _atan__cpu CUDA: _atan__cuda - func: atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _atan_out_cpu CUDA: _atan_out_cuda - func: baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: CPU: baddbmm_cpu CUDA: baddbmm_cuda - func: baddbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: baddbmm__cpu CUDA: baddbmm__cuda - func: _baddbmm_mkl_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + use_c10_dispatcher: True variants: function - func: baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) @@ -347,25 +434,25 @@ # Sample bernoulli with values in `self` as probability. - func: bernoulli(Tensor self, *, Generator? generator=None) -> Tensor variants: function, method - named_guard: False + supports_named_tensor: True - func: bernoulli.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) variants: function - named_guard: False + supports_named_tensor: True - func: bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: CPU: bernoulli_tensor_cpu_ CUDA: bernoulli_tensor_cuda_ - named_guard: False + supports_named_tensor: True - func: bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: CPU: bernoulli_scalar_cpu_ CUDA: bernoulli_scalar_cuda_ - named_guard: False + supports_named_tensor: True # This out-of-place version isn't used explicitly, but needed by jit. # There is no default valid on `p` here because it would introduce ambiguity @@ -388,9 +475,11 @@ CUDA: _bincount_cuda - func: bitwise_not(Tensor self) -> Tensor + use_c10_dispatcher: True variants: function, method - func: bitwise_not_(Tensor(a!) self) -> Tensor(a!) + use_c10_dispatcher: True variants: method - func: bitwise_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) @@ -399,9 +488,11 @@ CUDA: bitwise_not_out - func: logical_not(Tensor self) -> Tensor + use_c10_dispatcher: True variants: function, method - func: logical_not_(Tensor(a!) self) -> Tensor(a!) + use_c10_dispatcher: True variants: method - func: logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) @@ -410,134 +501,152 @@ CUDA: logical_not_out - func: logical_xor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: function, method - named_guard: False + supports_named_tensor: True - func: logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: True variants: method - named_guard: False + supports_named_tensor: True - func: logical_xor.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: logical_xor_out CUDA: logical_xor_out - named_guard: False + supports_named_tensor: True - func: blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: blackman_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: bmm(Tensor self, Tensor mat2) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: CPU: bmm_cpu CUDA: bmm_cuda - named_guard: False + supports_named_tensor: True - func: bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: CPU: bmm_out_cpu CUDA: bmm_out_cuda - named_guard: False + supports_named_tensor: True - func: broadcast_tensors(Tensor[] tensors) -> Tensor[] + use_c10_dispatcher: True device_guard: False - func: cat(Tensor[] tensors, int dim=0) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True + use_c10_dispatcher: True - func: cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True - func: cat.names(Tensor[] tensors, Dimname dim) -> Tensor - named_guard: False + supports_named_tensor: True - func: cat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True - func: ceil(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: ceil_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: ceil.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: ceil_out CUDA: ceil_out - func: chain_matmul(Tensor[] matrices) -> Tensor + use_c10_dispatcher: True variants: function - func: chunk(Tensor(a) self, int chunks, int dim=0) -> Tensor(a)[] + use_c10_dispatcher: True variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True - func: clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _clamp__cpu CUDA: _clamp__cuda - func: clamp.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _clamp_out_cpu CUDA: _clamp_out_cuda - func: clamp_max(Tensor self, Scalar max) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: clamp_max_(Tensor(a!) self, Scalar max) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _clamp_max__cpu CUDA: _clamp_max__cuda - func: clamp_max.out(Tensor self, Scalar max, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _clamp_max_out_cpu CUDA: _clamp_max_out_cuda - func: clamp_min(Tensor self, Scalar min) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _clamp_min__cpu CUDA: _clamp_min__cuda - func: clamp_min.out(Tensor self, Scalar min, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _clamp_min_out_cpu CUDA: _clamp_min_out_cuda - func: cudnn_is_acceptable(Tensor self) -> bool + use_c10_dispatcher: True device_guard: False - func: constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> Tensor + use_c10_dispatcher: True variants: function - func: contiguous(Tensor self, *, MemoryFormat memory_format=contiguous_format) -> Tensor variants: method - named_guard: False + supports_named_tensor: True - func: convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor @@ -558,8 +667,10 @@ - func: conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor - func: conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor + use_c10_dispatcher: True - func: conv_tbc_backward(Tensor self, Tensor input, Tensor weight, Tensor bias, int pad) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: True # NB: we inherit the goofy argument order from PyTorch torch.nn.functional - func: conv_transpose1d(Tensor input, Tensor weight, Tensor? bias=None, int[1] stride=1, int[1] padding=0, int[1] output_padding=0, int groups=1, int[1] dilation=1) -> Tensor @@ -569,55 +680,64 @@ - func: conv_transpose3d.input(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] output_padding=0, int groups=1, int[3] dilation=1) -> Tensor - func: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) + use_c10_dispatcher: True variants: method device_guard: False - named_guard: False + supports_named_tensor: True - func: _copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor + use_c10_dispatcher: True dispatch: {} - func: cos(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: cos_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _cos__cpu CUDA: _cos__cuda - func: cos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _cos_out_cpu CUDA: _cos_out_cuda - func: cosh(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: cosh_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _cosh__cpu CUDA: _cosh__cuda - func: cosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _cosh_out_cpu CUDA: _cosh_out_cuda - func: cosine_embedding_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor + use_c10_dispatcher: True - func: cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid + use_c10_dispatcher: True dispatch: CUDA: cudnn_affine_grid_generator_forward # TODO: Why do I have to call this grad?! - func: cudnn_affine_grid_generator_backward(Tensor grad, int N, int C, int H, int W) -> Tensor grad_theta + use_c10_dispatcher: True dispatch: CUDA: cudnn_affine_grid_generator_backward @@ -635,18 +755,22 @@ CUDA: cudnn_convolution - func: cudnn_convolution_backward_input(int[] self_size, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor + use_c10_dispatcher: True dispatch: CUDA: cudnn_convolution_backward_input - func: cudnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: True dispatch: CUDA: cudnn_convolution_backward - func: cudnn_convolution_backward_bias(Tensor grad_output) -> Tensor + use_c10_dispatcher: True dispatch: CUDA: cudnn_convolution_backward_bias - func: cudnn_convolution_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor + use_c10_dispatcher: True dispatch: CUDA: cudnn_convolution_backward_weight @@ -657,27 +781,33 @@ # NB: output_padding not strictly needed here, but it's helpful for the float # backwards - func: cudnn_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: True dispatch: CUDA: cudnn_convolution_transpose_backward - func: cudnn_convolution_transpose_backward_bias(Tensor grad_output) -> Tensor + use_c10_dispatcher: True dispatch: CUDA: cudnn_convolution_backward_bias - func: cudnn_convolution_transpose_backward_input(Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor + use_c10_dispatcher: True dispatch: CUDA: cudnn_convolution_transpose_backward_input - func: cudnn_convolution_transpose_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor + use_c10_dispatcher: True dispatch: CUDA: cudnn_convolution_transpose_backward_weight # NB: input is special cased in a way I don't quite understand - func: cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output + use_c10_dispatcher: True dispatch: CUDA: cudnn_grid_sampler_forward - func: cudnn_grid_sampler_backward(Tensor self, Tensor grid, Tensor grad_output) -> (Tensor grad_self, Tensor grad_grid) + use_c10_dispatcher: True dispatch: CUDA: cudnn_grid_sampler_backward @@ -692,82 +822,117 @@ - func: cumprod.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) - func: ctc_loss.IntList(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor + use_c10_dispatcher: True # convenience function that converts to intlists for you - func: ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor + use_c10_dispatcher: True - func: _ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) + use_c10_dispatcher: True dispatch: CPU: ctc_loss_cpu CUDA: ctc_loss_gpu - func: _ctc_loss_backward(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor + use_c10_dispatcher: True dispatch: CPU: ctc_loss_backward_cpu CUDA: ctc_loss_backward_gpu - func: det(Tensor self) -> Tensor + use_c10_dispatcher: True variants: function, method - func: diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor + use_c10_dispatcher: True variants: function, method - func: diagflat(Tensor self, int offset=0) -> Tensor + use_c10_dispatcher: True variants: function, method - func: diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a) + use_c10_dispatcher: True variants: function, method - func: fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!) + use_c10_dispatcher: True variants: method - func: div.Tensor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: function, method - named_guard: False + dispatch: + CPU: div + CUDA: div + SparseCPU: div_sparse + SparseCUDA: div_sparse + supports_named_tensor: True - func: div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: True variants: method - named_guard: False + dispatch: + CPU: div_ + CUDA: div_ + SparseCPU: div_sparse_ + SparseCUDA: div_sparse_ + supports_named_tensor: True - func: div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + dispatch: + CPU: div_out + CUDA: div_out + SparseCPU: div_out_sparse_zerodim + SparseCUDA: div_out_sparse_zerodim + supports_named_tensor: True # For C++ only, until we have conversion from C++ numbers to Tensor - func: div.Scalar(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: True variants: function, method - named_guard: False + supports_named_tensor: True - func: div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + use_c10_dispatcher: True variants: method - named_guard: False + supports_named_tensor: True - func: dot(Tensor self, Tensor tensor) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: CPU: legacy::cpu::_th_dot CUDA: legacy::cuda::_th_dot - named_guard: False + supports_named_tensor: True - func: dot.out(Tensor self, Tensor tensor, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True - func: einsum(str equation, Tensor[] tensors) -> Tensor + use_c10_dispatcher: True - func: embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor + use_c10_dispatcher: True - func: embedding_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor + use_c10_dispatcher: True - func: embedding_dense_backward(Tensor grad_output, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor + use_c10_dispatcher: True dispatch: CPU: embedding_dense_backward_cpu CUDA: embedding_dense_backward_cuda - func: embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!) + use_c10_dispatcher: True dispatch: CPU: embedding_renorm_cpu_ CUDA: embedding_renorm_cuda_ - func: embedding_sparse_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor + use_c10_dispatcher: True # NOTE [ embedding_bag Native Functions ] # The `_embedding_bag.*` variants assume that input tensors except for `weight`, @@ -795,6 +960,7 @@ CUDA: _embedding_bag_dense_backward_cuda - func: _embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode) -> Tensor + use_c10_dispatcher: True dispatch: CPU: _embedding_bag_per_sample_weights_backward_cpu CUDA: _embedding_bag_per_sample_weights_backward_cuda @@ -825,6 +991,7 @@ QuantizedCPU: empty_per_channel_affine_quantized_cpu - func: resize_(Tensor(a!) self, int[] size) -> Tensor(a!) + use_c10_dispatcher: True variants: method device_guard: False dispatch: @@ -835,12 +1002,13 @@ device_guard: False - func: empty_like(Tensor self) -> Tensor + use_c10_dispatcher: True device_guard: False - named_guard: False + supports_named_tensor: True - func: empty_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, MemoryFormat? memory_format=contiguous_format) -> Tensor device_guard: False - named_guard: False + supports_named_tensor: True - func: empty_strided(int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor dispatch: @@ -848,79 +1016,89 @@ CUDA: empty_strided_cuda - func: erf(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: erf_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _erf__cpu CUDA: _erf__cuda - func: erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _erf_out_cpu CUDA: _erf_out_cuda - func: erfc(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: erfc_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _erfc__cpu CUDA: _erfc__cuda - func: erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _erfc_out_cpu CUDA: _erfc_out_cuda - func: exp(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: exp_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _exp__cpu CUDA: _exp__cuda - func: exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _exp_out_cpu CUDA: _exp_out_cuda - func: expm1(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: expm1_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _expm1__cpu CUDA: _expm1__cuda - func: expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _expm1_out_cpu CUDA: _expm1_out_cuda - func: expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a) + use_c10_dispatcher: True variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. device_guard: False - named_guard: False + supports_named_tensor: True - func: expand_as(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. device_guard: False @@ -938,47 +1116,67 @@ CPU: eye_out_cpu CUDA: eye_out_cuda -- func: flatten(Tensor self, int start_dim=0, int end_dim=-1) -> Tensor +- func: flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> Tensor + use_c10_dispatcher: True + variants: function, method + supports_named_tensor: True + +- func: flatten.named_out_dim(Tensor self, int start_dim, int end_dim, Dimname out_dim) -> Tensor + variants: function, method + supports_named_tensor: True + +- func: flatten.using_names(Tensor self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor + variants: function, method + supports_named_tensor: True + +- func: flatten.DimnameList(Tensor self, DimnameList dims, Dimname out_dim) -> Tensor variants: function, method + supports_named_tensor: True - func: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: floor(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: floor_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _floor__cpu CUDA: _floor__cuda - func: floor.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _floor_out_cpu CUDA: _floor_out_cuda - func: frac(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: frac_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _frac__cpu CUDA: _frac__cuda - func: frac.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _frac_out_cpu CUDA: _frac_out_cuda @@ -991,6 +1189,7 @@ - func: full.out(int[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) - func: full_like(Tensor self, Scalar fill_value) -> Tensor + use_c10_dispatcher: True - func: full_like.dtype(Tensor self, Scalar fill_value, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor @@ -1010,23 +1209,28 @@ # Nor does it take in `align_corners` because it only supports the mode # `align_corners = True`. - func: grid_sampler(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + use_c10_dispatcher: True - func: grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + use_c10_dispatcher: True dispatch: CPU: grid_sampler_2d_cpu CUDA: grid_sampler_2d_cuda - func: grid_sampler_2d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> (Tensor, Tensor) + use_c10_dispatcher: True dispatch: CPU: grid_sampler_2d_backward_cpu CUDA: grid_sampler_2d_backward_cuda - func: grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + use_c10_dispatcher: True dispatch: CPU: grid_sampler_3d_cpu CUDA: grid_sampler_3d_cuda - func: grid_sampler_3d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> (Tensor, Tensor) + use_c10_dispatcher: True dispatch: CPU: grid_sampler_3d_backward_cpu CUDA: grid_sampler_3d_backward_cuda @@ -1044,8 +1248,10 @@ - func: hamming_window.periodic_alpha_beta(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor - func: hinge_embedding_loss(Tensor self, Tensor target, float margin=1.0, int reduction=Mean) -> Tensor + use_c10_dispatcher: True - func: ger(Tensor self, Tensor vec2) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: CPU: legacy::cpu::_th_ger @@ -1061,39 +1267,48 @@ # FFT - func: fft(Tensor self, int signal_ndim, bool normalized=False) -> Tensor + use_c10_dispatcher: True variants: function, method - func: ifft(Tensor self, int signal_ndim, bool normalized=False) -> Tensor + use_c10_dispatcher: True variants: function, method - func: rfft(Tensor self, int signal_ndim, bool normalized=False, bool onesided=True) -> Tensor + use_c10_dispatcher: True variants: function, method - func: irfft(Tensor self, int signal_ndim, bool normalized=False, bool onesided=True, int[] signal_sizes=[]) -> Tensor + use_c10_dispatcher: True variants: function, method - func: _fft_with_size(Tensor self, int signal_ndim, bool complex_input, bool complex_output, bool inverse, int[] checked_signal_sizes, bool normalized, bool onesided, int[] output_sizes) -> Tensor + use_c10_dispatcher: True variants: function dispatch: CPU: _fft_mkl CUDA: _fft_cufft - func: _cufft_get_plan_cache_size(int device_index) -> int + use_c10_dispatcher: True - func: _cufft_get_plan_cache_max_size(int device_index) -> int + use_c10_dispatcher: True - func: _cufft_set_plan_cache_max_size(int device_index, int max_size) -> void - func: _cufft_clear_plan_cache(int device_index) -> void -- func: index(Tensor self, Tensor?[] indices) -> Tensor +- func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor variants: function, method # NB: This function is special-cased in tools/autograd/gen_variable_type.py - func: index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!) + use_c10_dispatcher: True variants: method - func: index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor + use_c10_dispatcher: True variants: function, method - func: index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!) @@ -1109,60 +1324,73 @@ variants: function - func: inverse(Tensor self) -> Tensor + use_c10_dispatcher: True variants: function, method - func: inverse.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - func: _inverse_helper(Tensor self) -> Tensor + use_c10_dispatcher: True variants: function dispatch: CPU: _inverse_helper_cpu CUDA: _inverse_helper_cuda - func: isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor + use_c10_dispatcher: True variants: function, method - func: isnan(Tensor self) -> Tensor + use_c10_dispatcher: True variants: function device_guard: False - func: is_distributed(Tensor self) -> bool + use_c10_dispatcher: True variants: function, method device_guard: False - func: is_floating_point(Tensor self) -> bool + use_c10_dispatcher: True variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True - func: is_complex(Tensor self) -> bool + use_c10_dispatcher: True variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True - func: is_nonzero(Tensor self) -> bool + use_c10_dispatcher: True variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True - func: is_same_size(Tensor self, Tensor other) -> bool + use_c10_dispatcher: True variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True - func: is_signed(Tensor self) -> bool + use_c10_dispatcher: True variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True - func: kl_div(Tensor self, Tensor target, int reduction=Mean) -> Tensor + use_c10_dispatcher: True - func: kl_div_backward(Tensor grad_output, Tensor self, Tensor target, int reduction=Mean) -> Tensor + use_c10_dispatcher: True dispatch: CPU: kl_div_backward_cpu CUDA: kl_div_backward_cuda - func: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) + use_c10_dispatcher: True variants: function, method - func: kthvalue.values(Tensor self, int k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) @@ -1193,22 +1421,31 @@ MkldnnCPU: mkldnn_linear - func: fbgemm_linear_int8_weight_fp32_activation(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor + use_c10_dispatcher: True - func: fbgemm_linear_int8_weight(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor + use_c10_dispatcher: True - func: fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, float, int) + use_c10_dispatcher: True - func: fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor + use_c10_dispatcher: True - func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor + use_c10_dispatcher: True - func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor + use_c10_dispatcher: True - func: fbgemm_pack_quantized_matrix(Tensor input) -> Tensor + use_c10_dispatcher: True -- func: fbgemm_pack_quantized_matrix(Tensor input, int K, int N) -> Tensor +- func: fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor + use_c10_dispatcher: True - func: fbgemm_is_cpu_supported() -> bool + use_c10_dispatcher: True - func: linspace(Scalar start, Scalar end, int steps=100, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -1218,45 +1455,51 @@ CUDA: linspace_cuda_out - func: log(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: log_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _log__cpu CUDA: _log__cuda - func: log.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _log_out_cpu CUDA: _log_out_cuda - func: log10(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: log10_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _log10__cpu CUDA: _log10__cuda - func: log10.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _log10_out_cpu CUDA: _log10_out_cuda - func: log1p(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: log1p_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _log1p__cpu @@ -1265,7 +1508,7 @@ SparseCUDA: log1p_sparse_ - func: log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _log1p_out_cpu CUDA: _log1p_out_cuda @@ -1273,23 +1516,26 @@ SparseCUDA: log1p_out_sparse - func: log2(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: log2_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _log2__cpu CUDA: _log2__cuda - func: log2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _log2_out_cpu CUDA: _log2_out_cuda - func: logdet(Tensor self) -> Tensor + use_c10_dispatcher: True variants: function, method - func: logspace(Scalar start, Scalar end, int steps=100, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -1302,24 +1548,27 @@ # log_softmax allows positional dtype, unlike most operators, because kwonly is BC-breaking when loading jit models. - func: log_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor variants: function, method - named_guard: False + supports_named_tensor: True - func: log_softmax(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor variants: function, method - named_guard: False + supports_named_tensor: True - func: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + use_c10_dispatcher: True dispatch: CPU: log_softmax_cpu CUDA: log_softmax_cuda - named_guard: False + supports_named_tensor: True - func: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor + use_c10_dispatcher: True dispatch: CPU: log_softmax_backward_cpu CUDA: log_softmax_backward_cuda - func: logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor + use_c10_dispatcher: True variants: function, method - func: logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) @@ -1330,27 +1579,34 @@ - func: logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) - func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor + use_c10_dispatcher: True - func: matmul(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: function, method - named_guard: false + supports_named_tensor: True - func: matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) - named_guard: false + supports_named_tensor: True - func: matrix_rank.tol(Tensor self, float tol, bool symmetric=False) -> Tensor + use_c10_dispatcher: True - func: matrix_rank(Tensor self, bool symmetric=False) -> Tensor + use_c10_dispatcher: True - func: matrix_power(Tensor self, int n) -> Tensor + use_c10_dispatcher: True variants: function, method - func: max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + use_c10_dispatcher: True variants: function, method - func: max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) - func: max_values(Tensor self, int[1] dim, bool keepdim=False) -> Tensor + use_c10_dispatcher: True variants: function, method - func: max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) @@ -1363,42 +1619,49 @@ # Return: (Tensor output, Tensor indices) - func: max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) + use_c10_dispatcher: True - func: max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor + use_c10_dispatcher: True - func: max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + use_c10_dispatcher: True - func: mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + use_c10_dispatcher: True requires_tensor: True dispatch: MkldnnCPU: mkldnn_max_pool2d - func: quantized_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1) -> Tensor + use_c10_dispatcher: True requires_tensor: True dispatch: QuantizedCPU: quantized_max_pool2d - func: max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor + use_c10_dispatcher: True - func: mean(Tensor self, *, ScalarType? dtype=None) -> Tensor variants: function, method - named_guard: false + supports_named_tensor: True - func: mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor variants: function, method - named_guard: false + supports_named_tensor: True - func: mean.out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) - named_guard: false + supports_named_tensor: True - func: mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor variants: function, method - named_guard: false + supports_named_tensor: True - func: mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) - named_guard: false + supports_named_tensor: True - func: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + use_c10_dispatcher: True variants: function, method - func: median.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) @@ -1409,11 +1672,13 @@ - func: median.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) - func: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + use_c10_dispatcher: True variants: function, method - func: min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) - func: min_values(Tensor self, int[1] dim, bool keepdim=False) -> Tensor + use_c10_dispatcher: True variants: function, method - func: min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) @@ -1427,10 +1692,13 @@ - func: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, int[] padding, int[] stride, int[] dilation, int groups) -> Tensor - func: mkldnn_convolution_backward_input(int[] self_size, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool bias_defined) -> Tensor + use_c10_dispatcher: True - func: mkldnn_convolution_backward_weights(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool bias_defined) -> (Tensor, Tensor) + use_c10_dispatcher: True - func: mkldnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: True - func: miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) dispatch: @@ -1445,18 +1713,22 @@ CUDA: miopen_convolution - func: miopen_convolution_backward_input(int[] self_size, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor + use_c10_dispatcher: True dispatch: CUDA: miopen_convolution_backward_input - func: miopen_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: True dispatch: CUDA: miopen_convolution_backward - func: miopen_convolution_backward_bias(Tensor grad_output) -> Tensor + use_c10_dispatcher: True dispatch: CUDA: miopen_convolution_backward_bias - func: miopen_convolution_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor + use_c10_dispatcher: True dispatch: CUDA: miopen_convolution_backward_weight @@ -1467,14 +1739,17 @@ # NB: output_padding not strictly needed here, but it's helpful for the float # backwards - func: miopen_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: True dispatch: CUDA: miopen_convolution_transpose_backward - func: miopen_convolution_transpose_backward_input(Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor + use_c10_dispatcher: True dispatch: CUDA: miopen_convolution_transpose_backward_input - func: miopen_convolution_transpose_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor + use_c10_dispatcher: True dispatch: CUDA: miopen_convolution_transpose_backward_weight @@ -1483,14 +1758,17 @@ CUDA: miopen_depthwise_convolution - func: miopen_depthwise_convolution_backward_input(int[] self_size, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor + use_c10_dispatcher: True dispatch: CUDA: miopen_depthwise_convolution_backward_input - func: miopen_depthwise_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: True dispatch: CUDA: miopen_depthwise_convolution_backward - func: miopen_depthwise_convolution_backward_weight(int[] weight_size, Tensor grad_output, Tensor self, int[] padding, int[] stride, int[] dilation, int groups, bool benchmark, bool deterministic) -> Tensor + use_c10_dispatcher: True dispatch: CUDA: miopen_depthwise_convolution_backward_weight @@ -1503,13 +1781,14 @@ CUDA: miopen_rnn_backward - func: mm(Tensor self, Tensor mat2) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: CPU: legacy::cpu::_th_mm CUDA: legacy::cuda::_th_mm SparseCPU: _sparse_mm SparseCUDA: _sparse_mm - named_guard: False + supports_named_tensor: True - func: mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -1517,72 +1796,81 @@ CUDA: legacy::cuda::_th_mm_out SparseCPU: _sparse_mm_out SparseCUDA: _sparse_mm_out - named_guard: False + supports_named_tensor: True - func: _sparse_mm(Tensor sparse, Tensor dense) -> Tensor + use_c10_dispatcher: True - func: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) + use_c10_dispatcher: True variants: function, method - func: mode.values(Tensor self, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) - func: mul.Tensor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: CPU: mul CUDA: mul - SparseCPU: mul - SparseCUDA: mul + SparseCPU: mul_sparse + SparseCUDA: mul_sparse MkldnnCPU: mkldnn_mul - named_guard: False - + supports_named_tensor: True - func: mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: mul_ CUDA: mul_ - SparseCPU: mul_ - SparseCUDA: mul_ + SparseCPU: mul_sparse_ + SparseCUDA: mul_sparse_ MkldnnCPU: mkldnn_mul_ - named_guard: False + supports_named_tensor: True - func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: mul_out CUDA: mul_out - SparseCPU: mul_out - SparseCUDA: mul_out + SparseCPU: mul_out_sparse_cpu + SparseCUDA: mul_out_sparse_cuda MkldnnCPU: mkldnn_mul_out - named_guard: False + supports_named_tensor: True # For C++ only, until we have conversion from C++ numbers to Tensor - func: mul.Scalar(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: True variants: function, method - func: mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + use_c10_dispatcher: True variants: method - func: mv(Tensor self, Tensor vec) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: CPU: legacy::cpu::_th_mv CUDA: legacy::cuda::_th_mv - named_guard: False + supports_named_tensor: True - func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: legacy::cpu::_th_mv_out CUDA: legacy::cuda::_th_mv_out - named_guard: False + supports_named_tensor: True - func: mvlgamma(Tensor self, int p) -> Tensor + use_c10_dispatcher: True variants: function, method - func: mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!) + use_c10_dispatcher: True variants: method - func: narrow_copy(Tensor self, int dim, int start, int length) -> Tensor + use_c10_dispatcher: True variants: method dispatch: CPU: narrow_copy_dense @@ -1591,9 +1879,10 @@ SparseCUDA: narrow_copy_sparse - func: narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a) + use_c10_dispatcher: True variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True - func: native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) dispatch: @@ -1602,6 +1891,7 @@ MkldnnCPU: mkldnn_batch_norm - func: batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor) + use_c10_dispatcher: True dispatch: CUDA: batch_norm_stats_cuda @@ -1637,17 +1927,21 @@ CUDA: batch_norm_update_stats_cuda - func: _nnpack_available() -> bool + use_c10_dispatcher: True - func: _nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, int[2] padding) -> Tensor variants: function - func: _nnpack_spatial_convolution_backward(Tensor input, Tensor grad_output, Tensor weight, int[2] padding, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: True variants: function - func: _nnpack_spatial_convolution_backward_input(Tensor input, Tensor grad_output, Tensor weight, int[2] padding) -> Tensor + use_c10_dispatcher: True variants: function - func: _nnpack_spatial_convolution_backward_weight(Tensor input, int[] weightsize, Tensor grad_output, int[2] padding) -> Tensor + use_c10_dispatcher: True variants: function - func: ones.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -1658,25 +1952,34 @@ - func: ones.out(int[] size, *, Tensor(a!) out) -> Tensor(a!) - func: ones_like(Tensor self) -> Tensor + use_c10_dispatcher: True - func: ones_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor - func: pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor + use_c10_dispatcher: True - func: cdist(Tensor x1, Tensor x2, float p=2) -> Tensor + use_c10_dispatcher: True - func: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor + use_c10_dispatcher: True - func: pdist(Tensor self, float p=2) -> Tensor + use_c10_dispatcher: True - func: _pdist_forward(Tensor self, float p=2) -> Tensor + use_c10_dispatcher: True - func: _pdist_backward(Tensor grad, Tensor self, float p, Tensor pdist) -> Tensor + use_c10_dispatcher: True - func: cosine_similarity(Tensor x1, Tensor x2, int dim=1, float eps=1e-08) -> Tensor + use_c10_dispatcher: True variants: function - func: permute(Tensor(a) self, int[] dims) -> Tensor(a) + use_c10_dispatcher: True variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. # Only exposed from C++ -- in Python, @@ -1687,21 +1990,27 @@ # behavior on Windows, for reasons I don't understand # (maybe related to capital letter collation somehow...) - func: numpy_T(Tensor(a) self) -> Tensor(a) + use_c10_dispatcher: True variants: method - func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor + use_c10_dispatcher: True - func: is_pinned(Tensor self) -> bool + use_c10_dispatcher: True variants: method - named_guard: False + supports_named_tensor: True - func: pin_memory(Tensor self) -> Tensor + use_c10_dispatcher: True variants: method - func: pinverse(Tensor self, float rcond=1e-15) -> Tensor + use_c10_dispatcher: True variants: function, method - func: poisson_nll_loss(Tensor input, Tensor target, bool log_input, bool full, float eps, int reduction) -> Tensor + use_c10_dispatcher: True variants: function - func: scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -1721,6 +2030,7 @@ - func: rand.generator_out(int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) - func: rand_like(Tensor self) -> Tensor + use_c10_dispatcher: True - func: rand_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor @@ -1741,8 +2051,10 @@ - func: randint.low_generator_out(int low, int high, int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) - func: randint_like(Tensor self, int high) -> Tensor + use_c10_dispatcher: True - func: randint_like.low(Tensor self, int low, int high) -> Tensor + use_c10_dispatcher: True - func: randint_like.dtype(Tensor self, int high, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor @@ -1763,6 +2075,7 @@ - func: randn.generator_out(int[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) - func: randn_like(Tensor self) -> Tensor + use_c10_dispatcher: True - func: randn_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor @@ -1787,98 +2100,110 @@ CUDA: range_cuda_out - func: reciprocal(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: reciprocal_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _reciprocal__cpu CUDA: _reciprocal__cuda - func: reciprocal.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _reciprocal_out_cpu CUDA: _reciprocal_out_cuda - func: neg(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: neg_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: neg_out CUDA: neg_out - func: repeat(Tensor self, int[] repeats) -> Tensor + use_c10_dispatcher: True variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. - func: repeat_interleave.Tensor(Tensor repeats) -> Tensor + use_c10_dispatcher: True variants: function dispatch: CPU: repeat_interleave_cpu CUDA: repeat_interleave_cuda - func: repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None) -> Tensor + use_c10_dispatcher: True variants: function, method - func: repeat_interleave.self_int(Tensor self, int repeats, int? dim=None) -> Tensor + use_c10_dispatcher: True variants: function, method - func: reshape(Tensor self, int[] shape) -> Tensor + use_c10_dispatcher: True variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True - func: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor + use_c10_dispatcher: True device_guard: False requires_tensor: True dispatch: MkldnnCPU: mkldnn_reshape - func: reshape_as(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method device_guard: False - func: round(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: round_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - dispatch: - CPU: _round__cpu - CUDA: _round__cuda - func: round.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: - CPU: _round_out_cpu - CUDA: _round_out_cuda + CPU: round_out + CUDA: round_out - func: rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor - func: rrelu_(Tensor(a!) self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) - func: relu(Tensor self) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: CPU: relu CUDA: relu MkldnnCPU: mkldnn_relu QuantizedCPU: quantized_relu - named_guard: False + supports_named_tensor: True - func: relu_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: relu_ @@ -1887,78 +2212,89 @@ QuantizedCPU: quantized_relu_ - func: prelu(Tensor self, Tensor weight) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: CPU: prelu_cpu CUDA: prelu_cuda - func: prelu_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor) + use_c10_dispatcher: True variants: function, method dispatch: CPU: prelu_backward_cpu CUDA: prelu_backward_cuda - func: gelu(Tensor self) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: gelu_cpu CUDA: gelu_cuda - func: gelu_backward(Tensor grad, Tensor self) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: gelu_backward_cpu CUDA: gelu_backward_cuda - func: hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: CPU: hardshrink_cpu CUDA: hardshrink_cuda - func: hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: CPU: hardshrink_backward_cpu CUDA: hardshrink_backward_cuda - func: rsqrt(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: rsqrt_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - dispatch: - CPU: _rsqrt__cpu - CUDA: _rsqrt__cuda - func: rsqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: - CPU: _rsqrt_out_cpu - CUDA: _rsqrt_out_cuda + CPU: rsqrt_out + CUDA: rsqrt_out - func: select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a) variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True - func: select.int(Tensor(a) self, int dim, int index) -> Tensor(a) + use_c10_dispatcher: True variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True - func: selu(Tensor self) -> Tensor + use_c10_dispatcher: True - func: selu_(Tensor(a!) self) -> Tensor(a!) + use_c10_dispatcher: True - func: celu(Tensor self, Scalar alpha=1.0) -> Tensor + use_c10_dispatcher: True - func: celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!) + use_c10_dispatcher: True - func: sigmoid(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: sigmoid @@ -1966,7 +2302,8 @@ MkldnnCPU: mkldnn_sigmoid - func: sigmoid_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _sigmoid__cpu @@ -1974,155 +2311,139 @@ MkldnnCPU: mkldnn_sigmoid_ - func: sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _sigmoid_out_cpu CUDA: _sigmoid_out_cuda - func: sin(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: sin_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _sin__cpu CUDA: _sin__cuda - func: sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _sin_out_cpu CUDA: _sin_out_cuda - func: sinh(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: sinh_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _sinh__cpu CUDA: _sinh__cuda - func: sinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _sinh_out_cpu CUDA: _sinh_out_cuda - func: detach(Tensor self) -> Tensor + use_c10_dispatcher: True variants: function, method - func: detach_(Tensor(a!) self) -> Tensor(a!) + use_c10_dispatcher: True variants: function, method - func: size.int(Tensor self, int dim) -> int + use_c10_dispatcher: True variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True - func: size.Dimname(Tensor self, Dimname dim) -> int variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True - func: slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a) + use_c10_dispatcher: True variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True - func: slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) + use_c10_dispatcher: True variants: function, method - func: smm(Tensor self, Tensor mat2) -> Tensor + use_c10_dispatcher: True variants: function, method # softmax allows positional dtype, unlike most operators, because kwonly is BC-breaking when loading jit models. - func: softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor variants: function, method - named_guard: False + supports_named_tensor: True - func: softmax(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor variants: function, method - named_guard: False + supports_named_tensor: True - func: _softmax(Tensor self, int dim, bool half_to_float) -> Tensor + use_c10_dispatcher: True dispatch: CPU: softmax_cpu CUDA: softmax_cuda MkldnnCPU: mkldnn_softmax - named_guard: False + supports_named_tensor: True - func: _softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor + use_c10_dispatcher: True dispatch: CPU: softmax_backward_cpu CUDA: softmax_backward_cuda -- func: _sparse_add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) - dispatch: - SparseCPU: add_out_sparse_cpu - SparseCUDA: add_out_sparse_cuda - -- func: _sparse_dense_add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) - dispatch: - CPU: add_out_dense_sparse_cpu - CUDA: add_out_dense_sparse_cuda - -- func: _sparse_div_zerodim.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) - dispatch: - SparseCPU: div_out_sparse_zerodim - SparseCUDA: div_out_sparse_zerodim - -- func: _sparse_div_scalar.out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) - dispatch: - SparseCPU: div_out_sparse_scalar - SparseCUDA: div_out_sparse_scalar - -- func: _sparse_mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) - dispatch: - SparseCPU: mul_out_sparse_cpu - SparseCUDA: mul_out_sparse_cuda - -- func: _sparse_mul_zerodim.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) - dispatch: - SparseCPU: mul_out_sparse_zerodim - SparseCUDA: mul_out_sparse_zerodim - -- func: _sparse_mul_scalar.out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) - dispatch: - SparseCPU: mul_out_sparse_scalar - SparseCUDA: mul_out_sparse_scalar - - func: split.Tensor(Tensor(a) self, int split_size, int dim=0) -> Tensor(a)[] + use_c10_dispatcher: True variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True - func: split_with_sizes(Tensor self, int[] split_sizes, int dim=0) -> Tensor[] + use_c10_dispatcher: True variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True - func: squeeze(Tensor(a) self) -> Tensor(a) + use_c10_dispatcher: True variants: function, method device_guard: False - func: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a) + use_c10_dispatcher: True variants: function, method device_guard: False - func: squeeze_(Tensor(a!) self) -> Tensor(a!) + use_c10_dispatcher: True variants: method device_guard: False - func: squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!) + use_c10_dispatcher: True variants: method device_guard: False - func: sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + use_c10_dispatcher: True variants: function, method - func: sspaddmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) @@ -2133,6 +2454,7 @@ SparseCUDA: _sspaddmm_out_cuda - func: stack(Tensor[] tensors, int dim=0) -> Tensor + use_c10_dispatcher: True - func: stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) @@ -2144,260 +2466,300 @@ variants: function, method - func: stride.int(Tensor self, int dim) -> int + use_c10_dispatcher: True variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True - func: stride.Dimname(Tensor self, Dimname dim) -> int variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True - func: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor variants: function, method - named_guard: False + supports_named_tensor: True - func: sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor variants: function, method - named_guard: False + supports_named_tensor: True - func: sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor variants: function, method - named_guard: False + supports_named_tensor: True - func: sum.IntList_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True - func: sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True - func: sum_to_size(Tensor self, int[] size) -> Tensor + use_c10_dispatcher: True variants: method device_guard: False - func: sqrt(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: sqrt_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _sqrt__cpu CUDA: _sqrt__cuda - func: sqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _sqrt_out_cpu CUDA: _sqrt_out_cuda - func: std(Tensor self, bool unbiased=True) -> Tensor + use_c10_dispatcher: True variants: function, method - named_guard: False + supports_named_tensor: True - func: std.dim(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor + use_c10_dispatcher: True variants: function, method - named_guard: False + supports_named_tensor: True - func: std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) + use_c10_dispatcher: True variants: function - named_guard: False + supports_named_tensor: True - func: std_mean.dim(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) + use_c10_dispatcher: True variants: function - named_guard: False + supports_named_tensor: True - func: std_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) variants: function - named_guard: False + supports_named_tensor: True - func: std.out(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True - func: std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor variants: function, method - named_guard: False + supports_named_tensor: True - func: std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True - func: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor variants: function, method - named_guard: False + supports_named_tensor: True - func: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor variants: function, method - named_guard: False + supports_named_tensor: True - func: prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True - func: prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor variants: function, method - named_guard: False + supports_named_tensor: True - func: prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True - func: t(Tensor(a) self) -> Tensor(a) + use_c10_dispatcher: True device_guard: False variants: function, method - named_guard: False + supports_named_tensor: True - func: t_(Tensor(a!) self) -> Tensor(a!) + use_c10_dispatcher: True device_guard: False variants: method - func: tan(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: tan_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _tan__cpu CUDA: _tan__cuda - func: tan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _tan_out_cpu CUDA: _tan_out_cuda - func: tanh(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: tanh_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _tanh__cpu CUDA: _tanh__cuda - func: tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _tanh_out_cpu CUDA: _tanh_out_cuda - func: tensordot(Tensor self, Tensor other, int[] dims_self, int[] dims_other) -> Tensor + use_c10_dispatcher: True variants: function # TODO: namespace threshold in 'nn' - func: threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor + use_c10_dispatcher: True variants: function - named_guard: False + supports_named_tensor: True - func: threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!) + use_c10_dispatcher: True variants: function - named_guard: False + supports_named_tensor: True - func: threshold.out(Tensor self, Scalar threshold, Scalar value, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True - func: threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor + use_c10_dispatcher: True variants: function -- func: transpose(Tensor(a) self, int dim0, int dim1) -> Tensor(a) +- func: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) + use_c10_dispatcher: True variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True -- func: transpose(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a) +- func: transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a) variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True - func: _mkldnn_transpose(Tensor self, int dim0, int dim1) -> Tensor + use_c10_dispatcher: True device_guard: False requires_tensor: True dispatch: MkldnnCPU: mkldnn_transpose - func: transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) + use_c10_dispatcher: True variants: method device_guard: False - func: _mkldnn_transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) + use_c10_dispatcher: True device_guard: False requires_tensor: True dispatch: MkldnnCPU: mkldnn_transpose_ - func: one_hot(Tensor self, int num_classes=-1) -> Tensor + use_c10_dispatcher: True python_module: nn variants: function - func: flip(Tensor self, int[] dims) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: CPU: flip_cpu CUDA: flip_cuda - func: roll(Tensor self, int[1] shifts, int[1] dims=[]) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: CPU: roll_cpu CUDA: roll_cuda # default int[] value [0,1] should not add space after comma, since native_parse.py uses ', ' to split args + use_c10_dispatcher: True - func: rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor + use_c10_dispatcher: True + use_c10_dispatcher: True variants: function, method - func: trapz.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor + use_c10_dispatcher: True - func: trapz.dx(Tensor y, *, float dx=1, int dim=-1) -> Tensor + use_c10_dispatcher: True - func: _trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor + use_c10_dispatcher: True - func: triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, float margin=1.0, float p=2, float eps=1e-06, bool swap=False, int reduction=Mean) -> Tensor + use_c10_dispatcher: True - func: trunc(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method - func: trunc_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: function, method dispatch: CPU: _trunc__cpu CUDA: _trunc__cuda - func: trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: _trunc_out_cpu CUDA: _trunc_out_cuda - func: type_as(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method - func: _has_compatible_shallow_copy_type(Tensor self, Tensor from) -> bool + use_c10_dispatcher: True variants: function - func: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor) + use_c10_dispatcher: True variants: function dispatch: CPU: _unique_cpu CUDA: _unique_cuda - func: unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: True variants: function dispatch: CPU: unique_dim_cpu CUDA: unique_dim_cuda - func: unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: True variants: function dispatch: CPU: unique_consecutive_cpu CUDA: unique_consecutive_cuda - func: unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: True variants: function dispatch: CPU: unique_dim_consecutive_cpu @@ -2408,52 +2770,61 @@ # Please don't rely on these two operators, they will be removed soon - func: _unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: True variants: function dispatch: CPU: _unique2_cpu CUDA: _unique2_cuda - func: _unsafe_view(Tensor self, int[] size) -> Tensor + use_c10_dispatcher: True - func: unsqueeze(Tensor(a) self, int dim) -> Tensor(a) + use_c10_dispatcher: True variants: function, method device_guard: False - func: unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!) + use_c10_dispatcher: True variants: method device_guard: False - func: var(Tensor self, bool unbiased=True) -> Tensor + use_c10_dispatcher: True variants: function, method - named_guard: False + supports_named_tensor: True - func: var.dim(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor + use_c10_dispatcher: True variants: function, method - named_guard: False + supports_named_tensor: True - func: var.out(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True - func: var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor variants: function, method - named_guard: False + supports_named_tensor: True - func: var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True - func: var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) + use_c10_dispatcher: True variants: function - named_guard: false + supports_named_tensor: True - func: var_mean.dim(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) + use_c10_dispatcher: True variants: function - named_guard: False + supports_named_tensor: True - func: var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) variants: function - named_guard: False + supports_named_tensor: True - func: view_as(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method device_guard: False @@ -2461,36 +2832,44 @@ # this allows us to implicitly calculate the broadcast derivative, while only dealing with the # _s_where derivative. - func: where.self(Tensor condition, Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: function, method - func: where(Tensor condition) -> Tensor[] + use_c10_dispatcher: True variants: function - func: _s_where(Tensor condition, Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: function dispatch: CPU: _s_where_cpu CUDA: _s_where_cuda - func: norm_except_dim(Tensor v, int pow=2, int dim=0) -> Tensor + use_c10_dispatcher: True variants: function # VariableType::_weight_norm does not want to be given a gap in the autograd graph, # so we don't define "dispatch" variants for it. - func: _weight_norm(Tensor v, Tensor g, int dim=0) -> Tensor + use_c10_dispatcher: True variants: function - func: _weight_norm_cuda_interface(Tensor v, Tensor g, int dim=0) -> (Tensor, Tensor) + use_c10_dispatcher: True variants: function dispatch: CUDA: weight_norm_cuda - func: _weight_norm_cuda_interface_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor) + use_c10_dispatcher: True variants: function dispatch: CUDA: weight_norm_cuda_backward - func: _weight_norm_differentiable_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor) + use_c10_dispatcher: True variants: function - func: zeros.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -2501,10 +2880,12 @@ - func: zeros.out(int[] size, *, Tensor(a!) out) -> Tensor(a!) - func: zeros_like(Tensor self) -> Tensor + use_c10_dispatcher: True - func: zeros_like.dtype(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False) -> Tensor - func: _standard_gamma_grad(Tensor self, Tensor output) -> Tensor + use_c10_dispatcher: True variants: function dispatch: CPU: _standard_gamma_grad_cpu @@ -2517,6 +2898,7 @@ CUDA: _s_gamma_cuda - func: _dirichlet_grad(Tensor x, Tensor alpha, Tensor total) -> Tensor + use_c10_dispatcher: True dispatch: CPU: _dirichlet_grad_cpu CUDA: _dirichlet_grad_cuda @@ -2536,20 +2918,24 @@ # complicated - func: native_norm(Tensor self, Scalar p=2) -> Tensor + use_c10_dispatcher: True dispatch: SparseCPU: norm_sparse SparseCUDA: norm_sparse # TODO: reduce signatures down to one when optional args is available - func: _sparse_sum(Tensor self) -> Tensor + use_c10_dispatcher: True - func: _sparse_sum.dtype(Tensor self, *, ScalarType dtype) -> Tensor - func: _sparse_sum.dim(Tensor self, int[1] dim) -> Tensor + use_c10_dispatcher: True - func: _sparse_sum.dim_dtype(Tensor self, int[1] dim, *, ScalarType dtype) -> Tensor - func: _sparse_sum_backward(Tensor grad, Tensor self, int[] dim) -> Tensor + use_c10_dispatcher: True dispatch: SparseCPU: _sparse_sum_backward_cpu SparseCUDA: _sparse_sum_backward_cuda @@ -2558,12 +2944,14 @@ variants: function, method - func: norm.Scalar(Tensor self, Scalar p=2) -> Tensor + use_c10_dispatcher: True variants: function, method - func: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor variants: function, method - func: norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor + use_c10_dispatcher: True variants: function, method - func: norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) @@ -2582,27 +2970,32 @@ - func: frobenius_norm(Tensor self) -> Tensor + use_c10_dispatcher: True variants: function - func: frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> Tensor + use_c10_dispatcher: True variants: function - func: frobenius_norm.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) variants: function - func: nuclear_norm(Tensor self, bool keepdim=False) -> Tensor + use_c10_dispatcher: True variants: function - func: nuclear_norm.out(Tensor self, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) variants: function - func: nuclear_norm.dim(Tensor self, int[2] dim, bool keepdim=False) -> Tensor + use_c10_dispatcher: True variants: function - func: nuclear_norm.dim_out(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) variants: function - func: clone(Tensor self) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: CPU: clone @@ -2611,9 +3004,10 @@ SparseCUDA: clone_sparse MkldnnCPU: mkldnn_clone QuantizedCPU: quantized_clone - named_guard: False + supports_named_tensor: True - func: resize_as_(Tensor(a!) self, Tensor the_template) -> Tensor(a!) + use_c10_dispatcher: True variants: function, method dispatch: CPU: resize_as_cpu_ @@ -2624,20 +3018,22 @@ - func: pow.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: pow_out - CUDA: legacy::cuda::_th_pow_out + CUDA: pow_out SparseCPU: pow_out_sparse_scalar SparseCUDA: pow_out_sparse_scalar - func: pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: CPU: pow - CUDA: legacy::cuda::_th_pow + CUDA: pow SparseCPU: pow_sparse_scalar SparseCUDA: pow_sparse_scalar - func: zero_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: method, function dispatch: CPU: legacy::cpu::_th_zero_ @@ -2647,61 +3043,90 @@ MkldnnCPU: mkldnn_zero_ - func: sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) - named_guard: False + dispatch: + CPU: sub_out + CUDA: sub_out + SparseCPU: sub_out_sparse + SparseCUDA: sub_out_sparse + supports_named_tensor: True - func: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + use_c10_dispatcher: True variants: function, method - named_guard: False + dispatch: + CPU: sub + CUDA: sub + SparseCPU: sub_sparse + SparseCUDA: sub_sparse + supports_named_tensor: True - func: sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + use_c10_dispatcher: True variants: method - named_guard: False + dispatch: + CPU: sub_ + CUDA: sub_ + SparseCPU: sub_sparse_ + SparseCUDA: sub_sparse_ + supports_named_tensor: True # For C++ only, until we have conversion from C++ numbers to Tensor - func: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + use_c10_dispatcher: True variants: function, method - named_guard: False + supports_named_tensor: True - func: sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) + use_c10_dispatcher: True variants: method - named_guard: False + supports_named_tensor: True - func: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + use_c10_dispatcher: True variants: function - named_guard: False + supports_named_tensor: True # For C++ only, until we have conversion from C++ numbers to Tensor - func: rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + use_c10_dispatcher: True variants: function - named_guard: False - -- func: s_native_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) - dispatch: - CPU: s_addmm_out_sparse_dense_cpu - CUDA: s_addmm_out_sparse_dense_cuda - -- func: s_native_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor - dispatch: - CPU: s_addmm_sparse_dense_cpu - CUDA: s_addmm_sparse_dense_cuda - -- func: s_native_addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) - dispatch: - CPU: s_addmm_sparse_dense_cpu_ - CUDA: s_addmm_sparse_dense_cuda_ + supports_named_tensor: True +# Functionally the same as addmm, but we give it a different derivative formula +# that doesn't propagate gradients to non-present entries on sparse. - func: _sparse_addmm(Tensor self, Tensor sparse, Tensor dense, *, Scalar beta=1, Scalar alpha=1) -> Tensor + use_c10_dispatcher: True + named_guard: False - func: addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) - named_guard: False + dispatch: + CPU: legacy::cpu::_th_addmm_out + CUDA: legacy::cuda::_th_addmm_out + SparseCPU: addmm_out_sparse_dense_cpu + SparseCUDA: addmm_out_sparse_dense_cuda + supports_named_tensor: True - func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + use_c10_dispatcher: True variants: function, method - named_guard: False + dispatch: + CPU: legacy::cpu::_th_addmm + CUDA: legacy::cuda::_th_addmm + SparseCPU: addmm_sparse_dense_cpu + SparseCUDA: addmm_sparse_dense_cuda + supports_named_tensor: True - func: addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + use_c10_dispatcher: True variants: method - named_guard: False + dispatch: + CPU: legacy::cpu::_th_addmm_ + CUDA: legacy::cuda::_th_addmm_ + # Warning! For whatever reason, the inplace sparse addmm is NON + # broadcasting + SparseCPU: s_addmm_sparse_dense_cpu_ + SparseCUDA: s_addmm_sparse_dense_cuda_ + supports_named_tensor: True # NOTE [ Sparse: autograd and API ] @@ -2843,6 +3268,7 @@ requires_tensor: True - func: sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: SparseCPU: sparse_resize_ @@ -2850,6 +3276,7 @@ requires_tensor: True - func: sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: SparseCPU: sparse_resize_and_clear_ @@ -2858,14 +3285,16 @@ - func: sparse_mask(Tensor self, Tensor mask) -> Tensor + use_c10_dispatcher: True variants: method dispatch: - CPU: sparse_mask_cpu - CUDA: sparse_mask_cuda + SparseCPU: sparse_mask_cpu + SparseCUDA: sparse_mask_cuda requires_tensor: True - func: to_dense(Tensor self) -> Tensor + use_c10_dispatcher: True variants: method dispatch: SparseCPU: sparse_to_dense @@ -2874,8 +3303,10 @@ requires_tensor: True - func: to_dense_backward(Tensor grad, Tensor input) -> Tensor + use_c10_dispatcher: True - func: sparse_dim(Tensor self) -> int + use_c10_dispatcher: True variants: method dispatch: SparseCPU: sparse_dim_sparse @@ -2885,6 +3316,7 @@ # legacy method - func: _dimI(Tensor self) -> int + use_c10_dispatcher: True variants: method dispatch: SparseCPU: sparse_dim_sparse @@ -2894,6 +3326,7 @@ - func: dense_dim(Tensor self) -> int + use_c10_dispatcher: True variants: method dispatch: SparseCPU: dense_dim_sparse @@ -2903,6 +3336,7 @@ # legacy method - func: _dimV(Tensor self) -> int + use_c10_dispatcher: True variants: method dispatch: SparseCPU: dense_dim_sparse @@ -2912,6 +3346,7 @@ - func: _nnz(Tensor self) -> int + use_c10_dispatcher: True variants: method dispatch: SparseCPU: _nnz_sparse @@ -2921,6 +3356,7 @@ - func: coalesce(Tensor self) -> Tensor + use_c10_dispatcher: True variants: method dispatch: SparseCPU: coalesce_sparse_cpu @@ -2929,16 +3365,18 @@ - func: is_coalesced(Tensor self) -> bool + use_c10_dispatcher: True variants: method dispatch: SparseCPU: is_coalesced_sparse SparseCUDA: is_coalesced_sparse requires_tensor: True device_guard: False - named_guard: False + supports_named_tensor: True - func: _indices(Tensor(a) self) -> Tensor(a) + use_c10_dispatcher: True variants: method dispatch: SparseCPU: _indices_sparse @@ -2947,6 +3385,7 @@ device_guard: False - func: _values(Tensor(a) self) -> Tensor(a) + use_c10_dispatcher: True variants: method dispatch: SparseCPU: _values_sparse @@ -2958,6 +3397,7 @@ # a bit unsafe. Similar to _indices and _values, this is useful for implementing # custom sparse operations in Python/C++ extension. - func: _coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: SparseCPU: _coalesced_sparse_ @@ -2966,6 +3406,7 @@ device_guard: False - func: indices(Tensor(a) self) -> Tensor(a) + use_c10_dispatcher: True variants: method dispatch: SparseCPU: indices_sparse @@ -2974,6 +3415,7 @@ device_guard: False - func: values(Tensor(a) self) -> Tensor(a) + use_c10_dispatcher: True variants: method dispatch: SparseCPU: values_sparse @@ -2989,12 +3431,14 @@ requires_tensor: True - func: hspmm(Tensor mat1, Tensor mat2) -> Tensor + use_c10_dispatcher: True dispatch: SparseCPU: hspmm_sparse_cpu SparseCUDA: hspmm_sparse_cuda requires_tensor: True - func: copy_sparse_to_sparse_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) + use_c10_dispatcher: True variants: function dispatch: SparseCPU: copy_sparse_ @@ -3002,37 +3446,49 @@ requires_tensor: True - func: numel(Tensor self) -> int + use_c10_dispatcher: True variants: function, method device_guard: False - named_guard: False + supports_named_tensor: True -- func: unbind(Tensor(a) self, int dim=0) -> Tensor(a)[] +- func: unbind.int(Tensor(a) self, int dim=0) -> Tensor(a)[] + use_c10_dispatcher: True variants: function, method + supports_named_tensor: True + +- func: unbind.Dimname(Tensor(a) self, Dimname dim) -> Tensor(a)[] + variants: function, method + supports_named_tensor: True - func: to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor + use_c10_dispatcher: True variants: method dispatch: CPU: dense_to_sparse CUDA: dense_to_sparse - func: to_sparse(Tensor self) -> Tensor + use_c10_dispatcher: True variants: method dispatch: CPU: dense_to_sparse CUDA: dense_to_sparse - func: to_mkldnn(Tensor self) -> Tensor + use_c10_dispatcher: True variants: method dispatch: CPU: dense_to_mkldnn - func: mkldnn_reorder_conv2d_weight(Tensor self, int[2] padding=0, int[2] stride=1, int[2] dilation=1, int groups=1) -> Tensor + use_c10_dispatcher: True variants: function python_module: nn dispatch: MkldnnCPU: mkldnn_reorder_conv2d_weight - func: to_mkldnn_backward(Tensor grad, Tensor input) -> Tensor + use_c10_dispatcher: True - func: quantize_linear(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor variants: function @@ -3045,6 +3501,7 @@ CPU: quantize_linear_per_channel_cpu - func: dequantize(Tensor self) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: QuantizedCPU: dequantize_quant @@ -3055,35 +3512,47 @@ CPU: dequantize_linear_cpu - func: q_scale(Tensor self) -> float + use_c10_dispatcher: True variants: function, method dispatch: QuantizedCPU: q_scale_quant - func: q_zero_point(Tensor self) -> int + use_c10_dispatcher: True variants: function, method dispatch: QuantizedCPU: q_zero_point_quant - func: q_per_channel_scales(Tensor self) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: QuantizedCPU: q_per_channel_scales_quant - func: q_per_channel_zero_points(Tensor self) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: QuantizedCPU: q_per_channel_zero_points_quant +- func: q_per_channel_axis(Tensor self) -> int[] + variants: function, method + dispatch: + QuantizedCPU: q_per_channel_axis_quant + - func: int_repr(Tensor self) -> Tensor + use_c10_dispatcher: True variants: function, method dispatch: QuantizedCPU: int_repr_quant - func: _per_tensor_affine_qtensor(Tensor self, float scale, int zero_point) -> Tensor + use_c10_dispatcher: True dispatch: CPU: per_tensor_affine_qtensor_cpu - func: _per_channel_affine_qtensor(Tensor self, Tensor scale, Tensor zero_point, int[] axis) -> Tensor + use_c10_dispatcher: True dispatch: CPU: per_channel_affine_qtensor_cpu @@ -3093,12 +3562,14 @@ QuantizedCPU: qscheme_quant - func: fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor + use_c10_dispatcher: True variants: function dispatch: CPU: fake_quantize_per_tensor_affine_cpu CUDA: fake_quantize_per_tensor_affine_cuda - func: fake_quantize_per_tensor_affine_backward(Tensor grad, Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor + use_c10_dispatcher: True variants: function dispatch: CPU: fake_quantize_per_tensor_affine_backward_cpu @@ -3110,41 +3581,47 @@ - func: to.dtype_layout(Tensor self, *, ScalarType dtype, Layout layout, Device device, bool pin_memory=False, bool non_blocking=False, bool copy=False) -> Tensor variants: method device_guard: False - named_guard: False + supports_named_tensor: True - func: to.device(Tensor self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False) -> Tensor variants: method device_guard: False - named_guard: False + supports_named_tensor: True - func: to.dtype(Tensor self, ScalarType dtype, bool non_blocking=False, bool copy=False) -> Tensor variants: method device_guard: False - named_guard: False + supports_named_tensor: True - func: to.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False) -> Tensor + use_c10_dispatcher: True variants: method device_guard: False - func: meshgrid(Tensor[] tensors) -> Tensor[] + use_c10_dispatcher: True - func: cartesian_prod(Tensor[] tensors) -> Tensor + use_c10_dispatcher: True variants: function - func: combinations(Tensor self, int r=2, bool with_replacement=False) -> Tensor + use_c10_dispatcher: True variants: function - func: item(Tensor self) -> Scalar + use_c10_dispatcher: True variants: method - named_guard: False + supports_named_tensor: True # NB: Does NOT check precondition that numel == 1 - func: _local_scalar_dense(Tensor self) -> Scalar + use_c10_dispatcher: True dispatch: CPU: _local_scalar_dense_cpu CUDA: _local_scalar_dense_cuda variants: function - named_guard: False + supports_named_tensor: True # Fused RNN kernels - func: _thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor) @@ -3160,25 +3637,34 @@ CUDA: _thnn_fused_gru_cell_cuda - func: _thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + use_c10_dispatcher: True dispatch: CUDA: _thnn_fused_gru_cell_backward_cuda # RNN cells and layers - func: lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: True - func: lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: True - func: gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) + use_c10_dispatcher: True - func: gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) + use_c10_dispatcher: True - func: rnn_tanh.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) + use_c10_dispatcher: True - func: rnn_tanh.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) + use_c10_dispatcher: True - func: rnn_relu.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) + use_c10_dispatcher: True - func: rnn_relu.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) + use_c10_dispatcher: True - func: lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor) @@ -3194,25 +3680,33 @@ # Quantized GRU layers - func: quantized_gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) + use_c10_dispatcher: True - func: quantized_gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) + use_c10_dispatcher: True # Quantized RNN cells - func: quantized_lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> (Tensor, Tensor) + use_c10_dispatcher: True - func: quantized_gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor + use_c10_dispatcher: True - func: quantized_rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor + use_c10_dispatcher: True - func: quantized_rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor - + use_c10_dispatcher: True # PackedSequence utilities - func: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor) + use_c10_dispatcher: True - func: _pack_padded_sequence_backward(Tensor grad, int[] input_size, Tensor batch_sizes, bool batch_first) -> Tensor + use_c10_dispatcher: True - func: _pad_packed_sequence(Tensor data, Tensor batch_sizes, bool batch_first, Scalar padding_value, int total_length) -> (Tensor, Tensor) + use_c10_dispatcher: True # wrappers for legacy TH methods @@ -3232,6 +3726,7 @@ QuantizedCPU: set_storage - func: set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!) + use_c10_dispatcher: True variants: method device_guard: False dispatch: @@ -3239,6 +3734,7 @@ CUDA: legacy::cuda::_th_set_ - func: set_(Tensor(a!) self) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_set_ @@ -3250,6 +3746,7 @@ QuantizedCPU: set_quantizer_ - func: is_set_to(Tensor self, Tensor tensor) -> bool + use_c10_dispatcher: True variants: method device_guard: False dispatch: @@ -3257,37 +3754,44 @@ CUDA: legacy::cuda::_th_is_set_to - func: masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: masked_fill__cpu CUDA: masked_fill__cuda - named_guard: False + supports_named_tensor: True - func: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor + use_c10_dispatcher: True variants: function, method - named_guard: False + supports_named_tensor: True - func: masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: masked_fill__cpu CUDA: masked_fill__cuda - named_guard: False + supports_named_tensor: True - func: masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor + use_c10_dispatcher: True variants: function, method - named_guard: False + supports_named_tensor: True - func: masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: masked_scatter__cpu CUDA: masked_scatter__cuda - func: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor + use_c10_dispatcher: True variants: function, method - func: view(Tensor(a) self, int[] size) -> Tensor(a) + use_c10_dispatcher: True variants: method device_guard: False dispatch: @@ -3297,348 +3801,403 @@ QuantizedCPU: view - func: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_put_ CUDA: legacy::cuda::_th_put_ - func: index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_index_add_ CUDA: legacy::cuda::_th_index_add_ - func: index_add(Tensor self, int dim, Tensor index, Tensor source) -> Tensor + use_c10_dispatcher: True variants: function, method - func: index_fill_.Scalar(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_index_fill_ CUDA: legacy::cuda::_th_index_fill_ - func: index_fill.Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor + use_c10_dispatcher: True variants: function, method - func: index_fill_.Tensor(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_index_fill_ CUDA: legacy::cuda::_th_index_fill_ - func: index_fill.Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor + use_c10_dispatcher: True variants: function, method - func: scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_scatter_ CUDA: legacy::cuda::_th_scatter_ - func: scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor + use_c10_dispatcher: True variants: function, method - func: scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_scatter_ CUDA: legacy::cuda::_th_scatter_ - func: scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor + use_c10_dispatcher: True variants: function, method - func: scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_scatter_add_ CUDA: legacy::cuda::_th_scatter_add_ - func: scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor + use_c10_dispatcher: True variants: function, method - func: lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_lt_ CUDA: legacy::cuda::_th_lt_ - func: lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_lt_ CUDA: legacy::cuda::_th_lt_ - func: gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_gt_ CUDA: legacy::cuda::_th_gt_ - func: gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_gt_ CUDA: legacy::cuda::_th_gt_ - func: le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_le_ CUDA: legacy::cuda::_th_le_ - func: le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_le_ CUDA: legacy::cuda::_th_le_ - func: ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_ge_ CUDA: legacy::cuda::_th_ge_ - func: ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_ge_ CUDA: legacy::cuda::_th_ge_ - func: eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_eq_ CUDA: legacy::cuda::_th_eq_ - func: eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_eq_ CUDA: legacy::cuda::_th_eq_ - func: ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_ne_ CUDA: legacy::cuda::_th_ne_ - func: ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_ne_ CUDA: legacy::cuda::_th_ne_ - func: __and__.Scalar(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_and CUDA: legacy::cuda::_th_and - func: __and__.Tensor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_and CUDA: legacy::cuda::_th_and - func: __iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_iand_ CUDA: legacy::cuda::_th_iand_ - func: __iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_iand_ CUDA: legacy::cuda::_th_iand_ - func: __or__.Scalar(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_or CUDA: legacy::cuda::_th_or - func: __or__.Tensor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_or CUDA: legacy::cuda::_th_or - func: __ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_ior_ CUDA: legacy::cuda::_th_ior_ - func: __ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_ior_ CUDA: legacy::cuda::_th_ior_ - func: __xor__.Scalar(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_xor CUDA: legacy::cuda::_th_xor - func: __xor__.Tensor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_xor CUDA: legacy::cuda::_th_xor - func: __ixor__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_ixor_ CUDA: legacy::cuda::_th_ixor_ - func: __ixor__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_ixor_ CUDA: legacy::cuda::_th_ixor_ - func: __lshift__.Scalar(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_lshift CUDA: legacy::cuda::_th_lshift - func: __lshift__.Tensor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_lshift CUDA: legacy::cuda::_th_lshift - func: __ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_ilshift_ CUDA: legacy::cuda::_th_ilshift_ - func: __ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_ilshift_ CUDA: legacy::cuda::_th_ilshift_ - func: __rshift__.Scalar(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_rshift CUDA: legacy::cuda::_th_rshift - func: __rshift__.Tensor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_rshift CUDA: legacy::cuda::_th_rshift - func: __irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_irshift_ CUDA: legacy::cuda::_th_irshift_ - func: __irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_irshift_ CUDA: legacy::cuda::_th_irshift_ - func: lgamma_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: method dispatch: - CPU: legacy::cpu::_th_lgamma_ + CPU: _lgamma__cpu CUDA: legacy::cuda::_th_lgamma_ - func: atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: True variants: method - func: tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: tril_cpu_ CUDA: tril_cuda_ - func: triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: triu_cpu_ CUDA: triu_cuda_ - func: digamma_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: method - dispatch: - CPU: _digamma__cpu - CUDA: legacy::cuda::_th_digamma_ - func: polygamma_(Tensor(a!) self, int n) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: method - dispatch: - CPU: polygamma_ - CUDA: legacy::cuda::_th_polygamma_ - func: renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_renorm_ CUDA: legacy::cuda::_th_renorm_ - func: pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: pow_ - CUDA: legacy::cuda::_th_pow_ + CUDA: pow_ - func: pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: pow_ - CUDA: legacy::cuda::_th_pow_ + CUDA: pow_ - func: lerp_.Scalar(Tensor(a!) self, Tensor end, Scalar weight) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: lerp_cpu_scalar_ CUDA: lerp_cuda_scalar_ - func: lerp_.Tensor(Tensor(a!) self, Tensor end, Tensor weight) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: lerp_cpu_tensor_ CUDA: lerp_cuda_tensor_ - func: fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_fmod_ CUDA: legacy::cuda::_th_fmod_ - func: fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_fmod_ CUDA: legacy::cuda::_th_fmod_ - func: remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_remainder_ CUDA: legacy::cuda::_th_remainder_ - func: remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_remainder_ CUDA: legacy::cuda::_th_remainder_ - func: addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_addbmm_ @@ -3650,12 +4209,14 @@ CUDA: legacy::cuda::_th_addbmm_out - func: addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_addbmm CUDA: legacy::cuda::_th_addbmm - func: addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) + use_c10_dispatcher: True variants: method - func: random_.from(Tensor(a!) self, int from, int to, *, Generator? generator=None) -> Tensor(a!) @@ -3663,63 +4224,63 @@ dispatch: CPU: legacy::cpu::_th_random_ CUDA: clamped_random_cuda_ - named_guard: False + supports_named_tensor: True - func: random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: CPU: legacy::cpu::_th_random_ CUDA: capped_random_cuda_ - named_guard: False + supports_named_tensor: True - func: random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: CPU: legacy::cpu::_th_random_ CUDA: random_cuda_ - named_guard: False + supports_named_tensor: True - func: uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: CPU: legacy::cpu::_th_uniform_ CUDA: uniform_cuda_ - named_guard: False + supports_named_tensor: True - func: normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: CPU: legacy::cpu::_th_normal_ CUDA: normal_cuda_ - named_guard: False + supports_named_tensor: True - func: cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: CPU: legacy::cpu::_th_cauchy_ CUDA: cauchy_cuda_ - named_guard: False + supports_named_tensor: True - func: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: CPU: legacy::cpu::_th_log_normal_ CUDA: log_normal_cuda_ - named_guard: False + supports_named_tensor: True - func: exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: CPU: legacy::cpu::_th_exponential_ CUDA: exponential_cuda_ - named_guard: False + supports_named_tensor: True - func: geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!) variants: method dispatch: CPU: legacy::cpu::_th_geometric_ CUDA: geometric_cuda_ - named_guard: False + supports_named_tensor: True # wrappers for TH functions @@ -3729,6 +4290,7 @@ CUDA: legacy::cuda::_th_diag_out - func: diag(Tensor self, int diagonal=0) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_diag @@ -3737,6 +4299,7 @@ - func: cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) - func: cross(Tensor self, Tensor other, int? dim=None) -> Tensor + use_c10_dispatcher: True variants: method, function - func: triu.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) @@ -3745,6 +4308,7 @@ CUDA: triu_cuda_out - func: triu(Tensor self, int diagonal=0) -> Tensor + use_c10_dispatcher: True variants: method, function - func: tril.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) @@ -3753,6 +4317,7 @@ CUDA: tril_cuda_out - func: tril(Tensor self, int diagonal=0) -> Tensor + use_c10_dispatcher: True variants: method, function - func: tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -3766,6 +4331,7 @@ CUDA: triu_indices_cuda - func: trace(Tensor self) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_trace @@ -3778,6 +4344,7 @@ QuantizedCPU: ne_out_quantized_cpu - func: ne.Scalar(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_ne @@ -3791,6 +4358,7 @@ QuantizedCPU: ne_out_quantized_cpu - func: ne.Tensor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_ne @@ -3804,6 +4372,7 @@ QuantizedCPU: eq_out_quantized_cpu - func: eq.Scalar(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_eq @@ -3817,6 +4386,7 @@ QuantizedCPU: eq_out_quantized_cpu - func: eq.Tensor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_eq @@ -3830,6 +4400,7 @@ QuantizedCPU: ge_out_quantized_cpu - func: ge.Scalar(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_ge @@ -3843,6 +4414,7 @@ QuantizedCPU: ge_out_quantized_cpu - func: ge.Tensor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_ge @@ -3856,6 +4428,7 @@ QuantizedCPU: le_out_quantized_cpu - func: le.Scalar(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_le @@ -3869,6 +4442,7 @@ QuantizedCPU: le_out_quantized_cpu - func: le.Tensor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_le @@ -3882,6 +4456,7 @@ QuantizedCPU: gt_out_quantized_cpu - func: gt.Scalar(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_gt @@ -3895,6 +4470,7 @@ QuantizedCPU: gt_out_quantized_cpu - func: gt.Tensor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_gt @@ -3908,6 +4484,7 @@ QuantizedCPU: lt_out_quantized_cpu - func: lt.Scalar(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_lt @@ -3921,6 +4498,7 @@ QuantizedCPU: lt_out_quantized_cpu - func: lt.Tensor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_lt @@ -3933,6 +4511,7 @@ CUDA: legacy::cuda::_th_take_out - func: take(Tensor self, Tensor index) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_take @@ -3944,6 +4523,7 @@ CUDA: legacy::cuda::_th_index_select_out - func: index_select(Tensor self, int dim, Tensor index) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_index_select @@ -3955,14 +4535,15 @@ dispatch: CPU: masked_select_out_cpu CUDA: masked_select_out_cuda - named_guard: False + supports_named_tensor: True - func: masked_select(Tensor self, Tensor mask) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: masked_select_cpu CUDA: masked_select_cuda - named_guard: False + supports_named_tensor: True - func: nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -3970,12 +4551,14 @@ CUDA: legacy::cuda::_th_nonzero_out - func: nonzero(Tensor self) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_nonzero CUDA: legacy::cuda::_th_nonzero - func: nonzero_numpy(Tensor self) -> Tensor[] + use_c10_dispatcher: True variants: method, function - func: gather.out(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) @@ -3984,24 +4567,29 @@ CUDA: gather_out_cuda - func: gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: gather_cpu CUDA: gather_cuda - func: _gather_sparse_backward(Tensor self, int dim, Tensor index, Tensor grad) -> Tensor + use_c10_dispatcher: True - func: addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) - func: addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor + use_c10_dispatcher: True variants: method, function - func: addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) + use_c10_dispatcher: True variants: method - func: addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) - func: addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor + use_c10_dispatcher: True variants: method, function - func: lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR) @@ -4010,6 +4598,7 @@ CUDA: legacy::cuda::_th_gels_out - func: lstsq(Tensor self, Tensor A) -> (Tensor solution, Tensor QR) + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_gels @@ -4018,9 +4607,11 @@ - func: triangular_solve.X(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!) solution, Tensor(b!) cloned_coefficient) - func: triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient) + use_c10_dispatcher: True variants: method, function - func: _triangular_solve_helper(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular) -> (Tensor, Tensor) + use_c10_dispatcher: True variants: function dispatch: CPU: _triangular_solve_helper_cpu @@ -4029,9 +4620,11 @@ - func: symeig.e(Tensor self, bool eigenvectors=False, bool upper=True, *, Tensor(a!) e, Tensor(b!) V) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) - func: symeig(Tensor self, bool eigenvectors=False, bool upper=True) -> (Tensor eigenvalues, Tensor eigenvectors) + use_c10_dispatcher: True variants: method, function - func: _symeig_helper(Tensor self, bool eigenvectors, bool upper) -> (Tensor, Tensor) + use_c10_dispatcher: True variants: function dispatch: CPU: _symeig_helper_cpu @@ -4043,6 +4636,7 @@ CUDA: legacy::cuda::_th_eig_out - func: eig(Tensor self, bool eigenvectors=False) -> (Tensor eigenvalues, Tensor eigenvectors) + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_eig @@ -4051,9 +4645,11 @@ - func: svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) - func: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) + use_c10_dispatcher: True variants: method, function - func: _svd_helper(Tensor self, bool some, bool compute_uv) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: True variants: function dispatch: CPU: _svd_helper_cpu @@ -4062,9 +4658,11 @@ - func: cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) - func: cholesky(Tensor self, bool upper=False) -> Tensor + use_c10_dispatcher: True variants: method, function - func: _cholesky_helper(Tensor self, bool upper) -> Tensor + use_c10_dispatcher: True variants: function dispatch: CPU: _cholesky_helper_cpu @@ -4073,20 +4671,24 @@ - func: cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) - func: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor + use_c10_dispatcher: True variants: method, function - func: _cholesky_solve_helper(Tensor self, Tensor A, bool upper) -> Tensor + use_c10_dispatcher: True variants: function dispatch: CPU: _cholesky_solve_helper_cpu CUDA: _cholesky_solve_helper_cuda - func: solve(Tensor self, Tensor A) -> (Tensor solution, Tensor LU) + use_c10_dispatcher: True variants: function, method - func: solve.solution(Tensor self, Tensor A, *, Tensor(a!) solution, Tensor(b!) lu) -> (Tensor(a!) solution, Tensor(b!) LU) - func: _solve_helper(Tensor self, Tensor A) -> (Tensor, Tensor) + use_c10_dispatcher: True variants: function dispatch: CPU: _solve_helper_cpu @@ -4098,6 +4700,7 @@ CUDA: legacy::cuda::_th_potri_out - func: cholesky_inverse(Tensor self, bool upper=False) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_potri @@ -4106,9 +4709,11 @@ - func: qr.Q(Tensor self, bool some=True, *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R) - func: qr(Tensor self, bool some=True) -> (Tensor Q, Tensor R) + use_c10_dispatcher: True variants: method, function - func: _qr_helper(Tensor self, bool some) -> (Tensor, Tensor) + use_c10_dispatcher: True variants: function dispatch: CPU: _qr_helper_cpu @@ -4120,6 +4725,7 @@ CUDA: legacy::cuda::_th_geqrf_out - func: geqrf(Tensor self) -> (Tensor a, Tensor tau) + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_geqrf @@ -4130,6 +4736,7 @@ CPU: legacy::cpu::_th_orgqr_out - func: orgqr(Tensor self, Tensor input2) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_orgqr @@ -4139,11 +4746,13 @@ CPU: legacy::cpu::_th_ormqr_out - func: ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_ormqr - func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor, Tensor, Tensor) + use_c10_dispatcher: True variants: function dispatch: CPU: _lu_with_info_cpu @@ -4152,26 +4761,30 @@ - func: lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!) - func: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor + use_c10_dispatcher: True variants: method, function - func: _lu_solve_helper(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor + use_c10_dispatcher: True variants: function dispatch: CPU: _lu_solve_helper_cpu CUDA: _lu_solve_helper_cuda +# TODO: remove dispatch section when porting TH CUDA to ATen - func: multinomial.out(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU: multinomial_out_cpu - CUDA: legacy::cuda::_th_multinomial_out + CPU: multinomial_out + CUDA: multinomial_out - func: multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor variants: method, function dispatch: - CPU: multinomial_cpu - CUDA: legacy::cuda::_th_multinomial + CPU: multinomial + CUDA: multinomial - func: _multinomial_alias_setup(Tensor probs) -> (Tensor, Tensor) + use_c10_dispatcher: True variants: function dispatch: CPU: legacy::cpu::_th_multinomial_alias_setup @@ -4184,73 +4797,69 @@ CUDA: legacy::cuda::_th_multinomial_alias_draw - func: lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: - CPU: legacy::cpu::_th_lgamma_out + CPU: _lgamma_out_cpu CUDA: legacy::cuda::_th_lgamma_out - func: lgamma(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: method, function dispatch: - CPU: legacy::cpu::_th_lgamma + CPU: lgamma CUDA: legacy::cuda::_th_lgamma - func: digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False - dispatch: - CPU: _digamma_out_cpu - CUDA: legacy::cuda::_th_digamma_out + supports_named_tensor: True - func: digamma(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: method, function - dispatch: - CPU: digamma - CUDA: legacy::cuda::_th_digamma - func: polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False - dispatch: - CPU: polygamma_out - CUDA: legacy::cuda::_th_polygamma_out + supports_named_tensor: True - func: polygamma(int n, Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: method, function - dispatch: - CPU: polygamma - CUDA: legacy::cuda::_th_polygamma - func: erfinv(Tensor self) -> Tensor - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: method, function - func: erfinv_(Tensor(a!) self) -> Tensor(a!) - named_guard: False + use_c10_dispatcher: True + supports_named_tensor: True variants: method - func: erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: erfinv_out CUDA: erfinv_out - func: sign(Tensor self) -> Tensor + use_c10_dispatcher: True variants: function, method - named_guard: False + supports_named_tensor: True - func: sign_(Tensor(a!) self) -> Tensor(a!) + use_c10_dispatcher: True variants: method - named_guard: False + supports_named_tensor: True - func: sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) - named_guard: False + supports_named_tensor: True dispatch: CPU: sign_out CUDA: sign_out - func: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_dist @@ -4259,6 +4868,7 @@ - func: atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) - func: atan2(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method, function - func: lerp.Scalar_out(Tensor self, Tensor end, Scalar weight, *, Tensor(a!) out) -> Tensor(a!) @@ -4272,12 +4882,14 @@ CUDA: lerp_cuda_tensor_out - func: lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: lerp_cpu_scalar CUDA: lerp_cuda_scalar - func: lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: lerp_cpu_tensor @@ -4289,6 +4901,7 @@ CUDA: _histc_out_cuda - func: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_histc @@ -4300,6 +4913,7 @@ CUDA: legacy::cuda::_th_fmod_out - func: fmod.Scalar(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_fmod @@ -4311,6 +4925,7 @@ CUDA: legacy::cuda::_th_fmod_out - func: fmod.Tensor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_fmod @@ -4322,6 +4937,7 @@ CUDA: legacy::cuda::_th_remainder_out - func: remainder.Scalar(Tensor self, Scalar other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_remainder @@ -4333,6 +4949,7 @@ CUDA: legacy::cuda::_th_remainder_out - func: remainder.Tensor(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_remainder @@ -4344,12 +4961,14 @@ CUDA: legacy::cuda::_th_min_out - func: min.other(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_min CUDA: legacy::cuda::_th_min - func: min(Tensor self) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_min @@ -4362,12 +4981,14 @@ CUDA: legacy::cuda::_th_max_out - func: max.other(Tensor self, Tensor other) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_max CUDA: legacy::cuda::_th_max - func: max(Tensor self) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_max @@ -4375,6 +4996,7 @@ QuantizedCPU: max_quant - func: median(Tensor self) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: median_cpu @@ -4386,6 +5008,7 @@ CUDA: legacy::cuda::_th_sort_out - func: sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_sort @@ -4393,6 +5016,7 @@ QuantizedCPU: sort_quant - func: argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor + use_c10_dispatcher: True variants: method, function - func: topk.values(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) ->(Tensor(a!) values, Tensor(b!) indices) @@ -4401,12 +5025,15 @@ CUDA: legacy::cuda::_th_topk_out - func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) + use_c10_dispatcher: True variants: method, function - func: all(Tensor self) -> Tensor + use_c10_dispatcher: True variants: method, function - func: any(Tensor self) -> Tensor + use_c10_dispatcher: True variants: method, function - func: renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!) @@ -4415,18 +5042,21 @@ CUDA: legacy::cuda::_th_renorm_out - func: renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_renorm CUDA: legacy::cuda::_th_renorm - func: unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a) + use_c10_dispatcher: True variants: method dispatch: CPU: legacy::cpu::_th_unfold CUDA: legacy::cuda::_th_unfold - func: equal(Tensor self, Tensor other) -> bool + use_c10_dispatcher: True variants: method, function dispatch: CPU: legacy::cpu::_th_equal @@ -4436,23 +5066,25 @@ - func: pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: pow_out - CUDA: legacy::cuda::_th_pow_out + CUDA: pow_out - func: pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor + use_c10_dispatcher: True variants: method, function dispatch: CPU: pow - CUDA: legacy::cuda::_th_pow + CUDA: pow - func: pow.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU: pow_out - CUDA: legacy::cuda::_th_pow_out + CUDA: pow_out - func: pow.Scalar(Scalar self, Tensor exponent) -> Tensor + use_c10_dispatcher: True dispatch: CPU: pow - CUDA: legacy::cuda::_th_pow + CUDA: pow - func: normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) dispatch: @@ -4489,15 +5121,18 @@ - func: normal.float_float_out(float mean, float std, int[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) - func: alias(Tensor(a) self) -> Tensor(a) + use_c10_dispatcher: True variants: method, function - named_guard: False + supports_named_tensor: True - func: _addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + use_c10_dispatcher: True dispatch: CPU: legacy::cpu::_th_addr CUDA: legacy::cuda::_th_addr - func: _addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + use_c10_dispatcher: True dispatch: CPU: legacy::cpu::_th_addr_ CUDA: legacy::cuda::_th_addr_ @@ -4508,11 +5143,13 @@ CUDA: legacy::cuda::_th_addr_out - func: _index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!) + use_c10_dispatcher: True dispatch: CPU: legacy::cpu::_th_index_copy_ CUDA: legacy::cuda::_th_index_copy_ - func: _cumsum(Tensor self, int dim) -> Tensor + use_c10_dispatcher: True dispatch: CPU: legacy::cpu::_th_cumsum CUDA: legacy::cuda::_th_cumsum @@ -4523,6 +5160,7 @@ CUDA: legacy::cuda::_th_cumsum_out - func: _cumprod(Tensor self, int dim) -> Tensor + use_c10_dispatcher: True dispatch: CPU: legacy::cpu::_th_cumprod CUDA: legacy::cuda::_th_cumprod @@ -4533,36 +5171,21 @@ CUDA: legacy::cuda::_th_cumprod_out - func: _var(Tensor self, bool unbiased=True) -> Tensor + use_c10_dispatcher: True dispatch: CPU: legacy::cpu::_th_var CUDA: legacy::cuda::_th_var - named_guard: False + supports_named_tensor: True - func: _std(Tensor self, bool unbiased=True) -> Tensor + use_c10_dispatcher: True dispatch: CPU: legacy::cpu::_th_std CUDA: legacy::cuda::_th_std - named_guard: False - -- func: _addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) - dispatch: - CPU: legacy::cpu::_th_addmm_out - CUDA: legacy::cuda::_th_addmm_out - named_guard: False - -- func: _addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor - dispatch: - CPU: legacy::cpu::_th_addmm - CUDA: legacy::cuda::_th_addmm - named_guard: False - -- func: _addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) - dispatch: - CPU: legacy::cpu::_th_addmm_ - CUDA: legacy::cuda::_th_addmm_ - named_guard: False + supports_named_tensor: True - func: _cat(Tensor[] tensors, int dim=0) -> Tensor + use_c10_dispatcher: True dispatch: CPU: legacy::cpu::_th_cat CUDA: legacy::cuda::_th_cat @@ -4573,6 +5196,7 @@ CUDA: legacy::cuda::_th_cat_out - func: _mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor, Tensor) + use_c10_dispatcher: True dispatch: CPU: legacy::cpu::_th_mode CUDA: legacy::cuda::_th_mode @@ -4583,6 +5207,7 @@ CUDA: legacy::cuda::_th_mode_out - func: _max(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor) + use_c10_dispatcher: True dispatch: CPU: legacy::cpu::_th_max CUDA: legacy::cuda::_th_max @@ -4593,6 +5218,7 @@ CUDA: legacy::cuda::_th_max_out - func: _min(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor) + use_c10_dispatcher: True dispatch: CPU: legacy::cpu::_th_min CUDA: legacy::cuda::_th_min @@ -4635,6 +5261,7 @@ CUDA: legacy::cuda::_thnn_mse_loss_forward_out - func: mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_mse_loss_forward @@ -4647,6 +5274,7 @@ CUDA: legacy::cuda::_thnn_mse_loss_backward_out - func: mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_mse_loss_backward @@ -4659,6 +5287,7 @@ CUDA: legacy::cuda::_thnn_l1_loss_forward_out - func: l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_l1_loss_forward @@ -4671,6 +5300,7 @@ CUDA: legacy::cuda::_thnn_l1_loss_backward_out - func: l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_l1_loss_backward @@ -4704,6 +5334,7 @@ python_module: nn - func: multilabel_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + use_c10_dispatcher: True python_module: nn - func: multilabel_margin_loss_forward.output(Tensor self, Tensor target, int reduction, *, Tensor(a!) output, Tensor(b!) is_target) -> (Tensor(a!), Tensor(b!)) @@ -4713,6 +5344,7 @@ CUDA: legacy::cuda::_thnn_multilabel_margin_loss_forward_out - func: multilabel_margin_loss_forward(Tensor self, Tensor target, int reduction) -> (Tensor output, Tensor is_target) + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_multilabel_margin_loss_forward @@ -4725,6 +5357,7 @@ CUDA: legacy::cuda::_thnn_multilabel_margin_loss_backward_out - func: multilabel_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_multilabel_margin_loss_backward @@ -4797,6 +5430,7 @@ CUDA: legacy::cuda::_thnn_smooth_l1_loss_forward_out - func: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_smooth_l1_loss_forward @@ -4809,6 +5443,7 @@ CUDA: legacy::cuda::_thnn_smooth_l1_loss_backward_out - func: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_smooth_l1_loss_backward @@ -4821,6 +5456,7 @@ CUDA: legacy::cuda::_thnn_soft_margin_loss_forward_out - func: soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_soft_margin_loss_forward @@ -4833,6 +5469,7 @@ CUDA: legacy::cuda::_thnn_soft_margin_loss_backward_out - func: soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_soft_margin_loss_backward @@ -4845,6 +5482,7 @@ CUDA: legacy::cuda::_thnn_elu_forward_out - func: elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_elu_forward @@ -4857,12 +5495,14 @@ CUDA: legacy::cuda::_thnn_elu_backward_out - func: elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, Tensor output) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_elu_backward CUDA: legacy::cuda::_thnn_elu_backward - func: elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!) + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_elu_forward_ @@ -4875,6 +5515,7 @@ CUDA: legacy::cuda::_thnn_glu_forward_out - func: glu(Tensor self, int dim=-1) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_glu_forward @@ -4887,6 +5528,7 @@ CUDA: legacy::cuda::_thnn_glu_backward_out - func: glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_glu_backward @@ -4899,6 +5541,7 @@ CUDA: legacy::cuda::_thnn_hardtanh_forward_out - func: hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_hardtanh_forward @@ -4911,12 +5554,14 @@ CUDA: legacy::cuda::_thnn_hardtanh_backward_out - func: hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_hardtanh_backward CUDA: legacy::cuda::_thnn_hardtanh_backward - func: hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> Tensor(a!) + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_hardtanh_forward_ @@ -4929,6 +5574,7 @@ CUDA: legacy::cuda::_thnn_leaky_relu_forward_out - func: leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_leaky_relu_forward @@ -4941,12 +5587,14 @@ CUDA: legacy::cuda::_thnn_leaky_relu_backward_out - func: leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_leaky_relu_backward CUDA: legacy::cuda::_thnn_leaky_relu_backward - func: leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!) + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_leaky_relu_forward_ @@ -4956,6 +5604,7 @@ python_module: nn - func: log_sigmoid(Tensor self) -> Tensor + use_c10_dispatcher: True python_module: nn - func: log_sigmoid_forward.output(Tensor self, *, Tensor(a!) output, Tensor(b!) buffer) -> (Tensor(a!), Tensor(b!)) @@ -4965,6 +5614,7 @@ CUDA: legacy::cuda::_thnn_log_sigmoid_forward_out - func: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer) + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_log_sigmoid_forward @@ -4977,6 +5627,7 @@ CUDA: legacy::cuda::_thnn_log_sigmoid_backward_out - func: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_log_sigmoid_backward @@ -5001,6 +5652,7 @@ CUDA: legacy::cuda::_thnn_rrelu_with_noise_backward_out - func: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_rrelu_with_noise_backward @@ -5019,6 +5671,7 @@ CUDA: legacy::cuda::_thnn_softplus_forward_out - func: softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_softplus_forward @@ -5031,6 +5684,7 @@ CUDA: legacy::cuda::_thnn_softplus_backward_out - func: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, Tensor output) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_softplus_backward @@ -5043,6 +5697,7 @@ CUDA: legacy::cuda::_thnn_softshrink_forward_out - func: softshrink(Tensor self, Scalar lambd=0.5) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_softshrink_forward @@ -5055,6 +5710,7 @@ CUDA: legacy::cuda::_thnn_softshrink_backward_out - func: softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_softshrink_backward @@ -5069,20 +5725,24 @@ QuantizedCPU: quantized_adaptive_avg_pool2d_out - func: adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor + use_c10_dispatcher: True python_module: nn - func: mkldnn_adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor + use_c10_dispatcher: True dispatch: MkldnnCPU: mkldnn_adaptive_avg_pool2d requires_tensor: True - func: _adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor + use_c10_dispatcher: True dispatch: CPU: adaptive_avg_pool2d_cpu CUDA: adaptive_avg_pool2d_cuda QuantizedCPU: quantized_adaptive_avg_pool2d - func: _adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: adaptive_avg_pool2d_backward_cpu @@ -5095,6 +5755,7 @@ CUDA: adaptive_avg_pool3d_out_cuda - func: adaptive_avg_pool3d(Tensor self, int[3] output_size) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: adaptive_avg_pool3d_cpu @@ -5107,6 +5768,7 @@ CUDA: adaptive_avg_pool3d_backward_out_cuda - func: adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: adaptive_avg_pool3d_backward_cpu @@ -5121,6 +5783,7 @@ # Return: (Tensor output, Tensor indices) - func: adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor) + use_c10_dispatcher: True python_module: nn dispatch: CPU: adaptive_max_pool2d_cpu @@ -5133,6 +5796,7 @@ CUDA: adaptive_max_pool2d_backward_out_cuda - func: adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: adaptive_max_pool2d_backward_cpu @@ -5147,6 +5811,7 @@ # Return: (Tensor output, Tensor indices) - func: adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor) + use_c10_dispatcher: True python_module: nn dispatch: CPU: adaptive_max_pool3d_cpu @@ -5159,6 +5824,7 @@ CUDA: adaptive_max_pool3d_backward_out_cuda - func: adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: adaptive_max_pool3d_backward_cpu @@ -5172,6 +5838,7 @@ MkldnnCPU: mkldnn_avg_pool2d_out - func: avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: avg_pool2d_cpu @@ -5185,6 +5852,7 @@ CUDA: avg_pool2d_backward_out_cuda - func: avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: avg_pool2d_backward_cpu @@ -5197,6 +5865,7 @@ CUDA: avg_pool3d_out_cuda - func: avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: avg_pool3d_cpu @@ -5209,6 +5878,7 @@ CUDA: avg_pool3d_backward_out_cuda - func: avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: avg_pool3d_backward_cpu @@ -5223,6 +5893,7 @@ # Return: (Tensor output, Tensor indices) - func: fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor) + use_c10_dispatcher: True python_module: nn dispatch: CPU: fractional_max_pool2d_cpu @@ -5235,6 +5906,7 @@ CUDA: fractional_max_pool2d_backward_out_cuda - func: fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: fractional_max_pool2d_backward_cpu @@ -5249,6 +5921,7 @@ # Return: (Tensor output, Tensor indices) - func: fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor) + use_c10_dispatcher: True python_module: nn dispatch: CPU: fractional_max_pool3d_cpu @@ -5261,6 +5934,7 @@ CUDA: fractional_max_pool3d_backward_out_cuda - func: fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: fractional_max_pool3d_backward_cpu @@ -5275,6 +5949,7 @@ # Return: (Tensor output, Tensor indices) - func: max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) + use_c10_dispatcher: True python_module: nn dispatch: CPU: max_pool2d_with_indices_cpu @@ -5287,6 +5962,7 @@ CUDA: max_pool2d_with_indices_backward_out_cuda - func: max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: max_pool2d_with_indices_backward_cpu @@ -5301,6 +5977,7 @@ # Return: (Tensor output, Tensor indices) - func: max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) + use_c10_dispatcher: True python_module: nn dispatch: CPU: max_pool3d_with_indices_cpu @@ -5313,6 +5990,7 @@ CUDA: max_pool3d_with_indices_backward_out_cuda - func: max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: max_pool3d_with_indices_backward_cpu @@ -5325,6 +6003,7 @@ CUDA: max_unpooling2d_forward_out_cuda - func: max_unpool2d(Tensor self, Tensor indices, int[2] output_size) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: max_unpooling2d_forward_cpu @@ -5337,6 +6016,7 @@ CUDA: max_unpooling2d_backward_out_cuda - func: max_unpool2d_backward(Tensor grad_output, Tensor self, Tensor indices, int[2] output_size) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: max_unpooling2d_backward_cpu @@ -5349,6 +6029,7 @@ CUDA: max_unpooling3d_forward_out_cuda - func: max_unpool3d(Tensor self, Tensor indices, int[3] output_size, int[3] stride, int[3] padding) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: max_unpooling3d_forward_cpu @@ -5361,6 +6042,7 @@ CUDA: max_unpooling3d_backward_out_cuda - func: max_unpool3d_backward(Tensor grad_output, Tensor self, Tensor indices, int[3] output_size, int[3] stride, int[3] padding) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: max_unpooling3d_backward_cpu @@ -5373,6 +6055,7 @@ CUDA: reflection_pad1d_out_cuda - func: reflection_pad1d(Tensor self, int[2] padding) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: reflection_pad1d_cpu @@ -5385,6 +6068,7 @@ CUDA: reflection_pad1d_backward_out_cuda - func: reflection_pad1d_backward(Tensor grad_output, Tensor self, int[2] padding) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: reflection_pad1d_backward_cpu @@ -5397,6 +6081,7 @@ CUDA: reflection_pad2d_out_cuda - func: reflection_pad2d(Tensor self, int[4] padding) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: reflection_pad2d_cpu @@ -5409,6 +6094,7 @@ CUDA: reflection_pad2d_backward_out_cuda - func: reflection_pad2d_backward(Tensor grad_output, Tensor self, int[4] padding) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: reflection_pad2d_backward_cpu @@ -5421,6 +6107,7 @@ CUDA: replication_pad1d_out_cuda - func: replication_pad1d(Tensor self, int[2] padding) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: replication_pad1d_cpu @@ -5433,6 +6120,7 @@ CUDA: replication_pad1d_backward_out_cuda - func: replication_pad1d_backward(Tensor grad_output, Tensor self, int[2] padding) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: replication_pad1d_backward_cpu @@ -5445,6 +6133,7 @@ CUDA: replication_pad2d_out_cuda - func: replication_pad2d(Tensor self, int[4] padding) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: replication_pad2d_cpu @@ -5457,6 +6146,7 @@ CUDA: replication_pad2d_backward_out_cuda - func: replication_pad2d_backward(Tensor grad_output, Tensor self, int[4] padding) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: replication_pad2d_backward_cpu @@ -5469,6 +6159,7 @@ CUDA: replication_pad3d_out_cuda - func: replication_pad3d(Tensor self, int[6] padding) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: replication_pad3d_cpu @@ -5481,6 +6172,7 @@ CUDA: replication_pad3d_backward_out_cuda - func: replication_pad3d_backward(Tensor grad_output, Tensor self, int[6] padding) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: replication_pad3d_backward_cpu @@ -5493,6 +6185,7 @@ CUDA: upsample_linear1d_out_cuda - func: upsample_linear1d(Tensor self, int[1] output_size, bool align_corners) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: upsample_linear1d_cpu @@ -5505,6 +6198,7 @@ CUDA: upsample_linear1d_backward_out_cuda - func: upsample_linear1d_backward(Tensor grad_output, int[1] output_size, int[3] input_size, bool align_corners) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: upsample_linear1d_backward_cpu @@ -5517,6 +6211,7 @@ CUDA: upsample_bilinear2d_out_cuda - func: upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: upsample_bilinear2d_cpu @@ -5529,6 +6224,7 @@ CUDA: upsample_bilinear2d_backward_out_cuda - func: upsample_bilinear2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size, bool align_corners) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: upsample_bilinear2d_backward_cpu @@ -5541,6 +6237,7 @@ CUDA: upsample_bicubic2d_out_cuda - func: upsample_bicubic2d(Tensor self, int[2] output_size, bool align_corners) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: upsample_bicubic2d_cpu @@ -5553,6 +6250,7 @@ CUDA: upsample_bicubic2d_backward_out_cuda - func: upsample_bicubic2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size, bool align_corners) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: upsample_bicubic2d_backward_cpu @@ -5565,6 +6263,7 @@ CUDA: upsample_trilinear3d_out_cuda - func: upsample_trilinear3d(Tensor self, int[3] output_size, bool align_corners) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: upsample_trilinear3d_cpu @@ -5577,6 +6276,7 @@ CUDA: upsample_trilinear3d_backward_out_cuda - func: upsample_trilinear3d_backward(Tensor grad_output, int[3] output_size, int[5] input_size, bool align_corners) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: upsample_trilinear3d_backward_cpu @@ -5589,6 +6289,7 @@ CUDA: upsample_nearest1d_out_cuda - func: upsample_nearest1d(Tensor self, int[1] output_size) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: upsample_nearest1d_cpu @@ -5601,6 +6302,7 @@ CUDA: upsample_nearest1d_backward_out_cuda - func: upsample_nearest1d_backward(Tensor grad_output, int[1] output_size, int[3] input_size) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: upsample_nearest1d_backward_cpu @@ -5613,6 +6315,7 @@ CUDA: upsample_nearest2d_out_cuda - func: upsample_nearest2d(Tensor self, int[2] output_size) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: upsample_nearest2d_cpu @@ -5625,6 +6328,7 @@ CUDA: upsample_nearest2d_backward_out_cuda - func: upsample_nearest2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: upsample_nearest2d_backward_cpu @@ -5637,6 +6341,7 @@ CUDA: upsample_nearest3d_out_cuda - func: upsample_nearest3d(Tensor self, int[3] output_size) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: upsample_nearest3d_cpu @@ -5649,6 +6354,7 @@ CUDA: upsample_nearest3d_backward_out_cuda - func: upsample_nearest3d_backward(Tensor grad_output, int[3] output_size, int[5] input_size) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: upsample_nearest3d_backward_cpu @@ -5661,6 +6367,7 @@ CUDA: legacy::cuda::_thnn_sigmoid_backward_out - func: sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_sigmoid_backward @@ -5673,6 +6380,7 @@ CUDA: legacy::cuda::_thnn_tanh_backward_out - func: tanh_backward(Tensor grad_output, Tensor output) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_tanh_backward @@ -5715,6 +6423,7 @@ CUDA: slow_conv_transpose2d_backward_out_cuda - func: slow_conv_transpose2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] output_padding, int[2] dilation, Tensor columns, Tensor ones, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + use_c10_dispatcher: True python_module: nn dispatch: CPU: slow_conv_transpose2d_backward_cpu @@ -5739,6 +6448,7 @@ CUDA: slow_conv_transpose3d_backward_out_cuda - func: slow_conv_transpose3d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, int[3] output_padding, int[3] dilation, Tensor finput, Tensor fgrad_input, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + use_c10_dispatcher: True python_module: nn dispatch: CPU: slow_conv_transpose3d_backward_cpu @@ -5769,6 +6479,7 @@ CUDA: legacy::cuda::_thnn_conv2d_backward_out - func: thnn_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, Tensor finput, Tensor fgrad_input, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_conv2d_backward @@ -5796,6 +6507,7 @@ CUDA: legacy::cuda::_thnn_conv_depthwise2d_backward_out - func: thnn_conv_depthwise2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool[2] output_mask) -> (Tensor grad_input, Tensor grad_weight) + use_c10_dispatcher: True python_module: nn dispatch: CUDA: legacy::cuda::_thnn_conv_depthwise2d_backward @@ -5822,6 +6534,7 @@ CPU: legacy::cpu::_thnn_conv3d_backward_out - func: thnn_conv3d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, Tensor finput, Tensor fgrad_input, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + use_c10_dispatcher: True python_module: nn dispatch: CPU: legacy::cpu::_thnn_conv3d_backward @@ -5833,6 +6546,7 @@ CUDA: slow_conv_dilated2d_cuda - func: slow_conv_dilated2d_backward(Tensor grad_output, Tensor self, Tensor weight, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + use_c10_dispatcher: True python_module: nn dispatch: CPU: slow_conv_dilated2d_backward_cpu @@ -5845,6 +6559,7 @@ CUDA: slow_conv_dilated3d_cuda - func: slow_conv_dilated3d_backward(Tensor grad_output, Tensor self, Tensor weight, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + use_c10_dispatcher: True python_module: nn dispatch: CPU: slow_conv_dilated3d_backward_cpu @@ -5857,6 +6572,7 @@ CUDA: col2im_out_cuda - func: col2im(Tensor self, int[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: col2im_cpu @@ -5869,6 +6585,7 @@ CUDA: col2im_backward_out_cuda - func: col2im_backward(Tensor grad_output, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: col2im_backward_cpu @@ -5881,6 +6598,7 @@ CUDA: im2col_out_cuda - func: im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: im2col_cpu @@ -5893,6 +6611,7 @@ CUDA: im2col_backward_out_cuda - func: im2col_backward(Tensor grad_output, int[2] input_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + use_c10_dispatcher: True python_module: nn dispatch: CPU: im2col_backward_cpu diff --git a/aten/src/ATen/native/quantized/QTensor.cpp b/aten/src/ATen/native/quantized/QTensor.cpp index 8a99c2b64a174..9d2fb454ea5d8 100644 --- a/aten/src/ATen/native/quantized/QTensor.cpp +++ b/aten/src/ATen/native/quantized/QTensor.cpp @@ -94,8 +94,10 @@ Tensor q_per_channel_zero_points_quant(const Tensor& self) { self.options().dtype(at::kLong)); } -Quantizer* quantizer(const Tensor& self) { - return get_qtensorimpl(self)->quantizer().get(); +IntArrayRef q_per_channel_axis_quant(const Tensor& self) { + auto quantizer = get_qtensorimpl(self)->quantizer(); + TORCH_CHECK(quantizer->qscheme() == kPerChannelAffine); + return static_cast(quantizer.get())->axis(); } Tensor int_repr_quant(const Tensor& self) { @@ -183,6 +185,10 @@ Tensor& set_quantizer_(Tensor& self, ConstQuantizerPtr quantizer) { } Tensor quantized_clone(const Tensor& self) { + // TODO: add per channel support + TORCH_INTERNAL_ASSERT( + self.qscheme() == at::kPerTensorAffine, + "clone for quantized Tensor only works for PerTensorAffine scheme right now"); Tensor dst = at::_empty_affine_quantized( self.sizes(), self.options(), self.q_scale(), self.q_zero_point()); diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h index 9cb681cec2b4c..f0f4aafcf3b2a 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.h @@ -1,6 +1,7 @@ #pragma once #include +#include #ifdef USE_FBGEMM #include "fbgemm/Fbgemm.h" #include "fbgemm/QuantUtils.h" @@ -16,6 +17,7 @@ // Note that in JIT mode we can think of a way to fuse col_offsets with bias. struct FBGEMM_API PackedLinearWeight { std::unique_ptr> w; + c10::optional bias; std::vector col_offsets; std::vector w_scale; std::vector w_zp; @@ -24,6 +26,7 @@ struct FBGEMM_API PackedLinearWeight { struct FBGEMM_API PackedConvWeight { std::unique_ptr> w; + c10::optional bias; std::vector col_offsets; std::vector kernel; std::vector w_scale; diff --git a/aten/src/ATen/native/quantized/cpu/init_qnnpack.cpp b/aten/src/ATen/native/quantized/cpu/init_qnnpack.cpp index ea9e123dad47f..abaa2000ecccf 100644 --- a/aten/src/ATen/native/quantized/cpu/init_qnnpack.cpp +++ b/aten/src/ATen/native/quantized/cpu/init_qnnpack.cpp @@ -1,19 +1,23 @@ -#ifdef USE_QNNPACK +#ifdef USE_PYTORCH_QNNPACK #include "init_qnnpack.h" #include #include -#include +#include namespace at { namespace native { + void initQNNPACK() { static std::once_flag once; - static enum qnnp_status qnnpackStatus = qnnp_status_uninitialized; - std::call_once(once, []() { qnnpackStatus = qnnp_initialize(); }); + static enum pytorch_qnnp_status qnnpackStatus = + pytorch_qnnp_status_uninitialized; + std::call_once(once, []() { qnnpackStatus = pytorch_qnnp_initialize(); }); TORCH_CHECK( - qnnpackStatus == qnnp_status_success, "failed to initialize QNNPACK"); + qnnpackStatus == pytorch_qnnp_status_success, + "failed to initialize QNNPACK"); } + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/init_qnnpack.h b/aten/src/ATen/native/quantized/cpu/init_qnnpack.h index f93a36c8b13e9..dbfb406ea55db 100644 --- a/aten/src/ATen/native/quantized/cpu/init_qnnpack.h +++ b/aten/src/ATen/native/quantized/cpu/init_qnnpack.h @@ -1,7 +1,6 @@ #pragma once -#ifdef USE_QNNPACK -#include "qnnpack.h" +#ifdef USE_PYTORCH_QNNPACK namespace at { namespace native { @@ -10,4 +9,5 @@ void initQNNPACK(); } // namespace native } // namespace at + #endif diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index 00a1585c39923..b2db3aba8426f 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -64,11 +64,12 @@ void qrelu6_kernel(const Tensor& qx, Tensor& qy) { template void qadd_kernel(Tensor& out, const Tensor& self, const Tensor& other) { int64_t zero_point = out.q_zero_point(); - double scale = out.q_scale(); + float scale = out.q_scale(); + float inv_scale = 1.0f / scale; int64_t self_zero_point = self.q_zero_point(); - double self_scale = self.q_scale(); + float self_scale = self.q_scale(); int64_t other_zero_point = other.q_zero_point(); - double other_scale = other.q_scale(); + float other_scale = other.q_scale(); // Broadcast out the parameters here to amortize out that cost across // loop iterations. @@ -79,6 +80,9 @@ void qadd_kernel(Tensor& out, const Tensor& self, const Tensor& other) { auto other_zero_point_vec = Vec256((float)other_zero_point); auto other_scale_vec = Vec256(other_scale); + auto self_scale_zp_premul_vec = self_scale_vec * self_zero_point_vec.neg(); + auto other_scale_zp_premul_vec = other_scale_vec * other_zero_point_vec.neg(); + auto iter = TensorIterator::binary_op(out, self, other); AT_DISPATCH_QINT_TYPES(out.scalar_type(), "qadd", [&]() { @@ -95,8 +99,10 @@ void qadd_kernel(Tensor& out, const Tensor& self, const Tensor& other) { return at::quantize_val(scale, zero_point, c); }, [&](Vec a, Vec b) -> Vec { - const auto da = a.dequantize(self_scale_vec, self_zero_point_vec); - const auto db = b.dequantize(other_scale_vec, other_zero_point_vec); + const auto da = a.dequantize( + self_scale_vec, self_zero_point_vec, self_scale_zp_premul_vec); + const auto db = b.dequantize( + other_scale_vec, other_zero_point_vec, other_scale_zp_premul_vec); Vec::float_vec_return_type retvals; for (int i = 0; i < Vec::float_num_vecs(); ++i) { auto c = da[i] + db[i]; @@ -111,18 +117,116 @@ void qadd_kernel(Tensor& out, const Tensor& self, const Tensor& other) { // TODO: specialize fbgemm::Quantize for a single vector and make it // inlineable. This could help with interleaving as suggested by the // TensorIterator implementations - auto rv = Vec::quantize(retvals, scale, zero_point); + auto rv = Vec::quantize(retvals, scale, zero_point, inv_scale); return rv; }); }); } +void qmaxpool_2d_nhwc_kernel(const Tensor &qx, + int64_t iC, // input/output channels + int64_t iH, + int64_t iW, // input sizes + int64_t oH, + int64_t oW, // output sizes + int64_t kH, + int64_t kW, // kernel size + int64_t sH, + int64_t sW, // strides + int64_t pH, + int64_t pW, // padding + int64_t dH, + int64_t dW, // dilation + Tensor &qy) { + AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "max_pool2d_nhwc", [&]() { + scalar_t *idata = static_cast(qx.data_ptr()); + scalar_t *odata = static_cast(qy.data_ptr()); + + // Loop over N + for (int64_t b = 0; b < qx.size(0); ++b) { + // Loop over H + auto *i_p = reinterpret_cast(idata + b * iW * iH * iC); + for (int64_t row = 0; row < oH; ++row) { + // Loop over W + for (int64_t col = 0; col < oW; ++col) { + // Pointer to output data for this specific N,H,W position + auto *o_p = reinterpret_cast(odata + b * oH * oW * iC + row * oW * iC + col * iC); + + // Loop over reduction block + int64_t h_start = row * sH - pH; + int64_t w_start = col * sW - pW; + int64_t h_end = std::min(h_start + (kH - 1) * dH + 1, iH); + int64_t w_end = std::min(w_start + (kW - 1) * dW + 1, iW); + while (h_start < 0) + h_start += dH; + while (w_start < 0) + w_start += dW; + + int64_t c = 0; + + // Interleaved vector loop 4x + constexpr auto vec_width = Vec256::size(); + for (; c + 4 * vec_width <= iC; c+= 4 * vec_width) { + Vec256 acc{scalar_t(std::numeric_limits::lowest())}; + Vec256 accs[4] = {acc, acc, acc, acc}; + int64_t tcntr = 0; + int64_t x, y; + for (y = h_start; y < h_end; y += dH) { + for (x = w_start; x < w_end; x += dW) { + for (int i = 0; i < 4; ++i) { + tcntr = y * iW + x; + auto vals = Vec256::loadu(i_p + tcntr * iC + c + Vec256::size() * i); + accs[i] = vec256::maximum(accs[i], vals); + } + } // for x + } // for y + for (int i = 0; i < 4; ++i) { + accs[i].store(o_p + c + Vec256::size() * i); + } + } // for c + + // Vector loop + for (; c + vec_width <= iC; c+= vec_width) { + Vec256 acc{scalar_t(std::numeric_limits::lowest())}; + int64_t tcntr = 0; + int64_t x, y; + for (y = h_start; y < h_end; y += dH) { + for (x = w_start; x < w_end; x += dW) { + tcntr = y * iW + x; + auto vals = Vec256::loadu(i_p + tcntr * iC + c); + acc = vec256::maximum(acc, vals); + } // for x + } // for y + acc.store(o_p + c); + } // for c + + for (; c < iC; ++c) { + auto max_val = std::numeric_limits::lowest(); + int64_t tcntr = 0; + int64_t x, y; + for (y = h_start; y < h_end; y += dH) { + for (x = w_start; x < w_end; x += dW) { + tcntr = y * iW + x; + auto val = *(i_p + tcntr * iC + c); + max_val = std::max(max_val, val); + } // for x + } // for y + + o_p[c] = max_val; + } // for c + } // for col + } // for row + } // for b + }); +} + } // namespace REGISTER_DISPATCH(qrelu_stub, &qrelu_kernel); REGISTER_DISPATCH(qrelu6_stub, &qrelu6_kernel); REGISTER_DISPATCH(qadd_relu_stub, &qadd_kernel); REGISTER_DISPATCH(qadd_stub, &qadd_kernel); +REGISTER_DISPATCH(qmaxpool_2d_nhwc_stub, &qmaxpool_2d_nhwc_kernel); } // namespace native } // namespace at \ No newline at end of file diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 8899d65eeaf8b..b1d4f6e13be23 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include namespace at { @@ -11,9 +12,9 @@ namespace { SmallVector convOutputShape( int N, // mini-batch + int K, // output channels int H, // input height int W, // input width - int K, // output channels const std::vector& kernel, const torch::List& stride, const torch::List& padding, @@ -27,6 +28,7 @@ SmallVector convOutputShape( (W + 2 * padding[1] - dilation[1] * (kernel[1] - 1) - 1) / stride[1] + 1); out_shape.push_back(H_out); out_shape.push_back(W_out); + // TODO: reorder it to NCHW order once the memory format regression is fixed out_shape.push_back(K); return out_shape; @@ -64,36 +66,56 @@ SmallVector convOutputShape( template class QConv2dInt8 final : public c10::OperatorKernel { public: + void conv_checks( + int64_t act_dims, + int64_t stride_dims, + int64_t padding_dims, + int64_t dilation_dims) { + TORCH_CHECK( + act_dims == 4, + "quantized::conv2d(): Expected activation tensor to have 4 dimensions."); + TORCH_CHECK( + stride_dims == 2, "quantized::conv2d(): Supports 2D convolution only"); + TORCH_CHECK( + padding_dims == 2, "quantized::conv2d(): Supports 2D convolution only"); + TORCH_CHECK( + dilation_dims == 2, + "quantized::conv2d(): Supports 2D convolution only"); + } #ifdef USE_FBGEMM - Tensor operator()( + at::Tensor fbgemm_conv( Tensor act, Tensor packed_weight, - c10::optional bias, torch::List stride, torch::List padding, torch::List dilation, int64_t groups, double output_scale, int64_t output_zero_point) { + // Quantized kernels are all written with NHWC (channels last) layout in + // mind. Ideally, we'd be compatible with conv2d behavior and preserve the + // inputs layout as is (doing necessary upconversions). + // + // However, to be more robust, for now we just force output layout to always + // be NHWC (channels last), thus opportunistically improving perf. + // + // This might change when full memory format support lands + // See https://github.com/pytorch/pytorch/issues/23403 TORCH_CHECK( fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); - TORCH_CHECK( - act.ndimension() == 4, - "Activations are supposed to have 4 dimensions."); - TORCH_CHECK(stride.size() == 2, "2D convolution only"); - TORCH_CHECK(padding.size() == 2, "2D convolution only"); - TORCH_CHECK(dilation.size() == 2, "2D convolution only"); - TORCH_CHECK( - (dilation[0] == 1 && dilation[1] == 1), - "Currently dilation should be 1"); + conv_checks( + act.ndimension(), stride.size(), padding.size(), dilation.size()); - // inputs are in NHWC format int N = act.size(0); - int H = act.size(1); - int W = act.size(2); - int C = act.size(3); + int C = act.size(1); + int H = act.size(2); + int W = act.size(3); - Tensor act_contig = act.contiguous(); + // FBGEMM requires NHWC + // TODO: change it to contiguous(MemoryFormat::ChannelsLast) once a perf + // regression of it is fixed. Today it's equivalent because `act` sizes + // are not used below + Tensor act_contig = act.permute({0, 2, 3, 1}).contiguous(); const uint8_t* act_ptr = reinterpret_cast(act_contig.data_ptr()); @@ -111,7 +133,14 @@ class QConv2dInt8 final : public c10::OperatorKernel { int stride_w = stride[1]; int kernel_h = kernel[0]; int kernel_w = kernel[1]; - + // clang-format off + TORCH_CHECK(C == (packB->inputChannels()), + "[QConv2D] Given groups=", groups, ", weight of size ", + K, ", ", kernel_h, ", ", kernel_w, ", ", packB->inputChannels(), + ", expected input (NCHW) ", N, ", ", C, ", ", H, ", ", W, + " to have ", (packB->inputChannels() * groups), + " channels, but got ", C, " channels instead"); + // clang-format on fbgemm::conv_param_t<> conv_p( N, // Batch size C, // Number of input channels @@ -120,65 +149,83 @@ class QConv2dInt8 final : public c10::OperatorKernel { groups, {kernel_h, kernel_w}, {stride_h, stride_w}, - {pad_l, pad_t, pad_l, pad_t}); + {pad_l, pad_t, pad_l, pad_t}, + {static_cast(dilation[0]), static_cast(dilation[1])}); fbgemm::DoNothing<> NoOpObj{}; - const int32_t* bias_ptr = nullptr; - if (bias.has_value()) { - Tensor bias_vec = bias.value(); - TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)"); + float act_scale = act.q_scale(); + int32_t act_zero_point = act.q_zero_point(); + + const float* bias_ptr = nullptr; + at::Tensor bias; + if (pack_ptr.bias.has_value()) { + bias = pack_ptr.bias.value(); + TORCH_CHECK( + bias.dtype() == at::kFloat, + "[QConv2D] The 'bias' tensor must have 'torch.float' dtype"); + bias = bias.contiguous(); + TORCH_CHECK(bias.dim() == 1, "bias should be a vector (1D Tensor)"); TORCH_CHECK( - bias_vec.size(0) == K, + bias.size(0) == K, "bias should have K elements: " + std::to_string(K)); - auto bias_contig = bias_vec.contiguous(); - bias_ptr = reinterpret_cast(bias_contig.data_ptr()); + bias_ptr = bias.data_ptr(); } - float act_scale = act.q_scale(); - int32_t act_zero_point = act.q_zero_point(); - std::vector output_multiplier_float(1, 0.0); + std::vector act_times_w_scale(1, 1.0); TORCH_CHECK( pack_ptr.w_scale.size() == pack_ptr.w_zp.size(), "Weight scales and zero points vectors should have the same size."); - // quantization scheme is PerTensorAffine if the number of scales is 1 and - // it's kPerChannelAffine if the number of scales is equal to K (output - // channels) - if (pack_ptr.w_scale.size() == 1) { + + if (pack_ptr.q_scheme == kPerTensorAffine) { + act_times_w_scale[0] = (act_scale * pack_ptr.w_scale[0]); output_multiplier_float[0] = - (act_scale * pack_ptr.w_scale[0]) / static_cast(output_scale); - } else if (pack_ptr.w_scale.size() == K) { + act_times_w_scale[0] / static_cast(output_scale); + } else if (pack_ptr.q_scheme == kPerChannelAffine) { output_multiplier_float.resize(K, 0.0); + act_times_w_scale.resize(K, 1.0); for (int i = 0; i < K; ++i) { - output_multiplier_float[i] = (act_scale * pack_ptr.w_scale[i]) / - static_cast(output_scale); + act_times_w_scale[i] = (act_scale * pack_ptr.w_scale[i]); + output_multiplier_float[i] = + act_times_w_scale[i] / static_cast(output_scale); } + } else { + TORCH_CHECK(false, "[QConv2D] Unknown quantization scheme"); } + // TODO: change convOutputShape to return NCHW sizes once perf is fixed auto outShape = - convOutputShape(N, H, W, K, kernel, stride, padding, dilation); + convOutputShape(N, K, H, W, kernel, stride, padding, dilation); TORCH_CHECK( std::all_of( outShape.begin(), outShape.end(), [](int64_t i) { return i > 0; }), "[QConv2D] each dimension of output tensor should be greater than 0") + // Force output format to be NHWC + // TODO: consider preserving input format + // TODO: add MemoryFormat::ChannelsLast here once perf is fixed Tensor output = _empty_affine_quantized( outShape, device(kCPU).dtype(kQUInt8), output_scale, output_zero_point); auto buffer = at::zeros_like(output, output.options().dtype(at::kInt)); if (pack_ptr.q_scheme == kPerTensorAffine) { - fbgemm::ReQuantizeOutput outputProcObj( - NoOpObj, - output_multiplier_float.data(), - output_zero_point, - act_zero_point, - pack_ptr.w_zp.data(), - nullptr, /* row offset buffer */ - col_offsets.data(), - bias_ptr, - K, - groups); + fbgemm::ReQuantizeOutput< + ReluFused, + fbgemm::QuantizationGranularity::TENSOR, + float> + outputProcObj( + NoOpObj, + output_multiplier_float.data(), + output_zero_point, + act_zero_point, + pack_ptr.w_zp.data(), + nullptr, /* row offset buffer */ + col_offsets.data(), + bias_ptr, + K, + groups, + act_times_w_scale.data()); fbgemm::fbgemmConv( conv_p, act_ptr, @@ -192,7 +239,8 @@ class QConv2dInt8 final : public c10::OperatorKernel { } else if (pack_ptr.q_scheme == kPerChannelAffine) { fbgemm::ReQuantizeOutput< ReluFused, - fbgemm::QuantizationGranularity::OUT_CHANNEL> + fbgemm::QuantizationGranularity::OUT_CHANNEL, + float> outputProcObj( NoOpObj, output_multiplier_float.data(), @@ -203,47 +251,180 @@ class QConv2dInt8 final : public c10::OperatorKernel { col_offsets.data(), bias_ptr, K, - groups); + groups, + act_times_w_scale.data()); fbgemm::fbgemmConv( conv_p, act_ptr, *packB, - reinterpret_cast(output.data()), - buffer.data(), + reinterpret_cast(output.data_ptr()), + buffer.data_ptr(), outputProcObj, 0 /* thread_id*/, 1 /* num_threads */); } - return output; + //TODO: remove permute once MemoryLayout is added above + return output.permute({0, 3, 1, 2}); } -#else // USE_FBGEMM - Tensor operator()( - Tensor /* activation */, - Tensor /* packed_weight */, - c10::optional /* bias */, - torch::List /* stride */, - torch::List /* padding */, - torch::List /* dilation */, - torch::List /* output padding */, - int64_t /* groups */, - double /* output scale */, - int64_t /* output_zero_point */) { +#endif +#ifdef USE_PYTORCH_QNNPACK + at::Tensor qnnpack_conv( + Tensor act, + Tensor packed_weight, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point) { + conv_checks( + act.ndimension(), stride.size(), padding.size(), dilation.size()); + + PackedConvWeightsQnnp& pack_ptr = + cpp_custom_type_hack::cast(packed_weight); + auto packB = pack_ptr.w.get(); + auto& kernel = pack_ptr.kernel; + auto kernel_zp = pack_ptr.w_zp; + auto kernel_scale = pack_ptr.w_scale; + + const uint32_t kernel_h = kernel[0]; + const uint32_t kernel_w = kernel[1]; + const auto out_ch = packB->getOutputChannels(); + // inputs are in semantic NCHW format + int N = act.size(0); + int in_ch = act.size(1); + int H = act.size(2); + int W = act.size(3); + int K = out_ch; // output channels + // TODO: change it to contiguous(MemoryFormat::ChannelsLast) once a perf + // regression of it is fixed. Today it's equivalent because `act` sizes + // are not used below + Tensor input_contig = act.permute({0, 2, 3, 1}).contiguous(); + + uint32_t stride_h = stride[0]; + uint32_t stride_w = stride[1]; + uint32_t pad_t = padding[0]; + uint32_t pad_l = padding[1]; + uint32_t dilation_h = dilation[0]; + uint32_t dilation_w = dilation[1]; + + auto output_min = ReluFused + ? activationLimits(output_scale, output_zero_point, Activation::RELU) + .first + : std::numeric_limits::min(); + auto output_max = ReluFused + ? activationLimits(output_scale, output_zero_point, Activation::RELU) + .second + : std::numeric_limits::max(); + qnnpack::conv_param_t conv_p( + {kernel_w, kernel_h}, + {stride_w, stride_h}, + {dilation_w, dilation_h}, + {pad_t, pad_l, pad_t, pad_l}, + groups, + in_ch, + out_ch, + kernel_zp, + kernel_scale, + output_min, + output_max); + + // TODO: change convOutputShape to return NCHW sizes once perf is fixed + // Force output format to be NHWC + // TODO: consider preserving input format + // TODO: add MemoryFormat::ChannelsLast here once perf is fixed + auto outShape = + convOutputShape(N, K, H, W, kernel, stride, padding, dilation); TORCH_CHECK( - false, - "This PyTorch installation was not built " - "with FBGEMM operators"); + std::all_of( + outShape.begin(), outShape.end(), [](int64_t i) { return i > 0; }), + "quantized::conv2d (qnnpack): each dimension of output tensor should be greater " + "than 0") + TORCH_CHECK( + (outShape[3] == out_ch), + "quantized::conv2d (qnnpack): Number of filters must be equal to number of " + "output channels") + + // Allocate output Tensor and a buffer for QNNPACK to use + Tensor output = at::_empty_affine_quantized( + outShape, + at::device(kCPU).dtype(kQUInt8), + output_scale, + output_zero_point); + + const pytorch_qnnp_status runStatus = qnnpack::qnnpackConv( + conv_p, + packB->getPackedWeights(), + N, + H, + W, + input_contig.q_scale(), + input_contig.q_zero_point(), + (uint8_t*)input_contig.data_ptr(), + output.q_scale(), + output.q_zero_point(), + (uint8_t*)output.data_ptr(), + nullptr); + + TORCH_INTERNAL_ASSERT( + runStatus == pytorch_qnnp_status_success, + "failed to run quantized::conv2d (qnnpack) operator"); + + //TODO: remove permute once MemoryLayout is added above + return output.permute({0, 3, 1, 2}); + } +#endif + Tensor operator()( + Tensor act, + Tensor packed_weight, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups, + double output_scale, + int64_t output_zero_point) { + auto& ctx = at::globalContext(); +#ifdef USE_FBGEMM + if (ctx.qEngine() == at::QEngine::FBGEMM) { + return fbgemm_conv( + act, + packed_weight, + stride, + padding, + dilation, + groups, + output_scale, + output_zero_point); + } +#endif +#ifdef USE_PYTORCH_QNNPACK + if (ctx.qEngine() == at::QEngine::QNNPACK) { + return qnnpack_conv( + act, + packed_weight, + stride, + padding, + dilation, + groups, + output_scale, + output_zero_point); + } +#endif + TORCH_INTERNAL_ASSERT( + "Didn't find engine for operation quantized::conv ", + toString(ctx.qEngine())); + return at::Tensor(); } -#endif // USE_FBGEMM }; static auto registry = c10::RegisterOperators() - .op("quantized::fbgemm_conv2d", + .op("quantized::conv2d", c10::RegisterOperators::options().kernel>( TensorTypeId::QuantizedCPUTensorId)) - .op("quantized::fbgemm_conv2d_relu", + .op("quantized::conv2d_relu", c10::RegisterOperators::options().kernel>( TensorTypeId::QuantizedCPUTensorId)); diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp index 27a1b0342ef94..460503e586f1b 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp @@ -2,6 +2,8 @@ #include #include #include +#include +#include #include namespace caffe2 { @@ -9,6 +11,10 @@ namespace caffe2 { // Required for cpp_custom_type_hack to work CAFFE_KNOWN_TYPE(PackedConvWeight); #endif +#ifdef USE_PYTORCH_QNNPACK +// Required for cpp_custom_type_hack to work +CAFFE_KNOWN_TYPE(PackedConvWeightsQnnp); +#endif // USE_PYTORCH_QNNPACK } // namespace caffe2 namespace at { @@ -17,8 +23,9 @@ namespace { class QConvPackWeightInt8 final : public c10::OperatorKernel { public: #ifdef USE_FBGEMM - Tensor operator()( + Tensor fbgemm_conv_prepack( Tensor weight, + c10::optional bias, torch::List stride, torch::List padding, torch::List dilation, @@ -30,16 +37,12 @@ class QConvPackWeightInt8 final : public c10::OperatorKernel { TORCH_CHECK( padding.size() == 2, "Specify top/left padding only. \ - bottom/right padding assumed to be equal to top/left"); + bottom/right padding assumed to be equal to top/left"); TORCH_CHECK(dilation.size() == 2, "2D convolution only"); - TORCH_CHECK( - (dilation[0] == 1 && dilation[1] == 1), - "Currently dilation should be 1"); - // weights in KRS(C/G) format int output_channels = weight.size(0); - int kernel_h = weight.size(1); - int kernel_w = weight.size(2); - int input_channels_per_group = weight.size(3); + int input_channels_per_group = weight.size(1); + int kernel_h = weight.size(2); + int kernel_w = weight.size(3); // mini-batch doesn't have any impact on how we pack weights // so we pass it as 1 @@ -56,18 +59,26 @@ class QConvPackWeightInt8 final : public c10::OperatorKernel { {static_cast(padding[0]), static_cast(padding[1]), static_cast(padding[0]), - static_cast(padding[1])}); + static_cast(padding[1])}, + {static_cast(dilation[0]), static_cast(dilation[1])}); - auto weight_contig = weight.contiguous(); + // FBGEMM expects weights to be in channels last + auto weight_contig = weight.contiguous(MemoryFormat::ChannelsLast); const auto qtype = weight.qscheme(); std::vector zero_points(1, 0); if (qtype == kPerTensorAffine) { zero_points[0] = weight.q_zero_point(); } else if (qtype == kPerChannelAffine) { + auto axis = weight.q_per_channel_axis(); + TORCH_CHECK( + axis.size() == 1 && axis[0] == 0, + "Only per output channel quantization is supported for the weights"); zero_points.resize(output_channels, 0); for (int i = 0; i < output_channels; ++i) { zero_points[i] = weight.q_per_channel_zero_points()[i].item(); } + } else { + TORCH_CHECK(false, "Unsupported qscheme: ", toString(qtype)); } const int8_t* weight_ptr_int8 = @@ -104,10 +115,19 @@ class QConvPackWeightInt8 final : public c10::OperatorKernel { scales[i] = weight.q_per_channel_scales()[i].item(); } } - + c10::optional bias_contig; + if (bias.has_value()) { + Tensor bias_vec = bias.value(); + TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)"); + TORCH_CHECK( + bias_vec.size(0) == output_channels, + "bias should have K elements: " + std::to_string(output_channels)); + bias_contig = bias->contiguous(); + } auto ret_ptr = guts::make_unique( PackedConvWeight{guts::make_unique>( conv_p, weight_ptr_int8), + bias_contig, col_offsets, {kernel_h, kernel_w}, scales, @@ -117,22 +137,124 @@ class QConvPackWeightInt8 final : public c10::OperatorKernel { // point. return cpp_custom_type_hack::create(std::move(ret_ptr), weight.options()); } -#else // USE_FBGEMM - Tensor operator()( - Tensor, /* weight */ - torch::List, /* stride */ - torch::List, /* padding */ - torch::List, /* dilation */ - int64_t /* groups */ - ) { +#endif // USE_FBGEMM +#ifdef USE_PYTORCH_QNNPACK + at::Tensor qnnpack_conv_prepack( + Tensor weight, + c10::optional bias_in, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups) { + TORCH_CHECK( + weight.ndimension() == 4, + "quantized::conv_prepack (qnnpack): Weights are expected to have 4 dimensions"); + const auto qtype = weight.qscheme(); + TORCH_CHECK( + weight.qscheme() == kPerTensorAffine, + "quantized::conv_prepack (qnnpack): only supports Per Tensor Quantization Scheme") + TORCH_CHECK( + stride.size() == 2, + "quantized::conv_prepack (qnnpack): 2D convolution only"); TORCH_CHECK( - false, "This PyTorch installation was not built with FBGEMM operators"); + padding.size() == 2, + "quantized::conv_prepack (qnnpack): Specify top/left padding only. \ + bottom/right padding assumed to be equal to top/left"); + TORCH_CHECK( + dilation.size() == 2, + " quantized::conv_prepack (qnnpack): 2D convolution only"); + + initQNNPACK(); + + // QNNPACK expects weights to be of the format {out_c, kH, kW, in_c/groups}, + // but PyTorch lays them out as {out_c, in_c/groups, kH, kW} + const size_t out_ch = weight.size(0); + const size_t in_ch = weight.size(1) * groups; + const uint32_t kernel_h = weight.size(2); + const uint32_t kernel_w = weight.size(3); + + Tensor bias; + if (bias_in.has_value()) { + bias = bias_in.value(); + } else { + bias = at::empty(out_ch, at::kFloat); + bias = at::quantize_linear(bias, 1.0, 0, kQInt32); + } + TORCH_CHECK( + !bias.defined() || (bias.ndimension() == 1 && bias.size(0) == out_ch), + "quantized::conv_prepack (qnnpack): expected bias to be 1-dimensional with ", + out_ch, + " elements", + ", but got bias of size ", + bias.sizes(), + " instead"); + + uint32_t stride_h = stride[0]; + uint32_t stride_w = stride[1]; + uint32_t pad_t = padding[0]; + uint32_t pad_l = padding[1]; + uint32_t dilation_h = dilation[0]; + uint32_t dilation_w = dilation[1]; + + qnnpack::conv_param_t conv_p( + {kernel_w, kernel_h}, + {stride_w, stride_h}, + {dilation_w, dilation_h}, + {pad_t, pad_l, pad_t, pad_l}, + groups, + in_ch, + out_ch, + weight.q_zero_point(), + weight.q_scale(), + std::numeric_limits::min(), + std::numeric_limits::max()); + + auto weight_contig = weight.contiguous(MemoryFormat::ChannelsLast); + auto bias_contig = bias.contiguous(); + auto wt_ptr = + guts::make_unique(PackedConvWeightsQnnp{ + guts::make_unique( + conv_p, + (uint8_t*)weight_contig.data_ptr(), + (int32_t*)bias_contig.data_ptr()), + weight_contig, + bias_contig, + {kernel_h, kernel_w}, + weight.q_scale(), + weight.q_zero_point()}); + + return cpp_custom_type_hack::create(std::move(wt_ptr), weight.options()); + } +#endif // USE_PYTORCH_QNNPACK + Tensor operator()( + Tensor weight, + c10::optional bias, + torch::List stride, + torch::List padding, + torch::List dilation, + int64_t groups) { + auto& ctx = at::globalContext(); +#ifdef USE_FBGEMM + if (ctx.qEngine() == at::QEngine::FBGEMM) { + return fbgemm_conv_prepack( + weight, bias, stride, padding, dilation, groups); + } +#endif +#ifdef USE_PYTORCH_QNNPACK + if (ctx.qEngine() == at::QEngine::QNNPACK) { + return qnnpack_conv_prepack( + weight, bias, stride, padding, dilation, groups); + } +#endif + TORCH_INTERNAL_ASSERT( + "Didn't find engine for operation quantized::conv_prepack ", + toString(ctx.qEngine())); + return at::Tensor(); } -#endif // USE_FBGEMM }; static auto registry = c10::RegisterOperators().op( - "quantized::fbgemm_conv_prepack", + "quantized::conv_prepack", c10::RegisterOperators::options().kernel( TensorTypeId::QuantizedCPUTensorId)); diff --git a/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp index 80de92aef9c07..f262fb04611ec 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp @@ -2,6 +2,7 @@ #include #include #include +#include namespace at { namespace native { @@ -16,7 +17,8 @@ namespace { class QConvUnpackWeightsInt8 final : public c10::OperatorKernel { public: #ifdef USE_FBGEMM - Tensor operator()(Tensor packed_weights) { + std::tuple> fbgemm_conv_unpack( + at::Tensor packed_weights) { // Pull out the packed weight instance from the owning tensor. auto& pack_ptr = cpp_custom_type_hack::cast(packed_weights); @@ -34,53 +36,77 @@ class QConvUnpackWeightsInt8 final : public c10::OperatorKernel { int C_per_G = input_channels / groups; // Tensor for unpacked weights - // Unpacked format would be KRS(C/G) - Tensor unpacked_weights; - if (pack_ptr.q_scheme == kPerTensorAffine) { - unpacked_weights = _empty_affine_quantized( - {output_channels, kernel_h, kernel_w, C_per_G}, - device(kCPU).dtype(kQInt8), - pack_ptr.w_scale[0], - pack_ptr.w_zp[0]); - } else if (pack_ptr.q_scheme == kPerChannelAffine) { - auto scales = from_blob( - pack_ptr.w_scale.data(), - pack_ptr.w_scale.size(), - device(kCPU).dtype(kFloat)); - auto zero_points = from_blob( - pack_ptr.w_zp.data(), - pack_ptr.w_zp.size(), - device(kCPU).dtype(kInt)); + // Unpacked format would be physical KRS(C/G) but logical KCRS (channels first) + // because that's how FBGEMM stores the weights + Tensor unpacked_weights; + if (pack_ptr.q_scheme == kPerTensorAffine) { + unpacked_weights = _empty_affine_quantized( + {output_channels, C_per_G, kernel_h, kernel_w}, + device(kCPU).dtype(kQInt8), + pack_ptr.w_scale[0], + pack_ptr.w_zp[0], + MemoryFormat::ChannelsLast); + } else if (pack_ptr.q_scheme == kPerChannelAffine) { + auto scales = from_blob( + pack_ptr.w_scale.data(), + pack_ptr.w_scale.size(), + device(kCPU).dtype(kFloat)); + auto zero_points = from_blob( + pack_ptr.w_zp.data(), pack_ptr.w_zp.size(), device(kCPU).dtype(kInt)); - unpacked_weights = _empty_per_channel_affine_quantized_like( - scales.toType(kDouble), - zero_points.toType(kLong), - {output_channels, kernel_h, kernel_w, C_per_G}, - {0}, /* The output channel axis is 0 */ - device(kCPU).dtype(kQInt8)); - } + unpacked_weights = _empty_per_channel_affine_quantized_like( + scales.toType(kDouble), + zero_points.toType(kLong), + {output_channels, C_per_G, kernel_h, kernel_w}, + {0}, /* The output channel axis is 0 */ + device(kCPU).dtype(kQInt8), + MemoryFormat::ChannelsLast); + } else { + TORCH_CHECK(false, "Unsupported qscheme: ", toString(pack_ptr.q_scheme)); + } int8_t* unpacked_weights_p = reinterpret_cast(unpacked_weights.data_ptr()); packed_weights_p->unpack(unpacked_weights_p); - return unpacked_weights; + return std::tuple>( + unpacked_weights, pack_ptr.bias); } -#else // USE_FBGEMM - Tensor operator()(Tensor /* weight */ - ) { - // We make a strong guarantee that models using these operators will have - // the same numerics across different machines. Therefore, we do not provide - // a fallback path and rather fail loudly if we cannot run FBGEMM. - TORCH_CHECK( - false, "This PyTorch installation was not built with FBGEMM operators"); +#endif +#ifdef USE_PYTORCH_QNNPACK + std::tuple> qnnpack_conv_unpack( + at::Tensor packed_weight) { + auto& pack_ptr = + cpp_custom_type_hack::cast(packed_weight); + return std::tuple>( + pack_ptr.orig_weight, pack_ptr.bias); + } +#endif + std::tuple> operator()( + Tensor packed_weights) { + auto& ctx = at::globalContext(); + +#ifdef USE_FBGEMM + if (ctx.qEngine() == at::QEngine::FBGEMM) { + return fbgemm_conv_unpack(packed_weights); + } +#endif +#ifdef USE_PYTORCH_QNNPACK + if (ctx.qEngine() == at::QEngine::QNNPACK) { + return qnnpack_conv_unpack(packed_weights); + } +#endif + TORCH_INTERNAL_ASSERT( + "Didn't find engine for operation quantized::conv_unpack ", + toString(ctx.qEngine())); + return std::tuple>( + at::Tensor(), at::Tensor()); } -#endif // USE_FBGEMM }; static auto registry = c10::RegisterOperators().op( - "quantized::fbgemm_conv_unpack(Tensor packed_weights)" - " -> Tensor unpacked_weights", + "quantized::conv_unpack(Tensor packed_weights)" + " -> (Tensor unpacked_weights, Tensor? B_origin)", c10::RegisterOperators::options().kernel( TensorTypeId::CPUTensorId)); diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp index 74142233ed8d7..38f26707ea766 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include @@ -14,10 +15,9 @@ template class QLinearInt8 final : public torch::OperatorKernel { public: #ifdef USE_FBGEMM - at::Tensor operator()( + at::Tensor fbgemm_linear( at::Tensor input, at::Tensor packed_weight, - c10::optional bias, double output_scale, int64_t output_zero_point) { // uint8 * int8 -> uint8 (no quantization/dequantization) @@ -58,22 +58,23 @@ class QLinearInt8 final : public torch::OperatorKernel { int32_t input_zero_point_int32 = input.q_zero_point(); std::vector output_multiplier_float(1, 0.0); + std::vector act_times_w_scale(1, 0.0); TORCH_CHECK( pack_ptr.w_scale.size() == pack_ptr.w_zp.size(), "Weight scales and zero points vectors should have the same size."); - // quantization scheme is PerTensorAffine if the number of scales is - // 1 and it's kPerChannelAffine if the number of scales is equal to - // N (output channels) if (pack_ptr.q_scheme == kPerTensorAffine) { // Process the per tensor quantization. - output_multiplier_float[0] = (input_scale_float * pack_ptr.w_scale[0]) / - static_cast(output_scale); + act_times_w_scale[0] = (input_scale_float * pack_ptr.w_scale[0]); + output_multiplier_float[0] = + act_times_w_scale[0] / static_cast(output_scale); } else if (pack_ptr.q_scheme == kPerChannelAffine) { // Process the per channel quantization. output_multiplier_float.resize(N, 0.0); + act_times_w_scale.resize(N, 1.0f); for (int i = 0; i < N; ++i) { - output_multiplier_float[i] = (input_scale_float * pack_ptr.w_scale[i]) / - static_cast(output_scale); + act_times_w_scale[i] = (input_scale_float * pack_ptr.w_scale[i]); + output_multiplier_float[i] = + act_times_w_scale[i] / static_cast(output_scale); } } int32_t output_zero_point_int32 = static_cast(output_zero_point); @@ -107,17 +108,16 @@ class QLinearInt8 final : public torch::OperatorKernel { // This is the end of the pipeline, pass the resulting matrix through. fbgemm::DoNothing<> doNothingObj{}; - const int32_t* bias_ptr = nullptr; - if (bias.has_value()) { - Tensor bias_vec = bias.value(); - TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)"); + const float* bias_ptr = nullptr; + at::Tensor bias; + if (pack_ptr.bias.has_value()) { + bias = pack_ptr.bias.value(); + bias = bias.contiguous(); + TORCH_CHECK(bias.dim() == 1, "bias should be a vector (1D Tensor)"); TORCH_CHECK( - bias_vec.size(0) == N, + bias.size(0) == N, "bias should have N elements: " + std::to_string(N)); - // TODO: contiguous is called for further jit optimizations. - auto bias_contig = bias_vec.contiguous(); - bias_ptr = - reinterpret_cast(bias_contig.data_ptr()); + bias_ptr = reinterpret_cast(bias.data_ptr()); } // The resulting matrix here is 2-D, let's view it with the original @@ -143,16 +143,22 @@ class QLinearInt8 final : public torch::OperatorKernel { // 1) Add in row and column offsets to the rows and columns, // respectively. // 2) Add in the bias term. - fbgemm::ReQuantizeOutput outputProcObj( - /*nextop=*/doNothingObj, - /*C_multiplier=*/output_multiplier_float.data(), - /*C_zero_point=*/output_zero_point_int32, - /*Aq_zero_point=*/input_zero_point_int32, - /*Bq_zero_point=*/pack_ptr.w_zp.data(), - /*row_offsets=*/packA.getRowOffsetBuffer(), - /*col_offsets=*/col_offsets.data(), - /*bias=*/bias_ptr, - /*nCol=*/N); + fbgemm::ReQuantizeOutput< + ReluFused, + fbgemm::QuantizationGranularity::TENSOR, + float> + outputProcObj( + doNothingObj, + output_multiplier_float.data(), + output_zero_point_int32, + input_zero_point_int32, + pack_ptr.w_zp.data(), + packA.getRowOffsetBuffer(), + col_offsets.data(), + bias_ptr, + N, /* nCol */ + 1 /* groups */, + act_times_w_scale.data()); // Do the GEMM fbgemm::fbgemmPacked( @@ -174,17 +180,20 @@ class QLinearInt8 final : public torch::OperatorKernel { // 2) Add in the bias term. fbgemm::ReQuantizeOutput< ReluFused, - fbgemm::QuantizationGranularity::OUT_CHANNEL> + fbgemm::QuantizationGranularity::OUT_CHANNEL, + float> outputProcObj( - /*nextop=*/doNothingObj, - /*C_multiplier=*/output_multiplier_float.data(), - /*C_zero_point=*/output_zero_point_int32, - /*Aq_zero_point=*/input_zero_point_int32, - /*Bq_zero_point=*/pack_ptr.w_zp.data(), - /*row_offsets=*/packA.getRowOffsetBuffer(), - /*col_offsets=*/col_offsets.data(), - /*bias=*/bias_ptr, - /*nCol=*/N); + doNothingObj, + output_multiplier_float.data(), + output_zero_point_int32, + input_zero_point_int32, + pack_ptr.w_zp.data(), + packA.getRowOffsetBuffer(), + col_offsets.data(), + bias_ptr, + N, /*nCol=*/ + 1, /* groups*/ + act_times_w_scale.data()); // Do the GEMM fbgemm::fbgemmPacked( @@ -197,31 +206,116 @@ class QLinearInt8 final : public torch::OperatorKernel { /*thread_id=*/0, /*num_threads=*/1); } + return output; + } +#endif +#ifdef USE_PYTORCH_QNNPACK + at::Tensor qnnpack_linear( + at::Tensor input, + at::Tensor packed_weight, + double output_scale, + int64_t output_zero_point) { + TORCH_CHECK( + input.dim() >= 2, + "quantized::linear(): Input tensor rank should be >= 2"); + auto input_contig = input.contiguous(); + + auto& pack_ptr = + cpp_custom_type_hack::cast(packed_weight); + auto packB = pack_ptr.w.get(); + auto kernel_zp = pack_ptr.w_zp; + auto kernel_scale = pack_ptr.w_scale; + + size_t rows_input = 1; + size_t cols_input = input_contig.size(input_contig.dim() - 1); + for (size_t i = 0; i < input_contig.dim() - 1; ++i) { + rows_input *= input_contig.size(i); + } + + size_t rows_w = packB->getOutputChannels(); + size_t cols_w = packB->getInputChannels(); + + TORCH_CHECK( + cols_input == cols_w, + "quantized::linear(): input size does not match weight dimension 1 size: \ + got ", + cols_input, + " but expected ", + cols_w); + + // Allocate output Tensor and a buffer for QNNPACK to use + Tensor output = at::_empty_affine_quantized( + {static_cast(rows_input), static_cast(rows_w)}, + input.options(), + output_scale, + output_zero_point); + + auto output_min = ReluFused + ? activationLimits(output_scale, output_zero_point, Activation::RELU) + .first + : std::numeric_limits::min(); + auto output_max = ReluFused + ? activationLimits(output_scale, output_zero_point, Activation::RELU) + .second + : std::numeric_limits::max(); + const pytorch_qnnp_status runStatus = qnnpack::qnnpackLinear( + rows_input /* batch_size */, + cols_input /* input_channels */, + rows_w /* output_channels */, + input_contig.q_zero_point(), + input_contig.q_scale(), + kernel_zp, + kernel_scale, + output_zero_point, + output_scale, + output_min, + output_max, + (uint8_t*)input_contig.data_ptr(), + cols_input /* input_stride */, + packB->getPackedWeights(), + (uint8_t*)output.data_ptr(), + rows_w /* output_stride */, + nullptr /* threadpool */); + + TORCH_INTERNAL_ASSERT( + runStatus == pytorch_qnnp_status_success, + "failed to run QNNPACK Linear operator"); return output; } -#else // USE_FBGEMM +#endif at::Tensor operator()( - at::Tensor /* input */, - at::Tensor /* packed_weight */, - c10::optional /* bias */, - double /* output_scale */, - int64_t /* output_zero_point */) { - // We make a strong guarantee that models using these operators will have - // the same numerics across different machines. Therefore, we do not provide - // a fallback path and rather fail loudly if we cannot run FBGEMM. - TORCH_CHECK( - false, "This PyTorch installation was not built with FBGEMM operators"); + at::Tensor input, + at::Tensor packed_weight, + double output_scale, + int64_t output_zero_point) { + auto& ctx = at::globalContext(); + +#ifdef USE_FBGEMM + if (ctx.qEngine() == at::QEngine::FBGEMM) { + return fbgemm_linear( + input, packed_weight, output_scale, output_zero_point); + } +#endif +#ifdef USE_PYTORCH_QNNPACK + if (ctx.qEngine() == at::QEngine::QNNPACK) { + return qnnpack_linear( + input, packed_weight, output_scale, output_zero_point); + } +#endif + TORCH_INTERNAL_ASSERT( + "Didn't find engine for operation quantized::linear ", + toString(ctx.qEngine())); + return at::Tensor(); } -#endif // USE_FBGEMM }; static auto registry = torch::RegisterOperators() - .op("quantized::fbgemm_linear(Tensor X, Tensor W_prepack, Tensor? b, float Y_scale_i, int Y_zero_point_i) -> Tensor Y", + .op("quantized::linear(Tensor X, Tensor W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y", torch::RegisterOperators::options().kernel>( TensorTypeId::QuantizedCPUTensorId)) - .op("quantized::fbgemm_linear_relu(Tensor X, Tensor W_prepack, Tensor? b, float Y_scale_i, int Y_zero_point_i) -> Tensor Y", + .op("quantized::linear_relu(Tensor X, Tensor W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y", torch::RegisterOperators::options().kernel>( TensorTypeId::QuantizedCPUTensorId)); } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp index bd5eb2f4387b0..c4491e6fe95d3 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -16,8 +16,7 @@ class QLinearDynamicInt8 final : public torch::OperatorKernel { #ifdef USE_FBGEMM at::Tensor operator()( at::Tensor input, - at::Tensor packed_weight, - c10::optional bias) { + at::Tensor packed_weight) { // fp32 * int8 -> fp32 (with quantization on activation, and dequantization // on the result). @@ -110,8 +109,9 @@ class QLinearDynamicInt8 final : public torch::OperatorKernel { fbgemm::DoNothing doNothingObj{}; const float* bias_ptr = nullptr; - if (bias.has_value()) { - Tensor bias_vec = bias.value(); + at::Tensor bias_vec; + if (pack_ptr.bias.has_value()) { + bias_vec = pack_ptr.bias.value(); TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)"); TORCH_CHECK( bias_vec.size(0) == N, @@ -120,7 +120,6 @@ class QLinearDynamicInt8 final : public torch::OperatorKernel { auto bias_contig = bias_vec.contiguous(); bias_ptr = bias_contig.data_ptr(); } - // The resulting matrix here is 2-D, let's view it with the original // left hand dimensions of the input. Here are two examples: // 1. If the input tensor is {M, K}, the output tensor is {M, N}. @@ -200,8 +199,7 @@ class QLinearDynamicInt8 final : public torch::OperatorKernel { #else // USE_FBGEMM at::Tensor operator()( at::Tensor /* input */, - at::Tensor /* packed_weight */, - c10::optional /* bias */) { + at::Tensor /* packed_weight */) { // We make a strong guarantee that models using these operators will have // the same numerics across different machines. Therefore, we do not provide // a fallback path and rather fail loudly if we cannot run FBGEMM. @@ -213,10 +211,10 @@ class QLinearDynamicInt8 final : public torch::OperatorKernel { static auto registry = torch::RegisterOperators() - .op("quantized::fbgemm_linear_dynamic(Tensor X, Tensor W_prepack, Tensor? b) -> Tensor Y", + .op("quantized::linear_dynamic(Tensor X, Tensor W_prepack) -> Tensor Y", torch::RegisterOperators::options() .kernel>(TensorTypeId::CPUTensorId)) - .op("quantized::fbgemm_linear_relu_dynamic(Tensor X, Tensor W_prepack, Tensor? b) -> Tensor Y", + .op("quantized::linear_relu_dynamic(Tensor X, Tensor W_prepack) -> Tensor Y", torch::RegisterOperators::options() .kernel>(TensorTypeId::CPUTensorId)); } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index a6f8d59053b9d..fa1b3a0867d73 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -2,8 +2,9 @@ #include #include #include +#include +#include #include - #include #include @@ -12,6 +13,10 @@ namespace caffe2 { // Required for cpp_custom_type_hack to work CAFFE_KNOWN_TYPE(PackedLinearWeight); #endif // USE_FBGEMM +#ifdef USE_PYTORCH_QNNPACK +// Required for cpp_custom_type_hack to work +CAFFE_KNOWN_TYPE(PackedLinearWeightsQnnp); +#endif // USE_PYTORCH_QNNPACK } // namespace caffe2 namespace at { @@ -44,11 +49,13 @@ class QLinearPackWeightInt8 final : public c10::OperatorKernel { } } } - - at::Tensor operator()(at::Tensor weight) { + at::Tensor fbgemm_linear_prepack( + at::Tensor weight, + c10::optional bias) { TORCH_CHECK( weight.dim() == 2, - "The weight tensor for quantized::fbgemm_linear_prepack should be 2-dimensional."); + "The weight tensor for quantized::linear_prepack (fbgemm) should" + " be 2-dimensional."); auto N = weight.size(0); auto K = weight.size(1); @@ -88,6 +95,15 @@ class QLinearPackWeightInt8 final : public c10::OperatorKernel { /*col_offsets=*/col_offsets.data(), /*qtype=*/qtype); + c10::optional bias_contig; + if (bias.has_value()) { + Tensor bias_vec = bias.value(); + TORCH_CHECK(bias_vec.dim() == 1, "bias should be a vector (1D Tensor)"); + TORCH_CHECK( + bias_vec.size(0) == N, + "bias should have N elements: " + std::to_string(N)); + bias_contig = bias->contiguous(); + } auto ret_ptr = guts::make_unique(PackedLinearWeight{ guts::make_unique>( /*trans=*/fbgemm::matrix_op_t::Transpose, @@ -97,6 +113,7 @@ class QLinearPackWeightInt8 final : public c10::OperatorKernel { /*ld=*/K, /*pmat=*/nullptr, // PackBMatrix manages ownership of pmat /*groups=*/1), + bias_contig, col_offsets, weight_scales_float, weight_zero_points_int32, @@ -106,20 +123,81 @@ class QLinearPackWeightInt8 final : public c10::OperatorKernel { // point. return cpp_custom_type_hack::create(std::move(ret_ptr), weight.options()); } -#else // USE_FBGEMM - at::Tensor operator()(at::Tensor /* weight */ - ) { - // We make a strong guarantee that models using these operators will have - // the same numerics across different machines. Therefore, we do not provide - // a fallback path and rather fail loudly if we cannot run FBGEMM. +#endif +#ifdef USE_PYTORCH_QNNPACK + at::Tensor qnnpack_linear_prepack( + at::Tensor weight, + c10::optional bias_in) { + TORCH_CHECK( + weight.dim() == 2, + "quantized::linear_prepack (qnnpack): Weight tensor rank should be == 2"); + TORCH_CHECK( + weight.qscheme() == kPerTensorAffine, + "quantized::linear_prepack (qnnpack) only supports Per Tensor Quantization Scheme") + + int64_t rows_w = weight.size(0); + int64_t cols_w = weight.size(1); + Tensor bias; + if (bias_in.has_value()) { + bias = bias_in.value(); + } else { + bias = at::zeros(rows_w, at::kFloat); + bias = at::quantize_linear(bias, 1.0, 0, kQInt32); + } TORCH_CHECK( - false, "This PyTorch installation was not built with FBGEMM operators"); + !bias.defined() || (bias.ndimension() == 1 && bias.size(0) == rows_w), + "quantized::linear_prepack (qnnpack): Given weight of size ", + weight.sizes(), + ", expected bias to be 1-dimensional with ", + rows_w, + " elements", + ", but got bias of size ", + bias.sizes(), + " instead"); + + Tensor weight_contig = weight.contiguous(); + Tensor bias_contig = bias.contiguous(); + + initQNNPACK(); + + auto wt_ptr = + guts::make_unique(PackedLinearWeightsQnnp{ + guts::make_unique( + cols_w /* input_channels */, + rows_w /* output_channels */, + weight.q_zero_point(), + weight.q_scale(), + (uint8_t*)weight_contig.data_ptr(), + (int32_t*)bias_contig.data_ptr()), + weight_contig, + bias_contig, + weight.q_scale(), + weight.q_zero_point()}); + return cpp_custom_type_hack::create(std::move(wt_ptr), weight.options()); + } +#endif + at::Tensor operator()(at::Tensor weight, c10::optional bias) { + auto& ctx = at::globalContext(); + +#ifdef USE_FBGEMM + if (ctx.qEngine() == at::QEngine::FBGEMM) { + return fbgemm_linear_prepack(weight, bias); + } +#endif +#ifdef USE_PYTORCH_QNNPACK + if (ctx.qEngine() == at::QEngine::QNNPACK) { + return qnnpack_linear_prepack(weight, bias); + } +#endif + TORCH_INTERNAL_ASSERT( + "Didn't find engine for operation quantized::linear_prepack ", + toString(ctx.qEngine())); + return at::Tensor(); } -#endif // USE_FBGEMM }; static auto registry = c10::RegisterOperators().op( - "quantized::fbgemm_linear_prepack(Tensor W) -> Tensor W_prepack", + "quantized::linear_prepack(Tensor W, Tensor? B=None) -> Tensor W_prepack", c10::RegisterOperators::options().kernel( TensorTypeId::QuantizedCPUTensorId)); } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp index 9579056284691..d2617ce93309c 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp @@ -2,6 +2,7 @@ #include #include #include +#include namespace at { namespace native { @@ -10,7 +11,8 @@ namespace { class QLinearUnpackWeightInt8 final : public c10::OperatorKernel { public: #ifdef USE_FBGEMM - at::Tensor operator()(at::Tensor packed_weight) { + std::tuple> fbgemm_linear_unpack( + at::Tensor packed_weight) { // Pull out the PackBMatrix instance from the owning tensor. auto& pack_ptr = cpp_custom_type_hack::cast(packed_weight); @@ -49,22 +51,43 @@ class QLinearUnpackWeightInt8 final : public c10::OperatorKernel { // (QLinearUnpackWeightInt8): "); packB->unpack(weight_ptr_int8); - return weight_origin; - } -#else // USE_FBGEMM - at::Tensor operator()(at::Tensor /* weight */ - ) { - // We make a strong guarantee that models using these operators will have - // the same numerics across different machines. Therefore, we do not provide - // a fallback path and rather fail loudly if we cannot run FBGEMM. - TORCH_CHECK( - false, "This PyTorch installation was not built with FBGEMM operators"); + return std::tuple>( + weight_origin, pack_ptr.bias); } #endif // USE_FBGEMM +#ifdef USE_PYTORCH_QNNPACK + std::tuple> qnnpack_linear_unpack( + at::Tensor packed_weight) { + auto& pack_ptr = + cpp_custom_type_hack::cast(packed_weight); + return std::tuple>( + pack_ptr.orig_weight, pack_ptr.bias); + } +#endif // USE_PYTORCH_QNNPACK + std::tuple> operator()( + at::Tensor packed_weight) { + auto& ctx = at::globalContext(); + +#ifdef USE_FBGEMM + if (ctx.qEngine() == at::QEngine::FBGEMM) { + return fbgemm_linear_unpack(packed_weight); + } +#endif +#ifdef USE_PYTORCH_QNNPACK + if (ctx.qEngine() == at::QEngine::QNNPACK) { + return qnnpack_linear_unpack(packed_weight); + } +#endif + TORCH_INTERNAL_ASSERT( + "Didn't find engine for operation quantized::linear_unpack ", + toString(ctx.qEngine())); + return std::tuple>( + at::Tensor(), at::Tensor()); + } }; static auto registry = c10::RegisterOperators().op( - "quantized::fbgemm_linear_unpack(Tensor W_prepack) -> Tensor W_origin", + "quantized::linear_unpack(Tensor W_prepack) -> (Tensor W_origin, Tensor? B_origin)", c10::RegisterOperators::options().kernel( TensorTypeId::CPUTensorId)); diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/.gitignore b/aten/src/ATen/native/quantized/cpu/qnnpack/.gitignore new file mode 100644 index 0000000000000..3af7dcc3aa24c --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/.gitignore @@ -0,0 +1,24 @@ +# Ninja files +build.ninja + +# Build objects and artifacts +deps/ +build/ +build-*/ +bin/ +lib/ +out/ +obj/ +libs/ +*.pyc +*.pyo +*.log + +# System files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt b/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt new file mode 100644 index 0000000000000..1232e889977ad --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt @@ -0,0 +1,694 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +CMAKE_MINIMUM_REQUIRED(VERSION 3.5 FATAL_ERROR) + +INCLUDE(GNUInstallDirs) + +# ---[ Project and semantic versioning. +PROJECT(PYTORCH_QNNPACK C CXX ASM) + +# ---[ Options. +OPTION(PYTORCH_QNNPACK_CUSTOM_THREADPOOL "Build QNNPACK for custom thread pool" OFF) +SET(PYTORCH_QNNPACK_LIBRARY_TYPE "default" CACHE STRING "Type of library (shared, static, or default) to build") +SET_PROPERTY(CACHE PYTORCH_QNNPACK_LIBRARY_TYPE PROPERTY STRINGS default static shared) +OPTION(PYTORCH_QNNPACK_BUILD_TESTS "Build QNNPACK unit tests" ON) +OPTION(PYTORCH_QNNPACK_BUILD_BENCHMARKS "Build QNNPACK benchmarks" ON) + +# Enable runtime requantization. +ADD_DEFINITIONS(-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION=1) + +# ---[ CMake options +IF(PYTORCH_QNNPACK_BUILD_TESTS) + ENABLE_TESTING() +ENDIF() + +# ---[ Build flags +IF(NOT CMAKE_SYSTEM_PROCESSOR) + IF(IOS) + LIST(LENGTH IOS_ARCH IOS_ARCH_COUNT) + IF(IOS_ARCH_COUNT GREATER 1) + MESSAGE(FATAL_ERROR "Unsupported QNNPACK build with multiple iOS architectures (${IOS_ARCH}). " + "Specify a single architecture in IOS_ARCH and re-configure. ") + ENDIF() + IF(NOT IOS_ARCH MATCHES "^(i386|x86_64|armv7.*|arm64.*)$") + MESSAGE(FATAL_ERROR "Unrecognized IOS_ARCH = ${IOS_ARCH}") + ENDIF() + ELSE() + MESSAGE(FATAL_ERROR "CMAKE_SYSTEM_PROCESSOR is not defined") + ENDIF() +ELSEIF(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|x86_64|armv[5-8].*|aarch64)$") + MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_PROCESSOR = ${CMAKE_SYSTEM_PROCESSOR}") +ENDIF() + +IF(NOT CMAKE_SYSTEM_NAME) + MESSAGE(FATAL_ERROR "CMAKE_SYSTEM_NAME not defined") +ELSEIF(NOT CMAKE_SYSTEM_NAME MATCHES "^(Darwin|Linux|Android)$") + MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_NAME = ${CMAKE_SYSTEM_NAME}") +ENDIF() + +# ---[ Download deps +SET(CONFU_DEPENDENCIES_SOURCE_DIR "${CMAKE_SOURCE_DIR}/deps" + CACHE PATH "Confu-style dependencies source directory") +SET(CONFU_DEPENDENCIES_BINARY_DIR "${CMAKE_BINARY_DIR}/deps" + CACHE PATH "Confu-style dependencies binary directory") + +IF(NOT DEFINED CLOG_SOURCE_DIR) + SET(CLOG_SOURCE_DIR "${PROJECT_SOURCE_DIR}/deps/clog") +ENDIF() + +IF(NOT DEFINED CPUINFO_SOURCE_DIR) + MESSAGE(STATUS "Downloading cpuinfo to ${CONFU_DEPENDENCIES_SOURCE_DIR}/cpuinfo (define CPUINFO_SOURCE_DIR to avoid it)") + CONFIGURE_FILE(cmake/DownloadCpuinfo.cmake "${CONFU_DEPENDENCIES_BINARY_DIR}/cpuinfo-download/CMakeLists.txt") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . + WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/cpuinfo-download") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . + WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/cpuinfo-download") + SET(CPUINFO_SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/cpuinfo" CACHE STRING "cpuinfo source directory") +ENDIF() + +IF(NOT DEFINED FP16_SOURCE_DIR) + MESSAGE(STATUS "Downloading FP16 to ${CONFU_DEPENDENCIES_SOURCE_DIR}/fp16 (define FP16_SOURCE_DIR to avoid it)") + CONFIGURE_FILE(cmake/DownloadFP16.cmake "${CONFU_DEPENDENCIES_BINARY_DIR}/fp16-download/CMakeLists.txt") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . + WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/fp16-download") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . + WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/fp16-download") + SET(FP16_SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/fp16" CACHE STRING "FP16 source directory") +ENDIF() + +IF(NOT DEFINED FXDIV_SOURCE_DIR) + MESSAGE(STATUS "Downloading FXdiv to ${CONFU_DEPENDENCIES_SOURCE_DIR}/fxdiv (define FXDIV_SOURCE_DIR to avoid it)") + CONFIGURE_FILE(cmake/DownloadFXdiv.cmake "${CONFU_DEPENDENCIES_BINARY_DIR}/fxdiv-download/CMakeLists.txt") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . + WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/fxdiv-download") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . + WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/fxdiv-download") + SET(FXDIV_SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/fxdiv" CACHE STRING "FXdiv source directory") +ENDIF() + +IF(NOT DEFINED PSIMD_SOURCE_DIR) + MESSAGE(STATUS "Downloading PSimd to ${CONFU_DEPENDENCIES_SOURCE_DIR}/psimd (define PSIMD_SOURCE_DIR to avoid it)") + CONFIGURE_FILE(cmake/DownloadPSimd.cmake "${CONFU_DEPENDENCIES_BINARY_DIR}/psimd-download/CMakeLists.txt") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . + WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/psimd-download") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . + WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/psimd-download") + SET(PSIMD_SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/psimd" CACHE STRING "PSimd source directory") +ENDIF() + +IF(NOT DEFINED PTHREADPOOL_SOURCE_DIR) + MESSAGE(STATUS "Downloading pthreadpool to ${CONFU_DEPENDENCIES_SOURCE_DIR}/pthreadpool (define PTHREADPOOL_SOURCE_DIR to avoid it)") + CONFIGURE_FILE(cmake/DownloadPThreadPool.cmake "${CONFU_DEPENDENCIES_BINARY_DIR}/pthreadpool-download/CMakeLists.txt") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . + WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/pthreadpool-download") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . + WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/pthreadpool-download") + SET(PTHREADPOOL_SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/pthreadpool" CACHE STRING "pthreadpool source directory") +ENDIF() + +IF(PYTORCH_QNNPACK_BUILD_TESTS AND NOT DEFINED GOOGLETEST_SOURCE_DIR) + MESSAGE(STATUS "Downloading Google Test to ${CONFU_DEPENDENCIES_SOURCE_DIR}/googletest (define GOOGLETEST_SOURCE_DIR to avoid it)") + CONFIGURE_FILE(cmake/DownloadGoogleTest.cmake "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest-download/CMakeLists.txt") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . + WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest-download") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . + WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest-download") + SET(GOOGLETEST_SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/googletest" CACHE STRING "Google Test source directory") +ENDIF() + +IF(PYTORCH_QNNPACK_BUILD_BENCHMARKS AND NOT DEFINED GOOGLEBENCHMARK_SOURCE_DIR) + MESSAGE(STATUS "Downloading Google Benchmark to ${CONFU_DEPENDENCIES_SOURCE_DIR}/googlebenchmark (define GOOGLEBENCHMARK_SOURCE_DIR to avoid it)") + CONFIGURE_FILE(cmake/DownloadGoogleBenchmark.cmake "${CONFU_DEPENDENCIES_BINARY_DIR}/googlebenchmark-download/CMakeLists.txt") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . + WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/googlebenchmark-download") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . + WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/googlebenchmark-download") + SET(GOOGLEBENCHMARK_SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/googlebenchmark" CACHE STRING "Google Benchmark source directory") +ENDIF() + +# ---[ QNNPACK library +SET(PYTORCH_QNNPACK_INIT_SRCS + src/init.c + src/add.c + src/average-pooling.c + src/channel-shuffle.c + src/clamp.c + src/conv-prepack.cc + src/convolution.c + src/deconvolution.c + src/fc-prepack.cc + src/fully-connected.c + src/global-average-pooling.c + src/leaky-relu.c + src/max-pooling.c + src/sigmoid.c + src/softargmax.c + src/operator-delete.c) + +SET(PYTORCH_QNNPACK_EXEC_SRCS + src/conv-run.cc + src/fc-run.cc + src/indirection.c + src/operator-run.c) + +SET(PYTORCH_QNNPACK_SCALAR_UKERNELS + src/u8lut32norm/scalar.c + src/x8lut/scalar.c) + +SET(PYTORCH_QNNPACK_PSIMD_UKERNELS + src/sgemm/6x8-psimd.c) + +SET(PYTORCH_QNNPACK_ARM_NEON_UKERNELS + src/q8avgpool/mp8x9p8q-neon.c + src/q8avgpool/up8x9-neon.c + src/q8avgpool/up8xm-neon.c + src/q8conv/4x8-neon.c + src/q8conv/8x8-neon.c + src/q8dwconv/mp8x25-neon.c + src/q8dwconv/up8x9-neon.c + src/q8gavgpool/mp8x7p7q-neon.c + src/q8gavgpool/up8x7-neon.c + src/q8gavgpool/up8xm-neon.c + src/q8gemm/4x-sumrows-neon.c + src/q8gemm/4x8-neon.c + src/q8gemm/4x8c2-xzp-neon.c + src/q8gemm/6x4-neon.c + src/q8gemm/8x8-neon.c + src/q8vadd/neon.c + src/sgemm/5x8-neon.c + src/sgemm/6x8-neon.c + src/u8clamp/neon.c + src/u8maxpool/16x9p8q-neon.c + src/u8maxpool/sub16-neon.c + src/u8rmax/neon.c + src/x8zip/x2-neon.c + src/x8zip/x3-neon.c + src/x8zip/x4-neon.c + src/x8zip/xm-neon.c) + +SET(PYTORCH_QNNPACK_AARCH32_ASM_UKERNELS + src/hgemm/8x8-aarch32-neonfp16arith.S + src/q8conv/4x8-aarch32-neon.S + src/q8dwconv/up8x9-aarch32-neon.S + src/q8gemm/4x8-aarch32-neon.S + src/q8gemm/4x8c2-xzp-aarch32-neon.S) + +SET(PYTORCH_QNNPACK_AARCH64_ASM_UKERNELS + src/q8conv/8x8-aarch64-neon.S + src/q8gemm/8x8-aarch64-neon.S) + +SET(PYTORCH_QNNPACK_X86_SSE2_UKERNELS + src/q8avgpool/mp8x9p8q-sse2.c + src/q8avgpool/up8x9-sse2.c + src/q8avgpool/up8xm-sse2.c + src/q8conv/4x4c2-sse2.c + src/q8dwconv/mp8x25-sse2.c + src/q8dwconv/up8x9-sse2.c + src/q8gavgpool/mp8x7p7q-sse2.c + src/q8gavgpool/up8x7-sse2.c + src/q8gavgpool/up8xm-sse2.c + src/q8gemm/2x4c8-sse2.c + src/q8gemm/4x4c2-sse2.c + src/q8vadd/sse2.c + src/u8clamp/sse2.c + src/u8maxpool/16x9p8q-sse2.c + src/u8maxpool/sub16-sse2.c + src/u8rmax/sse2.c + src/x8zip/x2-sse2.c + src/x8zip/x3-sse2.c + src/x8zip/x4-sse2.c + src/x8zip/xm-sse2.c) + +SET(PYTORCH_QNNPACK_UKERNELS ${PYTORCH_QNNPACK_SCALAR_UKERNELS} ${PYTORCH_QNNPACK_PSIMD_UKERNELS}) +IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]" OR IOS_ARCH MATCHES "^armv7") + LIST(APPEND PYTORCH_QNNPACK_UKERNELS ${PYTORCH_QNNPACK_ARM_NEON_UKERNELS}) + LIST(APPEND PYTORCH_QNNPACK_UKERNELS ${PYTORCH_QNNPACK_AARCH32_ASM_UKERNELS}) +ENDIF() +IF(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR IOS_ARCH MATCHES "^arm64.*") + LIST(APPEND PYTORCH_QNNPACK_UKERNELS ${PYTORCH_QNNPACK_ARM_NEON_UKERNELS}) + LIST(APPEND PYTORCH_QNNPACK_UKERNELS ${PYTORCH_QNNPACK_AARCH64_ASM_UKERNELS}) +ENDIF() +IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|x86_64)$" OR IOS_ARCH MATCHES "^(i386|x86_64)$") + LIST(APPEND PYTORCH_QNNPACK_UKERNELS ${PYTORCH_QNNPACK_X86_SSE2_UKERNELS}) +ENDIF() + +IF(PYTORCH_QNNPACK_LIBRARY_TYPE STREQUAL "default") + ADD_LIBRARY(pytorch_qnnpack ${PYTORCH_QNNPACK_INIT_SRCS} ${PYTORCH_QNNPACK_EXEC_SRCS} ${PYTORCH_QNNPACK_UKERNELS}) +ELSEIF(PYTORCH_QNNPACK_LIBRARY_TYPE STREQUAL "shared") + ADD_LIBRARY(pytorch_qnnpack SHARED ${PYTORCH_QNNPACK_INIT_SRCS} ${PYTORCH_QNNPACK_EXEC_SRCS} ${PYTORCH_QNNPACK_UKERNELS}) +ELSEIF(PYTORCH_QNNPACK_LIBRARY_TYPE STREQUAL "static") + ADD_LIBRARY(pytorch_qnnpack STATIC ${PYTORCH_QNNPACK_INIT_SRCS} ${PYTORCH_QNNPACK_EXEC_SRCS} ${PYTORCH_QNNPACK_UKERNELS}) +ELSE() + MESSAGE(FATAL_ERROR "Unsupported QNNPACK library type \"${PYTORCH_QNNPACK_LIBRARY_TYPE}\". Must be \"static\", \"shared\", or \"default\"") +ENDIF() +SET_TARGET_PROPERTIES(pytorch_qnnpack PROPERTIES + CXX_STANDARD 11 + C_STANDARD 99 + C_EXTENSIONS YES) +IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]" OR IOS_ARCH MATCHES "^armv7") + SET_PROPERTY(SOURCE ${PYTORCH_QNNPACK_ARM_NEON_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -O2 -marm -mfpu=neon ") + IF(IOS) + SET_PROPERTY(SOURCE ${PYTORCH_QNNPACK_AARCH32_ASM_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -arch ${IOS_ARCH} ") + ENDIF() +ENDIF() +IF(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR IOS_ARCH MATCHES "^arm64.*") + SET_PROPERTY(SOURCE ${PYTORCH_QNNPACK_ARM_NEON_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -O2 ") + IF(IOS) + SET_PROPERTY(SOURCE ${PYTORCH_QNNPACK_AARCH64_ASM_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -arch ${IOS_ARCH} ") + ENDIF() +ENDIF() +IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|x86_64)$" OR IOS_ARCH MATCHES "^(i386|x86_64)$") + SET_PROPERTY(SOURCE ${PYTORCH_QNNPACK_X86_SSE2_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -O2 -msse2 ") +ENDIF() +IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]" OR IOS_ARCH MATCHES "^armv7") + SET_PROPERTY(SOURCE ${PYTORCH_QNNPACK_PSIMD_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -O2 -marm -mfpu=neon ") + SET_PROPERTY(SOURCE ${PYTORCH_QNNPACK_SCALAR_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -O2 -marm ") +ELSE() + SET_PROPERTY(SOURCE ${PYTORCH_QNNPACK_PSIMD_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -O2 ") + SET_PROPERTY(SOURCE ${PYTORCH_QNNPACK_SCALAR_UKERNELS} APPEND_STRING PROPERTY COMPILE_FLAGS " -O2 ") +ENDIF() +SET_PROPERTY(SOURCE ${PYTORCH_QNNPACK_INIT_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -Os ") +IF(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") + SET_PROPERTY(SOURCE ${PYTORCH_QNNPACK_OPERATOR_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -O2 ") +ENDIF() +TARGET_INCLUDE_DIRECTORIES(pytorch_qnnpack PUBLIC include) +TARGET_INCLUDE_DIRECTORIES(pytorch_qnnpack PUBLIC src) +SET_TARGET_PROPERTIES(pytorch_qnnpack PROPERTIES PUBLIC_HEADER include/pytorch_qnnpack.h) +SET_TARGET_PROPERTIES(pytorch_qnnpack PROPERTIES PUBLIC_HEADER include/conv_utils.h) +SET_TARGET_PROPERTIES(pytorch_qnnpack PROPERTIES PUBLIC_HEADER include/qnnpack_func.h) + +# ---[ Configure clog +IF(NOT TARGET clog) + SET(CLOG_BUILD_TESTS OFF CACHE BOOL "") + SET(CLOG_RUNTIME_TYPE "${CPUINFO_RUNTIME_TYPE}" CACHE STRING "") + ADD_SUBDIRECTORY( + "${CLOG_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/clog") + # We build static version of clog but a dynamic library may indirectly depend on it + SET_PROPERTY(TARGET clog PROPERTY POSITION_INDEPENDENT_CODE ON) +ENDIF() +TARGET_LINK_LIBRARIES(pytorch_qnnpack PUBLIC clog) + +# ---[ Configure cpuinfo +IF(NOT TARGET cpuinfo) + SET(CPUINFO_BUILD_TOOLS OFF CACHE BOOL "") + SET(CPUINFO_BUILD_UNIT_TESTS OFF CACHE BOOL "") + SET(CPUINFO_BUILD_MOCK_TESTS OFF CACHE BOOL "") + SET(CPUINFO_BUILD_BENCHMARKS OFF CACHE BOOL "") + ADD_SUBDIRECTORY( + "${CPUINFO_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/cpuinfo") +ENDIF() +TARGET_LINK_LIBRARIES(pytorch_qnnpack PRIVATE cpuinfo) + +# ---[ Configure pthreadpool +IF(NOT TARGET pthreadpool) + SET(PTHREADPOOL_BUILD_TESTS OFF CACHE BOOL "") + SET(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE BOOL "") + ADD_SUBDIRECTORY( + "${PTHREADPOOL_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/pthreadpool") +ENDIF() +IF(PYTORCH_QNNPACK_CUSTOM_THREADPOOL) + # Depend on pthreadpool interface, but not on implementation. + # This is used when QNNPACK user (e.g. Caffe2) provides its own threadpool implementation. + TARGET_LINK_LIBRARIES(pytorch_qnnpack PUBLIC pthreadpool_interface) +ELSE() + TARGET_LINK_LIBRARIES(pytorch_qnnpack PUBLIC pthreadpool) +ENDIF() + +# ---[ Configure FXdiv +IF(NOT TARGET fxdiv) + SET(FXDIV_BUILD_TESTS OFF CACHE BOOL "") + SET(FXDIV_BUILD_BENCHMARKS OFF CACHE BOOL "") + ADD_SUBDIRECTORY( + "${FXDIV_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/fxdiv") +ENDIF() +TARGET_LINK_LIBRARIES(pytorch_qnnpack PRIVATE fxdiv) + +# ---[ Configure psimd +IF(NOT TARGET psimd) + ADD_SUBDIRECTORY( + "${PSIMD_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/psimd") +ENDIF() +TARGET_LINK_LIBRARIES(pytorch_qnnpack PRIVATE psimd) + +# ---[ Configure FP16 +IF(NOT TARGET fp16) + SET(FP16_BUILD_TESTS OFF CACHE BOOL "") + SET(FP16_BUILD_BENCHMARKS OFF CACHE BOOL "") + ADD_SUBDIRECTORY( + "${FP16_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/fp16") +ENDIF() +TARGET_LINK_LIBRARIES(pytorch_qnnpack PRIVATE fp16) + +INSTALL(TARGETS pytorch_qnnpack + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) + +# ---[ QNNPACK unit tests +IF(PYTORCH_QNNPACK_BUILD_TESTS) + # ---[ Build google test + IF(NOT TARGET gtest) + SET(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + ADD_SUBDIRECTORY( + "${GOOGLETEST_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest") + ENDIF() + + # ---[ Build unit tests for high-level functionality + ADD_EXECUTABLE(convolution-test test/convolution.cc) + SET_TARGET_PROPERTIES(convolution-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(convolution-test PRIVATE src test) + TARGET_LINK_LIBRARIES(convolution-test PRIVATE pytorch_qnnpack clog cpuinfo fp16 gtest gtest_main) + ADD_TEST(convolution-test convolution-test) + + ADD_EXECUTABLE(deconvolution-test test/deconvolution.cc) + SET_TARGET_PROPERTIES(deconvolution-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(deconvolution-test PRIVATE src test) + TARGET_LINK_LIBRARIES(deconvolution-test PRIVATE pytorch_qnnpack cpuinfo gtest gtest_main) + ADD_TEST(deconvolution-test deconvolution-test) + + ADD_EXECUTABLE(fully-connected-test test/fully-connected.cc) + SET_TARGET_PROPERTIES(fully-connected-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(fully-connected-test PRIVATE src test) + TARGET_LINK_LIBRARIES(fully-connected-test PRIVATE pytorch_qnnpack clog cpuinfo fp16 gtest gtest_main) + ADD_TEST(fully-connected-test fully-connected-test) + + ADD_EXECUTABLE(channel-shuffle-test test/channel-shuffle.cc) + SET_TARGET_PROPERTIES(channel-shuffle-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(channel-shuffle-test PRIVATE src test) + TARGET_LINK_LIBRARIES(channel-shuffle-test PRIVATE pytorch_qnnpack cpuinfo gtest gtest_main) + ADD_TEST(channel-shuffle-test channel-shuffle-test) + + ADD_EXECUTABLE(add-test test/add.cc) + SET_TARGET_PROPERTIES(add-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(add-test PRIVATE src test) + TARGET_LINK_LIBRARIES(add-test PRIVATE pytorch_qnnpack cpuinfo gtest gtest_main) + ADD_TEST(add-test add-test) + + ADD_EXECUTABLE(leaky-relu-test test/leaky-relu.cc) + SET_TARGET_PROPERTIES(leaky-relu-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(leaky-relu-test PRIVATE src test) + TARGET_LINK_LIBRARIES(leaky-relu-test PRIVATE pytorch_qnnpack cpuinfo gtest gtest_main) + ADD_TEST(leaky-relu-test leaky-relu-test) + + ADD_EXECUTABLE(sigmoid-test test/sigmoid.cc) + SET_TARGET_PROPERTIES(sigmoid-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(sigmoid-test PRIVATE src test) + TARGET_LINK_LIBRARIES(sigmoid-test PRIVATE pytorch_qnnpack cpuinfo gtest gtest_main) + ADD_TEST(sigmoid-test sigmoid-test) + + ADD_EXECUTABLE(clamp-test test/clamp.cc) + SET_TARGET_PROPERTIES(clamp-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(clamp-test PRIVATE src test) + TARGET_LINK_LIBRARIES(clamp-test PRIVATE pytorch_qnnpack cpuinfo gtest gtest_main) + ADD_TEST(clamp-test clamp-test) + + ADD_EXECUTABLE(softargmax-test test/softargmax.cc) + SET_TARGET_PROPERTIES(softargmax-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(softargmax-test PRIVATE src test) + TARGET_LINK_LIBRARIES(softargmax-test PRIVATE pytorch_qnnpack cpuinfo gtest gtest_main) + ADD_TEST(softargmax-test softargmax-test) + + ADD_EXECUTABLE(max-pooling-test test/max-pooling.cc) + SET_TARGET_PROPERTIES(max-pooling-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(max-pooling-test PRIVATE src test) + TARGET_LINK_LIBRARIES(max-pooling-test PRIVATE pytorch_qnnpack cpuinfo gtest gtest_main) + ADD_TEST(max-pooling-test max-pooling-test) + + ADD_EXECUTABLE(average-pooling-test test/average-pooling.cc) + SET_TARGET_PROPERTIES(average-pooling-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(average-pooling-test PRIVATE src test) + TARGET_LINK_LIBRARIES(average-pooling-test PRIVATE pytorch_qnnpack cpuinfo gtest gtest_main) + ADD_TEST(average-pooling-test average-pooling-test) + + ADD_EXECUTABLE(global-average-pooling-test test/global-average-pooling.cc) + SET_TARGET_PROPERTIES(global-average-pooling-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(global-average-pooling-test PRIVATE src test) + TARGET_LINK_LIBRARIES(global-average-pooling-test PRIVATE pytorch_qnnpack cpuinfo gtest gtest_main) + ADD_TEST(global-average-pooling-test global-average-pooling-test) + + # ---[ Build unit tests for micro-kernels + ADD_EXECUTABLE(q8gemm-test test/q8gemm.cc) + SET_TARGET_PROPERTIES(q8gemm-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(q8gemm-test PRIVATE src test) + TARGET_LINK_LIBRARIES(q8gemm-test PRIVATE pytorch_qnnpack cpuinfo fp16 gtest gtest_main) + ADD_TEST(q8gemm-test q8gemm-test) + + ADD_EXECUTABLE(q8conv-test test/q8conv.cc) + SET_TARGET_PROPERTIES(q8conv-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(q8conv-test PRIVATE src test) + TARGET_LINK_LIBRARIES(q8conv-test PRIVATE pytorch_qnnpack cpuinfo fp16 gtest gtest_main) + ADD_TEST(q8conv-test q8conv-test) + + ADD_EXECUTABLE(q8dwconv-test test/q8dwconv.cc) + SET_TARGET_PROPERTIES(q8dwconv-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(q8dwconv-test PRIVATE src test) + TARGET_LINK_LIBRARIES(q8dwconv-test PRIVATE pytorch_qnnpack cpuinfo fp16 gtest gtest_main) + ADD_TEST(q8dwconv-test q8dwconv-test) + + ADD_EXECUTABLE(q8vadd-test test/q8vadd.cc) + SET_TARGET_PROPERTIES(q8vadd-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(q8vadd-test PRIVATE src test) + TARGET_LINK_LIBRARIES(q8vadd-test PRIVATE pytorch_qnnpack cpuinfo fp16 gtest gtest_main) + ADD_TEST(q8vadd-test q8vadd-test) + + ADD_EXECUTABLE(q8avgpool-test test/q8avgpool.cc) + SET_TARGET_PROPERTIES(q8avgpool-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(q8avgpool-test PRIVATE src test) + TARGET_LINK_LIBRARIES(q8avgpool-test PRIVATE pytorch_qnnpack cpuinfo fp16 gtest gtest_main) + ADD_TEST(q8avgpool-test q8avgpool-test) + + ADD_EXECUTABLE(q8gavgpool-test test/q8gavgpool.cc) + SET_TARGET_PROPERTIES(q8gavgpool-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(q8gavgpool-test PRIVATE src test) + TARGET_LINK_LIBRARIES(q8gavgpool-test PRIVATE pytorch_qnnpack cpuinfo fp16 gtest gtest_main) + ADD_TEST(q8gavgpool-test q8gavgpool-test) + + ADD_EXECUTABLE(u8maxpool-test test/u8maxpool.cc) + SET_TARGET_PROPERTIES(u8maxpool-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(u8maxpool-test PRIVATE src test) + TARGET_LINK_LIBRARIES(u8maxpool-test PRIVATE pytorch_qnnpack cpuinfo fp16 gtest gtest_main) + ADD_TEST(u8maxpool-test u8maxpool-test) + + ADD_EXECUTABLE(u8clamp-test test/u8clamp.cc) + SET_TARGET_PROPERTIES(u8clamp-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(u8clamp-test PRIVATE src test) + TARGET_LINK_LIBRARIES(u8clamp-test PRIVATE pytorch_qnnpack cpuinfo fp16 gtest gtest_main) + ADD_TEST(u8clamp-test u8clamp-test) + + ADD_EXECUTABLE(u8rmax-test test/u8rmax.cc) + SET_TARGET_PROPERTIES(u8rmax-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(u8rmax-test PRIVATE src test) + TARGET_LINK_LIBRARIES(u8rmax-test PRIVATE pytorch_qnnpack cpuinfo fp16 gtest gtest_main) + ADD_TEST(u8rmax-test u8rmax-test) + + ADD_EXECUTABLE(u8lut32norm-test test/u8lut32norm.cc) + SET_TARGET_PROPERTIES(u8lut32norm-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(u8lut32norm-test PRIVATE src test) + TARGET_LINK_LIBRARIES(u8lut32norm-test PRIVATE pytorch_qnnpack cpuinfo fp16 gtest gtest_main) + ADD_TEST(u8lut32norm-test u8lut32norm-test) + + ADD_EXECUTABLE(x8lut-test test/x8lut.cc) + SET_TARGET_PROPERTIES(x8lut-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(x8lut-test PRIVATE src test) + TARGET_LINK_LIBRARIES(x8lut-test PRIVATE pytorch_qnnpack cpuinfo fp16 gtest gtest_main) + ADD_TEST(x8lut-test x8lut-test) + + ADD_EXECUTABLE(x8zip-test test/x8zip.cc) + SET_TARGET_PROPERTIES(x8zip-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(x8zip-test PRIVATE src test) + TARGET_LINK_LIBRARIES(x8zip-test PRIVATE pytorch_qnnpack cpuinfo fp16 gtest gtest_main) + ADD_TEST(x8zip-test x8zip-test) + + ADD_EXECUTABLE(hgemm-test test/hgemm.cc) + SET_TARGET_PROPERTIES(hgemm-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(hgemm-test PRIVATE src test) + TARGET_LINK_LIBRARIES(hgemm-test PRIVATE pytorch_qnnpack cpuinfo fp16 gtest gtest_main) + ADD_TEST(hgemm-test hgemm-test) + + ADD_EXECUTABLE(sgemm-test test/sgemm.cc) + SET_TARGET_PROPERTIES(sgemm-test PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(sgemm-test PRIVATE src test) + TARGET_LINK_LIBRARIES(sgemm-test PRIVATE pytorch_qnnpack cpuinfo fp16 gtest gtest_main) + ADD_TEST(sgemm-test sgemm-test) +ENDIF() + +# ---[ QNNPACK micro-benchmarks +IF(PYTORCH_QNNPACK_BUILD_BENCHMARKS) + # ---[ Build google benchmark + IF(NOT TARGET benchmark) + SET(BENCHMARK_ENABLE_TESTING OFF CACHE BOOL "") + ADD_SUBDIRECTORY( + "${GOOGLEBENCHMARK_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/googlebenchmark") + ENDIF() + + ADD_EXECUTABLE(add-bench bench/add.cc) + SET_TARGET_PROPERTIES(add-bench PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_LINK_LIBRARIES(add-bench PRIVATE pytorch_qnnpack benchmark) + + ADD_EXECUTABLE(average-pooling-bench bench/average-pooling.cc) + SET_TARGET_PROPERTIES(average-pooling-bench PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_LINK_LIBRARIES(average-pooling-bench PRIVATE pytorch_qnnpack benchmark) + + ADD_EXECUTABLE(channel-shuffle-bench bench/channel-shuffle.cc) + SET_TARGET_PROPERTIES(channel-shuffle-bench PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_LINK_LIBRARIES(channel-shuffle-bench PRIVATE pytorch_qnnpack benchmark) + + ADD_EXECUTABLE(convolution-bench bench/convolution.cc) + SET_TARGET_PROPERTIES(convolution-bench PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_LINK_LIBRARIES(convolution-bench PRIVATE pytorch_qnnpack benchmark) + + ADD_EXECUTABLE(global-average-pooling-bench bench/global-average-pooling.cc) + SET_TARGET_PROPERTIES(global-average-pooling-bench PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_LINK_LIBRARIES(global-average-pooling-bench PRIVATE pytorch_qnnpack benchmark) + + ADD_EXECUTABLE(max-pooling-bench bench/max-pooling.cc) + SET_TARGET_PROPERTIES(max-pooling-bench PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_LINK_LIBRARIES(max-pooling-bench PRIVATE pytorch_qnnpack benchmark) + + ADD_EXECUTABLE(sigmoid-bench bench/sigmoid.cc) + SET_TARGET_PROPERTIES(sigmoid-bench PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_LINK_LIBRARIES(sigmoid-bench PRIVATE pytorch_qnnpack benchmark) + + ADD_EXECUTABLE(softargmax-bench bench/softargmax.cc) + SET_TARGET_PROPERTIES(softargmax-bench PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_LINK_LIBRARIES(softargmax-bench PRIVATE pytorch_qnnpack benchmark) + + ADD_EXECUTABLE(q8gemm-bench bench/q8gemm.cc) + SET_TARGET_PROPERTIES(q8gemm-bench PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(q8gemm-bench PRIVATE src) + TARGET_COMPILE_DEFINITIONS(q8gemm-bench PRIVATE pytorch_PYTORCH_QNNPACK_BENCHMARK_GEMMLOWP=0) + TARGET_LINK_LIBRARIES(q8gemm-bench PRIVATE pytorch_qnnpack cpuinfo fp16 benchmark) + + ADD_EXECUTABLE(hgemm-bench bench/hgemm.cc) + SET_TARGET_PROPERTIES(hgemm-bench PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(hgemm-bench PRIVATE src) + TARGET_LINK_LIBRARIES(hgemm-bench PRIVATE pytorch_qnnpack cpuinfo fp16 benchmark) + + ADD_EXECUTABLE(sgemm-bench bench/sgemm.cc) + SET_TARGET_PROPERTIES(sgemm-bench PROPERTIES + CXX_STANDARD 11 + CXX_STANDARD_REQUIRED YES + CXX_EXTENSIONS NO) + TARGET_INCLUDE_DIRECTORIES(sgemm-bench PRIVATE src) + TARGET_LINK_LIBRARIES(sgemm-bench PRIVATE pytorch_qnnpack cpuinfo benchmark) +ENDIF() diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/CODE_OF_CONDUCT.md b/aten/src/ATen/native/quantized/cpu/qnnpack/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000..0f7ad8bfc173e --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/CODE_OF_CONDUCT.md @@ -0,0 +1,5 @@ +# Code of Conduct + +Facebook has adopted a Code of Conduct that we expect project participants to adhere to. +Please read the [full text](https://code.fb.com/codeofconduct/) +so that you can understand what actions will and will not be tolerated. diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/CONTRIBUTING.md b/aten/src/ATen/native/quantized/cpu/qnnpack/CONTRIBUTING.md new file mode 100644 index 0000000000000..cd6b1221a0717 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/CONTRIBUTING.md @@ -0,0 +1,34 @@ +# Contributing to QNNPACK +We want to make contributing to this project as easy and transparent as +possible. + +## Code of Conduct +The code of conduct is described in [`CODE_OF_CONDUCT.md`](CODE_OF_CONDUCT.md). + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `master`. +2. If you've added code that should be tested, add tests. +3. If you've added new micro-kernels, update or add micro-benchmarks. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to QNNPACK, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/LICENSE b/aten/src/ATen/native/quantized/cpu/qnnpack/LICENSE new file mode 100644 index 0000000000000..fc96febb92e1c --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/LICENSE @@ -0,0 +1,30 @@ +BSD License + +For QNNPACK software + +Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Facebook nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/README.md b/aten/src/ATen/native/quantized/cpu/qnnpack/README.md new file mode 100644 index 0000000000000..84f26ee59c88c --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/README.md @@ -0,0 +1,186 @@ +# QNNPACK +QNNPACK (Quantized Neural Networks PACKage) is a mobile-optimized library for low-precision high-performance neural network inference. QNNPACK provides implementation of common neural network operators on quantized 8-bit tensors. + +QNNPACK is not intended to be directly used by machine learning researchers; instead it provides low-level performance primitives for high-level deep learning frameworks. As of today, QNNPACK is integrated in [PyTorch 1.0](https://github.com/pytorch/pytorch) with Caffe2 graph representation. + +## Operator Coverage + +Currently implemented and planned for implementation operators are below: + +- [x] 2D Convolution +- [x] 2D Deconvolution +- [x] Channel Shuffle +- [x] Fully Connected +- [ ] Locally Connected +- [x] 2D Max Pooling +- [x] 2D Average Pooling +- [x] Global Average Pooling +- [x] Sigmoid +- [x] Leaky ReLU +- [x] Clamp (can be used for ReLU, ReLU6 if it is not fused in another operator) +- [x] SoftArgMax (aka SoftMax) +- [ ] Group Normalization + +## Building + +QNNPACK provides standard CMake-based build scripts. + +### Native compilation + +Users are recommended to use `scripts/build-local.sh` script to build QNNPACK for the host machine. + +### Cross-compilation for Android + +To cross-compile for Android, set `$ANDROID_NDK` environment variable (where `$ANDROID_NDK` is the path to Android NDK directory, e.g. `/opt/android-ndk-r15c`) and use one of the scripts from the table below: + +| ABI | Build script | Restrictions | +| ----------- | ---------------------------------| -------------------------- | +| armeabi-v7a | `scripts/build-android-armv7.sh` | Requires CPU with ARM NEON | +| arm64-v8a | `scripts/build-android-arm64.sh` | | +| x86 | `scripts/build-android-x86.sh` | | + +Notes: +- On **armeabi-v7a** `pytorch_qnnp_initialize` will fail with `pytorch_qnnp_status_unsupported_hardware` if the mobile CPU does not support ARM NEON. Don't set `-DANDROID_ARM_NEON=1` for QNNPACK compilation as it can make `pytorch_qnnp_initialize` crash on CPUs without ARM NEON. + +### Cross-compilation for iOS + +To cross-compile for iOS, clone [ios-cmake](https://github.com/leetal/ios-cmake), and set `$IOS_CMAKE_TOOLCHAIN_FILE` environment variable (where `$IOS_CMAKE_TOOLCHAIN_FILE` is the path to `ios.toolchain.cmake` file in [ios-cmake](https://github.com/leetal/ios-cmake)), and use one of the scripts from the table below: + +| Architecture | Build script | Notes | +| ------------ | ----------------------------- | ------------------------- | +| armv7 | `scripts/build-ios-armv7.sh` | iPhone 3GS/4/4S | +| armv7 | `scripts/build-ios-armv7s.sh` | iPhone 5 and newer | +| arm64 | `scripts/build-ios-arm64.sh` | iPhone 5S and newer | +| arm64e | `scripts/build-ios-arm64e.sh` | iPhone XS/XR | +| i386 | `scripts/build-ios-i386.sh` | iPhone Simulator (32-bit) | +| x86_64 | `scripts/build-ios-x86_64.sh` | iPhone Simulator (64-bit) | + +## End-to-End Benchmarking + +Caffe2 backend of PyTorch 1.0 natively integrates QNNPACK, and provides a [pre-trained quantized MobileNet v2 model](https://github.com/caffe2/models/tree/master/mobilenet_v2_quantized). Below are instructions for benchmarking this model end-to-end with QNNPACK. + +### Raspberry Pi 2 or 3 + +```bash +# Clone PyTorch 1.0 repo +git clone --recursive https://github.com/pytorch/pytorch.git +cd pytorch + +# Optional: update QNNPACK submodule to latest revision +git submodule update --remote third_party/QNNPACK + +# Build Caffe2 (including binaries) for the host system +# Use only 1 thread for build to avoid out-of-memory failures +MAX_JOBS=1 scripts/build_local.sh -DBUILD_BINARY=ON -DBUILD_PYTHON=OFF \ + -DUSE_OBSERVERS=OFF -DUSE_DISTRIBUTED=OFF + +# Download model weights +wget https://s3.amazonaws.com/download.caffe2.ai/models/mobilenet_v2_1.0_224_quant/init_net.pb + +# Download model graph +wget https://s3.amazonaws.com/download.caffe2.ai/models/mobilenet_v2_1.0_224_quant/predict_net.pb + +# Run speed benchmark with 50 warm-up iterations and 10 measurement iterations +build/bin/speed_benchmark --net predict_net.pb --init_net init_net.pb \ + --input data --input_dims 1,3,224,224 --input_type float \ + --warmup 50 --iter 10 +``` + +### ARMv7 (32-bit) Android + +```bash +# Clone PyTorch 1.0 repo +git clone --recursive https://github.com/pytorch/pytorch.git +cd pytorch + +# Optional: update QNNPACK submodule to latest revision +git submodule update --remote third_party/QNNPACK + +# Build Caffe2 (including binaries) for Android, and push to device +scripts/build_android.sh -DANDROID_TOOLCHAIN=clang -DBUILD_BINARY=ON +adb push build_android/bin/speed_benchmark /data/local/tmp/speed_benchmark + +# Download model weights and copy them to Android device +wget https://s3.amazonaws.com/download.caffe2.ai/models/mobilenet_v2_1.0_224_quant/init_net.pb +adb push init_net.pb /data/local/tmp/init_net.pb + +# Download model graph and copy it to Android device +wget https://s3.amazonaws.com/download.caffe2.ai/models/mobilenet_v2_1.0_224_quant/predict_net.pb +adb push predict_net.pb /data/local/tmp/predict_net.pb + +# Run speed benchmark with 50 warm-up iterations and 10 measurement iterations +adb shell /data/local/tmp/speed_benchmark \ + --net /data/local/tmp/predict_net.pb \ + --init_net /data/local/tmp/init_net.pb \ + --input data --input_dims 1,3,224,224 --input_type float \ + --warmup 50 --iter 10 +``` + +### ARM64 (64-bit) Android + +```bash +# Clone PyTorch 1.0 repo +git clone --recursive https://github.com/pytorch/pytorch.git +cd pytorch + +# Optional: update QNNPACK submodule to latest revision +git submodule update --remote third_party/QNNPACK + +# Build Caffe2 (including binaries) for Android, and push to device +scripts/build_android.sh -DANDROID_ABI=arm64-v8a -DANDROID_TOOLCHAIN=clang -DBUILD_BINARY=ON +adb push build_android/bin/speed_benchmark /data/local/tmp/speed_benchmark + +# Download model weights and copy them to Android device +wget https://s3.amazonaws.com/download.caffe2.ai/models/mobilenet_v2_1.0_224_quant/init_net.pb +adb push init_net.pb /data/local/tmp/init_net.pb + +# Download model graph and copy it to Android device +wget https://s3.amazonaws.com/download.caffe2.ai/models/mobilenet_v2_1.0_224_quant/predict_net.pb +adb push predict_net.pb /data/local/tmp/predict_net.pb + +# Run speed benchmark with 50 warm-up iterations and 10 measurement iterations +adb shell /data/local/tmp/speed_benchmark \ + --net /data/local/tmp/predict_net.pb \ + --init_net /data/local/tmp/init_net.pb \ + --input data --input_dims 1,3,224,224 --input_type float \ + --warmup 50 --iter 10 +``` + +### PEP (Performance Evaluation Platform) Method + +[Facebook AI Performance Evaluation Platform](https://github.com/facebook/FAI-PEP) is a framework and backend agnostic benchmarking platform to compare machine learning inferencing runtime metrics on a set of models and a variety of backends. + +We use PEP to produce the results we have in our [blog](https://code.fb.com/ml-applications/qnnpack/) + +With an ARMv7 device connected: + +```bash +# Clone PyTorch 1.0 repo +mkdir ~/Code && cd ~/Code +git clone --recursive https://github.com/pytorch/pytorch.git +cd pytorch + +# Optional: update QNNPACK submodule to latest revision +git submodule update --remote third_party/QNNPACK + +# Clone PEP repo +cd ~/Code +git clone --recursive https://github.com/facebook/FAI-PEP.git aibench +cd aibench + +# Run PEP benchmark with cool specifications. Try changing that cmd with more specifications! +# First time compile could take 20+ minutes +./benchmarking/run_bench.py \ + --platform android \ + -b ~/Code/aibench/specifications/models/caffe2/mobilenet_v2/mobilenet_v2_quant.json \ + --platform android --repo_dir ~/Code/pytorch \ + --frameworks_dir ~/Code/aibench/specifications/frameworks --framework caffe2 +``` + +## Acknowledgements + +QNNPACK is developed by Marat Dukhan, Yiming Wu, Hao Lu, and Bert Maher. We thank Andrew Tulloch and Yangqing Jia for advice during the development of QNNPACK. + +## License + +QNNPACK is BSD licensed, as found in the [`LICENSE`](LICENSE) file. diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/bench/add.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/add.cc new file mode 100644 index 0000000000000..03e6c2be85027 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/add.cc @@ -0,0 +1,172 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include + +#include + +static void add_nc_q8(benchmark::State& state) { + const size_t batchSize = static_cast(state.range(0)); + const size_t channels = static_cast(state.range(1)); + + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector a(batchSize * channels); + std::vector b(batchSize * channels); + std::vector y(batchSize * channels); + std::generate(a.begin(), a.end(), std::ref(u8rng)); + std::generate(b.begin(), b.end(), std::ref(u8rng)); + + pytorch_qnnp_status status = pytorch_qnnp_initialize(); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to initialize QNNPACK"); + } + + pytorch_qnnp_operator_t addOperator = nullptr; + status = pytorch_qnnp_create_add_nc_q8( + channels, + 127 /* a:zero point */, + 1.0f /* a:scale */, + 127 /* b:zero point */, + 1.0f /* b:scale */, + 127 /* y:zero point */, + 1.0f /* y:scale */, + 1 /* y:min */, + 254 /* y:max */, + 0 /* flags */, + &addOperator); + if (status != pytorch_qnnp_status_success || addOperator == nullptr) { + state.SkipWithError("failed to create Q8 Add operator"); + } + + status = pytorch_qnnp_setup_add_nc_q8( + addOperator, + batchSize, + a.data(), + channels /* a:stride */, + b.data(), + channels /* b:stride */, + y.data(), + channels /* y:stride */); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to setup Q8 Add operator"); + } + + for (auto _ : state) { + status = pytorch_qnnp_run_operator(addOperator, nullptr /* thread pool */); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to run Q8 Add operator"); + } + } + + const size_t itemsPerIteration = batchSize * channels; + state.SetItemsProcessed( + int64_t(state.iterations()) * int64_t(itemsPerIteration)); + + const size_t bytesPerIteration = 3 * itemsPerIteration * sizeof(uint8_t); + state.SetBytesProcessed( + int64_t(state.iterations()) * int64_t(bytesPerIteration)); + + status = pytorch_qnnp_delete_operator(addOperator); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to delete Q8 Add operator"); + } +} + +static void add_nc_q8_inplace(benchmark::State& state) { + const size_t batchSize = static_cast(state.range(0)); + const size_t channels = static_cast(state.range(1)); + + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector a(batchSize * channels); + std::vector y(batchSize * channels); + std::generate(a.begin(), a.end(), std::ref(u8rng)); + + pytorch_qnnp_status status = pytorch_qnnp_initialize(); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to initialize QNNPACK"); + } + + pytorch_qnnp_operator_t addOperator = nullptr; + status = pytorch_qnnp_create_add_nc_q8( + channels, + 127 /* a:zero point */, + 1.0f /* a:scale */, + 127 /* b:zero point */, + 1.0f /* b:scale */, + 127 /* y:zero point */, + 1.0f /* y:scale */, + 1 /* y:min */, + 254 /* y:max */, + 0 /* flags */, + &addOperator); + if (status != pytorch_qnnp_status_success || addOperator == nullptr) { + state.SkipWithError("failed to create Q8 Add operator"); + } + + status = pytorch_qnnp_setup_add_nc_q8( + addOperator, + batchSize, + a.data(), + channels /* a:stride */, + y.data(), + channels /* b:stride */, + y.data(), + channels /* y:stride */); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to setup Q8 Add operator"); + } + + for (auto _ : state) { + status = pytorch_qnnp_run_operator(addOperator, nullptr /* thread pool */); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to run Q8 Add operator"); + } + } + + const size_t itemsPerIteration = batchSize * channels; + state.SetItemsProcessed( + int64_t(state.iterations()) * int64_t(itemsPerIteration)); + + const size_t bytesPerIteration = 3 * itemsPerIteration * sizeof(uint8_t); + state.SetBytesProcessed( + int64_t(state.iterations()) * int64_t(bytesPerIteration)); + + status = pytorch_qnnp_delete_operator(addOperator); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to delete Q8 Add operator"); + } +} + +static void CharacteristicArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "C"}); + + int32_t c = 16; + for (int32_t n = 224; n >= 7; n /= 2) { + b->Args({n * n, c}); + c *= 2; + } +} + +BENCHMARK(add_nc_q8)->Apply(CharacteristicArguments); +BENCHMARK(add_nc_q8_inplace)->Apply(CharacteristicArguments); + +#ifndef PYTORCH_QNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/bench/average-pooling.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/average-pooling.cc new file mode 100644 index 0000000000000..473c1e7ad16bb --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/average-pooling.cc @@ -0,0 +1,194 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +static void average_pooling_q8(benchmark::State& state, const char* net) { + const size_t batchSize = state.range(0); + const size_t inputHeight = state.range(1); + const size_t inputWidth = state.range(2); + const size_t poolingSize = state.range(3); + const size_t paddingSize = state.range(4); + const size_t stride = state.range(5); + const size_t channels = state.range(6); + + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + const size_t inputPixelStride = channels; + const size_t outputPixelStride = channels; + const size_t outputHeight = + (2 * paddingSize + inputHeight - poolingSize) / stride + 1; + const size_t outputWidth = + (2 * paddingSize + inputWidth - poolingSize) / stride + 1; + + std::vector input( + batchSize * inputHeight * inputWidth * inputPixelStride); + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::vector output( + batchSize * outputHeight * outputWidth * outputPixelStride); + std::fill(output.begin(), output.end(), 0xA5); + + pytorch_qnnp_status status = pytorch_qnnp_initialize(); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to initialize QNNPACK"); + } + + pytorch_qnnp_operator_t poolingOperator = nullptr; + status = pytorch_qnnp_create_average_pooling2d_nhwc_q8( + paddingSize, + paddingSize, + paddingSize, + paddingSize, + poolingSize, + poolingSize, + stride, + stride, + channels, + 127 /* input zero point */, + 0.75f /* input scale */, + 127 /* output zero point */, + 1.25f /* output scale */, + 0, + 255, + 0 /* flags */, + &poolingOperator); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to create Average Pooling operator"); + } + + status = pytorch_qnnp_setup_average_pooling2d_nhwc_q8( + poolingOperator, + batchSize, + inputHeight, + inputWidth, + input.data(), + inputPixelStride, + output.data(), + outputPixelStride, + nullptr /* thread pool */); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to setup Average Pooling operator"); + } + + for (auto _ : state) { + status = + pytorch_qnnp_run_operator(poolingOperator, nullptr /* thread pool */); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to run Average Pooling operator"); + } + } + + status = pytorch_qnnp_delete_operator(poolingOperator); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to delete Average Pooling operator"); + } + poolingOperator = nullptr; + + state.SetBytesProcessed( + uint64_t(state.iterations()) * batchSize * + (inputHeight * inputWidth + outputHeight * outputWidth) * channels * + sizeof(uint8_t)); +} + +/* ShuffleNet v1 with 1 group */ +static void ShuffleNetV1G1(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "K", "P", "S", "C"}); + + /* N H W K P S C */ + b->Args({1, 56, 56, 3, 1, 2, 24}); + b->Args({1, 28, 28, 3, 1, 2, 144}); + b->Args({1, 14, 14, 3, 1, 2, 288}); + b->Args({1, 7, 7, 3, 1, 2, 576}); +} + +/* ShuffleNet v1 with 2 groups */ +static void ShuffleNetV1G2(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "K", "P", "S", "C"}); + + /* N H W K P S C */ + b->Args({1, 56, 56, 3, 1, 2, 24}); + b->Args({1, 28, 28, 3, 1, 2, 200}); + b->Args({1, 14, 14, 3, 1, 2, 400}); + b->Args({1, 7, 7, 3, 1, 2, 800}); +} + +/* ShuffleNet v1 with 3 groups */ +static void ShuffleNetV1G3(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "K", "P", "S", "C"}); + + /* N H W K P S C */ + b->Args({1, 56, 56, 3, 1, 2, 24}); + b->Args({1, 28, 28, 3, 1, 2, 240}); + b->Args({1, 14, 14, 3, 1, 2, 480}); + b->Args({1, 7, 7, 3, 1, 2, 960}); +} + +/* ShuffleNet v1 with 4 groups */ +static void ShuffleNetV1G4(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "K", "P", "S", "C"}); + + /* N H W K P S C */ + b->Args({1, 56, 56, 3, 1, 2, 24}); + b->Args({1, 28, 28, 3, 1, 2, 272}); + b->Args({1, 14, 14, 3, 1, 2, 576}); + b->Args({1, 7, 7, 3, 1, 2, 1088}); +} + +/* ShuffleNet v1 with 8 groups */ +static void ShuffleNetV1G8(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "K", "P", "S", "C"}); + + /* N H W K P S C */ + b->Args({1, 56, 56, 3, 1, 2, 24}); + b->Args({1, 28, 28, 3, 1, 2, 384}); + b->Args({1, 14, 14, 3, 1, 2, 768}); + b->Args({1, 7, 7, 3, 1, 2, 1536}); +} + +BENCHMARK_CAPTURE( + average_pooling_q8, + shufflenet_v1_g1, + "ShuffleNet v1 (1 group)") + ->Apply(ShuffleNetV1G1); +BENCHMARK_CAPTURE( + average_pooling_q8, + shufflenet_v1_g2, + "ShuffleNet v1 (2 groups)") + ->Apply(ShuffleNetV1G2); +BENCHMARK_CAPTURE( + average_pooling_q8, + shufflenet_v1_g3, + "ShuffleNet v1 (3 groups)") + ->Apply(ShuffleNetV1G3); +BENCHMARK_CAPTURE( + average_pooling_q8, + shufflenet_v1_g4, + "ShuffleNet v1 (4 groups)") + ->Apply(ShuffleNetV1G4); +BENCHMARK_CAPTURE( + average_pooling_q8, + shufflenet_v1_g8, + "ShuffleNet v1 (8 groups)") + ->Apply(ShuffleNetV1G8); + +#ifndef PYTORCH_QNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/bench/channel-shuffle.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/channel-shuffle.cc new file mode 100644 index 0000000000000..3b275eb27aea0 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/channel-shuffle.cc @@ -0,0 +1,249 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include + +#include + +static void channel_shuffle_x8(benchmark::State& state, const char* net) { + const size_t batchSize = static_cast(state.range(0)); + const size_t groups = static_cast(state.range(1)); + const size_t groupChannels = static_cast(state.range(2)); + + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input(batchSize * groups * groupChannels); + std::vector output(batchSize * groups * groupChannels); + std::generate(input.begin(), input.end(), std::ref(u8rng)); + + pytorch_qnnp_status status = pytorch_qnnp_initialize(); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to initialize QNNPACK"); + } + + pytorch_qnnp_operator_t channelShuffleOperator = nullptr; + status = pytorch_qnnp_create_channel_shuffle_nc_x8( + groups, groupChannels, 0 /* flags */, &channelShuffleOperator); + if (status != pytorch_qnnp_status_success || + channelShuffleOperator == nullptr) { + state.SkipWithError("failed to create X8 Channel Shuffle operator"); + } + + status = pytorch_qnnp_setup_channel_shuffle_nc_x8( + channelShuffleOperator, + batchSize, + input.data(), + groups * groupChannels /* input:stride */, + output.data(), + groups * groupChannels /* output:stride */); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to setup X8 Channel Shuffle operator"); + } + + for (auto _ : state) { + status = pytorch_qnnp_run_operator( + channelShuffleOperator, nullptr /* thread pool */); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to run X8 Channel Shuffle operator"); + } + } + + const size_t itemsPerIteration = batchSize * groups * groupChannels; + state.SetItemsProcessed( + int64_t(state.iterations()) * int64_t(itemsPerIteration)); + + const size_t bytesPerIteration = 2 * itemsPerIteration * sizeof(uint8_t); + state.SetBytesProcessed( + int64_t(state.iterations()) * int64_t(bytesPerIteration)); + + status = pytorch_qnnp_delete_operator(channelShuffleOperator); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to delete X8 Channel Shuffle operator"); + } +} + +static void ShuffleNetV1G2Arguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "G", "GC"}); + + /******** Stage 2 ********/ + /* H W G CG */ + b->Args({56 * 56, 2, 25}); + b->Args({28 * 28, 2, 25}); + + /******** Stage 3 ********/ + /* H W G CG */ + b->Args({28 * 28, 2, 50}); + b->Args({14 * 14, 2, 50}); + + /******** Stage 4 ********/ + /* H W G CG */ + b->Args({14 * 14, 2, 100}); + b->Args({7 * 7, 2, 100}); +} + +static void ShuffleNetV1G3Arguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "G", "GC"}); + + /******** Stage 2 *******/ + /* H W G CG */ + b->Args({56 * 56, 3, 20}); + b->Args({28 * 28, 3, 20}); + + /******** Stage 3 *******/ + /* H W G CG */ + b->Args({28 * 28, 3, 40}); + b->Args({14 * 14, 3, 40}); + + /******** Stage 4 *******/ + /* H W G CG */ + b->Args({14 * 14, 3, 80}); + b->Args({7 * 7, 3, 80}); +} + +static void ShuffleNetV1G4Arguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "G", "GC"}); + + /******** Stage 2 *******/ + /* H W G CG */ + b->Args({56 * 56, 4, 17}); + b->Args({28 * 28, 4, 17}); + + /******** Stage 3 *******/ + /* H W G CG */ + b->Args({28 * 28, 4, 34}); + b->Args({14 * 14, 4, 34}); + + /******** Stage 4 *******/ + /* H W G CG */ + b->Args({14 * 14, 4, 68}); + b->Args({7 * 7, 4, 68}); +} + +static void ShuffleNetV1G8Arguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "G", "GC"}); + + /******** Stage 2 *******/ + /* H W G CG */ + b->Args({56 * 56, 8, 12}); + b->Args({28 * 28, 8, 12}); + + /******** Stage 3 *******/ + /* H W G CG */ + b->Args({28 * 28, 8, 24}); + b->Args({14 * 14, 8, 24}); + + /******** Stage 4 *******/ + /* H W G CG */ + b->Args({14 * 14, 8, 48}); + b->Args({7 * 7, 8, 48}); +} + +static void ShuffleNetV2x0_5Arguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "G", "GC"}); + + /******** Stage 2 *******/ + /* H W G CG */ + b->Args({28 * 28, 2, 24}); + + /******** Stage 3 *******/ + /* H W G CG */ + b->Args({14 * 14, 2, 48}); + + /******** Stage 4 *******/ + /* H W G CG */ + b->Args({7 * 7, 2, 96}); +} + +static void ShuffleNetV2x1_0Arguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "G", "GC"}); + + /******** Stage 2 ********/ + /* H W G CG */ + b->Args({28 * 28, 2, 58}); + + /******** Stage 3 ********/ + /* H W G CG */ + b->Args({14 * 14, 2, 116}); + + /******** Stage 4 ********/ + /* H W G CG */ + b->Args({7 * 7, 2, 232}); +} + +static void ShuffleNetV2x1_5Arguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "G", "GC"}); + + /******** Stage 2 ********/ + /* H W G CG */ + b->Args({28 * 28, 2, 88}); + + /******** Stage 3 ********/ + /* H W G CG */ + b->Args({14 * 14, 2, 176}); + + /******** Stage 4 ********/ + /* H W G CG */ + b->Args({7 * 7, 2, 352}); +} + +static void ShuffleNetV2x2_0Arguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "G", "GC"}); + + /******** Stage 2 ********/ + /* H W G CG */ + b->Args({28 * 28, 2, 122}); + + /******** Stage 3 ********/ + /* H W G CG */ + b->Args({14 * 14, 2, 244}); + + /******** Stage 4 ********/ + /* H W G CG */ + b->Args({7 * 7, 2, 488}); +} + +BENCHMARK_CAPTURE( + channel_shuffle_x8, + shufflenet_v1_g2, + "ShuffleNet v1 (2 groups)") + ->Apply(ShuffleNetV1G2Arguments); +BENCHMARK_CAPTURE( + channel_shuffle_x8, + shufflenet_v1_g3, + "ShuffleNet v1 (3 groups)") + ->Apply(ShuffleNetV1G3Arguments); +BENCHMARK_CAPTURE( + channel_shuffle_x8, + shufflenet_v1_g4, + "ShuffleNet v1 (4 groups)") + ->Apply(ShuffleNetV1G4Arguments); +BENCHMARK_CAPTURE( + channel_shuffle_x8, + shufflenet_v1_g8, + "ShuffleNet v1 (8 groups)") + ->Apply(ShuffleNetV1G8Arguments); +BENCHMARK_CAPTURE(channel_shuffle_x8, shufflenet_v2_x05, "ShuffleNet v2 x0.5") + ->Apply(ShuffleNetV2x0_5Arguments); +BENCHMARK_CAPTURE(channel_shuffle_x8, shufflenet_v2_x10, "ShuffleNet v2 x1.0") + ->Apply(ShuffleNetV2x1_0Arguments); +BENCHMARK_CAPTURE(channel_shuffle_x8, shufflenet_v2_x15, "ShuffleNet v2 x1.5") + ->Apply(ShuffleNetV2x1_5Arguments); +BENCHMARK_CAPTURE(channel_shuffle_x8, shufflenet_v2_x20, "ShuffleNet v2 x2.0") + ->Apply(ShuffleNetV2x2_0Arguments); + +#ifndef PYTORCH_QNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/bench/convolution.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/convolution.cc new file mode 100644 index 0000000000000..623c2d1081438 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/convolution.cc @@ -0,0 +1,1013 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +static void convolution_q8(benchmark::State& state, const char* net) { + const size_t batchSize = state.range(0); + const size_t inputHeight = state.range(1); + const size_t inputWidth = state.range(2); + const size_t kernelHeight = state.range(3); + const size_t kernelWidth = state.range(4); + const size_t subsampling = state.range(5); + const size_t dilation = state.range(6); + const size_t groups = state.range(7); + const size_t groupInputChannels = state.range(8); + const size_t groupOutputChannels = state.range(9); + + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto s32rng = + std::bind(std::uniform_int_distribution(-10000, 10000), rng); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + const size_t outputPixelStride = groups * groupOutputChannels; + const size_t inputPixelStride = groups * groupInputChannels; + const size_t effectiveKernelHeight = (kernelHeight - 1) * dilation + 1; + const size_t effectiveKernelWidth = (kernelWidth - 1) * dilation + 1; + const size_t paddingLeft = effectiveKernelWidth / 2; + const size_t paddingTop = effectiveKernelHeight / 2; + const size_t paddingRight = effectiveKernelWidth - 1 - paddingLeft; + const size_t paddingBottom = effectiveKernelHeight - 1 - paddingTop; + const size_t outputHeight = + (paddingTop + inputHeight + paddingBottom - effectiveKernelHeight) / + subsampling + + 1; + const size_t outputWidth = + (paddingLeft + inputWidth + paddingRight - effectiveKernelWidth) / + subsampling + + 1; + + std::vector input( + batchSize * inputHeight * inputWidth * inputPixelStride); + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::vector kernel( + groups * groupOutputChannels * kernelHeight * kernelWidth * + groupInputChannels); + std::generate(kernel.begin(), kernel.end(), std::ref(u8rng)); + std::vector bias(groups * groupOutputChannels); + std::generate(bias.begin(), bias.end(), std::ref(s32rng)); + std::vector output( + batchSize * outputHeight * outputWidth * outputPixelStride); + + pytorch_qnnp_status status = pytorch_qnnp_initialize(); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to initialize QNNPACK"); + } + + pytorch_qnnp_operator_t convolutionObject = nullptr; + status = pytorch_qnnp_create_convolution2d_nhwc_q8( + paddingTop, + paddingRight, + paddingBottom, + paddingLeft, + kernelHeight, + kernelWidth, + subsampling, + subsampling, + dilation, + dilation, + groups, + groupInputChannels, + groupOutputChannels, + 127, + 0.5f, + 127, + 0.5f, + kernel.data(), + bias.data(), + 127, + 0.5f, + 0, + 255, + 0 /* flags */, + &convolutionObject); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to create Convolution operator"); + } + + status = pytorch_qnnp_setup_convolution2d_nhwc_q8( + convolutionObject, + batchSize, + inputHeight, + inputWidth, + input.data(), + inputPixelStride, + output.data(), + outputPixelStride, + nullptr /* thread pool */); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to setup Convolution operator"); + } + + for (auto _ : state) { + pytorch_qnnp_run_operator(convolutionObject, nullptr /* thread pool */); + } + + status = pytorch_qnnp_delete_operator(convolutionObject); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to delete Convolution operator"); + } + convolutionObject = nullptr; + + state.SetItemsProcessed( + uint64_t(state.iterations()) * 2 * batchSize * outputHeight * + outputWidth * groups * groupInputChannels * groupOutputChannels * + kernelHeight * kernelWidth); +} + +/* ShuffleNet v1 with 1 group */ +static void ShuffleNetV1G1(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /*********************** Conv 1 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 24}); + /*************** Stage 2: stride-2 unit **************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 36}); + b->Args({1, 56, 56, 3, 3, 2, 1, 36, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 36, 120}); + /*************** Stage 2: stride-1 units *************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 144, 36}); + b->Args({1, 28, 28, 3, 3, 2, 1, 36, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 36, 144}); + /*************** Stage 3: stride-2 unit **************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 144, 72}); + b->Args({1, 28, 28, 3, 3, 2, 1, 72, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 72, 144}); + /*************** Stage 3: stride-1 units *************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 288, 72}); + b->Args({1, 14, 14, 3, 3, 2, 1, 72, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 72, 288}); + /*************** Stage 4: stride-2 unit **************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 288, 144}); + b->Args({1, 14, 14, 3, 3, 2, 1, 144, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 144, 288}); + /*************** Stage 4: stride-1 units *************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 576, 144}); + b->Args({1, 7, 7, 3, 3, 2, 1, 144, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 144, 576}); +} + +/* ShuffleNet v1 with 2 groups */ +static void ShuffleNetV1G2(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /*********************** Conv 1 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 24}); + /*************** Stage 2: stride-2 unit **************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 50}); + b->Args({1, 56, 56, 3, 3, 2, 1, 50, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 2, 25, 88}); + /*************** Stage 2: stride-1 units *************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 1, 1, 1, 1, 2, 100, 25}); + b->Args({1, 28, 28, 3, 3, 2, 1, 50, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 2, 25, 100}); + /*************** Stage 3: stride-2 unit **************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 1, 1, 1, 1, 2, 100, 50}); + b->Args({1, 28, 28, 3, 3, 2, 1, 100, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 2, 50, 100}); + /*************** Stage 3: stride-1 units *************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 1, 1, 1, 1, 2, 200, 50}); + b->Args({1, 14, 14, 3, 3, 2, 1, 100, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 2, 50, 200}); + /*************** Stage 4: stride-2 unit **************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 1, 1, 1, 1, 2, 200, 100}); + b->Args({1, 14, 14, 3, 3, 2, 1, 200, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 2, 100, 200}); + /*************** Stage 4: stride-1 units *************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 7, 7, 1, 1, 1, 1, 2, 400, 100}); + b->Args({1, 7, 7, 3, 3, 2, 1, 200, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 2, 100, 400}); +} + +/* ShuffleNet v1 with 3 groups */ +static void ShuffleNetV1G3(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /*********************** Conv 1 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 24}); + /*************** Stage 2: stride-2 unit **************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 60}); + b->Args({1, 56, 56, 3, 3, 2, 1, 60, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 3, 20, 72}); + /*************** Stage 2: stride-1 units *************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 1, 1, 1, 1, 3, 80, 20}); + b->Args({1, 28, 28, 3, 3, 2, 1, 60, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 3, 20, 80}); + /*************** Stage 3: stride-2 unit **************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 1, 1, 1, 1, 3, 80, 40}); + b->Args({1, 28, 28, 3, 3, 2, 1, 120, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 3, 40, 80}); + /*************** Stage 3: stride-1 units *************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 1, 1, 1, 1, 3, 160, 40}); + b->Args({1, 14, 14, 3, 3, 2, 1, 120, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 3, 40, 160}); + /*************** Stage 4: stride-2 unit **************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 1, 1, 1, 1, 3, 160, 80}); + b->Args({1, 14, 14, 3, 3, 2, 1, 240, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 3, 80, 160}); + /*************** Stage 4: stride-1 units *************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 7, 7, 1, 1, 1, 1, 3, 320, 80}); + b->Args({1, 7, 7, 3, 3, 2, 1, 240, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 3, 80, 320}); +} + +/* ShuffleNet v1 with 4 groups */ +static void ShuffleNetV1G4(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /*********************** Conv 1 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 24}); + /*************** Stage 2: stride-2 unit **************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 68}); + b->Args({1, 56, 56, 3, 3, 2, 1, 68, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 4, 17, 62}); + /*************** Stage 2: stride-1 units *************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 1, 1, 1, 1, 4, 68, 17}); + b->Args({1, 28, 28, 3, 3, 2, 1, 68, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 4, 17, 68}); + /*************** Stage 3: stride-2 unit **************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 1, 1, 1, 1, 4, 68, 34}); + b->Args({1, 28, 28, 3, 3, 2, 1, 136, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 4, 34, 68}); + /*************** Stage 3: stride-1 units *************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 1, 1, 1, 1, 4, 136, 34}); + b->Args({1, 14, 14, 3, 3, 2, 1, 136, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 4, 34, 136}); + /*************** Stage 4: stride-2 unit **************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 1, 1, 1, 1, 4, 136, 68}); + b->Args({1, 14, 14, 3, 3, 2, 1, 272, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 4, 68, 136}); + /*************** Stage 4: stride-1 units *************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 7, 7, 1, 1, 1, 1, 4, 272, 68}); + b->Args({1, 7, 7, 3, 3, 2, 1, 272, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 4, 68, 272}); +} + +/* ShuffleNet v1 with 8 groups */ +static void ShuffleNetV1G8(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /*********************** Conv 1 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 24}); + /*************** Stage 2: stride-2 unit **************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 96}); + b->Args({1, 56, 56, 3, 3, 2, 1, 96, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 8, 12, 45}); + /*************** Stage 2: stride-1 units *************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 1, 1, 1, 1, 8, 48, 12}); + b->Args({1, 28, 28, 3, 3, 2, 1, 96, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 8, 12, 48}); + /*************** Stage 3: stride-2 unit **************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 1, 1, 1, 1, 8, 48, 24}); + b->Args({1, 28, 28, 3, 3, 2, 1, 192, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 8, 24, 48}); + /*************** Stage 3: stride-1 units *************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 1, 1, 1, 1, 8, 96, 24}); + b->Args({1, 14, 14, 3, 3, 2, 1, 192, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 8, 24, 96}); + /*************** Stage 4: stride-2 unit **************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 1, 1, 1, 1, 8, 96, 48}); + b->Args({1, 14, 14, 3, 3, 2, 1, 384, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 8, 48, 96}); + /*************** Stage 4: stride-1 units *************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 7, 7, 1, 1, 1, 1, 8, 192, 48}); + b->Args({1, 7, 7, 3, 3, 2, 1, 384, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 8, 48, 192}); +} + +/* ShuffleNet v2 (0.5X scale) */ +static void ShuffleNetV2X05(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /*********************** Conv 1 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 24}); + /********************** Stage 2 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 56, 56, 3, 3, 2, 1, 24, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 24, 24}); + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 24}); + b->Args({1, 28, 28, 3, 3, 1, 1, 24, 1, 1}); + /********************** Stage 3 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 3, 3, 2, 1, 48, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 48, 48}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 48, 48}); + b->Args({1, 14, 14, 3, 3, 1, 1, 48, 1, 1}); + /********************** Stage 4 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 3, 3, 2, 1, 96, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 96, 96}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 96, 96}); + b->Args({1, 7, 7, 3, 3, 1, 1, 96, 1, 1}); + /*********************** Conv 5 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 192, 1024}); +} + +/* ShuffleNet v2 (1.0X scale) */ +static void ShuffleNetV2X10(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /*********************** Conv 1 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 24}); + /********************** Stage 2 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 56, 56, 3, 3, 2, 1, 24, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 24, 58}); + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 58}); + b->Args({1, 56, 56, 3, 3, 2, 1, 58, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 58, 58}); + b->Args({1, 28, 28, 3, 3, 1, 1, 58, 1, 1}); + /********************** Stage 3 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 3, 3, 2, 1, 116, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 116, 116}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 116, 116}); + b->Args({1, 14, 14, 3, 3, 1, 1, 116, 1, 1}); + /********************** Stage 4 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 3, 3, 2, 1, 232, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 232, 232}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 232, 232}); + b->Args({1, 7, 7, 3, 3, 1, 1, 232, 1, 1}); + /*********************** Conv 5 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 464, 1024}); +} + +/* ShuffleNet v2 (1.5X scale) */ +static void ShuffleNetV2X15(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /*********************** Conv 1 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 24}); + /********************** Stage 2 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 56, 56, 3, 3, 2, 1, 24, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 24, 88}); + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 88}); + b->Args({1, 56, 56, 3, 3, 2, 1, 88, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 88, 88}); + b->Args({1, 28, 28, 3, 3, 1, 1, 88, 1, 1}); + /********************** Stage 3 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 3, 3, 2, 1, 176, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 176, 176}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 176, 176}); + b->Args({1, 14, 14, 3, 3, 1, 1, 176, 1, 1}); + /********************** Stage 4 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 3, 3, 2, 1, 352, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 352, 352}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 352, 352}); + b->Args({1, 7, 7, 3, 3, 1, 1, 352, 1, 1}); + /*********************** Conv 5 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 704, 1024}); +} + +/* ShuffleNet v2 (2.0X scale) */ +static void ShuffleNetV2X20(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /*********************** Conv 1 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 24}); + /********************** Stage 2 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 56, 56, 3, 3, 2, 1, 24, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 24, 122}); + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 122}); + b->Args({1, 56, 56, 3, 3, 2, 1, 122, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 122, 122}); + b->Args({1, 28, 28, 3, 3, 1, 1, 122, 1, 1}); + /********************** Stage 3 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 3, 3, 2, 1, 244, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 244, 244}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 244, 244}); + b->Args({1, 14, 14, 3, 3, 1, 1, 244, 1, 1}); + /********************** Stage 4 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 3, 3, 2, 1, 488, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 488, 488}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 488, 488}); + b->Args({1, 7, 7, 3, 3, 1, 1, 488, 1, 1}); + /*********************** Conv 5 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 976, 2048}); +} + +static void MobileNetV1(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 32}); + b->Args({1, 112, 112, 3, 3, 1, 1, 32, 1, 1}); + b->Args({1, 112, 112, 1, 1, 1, 1, 1, 32, 64}); + b->Args({1, 112, 112, 3, 3, 2, 1, 64, 1, 1}); + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 64, 128}); + b->Args({1, 56, 56, 3, 3, 1, 1, 128, 1, 1}); + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 128, 128}); + b->Args({1, 56, 56, 3, 3, 2, 1, 128, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 128, 256}); + b->Args({1, 28, 28, 3, 3, 1, 1, 256, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 256, 256}); + b->Args({1, 28, 28, 3, 3, 2, 1, 256, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 256, 512}); + b->Args({1, 14, 14, 3, 3, 1, 1, 512, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 512, 512}); + b->Args({1, 14, 14, 3, 3, 2, 1, 512, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 512, 1024}); + b->Args({1, 7, 7, 3, 3, 1, 1, 1024, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 1024, 1024}); +} + +static void MobileNetV2(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 32}); + + /******************** Bottleneck 1 *******************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 112, 112, 3, 3, 1, 1, 32, 1, 1}); + b->Args({1, 112, 112, 1, 1, 1, 1, 1, 32, 16}); + + /******************** Bottleneck 2 *******************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 112, 112, 1, 1, 1, 1, 1, 16, 96}); + b->Args({1, 112, 112, 3, 3, 2, 1, 96, 1, 1}); + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 96, 24}); + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 144}); + b->Args({1, 56, 56, 3, 3, 1, 1, 144, 1, 1}); + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 144, 24}); + + /******************** Bottleneck 3 *******************/ + /* N H W KH KW S D G GCin GCout */ + // b->Args({1, 56, 56, 1, 1, 1, 1, 1, 24, 144}); + b->Args({1, 56, 56, 3, 3, 2, 1, 144, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 144, 32}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 32, 192}); + b->Args({1, 28, 28, 3, 3, 1, 1, 192, 1, 1}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 192, 32}); + // b->Args({1, 28, 28, 1, 1, 1, 1, 1, 32, 192}); + // b->Args({1, 28, 28, 3, 3, 1, 1, 192, 1, 1}); + // b->Args({1, 28, 28, 1, 1, 1, 1, 1, 192, 32}); + + /******************** Bottleneck 4 *******************/ + /* N H W KH KW S D G GCin GCout */ + // b->Args({1, 28, 28, 1, 1, 1, 1, 1, 32, 192}); + b->Args({1, 28, 28, 3, 3, 2, 1, 192, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 192, 64}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 64, 384}); + b->Args({1, 14, 14, 3, 3, 1, 1, 384, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 384, 64}); + // b->Args({1, 14, 14, 1, 1, 1, 1, 1, 64, 384}); + // b->Args({1, 14, 14, 3, 3, 1, 1, 384, 1, 1}); + // b->Args({1, 14, 14, 1, 1, 1, 1, 1, 384, 64}); + // b->Args({1, 14, 14, 1, 1, 1, 1, 1, 64, 384}); + // b->Args({1, 14, 14, 3, 3, 1, 1, 384, 1, 1}); + // b->Args({1, 14, 14, 1, 1, 1, 1, 1, 384, 64}); + + /******************** Bottleneck 5 *******************/ + /* N H W KH KW S D G GCin GCout */ + // b->Args({1, 14, 14, 1, 1, 1, 1, 1, 64, 384}); + // b->Args({1, 14, 14, 3, 3, 1, 1, 384, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 384, 96}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 96, 576}); + b->Args({1, 14, 14, 3, 3, 1, 1, 576, 1, 1}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 576, 96}); + // b->Args({1, 14, 14, 1, 1, 1, 1, 1, 96, 576}); + // b->Args({1, 14, 14, 3, 3, 1, 1, 576, 1, 1}); + // b->Args({1, 14, 14, 1, 1, 1, 1, 1, 576, 96}); + + /******************** Bottleneck 6 *******************/ + /* N H W KH KW S D G GCin GCout */ + // b->Args({1, 14, 14, 1, 1, 1, 1, 1, 96, 576}); + b->Args({1, 14, 14, 3, 3, 2, 1, 576, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 576, 160}); + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 160, 960}); + b->Args({1, 7, 7, 3, 3, 1, 1, 960, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 960, 160}); + // b->Args({1, 7, 7, 1, 1, 1, 1, 1, 160, 960}); + // b->Args({1, 7, 7, 3, 3, 1, 1, 960, 1, 1}); + // b->Args({1, 7, 7, 1, 1, 1, 1, 1, 960, 160}); + + /******************** Bottleneck 7 *******************/ + /* N H W KH KW S D G GCin GCout */ + // b->Args({1, 7, 7, 1, 1, 1, 1, 1, 160, 960}); + // b->Args({1, 7, 7, 3, 3, 1, 1, 960, 1, 1}); + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 960, 320}); + + /**************** Pre-pooling Conv2D *****************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 320, 1280}); + /**************** Post-pooling Conv2D ****************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 1, 1, 1, 1, 1, 1, 1, 1280, 1000}); +} + +/* SqueezeNet 1.0 */ +static void SqueezeNetV10(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /********************** Conv 1 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 224, 224, 7, 7, 2, 1, 1, 3, 96}); + /********************** Fire 2 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 55, 55, 1, 1, 1, 1, 1, 96, 16}); + b->Args({1, 55, 55, 1, 1, 1, 1, 1, 16, 64}); + b->Args({1, 55, 55, 3, 3, 1, 1, 1, 16, 64}); + /********************** Fire 3 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 56, 55, 1, 1, 1, 1, 1, 128, 16}); + /*b->Args({1, 55, 55, 1, 1, 1, 1, 1, 16, 64});*/ + /*b->Args({1, 55, 55, 3, 3, 1, 1, 1, 16, 64});*/ + /********************** Fire 4 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 55, 55, 1, 1, 1, 1, 1, 128, 32}); + b->Args({1, 55, 55, 1, 1, 1, 1, 1, 32, 128}); + b->Args({1, 55, 55, 3, 3, 1, 1, 1, 32, 128}); + /********************** Fire 5 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 27, 27, 1, 1, 1, 1, 1, 256, 32}); + b->Args({1, 27, 27, 1, 1, 1, 1, 1, 32, 128}); + b->Args({1, 27, 27, 3, 3, 1, 1, 1, 32, 128}); + /********************** Fire 6 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 27, 27, 1, 1, 1, 1, 1, 256, 48}); + b->Args({1, 27, 27, 1, 1, 1, 1, 1, 48, 192}); + b->Args({1, 27, 27, 3, 3, 1, 1, 1, 48, 192}); + /********************** Fire 7 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 27, 27, 1, 1, 1, 1, 1, 384, 48}); + /*b->Args({1, 27, 27, 1, 1, 1, 1, 1, 48, 192});*/ + /*b->Args({1, 27, 27, 3, 3, 1, 1, 1, 48, 192});*/ + /********************** Fire 8 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 27, 27, 1, 1, 1, 1, 1, 384, 64}); + b->Args({1, 27, 27, 1, 1, 1, 1, 1, 64, 256}); + b->Args({1, 27, 27, 3, 3, 1, 1, 1, 64, 256}); + /********************** Fire 9 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 13, 13, 1, 1, 1, 1, 1, 512, 64}); + b->Args({1, 13, 13, 1, 1, 1, 1, 1, 64, 256}); + b->Args({1, 13, 13, 3, 3, 1, 1, 1, 64, 256}); + /********************* Conv 10 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 13, 13, 1, 1, 1, 1, 1, 512, 1000}); +} + +/* SqueezeNet 1.1 */ +static void SqueezeNetV11(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /********************** Conv 1 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 224, 224, 3, 3, 2, 1, 1, 3, 64}); + /********************** Fire 2 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 55, 55, 1, 1, 1, 1, 1, 64, 16}); + b->Args({1, 55, 55, 1, 1, 1, 1, 1, 16, 64}); + b->Args({1, 55, 55, 3, 3, 1, 1, 1, 16, 64}); + /********************** Fire 3 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 55, 55, 1, 1, 1, 1, 1, 128, 16}); + /*b->Args({1, 55, 55, 1, 1, 1, 1, 1, 16, 64});*/ + /*b->Args({1, 55, 55, 3, 3, 1, 1, 1, 16, 64});*/ + /********************** Fire 4 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 27, 27, 1, 1, 1, 1, 1, 128, 32}); + b->Args({1, 27, 27, 1, 1, 1, 1, 1, 32, 128}); + b->Args({1, 27, 27, 3, 3, 1, 1, 1, 32, 128}); + /********************** Fire 5 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 27, 27, 1, 1, 1, 1, 1, 256, 32}); + /*b->Args({1, 27, 27, 1, 1, 1, 1, 1, 32, 128});*/ + /*b->Args({1, 27, 27, 3, 3, 1, 1, 1, 32, 128});*/ + /********************** Fire 6 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 13, 13, 1, 1, 1, 1, 1, 256, 48}); + b->Args({1, 13, 13, 1, 1, 1, 1, 1, 48, 192}); + b->Args({1, 13, 13, 3, 3, 1, 1, 1, 48, 192}); + /********************** Fire 7 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 13, 13, 1, 1, 1, 1, 1, 384, 48}); + /*b->Args({1, 13, 13, 1, 1, 1, 1, 1, 48, 192});*/ + /*b->Args({1, 13, 13, 3, 3, 1, 1, 1, 48, 192});*/ + /********************** Fire 8 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 13, 13, 1, 1, 1, 1, 1, 384, 64}); + b->Args({1, 13, 13, 1, 1, 1, 1, 1, 64, 256}); + b->Args({1, 13, 13, 3, 3, 1, 1, 1, 64, 256}); + /********************** Fire 9 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 13, 13, 1, 1, 1, 1, 1, 512, 64}); + /*b->Args({1, 13, 13, 1, 1, 1, 1, 1, 64, 256});*/ + /*b->Args({1, 13, 13, 3, 3, 1, 1, 1, 64, 256});*/ + /********************* Conv 10 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 13, 13, 1, 1, 1, 1, 1, 512, 1000}); +} + +static void ResNet18(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /********************* Conv 1 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 224, 224, 7, 7, 2, 1, 1, 3, 64}); + /******************** Conv 2.X ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 56, 56, 3, 3, 1, 1, 1, 64, 64}); + /******************** Conv 3.X ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 56, 56, 3, 3, 2, 1, 1, 64, 128}); + b->Args({1, 28, 28, 3, 3, 1, 1, 1, 128, 128}); + b->Args({1, 56, 56, 1, 1, 2, 1, 1, 64, 128}); + /******************** Conv 4.X ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 3, 3, 2, 1, 1, 128, 256}); + b->Args({1, 14, 14, 3, 3, 1, 1, 1, 256, 256}); + b->Args({1, 28, 28, 1, 1, 2, 1, 1, 128, 256}); + /******************** Conv 5.X ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 3, 3, 2, 1, 1, 256, 512}); + b->Args({1, 7, 7, 3, 3, 1, 1, 1, 512, 512}); + b->Args({1, 14, 14, 1, 1, 2, 1, 1, 256, 512}); +} + +static void ResNet50(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /********************* Conv 1 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 224, 224, 7, 7, 2, 1, 1, 3, 64}); + /******************** Conv 2.1 ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 64, 64}); + b->Args({1, 56, 56, 3, 3, 1, 1, 1, 64, 64}); + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 64, 256}); + /*b->Args({1, 56, 56, 1, 1, 1, 1, 1, 64, 256});*/ + /******************** Conv 2.X ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 256, 64}); + /*b->Args({1, 56, 56, 3, 3, 1, 1, 1, 64, 64});*/ + /*b->Args({1, 56, 56, 1, 1, 1, 1, 1, 64, 256});*/ + /******************** Conv 3.1 ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 256, 128}); + b->Args({1, 56, 56, 3, 3, 2, 1, 1, 128, 128}); + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 128, 512}); + b->Args({1, 56, 56, 1, 1, 2, 1, 1, 256, 512}); + /******************** Conv 3.X ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 512, 128}); + b->Args({1, 28, 28, 3, 3, 1, 1, 1, 128, 128}); + /*b->Args({1, 28, 28, 1, 1, 1, 1, 1, 128, 512});*/ + /******************** Conv 4.1 ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 512, 256}); + b->Args({1, 28, 28, 3, 3, 2, 1, 1, 256, 256}); + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 256, 1024}); + b->Args({1, 28, 28, 1, 1, 2, 1, 1, 512, 1024}); + /******************** Conv 4.X ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 1024, 256}); + b->Args({1, 14, 14, 3, 3, 1, 1, 1, 256, 256}); + /*b->Args({1, 14, 14, 1, 1, 1, 1, 1, 256, 1024});*/ + /******************** Conv 5.1 ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 1024, 512}); + b->Args({1, 14, 14, 3, 3, 2, 1, 1, 512, 512}); + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 512, 2048}); + b->Args({1, 14, 14, 1, 1, 2, 1, 1, 1024, 2048}); + /******************** Conv 5.X ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 7, 7, 1, 1, 1, 1, 1, 2048, 512}); + b->Args({1, 7, 7, 3, 3, 1, 1, 1, 512, 512}); + /*b->Args({1, 7, 7, 1, 1, 1, 1, 1, 512, 2048});*/ +} + +static void VGG(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /********************* Conv 1.1 ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 224, 224, 3, 3, 1, 1, 1, 3, 64}); + /********************* Conv 1.2 ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 224, 224, 3, 3, 1, 1, 1, 64, 64}); + + /********************* Conv 2.1 ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 112, 112, 3, 3, 1, 1, 1, 64, 128}); + /********************* Conv 2.2 ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 112, 112, 3, 3, 1, 1, 1, 128, 128}); + + /********************* Conv 3.1 ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 56, 56, 3, 3, 1, 1, 1, 128, 256}); + /********************* Conv 3.2 ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 56, 56, 3, 3, 1, 1, 1, 256, 256}); + /********************* Conv 3.3 ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 56, 56, 1, 1, 1, 1, 1, 256, 256}); + + /********************* Conv 4.1 ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 3, 3, 1, 1, 1, 256, 512}); + /********************* Conv 4.2 ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 3, 3, 1, 1, 1, 512, 512}); + /********************* Conv 4.3 ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 28, 28, 1, 1, 1, 1, 1, 512, 512}); + + /********************* Conv 5.X ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 3, 3, 1, 1, 1, 512, 512}); + /********************* Conv 5.3 ********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 14, 14, 1, 1, 1, 1, 1, 512, 512}); +} + +static void DWConv3x3(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /********************** 96 x 96 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 96, 96, 3, 3, 1, 1, 512, 1, 1}); + b->Args({1, 96, 96, 3, 3, 1, 1, 256, 1, 1}); + b->Args({1, 96, 96, 3, 3, 1, 1, 128, 1, 1}); + b->Args({1, 96, 96, 3, 3, 1, 1, 64, 1, 1}); + b->Args({1, 96, 96, 3, 3, 1, 1, 48, 1, 1}); + b->Args({1, 96, 96, 3, 3, 1, 1, 32, 1, 1}); + b->Args({1, 96, 96, 3, 3, 1, 1, 24, 1, 1}); + b->Args({1, 96, 96, 3, 3, 1, 1, 16, 1, 1}); + /********************** 32 x 32 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 32, 32, 3, 3, 1, 1, 768, 1, 1}); + b->Args({1, 32, 32, 3, 3, 1, 1, 512, 1, 1}); + b->Args({1, 32, 32, 3, 3, 1, 1, 256, 1, 1}); + b->Args({1, 32, 32, 3, 3, 1, 1, 128, 1, 1}); + b->Args({1, 32, 32, 3, 3, 1, 1, 64, 1, 1}); + b->Args({1, 32, 32, 3, 3, 1, 1, 48, 1, 1}); + b->Args({1, 32, 32, 3, 3, 1, 1, 32, 1, 1}); + b->Args({1, 32, 32, 3, 3, 1, 1, 24, 1, 1}); + b->Args({1, 32, 32, 3, 3, 1, 1, 16, 1, 1}); + /********************** 17 x 17 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 17, 17, 3, 3, 1, 1, 1024, 1, 1}); + b->Args({1, 17, 17, 3, 3, 1, 1, 768, 1, 1}); + b->Args({1, 17, 17, 3, 3, 1, 1, 512, 1, 1}); + b->Args({1, 17, 17, 3, 3, 1, 1, 384, 1, 1}); + b->Args({1, 17, 17, 3, 3, 1, 1, 256, 1, 1}); + b->Args({1, 17, 17, 3, 3, 1, 1, 128, 1, 1}); + b->Args({1, 17, 17, 3, 3, 1, 1, 64, 1, 1}); + b->Args({1, 17, 17, 3, 3, 1, 1, 32, 1, 1}); + b->Args({1, 17, 17, 3, 3, 1, 1, 16, 1, 1}); + /********************** 11 x 11 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 11, 11, 3, 3, 1, 1, 1024, 1, 1}); + b->Args({1, 11, 11, 3, 3, 1, 1, 768, 1, 1}); + b->Args({1, 11, 11, 3, 3, 1, 1, 512, 1, 1}); + b->Args({1, 11, 11, 3, 3, 1, 1, 384, 1, 1}); + b->Args({1, 11, 11, 3, 3, 1, 1, 256, 1, 1}); + b->Args({1, 11, 11, 3, 3, 1, 1, 192, 1, 1}); + b->Args({1, 11, 11, 3, 3, 1, 1, 128, 1, 1}); + b->Args({1, 11, 11, 3, 3, 1, 1, 64, 1, 1}); + b->Args({1, 11, 11, 3, 3, 1, 1, 32, 1, 1}); + b->Args({1, 11, 11, 3, 3, 1, 1, 16, 1, 1}); + /*********************** 7 x 7 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 7, 7, 3, 3, 1, 1, 1024, 1, 1}); + b->Args({1, 7, 7, 3, 3, 1, 1, 768, 1, 1}); + b->Args({1, 7, 7, 3, 3, 1, 1, 512, 1, 1}); + b->Args({1, 7, 7, 3, 3, 1, 1, 384, 1, 1}); + b->Args({1, 7, 7, 3, 3, 1, 1, 256, 1, 1}); + b->Args({1, 7, 7, 3, 3, 1, 1, 128, 1, 1}); + b->Args({1, 7, 7, 3, 3, 1, 1, 64, 1, 1}); + b->Args({1, 7, 7, 3, 3, 1, 1, 32, 1, 1}); + b->Args({1, 7, 7, 3, 3, 1, 1, 16, 1, 1}); +} + +static void DWConv3x3d2(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /********************** 96 x 96 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 96, 96, 3, 3, 1, 2, 512, 1, 1}); + b->Args({1, 96, 96, 3, 3, 1, 2, 256, 1, 1}); + b->Args({1, 96, 96, 3, 3, 1, 2, 128, 1, 1}); + b->Args({1, 96, 96, 3, 3, 1, 2, 64, 1, 1}); + b->Args({1, 96, 96, 3, 3, 1, 2, 48, 1, 1}); + b->Args({1, 96, 96, 3, 3, 1, 2, 32, 1, 1}); + b->Args({1, 96, 96, 3, 3, 1, 2, 24, 1, 1}); + b->Args({1, 96, 96, 3, 3, 1, 2, 16, 1, 1}); + /********************** 32 x 32 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 32, 32, 3, 3, 1, 2, 768, 1, 1}); + b->Args({1, 32, 32, 3, 3, 1, 2, 512, 1, 1}); + b->Args({1, 32, 32, 3, 3, 1, 2, 256, 1, 1}); + b->Args({1, 32, 32, 3, 3, 1, 2, 128, 1, 1}); + b->Args({1, 32, 32, 3, 3, 1, 2, 64, 1, 1}); + b->Args({1, 32, 32, 3, 3, 1, 2, 48, 1, 1}); + b->Args({1, 32, 32, 3, 3, 1, 2, 32, 1, 1}); + b->Args({1, 32, 32, 3, 3, 1, 2, 24, 1, 1}); + b->Args({1, 32, 32, 3, 3, 1, 2, 16, 1, 1}); + /********************** 17 x 17 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 17, 17, 3, 3, 1, 2, 1024, 1, 1}); + b->Args({1, 17, 17, 3, 3, 1, 2, 768, 1, 1}); + b->Args({1, 17, 17, 3, 3, 1, 2, 512, 1, 1}); + b->Args({1, 17, 17, 3, 3, 1, 2, 384, 1, 1}); + b->Args({1, 17, 17, 3, 3, 1, 2, 256, 1, 1}); + b->Args({1, 17, 17, 3, 3, 1, 2, 128, 1, 1}); + b->Args({1, 17, 17, 3, 3, 1, 2, 64, 1, 1}); + b->Args({1, 17, 17, 3, 3, 1, 2, 32, 1, 1}); + b->Args({1, 17, 17, 3, 3, 1, 2, 16, 1, 1}); + /********************** 11 x 11 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 11, 11, 3, 3, 1, 2, 1024, 1, 1}); + b->Args({1, 11, 11, 3, 3, 1, 2, 768, 1, 1}); + b->Args({1, 11, 11, 3, 3, 1, 2, 512, 1, 1}); + b->Args({1, 11, 11, 3, 3, 1, 2, 384, 1, 1}); + b->Args({1, 11, 11, 3, 3, 1, 2, 256, 1, 1}); + b->Args({1, 11, 11, 3, 3, 1, 2, 192, 1, 1}); + b->Args({1, 11, 11, 3, 3, 1, 2, 128, 1, 1}); + b->Args({1, 11, 11, 3, 3, 1, 2, 64, 1, 1}); + b->Args({1, 11, 11, 3, 3, 1, 2, 32, 1, 1}); + b->Args({1, 11, 11, 3, 3, 1, 2, 16, 1, 1}); + /*********************** 7 x 7 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 7, 7, 3, 3, 1, 2, 1024, 1, 1}); + b->Args({1, 7, 7, 3, 3, 1, 2, 768, 1, 1}); + b->Args({1, 7, 7, 3, 3, 1, 2, 512, 1, 1}); + b->Args({1, 7, 7, 3, 3, 1, 2, 384, 1, 1}); + b->Args({1, 7, 7, 3, 3, 1, 2, 256, 1, 1}); + b->Args({1, 7, 7, 3, 3, 1, 2, 128, 1, 1}); + b->Args({1, 7, 7, 3, 3, 1, 2, 64, 1, 1}); + b->Args({1, 7, 7, 3, 3, 1, 2, 32, 1, 1}); + b->Args({1, 7, 7, 3, 3, 1, 2, 16, 1, 1}); +} + +static void DWConv5x5(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "KH", "KW", "S", "D", "G", "GCin", "GCout"}); + + /********************** 96 x 96 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 96, 96, 5, 5, 1, 1, 512, 1, 1}); + b->Args({1, 96, 96, 5, 5, 1, 1, 256, 1, 1}); + b->Args({1, 96, 96, 5, 5, 1, 1, 128, 1, 1}); + b->Args({1, 96, 96, 5, 5, 1, 1, 64, 1, 1}); + b->Args({1, 96, 96, 5, 5, 1, 1, 48, 1, 1}); + b->Args({1, 96, 96, 5, 5, 1, 1, 32, 1, 1}); + b->Args({1, 96, 96, 5, 5, 1, 1, 24, 1, 1}); + b->Args({1, 96, 96, 5, 5, 1, 1, 16, 1, 1}); + /********************** 32 x 32 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 32, 32, 5, 5, 1, 1, 768, 1, 1}); + b->Args({1, 32, 32, 5, 5, 1, 1, 512, 1, 1}); + b->Args({1, 32, 32, 5, 5, 1, 1, 256, 1, 1}); + b->Args({1, 32, 32, 5, 5, 1, 1, 128, 1, 1}); + b->Args({1, 32, 32, 5, 5, 1, 1, 64, 1, 1}); + b->Args({1, 32, 32, 5, 5, 1, 1, 48, 1, 1}); + b->Args({1, 32, 32, 5, 5, 1, 1, 32, 1, 1}); + b->Args({1, 32, 32, 5, 5, 1, 1, 24, 1, 1}); + b->Args({1, 32, 32, 5, 5, 1, 1, 16, 1, 1}); + /********************** 17 x 17 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 17, 17, 5, 5, 1, 1, 1024, 1, 1}); + b->Args({1, 17, 17, 5, 5, 1, 1, 768, 1, 1}); + b->Args({1, 17, 17, 5, 5, 1, 1, 512, 1, 1}); + b->Args({1, 17, 17, 5, 5, 1, 1, 384, 1, 1}); + b->Args({1, 17, 17, 5, 5, 1, 1, 256, 1, 1}); + b->Args({1, 17, 17, 5, 5, 1, 1, 128, 1, 1}); + b->Args({1, 17, 17, 5, 5, 1, 1, 64, 1, 1}); + b->Args({1, 17, 17, 5, 5, 1, 1, 32, 1, 1}); + b->Args({1, 17, 17, 5, 5, 1, 1, 16, 1, 1}); + /********************** 11 x 11 *********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 11, 11, 5, 5, 1, 1, 1024, 1, 1}); + b->Args({1, 11, 11, 5, 5, 1, 1, 768, 1, 1}); + b->Args({1, 11, 11, 5, 5, 1, 1, 512, 1, 1}); + b->Args({1, 11, 11, 5, 5, 1, 1, 384, 1, 1}); + b->Args({1, 11, 11, 5, 5, 1, 1, 256, 1, 1}); + b->Args({1, 11, 11, 5, 5, 1, 1, 128, 1, 1}); + b->Args({1, 11, 11, 5, 5, 1, 1, 64, 1, 1}); + b->Args({1, 11, 11, 5, 5, 1, 1, 32, 1, 1}); + b->Args({1, 11, 11, 5, 5, 1, 1, 16, 1, 1}); + /*********************** 7 x 7 **********************/ + /* N H W KH KW S D G GCin GCout */ + b->Args({1, 7, 7, 5, 5, 1, 1, 1024, 1, 1}); + b->Args({1, 7, 7, 5, 5, 1, 1, 768, 1, 1}); + b->Args({1, 7, 7, 5, 5, 1, 1, 512, 1, 1}); + b->Args({1, 7, 7, 5, 5, 1, 1, 384, 1, 1}); + b->Args({1, 7, 7, 5, 5, 1, 1, 256, 1, 1}); + b->Args({1, 7, 7, 5, 5, 1, 1, 128, 1, 1}); + b->Args({1, 7, 7, 5, 5, 1, 1, 64, 1, 1}); + b->Args({1, 7, 7, 5, 5, 1, 1, 32, 1, 1}); + b->Args({1, 7, 7, 5, 5, 1, 1, 16, 1, 1}); +} + +BENCHMARK_CAPTURE(convolution_q8, mobilenet_v1, "MobileNet v1") + ->Apply(MobileNetV1); +BENCHMARK_CAPTURE(convolution_q8, mobilenet_v2, "MobileNet v2") + ->Apply(MobileNetV2); +BENCHMARK_CAPTURE(convolution_q8, shufflenet_v1_g1, "ShuffleNet v1 (1 group)") + ->Apply(ShuffleNetV1G1); +BENCHMARK_CAPTURE(convolution_q8, shufflenet_v1_g2, "ShuffleNet v1 (2 groups)") + ->Apply(ShuffleNetV1G2); +BENCHMARK_CAPTURE(convolution_q8, shufflenet_v1_g3, "ShuffleNet v1 (3 groups)") + ->Apply(ShuffleNetV1G3); +BENCHMARK_CAPTURE(convolution_q8, shufflenet_v1_g4, "ShuffleNet v1 (4 groups)") + ->Apply(ShuffleNetV1G4); +BENCHMARK_CAPTURE(convolution_q8, shufflenet_v1_g8, "ShuffleNet v1 (8 groups)") + ->Apply(ShuffleNetV1G8); +BENCHMARK_CAPTURE(convolution_q8, shufflenet_v2_x05, "ShuffleNet v2 0.5X") + ->Apply(ShuffleNetV2X05); +BENCHMARK_CAPTURE(convolution_q8, shufflenet_v2_x10, "ShuffleNet v2 1.0X") + ->Apply(ShuffleNetV2X10); +BENCHMARK_CAPTURE(convolution_q8, shufflenet_v2_x15, "ShuffleNet v2 1.5X") + ->Apply(ShuffleNetV2X15); +BENCHMARK_CAPTURE(convolution_q8, shufflenet_v2_x20, "ShuffleNet v2 2.0X") + ->Apply(ShuffleNetV2X20); +BENCHMARK_CAPTURE(convolution_q8, squeezenet_v10, "SqueezeNet 1.0") + ->Apply(SqueezeNetV10); +BENCHMARK_CAPTURE(convolution_q8, squeezenet_v11, "SqueezeNet 1.1") + ->Apply(SqueezeNetV11); +BENCHMARK_CAPTURE(convolution_q8, resnet18, "ResNet-18")->Apply(ResNet18); +BENCHMARK_CAPTURE(convolution_q8, resnet50, "ResNet-50")->Apply(ResNet50); +BENCHMARK_CAPTURE(convolution_q8, vgg, "VGG")->Apply(VGG); +BENCHMARK_CAPTURE(convolution_q8, dwconv3x3, "3x3 DW Convolutions") + ->Apply(DWConv3x3); +BENCHMARK_CAPTURE( + convolution_q8, + dwconv3x3d2, + "3x3 DW Convolutions (dilation 2)") + ->Apply(DWConv3x3d2); +BENCHMARK_CAPTURE(convolution_q8, dwconv5x5, "5x5 DW Convolutions") + ->Apply(DWConv5x5); + +#ifndef PYTORCH_QNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/bench/global-average-pooling.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/global-average-pooling.cc new file mode 100644 index 0000000000000..d71c45bae4441 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/global-average-pooling.cc @@ -0,0 +1,99 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +static void global_average_pooling_q8(benchmark::State& state) { + const size_t batchSize = state.range(0); + const size_t inputHeight = state.range(1); + const size_t inputWidth = state.range(2); + const size_t channels = state.range(3); + + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + const size_t inputPixelStride = channels; + const size_t outputPixelStride = channels; + + std::vector input( + batchSize * inputHeight * inputWidth * inputPixelStride); + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::vector output(batchSize * outputPixelStride); + + pytorch_qnnp_status status = pytorch_qnnp_initialize(); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to initialize QNNPACK"); + } + + pytorch_qnnp_operator_t globalPoolingOperator = nullptr; + status = pytorch_qnnp_create_global_average_pooling_nwc_q8( + channels, + 127 /* input zero point */, + 0.75f /* input scale */, + 127 /* output zero point */, + 1.25f /* output scale */, + 0, + 255, + 0 /* flags */, + &globalPoolingOperator); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to create Global Average Pooling operator"); + } + + status = pytorch_qnnp_setup_global_average_pooling_nwc_q8( + globalPoolingOperator, + batchSize, + inputHeight * inputWidth, + input.data(), + inputPixelStride, + output.data(), + outputPixelStride); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to setup Global Average Pooling operator"); + } + + for (auto _ : state) { + pytorch_qnnp_run_operator(globalPoolingOperator, nullptr /* thread pool */); + } + + status = pytorch_qnnp_delete_operator(globalPoolingOperator); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to delete Global Average Pooling operator"); + } + globalPoolingOperator = nullptr; + + state.SetBytesProcessed( + uint64_t(state.iterations()) * batchSize * + (inputHeight * inputWidth + 1) * channels * sizeof(uint8_t)); +} + +static void ImageNetArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "C"}); + + /* N IH IW C */ + b->Args({1, 7, 7, 1000}); + b->Args({1, 13, 13, 1000}); +} + +BENCHMARK(global_average_pooling_q8)->Apply(ImageNetArguments); + +#ifndef PYTORCH_QNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/bench/hgemm.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/hgemm.cc new file mode 100644 index 0000000000000..d3d4d4ac2473d --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/hgemm.cc @@ -0,0 +1,378 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +inline uint32_t divideRoundUp(uint32_t x, uint32_t q) { + return x / q + uint32_t(x % q != 0); +} + +inline uint32_t roundUp(uint32_t x, uint32_t q) { + return q * divideRoundUp(x, q); +} + +class HGEMM : public benchmark::Fixture { + public: + inline HGEMM(uint32_t mr, uint32_t nr, uint32_t kr) + : mr_(mr), nr_(nr), kr_(kr), mc_(mr), nc_(nr), kc_(kr) {} + + virtual void SetUp(const benchmark::State&) override { + const uint_fast32_t seed = + std::chrono::system_clock::now().time_since_epoch().count(); + auto rng = std::bind( + fp16_ieee_from_fp32_value, + std::bind(std::uniform_real_distribution(), std::mt19937(seed))); + + a_.resize(mc() * kc()); + std::generate(a_.begin(), a_.end(), std::ref(rng)); + k_.resize(nc() * kc()); + std::generate(k_.begin(), k_.end(), std::ref(rng)); + b_.resize(nc()); + std::generate(b_.begin(), b_.end(), std::ref(rng)); + w_.resize(ncStride() * kcStride() + ncStride()); + std::fill(w_.begin(), w_.end(), 0); + pytorch_pack_hgemm_w(nc(), kc(), nr(), kr(), k(), b(), w()); + c_.resize(mc() * nc()); + std::fill(c_.begin(), c_.end(), UINT16_C(0x7E00) /* NaN */); + } + + virtual void TearDown(benchmark::State& state) override { + state.SetItemsProcessed( + uint64_t(state.iterations()) * 2 * mc() * nc() * kc()); + a_.clear(); + k_.clear(); + b_.clear(); + w_.clear(); + c_.clear(); + } + + inline const uint16_t* a() const { + return a_.data(); + } + + inline const uint16_t* k() const { + return k_.data(); + } + + inline const uint16_t* b() const { + return b_.data(); + } + + inline uint16_t* w() { + return w_.data(); + } + + inline const uint16_t* w() const { + return w_.data(); + } + + inline uint16_t* c() { + return c_.data(); + } + + inline uint32_t mr() const { + return mr_; + } + + inline uint32_t mc() const { + return mc_; + } + + inline uint32_t nr() const { + return nr_; + } + + inline uint32_t nc() const { + return nc_; + } + + inline uint32_t ncStride() const { + return roundUp(nc(), nr()); + } + + inline uint32_t kr() const { + return kr_; + } + + inline uint32_t kc() const { + return kc_; + } + + inline uint32_t kcStride() const { + return roundUp(kc(), kr()); + } + + inline const pytorch_qnnp_fp16_clamping_params* clampingParams() const { + return &clampingParams_; + } + + protected: + std::vector a_; + std::vector k_; + std::vector b_; + std::vector> w_; + std::vector c_; + uint32_t mr_{0}; + uint32_t nr_{0}; + uint32_t kr_{0}; + uint32_t mc_{mr_}; + uint32_t nc_{nr_}; + uint32_t kc_{kr_}; + pytorch_qnnp_fp16_clamping_params clampingParams_{0x3C00, 0x7C00, 0xFC00}; +}; + +template +class HGEMM_L1 : public HGEMM { + public: + inline HGEMM_L1() : HGEMM(MR, NR, KR) { + cpuinfo_initialize(); + const size_t l1d_size = cpuinfo_get_l1d_cache(0)->size; + const size_t l1d_reserve = 512; + kc_ = ((l1d_size - l1d_reserve) / sizeof(uint16_t) - mr() * nr()) / + (mr() + nr()); + if (kr() != 1) { + kc_ = kc_ / kr() * kr(); + } else { + kc_ = kc_ / nr() * nr(); + } + } +}; + +template +class HGEMM_Op : public HGEMM { + public: + inline HGEMM_Op() : HGEMM(MR, NR, KR) {} + + virtual void SetUp(const benchmark::State& state) override { + mc_ = state.range(0); + nc_ = state.range(1); + kc_ = state.range(2); + + HGEMM::SetUp(state); + } +}; + +static void ShuffleNetV1G1GemmArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* group = 1 */ + b->Args({56 * 56, 30, 24}); + b->Args({28 * 28, 120, 30}); + b->Args({28 * 28, 36, 144}); + b->Args({28 * 28, 144, 36}); + b->Args({14 * 14, 144, 36}); + b->Args({14 * 14, 72, 288}); + b->Args({14 * 14, 288, 72}); + b->Args({7 * 7, 288, 72}); + b->Args({7 * 7, 144, 576}); + b->Args({7 * 7, 576, 144}); +} + +static void ShuffleNetV1G2GemmArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* group = 2 */ + b->Args({56 * 56, 22, 12}); + b->Args({28 * 28, 88, 22}); + b->Args({28 * 28, 25, 100}); + b->Args({28 * 28, 100, 25}); + b->Args({14 * 14, 100, 25}); + b->Args({14 * 14, 50, 200}); + b->Args({14 * 14, 200, 50}); + b->Args({7 * 7, 200, 50}); + b->Args({7 * 7, 100, 400}); + b->Args({7 * 7, 400, 100}); +} + +static void ShuffleNetV1G3GemmArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* group = 3 */ + b->Args({56 * 56, 18, 8}); + b->Args({28 * 28, 72, 18}); + b->Args({28 * 28, 20, 80}); + b->Args({28 * 28, 80, 20}); + b->Args({14 * 14, 80, 20}); + b->Args({14 * 14, 40, 160}); + b->Args({14 * 14, 160, 40}); + b->Args({7 * 7, 160, 40}); + b->Args({7 * 7, 80, 320}); + b->Args({7 * 7, 320, 80}); +} + +static void ShuffleNetV1G4GemmArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* group = 4 */ + b->Args({56 * 56, 15, 6}); + b->Args({28 * 28, 62, 15}); + b->Args({28 * 28, 17, 68}); + b->Args({28 * 28, 68, 17}); + b->Args({14 * 14, 68, 17}); + b->Args({14 * 14, 34, 136}); + b->Args({14 * 14, 136, 34}); + b->Args({7 * 7, 136, 34}); + b->Args({7 * 7, 68, 272}); + b->Args({7 * 7, 272, 68}); +} + +static void ShuffleNetV1G8GemmArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* group = 8 */ + b->Args({56 * 56, 11, 3}); + b->Args({28 * 28, 45, 11}); + b->Args({28 * 28, 12, 48}); + b->Args({28 * 28, 48, 12}); + b->Args({14 * 14, 48, 12}); + b->Args({14 * 14, 24, 96}); + b->Args({14 * 14, 96, 24}); + b->Args({7 * 7, 96, 24}); + b->Args({7 * 7, 48, 192}); + b->Args({7 * 7, 192, 48}); +} + +static void MobileNetV1GemmArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + b->Args({112 * 112, 32, 3 * 3 * 3}); + b->Args({112 * 112, 64, 32}); + b->Args({56 * 56, 128, 64}); + b->Args({56 * 56, 128, 128}); + b->Args({28 * 28, 256, 128}); + b->Args({28 * 28, 256, 256}); + b->Args({14 * 14, 512, 256}); + b->Args({14 * 14, 512, 512}); + b->Args({7 * 7, 1024, 512}); + b->Args({7 * 7, 1024, 1024}); +} + +static void SqueezeNetV10GemmArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* Conv 1 */ + b->Args({111 * 111, 96, 7 * 7 * 3}); + /* Fire 2 */ + b->Args({55 * 55, 16, 96}); + b->Args({55 * 55, 64, 16}); + b->Args({55 * 55, 64, 3 * 3 * 16}); + /* Fire 3 */ + b->Args({55 * 55, 16, 128}); + b->Args({55 * 55, 64, 16}); + b->Args({55 * 55, 64, 3 * 3 * 16}); + /* Fire 4 */ + b->Args({55 * 55, 32, 128}); + b->Args({55 * 55, 128, 32}); + b->Args({55 * 55, 128, 3 * 3 * 32}); + /* Fire 5 */ + b->Args({27 * 27, 32, 256}); + b->Args({27 * 27, 128, 32}); + b->Args({27 * 27, 128, 3 * 3 * 32}); + /* Fire 6 */ + b->Args({27 * 27, 48, 256}); + b->Args({27 * 27, 192, 48}); + b->Args({27 * 27, 192, 3 * 3 * 48}); + /* Fire 7 */ + b->Args({27 * 27, 48, 384}); + b->Args({27 * 27, 192, 48}); + b->Args({27 * 27, 192, 3 * 3 * 48}); + /* Fire 8 */ + b->Args({27 * 27, 64, 384}); + b->Args({27 * 27, 256, 64}); + b->Args({27 * 27, 256, 3 * 3 * 64}); + /* Fire 9 */ + b->Args({13 * 13, 64, 512}); + b->Args({13 * 13, 256, 64}); + b->Args({13 * 13, 256, 3 * 3 * 64}); + /* Conv 10 */ + b->Args({13 * 13, 1000, 512}); +} + +static void GemmArguments(benchmark::internal::Benchmark* b) { + for (auto S = 15; S <= 128; S *= 2) { + for (int K = 8; K <= 1024; K *= 2) { + b->Args({S * S, K, K}); + } + } +} + +#if CPUINFO_ARCH_ARM +BENCHMARK_TEMPLATE_F(HGEMM_L1, 8x8__aarch32_neonfp16arith, 8, 8, 1) +(benchmark::State& state) { + if (!cpuinfo_initialize() || !cpuinfo_has_arm_neon_fp16_arith()) { + state.SkipWithError("NEON FP16 compute is not supported"); + } + for (auto _ : state) { + pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith( + mr(), + nr(), + kc(), + a(), + kc() * sizeof(uint16_t), + w() + nc() * (kcStride() + 1), + c(), + mr() * sizeof(uint16_t), + clampingParams()); + } +} + +BENCHMARK_TEMPLATE_DEFINE_F(HGEMM_Op, 8x8__aarch32_neonfp16arith, 8, 8, 1) +(benchmark::State& state) { + if (!cpuinfo_initialize() || !cpuinfo_has_arm_neon_fp16_arith()) { + state.SkipWithError("NEON FP16 compute is not supported"); + } + for (auto _ : state) { + for (uint32_t m = 0; m < mc(); m += mr()) { + const uint32_t mrr = min(mc() - m, mr()); + for (uint32_t n = 0; n < nc(); n += nr()) { + const uint32_t nrr = min(nc() - n, nr()); + pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith( + mrr, + nrr, + kc(), + a() + m * kc(), + kc() * sizeof(uint16_t), + w() + n * (kcStride() + 1), + c() + m * nc() + n, + nc() * sizeof(uint16_t), + clampingParams()); + } + } + } +} + +BENCHMARK_REGISTER_F(HGEMM_Op, 8x8__aarch32_neonfp16arith) + ->Apply(ShuffleNetV1G1GemmArguments); +BENCHMARK_REGISTER_F(HGEMM_Op, 8x8__aarch32_neonfp16arith) + ->Apply(MobileNetV1GemmArguments); +BENCHMARK_REGISTER_F(HGEMM_Op, 8x8__aarch32_neonfp16arith) + ->Apply(SqueezeNetV10GemmArguments); +BENCHMARK_REGISTER_F(HGEMM_Op, 8x8__aarch32_neonfp16arith) + ->Apply(GemmArguments); +#endif + +#ifndef PYTORCH_QNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/bench/max-pooling.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/max-pooling.cc new file mode 100644 index 0000000000000..9a879f9c94f52 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/max-pooling.cc @@ -0,0 +1,168 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +static void max_pooling_u8(benchmark::State& state, const char* net) { + const size_t batchSize = state.range(0); + const size_t inputHeight = state.range(1); + const size_t inputWidth = state.range(2); + const size_t poolingSize = state.range(3); + const size_t paddingSize = state.range(4); + const size_t stride = state.range(5); + const size_t channels = state.range(6); + + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + const size_t inputPixelStride = channels; + const size_t outputPixelStride = channels; + const size_t outputHeight = + (2 * paddingSize + inputHeight - poolingSize) / stride + 1; + const size_t outputWidth = + (2 * paddingSize + inputWidth - poolingSize) / stride + 1; + + std::vector input( + batchSize * inputHeight * inputWidth * inputPixelStride); + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::vector output( + batchSize * outputHeight * outputWidth * outputPixelStride); + std::fill(output.begin(), output.end(), 0xA5); + + pytorch_qnnp_status status = pytorch_qnnp_initialize(); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to initialize QNNPACK"); + } + + pytorch_qnnp_operator_t poolingOperator = nullptr; + status = pytorch_qnnp_create_max_pooling2d_nhwc_u8( + paddingSize, + paddingSize, + paddingSize, + paddingSize, + poolingSize, + poolingSize, + stride, + stride, + 1 /* dilation height */, + 1 /* dilation width */, + channels, + 0, + 255, + 0 /* flags */, + &poolingOperator); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to create Max Pooling operator"); + } + + status = pytorch_qnnp_setup_max_pooling2d_nhwc_u8( + poolingOperator, + batchSize, + inputHeight, + inputWidth, + input.data(), + inputPixelStride, + output.data(), + outputPixelStride, + nullptr /* thread pool */); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to setup Max Pooling operator"); + } + + for (auto _ : state) { + status = + pytorch_qnnp_run_operator(poolingOperator, nullptr /* thread pool */); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to run Max Pooling operator"); + } + } + + status = pytorch_qnnp_delete_operator(poolingOperator); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to delete Max Pooling operator"); + } + poolingOperator = nullptr; + + state.SetBytesProcessed( + uint64_t(state.iterations()) * batchSize * + (inputHeight * inputWidth + outputHeight * outputWidth) * channels * + sizeof(uint8_t)); +} + +/* ShuffleNet v1/v2 */ +static void ShuffleNet(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "K", "P", "S", "C"}); + + /* N H W K P S C */ + b->Args({1, 112, 112, 3, 1, 2, 24}); +} + +/* SqueezeNet 1.0 */ +static void SqueezeNetV10(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "K", "P", "S", "C"}); + + /*********** MaxPool 1 ************/ + /* N H W K P S C */ + b->Args({1, 111, 111, 3, 0, 2, 96}); + /*********** MaxPool 4 ************/ + /* N H W K P S C */ + b->Args({1, 27, 27, 3, 0, 2, 256}); + /*********** MaxPool 8 ************/ + /* N H W K P S C */ + b->Args({1, 13, 13, 3, 0, 2, 512}); +} + +/* SqueezeNet 1.1 */ +static void SqueezeNetV11(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "K", "P", "S", "C"}); + + /*********** MaxPool 1 ***********/ + /* N H W K P S C */ + b->Args({1, 111, 111, 3, 0, 2, 64}); + /*********** MaxPool 3 ************/ + /* N H W K P S C */ + b->Args({1, 55, 55, 3, 0, 2, 128}); + /*********** MaxPool 5 ************/ + /* N H W K P S C */ + b->Args({1, 13, 13, 3, 0, 2, 256}); +} + +static void VGG(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "H", "W", "K", "P", "S", "C"}); + + /* N H W K P S C */ + b->Args({1, 224, 224, 2, 1, 2, 64}); + b->Args({1, 112, 112, 2, 1, 2, 128}); + b->Args({1, 56, 56, 2, 1, 2, 256}); + b->Args({1, 28, 28, 2, 1, 2, 512}); + b->Args({1, 14, 14, 2, 1, 2, 512}); +} + +BENCHMARK_CAPTURE(max_pooling_u8, shufflenet, "ShuffleNet v1/v2") + ->Apply(ShuffleNet); +BENCHMARK_CAPTURE(max_pooling_u8, squeezenet_v10, "SqueezeNet v1.0") + ->Apply(SqueezeNetV10); +BENCHMARK_CAPTURE(max_pooling_u8, squeezenet_v11, "SqueezeNet v1.1") + ->Apply(SqueezeNetV11); +BENCHMARK_CAPTURE(max_pooling_u8, vgg, "VGG")->Apply(VGG); + +#ifndef PYTORCH_QNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/bench/q8gemm.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/q8gemm.cc new file mode 100644 index 0000000000000..4895a5d5d0325 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/q8gemm.cc @@ -0,0 +1,1043 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#if PYTORCH_QNNPACK_BENCHMARK_GEMMLOWP +#include +#endif + +inline uint32_t divideRoundUp(uint32_t x, uint32_t q) { + return x / q + uint32_t(x % q != 0); +} + +inline uint32_t roundUp(uint32_t x, uint32_t q) { + return q * divideRoundUp(x, q); +} + +#if PYTORCH_QNNPACK_BENCHMARK_GEMMLOWP +struct GemmlowpOutputPipeline { + typedef gemmlowp::VectorMap + ColVectorMap; + typedef std::tuple< + gemmlowp::OutputStageBiasAddition, + gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint, + gemmlowp::OutputStageClamp, + gemmlowp::OutputStageSaturatingCastToUint8> + Pipeline; + + static Pipeline Make( + const int32_t* bias_data, + int output_rows, + int32_t output_offset, + int32_t output_multiplier, + int output_shift, + int32_t output_activation_min, + int32_t output_activation_max) { + ColVectorMap bias_vector(bias_data, output_rows); + gemmlowp::OutputStageBiasAddition bias_addition_stage; + bias_addition_stage.bias_vector = bias_vector; + gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint + quantize_down_stage; + quantize_down_stage.result_offset_after_shift = output_offset; + quantize_down_stage.result_fixedpoint_multiplier = output_multiplier; + quantize_down_stage.result_shift = output_shift; + gemmlowp::OutputStageClamp clamp_stage; + clamp_stage.min = output_activation_min; + clamp_stage.max = output_activation_max; + gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage; + return std::make_tuple( + bias_addition_stage, + quantize_down_stage, + clamp_stage, + saturating_cast_stage); + } +}; +#endif + +class Q8GEMM : public benchmark::Fixture { + public: + inline Q8GEMM(uint32_t mr, uint32_t nr, uint32_t np, uint32_t kr) + : mr_(mr), nr_(nr), np_(np), kr_(kr), mc_(mr), nc_(nr), kc_(kr) {} + + virtual void SetUp(const benchmark::State&) override { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto s32rng = + std::bind(std::uniform_int_distribution(-10000, 10000), rng); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + a_.resize(mc() * kc()); + std::generate(a_.begin(), a_.end(), std::ref(u8rng)); + k_.resize(nc() * kc()); + std::generate(k_.begin(), k_.end(), std::ref(u8rng)); + b_.resize(nc()); + std::generate(b_.begin(), b_.end(), std::ref(s32rng)); + w_.resize( + kcStride() * ncStride() + + ncStride() * sizeof(int32_t) / sizeof(uint8_t)); + std::fill(w_.begin(), w_.end(), 127); + pytorch_pack_q8gemm_w( + nc(), + kc(), + nr(), + np(), + kr(), +#if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION + 127, + 127, +#endif + k(), + b(), + w()); + c_.resize(mc() * nc()); + std::fill(c_.begin(), c_.end(), 0xA5); + + quantizationParams_ = pytorch_qnnp_compute_conv_quantization_params( + 127, 127, 0.75f, 127, 1, 254); + } + + virtual void TearDown(benchmark::State& state) override { + state.SetItemsProcessed( + uint64_t(state.iterations()) * 2 * mc() * nc() * kc()); + a_.clear(); + k_.clear(); + b_.clear(); + w_.clear(); + c_.clear(); + } + + inline const uint8_t* a() const { + return a_.data(); + } + + inline const uint8_t* k() const { + return k_.data(); + } + + inline const int32_t* b() const { + return b_.data(); + } + + inline uint8_t* w() { + return w_.data(); + } + + inline const uint8_t* w() const { + return w_.data(); + } + + inline uint8_t* c() { + return c_.data(); + } + + inline uint32_t mr() const { + return mr_; + } + + inline uint32_t mc() const { + return mc_; + } + + inline uint32_t nr() const { + return nr_; + } + + inline uint32_t np() const { + return np_; + } + + inline uint32_t nc() const { + return nc_; + } + + inline uint32_t ncStride() const { + return roundUp(nc(), nr()); + } + + inline uint32_t kr() const { + return kr_; + } + + inline uint32_t kc() const { + return kc_; + } + + inline uint32_t kcStride() const { + return roundUp(kc(), kr()); + } + + inline const pytorch_qnnp_conv_quantization_params* quantizationParams() + const { + return &quantizationParams_; + } + + protected: + std::vector a_; + std::vector k_; + std::vector b_; + std::vector> w_; + std::vector c_; + uint32_t mr_{0}; + uint32_t nr_{0}; + uint32_t np_{0}; + uint32_t kr_{0}; + uint32_t mc_{mr_}; + uint32_t nc_{nr_}; + uint32_t kc_{kr_}; + pytorch_qnnp_conv_quantization_params quantizationParams_; +}; + +template +class Q8GEMM_L1 : public Q8GEMM { + public: + inline Q8GEMM_L1() : Q8GEMM(MR, NR, NP, KR) { + cpuinfo_initialize(); + const size_t l1d_size = cpuinfo_get_l1d_cache(0)->size; + const size_t l1d_reserve = 512; + kc_ = ((l1d_size - l1d_reserve) / sizeof(uint8_t) - mr() * nr()) / + (mr() + nr()); + if (kr() != 1) { + kc_ = kc_ / kr() * kr(); + } else { + kc_ = kc_ / nr() * nr(); + } + } +}; + +template +class Q8GEMM_Op : public Q8GEMM { + public: + inline Q8GEMM_Op() : Q8GEMM(MR, NR, NP, KR) {} + + virtual void SetUp(const benchmark::State& state) override { + mc_ = state.range(0); + nc_ = state.range(1); + kc_ = state.range(2); + + Q8GEMM::SetUp(state); + } +}; + +class Q8GEMM_XZP : public Q8GEMM { + public: + inline Q8GEMM_XZP(uint32_t mr, uint32_t nr, uint32_t np, uint32_t kr) + : Q8GEMM(mr, nr, np, kr) {} + virtual void SetUp(const benchmark::State&) override { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto s32rng = + std::bind(std::uniform_int_distribution(-10000, 10000), rng); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + a_.resize(mc() * kc()); + std::generate(a_.begin(), a_.end(), std::ref(u8rng)); + k_.resize(ncStride() * kcStride()); + std::generate(k_.begin(), k_.end(), std::ref(u8rng)); + b_.resize(roundUp(nc(), nr())); + std::generate(b_.begin(), b_.end(), std::ref(s32rng)); + w_.resize(ncStride() * (kcStride() + sizeof(int32_t) / sizeof(uint8_t))); + std::fill(w_.begin(), w_.end(), 127); + pytorch_pack_swizzle_q8gemm_b( + nc(), + kc(), + np(), + kr(), + 8, +#if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION + 127, + 127, +#endif + k(), + b(), + w()); + c_.resize(mc() * nc()); + std::fill(c_.begin(), c_.end(), 0xA5); + aRowSums_.resize(roundUp(mc(), mr())); + std::fill(aRowSums_.begin(), aRowSums_.end(), 0xFE01); + + requantizationParams_ = + pytorch_qnnp_compute_requantization_params(0.75f, 127, 1, 254); + } + + virtual void TearDown(benchmark::State& state) override { + state.SetItemsProcessed( + uint64_t(state.iterations()) * 2 * mc() * nc() * kc()); + a_.clear(); + k_.clear(); + c_.clear(); + aRowSums_.clear(); + } + + inline int32_t* aRowSums() { + return aRowSums_.data(); + } + + inline const int32_t* aRowSums() const { + return aRowSums_.data(); + } + + inline const pytorch_qnnp_q31_requantization_params* requantizationParams() + const { + return &requantizationParams_; + } + + protected: + std::vector aRowSums_; + pytorch_qnnp_q31_requantization_params requantizationParams_; +}; + +template +class Q8GEMM_XZP_L1 : public Q8GEMM_XZP { + public: + inline Q8GEMM_XZP_L1() : Q8GEMM_XZP(MR, NR, NP, KR) { + cpuinfo_initialize(); + const size_t l1d_size = cpuinfo_get_l1d_cache(0)->size; + const size_t l1d_reserve = 512; + kc_ = ((l1d_size - l1d_reserve) / sizeof(uint8_t) - mr() * nr()) / + (mr() + nr()); + if (kr() != 1) { + kc_ = kc_ / kr() * kr(); + } else { + kc_ = kc_ / nr() * nr(); + } + } +}; + +template +class Q8GEMM_XZP_Op : public Q8GEMM_XZP { + public: + inline Q8GEMM_XZP_Op() : Q8GEMM_XZP(MR, NR, NP, KR) {} + + virtual void SetUp(const benchmark::State& state) override { + mc_ = state.range(0); + nc_ = state.range(1); + kc_ = state.range(2); + + Q8GEMM_XZP::SetUp(state); + } +}; + +template +class COMPUTE_ROW_SUM_Op : public Q8GEMM_XZP { + public: + inline COMPUTE_ROW_SUM_Op() : Q8GEMM_XZP(MR, NR, NP, KR) {} + + virtual void SetUp(const benchmark::State& state) override { + mc_ = state.range(0); + nc_ = state.range(1); + kc_ = state.range(2); + + Q8GEMM_XZP::SetUp(state); + } + + virtual void TearDown(benchmark::State& state) override { + state.SetItemsProcessed(uint64_t(state.iterations()) * (mc() * kc())); + a_.clear(); + k_.clear(); + b_.clear(); + c_.clear(); + aRowSums_.clear(); + } +}; + +#if PYTORCH_QNNPACK_BENCHMARK_GEMMLOWP +class GEMMLOWP : public benchmark::Fixture { + public: + virtual void SetUp(const benchmark::State& state) override { + const uint_fast32_t seed = + std::chrono::system_clock::now().time_since_epoch().count(); + auto rng = + std::bind(std::uniform_int_distribution(), std::mt19937(seed)); + + mc_ = state.range(0); + nc_ = state.range(1); + kc_ = state.range(2); + + a_.resize(mc() * kc()); + std::generate(a_.begin(), a_.end(), std::ref(rng)); + k_.resize(nc() * kc()); + std::generate(k_.begin(), k_.end(), std::ref(rng)); + b_.resize(nc()); + std::generate(b_.begin(), b_.end(), std::ref(rng)); + c_.resize(mc() * nc()); + std::fill(c_.begin(), c_.end(), 0xA5); + + threadingContext.set_max_num_threads(1); + } + + virtual void TearDown(benchmark::State& state) override { + state.SetItemsProcessed( + uint64_t(state.iterations()) * 2 * mc() * nc() * kc()); + a_.clear(); + k_.clear(); + c_.clear(); + } + + inline const uint8_t* a() const { + return a_.data(); + } + + inline const uint8_t* k() const { + return k_.data(); + } + + inline const int32_t* b() const { + return b_.data(); + } + + inline uint8_t* c() { + return c_.data(); + } + + inline uint32_t mc() const { + return mc_; + } + + inline uint32_t nc() const { + return nc_; + } + + inline uint32_t kc() const { + return kc_; + } + + protected: + gemmlowp::MultiThreadGemmContext threadingContext; + + private: + std::vector a_; + std::vector k_; + std::vector b_; + std::vector c_; + uint32_t mc_; + uint32_t nc_; + uint32_t kc_; +}; +#endif + +static void ShuffleNetV1G1GemmArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* group = 1 */ + b->Args({56 * 56, 30, 24}); + b->Args({28 * 28, 120, 30}); + b->Args({28 * 28, 36, 144}); + b->Args({28 * 28, 144, 36}); + b->Args({14 * 14, 144, 36}); + b->Args({14 * 14, 72, 288}); + b->Args({14 * 14, 288, 72}); + b->Args({7 * 7, 288, 72}); + b->Args({7 * 7, 144, 576}); + b->Args({7 * 7, 576, 144}); +} + +static void ShuffleNetV1G2GemmArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* group = 2 */ + b->Args({56 * 56, 22, 12}); + b->Args({28 * 28, 88, 22}); + b->Args({28 * 28, 25, 100}); + b->Args({28 * 28, 100, 25}); + b->Args({14 * 14, 100, 25}); + b->Args({14 * 14, 50, 200}); + b->Args({14 * 14, 200, 50}); + b->Args({7 * 7, 200, 50}); + b->Args({7 * 7, 100, 400}); + b->Args({7 * 7, 400, 100}); +} + +static void ShuffleNetV1G3GemmArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* group = 3 */ + b->Args({56 * 56, 18, 8}); + b->Args({28 * 28, 72, 18}); + b->Args({28 * 28, 20, 80}); + b->Args({28 * 28, 80, 20}); + b->Args({14 * 14, 80, 20}); + b->Args({14 * 14, 40, 160}); + b->Args({14 * 14, 160, 40}); + b->Args({7 * 7, 160, 40}); + b->Args({7 * 7, 80, 320}); + b->Args({7 * 7, 320, 80}); +} + +static void ShuffleNetV1G4GemmArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* group = 4 */ + b->Args({56 * 56, 15, 6}); + b->Args({28 * 28, 62, 15}); + b->Args({28 * 28, 17, 68}); + b->Args({28 * 28, 68, 17}); + b->Args({14 * 14, 68, 17}); + b->Args({14 * 14, 34, 136}); + b->Args({14 * 14, 136, 34}); + b->Args({7 * 7, 136, 34}); + b->Args({7 * 7, 68, 272}); + b->Args({7 * 7, 272, 68}); +} + +static void ShuffleNetV1G8GemmArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* group = 8 */ + b->Args({56 * 56, 11, 3}); + b->Args({28 * 28, 45, 11}); + b->Args({28 * 28, 12, 48}); + b->Args({28 * 28, 48, 12}); + b->Args({14 * 14, 48, 12}); + b->Args({14 * 14, 24, 96}); + b->Args({14 * 14, 96, 24}); + b->Args({7 * 7, 96, 24}); + b->Args({7 * 7, 48, 192}); + b->Args({7 * 7, 192, 48}); +} + +static void MobileNetV1GemmArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + b->Args({112 * 112, 32, 3 * 3 * 3}); + b->Args({112 * 112, 64, 32}); + b->Args({56 * 56, 128, 64}); + b->Args({56 * 56, 128, 128}); + b->Args({28 * 28, 256, 128}); + b->Args({28 * 28, 256, 256}); + b->Args({14 * 14, 512, 256}); + b->Args({14 * 14, 512, 512}); + b->Args({7 * 7, 1024, 512}); + b->Args({7 * 7, 1024, 1024}); +} + +static void SqueezeNetV10GemmArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* Conv 1 */ + b->Args({111 * 111, 96, 7 * 7 * 3}); + /* Fire 2 */ + b->Args({55 * 55, 16, 96}); + b->Args({55 * 55, 64, 16}); + b->Args({55 * 55, 64, 3 * 3 * 16}); + /* Fire 3 */ + b->Args({55 * 55, 16, 128}); + b->Args({55 * 55, 64, 16}); + b->Args({55 * 55, 64, 3 * 3 * 16}); + /* Fire 4 */ + b->Args({55 * 55, 32, 128}); + b->Args({55 * 55, 128, 32}); + b->Args({55 * 55, 128, 3 * 3 * 32}); + /* Fire 5 */ + b->Args({27 * 27, 32, 256}); + b->Args({27 * 27, 128, 32}); + b->Args({27 * 27, 128, 3 * 3 * 32}); + /* Fire 6 */ + b->Args({27 * 27, 48, 256}); + b->Args({27 * 27, 192, 48}); + b->Args({27 * 27, 192, 3 * 3 * 48}); + /* Fire 7 */ + b->Args({27 * 27, 48, 384}); + b->Args({27 * 27, 192, 48}); + b->Args({27 * 27, 192, 3 * 3 * 48}); + /* Fire 8 */ + b->Args({27 * 27, 64, 384}); + b->Args({27 * 27, 256, 64}); + b->Args({27 * 27, 256, 3 * 3 * 64}); + /* Fire 9 */ + b->Args({13 * 13, 64, 512}); + b->Args({13 * 13, 256, 64}); + b->Args({13 * 13, 256, 3 * 3 * 64}); + /* Conv 10 */ + b->Args({13 * 13, 1000, 512}); +} + +static void GemmArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + for (auto S = 15; S <= 128; S *= 2) { + for (int K = 8; K <= 1024; K *= 2) { + b->Args({S * S, K, K}); + } + } +} + +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 +static void q8gemm_compute_row_sum( + const uint8_t* a, + size_t m, + size_t k, + size_t stride, + const int32_t multiplier, + int32_t* row_sum) { + const size_t block_size = 4; + for (size_t block_start = 0; block_start < m; block_start += block_size) { + pytorch_q8sumrows_ukernel_4x__neon( + a + block_start * stride, + std::min(block_size, m - block_start), + k, + stride, + multiplier, + row_sum + block_start); + } +} +#endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */ + +#if CPUINFO_ARCH_ARM +BENCHMARK_TEMPLATE_F(Q8GEMM_L1, 4x8__aarch32_neon, 4, 8, 8, 1) +(benchmark::State& state) { + for (auto _ : state) { + pytorch_q8gemm_ukernel_4x8__aarch32_neon( + mr(), + nr(), + kc(), + a(), + kc() * sizeof(uint8_t), + w(), + c(), + mr() * sizeof(uint8_t), + quantizationParams()); + } +} + +BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_Op, 4x8__aarch32_neon, 4, 8, 8, 1) +(benchmark::State& state) { + for (auto _ : state) { + for (uint32_t m = 0; m < mc(); m += mr()) { + const uint32_t mrr = min(mc() - m, mr()); + for (uint32_t n = 0; n < nc(); n += nr()) { + const uint32_t nrr = min(nc() - n, nr()); + pytorch_q8gemm_ukernel_4x8__aarch32_neon( + mrr, + nrr, + kc(), + a() + m * kc(), + kc() * sizeof(uint8_t), + w() + n * (kcStride() * sizeof(uint8_t) + sizeof(int32_t)), + c() + m * nc() + n, + nc() * sizeof(uint8_t), + quantizationParams()); + } + } + } +} +BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__aarch32_neon) + ->Apply(ShuffleNetV1G1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__aarch32_neon) + ->Apply(MobileNetV1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__aarch32_neon) + ->Apply(SqueezeNetV10GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__aarch32_neon)->Apply(GemmArguments); + +BENCHMARK_TEMPLATE_F(Q8GEMM_XZP_L1, 4x8c2__aarch32_neon, 4, 8, 8, 2) +(benchmark::State& state) { + for (auto _ : state) { + q8gemm_compute_row_sum(a(), mr(), kc(), kc(), -64, aRowSums()); + pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon( + mr(), + nr(), + kc(), + a(), + kc(), + aRowSums(), + w(), + c(), + mr(), + requantizationParams()); + } +} + +BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_XZP_Op, 4x8c2__aarch32_neon, 4, 8, 8, 2) +(benchmark::State& state) { + for (auto _ : state) { + q8gemm_compute_row_sum(a(), mc(), kc(), kc(), -64, aRowSums()); + for (uint32_t m = 0; m < mc(); m += mr()) { + const uint32_t mrr = min(mc() - m, mr()); + for (uint32_t n = 0; n < nc(); n += nr()) { + const uint32_t nrr = min(nc() - n, nr()); + pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon( + mrr, + nrr, + kc(), + a() + m * kc(), + kc(), + aRowSums() + m, + w() + n * (kcStride() + sizeof(int32_t) / sizeof(uint8_t)), + c() + m * nc() + n, + nc(), + requantizationParams()); + } + } + } +} + +BENCHMARK_REGISTER_F(Q8GEMM_XZP_Op, 4x8c2__aarch32_neon) + ->Apply(ShuffleNetV1G1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_XZP_Op, 4x8c2__aarch32_neon) + ->Apply(MobileNetV1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_XZP_Op, 4x8c2__aarch32_neon) + ->Apply(SqueezeNetV10GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_XZP_Op, 4x8c2__aarch32_neon)->Apply(GemmArguments); +#endif + +#if CPUINFO_ARCH_ARM64 +BENCHMARK_TEMPLATE_F(Q8GEMM_L1, 8x8__aarch64_neon, 8, 8, 8, 1) +(benchmark::State& state) { + for (auto _ : state) { + pytorch_q8gemm_ukernel_8x8__aarch64_neon( + mr(), + nr(), + kc(), + a(), + kc() * sizeof(uint8_t), + w(), + c(), + mr() * sizeof(uint8_t), + quantizationParams()); + } +} + +BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_Op, 8x8__aarch64_neon, 8, 8, 8, 1) +(benchmark::State& state) { + for (auto _ : state) { + for (uint32_t m = 0; m < mc(); m += mr()) { + const uint32_t mrr = min(mc() - m, mr()); + for (uint32_t n = 0; n < nc(); n += nr()) { + const uint32_t nrr = min(nc() - n, nr()); + pytorch_q8gemm_ukernel_8x8__aarch64_neon( + mrr, + nrr, + kc(), + a() + m * kc(), + kc() * sizeof(uint8_t), + w() + n * (kcStride() * sizeof(uint8_t) + sizeof(int32_t)), + c() + m * nc() + n, + nc() * sizeof(uint8_t), + quantizationParams()); + } + } + } +} + +BENCHMARK_REGISTER_F(Q8GEMM_Op, 8x8__aarch64_neon) + ->Apply(ShuffleNetV1G1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_Op, 8x8__aarch64_neon) + ->Apply(MobileNetV1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_Op, 8x8__aarch64_neon) + ->Apply(SqueezeNetV10GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_Op, 8x8__aarch64_neon)->Apply(GemmArguments); +#endif + +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 +BENCHMARK_TEMPLATE_F(Q8GEMM_L1, 4x8__neon, 4, 8, 8, 1) +(benchmark::State& state) { + for (auto _ : state) { + pytorch_q8gemm_ukernel_4x8__neon( + mr(), + nr(), + kc(), + a(), + kc() * sizeof(uint8_t), + w(), + c(), + mr() * sizeof(uint8_t), + quantizationParams()); + } +} + +BENCHMARK_TEMPLATE_F(Q8GEMM_L1, 8x8__neon, 8, 8, 8, 1) +(benchmark::State& state) { + for (auto _ : state) { + pytorch_q8gemm_ukernel_8x8__neon( + mr(), + nr(), + kc(), + a(), + kc() * sizeof(uint8_t), + w(), + c(), + mr() * sizeof(uint8_t), + quantizationParams()); + } +} + +BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_Op, 4x8__neon, 4, 8, 8, 1) +(benchmark::State& state) { + for (auto _ : state) { + for (uint32_t m = 0; m < mc(); m += mr()) { + const uint32_t mrr = min(mc() - m, mr()); + for (uint32_t n = 0; n < nc(); n += nr()) { + const uint32_t nrr = min(nc() - n, nr()); + pytorch_q8gemm_ukernel_4x8__neon( + mrr, + nrr, + kc(), + a() + m * kc(), + kc() * sizeof(uint8_t), + w() + n * (kcStride() * sizeof(uint8_t) + sizeof(int32_t)), + c() + m * nc() + n, + nc() * sizeof(uint8_t), + quantizationParams()); + } + } + } +} + +BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__neon)->Apply(ShuffleNetV1G1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__neon)->Apply(MobileNetV1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__neon)->Apply(SqueezeNetV10GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__neon)->Apply(GemmArguments); + +BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_Op, 8x8__neon, 8, 8, 8, 1) +(benchmark::State& state) { + for (auto _ : state) { + for (uint32_t m = 0; m < mc(); m += mr()) { + const uint32_t mrr = min(mc() - m, mr()); + for (uint32_t n = 0; n < nc(); n += nr()) { + const uint32_t nrr = min(nc() - n, nr()); + pytorch_q8gemm_ukernel_8x8__neon( + mrr, + nrr, + kc(), + a() + m * kc(), + kc() * sizeof(uint8_t), + w() + n * (kcStride() * sizeof(uint8_t) + sizeof(int32_t)), + c() + m * nc() + n, + nc() * sizeof(uint8_t), + quantizationParams()); + } + } + } +} + +BENCHMARK_REGISTER_F(Q8GEMM_Op, 8x8__neon)->Apply(ShuffleNetV1G1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_Op, 8x8__neon)->Apply(MobileNetV1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_Op, 8x8__neon)->Apply(SqueezeNetV10GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_Op, 8x8__neon)->Apply(GemmArguments); + +BENCHMARK_TEMPLATE_F(Q8GEMM_XZP_L1, 4x8c2_neon, 4, 8, 8, 2) +(benchmark::State& state) { + for (auto _ : state) { + q8gemm_compute_row_sum(a(), mr(), kc(), kc(), -64, aRowSums()); + pytorch_q8gemm_xzp_ukernel_4x8c2__neon( + mr(), + nr(), + kc(), + a(), + kc(), + aRowSums(), + w(), + c(), + mr(), + requantizationParams()); + } +} + +BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_XZP_Op, 4x8c2_neon, 4, 8, 8, 2) +(benchmark::State& state) { + for (auto _ : state) { + q8gemm_compute_row_sum(a(), mc(), kc(), kc(), -64, aRowSums()); + for (uint32_t m = 0; m < mc(); m += mr()) { + const uint32_t mrr = min(mc() - m, mr()); + for (uint32_t n = 0; n < nc(); n += nr()) { + const uint32_t nrr = min(nc() - n, nr()); + pytorch_q8gemm_xzp_ukernel_4x8c2__neon( + mrr, + nrr, + kc(), + a() + m * kc(), + kc(), + aRowSums() + m, + w() + n * (kcStride() + sizeof(int32_t) / sizeof(uint8_t)), + c() + m * nc() + n, + nc(), + requantizationParams()); + } + } + } +} + +BENCHMARK_REGISTER_F(Q8GEMM_XZP_Op, 4x8c2_neon) + ->Apply(ShuffleNetV1G1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_XZP_Op, 4x8c2_neon) + ->Apply(MobileNetV1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_XZP_Op, 4x8c2_neon) + ->Apply(SqueezeNetV10GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_XZP_Op, 4x8c2_neon)->Apply(GemmArguments); + +BENCHMARK_TEMPLATE_DEFINE_F( + COMPUTE_ROW_SUM_Op, + compute_row_sum_neon, + 4, + 8, + 8, + 2) +(benchmark::State& state) { + for (auto _ : state) { + const size_t block_size = 4; + for (size_t block_start = 0; block_start < mc(); + block_start += block_size) { + pytorch_q8sumrows_ukernel_4x__neon( + a() + block_start * kc(), + min(block_size, mc() - block_start), + kc(), + kc(), + 0x11, + aRowSums() + block_start); + } + } +} + +BENCHMARK_REGISTER_F(COMPUTE_ROW_SUM_Op, compute_row_sum_neon) + ->Apply(ShuffleNetV1G1GemmArguments); +BENCHMARK_REGISTER_F(COMPUTE_ROW_SUM_Op, compute_row_sum_neon) + ->Apply(MobileNetV1GemmArguments); +BENCHMARK_REGISTER_F(COMPUTE_ROW_SUM_Op, compute_row_sum_neon) + ->Apply(SqueezeNetV10GemmArguments); +BENCHMARK_REGISTER_F(COMPUTE_ROW_SUM_Op, compute_row_sum_neon) + ->Apply(GemmArguments); + +#endif + +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +BENCHMARK_TEMPLATE_F(Q8GEMM_L1, 2x4c8__sse2, 2, 4, 1, 8) +(benchmark::State& state) { + for (auto _ : state) { + pytorch_q8gemm_ukernel_2x4c8__sse2( + mr(), + nr(), + kc(), + a(), + kc() * sizeof(uint8_t), + w(), + c(), + mr() * sizeof(uint8_t), + quantizationParams()); + } +} + +BENCHMARK_TEMPLATE_F(Q8GEMM_L1, 4x4c2__sse2, 4, 4, 4, 2) +(benchmark::State& state) { + for (auto _ : state) { + pytorch_q8gemm_ukernel_4x4c2__sse2( + mr(), + nr(), + kc(), + a(), + kc() * sizeof(uint8_t), + w(), + c(), + mr() * sizeof(uint8_t), + quantizationParams()); + } +} + +BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_Op, 2x4c8__sse2, 2, 4, 1, 8) +(benchmark::State& state) { + for (auto _ : state) { + for (uint32_t m = 0; m < mc(); m += mr()) { + const uint32_t mrr = min(mc() - m, mr()); + for (uint32_t n = 0; n < nc(); n += nr()) { + const uint32_t nrr = min(nc() - n, nr()); + pytorch_q8gemm_ukernel_2x4c8__sse2( + mrr, + nrr, + kc(), + a() + m * kc(), + kc() * sizeof(uint8_t), + w() + n * (kcStride() * sizeof(uint8_t) + sizeof(int32_t)), + c() + m * nc() + n, + nc() * sizeof(uint8_t), + quantizationParams()); + } + } + } +} + +BENCHMARK_REGISTER_F(Q8GEMM_Op, 2x4c8__sse2) + ->Apply(ShuffleNetV1G1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_Op, 2x4c8__sse2)->Apply(MobileNetV1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_Op, 2x4c8__sse2)->Apply(SqueezeNetV10GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_Op, 2x4c8__sse2)->Apply(GemmArguments); + +BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_Op, 4x4c2__sse2, 4, 4, 4, 2) +(benchmark::State& state) { + for (auto _ : state) { + for (uint32_t m = 0; m < mc(); m += mr()) { + const uint32_t mrr = min(mc() - m, mr()); + for (uint32_t n = 0; n < nc(); n += nr()) { + const uint32_t nrr = min(nc() - n, nr()); + pytorch_q8gemm_ukernel_4x4c2__sse2( + mrr, + nrr, + kc(), + a() + m * kc(), + kc() * sizeof(uint8_t), + w() + n * (kcStride() * sizeof(uint8_t) + sizeof(int32_t)), + c() + m * nc() + n, + nc() * sizeof(uint8_t), + quantizationParams()); + } + } + } +} + +BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x4c2__sse2) + ->Apply(ShuffleNetV1G1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x4c2__sse2)->Apply(MobileNetV1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x4c2__sse2)->Apply(SqueezeNetV10GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x4c2__sse2)->Apply(GemmArguments); +#endif + +#if PYTORCH_QNNPACK_BENCHMARK_GEMMLOWP +BENCHMARK_DEFINE_F(GEMMLOWP, single_threaded)(benchmark::State& state) { + for (auto _ : state) { + gemmlowp::MatrixMap AM( + a(), mc(), kc(), kc()); + gemmlowp::MatrixMap BM( + k(), kc(), nc(), kc()); + gemmlowp::MatrixMap CM( + c(), mc(), nc(), nc()); + const auto& output_pipeline = + GemmlowpOutputPipeline::Make(b(), nc(), 127, 1, 2, 0, 255); + gemmlowp::GemmWithOutputPipeline< + uint8_t, + uint8_t, + gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( + &threadingContext, AM, BM, &CM, 2, 1, output_pipeline); + } +} + +BENCHMARK_REGISTER_F(GEMMLOWP, single_threaded) + ->Apply(ShuffleNetV1G1GemmArguments); +BENCHMARK_REGISTER_F(GEMMLOWP, single_threaded) + ->Apply(MobileNetV1GemmArguments); +BENCHMARK_REGISTER_F(GEMMLOWP, single_threaded) + ->Apply(SqueezeNetV10GemmArguments); +BENCHMARK_REGISTER_F(GEMMLOWP, single_threaded)->Apply(GemmArguments); +#endif + +#ifndef PYTORCH_QNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/bench/requantization.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/requantization.cc new file mode 100644 index 0000000000000..8582d107a90c8 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/requantization.cc @@ -0,0 +1,379 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +inline uint32_t divideRoundUp(uint32_t x, uint32_t q) { + return x / q + uint32_t(x % q != 0); +} + +inline uint32_t roundUp(uint32_t x, uint32_t q) { + return q * divideRoundUp(x, q); +} + +inline uint32_t min(uint32_t a, uint32_t b) { + return a < b ? a : b; +} + +class Requantization : public benchmark::Fixture { + public: + inline Requantization() { + cpuinfo_initialize(); + const size_t l1d_size = cpuinfo_get_l1d_cache(0)->size; + const size_t l1d_reserve = 1024; + n_ = (l1d_size - l1d_reserve) / (sizeof(int32_t) + sizeof(uint8_t)); + n_ = n_ / 16 * 16; + } + + virtual void SetUp(const benchmark::State&) override { + const uint_fast32_t seed = + std::chrono::system_clock::now().time_since_epoch().count(); + auto rng = + std::bind(std::uniform_int_distribution(), std::mt19937(seed)); + + input_.resize(n()); + std::generate(input_.begin(), input_.end(), std::ref(rng)); + output_.resize(n()); + std::fill(output_.begin(), output_.end(), 0xA5); + } + + virtual void TearDown(benchmark::State& state) override { + state.SetItemsProcessed(uint64_t(state.iterations()) * n()); + state.SetBytesProcessed( + uint64_t(state.iterations()) * n() * + (sizeof(int32_t) + sizeof(uint8_t))); + input_.clear(); + output_.clear(); + } + + inline const int32_t* input() const { + return input_.data(); + } + + inline uint8_t* output() { + return output_.data(); + } + + inline size_t n() const { + return n_; + } + + protected: + std::vector> input_; + std::vector output_; + size_t n_; +}; + +BENCHMARK_F(Requantization, precise__scalar_unsigned32) +(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_precise__scalar_unsigned32( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, precise__scalar_unsigned64) +(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_precise__scalar_unsigned64( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, precise__scalar_signed64)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_precise__scalar_signed64( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, fp32__scalar_lrintf)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_fp32__scalar_lrintf( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, fp32__scalar_magic)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_fp32__scalar_magic( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, gemmlowp__scalar)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_gemmlowp__scalar( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, precise__psimd)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_precise__psimd( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, fp32__psimd)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_fp32__psimd( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 +BENCHMARK_F(Requantization, precise__neon)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_precise__neon( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, fp32__neon)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_fp32__neon( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, q31__neon)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_q31__neon( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, gemmlowp__neon)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_gemmlowp__neon( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} +#endif + +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +BENCHMARK_F(Requantization, precise__sse2)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_precise__sse2( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, precise__ssse3)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_precise__ssse3( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, precise__sse4)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_precise__sse4( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, fp32__sse2)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_fp32__sse2( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, q31__sse2)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_q31__sse2( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, q31__ssse3)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_q31__ssse3( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, q31__sse4)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_q31__sse4( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, gemmlowp__sse2)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_gemmlowp__sse2( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, gemmlowp__ssse3)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_gemmlowp__ssse3( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} + +BENCHMARK_F(Requantization, gemmlowp__sse4)(benchmark::State& state) { + for (auto _ : state) { + pytorch_qnnp_requantize_gemmlowp__sse4( + n(), + input(), + 0x1.0p-12f /* scale */, + 128 /* zero point */, + 1 /* qmin */, + 254 /* qmax */, + output()); + } +} +#endif + +#ifndef PYTORCH_QNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/bench/sgemm.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/sgemm.cc new file mode 100644 index 0000000000000..072855350705f --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/sgemm.cc @@ -0,0 +1,624 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +inline uint32_t divideRoundUp(uint32_t x, uint32_t q) { + return x / q + uint32_t(x % q != 0); +} + +inline uint32_t roundUp(uint32_t x, uint32_t q) { + return q * divideRoundUp(x, q); +} + +static void sgemmBenchmark( + benchmark::State& state, + pytorch_sgemm_ukernel_function sgemm, + uint32_t mc, + uint32_t nc, + uint32_t kc, + uint32_t mr, + uint32_t nr, + uint32_t np, + uint32_t kr) { + const size_t ncStride = roundUp(nc, np); + const size_t kcStride = roundUp(kc, kr); + + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto f32rng = std::bind(std::uniform_real_distribution(), rng); + + std::vector a(mc * kc); + std::generate(a.begin(), a.end(), std::ref(f32rng)); + std::vector k(nc * kc); + std::generate(k.begin(), k.end(), std::ref(f32rng)); + std::vector b(nc); + std::generate(b.begin(), b.end(), std::ref(f32rng)); + std::vector> w( + ncStride * kcStride + ncStride); + std::fill(w.begin(), w.end(), 0.0f); + pytorch_pack_sgemm_w(nc, kc, nr, kr, k.data(), b.data(), w.data()); + std::vector c(mc * nc); + std::fill(c.begin(), c.end(), std::nanf("")); + + pytorch_qnnp_fp32_clamping_params clampingParams{ + std::numeric_limits::infinity(), + -std::numeric_limits::infinity()}; + + for (auto _ : state) { + for (uint32_t m = 0; m < mc; m += mr) { + const uint32_t mb = min(mc - m, mr); + for (uint32_t n = 0; n < nc; n += nr) { + const uint32_t nb = min(nc - n, nr); + sgemm( + mb, + nb, + kc, + a.data() + m * kc, + kc * sizeof(float), + w.data() + n * (kcStride + 1), + c.data() + m * nc + n, + nc * sizeof(float), + &clampingParams); + } + } + } + + state.SetItemsProcessed(uint64_t(state.iterations()) * 2 * mc * nc * kc); +} + +static void sgemm_in_l1( + benchmark::State& state, + pytorch_sgemm_ukernel_function sgemm, + uint32_t mr, + uint32_t nr, + uint32_t np, + uint32_t kr) { + if (!cpuinfo_initialize()) { + state.SkipWithError("cpuinfo initialization failed"); + } + + const size_t l1d_size = cpuinfo_get_l1d_cache(0)->size; + const size_t l1d_reserve = 512; + const size_t kc = roundUp( + ((l1d_size - l1d_reserve) / sizeof(float) - mr * nr) / (mr + nr), + np * kr); + + sgemmBenchmark(state, sgemm, mr /* mc */, nr /* nc */, kc, mr, nr, np, kr); +} + +static void sgemm( + benchmark::State& state, + pytorch_sgemm_ukernel_function sgemm, + uint32_t mr, + uint32_t nr, + uint32_t np, + uint32_t kr) { + const size_t mc = state.range(0); + const size_t nc = state.range(1); + const size_t kc = state.range(2); + + sgemmBenchmark(state, sgemm, mc, nc, kc, mr, nr, np, kr); +} + +/* ShuffleNet v1 with 1 group */ +static void ShuffleNetV1G1(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* M N K */ + b->Args({112 * 112, 24, 3 * 3 * 3}); + b->Args({56 * 56, 36, 24 * 1 * 1}); + b->Args({28 * 28, 120, 36 * 1 * 1}); + b->Args({28 * 28, 36, 144 * 1 * 1}); + b->Args({28 * 28, 144, 36 * 1 * 1}); + b->Args({28 * 28, 72, 144 * 1 * 1}); + b->Args({14 * 14, 144, 72 * 1 * 1}); + b->Args({14 * 14, 72, 288 * 1 * 1}); + b->Args({14 * 14, 288, 72 * 1 * 1}); + b->Args({14 * 14, 144, 288 * 1 * 1}); + b->Args({7 * 7, 288, 144 * 1 * 1}); + b->Args({7 * 7, 144, 576 * 1 * 1}); + b->Args({7 * 7, 576, 144 * 1 * 1}); +} + +/* ShuffleNet v1 with 2 groups */ +static void ShuffleNetV1G2(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* M N K */ + b->Args({112 * 112, 24, 3 * 3 * 3}); + b->Args({56 * 56, 50, 24 * 1 * 1}); + b->Args({28 * 28, 88, 25 * 1 * 1}); + b->Args({28 * 28, 25, 100 * 1 * 1}); + b->Args({28 * 28, 100, 25 * 1 * 1}); + b->Args({28 * 28, 50, 100 * 1 * 1}); + b->Args({14 * 14, 100, 50 * 1 * 1}); + b->Args({14 * 14, 50, 200 * 1 * 1}); + b->Args({14 * 14, 200, 50 * 1 * 1}); + b->Args({14 * 14, 100, 200 * 1 * 1}); + b->Args({7 * 7, 200, 100 * 1 * 1}); + b->Args({7 * 7, 100, 400 * 1 * 1}); + b->Args({7 * 7, 400, 100 * 1 * 1}); +} + +/* ShuffleNet v1 with 3 groups */ +static void ShuffleNetV1G3(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* M N K */ + b->Args({112 * 112, 24, 3 * 3 * 3}); + b->Args({56 * 56, 60, 24 * 1 * 1}); + b->Args({28 * 28, 72, 20 * 1 * 1}); + b->Args({28 * 28, 20, 80 * 1 * 1}); + b->Args({28 * 28, 80, 20 * 1 * 1}); + b->Args({28 * 28, 40, 80 * 1 * 1}); + b->Args({14 * 14, 80, 40 * 1 * 1}); + b->Args({14 * 14, 40, 160 * 1 * 1}); + b->Args({14 * 14, 160, 40 * 1 * 1}); + b->Args({14 * 14, 80, 160 * 1 * 1}); + b->Args({7 * 7, 160, 80 * 1 * 1}); + b->Args({7 * 7, 80, 320 * 1 * 1}); + b->Args({7 * 7, 320, 80 * 1 * 1}); +} + +/* ShuffleNet v1 with 4 groups */ +static void ShuffleNetV1G4(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* M N K */ + b->Args({112 * 112, 24, 3 * 3 * 3}); + b->Args({56 * 56, 68, 24 * 1 * 1}); + b->Args({28 * 28, 62, 17 * 1 * 1}); + b->Args({28 * 28, 17, 68 * 1 * 1}); + b->Args({28 * 28, 68, 17 * 1 * 1}); + b->Args({28 * 28, 34, 68 * 1 * 1}); + b->Args({14 * 14, 68, 34 * 1 * 1}); + b->Args({14 * 14, 34, 136 * 1 * 1}); + b->Args({14 * 14, 136, 34 * 1 * 1}); + b->Args({14 * 14, 68, 136 * 1 * 1}); + b->Args({7 * 7, 136, 68 * 1 * 1}); + b->Args({7 * 7, 68, 272 * 1 * 1}); + b->Args({7 * 7, 272, 68 * 1 * 1}); +} + +/* ShuffleNet v1 with 8 groups */ +static void ShuffleNetV1G8(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* M N K */ + b->Args({112 * 112, 24, 3 * 3 * 3}); + b->Args({56 * 56, 96, 24 * 1 * 1}); + b->Args({28 * 28, 45, 12 * 1 * 1}); + b->Args({28 * 28, 12, 48 * 1 * 1}); + b->Args({28 * 28, 48, 12 * 1 * 1}); + b->Args({28 * 28, 24, 48 * 1 * 1}); + b->Args({14 * 14, 48, 24 * 1 * 1}); + b->Args({14 * 14, 24, 96 * 1 * 1}); + b->Args({14 * 14, 96, 24 * 1 * 1}); + b->Args({14 * 14, 48, 96 * 1 * 1}); + b->Args({7 * 7, 96, 48 * 1 * 1}); + b->Args({7 * 7, 48, 192 * 1 * 1}); + b->Args({7 * 7, 192, 48 * 1 * 1}); +} + +/* ShuffleNet v2 (0.5X scale) */ +static void ShuffleNetV2X05(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* M N K */ + b->Args({112 * 112, 24, 3 * 3 * 3}); + b->Args({56 * 56, 24, 24 * 1 * 1}); + b->Args({28 * 28, 24, 24 * 1 * 1}); + b->Args({28 * 28, 48, 48 * 1 * 1}); + b->Args({14 * 14, 48, 48 * 1 * 1}); + b->Args({14 * 14, 96, 96 * 1 * 1}); + b->Args({7 * 7, 96, 96 * 1 * 1}); + b->Args({7 * 7, 1024, 192 * 1 * 1}); +} + +/* ShuffleNet v2 (1.0X scale) */ +static void ShuffleNetV2X10(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* M N K */ + b->Args({112 * 112, 24, 3 * 3 * 3}); + b->Args({56 * 56, 58, 24 * 1 * 1}); + b->Args({28 * 28, 58, 24 * 1 * 1}); + b->Args({28 * 28, 58, 58 * 1 * 1}); + b->Args({14 * 14, 116, 116 * 1 * 1}); + b->Args({14 * 14, 116, 116 * 1 * 1}); + b->Args({14 * 14, 232, 232 * 1 * 1}); + b->Args({7 * 7, 232, 232 * 1 * 1}); + b->Args({7 * 7, 1024, 464 * 1 * 1}); +} + +/* ShuffleNet v2 (1.5X scale) */ +static void ShuffleNetV2X15(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* M N K */ + b->Args({112 * 112, 24, 3 * 3 * 3}); + b->Args({56 * 56, 88, 24 * 1 * 1}); + b->Args({28 * 28, 88, 24 * 1 * 1}); + b->Args({28 * 28, 88, 88 * 1 * 1}); + b->Args({28 * 28, 176, 176 * 1 * 1}); + b->Args({14 * 14, 176, 176 * 1 * 1}); + b->Args({14 * 14, 352, 352 * 1 * 1}); + b->Args({7 * 7, 352, 352 * 1 * 1}); + b->Args({7 * 7, 1024, 704 * 1 * 1}); +} + +/* ShuffleNet v2 (2.0X scale) */ +static void ShuffleNetV2X20(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* M N K */ + b->Args({112 * 112, 24, 3 * 3 * 3}); + b->Args({56 * 56, 122, 24 * 1 * 1}); + b->Args({28 * 28, 122, 24 * 1 * 1}); + b->Args({28 * 28, 122, 122 * 1 * 1}); + b->Args({28 * 28, 244, 244 * 1 * 1}); + b->Args({14 * 14, 244, 244 * 1 * 1}); + b->Args({14 * 14, 488, 488 * 1 * 1}); + b->Args({7 * 7, 488, 488 * 1 * 1}); + b->Args({7 * 7, 2048, 976 * 1 * 1}); +} + +static void MobileNetV1(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* M N K */ + b->Args({112 * 112, 32, 3 * 3 * 3}); + b->Args({112 * 112, 64, 32 * 1 * 1}); + b->Args({56 * 56, 128, 64 * 1 * 1}); + b->Args({56 * 56, 128, 128 * 1 * 1}); + b->Args({28 * 28, 256, 128 * 1 * 1}); + b->Args({28 * 28, 256, 256 * 1 * 1}); + b->Args({14 * 14, 512, 256 * 1 * 1}); + b->Args({14 * 14, 512, 512 * 1 * 1}); + b->Args({7 * 7, 1024, 512 * 1 * 1}); + b->Args({7 * 7, 1024, 1024 * 1 * 1}); +} + +static void MobileNetV2(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* M N K */ + b->Args({112 * 112, 32, 3 * 3 * 3}); + /************ Bottleneck 1 ************/ + b->Args({112 * 112, 16, 32 * 1 * 1}); + /************ Bottleneck 2 ************/ + b->Args({112 * 112, 96, 16 * 1 * 1}); + b->Args({56 * 56, 24, 96 * 1 * 1}); + b->Args({56 * 56, 144, 24 * 1 * 1}); + b->Args({56 * 56, 24, 144 * 1 * 1}); + /************ Bottleneck 3 ************/ + b->Args({28 * 28, 32, 144 * 1 * 1}); + b->Args({28 * 28, 192, 32 * 1 * 1}); + b->Args({28 * 28, 32, 192 * 1 * 1}); + /************ Bottleneck 4 ************/ + b->Args({14 * 14, 64, 192 * 1 * 1}); + b->Args({14 * 14, 192, 64 * 1 * 1}); + b->Args({14 * 14, 64, 384 * 1 * 1}); + /************ Bottleneck 5 ************/ + b->Args({14 * 14, 96, 384 * 1 * 1}); + b->Args({14 * 14, 576, 96 * 1 * 1}); + b->Args({14 * 14, 96, 576 * 1 * 1}); + /************ Bottleneck 6 ************/ + b->Args({7 * 7, 160, 576 * 1 * 1}); + b->Args({7 * 7, 960, 160 * 1 * 1}); + b->Args({7 * 7, 160, 960 * 1 * 1}); + /************ Bottleneck 7 ************/ + b->Args({7 * 7, 320, 960 * 1 * 1}); + /********* Pre-pooling Conv2D *********/ + b->Args({7 * 7, 1280, 320 * 1 * 1}); + /******** Post-pooling Conv2D *********/ + b->Args({1 * 1, 1000, 1280 * 1 * 1}); +} + +/* SqueezeNet 1.0 */ +static void SqueezeNetV10(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* M N K */ + /*************** Conv 1 ***************/ + b->Args({111 * 111, 96, 3 * 7 * 7}); + /*************** Fire 2 ***************/ + b->Args({55 * 55, 16, 96 * 1 * 1}); + b->Args({55 * 55, 64, 16 * 1 * 1}); + b->Args({55 * 55, 64, 16 * 3 * 3}); + /*************** Fire 3 ***************/ + b->Args({55 * 55, 16, 128 * 1 * 1}); + /*************** Fire 4 ***************/ + b->Args({55 * 55, 32, 128 * 1 * 1}); + b->Args({55 * 55, 128, 32 * 1 * 1}); + b->Args({55 * 55, 128, 32 * 3 * 3}); + /*************** Fire 5 ***************/ + b->Args({27 * 27, 32, 256 * 1 * 1}); + b->Args({27 * 27, 128, 32 * 1 * 1}); + b->Args({27 * 27, 128, 32 * 3 * 3}); + /*************** Fire 6 ***************/ + b->Args({27 * 27, 48, 256 * 1 * 1}); + b->Args({27 * 27, 192, 48 * 1 * 1}); + b->Args({27 * 27, 192, 48 * 3 * 3}); + /*************** Fire 7 ***************/ + b->Args({27 * 27, 48, 384 * 1 * 1}); + /*************** Fire 8 ***************/ + b->Args({27 * 27, 64, 384 * 1 * 1}); + b->Args({27 * 27, 256, 64 * 1 * 1}); + b->Args({27 * 27, 256, 64 * 3 * 3}); + /*************** Fire 9 ***************/ + b->Args({13 * 13, 64, 512 * 1 * 1}); + b->Args({13 * 13, 256, 64 * 1 * 1}); + b->Args({13 * 13, 256, 64 * 3 * 3}); + /*************** Conv 10 **************/ + b->Args({13 * 13, 1000, 512 * 1 * 1}); +} + +/* SqueezeNet 1.1 */ +static void SqueezeNetV11(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* M N K */ + /*************** Conv 1 ***************/ + b->Args({111 * 111, 64, 3 * 3 * 3}); + /*************** Fire 2 ***************/ + b->Args({55 * 55, 16, 64 * 1 * 1}); + b->Args({55 * 55, 64, 16 * 1 * 1}); + b->Args({55 * 55, 64, 16 * 3 * 3}); + /*************** Fire 3 ***************/ + b->Args({55 * 55, 16, 128 * 1 * 1}); + /*************** Fire 4 ***************/ + b->Args({27 * 27, 32, 128 * 1 * 1}); + b->Args({27 * 27, 128, 32 * 1 * 1}); + b->Args({27 * 27, 128, 32 * 3 * 3}); + /*************** Fire 5 ***************/ + b->Args({27 * 27, 32, 256 * 1 * 1}); + /*************** Fire 6 ***************/ + b->Args({13 * 13, 48, 256 * 1 * 1}); + b->Args({13 * 13, 192, 48 * 1 * 1}); + b->Args({13 * 13, 192, 48 * 3 * 3}); + /*************** Fire 7 ***************/ + b->Args({13 * 13, 48, 384 * 1 * 1}); + /*************** Fire 8 ***************/ + b->Args({13 * 13, 64, 384 * 1 * 1}); + b->Args({13 * 13, 256, 64 * 1 * 1}); + b->Args({13 * 13, 256, 64 * 3 * 3}); + /*************** Fire 9 ***************/ + b->Args({13 * 13, 64, 512 * 1 * 1}); + /*************** Conv 10 **************/ + b->Args({13 * 13, 1000, 512 * 1 * 1}); +} + +static void ResNet18(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* M N K */ + b->Args({112 * 112, 64, 3 * 7 * 7}); + b->Args({56 * 56, 64, 64 * 3 * 3}); + b->Args({28 * 28, 128, 64 * 3 * 3}); + b->Args({28 * 28, 128, 128 * 3 * 3}); + b->Args({28 * 28, 128, 64 * 1 * 1}); + b->Args({14 * 14, 256, 128 * 3 * 3}); + b->Args({14 * 14, 256, 256 * 3 * 3}); + b->Args({14 * 14, 256, 128 * 1 * 1}); + b->Args({7 * 7, 512, 256 * 3 * 3}); + b->Args({7 * 7, 512, 512 * 3 * 3}); + b->Args({7 * 7, 512, 256 * 1 * 1}); +} + +static void ResNet50(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* M N K */ + /**************** Conv 1 ***************/ + b->Args({112 * 112, 64, 3 * 7 * 7}); + /* M N K */ + /*************** Conv 2.X **************/ + b->Args({56 * 56, 64, 64 * 1 * 1}); + b->Args({56 * 56, 64, 64 * 3 * 3}); + b->Args({56 * 56, 256, 64 * 1 * 1}); + b->Args({56 * 56, 64, 256 * 1 * 1}); + /* M N K */ + /*************** Conv 3.X **************/ + b->Args({56 * 56, 128, 256 * 1 * 1}); + b->Args({28 * 28, 128, 128 * 3 * 3}); + b->Args({28 * 28, 512, 128 * 1 * 1}); + b->Args({28 * 28, 512, 256 * 1 * 1}); + b->Args({28 * 28, 128, 512 * 1 * 1}); + /* M N K */ + /*************** Conv 4.X **************/ + b->Args({28 * 28, 256, 512 * 1 * 1}); + b->Args({14 * 14, 256, 256 * 3 * 3}); + b->Args({14 * 14, 1024, 256 * 1 * 1}); + b->Args({14 * 14, 1024, 512 * 1 * 1}); + b->Args({14 * 14, 256, 1024 * 1 * 1}); + /* M N K */ + /*************** Conv 5.X **************/ + b->Args({14 * 14, 512, 1024 * 1 * 1}); + b->Args({7 * 7, 512, 512 * 3 * 3}); + b->Args({7 * 7, 2048, 512 * 1 * 1}); + b->Args({7 * 7, 2048, 1024 * 1 * 1}); + b->Args({7 * 7, 512, 2048 * 1 * 1}); +} + +static void VGG(benchmark::internal::Benchmark* b) { + b->ArgNames({"M", "N", "K"}); + + /* M N K */ + /************** Conv 1.1 *************/ + b->Args({224 * 224, 64, 3 * 3 * 3}); + /************** Conv 1.2 *************/ + b->Args({224 * 224, 64, 64 * 3 * 3}); + /************** Conv 2.1 *************/ + b->Args({112 * 112, 128, 64 * 3 * 3}); + /************** Conv 2.2 *************/ + b->Args({112 * 112, 128, 128 * 3 * 3}); + /************** Conv 3.1 *************/ + b->Args({56 * 56, 256, 128 * 3 * 3}); + /************** Conv 3.3 *************/ + b->Args({56 * 56, 256, 256 * 1 * 1}); + /************** Conv 4.1 *************/ + b->Args({28 * 28, 512, 256 * 3 * 3}); + /************** Conv 4.2 *************/ + b->Args({28 * 28, 512, 512 * 3 * 3}); + /************** Conv 4.3 *************/ + b->Args({28 * 28, 512, 512 * 1 * 1}); + /************** Conv 5.X *************/ + b->Args({14 * 14, 512, 512 * 3 * 3}); + /************** Conv 5.3 *************/ + b->Args({14 * 14, 512, 512 * 1 * 1}); +} + +BENCHMARK_CAPTURE( + sgemm_in_l1, + 6x8__psimd, + pytorch_sgemm_ukernel_6x8__psimd, + 6, + 8, + 8, + 1); +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 +BENCHMARK_CAPTURE(sgemm_in_l1, 5x8__neon, pytorch_sgemm_ukernel_5x8__neon, 5, 8, 8, 1); +BENCHMARK_CAPTURE(sgemm_in_l1, 6x8__neon, pytorch_sgemm_ukernel_6x8__neon, 6, 8, 8, 1); +#endif + +static void sgemm_6x8__psimd(benchmark::State& state, const char* net) { + sgemm(state, pytorch_sgemm_ukernel_6x8__psimd, 6, 8, 8, 1); +} + +BENCHMARK_CAPTURE(sgemm_6x8__psimd, mobilenet_v1, "MobileNet v1") + ->Apply(MobileNetV1); +BENCHMARK_CAPTURE(sgemm_6x8__psimd, mobilenet_v2, "MobileNet v2") + ->Apply(MobileNetV2); +BENCHMARK_CAPTURE(sgemm_6x8__psimd, shufflenet_v1_g1, "ShuffleNet v1 (1 group)") + ->Apply(ShuffleNetV1G1); +BENCHMARK_CAPTURE( + sgemm_6x8__psimd, + shufflenet_v1_g2, + "ShuffleNet v1 (2 groups)") + ->Apply(ShuffleNetV1G2); +BENCHMARK_CAPTURE( + sgemm_6x8__psimd, + shufflenet_v1_g3, + "ShuffleNet v1 (3 groups)") + ->Apply(ShuffleNetV1G3); +BENCHMARK_CAPTURE( + sgemm_6x8__psimd, + shufflenet_v1_g4, + "ShuffleNet v1 (4 groups)") + ->Apply(ShuffleNetV1G4); +BENCHMARK_CAPTURE( + sgemm_6x8__psimd, + shufflenet_v1_g8, + "ShuffleNet v1 (8 groups)") + ->Apply(ShuffleNetV1G8); +BENCHMARK_CAPTURE(sgemm_6x8__psimd, shufflenet_v2_x05, "ShuffleNet v2 0.5X") + ->Apply(ShuffleNetV2X05); +BENCHMARK_CAPTURE(sgemm_6x8__psimd, shufflenet_v2_x10, "ShuffleNet v2 1.0X") + ->Apply(ShuffleNetV2X10); +BENCHMARK_CAPTURE(sgemm_6x8__psimd, shufflenet_v2_x15, "ShuffleNet v2 1.5X") + ->Apply(ShuffleNetV2X15); +BENCHMARK_CAPTURE(sgemm_6x8__psimd, shufflenet_v2_x20, "ShuffleNet v2 2.0X") + ->Apply(ShuffleNetV2X20); +BENCHMARK_CAPTURE(sgemm_6x8__psimd, resnet18, "ResNet-18")->Apply(ResNet18); +BENCHMARK_CAPTURE(sgemm_6x8__psimd, resnet50, "ResNet-50")->Apply(ResNet50); +BENCHMARK_CAPTURE(sgemm_6x8__psimd, squeezenet_v10, "SqueezeNet 1.0") + ->Apply(SqueezeNetV10); +BENCHMARK_CAPTURE(sgemm_6x8__psimd, squeezenet_v11, "SqueezeNet 1.1") + ->Apply(SqueezeNetV11); +BENCHMARK_CAPTURE(sgemm_6x8__psimd, vgg, "VGG")->Apply(VGG); + +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 +static void sgemm_5x8__neon(benchmark::State& state, const char* net) { + sgemm(state, pytorch_sgemm_ukernel_5x8__neon, 5, 8, 8, 1); +} + +static void sgemm_6x8__neon(benchmark::State& state, const char* net) { + sgemm(state, pytorch_sgemm_ukernel_6x8__neon, 6, 8, 8, 1); +} + +BENCHMARK_CAPTURE(sgemm_5x8__neon, mobilenet_v1, "MobileNet v1") + ->Apply(MobileNetV1); +BENCHMARK_CAPTURE(sgemm_5x8__neon, mobilenet_v2, "MobileNet v2") + ->Apply(MobileNetV2); +BENCHMARK_CAPTURE(sgemm_5x8__neon, shufflenet_v1_g1, "ShuffleNet v1 (1 group)") + ->Apply(ShuffleNetV1G1); +BENCHMARK_CAPTURE(sgemm_5x8__neon, shufflenet_v1_g2, "ShuffleNet v1 (2 groups)") + ->Apply(ShuffleNetV1G2); +BENCHMARK_CAPTURE(sgemm_5x8__neon, shufflenet_v1_g3, "ShuffleNet v1 (3 groups)") + ->Apply(ShuffleNetV1G3); +BENCHMARK_CAPTURE(sgemm_5x8__neon, shufflenet_v1_g4, "ShuffleNet v1 (4 groups)") + ->Apply(ShuffleNetV1G4); +BENCHMARK_CAPTURE(sgemm_5x8__neon, shufflenet_v1_g8, "ShuffleNet v1 (8 groups)") + ->Apply(ShuffleNetV1G8); +BENCHMARK_CAPTURE(sgemm_5x8__neon, shufflenet_v2_x05, "ShuffleNet v2 0.5X") + ->Apply(ShuffleNetV2X05); +BENCHMARK_CAPTURE(sgemm_5x8__neon, shufflenet_v2_x10, "ShuffleNet v2 1.0X") + ->Apply(ShuffleNetV2X10); +BENCHMARK_CAPTURE(sgemm_5x8__neon, shufflenet_v2_x15, "ShuffleNet v2 1.5X") + ->Apply(ShuffleNetV2X15); +BENCHMARK_CAPTURE(sgemm_5x8__neon, shufflenet_v2_x20, "ShuffleNet v2 2.0X") + ->Apply(ShuffleNetV2X20); +BENCHMARK_CAPTURE(sgemm_5x8__neon, resnet18, "ResNet-18")->Apply(ResNet18); +BENCHMARK_CAPTURE(sgemm_5x8__neon, resnet50, "ResNet-50")->Apply(ResNet50); +BENCHMARK_CAPTURE(sgemm_5x8__neon, squeezenet_v10, "SqueezeNet 1.0") + ->Apply(SqueezeNetV10); +BENCHMARK_CAPTURE(sgemm_5x8__neon, squeezenet_v11, "SqueezeNet 1.1") + ->Apply(SqueezeNetV11); +BENCHMARK_CAPTURE(sgemm_5x8__neon, vgg, "VGG")->Apply(VGG); + +BENCHMARK_CAPTURE(sgemm_6x8__neon, mobilenet_v1, "MobileNet v1") + ->Apply(MobileNetV1); +BENCHMARK_CAPTURE(sgemm_6x8__neon, mobilenet_v2, "MobileNet v2") + ->Apply(MobileNetV2); +BENCHMARK_CAPTURE(sgemm_6x8__neon, shufflenet_v1_g1, "ShuffleNet v1 (1 group)") + ->Apply(ShuffleNetV1G1); +BENCHMARK_CAPTURE(sgemm_6x8__neon, shufflenet_v1_g2, "ShuffleNet v1 (2 groups)") + ->Apply(ShuffleNetV1G2); +BENCHMARK_CAPTURE(sgemm_6x8__neon, shufflenet_v1_g3, "ShuffleNet v1 (3 groups)") + ->Apply(ShuffleNetV1G3); +BENCHMARK_CAPTURE(sgemm_6x8__neon, shufflenet_v1_g4, "ShuffleNet v1 (4 groups)") + ->Apply(ShuffleNetV1G4); +BENCHMARK_CAPTURE(sgemm_6x8__neon, shufflenet_v1_g8, "ShuffleNet v1 (8 groups)") + ->Apply(ShuffleNetV1G8); +BENCHMARK_CAPTURE(sgemm_6x8__neon, shufflenet_v2_x05, "ShuffleNet v2 0.5X") + ->Apply(ShuffleNetV2X05); +BENCHMARK_CAPTURE(sgemm_6x8__neon, shufflenet_v2_x10, "ShuffleNet v2 1.0X") + ->Apply(ShuffleNetV2X10); +BENCHMARK_CAPTURE(sgemm_6x8__neon, shufflenet_v2_x15, "ShuffleNet v2 1.5X") + ->Apply(ShuffleNetV2X15); +BENCHMARK_CAPTURE(sgemm_6x8__neon, shufflenet_v2_x20, "ShuffleNet v2 2.0X") + ->Apply(ShuffleNetV2X20); +BENCHMARK_CAPTURE(sgemm_6x8__neon, resnet18, "ResNet-18")->Apply(ResNet18); +BENCHMARK_CAPTURE(sgemm_6x8__neon, resnet50, "ResNet-50")->Apply(ResNet50); +BENCHMARK_CAPTURE(sgemm_6x8__neon, squeezenet_v10, "SqueezeNet 1.0") + ->Apply(SqueezeNetV10); +BENCHMARK_CAPTURE(sgemm_6x8__neon, squeezenet_v11, "SqueezeNet 1.1") + ->Apply(SqueezeNetV11); +BENCHMARK_CAPTURE(sgemm_6x8__neon, vgg, "VGG")->Apply(VGG); +#endif + +#ifndef PYTORCH_QNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/bench/sigmoid.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/sigmoid.cc new file mode 100644 index 0000000000000..5b98dde25b8f4 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/sigmoid.cc @@ -0,0 +1,99 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include + +#include + +static void sigmoid_q8(benchmark::State& state) { + const size_t batchSize = static_cast(state.range(0)); + const size_t channels = static_cast(state.range(1)); + + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input(batchSize * channels); + std::vector output(batchSize * channels); + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::fill(output.begin(), output.end(), 0xA5); + + pytorch_qnnp_status status = pytorch_qnnp_initialize(); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to initialize QNNPACK"); + } + + pytorch_qnnp_operator_t sigmoidOperator = nullptr; + status = pytorch_qnnp_create_sigmoid_nc_q8( + channels, + 127 /* input zero point */, + 1.0f /* input scale */, + 0 /* output zero point */, + 1.0f / 256.0f /* output scale */, + 0 /* output min */, + 255 /* output max */, + 0 /* flags */, + &sigmoidOperator); + if (status != pytorch_qnnp_status_success || sigmoidOperator == nullptr) { + state.SkipWithError("failed to create Sigmoid operator"); + } + + status = pytorch_qnnp_setup_sigmoid_nc_q8( + sigmoidOperator, + batchSize, + input.data(), + channels /* input:stride */, + output.data(), + channels /* output:stride */); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to setup Sigmoid operator"); + } + + for (auto _ : state) { + status = + pytorch_qnnp_run_operator(sigmoidOperator, nullptr /* thread pool */); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to run Sigmoid operator"); + } + } + + const size_t itemsPerIteration = batchSize * channels; + state.SetItemsProcessed( + int64_t(state.iterations()) * int64_t(itemsPerIteration)); + + const size_t bytesPerIteration = 2 * itemsPerIteration * sizeof(uint8_t); + state.SetBytesProcessed( + int64_t(state.iterations()) * int64_t(bytesPerIteration)); + + status = pytorch_qnnp_delete_operator(sigmoidOperator); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to delete Sigmoid operator"); + } +} + +static void CharacteristicArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "C"}); + + int32_t c = 16; + for (int32_t n = 224; n >= 7; n /= 2) { + b->Args({n * n, c}); + c *= 2; + } +} + +BENCHMARK(sigmoid_q8)->Apply(CharacteristicArguments); + +#ifndef PYTORCH_QNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/bench/softargmax.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/softargmax.cc new file mode 100644 index 0000000000000..f07d8aeece8a2 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/bench/softargmax.cc @@ -0,0 +1,101 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include + +#include + +static void softargmax_q8(benchmark::State& state) { + const size_t batchSize = static_cast(state.range(0)); + const size_t channels = static_cast(state.range(1)); + + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input(batchSize * channels); + std::vector output(batchSize * channels); + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::fill(output.begin(), output.end(), 0xA5); + + pytorch_qnnp_status status = pytorch_qnnp_initialize(); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to initialize QNNPACK"); + } + + pytorch_qnnp_operator_t softArgMaxOperator = nullptr; + status = pytorch_qnnp_create_softargmax_nc_q8( + channels, + 1.0f /* input scale */, + 0 /* output zero point */, + 1.0f / 256.0f /* output scale */, + 0 /* flags */, + &softArgMaxOperator); + if (status != pytorch_qnnp_status_success || softArgMaxOperator == nullptr) { + state.SkipWithError("failed to create SoftArgMax operator"); + } + + status = pytorch_qnnp_setup_softargmax_nc_q8( + softArgMaxOperator, + batchSize, + input.data(), + channels /* input:stride */, + output.data(), + channels /* output:stride */); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to setup SoftArgMax operator"); + } + + for (auto _ : state) { + status = pytorch_qnnp_run_operator( + softArgMaxOperator, nullptr /* thread pool */); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to run SoftArgMax operator"); + } + } + + const size_t itemsPerIteration = batchSize * channels; + state.SetItemsProcessed( + int64_t(state.iterations()) * int64_t(itemsPerIteration)); + + const size_t bytesPerIteration = 2 * itemsPerIteration * sizeof(uint8_t); + state.SetBytesProcessed( + int64_t(state.iterations()) * int64_t(bytesPerIteration)); + + status = pytorch_qnnp_delete_operator(softArgMaxOperator); + if (status != pytorch_qnnp_status_success) { + state.SkipWithError("failed to delete SoftArgMax operator"); + } +} + +static void CharacteristicArguments(benchmark::internal::Benchmark* b) { + b->ArgNames({"N", "C"}); + + /* CIFAR-10 */ + b->Args({1, 10}); + /* CIFAR-100 */ + b->Args({1, 100}); + /* ImageNet-1K */ + b->Args({1, 1000}); + /* ImageNet-1K+1 */ + b->Args({1, 1001}); + /* ImageNet-22K */ + b->Args({1, 21841}); +} + +BENCHMARK(softargmax_q8)->Apply(CharacteristicArguments); + +#ifndef PYTORCH_QNNPACK_BENCHMARK_NO_MAIN +BENCHMARK_MAIN(); +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadCpuinfo.cmake b/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadCpuinfo.cmake new file mode 100644 index 0000000000000..f20e9b33f8229 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadCpuinfo.cmake @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +CMAKE_MINIMUM_REQUIRED(VERSION 2.8.12 FATAL_ERROR) + +PROJECT(cpuinfo-download NONE) + +INCLUDE(ExternalProject) +ExternalProject_Add(cpuinfo + GIT_REPOSITORY https://github.com/Maratyszcza/cpuinfo.git + GIT_TAG master + SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/cpuinfo" + BINARY_DIR "${CONFU_DEPENDENCIES_BINARY_DIR}/cpuinfo" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadFP16.cmake b/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadFP16.cmake new file mode 100644 index 0000000000000..ccbd1fba6a967 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadFP16.cmake @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +CMAKE_MINIMUM_REQUIRED(VERSION 2.8.12 FATAL_ERROR) + +PROJECT(fp16-download NONE) + +INCLUDE(ExternalProject) +ExternalProject_Add(fp16 + GIT_REPOSITORY https://github.com/Maratyszcza/FP16.git + GIT_TAG master + SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/fp16" + BINARY_DIR "${CONFU_DEPENDENCIES_BINARY_DIR}/fp16" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadFXdiv.cmake b/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadFXdiv.cmake new file mode 100644 index 0000000000000..d04bc7cb94304 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadFXdiv.cmake @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +CMAKE_MINIMUM_REQUIRED(VERSION 2.8.12 FATAL_ERROR) + +PROJECT(fxdiv-download NONE) + +INCLUDE(ExternalProject) +ExternalProject_Add(fxdiv + GIT_REPOSITORY https://github.com/Maratyszcza/FXdiv.git + GIT_TAG master + SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/fxdiv" + BINARY_DIR "${CONFU_DEPENDENCIES_BINARY_DIR}/fxdiv" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadGoogleBenchmark.cmake b/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadGoogleBenchmark.cmake new file mode 100644 index 0000000000000..f57744b927e8c --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadGoogleBenchmark.cmake @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +CMAKE_MINIMUM_REQUIRED(VERSION 2.8.12 FATAL_ERROR) + +PROJECT(googlebenchmark-download NONE) + +INCLUDE(ExternalProject) +ExternalProject_Add(googlebenchmark + URL https://github.com/google/benchmark/archive/v1.4.1.zip + URL_HASH SHA256=61ae07eb5d4a0b02753419eb17a82b7d322786bb36ab62bd3df331a4d47c00a7 + SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/googlebenchmark" + BINARY_DIR "${CONFU_DEPENDENCIES_BINARY_DIR}/googlebenchmark" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadGoogleTest.cmake b/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadGoogleTest.cmake new file mode 100644 index 0000000000000..559d86be1b734 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadGoogleTest.cmake @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +CMAKE_MINIMUM_REQUIRED(VERSION 2.8.12 FATAL_ERROR) + +PROJECT(googletest-download NONE) + +INCLUDE(ExternalProject) +ExternalProject_Add(googletest + URL https://github.com/google/googletest/archive/release-1.8.0.zip + URL_HASH SHA256=f3ed3b58511efd272eb074a3a6d6fb79d7c2e6a0e374323d1e6bcbcc1ef141bf + SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/googletest" + BINARY_DIR "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadPSimd.cmake b/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadPSimd.cmake new file mode 100644 index 0000000000000..4178ea380f249 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadPSimd.cmake @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +CMAKE_MINIMUM_REQUIRED(VERSION 2.8.12 FATAL_ERROR) + +PROJECT(psimd-download NONE) + +INCLUDE(ExternalProject) +ExternalProject_Add(psimd + GIT_REPOSITORY https://github.com/Maratyszcza/psimd.git + GIT_TAG master + SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/psimd" + BINARY_DIR "${CONFU_DEPENDENCIES_BINARY_DIR}/psimd" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadPThreadPool.cmake b/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadPThreadPool.cmake new file mode 100644 index 0000000000000..0fe01e211bcef --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/cmake/DownloadPThreadPool.cmake @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +CMAKE_MINIMUM_REQUIRED(VERSION 2.8.12 FATAL_ERROR) + +PROJECT(pthreadpool-download NONE) + +INCLUDE(ExternalProject) +ExternalProject_Add(pthreadpool + GIT_REPOSITORY https://github.com/Maratyszcza/pthreadpool.git + GIT_TAG master + SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/pthreadpool" + BINARY_DIR "${CONFU_DEPENDENCIES_BINARY_DIR}/pthreadpool" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/configure.py b/aten/src/ATen/native/quantized/cpu/qnnpack/configure.py new file mode 100755 index 0000000000000..ac06b85f68569 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/configure.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python +# +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import confu +from confu import arm, x86 + + +parser = confu.standard_parser() + + +def main(args): + options = parser.parse_args(args) + build = confu.Build.from_options(options) + + build.export_cpath("include", ["q8gemm.h"]) + + with build.options( + source_dir="src", + deps=[ + build.deps.cpuinfo, + build.deps.clog, + build.deps.psimd, + build.deps.fxdiv, + build.deps.pthreadpool, + build.deps.FP16, + ], + extra_include_dirs="src", + ): + + requantization_objects = [ + build.cc("requantization/precise-scalar.c"), + build.cc("requantization/fp32-scalar.c"), + build.cc("requantization/q31-scalar.c"), + build.cc("requantization/gemmlowp-scalar.c"), + ] + with build.options(isa=arm.neon if build.target.is_arm else None): + requantization_objects += [ + build.cc("requantization/precise-psimd.c"), + build.cc("requantization/fp32-psimd.c"), + ] + if build.target.is_x86 or build.target.is_x86_64: + with build.options(isa=x86.sse2): + requantization_objects += [ + build.cc("requantization/precise-sse2.c"), + build.cc("requantization/fp32-sse2.c"), + build.cc("requantization/q31-sse2.c"), + build.cc("requantization/gemmlowp-sse2.c"), + ] + with build.options(isa=x86.ssse3): + requantization_objects += [ + build.cc("requantization/precise-ssse3.c"), + build.cc("requantization/q31-ssse3.c"), + build.cc("requantization/gemmlowp-ssse3.c"), + ] + with build.options(isa=x86.sse4_1): + requantization_objects += [ + build.cc("requantization/precise-sse4.c"), + build.cc("requantization/q31-sse4.c"), + build.cc("requantization/gemmlowp-sse4.c"), + ] + if build.target.is_arm or build.target.is_arm64: + with build.options(isa=arm.neon if build.target.is_arm else None): + requantization_objects += [ + build.cc("requantization/precise-neon.c"), + build.cc("requantization/fp32-neon.c"), + build.cc("requantization/q31-neon.c"), + build.cc("requantization/gemmlowp-neon.c"), + ] + + qnnpytorch_pack_objects = [ + # Common parts + build.cc("init.c"), + build.cc("operator-delete.c"), + build.cc("operator-run.c"), + # Operators + build.cc("add.c"), + build.cc("average-pooling.c"), + build.cc("channel-shuffle.c"), + build.cc("clamp.c"), + build.cc("convolution.c"), + build.cc("indirection.c"), + build.cc("deconvolution.c"), + build.cc("fully-connected.c"), + build.cc("global-average-pooling.c"), + build.cc("leaky-relu.c"), + build.cc("max-pooling.c"), + build.cc("sigmoid.c"), + build.cc("softargmax.c"), + # Scalar micro-kernels + build.cc("u8lut32norm/scalar.c"), + build.cc("x8lut/scalar.c"), + ] + + with build.options(isa=arm.neon if build.target.is_arm else None): + qnnpytorch_pack_objects += [ + build.cc("sconv/6x8-psimd.c"), + build.cc("sdwconv/up4x9-psimd.c"), + build.cc("sgemm/6x8-psimd.c"), + ] + + with build.options(isa=arm.neon if build.target.is_arm else None): + if build.target.is_arm or build.target.is_arm64: + qnnpytorch_pack_objects += [ + build.cc("q8avgpool/mp8x9p8q-neon.c"), + build.cc("q8avgpool/up8x9-neon.c"), + build.cc("q8avgpool/up8xm-neon.c"), + build.cc("q8conv/4x8-neon.c"), + build.cc("q8conv/8x8-neon.c"), + build.cc("q8dwconv/mp8x25-neon.c"), + build.cc("q8dwconv/up8x9-neon.c"), + build.cc("q8gavgpool/mp8x7p7q-neon.c"), + build.cc("q8gavgpool/up8x7-neon.c"), + build.cc("q8gavgpool/up8xm-neon.c"), + build.cc("q8gemm/4x-sumrows-neon.c"), + build.cc("q8gemm/4x8-neon.c"), + build.cc("q8gemm/4x8c2-xzp-neon.c"), + build.cc("q8gemm/6x4-neon.c"), + build.cc("q8gemm/8x8-neon.c"), + build.cc("q8vadd/neon.c"), + build.cc("sgemm/5x8-neon.c"), + build.cc("sgemm/6x8-neon.c"), + build.cc("u8clamp/neon.c"), + build.cc("u8maxpool/16x9p8q-neon.c"), + build.cc("u8maxpool/sub16-neon.c"), + build.cc("u8rmax/neon.c"), + build.cc("x8zip/x2-neon.c"), + build.cc("x8zip/x3-neon.c"), + build.cc("x8zip/x4-neon.c"), + build.cc("x8zip/xm-neon.c"), + ] + if build.target.is_arm: + qnnpytorch_pack_objects += [ + build.cc("hgemm/8x8-aarch32-neonfp16arith.S"), + build.cc("q8conv/4x8-aarch32-neon.S"), + build.cc("q8dwconv/up8x9-aarch32-neon.S"), + build.cc("q8gemm/4x8-aarch32-neon.S"), + build.cc("q8gemm/4x8c2-xzp-aarch32-neon.S"), + ] + if build.target.is_arm64: + qnnpytorch_pack_objects += [ + build.cc("q8gemm/8x8-aarch64-neon.S"), + build.cc("q8conv/8x8-aarch64-neon.S"), + ] + if build.target.is_x86 or build.target.is_x86_64: + with build.options(isa=x86.sse2): + qnnpytorch_pack_objects += [ + build.cc("q8avgpool/mp8x9p8q-sse2.c"), + build.cc("q8avgpool/up8x9-sse2.c"), + build.cc("q8avgpool/up8xm-sse2.c"), + build.cc("q8conv/4x4c2-sse2.c"), + build.cc("q8dwconv/mp8x25-sse2.c"), + build.cc("q8dwconv/up8x9-sse2.c"), + build.cc("q8gavgpool/mp8x7p7q-sse2.c"), + build.cc("q8gavgpool/up8x7-sse2.c"), + build.cc("q8gavgpool/up8xm-sse2.c"), + build.cc("q8gemm/2x4c8-sse2.c"), + build.cc("q8gemm/4x4c2-sse2.c"), + build.cc("q8vadd/sse2.c"), + build.cc("u8clamp/sse2.c"), + build.cc("u8maxpool/16x9p8q-sse2.c"), + build.cc("u8maxpool/sub16-sse2.c"), + build.cc("u8rmax/sse2.c"), + build.cc("x8zip/x2-sse2.c"), + build.cc("x8zip/x3-sse2.c"), + build.cc("x8zip/x4-sse2.c"), + build.cc("x8zip/xm-sse2.c"), + ] + build.static_library("qnnpack", qnnpytorch_pack_objects) + + with build.options( + source_dir="test", + deps={ + ( + build, + build.deps.cpuinfo, + build.deps.clog, + build.deps.pthreadpool, + build.deps.FP16, + build.deps.googletest, + ): any, + "log": build.target.is_android, + }, + extra_include_dirs=["src", "test"], + ): + + build.unittest("hgemm-test", build.cxx("hgemm.cc")) + build.unittest("q8avgpool-test", build.cxx("q8avgpool.cc")) + build.unittest("q8conv-test", build.cxx("q8conv.cc")) + build.unittest("q8dwconv-test", build.cxx("q8dwconv.cc")) + build.unittest("q8gavgpool-test", build.cxx("q8gavgpool.cc")) + build.unittest("q8gemm-test", build.cxx("q8gemm.cc")) + build.unittest("q8vadd-test", build.cxx("q8vadd.cc")) + build.unittest("sconv-test", build.cxx("sconv.cc")) + build.unittest("sgemm-test", build.cxx("sgemm.cc")) + build.unittest("u8clamp-test", build.cxx("u8clamp.cc")) + build.unittest("u8lut32norm-test", build.cxx("u8lut32norm.cc")) + build.unittest("u8maxpool-test", build.cxx("u8maxpool.cc")) + build.unittest("u8rmax-test", build.cxx("u8rmax.cc")) + build.unittest("x8lut-test", build.cxx("x8lut.cc")) + build.unittest("x8zip-test", build.cxx("x8zip.cc")) + + build.unittest("add-test", build.cxx("add.cc")) + build.unittest("average-pooling-test", build.cxx("average-pooling.cc")) + build.unittest("channel-shuffle-test", build.cxx("channel-shuffle.cc")) + build.unittest("clamp-test", build.cxx("clamp.cc")) + build.unittest("convolution-test", build.cxx("convolution.cc")) + build.unittest("deconvolution-test", build.cxx("deconvolution.cc")) + build.unittest("fully-connected-test", build.cxx("fully-connected.cc")) + build.unittest( + "global-average-pooling-test", build.cxx("global-average-pooling.cc") + ) + build.unittest("leaky-relu-test", build.cxx("leaky-relu.cc")) + build.unittest("max-pooling-test", build.cxx("max-pooling.cc")) + build.unittest("sigmoid-test", build.cxx("sigmoid.cc")) + build.unittest("softargmax-test", build.cxx("softargmax.cc")) + build.unittest( + "requantization-test", + [build.cxx("requantization.cc")] + requantization_objects, + ) + + benchmark_isa = None + if build.target.is_arm: + benchmark_isa = arm.neon + elif build.target.is_x86: + benchmark_isa = x86.sse4_1 + with build.options( + source_dir="bench", + deps={ + ( + build, + build.deps.cpuinfo, + build.deps.clog, + build.deps.pthreadpool, + build.deps.FP16, + build.deps.googlebenchmark, + ): any, + "log": build.target.is_android, + }, + isa=benchmark_isa, + extra_include_dirs="src", + ): + + build.benchmark("add-bench", build.cxx("add.cc")) + build.benchmark("average-pooling-bench", build.cxx("average-pooling.cc")) + build.benchmark("channel-shuffle-bench", build.cxx("channel-shuffle.cc")) + build.benchmark("convolution-bench", build.cxx("convolution.cc")) + build.benchmark( + "global-average-pooling-bench", build.cxx("global-average-pooling.cc") + ) + build.benchmark("max-pooling-bench", build.cxx("max-pooling.cc")) + build.benchmark("sigmoid-bench", build.cxx("sigmoid.cc")) + build.benchmark("softargmax-bench", build.cxx("softargmax.cc")) + + build.benchmark("q8gemm-bench", build.cxx("q8gemm.cc")) + build.benchmark("hgemm-bench", build.cxx("hgemm.cc")) + build.benchmark("sgemm-bench", build.cxx("sgemm.cc")) + build.benchmark( + "requantization-bench", + [build.cxx("requantization.cc")] + requantization_objects, + ) + + return build + + +if __name__ == "__main__": + import sys + + main(sys.argv[1:]).generate() diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/confu.yaml b/aten/src/ATen/native/quantized/cpu/qnnpack/confu.yaml new file mode 100644 index 0000000000000..5fd39d17d6c5c --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/confu.yaml @@ -0,0 +1,18 @@ +name: qnnpack +title: Quantized UINT8 Functions for Mobile +license: Apache 2.0 +deps: + - name: cpuinfo + url: https://github.com/pytorch/cpuinfo.git + - name: fxdiv + url: https://github.com/Maratyszcza/FXdiv.git + - name: psimd + url: https://github.com/Maratyszcza/psimd.git + - name: pthreadpool + url: https://github.com/Maratyszcza/pthreadpool.git + - name: FP16 + url: https://github.com/Maratyszcza/FP16.git + - name: clog + dir: deps/clog + - name: googletest + - name: googlebenchmark diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/.gitignore b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/.gitignore new file mode 100644 index 0000000000000..73b299889044e --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/.gitignore @@ -0,0 +1,19 @@ +# Ninja files +build.ninja + +# Build objects and artifacts +deps/ +build/ +bin/ +lib/ +*.pyc +*.pyo + +# System files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/CMakeLists.txt b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/CMakeLists.txt new file mode 100644 index 0000000000000..746531324765c --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/CMakeLists.txt @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +CMAKE_MINIMUM_REQUIRED(VERSION 3.1 FATAL_ERROR) + +INCLUDE(GNUInstallDirs) + +# ---[ Project and semantic versioning. +PROJECT(clog C CXX) + +# ---[ Options. +SET(CLOG_RUNTIME_TYPE "default" CACHE STRING "Type of runtime library (shared, static, or default) to use") +SET_PROPERTY(CACHE CLOG_RUNTIME_TYPE PROPERTY STRINGS default static shared) +IF(ANDROID) + OPTION(CLOG_LOG_TO_STDIO "Log errors, warnings, and information to stdout/stderr" OFF) +ELSE() + OPTION(CLOG_LOG_TO_STDIO "Log errors, warnings, and information to stdout/stderr" ON) +ENDIF() +OPTION(CLOG_BUILD_TESTS "Build clog tests" ON) + +# ---[ CMake options +IF(CLOG_BUILD_TESTS) + ENABLE_TESTING() +ENDIF() + +MACRO(CLOG_TARGET_RUNTIME_LIBRARY target) + IF(MSVC AND NOT CLOG_RUNTIME_TYPE STREQUAL "default") + IF(CLOG_RUNTIME_TYPE STREQUAL "shared") + TARGET_COMPILE_OPTIONS(${target} PRIVATE + "/MD$<$:d>") + ELSEIF(CLOG_RUNTIME_TYPE STREQUAL "static") + TARGET_COMPILE_OPTIONS(${target} PRIVATE + "/MT$<$:d>") + ENDIF() + ENDIF() +ENDMACRO() + +# ---[ Download deps +SET(CONFU_DEPENDENCIES_SOURCE_DIR ${CMAKE_SOURCE_DIR}/deps + CACHE PATH "Confu-style dependencies source directory") +SET(CONFU_DEPENDENCIES_BINARY_DIR ${CMAKE_BINARY_DIR}/deps + CACHE PATH "Confu-style dependencies binary directory") + +IF(CLOG_BUILD_TESTS) + IF(NOT DEFINED GOOGLETEST_SOURCE_DIR) + MESSAGE(STATUS "Downloading Google Test to ${CONFU_DEPENDENCIES_SOURCE_DIR}/googletest (define GOOGLETEST_SOURCE_DIR to avoid it)") + CONFIGURE_FILE(cmake/DownloadGoogleTest.cmake "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest-download/CMakeLists.txt") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . + WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest-download") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . + WORKING_DIRECTORY "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest-download") + SET(GOOGLETEST_SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/googletest" CACHE STRING "Google Test source directory") + ENDIF() +ENDIF() + +# ---[ clog library +ADD_LIBRARY(clog STATIC src/clog.c) +SET_TARGET_PROPERTIES(clog PROPERTIES + C_STANDARD 99 + C_EXTENSIONS NO) +CLOG_TARGET_RUNTIME_LIBRARY(clog) +SET_TARGET_PROPERTIES(clog PROPERTIES PUBLIC_HEADER include/clog.h) +TARGET_INCLUDE_DIRECTORIES(clog BEFORE PUBLIC include) +IF(CLOG_LOG_TO_STDIO) + TARGET_COMPILE_DEFINITIONS(clog PRIVATE CLOG_LOG_TO_STDIO=1) +ELSE() + TARGET_COMPILE_DEFINITIONS(clog PRIVATE CLOG_LOG_TO_STDIO=0) +ENDIF() +IF(ANDROID AND NOT CLOG_LOG_TO_STDIO) + TARGET_LINK_LIBRARIES(clog PRIVATE log) +ENDIF() + +INSTALL(TARGETS clog + LIBRARY DESTINATION "${CMAKE_INSTALL_LIBDIR}" + ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}" + PUBLIC_HEADER DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}") + +# ---[ clog tests +IF(CLOG_BUILD_TESTS) + # ---[ Build google test + IF(NOT TARGET gtest) + IF(MSVC AND NOT CLOG_RUNTIME_TYPE STREQUAL "static") + SET(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + ENDIF() + ADD_SUBDIRECTORY( + "${GOOGLETEST_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest") + ENDIF() + + ADD_EXECUTABLE(clog-test test/clog.cc) + SET_TARGET_PROPERTIES(clog-test PROPERTIES + CXX_STANDARD 11 + CXX_EXTENSIONS NO) + CLOG_TARGET_RUNTIME_LIBRARY(clog-test) + TARGET_LINK_LIBRARIES(clog-test PRIVATE clog gtest gtest_main) + ADD_TEST(clog-test clog-test) +ENDIF() diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/LICENSE b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/LICENSE new file mode 100644 index 0000000000000..306de3d8f1628 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/LICENSE @@ -0,0 +1,26 @@ +Copyright (C) 2018 Marat Dukhan +Copyright (c) 2017-2018 Facebook Inc. +Copyright (c) 2017 Georgia Institute of Technology + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/README.md b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/README.md new file mode 100644 index 0000000000000..17fc709314e6d --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/README.md @@ -0,0 +1,57 @@ +# clog: C-style (a-la printf) logging library + +[![BSD (2 clause) License](https://img.shields.io/badge/License-BSD%202--Clause%20%22Simplified%22%20License-blue.svg)](https://github.com/pytorch/cpuinfo/blob/master/deps/clog/LICENSE) + +C-style library for logging errors, warnings, information notes, and debug information. + +## Features + +- printf-style interface for formatting variadic parameters. +- Separate functions for logging errors, warnings, information notes, and debug information. +- Independent logging settings for different modules. +- Logging to logcat on Android and stderr/stdout on other platforms. +- Compatible with C99 and C++. +- Covered with unit tests. + +## Example + +```c +#include + +#ifndef MYMODULE_LOG_LEVEL + #define MYMODULE_LOG_LEVEL CLOG_DEBUG +#endif + +CLOG_DEFINE_LOG_DEBUG(mymodule_, "My Module", MYMODULE_LOG_LEVEL); +CLOG_DEFINE_LOG_INFO(mymodule_, "My Module", MYMODULE_LOG_LEVEL); +CLOG_DEFINE_LOG_WARNING(mymodule_, "My Module", MYMODULE_LOG_LEVEL); +CLOG_DEFINE_LOG_ERROR(mymodule_, "My Module", MYMODULE_LOG_LEVEL); + +... + +void some_function(...) { + int status = ... + if (status != 0) { + mymodule_log_error( + "something really bad happened: " + "operation failed with status %d", status); + } + + uint32_t expected_zero = ... + if (expected_zero != 0) { + mymodule_log_warning( + "something suspicious happened (var = %"PRIu32"), " + "fall back to generic implementation", expected_zero); + } + + void* usually_non_null = ... + if (usually_non_null == NULL) { + mymodule_log_info( + "something unusual, but common, happened: " + "enabling work-around"); + } + + float a = ... + mymodule_log_debug("computed a = %.7f", a); +} +``` diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/cmake/DownloadGoogleTest.cmake b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/cmake/DownloadGoogleTest.cmake new file mode 100644 index 0000000000000..559d86be1b734 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/cmake/DownloadGoogleTest.cmake @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +CMAKE_MINIMUM_REQUIRED(VERSION 2.8.12 FATAL_ERROR) + +PROJECT(googletest-download NONE) + +INCLUDE(ExternalProject) +ExternalProject_Add(googletest + URL https://github.com/google/googletest/archive/release-1.8.0.zip + URL_HASH SHA256=f3ed3b58511efd272eb074a3a6d6fb79d7c2e6a0e374323d1e6bcbcc1ef141bf + SOURCE_DIR "${CONFU_DEPENDENCIES_SOURCE_DIR}/googletest" + BINARY_DIR "${CONFU_DEPENDENCIES_BINARY_DIR}/googletest" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/configure.py b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/configure.py new file mode 100755 index 0000000000000..d6b829ca8c7a4 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/configure.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import confu +parser = confu.standard_parser("clog configuration script") + + +def main(args): + options = parser.parse_args(args) + build = confu.Build.from_options(options) + + build.export_cpath("include", ["clog.h"]) + + with build.options(source_dir="src", extra_include_dirs="src"): + build.static_library("clog", build.cc("clog.c")) + + with build.options(source_dir="test", deps={ + (build, build.deps.googletest): all, + "log": build.target.is_android}): + build.unittest("clog-test", build.cxx("clog.cc")) + + return build + +if __name__ == "__main__": + import sys + main(sys.argv[1:]).generate() diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/confu.yaml b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/confu.yaml new file mode 100644 index 0000000000000..b033fa2761c92 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/confu.yaml @@ -0,0 +1,5 @@ +name: clog +title: C-style (a-la printf) logging library +license: Simplified BSD +deps: + - name: googletest diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/include/clog.h b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/include/clog.h new file mode 100644 index 0000000000000..bf09cd0cb6de4 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/include/clog.h @@ -0,0 +1,123 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#define CLOG_NONE 0 +#define CLOG_FATAL 1 +#define CLOG_ERROR 2 +#define CLOG_WARNING 3 +#define CLOG_INFO 4 +#define CLOG_DEBUG 5 + +#ifndef CLOG_VISIBILITY +#if defined(__ELF__) +#define CLOG_VISIBILITY __attribute__((__visibility__("internal"))) +#elif defined(__MACH__) +#define CLOG_VISIBILITY __attribute__((__visibility__("hidden"))) +#else +#define CLOG_VISIBILITY +#endif +#endif + +#ifndef CLOG_ARGUMENTS_FORMAT +#if defined(__GNUC__) +#define CLOG_ARGUMENTS_FORMAT __attribute__((__format__(__printf__, 1, 2))) +#else +#define CLOG_ARGUMENTS_FORMAT +#endif +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +CLOG_VISIBILITY void clog_vlog_debug( + const char* module, + const char* format, + va_list args); +CLOG_VISIBILITY void clog_vlog_info( + const char* module, + const char* format, + va_list args); +CLOG_VISIBILITY void clog_vlog_warning( + const char* module, + const char* format, + va_list args); +CLOG_VISIBILITY void clog_vlog_error( + const char* module, + const char* format, + va_list args); +CLOG_VISIBILITY void clog_vlog_fatal( + const char* module, + const char* format, + va_list args); + +#define CLOG_DEFINE_LOG_DEBUG(log_debug_function_name, module, level) \ + CLOG_ARGUMENTS_FORMAT \ + inline static void log_debug_function_name(const char* format, ...) { \ + if (level >= CLOG_DEBUG) { \ + va_list args; \ + va_start(args, format); \ + clog_vlog_debug(module, format, args); \ + va_end(args); \ + } \ + } + +#define CLOG_DEFINE_LOG_INFO(log_info_function_name, module, level) \ + CLOG_ARGUMENTS_FORMAT \ + inline static void log_info_function_name(const char* format, ...) { \ + if (level >= CLOG_INFO) { \ + va_list args; \ + va_start(args, format); \ + clog_vlog_info(module, format, args); \ + va_end(args); \ + } \ + } + +#define CLOG_DEFINE_LOG_WARNING(log_warning_function_name, module, level) \ + CLOG_ARGUMENTS_FORMAT \ + inline static void log_warning_function_name(const char* format, ...) { \ + if (level >= CLOG_WARNING) { \ + va_list args; \ + va_start(args, format); \ + clog_vlog_warning(module, format, args); \ + va_end(args); \ + } \ + } + +#define CLOG_DEFINE_LOG_ERROR(log_error_function_name, module, level) \ + CLOG_ARGUMENTS_FORMAT \ + inline static void log_error_function_name(const char* format, ...) { \ + if (level >= CLOG_ERROR) { \ + va_list args; \ + va_start(args, format); \ + clog_vlog_error(module, format, args); \ + va_end(args); \ + } \ + } + +#define CLOG_DEFINE_LOG_FATAL(log_fatal_function_name, module, level) \ + CLOG_ARGUMENTS_FORMAT \ + inline static void log_fatal_function_name(const char* format, ...) { \ + if (level >= CLOG_FATAL) { \ + va_list args; \ + va_start(args, format); \ + clog_vlog_fatal(module, format, args); \ + va_end(args); \ + } \ + abort(); \ + } + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/src/clog.c b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/src/clog.c new file mode 100644 index 0000000000000..8e6073f9a7096 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/src/clog.c @@ -0,0 +1,524 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#ifdef _WIN32 +#include +#else +#include +#endif +#ifdef __ANDROID__ +#include +#endif + +#ifndef CLOG_LOG_TO_STDIO +#ifdef __ANDROID__ +#define CLOG_LOG_TO_STDIO 0 +#else +#define CLOG_LOG_TO_STDIO 1 +#endif +#endif + +#include + +/* Messages up to this size are formatted entirely on-stack, and don't allocate + * heap memory */ +#define CLOG_STACK_BUFFER_SIZE 1024 + +#define CLOG_FATAL_PREFIX "Fatal error: " +#define CLOG_FATAL_PREFIX_LENGTH 13 +#define CLOG_FATAL_PREFIX_FORMAT "Fatal error in %s: " +#define CLOG_ERROR_PREFIX "Error: " +#define CLOG_ERROR_PREFIX_LENGTH 7 +#define CLOG_ERROR_PREFIX_FORMAT "Error in %s: " +#define CLOG_WARNING_PREFIX "Warning: " +#define CLOG_WARNING_PREFIX_LENGTH 9 +#define CLOG_WARNING_PREFIX_FORMAT "Warning in %s: " +#define CLOG_INFO_PREFIX "Note: " +#define CLOG_INFO_PREFIX_LENGTH 6 +#define CLOG_INFO_PREFIX_FORMAT "Note (%s): " +#define CLOG_DEBUG_PREFIX "Debug: " +#define CLOG_DEBUG_PREFIX_LENGTH 7 +#define CLOG_DEBUG_PREFIX_FORMAT "Debug (%s): " +#define CLOG_SUFFIX_LENGTH 1 + +void clog_vlog_fatal(const char* module, const char* format, va_list args) { +#if defined(__ANDROID__) && !CLOG_LOG_TO_STDIO + __android_log_vprint(ANDROID_LOG_FATAL, module, format, args); +#else + char stack_buffer[CLOG_STACK_BUFFER_SIZE]; + char* heap_buffer = NULL; + char* out_buffer = &stack_buffer[0]; + + /* The first call to vsnprintf will clobber args, thus need a copy in case a + * second vsnprintf call is needed */ + va_list args_copy; + va_copy(args_copy, args); + + int prefix_chars = CLOG_FATAL_PREFIX_LENGTH; + if (module == NULL) { + memcpy(stack_buffer, CLOG_FATAL_PREFIX, CLOG_FATAL_PREFIX_LENGTH); + } else { + prefix_chars = snprintf( + stack_buffer, CLOG_STACK_BUFFER_SIZE, CLOG_FATAL_PREFIX_FORMAT, module); + if (prefix_chars < 0) { + /* Format error in prefix (possible if prefix is modified): skip prefix + * and continue as if nothing happened. */ + prefix_chars = 0; + } + } + + int format_chars; + if (prefix_chars + CLOG_SUFFIX_LENGTH >= CLOG_STACK_BUFFER_SIZE) { + /* + * Prefix + suffix alone would overflow the on-stack buffer, thus need to + * use on-heap buffer. Do not even try to format the string into on-stack + * buffer. + */ + format_chars = vsnprintf(NULL, 0, format, args); + } else { + format_chars = vsnprintf( + &stack_buffer[prefix_chars], + CLOG_STACK_BUFFER_SIZE - prefix_chars - CLOG_SUFFIX_LENGTH, + format, + args); + } + if (format_chars < 0) { + /* Format error in the message: silently ignore this particular message. */ + goto cleanup; + } + if (prefix_chars + format_chars + CLOG_SUFFIX_LENGTH > + CLOG_STACK_BUFFER_SIZE) { + /* Allocate a buffer on heap, and vsnprintf to this buffer */ + heap_buffer = malloc(prefix_chars + format_chars + CLOG_SUFFIX_LENGTH); + if (heap_buffer == NULL) { + goto cleanup; + } + + if (prefix_chars > CLOG_STACK_BUFFER_SIZE) { + /* Prefix didn't fit into on-stack buffer, re-format it again to on-heap + * buffer */ + snprintf( + heap_buffer, + prefix_chars + 1 /* for '\0'-terminator */, + CLOG_FATAL_PREFIX_FORMAT, + module); + } else { + /* Copy pre-formatted prefix from on-stack buffer to on-heap buffer */ + memcpy(heap_buffer, stack_buffer, prefix_chars); + } + vsnprintf( + heap_buffer + prefix_chars, + format_chars + CLOG_SUFFIX_LENGTH, + format, + args_copy); + out_buffer = heap_buffer; + } + out_buffer[prefix_chars + format_chars] = '\n'; +#ifdef _WIN32 + DWORD bytes_written; + WriteFile( + GetStdHandle(STD_ERROR_HANDLE), + out_buffer, + prefix_chars + format_chars + CLOG_SUFFIX_LENGTH, + &bytes_written, + NULL); +#else + write( + STDERR_FILENO, + out_buffer, + prefix_chars + format_chars + CLOG_SUFFIX_LENGTH); +#endif + +cleanup: + free(heap_buffer); + va_end(args_copy); +#endif +} + +void clog_vlog_error(const char* module, const char* format, va_list args) { +#if defined(__ANDROID__) && !CLOG_LOG_TO_STDIO + __android_log_vprint(ANDROID_LOG_ERROR, module, format, args); +#else + char stack_buffer[CLOG_STACK_BUFFER_SIZE]; + char* heap_buffer = NULL; + char* out_buffer = &stack_buffer[0]; + + /* The first call to vsnprintf will clobber args, thus need a copy in case a + * second vsnprintf call is needed */ + va_list args_copy; + va_copy(args_copy, args); + + int prefix_chars = CLOG_ERROR_PREFIX_LENGTH; + if (module == NULL) { + memcpy(stack_buffer, CLOG_ERROR_PREFIX, CLOG_ERROR_PREFIX_LENGTH); + } else { + prefix_chars = snprintf( + stack_buffer, CLOG_STACK_BUFFER_SIZE, CLOG_ERROR_PREFIX_FORMAT, module); + if (prefix_chars < 0) { + /* Format error in prefix (possible if prefix is modified): skip prefix + * and continue as if nothing happened. */ + prefix_chars = 0; + } + } + + int format_chars; + if (prefix_chars + CLOG_SUFFIX_LENGTH >= CLOG_STACK_BUFFER_SIZE) { + /* + * Prefix + suffix alone would overflow the on-stack buffer, thus need to + * use on-heap buffer. Do not even try to format the string into on-stack + * buffer. + */ + format_chars = vsnprintf(NULL, 0, format, args); + } else { + format_chars = vsnprintf( + &stack_buffer[prefix_chars], + CLOG_STACK_BUFFER_SIZE - prefix_chars - CLOG_SUFFIX_LENGTH, + format, + args); + } + if (format_chars < 0) { + /* Format error in the message: silently ignore this particular message. */ + goto cleanup; + } + if (prefix_chars + format_chars + CLOG_SUFFIX_LENGTH > + CLOG_STACK_BUFFER_SIZE) { + /* Allocate a buffer on heap, and vsnprintf to this buffer */ + heap_buffer = malloc(prefix_chars + format_chars + CLOG_SUFFIX_LENGTH); + if (heap_buffer == NULL) { + goto cleanup; + } + + if (prefix_chars > CLOG_STACK_BUFFER_SIZE) { + /* Prefix didn't fit into on-stack buffer, re-format it again to on-heap + * buffer */ + snprintf( + heap_buffer, + prefix_chars + 1 /* for '\0'-terminator */, + CLOG_ERROR_PREFIX_FORMAT, + module); + } else { + /* Copy pre-formatted prefix from on-stack buffer to on-heap buffer */ + memcpy(heap_buffer, stack_buffer, prefix_chars); + } + vsnprintf( + heap_buffer + prefix_chars, + format_chars + CLOG_SUFFIX_LENGTH, + format, + args_copy); + out_buffer = heap_buffer; + } + out_buffer[prefix_chars + format_chars] = '\n'; +#ifdef _WIN32 + DWORD bytes_written; + WriteFile( + GetStdHandle(STD_ERROR_HANDLE), + out_buffer, + prefix_chars + format_chars + CLOG_SUFFIX_LENGTH, + &bytes_written, + NULL); +#else + write( + STDERR_FILENO, + out_buffer, + prefix_chars + format_chars + CLOG_SUFFIX_LENGTH); +#endif + +cleanup: + free(heap_buffer); + va_end(args_copy); +#endif +} + +void clog_vlog_warning(const char* module, const char* format, va_list args) { +#if defined(__ANDROID__) && !CLOG_LOG_TO_STDIO + __android_log_vprint(ANDROID_LOG_WARN, module, format, args); +#else + char stack_buffer[CLOG_STACK_BUFFER_SIZE]; + char* heap_buffer = NULL; + char* out_buffer = &stack_buffer[0]; + + /* The first call to vsnprintf will clobber args, thus need a copy in case a + * second vsnprintf call is needed */ + va_list args_copy; + va_copy(args_copy, args); + + int prefix_chars = CLOG_WARNING_PREFIX_LENGTH; + if (module == NULL) { + memcpy(stack_buffer, CLOG_WARNING_PREFIX, CLOG_WARNING_PREFIX_LENGTH); + } else { + prefix_chars = snprintf( + stack_buffer, + CLOG_STACK_BUFFER_SIZE, + CLOG_WARNING_PREFIX_FORMAT, + module); + if (prefix_chars < 0) { + /* Format error in prefix (possible if prefix is modified): skip prefix + * and continue as if nothing happened. */ + prefix_chars = 0; + } + } + + int format_chars; + if (prefix_chars + CLOG_SUFFIX_LENGTH >= CLOG_STACK_BUFFER_SIZE) { + /* + * Prefix + suffix alone would overflow the on-stack buffer, thus need to + * use on-heap buffer. Do not even try to format the string into on-stack + * buffer. + */ + format_chars = vsnprintf(NULL, 0, format, args); + } else { + format_chars = vsnprintf( + &stack_buffer[prefix_chars], + CLOG_STACK_BUFFER_SIZE - prefix_chars - CLOG_SUFFIX_LENGTH, + format, + args); + } + if (format_chars < 0) { + /* Format error in the message: silently ignore this particular message. */ + goto cleanup; + } + if (prefix_chars + format_chars + CLOG_SUFFIX_LENGTH > + CLOG_STACK_BUFFER_SIZE) { + /* Allocate a buffer on heap, and vsnprintf to this buffer */ + heap_buffer = malloc(prefix_chars + format_chars + CLOG_SUFFIX_LENGTH); + if (heap_buffer == NULL) { + goto cleanup; + } + + if (prefix_chars > CLOG_STACK_BUFFER_SIZE) { + /* Prefix didn't fit into on-stack buffer, re-format it again to on-heap + * buffer */ + snprintf( + heap_buffer, + prefix_chars + 1 /* for '\0'-terminator */, + CLOG_WARNING_PREFIX_FORMAT, + module); + } else { + /* Copy pre-formatted prefix from on-stack buffer to on-heap buffer */ + memcpy(heap_buffer, stack_buffer, prefix_chars); + } + vsnprintf( + heap_buffer + prefix_chars, + format_chars + CLOG_SUFFIX_LENGTH, + format, + args_copy); + out_buffer = heap_buffer; + } + out_buffer[prefix_chars + format_chars] = '\n'; +#ifdef _WIN32 + DWORD bytes_written; + WriteFile( + GetStdHandle(STD_ERROR_HANDLE), + out_buffer, + prefix_chars + format_chars + CLOG_SUFFIX_LENGTH, + &bytes_written, + NULL); +#else + write( + STDERR_FILENO, + out_buffer, + prefix_chars + format_chars + CLOG_SUFFIX_LENGTH); +#endif + +cleanup: + free(heap_buffer); + va_end(args_copy); +#endif +} + +void clog_vlog_info(const char* module, const char* format, va_list args) { +#if defined(__ANDROID__) && !CLOG_LOG_TO_STDIO + __android_log_vprint(ANDROID_LOG_INFO, module, format, args); +#else + char stack_buffer[CLOG_STACK_BUFFER_SIZE]; + char* heap_buffer = NULL; + char* out_buffer = &stack_buffer[0]; + + /* The first call to vsnprintf will clobber args, thus need a copy in case a + * second vsnprintf call is needed */ + va_list args_copy; + va_copy(args_copy, args); + + int prefix_chars = CLOG_INFO_PREFIX_LENGTH; + if (module == NULL) { + memcpy(stack_buffer, CLOG_INFO_PREFIX, CLOG_INFO_PREFIX_LENGTH); + } else { + prefix_chars = snprintf( + stack_buffer, CLOG_STACK_BUFFER_SIZE, CLOG_INFO_PREFIX_FORMAT, module); + if (prefix_chars < 0) { + /* Format error in prefix (possible if prefix is modified): skip prefix + * and continue as if nothing happened. */ + prefix_chars = 0; + } + } + + int format_chars; + if (prefix_chars + CLOG_SUFFIX_LENGTH >= CLOG_STACK_BUFFER_SIZE) { + /* + * Prefix + suffix alone would overflow the on-stack buffer, thus need to + * use on-heap buffer. Do not even try to format the string into on-stack + * buffer. + */ + format_chars = vsnprintf(NULL, 0, format, args); + } else { + format_chars = vsnprintf( + &stack_buffer[prefix_chars], + CLOG_STACK_BUFFER_SIZE - prefix_chars - CLOG_SUFFIX_LENGTH, + format, + args); + } + if (format_chars < 0) { + /* Format error in the message: silently ignore this particular message. */ + goto cleanup; + } + if (prefix_chars + format_chars + CLOG_SUFFIX_LENGTH > + CLOG_STACK_BUFFER_SIZE) { + /* Allocate a buffer on heap, and vsnprintf to this buffer */ + heap_buffer = malloc(prefix_chars + format_chars + CLOG_SUFFIX_LENGTH); + if (heap_buffer == NULL) { + goto cleanup; + } + + if (prefix_chars > CLOG_STACK_BUFFER_SIZE) { + /* Prefix didn't fit into on-stack buffer, re-format it again to on-heap + * buffer */ + snprintf( + heap_buffer, + prefix_chars + 1 /* for '\0'-terminator */, + CLOG_INFO_PREFIX_FORMAT, + module); + } else { + /* Copy pre-formatted prefix from on-stack buffer to on-heap buffer */ + memcpy(heap_buffer, stack_buffer, prefix_chars); + } + vsnprintf( + heap_buffer + prefix_chars, + format_chars + CLOG_SUFFIX_LENGTH, + format, + args_copy); + out_buffer = heap_buffer; + } + out_buffer[prefix_chars + format_chars] = '\n'; +#ifdef _WIN32 + DWORD bytes_written; + WriteFile( + GetStdHandle(STD_OUTPUT_HANDLE), + out_buffer, + prefix_chars + format_chars + CLOG_SUFFIX_LENGTH, + &bytes_written, + NULL); +#else + write( + STDOUT_FILENO, + out_buffer, + prefix_chars + format_chars + CLOG_SUFFIX_LENGTH); +#endif + +cleanup: + free(heap_buffer); + va_end(args_copy); +#endif +} + +void clog_vlog_debug(const char* module, const char* format, va_list args) { +#if defined(__ANDROID__) && !CLOG_LOG_TO_STDIO + __android_log_vprint(ANDROID_LOG_DEBUG, module, format, args); +#else + char stack_buffer[CLOG_STACK_BUFFER_SIZE]; + char* heap_buffer = NULL; + char* out_buffer = &stack_buffer[0]; + + /* The first call to vsnprintf will clobber args, thus need a copy in case a + * second vsnprintf call is needed */ + va_list args_copy; + va_copy(args_copy, args); + + int prefix_chars = CLOG_DEBUG_PREFIX_LENGTH; + if (module == NULL) { + memcpy(stack_buffer, CLOG_DEBUG_PREFIX, CLOG_DEBUG_PREFIX_LENGTH); + } else { + prefix_chars = snprintf( + stack_buffer, CLOG_STACK_BUFFER_SIZE, CLOG_DEBUG_PREFIX_FORMAT, module); + if (prefix_chars < 0) { + /* Format error in prefix (possible if prefix is modified): skip prefix + * and continue as if nothing happened. */ + prefix_chars = 0; + } + } + + int format_chars; + if (prefix_chars + CLOG_SUFFIX_LENGTH >= CLOG_STACK_BUFFER_SIZE) { + /* + * Prefix + suffix alone would overflow the on-stack buffer, thus need to + * use on-heap buffer. Do not even try to format the string into on-stack + * buffer. + */ + format_chars = vsnprintf(NULL, 0, format, args); + } else { + format_chars = vsnprintf( + &stack_buffer[prefix_chars], + CLOG_STACK_BUFFER_SIZE - prefix_chars - CLOG_SUFFIX_LENGTH, + format, + args); + } + if (format_chars < 0) { + /* Format error in the message: silently ignore this particular message. */ + goto cleanup; + } + if (prefix_chars + format_chars + CLOG_SUFFIX_LENGTH > + CLOG_STACK_BUFFER_SIZE) { + /* Allocate a buffer on heap, and vsnprintf to this buffer */ + heap_buffer = malloc(prefix_chars + format_chars + CLOG_SUFFIX_LENGTH); + if (heap_buffer == NULL) { + goto cleanup; + } + + if (prefix_chars > CLOG_STACK_BUFFER_SIZE) { + /* Prefix didn't fit into on-stack buffer, re-format it again to on-heap + * buffer */ + snprintf( + heap_buffer, + prefix_chars + 1 /* for '\0'-terminator */, + CLOG_DEBUG_PREFIX_FORMAT, + module); + } else { + /* Copy pre-formatted prefix from on-stack buffer to on-heap buffer */ + memcpy(heap_buffer, stack_buffer, prefix_chars); + } + vsnprintf( + heap_buffer + prefix_chars, + format_chars + CLOG_SUFFIX_LENGTH, + format, + args_copy); + out_buffer = heap_buffer; + } + out_buffer[prefix_chars + format_chars] = '\n'; +#ifdef _WIN32 + DWORD bytes_written; + WriteFile( + GetStdHandle(STD_OUTPUT_HANDLE), + out_buffer, + prefix_chars + format_chars + CLOG_SUFFIX_LENGTH, + &bytes_written, + NULL); +#else + write( + STDOUT_FILENO, + out_buffer, + prefix_chars + format_chars + CLOG_SUFFIX_LENGTH); +#endif + +cleanup: + free(heap_buffer); + va_end(args_copy); +#endif +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/test/clog.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/test/clog.cc new file mode 100644 index 0000000000000..988a7b9eb3f58 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/test/clog.cc @@ -0,0 +1,53 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +CLOG_DEFINE_LOG_DEBUG(named_log_debug, "Unit Test", CLOG_DEBUG); +CLOG_DEFINE_LOG_INFO(named_log_info, "Unit Test", CLOG_INFO); +CLOG_DEFINE_LOG_WARNING(named_log_warning, "Unit Test", CLOG_WARNING); +CLOG_DEFINE_LOG_ERROR(named_log_error, "Unit Test", CLOG_ERROR); +CLOG_DEFINE_LOG_FATAL(named_log_fatal, "Unit Test", CLOG_FATAL); + +CLOG_DEFINE_LOG_DEBUG(nameless_log_debug, NULL, CLOG_DEBUG); +CLOG_DEFINE_LOG_INFO(nameless_log_info, NULL, CLOG_INFO); +CLOG_DEFINE_LOG_WARNING(nameless_log_warning, NULL, CLOG_WARNING); +CLOG_DEFINE_LOG_ERROR(nameless_log_error, NULL, CLOG_ERROR); +CLOG_DEFINE_LOG_FATAL(nameless_log_fatal, NULL, CLOG_FATAL); + +CLOG_DEFINE_LOG_DEBUG(suppressed_log_debug, NULL, CLOG_INFO); +CLOG_DEFINE_LOG_INFO(suppressed_log_info, NULL, CLOG_WARNING); +CLOG_DEFINE_LOG_WARNING(suppressed_log_warning, NULL, CLOG_ERROR); +CLOG_DEFINE_LOG_ERROR(suppressed_log_error, NULL, CLOG_FATAL); +CLOG_DEFINE_LOG_FATAL(suppressed_log_fatal, NULL, CLOG_NONE); + +TEST(CLOG, debug) { + named_log_debug("test debug message with a module name"); + nameless_log_debug("test debug message without a module name"); + suppressed_log_debug("test suppressed debug message"); +} + +TEST(CLOG, info) { + named_log_info("test info message with a module name"); + nameless_log_info("test info message without a module name"); + suppressed_log_info("test suppressed info message"); +} + +TEST(CLOG, warning) { + named_log_warning("test warning message with a module name"); + nameless_log_warning("test warning message without a module name"); + suppressed_log_warning("test suppressed warning message"); +} + +TEST(CLOG, error) { + named_log_error("test error message with a module name"); + nameless_log_error("test error message without a module name"); + suppressed_log_error("test suppressed error message"); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/include/conv_utils.h b/aten/src/ATen/native/quantized/cpu/qnnpack/include/conv_utils.h new file mode 100644 index 0000000000000..5dc8880335905 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/include/conv_utils.h @@ -0,0 +1,193 @@ +#pragma once +#include +#include +#include + +#include +#include +#include + +namespace qnnpack { +struct conv_param_t { + const std::array kernel_dims; // kernel width, kernel height + const std::array subsampling_dims; // subsampling width, height + const std::array dilation; // dilation width, height + const std::array pad; // input padding top, left, bottom, right + const uint32_t groups; + const size_t input_channels; + const size_t output_channels; + const uint8_t kernel_zero_point; + const float kernel_scale; + const uint8_t output_min; + const uint8_t output_max; + + // The following are derived parameters + enum pytorch_qnnp_ukernel_type ukernel_type; // kernel type based on input params + size_t group_input_channels; + size_t group_output_channels; + + /** + * @brief Constructor for initializing the convolution parameters. + */ + conv_param_t( + const std::array kernel, + const std::array subsampling, + const std::array dil, + const std::array pd, + const uint32_t grp, + const size_t in_ch, + const size_t out_ch, + const uint8_t kernel_zp, + const float kernel_s, + const uint8_t out_min, + const uint8_t out_max) + : kernel_dims(kernel), + subsampling_dims(subsampling), + dilation(dil), + pad(pd), + groups(grp), + input_channels(in_ch), + output_channels(out_ch), + kernel_zero_point(kernel_zp), + kernel_scale(kernel_s), + output_min(out_min), + output_max(out_max) { + const uint32_t kernel_width = kernel_dims[0]; + const uint32_t kernel_height = kernel_dims[1]; + + const uint32_t input_padding_top = pad[0]; + const uint32_t input_padding_left = pad[1]; + const uint32_t input_padding_bottom = pad[2]; + const uint32_t input_padding_right = pad[3]; + + group_input_channels = input_channels / groups; + group_output_channels = output_channels / groups; + + if (kernel_width == 0 || kernel_height == 0) { + pytorch_qnnp_log_error( + "failed to create convolution with %" PRIu32 "x%" PRIu32 + " kernel: kernel dimensions must be non-zero", + kernel_width, + kernel_height); + assert("Failed to initialize QNNPACK conv_param_t struct."); + } + + if (subsampling_dims[0] == 0 || subsampling_dims[1] == 0) { + pytorch_qnnp_log_error( + "failed to create convolution with %" PRIu32 "x%" PRIu32 + " subsampling: " + "subsampling dimensions must be non-zero", + subsampling_dims[0], + subsampling_dims[1]); + assert("Failed to initialize QNNPACK conv_param_t struct."); + } + + if (dilation[0] == 0 || dilation[1] == 0) { + pytorch_qnnp_log_error( + "failed to create convolution with %" PRIu32 "x%" PRIu32 + " dilation: " + "dilation dimensions must be non-zero", + dilation[0], + dilation[1]); + assert("Failed to initialize QNNPACK conv_param_t struct."); + } + + if (kernel_scale <= 0.0f || !std::isnormal(kernel_scale)) { + pytorch_qnnp_log_error( + "failed to create convolution with %.7g kernel scale: scale must be" + "finite and positive", + kernel_scale); + assert("Failed to initialize QNNPACK conv_param_t struct."); + } + + if (subsampling_dims[1] > kernel_height) { + pytorch_qnnp_log_info( + "inefficiency in convolution with %" PRIu32 "x%" PRIu32 + " kernel and %" PRIu32 "x%" PRIu32 + " subsampling: " + "height subsampling is greater than kernel height; subsampling should" + " be performed before the convolution", + kernel_width, + kernel_height, + subsampling_dims[0], + subsampling_dims[1]); + } + + if (subsampling_dims[0] > kernel_width) { + pytorch_qnnp_log_info( + "inefficiency in convolution with %" PRIu32 "x%" PRIu32 + " kernel and %" PRIu32 "x%" PRIu32 + " subsampling: " + "width subsampling is greater than kernel width; subsampling should" + " be performed before the convolution", + kernel_width, + kernel_height, + subsampling_dims[0], + subsampling_dims[1]); + } + + if (input_padding_top >= kernel_height) { + pytorch_qnnp_log_info( + "inefficiency in convolution with %" PRIu32 "x%" PRIu32 + " kernel and %" PRIu32 "+%" PRIu32 + " height padding: " + "input top padding is greater or equal to kernel height", + kernel_width, + kernel_height, + input_padding_top, + input_padding_bottom); + } + + if (input_padding_bottom >= kernel_height) { + pytorch_qnnp_log_info( + "inefficiency in convolution with %" PRIu32 "x%" PRIu32 + " kernel and %" PRIu32 "+%" PRIu32 + " height padding: " + "input bottom padding is greater or equal to kernel height", + kernel_width, + kernel_height, + input_padding_top, + input_padding_bottom); + } + + if (input_padding_right >= kernel_width) { + pytorch_qnnp_log_info( + "inefficiency in convolution with %" PRIu32 "x%" PRIu32 + " kernel and %" PRIu32 "+%" PRIu32 + " width padding: " + "input right padding is greater or equal to kernel width", + kernel_width, + kernel_height, + input_padding_left, + input_padding_right); + } + + if (input_padding_left >= kernel_width) { + pytorch_qnnp_log_info( + "inefficiency in convolution with %" PRIu32 "x%" PRIu32 + " kernel and %" PRIu32 "+%" PRIu32 + " width padding: " + "input left padding is greater or equal to kernel width", + kernel_width, + kernel_height, + input_padding_left, + input_padding_right); + } + + const size_t kernel_size = kernel_height * kernel_width; + + ukernel_type = pytorch_qnnp_ukernel_type_none; + const bool any_padding = (input_padding_left | input_padding_top + | input_padding_right | input_padding_bottom) != 0; + + if ((kernel_size == 9 || kernel_size == 25) && + group_input_channels == 1 && group_output_channels == 1 && groups > 1) { + ukernel_type = pytorch_qnnp_ukernel_type_dwconv; + } else if (kernel_size == 1 && subsampling_dims[1] == 1 && subsampling_dims[0] == 1 && !any_padding) { + ukernel_type = group_input_channels >= SIZE_MAX ? pytorch_qnnp_ukernel_type_xzp_gemm : pytorch_qnnp_ukernel_type_gemm; + } else { + ukernel_type = pytorch_qnnp_ukernel_type_conv; + } + } +}; +} // namespace qnnpack diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/include/pytorch_qnnpack.h b/aten/src/ATen/native/quantized/cpu/qnnpack/include/pytorch_qnnpack.h new file mode 100644 index 0000000000000..a975194766802 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/include/pytorch_qnnpack.h @@ -0,0 +1,336 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Status code for any QNNPACK function call. + */ +enum pytorch_qnnp_status { + /** The call succeeded, and all output arguments now contain valid data. */ + pytorch_qnnp_status_success = 0, + pytorch_qnnp_status_uninitialized = 1, + pytorch_qnnp_status_invalid_parameter = 2, + pytorch_qnnp_status_unsupported_parameter = 3, + pytorch_qnnp_status_unsupported_hardware = 4, + pytorch_qnnp_status_out_of_memory = 5, +}; + +enum pytorch_qnnp_status pytorch_qnnp_initialize(void); + +enum pytorch_qnnp_status pytorch_qnnp_deinitialize(void); + +typedef struct pytorch_qnnp_operator* pytorch_qnnp_operator_t; + +enum pytorch_qnnp_status pytorch_qnnp_create_convolution2d_nhwc_q8( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t subsampling_height, + uint32_t subsampling_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + uint8_t input_zero_point, + float input_scale, + uint8_t kernel_zero_point, + float kernel_scale, + const uint8_t* kernel, + const int32_t* bias, + uint8_t output_zero_point, + float output_scale, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + pytorch_qnnp_operator_t* convolution); + +enum pytorch_qnnp_status pytorch_qnnp_setup_convolution2d_nhwc_q8( + pytorch_qnnp_operator_t convolution, + size_t batch_size, + size_t input_height, + size_t input_width, + const uint8_t* input, + size_t input_stride, + uint8_t* output, + size_t output_stride, + pthreadpool_t threadpool); + +enum pytorch_qnnp_status pytorch_qnnp_create_deconvolution2d_nhwc_q8( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t adjustment_height, + uint32_t adjustment_width, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t stride_height, + uint32_t stride_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + uint8_t input_zero_point, + float input_scale, + uint8_t kernel_zero_point, + float kernel_scale, + const uint8_t* kernel, + const int32_t* bias, + uint8_t output_zero_point, + float output_scale, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + pytorch_qnnp_operator_t* deconvolution); + +enum pytorch_qnnp_status pytorch_qnnp_setup_deconvolution2d_nhwc_q8( + pytorch_qnnp_operator_t deconvolution, + size_t batch_size, + size_t input_height, + size_t input_width, + const uint8_t* input, + size_t input_stride, + uint8_t* output, + size_t output_stride, + pthreadpool_t threadpool); + +enum pytorch_qnnp_status pytorch_qnnp_create_fully_connected_nc_q8( + size_t input_channels, + size_t output_channels, + uint8_t input_zero_point, + float input_scale, + uint8_t kernel_zero_point, + float kernel_scale, + const uint8_t* kernel, + const int32_t* bias, + uint8_t output_zero_point, + float output_scale, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + pytorch_qnnp_operator_t* fully_connected); + +enum pytorch_qnnp_status pytorch_qnnp_setup_fully_connected_nc_q8( + pytorch_qnnp_operator_t fully_connected, + size_t batch_size, + const uint8_t* input, + size_t input_stride, + uint8_t* output, + size_t output_stride); + +enum pytorch_qnnp_status pytorch_qnnp_create_global_average_pooling_nwc_q8( + size_t channels, + uint8_t input_zero_point, + float input_scale, + uint8_t output_zero_point, + float output_scale, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + pytorch_qnnp_operator_t* global_average_pooling); + +enum pytorch_qnnp_status pytorch_qnnp_setup_global_average_pooling_nwc_q8( + pytorch_qnnp_operator_t global_average_pooling, + size_t batch_size, + size_t width, + const uint8_t* input, + size_t input_stride, + uint8_t* output, + size_t output_stride); + +enum pytorch_qnnp_status pytorch_qnnp_create_average_pooling2d_nhwc_q8( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t pooling_height, + uint32_t pooling_width, + uint32_t stride_height, + uint32_t stride_width, + size_t channels, + uint8_t input_zero_point, + float input_scale, + uint8_t output_zero_point, + float output_scale, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + pytorch_qnnp_operator_t* average_pooling); + +enum pytorch_qnnp_status pytorch_qnnp_setup_average_pooling2d_nhwc_q8( + pytorch_qnnp_operator_t average_pooling, + size_t batch_size, + size_t input_height, + size_t input_width, + const uint8_t* input, + size_t input_stride, + uint8_t* output, + size_t output_stride, + pthreadpool_t threadpool); + +enum pytorch_qnnp_status pytorch_qnnp_create_max_pooling2d_nhwc_u8( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t pooling_height, + uint32_t pooling_width, + uint32_t stride_height, + uint32_t stride_width, + uint32_t dilation_height, + uint32_t dilation_width, + size_t channels, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + pytorch_qnnp_operator_t* max_pooling); + +enum pytorch_qnnp_status pytorch_qnnp_setup_max_pooling2d_nhwc_u8( + pytorch_qnnp_operator_t max_pooling, + size_t batch_size, + size_t input_height, + size_t input_width, + const uint8_t* input, + size_t input_stride, + uint8_t* output, + size_t output_stride, + pthreadpool_t threadpool); + +enum pytorch_qnnp_status pytorch_qnnp_create_channel_shuffle_nc_x8( + size_t groups, + size_t group_channels, + uint32_t flags, + pytorch_qnnp_operator_t* channel_shuffle); + +enum pytorch_qnnp_status pytorch_qnnp_setup_channel_shuffle_nc_x8( + pytorch_qnnp_operator_t channel_shuffle, + size_t batch_size, + const uint8_t* input, + size_t input_stride, + uint8_t* output, + size_t output_stride); + +enum pytorch_qnnp_status pytorch_qnnp_create_add_nc_q8( + size_t channels, + uint8_t a_zero_point, + float a_scale, + uint8_t b_zero_point, + float b_scale, + uint8_t sum_zero_point, + float sum_scale, + uint8_t sum_min, + uint8_t sum_max, + uint32_t flags, + pytorch_qnnp_operator_t* add); + +enum pytorch_qnnp_status pytorch_qnnp_setup_add_nc_q8( + pytorch_qnnp_operator_t add, + size_t batch_size, + const uint8_t* a, + size_t a_stride, + const uint8_t* b, + size_t b_stride, + uint8_t* sum, + size_t sum_stride); + +enum pytorch_qnnp_status pytorch_qnnp_create_clamp_nc_u8( + size_t channels, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + pytorch_qnnp_operator_t* clamp); + +enum pytorch_qnnp_status pytorch_qnnp_setup_clamp_nc_u8( + pytorch_qnnp_operator_t clamp, + size_t batch_size, + const uint8_t* input, + size_t input_stride, + uint8_t* output, + size_t output_stride); + +enum pytorch_qnnp_status pytorch_qnnp_create_sigmoid_nc_q8( + size_t channels, + uint8_t input_zero_point, + float input_scale, + uint8_t output_zero_point, + float output_scale, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + pytorch_qnnp_operator_t* sigmoid); + +enum pytorch_qnnp_status pytorch_qnnp_setup_sigmoid_nc_q8( + pytorch_qnnp_operator_t sigmoid, + size_t batch_size, + const uint8_t* input, + size_t input_stride, + uint8_t* output, + size_t output_stride); + +enum pytorch_qnnp_status pytorch_qnnp_create_leaky_relu_nc_q8( + size_t channels, + float negative_slope, + uint8_t input_zero_point, + float input_scale, + uint8_t output_zero_point, + float output_scale, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + pytorch_qnnp_operator_t* leaky_relu); + +enum pytorch_qnnp_status pytorch_qnnp_setup_leaky_relu_nc_q8( + pytorch_qnnp_operator_t leaky_relu, + size_t batch_size, + const uint8_t* input, + size_t input_stride, + uint8_t* output, + size_t output_stride); + +enum pytorch_qnnp_status pytorch_qnnp_create_softargmax_nc_q8( + size_t channels, + float input_scale, + uint8_t output_zero_point, + float output_scale, + uint32_t flags, + pytorch_qnnp_operator_t* softargmax); + +enum pytorch_qnnp_status pytorch_qnnp_setup_softargmax_nc_q8( + pytorch_qnnp_operator_t softargmax, + size_t batch_size, + const uint8_t* input, + size_t input_stride, + uint8_t* output, + size_t output_stride); + +enum pytorch_qnnp_status pytorch_qnnp_run_operator( + pytorch_qnnp_operator_t op, + pthreadpool_t threadpool); + +enum pytorch_qnnp_status pytorch_qnnp_delete_operator( + pytorch_qnnp_operator_t op); + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/include/qnnpack_func.h b/aten/src/ATen/native/quantized/cpu/qnnpack/include/qnnpack_func.h new file mode 100644 index 0000000000000..35f3afe7c41a3 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/include/qnnpack_func.h @@ -0,0 +1,109 @@ +#pragma once +#include + +namespace qnnpack { +class PrePackConvWeights final { + public: + PrePackConvWeights(const conv_param_t& conv_param, const uint8_t* kernel, const int32_t* bias); + + void* getPackedWeights() const + { + return packed_weights_; + } + + int64_t getOutputChannels() const + { + return output_channels_; + } + + ~PrePackConvWeights() + { + if (packed_weights_ != nullptr) { + free(packed_weights_); + } + } + + PrePackConvWeights() = delete; + PrePackConvWeights(const PrePackConvWeights&) = delete; + PrePackConvWeights& operator=(const PrePackConvWeights&) = delete; + + private: + void* packed_weights_ = nullptr; + int64_t output_channels_; +}; + +class PackBMatrix final { + public: + PackBMatrix( + size_t input_channels, + size_t output_channels, + uint8_t kernel_zero_point, + float kernel_scale, + const uint8_t* kernel, + const int32_t* bias); + + void* getPackedWeights() const + { + return packed_weights_; + } + + size_t getInputChannels() const + { + return input_channels_; + } + + size_t getOutputChannels() const + { + return output_channels_; + } + + ~PackBMatrix() + { + if (packed_weights_ != nullptr) { + free(packed_weights_); + } + } + + PackBMatrix() = delete; + PackBMatrix(const PackBMatrix&) = delete; + PackBMatrix& operator=(const PackBMatrix&) = delete; + + private: + void* packed_weights_ = nullptr; + size_t input_channels_; + size_t output_channels_; +}; + +enum pytorch_qnnp_status qnnpackLinear( + const size_t batch_size, + const size_t input_channels, + const size_t output_channels, + const uint8_t input_zero_point, + const float input_scale, + const uint8_t kernel_zero_point, + const float kernel_scale, + const uint8_t output_zero_point, + const float output_scale, + const uint8_t output_min, + const uint8_t output_max, + const uint8_t* input, + const size_t input_stride, + void* packed_weights, + uint8_t* output, + const size_t output_stride, + pthreadpool_t threadpool); + +enum pytorch_qnnp_status qnnpackConv( + const conv_param_t& conv_p, + void* packed_weights, + const size_t batch_size, + const size_t input_height, + const size_t input_width, + const float input_scale, + const uint8_t input_zero_point, + const uint8_t* input, + const float output_scale, + const uint8_t output_zero_point, + uint8_t* output, + pthreadpool_t threadpool); +} // namespace qnnpack diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-arm64.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-arm64.sh new file mode 100755 index 0000000000000..389430b043fe6 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-arm64.sh @@ -0,0 +1,68 @@ +#!/usr/bin/env bash +# +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -e + +if [ -z "$ANDROID_NDK" ] +then + echo "ANDROID_NDK not set; please set it to the Android NDK directory" + exit 1 +fi + +if [ ! -d "$ANDROID_NDK" ] +then + echo "ANDROID_NDK not a directory; did you install it under ${ANDROID_NDK}?" + exit 1 +fi + +mkdir -p build/android/arm64-v8a + +CMAKE_ARGS=() + +# CMake-level configuration +CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake") +CMAKE_ARGS+=("-DCMAKE_BUILD_TYPE=Release") +CMAKE_ARGS+=("-DCMAKE_POSITION_INDEPENDENT_CODE=ON") + +# If Ninja is installed, prefer it to Make +if [ -x "$(command -v ninja)" ] +then + CMAKE_ARGS+=("-GNinja") +fi + +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_LIBRARY_TYPE=static") + +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_BENCHMARKS=ON") +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_TESTS=ON") + +# Cross-compilation options for Google Benchmark +CMAKE_ARGS+=("-DHAVE_POSIX_REGEX=0") +CMAKE_ARGS+=("-DHAVE_STEADY_CLOCK=0") +CMAKE_ARGS+=("-DHAVE_STD_REGEX=0") + +# Android-specific options +CMAKE_ARGS+=("-DANDROID_NDK=$ANDROID_NDK") +CMAKE_ARGS+=("-DANDROID_ABI=arm64-v8a") +CMAKE_ARGS+=("-DANDROID_PLATFORM=android-21") +CMAKE_ARGS+=("-DANDROID_PIE=ON") +CMAKE_ARGS+=("-DANDROID_STL=c++_static") +CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=exceptions") + +# Use-specified CMake arguments go last to allow overridding defaults +CMAKE_ARGS+=($@) + +cd build/android/arm64-v8a && cmake ../../.. \ + "${CMAKE_ARGS[@]}" + +# Cross-platform parallel build +if [ "$(uname)" == "Darwin" ] +then + cmake --build . -- "-j$(sysctl -n hw.ncpu)" +else + cmake --build . -- "-j$(nproc)" +fi diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-armv7.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-armv7.sh new file mode 100755 index 0000000000000..6f32950125e0b --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-armv7.sh @@ -0,0 +1,68 @@ +#!/usr/bin/env bash +# +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -e + +if [ -z "$ANDROID_NDK" ] +then + echo "ANDROID_NDK not set; please set it to the Android NDK directory" + exit 1 +fi + +if [ ! -d "$ANDROID_NDK" ] +then + echo "ANDROID_NDK not a directory; did you install it under ${ANDROID_NDK}?" + exit 1 +fi + +mkdir -p build/android/armeabi-v7a + +CMAKE_ARGS=() + +# CMake-level configuration +CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake") +CMAKE_ARGS+=("-DCMAKE_BUILD_TYPE=Release") +CMAKE_ARGS+=("-DCMAKE_POSITION_INDEPENDENT_CODE=ON") + +# If Ninja is installed, prefer it to Make +if [ -x "$(command -v ninja)" ] +then + CMAKE_ARGS+=("-GNinja") +fi + +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_LIBRARY_TYPE=static") + +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_BENCHMARKS=ON") +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_TESTS=ON") + +# Cross-compilation options for Google Benchmark +CMAKE_ARGS+=("-DHAVE_POSIX_REGEX=0") +CMAKE_ARGS+=("-DHAVE_STEADY_CLOCK=0") +CMAKE_ARGS+=("-DHAVE_STD_REGEX=0") + +# Android-specific options +CMAKE_ARGS+=("-DANDROID_NDK=$ANDROID_NDK") +CMAKE_ARGS+=("-DANDROID_ABI=armeabi-v7a") +CMAKE_ARGS+=("-DANDROID_PLATFORM=android-14") +CMAKE_ARGS+=("-DANDROID_PIE=ON") +CMAKE_ARGS+=("-DANDROID_STL=c++_static") +CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=exceptions") + +# Use-specified CMake arguments go last to allow overridding defaults +CMAKE_ARGS+=($@) + +cd build/android/armeabi-v7a && cmake ../../.. \ + "${CMAKE_ARGS[@]}" + +# Cross-platform parallel build +if [ "$(uname)" == "Darwin" ] +then + cmake --build . -- "-j$(sysctl -n hw.ncpu)" +else + cmake --build . -- "-j$(nproc)" +fi diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-x86.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-x86.sh new file mode 100755 index 0000000000000..5f19db582fb09 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-android-x86.sh @@ -0,0 +1,68 @@ +#!/usr/bin/env bash +# +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -e + +if [ -z "$ANDROID_NDK" ] +then + echo "ANDROID_NDK not set; please set it to the Android NDK directory" + exit 1 +fi + +if [ ! -d "$ANDROID_NDK" ] +then + echo "ANDROID_NDK not a directory; did you install it under ${ANDROID_NDK}?" + exit 1 +fi + +mkdir -p build/android/x86 + +CMAKE_ARGS=() + +# CMake-level configuration +CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake") +CMAKE_ARGS+=("-DCMAKE_BUILD_TYPE=Release") +CMAKE_ARGS+=("-DCMAKE_POSITION_INDEPENDENT_CODE=ON") + +# If Ninja is installed, prefer it to Make +if [ -x "$(command -v ninja)" ] +then + CMAKE_ARGS+=("-GNinja") +fi + +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_LIBRARY_TYPE=static") + +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_BENCHMARKS=ON") +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_TESTS=ON") + +# Cross-compilation options for Google Benchmark +CMAKE_ARGS+=("-DHAVE_POSIX_REGEX=0") +CMAKE_ARGS+=("-DHAVE_STEADY_CLOCK=0") +CMAKE_ARGS+=("-DHAVE_STD_REGEX=0") + +# Android-specific options +CMAKE_ARGS+=("-DANDROID_NDK=$ANDROID_NDK") +CMAKE_ARGS+=("-DANDROID_ABI=x86") +CMAKE_ARGS+=("-DANDROID_PLATFORM=android-14") +CMAKE_ARGS+=("-DANDROID_PIE=ON") +CMAKE_ARGS+=("-DANDROID_STL=c++_static") +CMAKE_ARGS+=("-DANDROID_CPP_FEATURES=exceptions") + +# Use-specified CMake arguments go last to allow overridding defaults +CMAKE_ARGS+=($@) + +cd build/android/x86 && cmake ../../.. \ + "${CMAKE_ARGS[@]}" + +# Cross-platform parallel build +if [ "$(uname)" == "Darwin" ] +then + cmake --build . -- "-j$(sysctl -n hw.ncpu)" +else + cmake --build . -- "-j$(nproc)" +fi diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64.sh new file mode 100755 index 0000000000000..d155d6f7507df --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64.sh @@ -0,0 +1,55 @@ +#!/usr/bin/env bash +# +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -e + +if [ -z "$IOS_CMAKE_TOOLCHAIN_FILE" ] +then + echo "IOS_CMAKE_TOOLCHAIN_FILE not set; please set it to path of CMake toolchain file for iOS" + exit 1 +fi + +if [ ! -f "$IOS_CMAKE_TOOLCHAIN_FILE" ] +then + echo "IOS_CMAKE_TOOLCHAIN_FILE not a file path; did you properly setup ${IOS_CMAKE_TOOLCHAIN_FILE}?" + exit 1 +fi + +mkdir -p build/ios/arm64 + +CMAKE_ARGS=() + +# CMake-level configuration +CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$IOS_CMAKE_TOOLCHAIN_FILE") +CMAKE_ARGS+=("-DCMAKE_BUILD_TYPE=Release") +CMAKE_ARGS+=("-DCMAKE_POSITION_INDEPENDENT_CODE=ON") + +# QNNPACK-specific options +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_LIBRARY_TYPE=static") +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_BENCHMARKS=OFF") +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_TESTS=OFF") + +# iOS-specific options +CMAKE_ARGS+=("-DIOS_PLATFORM=OS64") +CMAKE_ARGS+=("-DIOS_ARCH=arm64") +CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") +CMAKE_ARGS+=("-DENABLE_ARC=OFF") + +# Use-specified CMake arguments go last to allow overridding defaults +CMAKE_ARGS+=($@) + +cd build/ios/arm64 && cmake ../../.. \ + "${CMAKE_ARGS[@]}" + +# Cross-platform parallel build +if [ "$(uname)" == "Darwin" ] +then + cmake --build . -- "-j$(sysctl -n hw.ncpu)" +else + cmake --build . -- "-j$(nproc)" +fi diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64e.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64e.sh new file mode 100755 index 0000000000000..985315f74a667 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-arm64e.sh @@ -0,0 +1,55 @@ +#!/usr/bin/env bash +# +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -e + +if [ -z "$IOS_CMAKE_TOOLCHAIN_FILE" ] +then + echo "IOS_CMAKE_TOOLCHAIN_FILE not set; please set it to path of CMake toolchain file for iOS" + exit 1 +fi + +if [ ! -f "$IOS_CMAKE_TOOLCHAIN_FILE" ] +then + echo "IOS_CMAKE_TOOLCHAIN_FILE not a file path; did you properly setup ${IOS_CMAKE_TOOLCHAIN_FILE}?" + exit 1 +fi + +mkdir -p build/ios/arm64e + +CMAKE_ARGS=() + +# CMake-level configuration +CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$IOS_CMAKE_TOOLCHAIN_FILE") +CMAKE_ARGS+=("-DCMAKE_BUILD_TYPE=Release") +CMAKE_ARGS+=("-DCMAKE_POSITION_INDEPENDENT_CODE=ON") + +# QNNPACK-specific options +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_LIBRARY_TYPE=static") +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_BENCHMARKS=OFF") +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_TESTS=OFF") + +# iOS-specific options +CMAKE_ARGS+=("-DIOS_PLATFORM=OS64") +CMAKE_ARGS+=("-DIOS_ARCH=arm64e") +CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") +CMAKE_ARGS+=("-DENABLE_ARC=OFF") + +# Use-specified CMake arguments go last to allow overridding defaults +CMAKE_ARGS+=($@) + +cd build/ios/arm64e && cmake ../../.. \ + "${CMAKE_ARGS[@]}" + +# Cross-platform parallel build +if [ "$(uname)" == "Darwin" ] +then + cmake --build . -- "-j$(sysctl -n hw.ncpu)" +else + cmake --build . -- "-j$(nproc)" +fi diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7.sh new file mode 100755 index 0000000000000..0431c090db68f --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7.sh @@ -0,0 +1,55 @@ +#!/usr/bin/env bash +# +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -e + +if [ -z "$IOS_CMAKE_TOOLCHAIN_FILE" ] +then + echo "IOS_CMAKE_TOOLCHAIN_FILE not set; please set it to path of CMake toolchain file for iOS" + exit 1 +fi + +if [ ! -f "$IOS_CMAKE_TOOLCHAIN_FILE" ] +then + echo "IOS_CMAKE_TOOLCHAIN_FILE not a file path; did you properly setup ${IOS_CMAKE_TOOLCHAIN_FILE}?" + exit 1 +fi + +mkdir -p build/ios/armv7 + +CMAKE_ARGS=() + +# CMake-level configuration +CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$IOS_CMAKE_TOOLCHAIN_FILE") +CMAKE_ARGS+=("-DCMAKE_BUILD_TYPE=Release") +CMAKE_ARGS+=("-DCMAKE_POSITION_INDEPENDENT_CODE=ON") + +# QNNPACK-specific options +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_LIBRARY_TYPE=static") +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_BENCHMARKS=OFF") +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_TESTS=OFF") + +# iOS-specific options +CMAKE_ARGS+=("-DIOS_PLATFORM=OS64") +CMAKE_ARGS+=("-DIOS_ARCH=armv7") +CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") +CMAKE_ARGS+=("-DENABLE_ARC=OFF") + +# Use-specified CMake arguments go last to allow overridding defaults +CMAKE_ARGS+=($@) + +cd build/ios/armv7 && cmake ../../.. \ + "${CMAKE_ARGS[@]}" + +# Cross-platform parallel build +if [ "$(uname)" == "Darwin" ] +then + cmake --build . -- "-j$(sysctl -n hw.ncpu)" +else + cmake --build . -- "-j$(nproc)" +fi diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7s.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7s.sh new file mode 100755 index 0000000000000..e3f3d6b76231d --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-armv7s.sh @@ -0,0 +1,55 @@ +#!/usr/bin/env bash +# +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -e + +if [ -z "$IOS_CMAKE_TOOLCHAIN_FILE" ] +then + echo "IOS_CMAKE_TOOLCHAIN_FILE not set; please set it to path of CMake toolchain file for iOS" + exit 1 +fi + +if [ ! -f "$IOS_CMAKE_TOOLCHAIN_FILE" ] +then + echo "IOS_CMAKE_TOOLCHAIN_FILE not a file path; did you properly setup ${IOS_CMAKE_TOOLCHAIN_FILE}?" + exit 1 +fi + +mkdir -p build/ios/armv7s + +CMAKE_ARGS=() + +# CMake-level configuration +CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$IOS_CMAKE_TOOLCHAIN_FILE") +CMAKE_ARGS+=("-DCMAKE_BUILD_TYPE=Release") +CMAKE_ARGS+=("-DCMAKE_POSITION_INDEPENDENT_CODE=ON") + +# QNNPACK-specific options +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_LIBRARY_TYPE=static") +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_BENCHMARKS=OFF") +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_TESTS=OFF") + +# iOS-specific options +CMAKE_ARGS+=("-DIOS_PLATFORM=OS64") +CMAKE_ARGS+=("-DIOS_ARCH=armv7s") +CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") +CMAKE_ARGS+=("-DENABLE_ARC=OFF") + +# Use-specified CMake arguments go last to allow overridding defaults +CMAKE_ARGS+=($@) + +cd build/ios/armv7s && cmake ../../.. \ + "${CMAKE_ARGS[@]}" + +# Cross-platform parallel build +if [ "$(uname)" == "Darwin" ] +then + cmake --build . -- "-j$(sysctl -n hw.ncpu)" +else + cmake --build . -- "-j$(nproc)" +fi diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-i386.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-i386.sh new file mode 100755 index 0000000000000..e8952148e66ad --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-i386.sh @@ -0,0 +1,55 @@ +#!/usr/bin/env bash +# +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -e + +if [ -z "$IOS_CMAKE_TOOLCHAIN_FILE" ] +then + echo "IOS_CMAKE_TOOLCHAIN_FILE not set; please set it to path of CMake toolchain file for iOS" + exit 1 +fi + +if [ ! -f "$IOS_CMAKE_TOOLCHAIN_FILE" ] +then + echo "IOS_CMAKE_TOOLCHAIN_FILE not a file path; did you properly setup ${IOS_CMAKE_TOOLCHAIN_FILE}?" + exit 1 +fi + +mkdir -p build/ios/i386 + +CMAKE_ARGS=() + +# CMake-level configuration +CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$IOS_CMAKE_TOOLCHAIN_FILE") +CMAKE_ARGS+=("-DCMAKE_BUILD_TYPE=Release") +CMAKE_ARGS+=("-DCMAKE_POSITION_INDEPENDENT_CODE=ON") + +# QNNPACK-specific options +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_LIBRARY_TYPE=static") +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_BENCHMARKS=OFF") # Google Benchmark is broken on 32-bit iOS +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_TESTS=ON") + +# iOS-specific options +CMAKE_ARGS+=("-DIOS_PLATFORM=SIMULATOR") +CMAKE_ARGS+=("-DIOS_ARCH=i386") +CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") +CMAKE_ARGS+=("-DENABLE_ARC=OFF") + +# Use-specified CMake arguments go last to allow overridding defaults +CMAKE_ARGS+=($@) + +cd build/ios/i386 && cmake ../../.. \ + "${CMAKE_ARGS[@]}" + +# Cross-platform parallel build +if [ "$(uname)" == "Darwin" ] +then + cmake --build . -- "-j$(sysctl -n hw.ncpu)" +else + cmake --build . -- "-j$(nproc)" +fi diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-x86_64.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-x86_64.sh new file mode 100755 index 0000000000000..10a58b843e2a7 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-ios-x86_64.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash +# +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -e + +if [ -z "$IOS_CMAKE_TOOLCHAIN_FILE" ] +then + echo "IOS_CMAKE_TOOLCHAIN_FILE not set; please set it to path of CMake toolchain file for iOS" + exit 1 +fi + +if [ ! -f "$IOS_CMAKE_TOOLCHAIN_FILE" ] +then + echo "IOS_CMAKE_TOOLCHAIN_FILE not a file path; did you properly setup ${IOS_CMAKE_TOOLCHAIN_FILE}?" + exit 1 +fi + +mkdir -p build/ios/x86_64 + +CMAKE_ARGS=() + +# CMake-level configuration +CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$IOS_CMAKE_TOOLCHAIN_FILE") +CMAKE_ARGS+=("-DCMAKE_BUILD_TYPE=Release") +CMAKE_ARGS+=("-DCMAKE_POSITION_INDEPENDENT_CODE=ON") + +# QNNPACK-specific options +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_LIBRARY_TYPE=static") +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_BENCHMARKS=ON") +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_TESTS=ON") + +# Cross-compilation options for Google Benchmark +CMAKE_ARGS+=("-DHAVE_POSIX_REGEX=0") +CMAKE_ARGS+=("-DHAVE_STEADY_CLOCK=0") +CMAKE_ARGS+=("-DHAVE_STD_REGEX=0") + +# iOS-specific options +CMAKE_ARGS+=("-DIOS_PLATFORM=SIMULATOR64") +CMAKE_ARGS+=("-DIOS_ARCH=x86_64") +CMAKE_ARGS+=("-DENABLE_BITCODE=OFF") +CMAKE_ARGS+=("-DENABLE_ARC=OFF") + +# Use-specified CMake arguments go last to allow overridding defaults +CMAKE_ARGS+=($@) + +cd build/ios/x86_64 && cmake ../../.. \ + "${CMAKE_ARGS[@]}" + +# Cross-platform parallel build +if [ "$(uname)" == "Darwin" ] +then + cmake --build . -- "-j$(sysctl -n hw.ncpu)" +else + cmake --build . -- "-j$(nproc)" +fi diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-local.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-local.sh new file mode 100755 index 0000000000000..b429650c21842 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/build-local.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash +# +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -e + +mkdir -p build/local + +CMAKE_ARGS=() + +# CMake-level configuration +CMAKE_ARGS+=("-DCMAKE_BUILD_TYPE=Release") +CMAKE_ARGS+=("-DCMAKE_POSITION_INDEPENDENT_CODE=ON") + +# If Ninja is installed, prefer it to Make +if [ -x "$(command -v ninja)" ] +then + CMAKE_ARGS+=("-GNinja") +fi + +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_LIBRARY_TYPE=static") + +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_BENCHMARKS=ON") +CMAKE_ARGS+=("-DPYTORCH_QNNPACK_BUILD_TESTS=ON") + +# Use-specified CMake arguments go last to allow overridding defaults +CMAKE_ARGS+=($@) + +cd build/local && cmake ../.. \ + "${CMAKE_ARGS[@]}" + +# Cross-platform parallel build +if [ "$(uname)" == "Darwin" ] +then + cmake --build . -- "-j$(sysctl -n hw.ncpu)" +else + cmake --build . -- "-j$(nproc)" +fi diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/test-android-arm64.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/test-android-arm64.sh new file mode 100755 index 0000000000000..4482dad213f22 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/test-android-arm64.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +# +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -e + +adb push build/android/arm64-v8a/convolution-test /data/local/tmp/convolution-test +adb push build/android/arm64-v8a/deconvolution-test /data/local/tmp/deconvolution-test +adb push build/android/arm64-v8a/q8gemm-test /data/local/tmp/q8gemm-test +adb push build/android/arm64-v8a/q8conv-test /data/local/tmp/q8conv-test +adb push build/android/arm64-v8a/q8dw-test /data/local/tmp/q8dw-test +adb push build/android/arm64-v8a/hgemm-test /data/local/tmp/hgemm-test +adb push build/android/arm64-v8a/sgemm-test /data/local/tmp/sgemm-test + +adb shell /data/local/tmp/convolution-test --gtest_color=yes +adb shell /data/local/tmp/deconvolution-test --gtest_color=yes +adb shell /data/local/tmp/q8gemm-test --gtest_color=yes +adb shell /data/local/tmp/q8conv-test --gtest_color=yes +adb shell /data/local/tmp/q8dw-test --gtest_color=yes +adb shell /data/local/tmp/hgemm-test --gtest_color=yes +adb shell /data/local/tmp/sgemm-test --gtest_color=yes diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/test-android-armv7.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/test-android-armv7.sh new file mode 100755 index 0000000000000..bc42fb1d730b1 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/test-android-armv7.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +# +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -e + +adb push build/android/armeabi-v7a/convolution-test /data/local/tmp/convolution-test +adb push build/android/armeabi-v7a/deconvolution-test /data/local/tmp/deconvolution-test +adb push build/android/armeabi-v7a/q8gemm-test /data/local/tmp/q8gemm-test +adb push build/android/armeabi-v7a/q8conv-test /data/local/tmp/q8conv-test +adb push build/android/armeabi-v7a/q8dw-test /data/local/tmp/q8dw-test +adb push build/android/armeabi-v7a/hgemm-test /data/local/tmp/hgemm-test +adb push build/android/armeabi-v7a/sgemm-test /data/local/tmp/sgemm-test + +adb shell /data/local/tmp/convolution-test --gtest_color=yes +adb shell /data/local/tmp/deconvolution-test --gtest_color=yes +adb shell /data/local/tmp/q8gemm-test --gtest_color=yes +adb shell /data/local/tmp/q8conv-test --gtest_color=yes +adb shell /data/local/tmp/q8dw-test --gtest_color=yes +adb shell /data/local/tmp/hgemm-test --gtest_color=yes +adb shell /data/local/tmp/sgemm-test --gtest_color=yes diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/test-android-x86.sh b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/test-android-x86.sh new file mode 100755 index 0000000000000..32742b0250127 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/scripts/test-android-x86.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +# +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -e + +adb push build/android/x86/convolution-test /data/local/tmp/convolution-test +adb push build/android/x86/deconvolution-test /data/local/tmp/deconvolution-test +adb push build/android/x86/q8gemm-test /data/local/tmp/q8gemm-test +adb push build/android/x86/q8conv-test /data/local/tmp/q8conv-test +adb push build/android/x86/q8dw-test /data/local/tmp/q8dw-test +adb push build/android/x86/hgemm-test /data/local/tmp/hgemm-test +adb push build/android/x86/sgemm-test /data/local/tmp/sgemm-test + +adb shell /data/local/tmp/convolution-test --gtest_color=yes +adb shell /data/local/tmp/deconvolution-test --gtest_color=yes +adb shell /data/local/tmp/q8gemm-test --gtest_color=yes +adb shell /data/local/tmp/q8conv-test --gtest_color=yes +adb shell /data/local/tmp/q8dw-test --gtest_color=yes +adb shell /data/local/tmp/hgemm-test --gtest_color=yes +adb shell /data/local/tmp/sgemm-test --gtest_color=yes diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/add.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/add.c new file mode 100644 index 0000000000000..a3a82281c26ed --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/add.c @@ -0,0 +1,160 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +enum pytorch_qnnp_status pytorch_qnnp_create_add_nc_q8( + size_t channels, + uint8_t a_zero_point, + float a_scale, + uint8_t b_zero_point, + float b_scale, + uint8_t sum_zero_point, + float sum_scale, + uint8_t sum_min, + uint8_t sum_max, + uint32_t flags, + pytorch_qnnp_operator_t* add_out) { + pytorch_qnnp_operator_t add_op = NULL; + enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized; + + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_create_add_nc_q8 failed because QNNPACK is not properly initialized"); + goto error; + } + + status = pytorch_qnnp_status_invalid_parameter; + + if (channels == 0) { + pytorch_qnnp_log_error( + "failed to create add operator with %zu channels: number of channels must be non-zero", + channels); + goto error; + } + + if (a_scale <= 0.0f || !isnormal(a_scale)) { + pytorch_qnnp_log_error( + "failed to create add operator with %.7g A scale: scale must be finite and positive", + a_scale); + goto error; + } + + if (b_scale <= 0.0f || !isnormal(b_scale)) { + pytorch_qnnp_log_error( + "failed to create add operator with %.7g B scale: scale must be finite and positive", + b_scale); + goto error; + } + + if (sum_scale <= 0.0f || !isnormal(sum_scale)) { + pytorch_qnnp_log_error( + "failed to create add operator with %.7g output scale: scale must be finite and positive", + sum_scale); + goto error; + } + + if (sum_min >= sum_max) { + pytorch_qnnp_log_error( + "failed to create add operator with [%" PRIu8 ", %" PRIu8 + "] output range: range min must be below range max", + sum_min, + sum_max); + goto error; + } + + status = pytorch_qnnp_status_unsupported_parameter; + + const float a_output_scale = a_scale / sum_scale; + if (a_output_scale < 0x1.0p-14f || a_output_scale >= 0x1.0p+8f) { + pytorch_qnnp_log_error( + "failed to create add operator with %.7g A-to-output scale ratio: scale ratio must be in [2**-14, 2**8) range", + a_output_scale); + goto error; + } + + const float b_output_scale = b_scale / sum_scale; + if (b_output_scale < 0x1.0p-14f || b_output_scale >= 0x1.0p+8f) { + pytorch_qnnp_log_error( + "failed to create add operator with %.7g A-to-output scale ratio: scale ratio must be in [2**-14, 2**8) range", + b_output_scale); + goto error; + } + + status = pytorch_qnnp_status_out_of_memory; + + add_op = calloc(1, sizeof(struct pytorch_qnnp_operator)); + if (add_op == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for pytorch_qnnp_operator structure", + sizeof(struct pytorch_qnnp_operator)); + goto error; + } + + add_op->channels = channels; + add_op->add_quantization_params = + pytorch_qnnp_compute_add_quantization_params( + a_zero_point, + b_zero_point, + sum_zero_point, + a_scale / sum_scale, + b_scale / sum_scale, + sum_min, + sum_max); + + add_op->ukernel_type = pytorch_qnnp_ukernel_type_add; + add_op->format = pytorch_qnnp_format_quint8; + + *add_out = add_op; + return pytorch_qnnp_status_success; + +error: + pytorch_qnnp_delete_operator(add_op); + return status; +} + +enum pytorch_qnnp_status pytorch_qnnp_setup_add_nc_q8( + pytorch_qnnp_operator_t add_op, + size_t batch_size, + const uint8_t* a, + size_t a_stride, + const uint8_t* b, + size_t b_stride, + uint8_t* sum, + size_t sum_stride) { + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_setup_add_nc_q8 failed because QNNPACK is not properly initialized"); + return pytorch_qnnp_status_uninitialized; + } + + if (batch_size == 0) { + add_op->batch_size = 0; + return pytorch_qnnp_status_success; + } + + add_op->batch_size = batch_size; + add_op->input = a; + add_op->input_pixel_stride = a_stride; + add_op->input2 = b; + add_op->input2_pixel_stride = b_stride; + add_op->output = sum; + add_op->output_pixel_stride = sum_stride; + + return pytorch_qnnp_status_success; +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/average-pooling.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/average-pooling.c new file mode 100644 index 0000000000000..1c6c2f7392ce2 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/average-pooling.c @@ -0,0 +1,298 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +static inline size_t compute_output_dimension( + size_t padded_input_dimension, + size_t pooling_dimension, + size_t stride_dimension) { + return (padded_input_dimension - pooling_dimension) / stride_dimension + 1; +} + +enum pytorch_qnnp_status pytorch_qnnp_create_average_pooling2d_nhwc_q8( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t pooling_height, + uint32_t pooling_width, + uint32_t stride_height, + uint32_t stride_width, + size_t channels, + uint8_t input_zero_point, + float input_scale, + uint8_t output_zero_point, + float output_scale, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + pytorch_qnnp_operator_t* average_pooling_out) { + pytorch_qnnp_operator_t average_pooling = NULL; + enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized; + + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_create_average_pooling2d_nhwc_q8 failed because QNNPACK is not properly initialized"); + goto error; + } + + status = pytorch_qnnp_status_invalid_parameter; + + const uint32_t pooling_size = pooling_height * pooling_width; + if (pooling_size == 0) { + pytorch_qnnp_log_error( + "failed to create average pooling with %" PRIu32 "x%" PRIu32 + " pooling size: " + "pooling size dimensions must be non-zero", + pooling_width, + pooling_height); + goto error; + } + + if (pooling_size == 1) { + pytorch_qnnp_log_error( + "failed to create average pooling with 1 pooling element: " + "1x1 pooling is meaningless"); + goto error; + } + + if (stride_height == 0 || stride_width == 0) { + pytorch_qnnp_log_error( + "failed to create average pooling with %" PRIu32 "x%" PRIu32 + " stride: " + "stride dimensions must be non-zero", + stride_width, + stride_height); + goto error; + } + + if (channels == 0) { + pytorch_qnnp_log_error( + "failed to create average pooling with %zu channels: " + "number of channels must be non-zero", + channels); + goto error; + } + + if (input_scale <= 0.0f || !isnormal(input_scale)) { + pytorch_qnnp_log_error( + "failed to create average pooling with %.7g input scale: " + "scale must be finite and positive", + input_scale); + goto error; + } + + if (output_scale <= 0.0f || !isnormal(output_scale)) { + pytorch_qnnp_log_error( + "failed to create average pooling with %.7g output scale: " + "scale must be finite and positive", + output_scale); + goto error; + } + + status = pytorch_qnnp_status_unsupported_parameter; + + const float input_output_scale = input_scale / output_scale; + if (input_output_scale < 0x1.0p-8f || input_output_scale >= 0x1.0p+8f) { + pytorch_qnnp_log_error( + "failed to create average pooling with %.7g input scale and %.7g output scale: " + "input-to-output scale ratio (%.7f) must be in [2**-8, 2**8) range", + input_scale, + output_scale, + input_output_scale); + goto error; + } + + if (pooling_size >= 16777216) { + pytorch_qnnp_log_error( + "failed to create average pooling with %" PRIu32 " (%" PRIu32 + "x%" PRIu32 + ") pooling elements: " + "the number of elements in the pooling area must be below 2**24", + pooling_size, + pooling_width, + pooling_height); + goto error; + } + + status = pytorch_qnnp_status_out_of_memory; + + average_pooling = calloc(1, sizeof(struct pytorch_qnnp_operator)); + if (average_pooling == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for pytorch_qnnp_operator structure", + sizeof(struct pytorch_qnnp_operator)); + goto error; + } + + const bool any_padding = (input_padding_left | input_padding_top | + input_padding_right | input_padding_bottom) != 0; + const uint32_t kr = pytorch_qnnp_params.q8avgpool.kr; + const uint32_t mr = pytorch_qnnp_params.q8avgpool.mr; + const uint32_t qr = pytorch_qnnp_params.q8avgpool.qr; + if (any_padding || (channels >= kr || (pooling_size - mr) % qr != 0)) { + void* zero_buffer = malloc(channels); + if (zero_buffer == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for zero padding", channels); + goto error; + } + memset(zero_buffer, input_zero_point, channels); + average_pooling->zero_buffer = zero_buffer; + average_pooling->zero_pointer = zero_buffer; + } + + average_pooling->input_padding_top = input_padding_top; + average_pooling->input_padding_right = input_padding_right; + average_pooling->input_padding_bottom = input_padding_bottom; + average_pooling->input_padding_left = input_padding_left; + + average_pooling->kernel_height = pooling_height; + average_pooling->kernel_width = pooling_width; + average_pooling->stride_height = stride_height; + average_pooling->stride_width = stride_width; + average_pooling->dilation_height = 1; + average_pooling->dilation_width = 1; + average_pooling->channels = channels; + + size_t nrows = pooling_height * pooling_width; + if (channels >= pytorch_qnnp_params.q8avgpool.kr) { + if (nrows <= mr) { + nrows = mr; + } else { + nrows = round_up(nrows - mr, qr) + mr; + } + } + + average_pooling->avgpool_quantization_params = + pytorch_qnnp_compute_avgpool_quantization_params( + (int32_t) - ((uint32_t)input_zero_point * (uint32_t)nrows), + input_scale / (output_scale * (float)pooling_size), + output_zero_point, + output_min, + output_max); + + average_pooling->ukernel_type = pytorch_qnnp_ukernel_type_average_pooling; + average_pooling->format = pytorch_qnnp_format_quint8; + + *average_pooling_out = average_pooling; + return pytorch_qnnp_status_success; + +error: + pytorch_qnnp_delete_operator(average_pooling); + return status; +} + +enum pytorch_qnnp_status pytorch_qnnp_setup_average_pooling2d_nhwc_q8( + pytorch_qnnp_operator_t average_pooling, + size_t batch_size, + size_t input_height, + size_t input_width, + const uint8_t* input, + size_t input_pixel_stride, + uint8_t* output, + size_t output_pixel_stride, + pthreadpool_t threadpool) { + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_setup_average_pooling2d_nhwc_q8 failed because QNNPACK is not properly initialized"); + return pytorch_qnnp_status_uninitialized; + } + + if (batch_size == 0) { + average_pooling->batch_size = 0; + return pytorch_qnnp_status_success; + } + + if (input_width == 0 || input_height == 0) { + pytorch_qnnp_log_error( + "failed to setup average pooling with %zux%zu input: input dimensions must be non-zero", + input_width, + input_height); + return pytorch_qnnp_status_invalid_parameter; + } + + average_pooling->batch_size = batch_size; + average_pooling->input_height = input_height; + average_pooling->input_width = input_width; + average_pooling->input = input; + average_pooling->input_pixel_stride = input_pixel_stride; + + average_pooling->output_height = compute_output_dimension( + average_pooling->input_padding_top + input_height + + average_pooling->input_padding_bottom, + average_pooling->kernel_height, + average_pooling->stride_height); + average_pooling->output_width = compute_output_dimension( + average_pooling->input_padding_left + input_width + + average_pooling->input_padding_right, + average_pooling->kernel_width, + average_pooling->stride_width); + average_pooling->output = output; + average_pooling->output_pixel_stride = output_pixel_stride; + + size_t valid_batch_size = 0; + if (input == average_pooling->last_input && + input_height == average_pooling->last_input_height && + input_width == average_pooling->last_input_width) { + valid_batch_size = average_pooling->valid_batch_size; + if (batch_size <= valid_batch_size) { + return pytorch_qnnp_status_success; + } + } + + const size_t pooling_height = average_pooling->kernel_height; + const size_t pooling_width = average_pooling->kernel_width; + const size_t pooling_size = pooling_height * pooling_width; + const size_t output_height = average_pooling->output_height; + const size_t output_width = average_pooling->output_width; + /* Micro-kernel may read up to (mr - 1) elements after the end of indirection + * buffer */ + const uint32_t mr = pytorch_qnnp_params.q8avgpool.mr; + + const size_t step_width = min(average_pooling->stride_width, pooling_width); + const size_t step_height = + pooling_size + (output_width * step_width - 1) * pooling_height; + const size_t indirection_buffer_size = + sizeof(void*) * ((mr - 1) + batch_size * output_height * step_height); + + const void** indirection_buffer = (const void**)realloc( + average_pooling->indirection_buffer, indirection_buffer_size); + if (indirection_buffer == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for indirection buffer", + indirection_buffer_size); + return pytorch_qnnp_status_out_of_memory; + } + average_pooling->indirection_buffer = indirection_buffer; + + pytorch_qnnp_indirection_init_dwconv2d( + average_pooling, valid_batch_size, step_height, step_width); + + average_pooling->last_input = input; + average_pooling->last_input_height = input_height; + average_pooling->last_input_width = input_width; + average_pooling->valid_batch_size = max(valid_batch_size, batch_size); + + return pytorch_qnnp_status_success; +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/channel-shuffle.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/channel-shuffle.c new file mode 100644 index 0000000000000..6485ee0a473d0 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/channel-shuffle.c @@ -0,0 +1,101 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +enum pytorch_qnnp_status pytorch_qnnp_create_channel_shuffle_nc_x8( + size_t groups, + size_t group_channels, + uint32_t flags, + pytorch_qnnp_operator_t* channel_shuffle_out) { + pytorch_qnnp_operator_t channel_shuffle_op = NULL; + enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized; + + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_create_channel_shuffle_nc_x8 failed because QNNPACK is not properly initialized"); + goto error; + } + + status = pytorch_qnnp_status_invalid_parameter; + + if (groups <= 1) { + pytorch_qnnp_log_error( + "failed to create channel shuffle operator with %zu groups: " + "at least two groups required", + groups); + goto error; + } + + if (group_channels == 0) { + pytorch_qnnp_log_error( + "failed to create channel shuffle operator with %zu group channels: " + "number of group channels must be non-zero", + group_channels); + goto error; + } + + status = pytorch_qnnp_status_out_of_memory; + + channel_shuffle_op = calloc(1, sizeof(struct pytorch_qnnp_operator)); + if (channel_shuffle_op == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for pytorch_qnnp_operator structure", + sizeof(struct pytorch_qnnp_operator)); + goto error; + } + + channel_shuffle_op->groups = groups; + channel_shuffle_op->group_channels = group_channels; + + channel_shuffle_op->ukernel_type = pytorch_qnnp_ukernel_type_channel_shuffle; + channel_shuffle_op->format = pytorch_qnnp_format_quint8; + + *channel_shuffle_out = channel_shuffle_op; + return pytorch_qnnp_status_success; + +error: + pytorch_qnnp_delete_operator(channel_shuffle_op); + return status; +} + +enum pytorch_qnnp_status pytorch_qnnp_setup_channel_shuffle_nc_x8( + pytorch_qnnp_operator_t channel_shuffle_op, + size_t batch_size, + const uint8_t* input, + size_t input_stride, + uint8_t* output, + size_t output_stride) { + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_setup_channel_shuffle_nc_x8 failed because QNNPACK is not properly initialized"); + return pytorch_qnnp_status_uninitialized; + } + + if (batch_size == 0) { + channel_shuffle_op->batch_size = 0; + return pytorch_qnnp_status_success; + } + + channel_shuffle_op->batch_size = batch_size; + channel_shuffle_op->input = input; + channel_shuffle_op->input_pixel_stride = input_stride; + channel_shuffle_op->output = output; + channel_shuffle_op->output_pixel_stride = output_stride; + + return pytorch_qnnp_status_success; +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/clamp.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/clamp.c new file mode 100644 index 0000000000000..8bc725c0d2ea3 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/clamp.c @@ -0,0 +1,102 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +enum pytorch_qnnp_status pytorch_qnnp_create_clamp_nc_u8( + size_t channels, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + pytorch_qnnp_operator_t* clamp_out) { + pytorch_qnnp_operator_t clamp_op = NULL; + enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized; + + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_create_clamp_nc_u8 failed because QNNPACK is not properly initialized"); + goto error; + } + + status = pytorch_qnnp_status_invalid_parameter; + + if (channels == 0) { + pytorch_qnnp_log_error( + "failed to create Clamp operator with %zu channels: number of channels must be non-zero", + channels); + goto error; + } + + if (output_min > output_max) { + pytorch_qnnp_log_error( + "failed to create Clamp operator with [%" PRIu8 ", %" PRIu8 + "] output range: range min must be below range max", + output_min, + output_max); + goto error; + } + + status = pytorch_qnnp_status_out_of_memory; + + clamp_op = calloc(1, sizeof(struct pytorch_qnnp_operator)); + if (clamp_op == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for pytorch_qnnp_operator structure", + sizeof(struct pytorch_qnnp_operator)); + goto error; + } + + clamp_op->channels = channels; + clamp_op->u8_clamping_params = + pytorch_qnnp_compute_u8_clamping_params(output_min, output_max); + + clamp_op->ukernel_type = pytorch_qnnp_ukernel_type_clamp; + clamp_op->format = pytorch_qnnp_format_quint8; + + *clamp_out = clamp_op; + return pytorch_qnnp_status_success; + +error: + pytorch_qnnp_delete_operator(clamp_op); + return status; +} + +enum pytorch_qnnp_status pytorch_qnnp_setup_clamp_nc_u8( + pytorch_qnnp_operator_t clamp, + size_t batch_size, + const uint8_t* input, + size_t input_stride, + uint8_t* output, + size_t output_stride) { + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_setup_clamp_nc_u8 failed because QNNPACK is not properly initialized"); + return pytorch_qnnp_status_uninitialized; + } + + if (batch_size == 0) { + clamp->batch_size = 0; + return pytorch_qnnp_status_success; + } + + clamp->batch_size = batch_size; + clamp->input = input; + clamp->input_pixel_stride = input_stride; + clamp->output = output; + clamp->output_pixel_stride = output_stride; + + return pytorch_qnnp_status_success; +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/conv-prepack.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/src/conv-prepack.cc new file mode 100644 index 0000000000000..2ec1d896271c4 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/conv-prepack.cc @@ -0,0 +1,190 @@ +#include +#include +#include +#include + +namespace qnnpack { + +PrePackConvWeights::PrePackConvWeights( + const conv_param_t& conv_p, + const uint8_t* kernel, + const int32_t* bias) { + output_channels_ = conv_p.output_channels; + enum pytorch_qnnp_ukernel_type ukernel_type = conv_p.ukernel_type; + const uint32_t kernel_width = conv_p.kernel_dims[0]; + const uint32_t kernel_height = conv_p.kernel_dims[1]; + const uint32_t groups = conv_p.groups; + + const size_t kernel_size = kernel_height * kernel_width; + switch (ukernel_type) { + case pytorch_qnnp_ukernel_type_dwconv: { + const uint32_t cr = pytorch_qnnp_params.q8dw9.cr; + const uint32_t c_stride = (groups + (cr - 1)) & -cr; + const size_t packed_weights_size = + (sizeof(uint8_t) * kernel_size + sizeof(int32_t)) * c_stride; + packed_weights_ = malloc(packed_weights_size); + if (packed_weights_ == nullptr) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for packed weights", + packed_weights_size); + assert("QNNPACK Runtime Error."); + } + + switch (kernel_size) { + case 9: + pytorch_pack_q8dw_wrq( + kernel_height, + kernel_width, + groups, + cr, + kernel, + bias, + packed_weights_); + break; + case 25: + /* change this later */ + pytorch_pack_q8dw_w_dilation( + kernel_height, + kernel_width, + groups, + cr, + 0, + kernel_height, + 0, + 2, + kernel, + bias, + packed_weights_, + true); + pytorch_pack_q8dw_w_dilation( + kernel_height, + kernel_width, + groups, + cr, + 0, + kernel_height, + 2, + 4, + kernel, + bias, + (char*)packed_weights_ + + (10 + sizeof(int32_t) / sizeof(uint8_t)) * c_stride, + false); + pytorch_pack_q8dw_w_dilation( + kernel_height, + kernel_width, + groups, + cr, + 0, + kernel_height, + 4, + 5, + kernel, + bias, + (char*)packed_weights_ + + (20 + sizeof(int32_t) / sizeof(uint8_t)) * c_stride, + false); + break; + default: + PYTORCH_QNNP_UNREACHABLE; + } + break; + } + case pytorch_qnnp_ukernel_type_xzp_gemm: { + const uint32_t nr = pytorch_qnnp_params.q8conv_xzp.nr; + const uint32_t kr = pytorch_qnnp_params.q8conv_xzp.kr; + const uint32_t sr = pytorch_qnnp_params.q8conv_xzp.kc; + const uint32_t n_stride = (conv_p.group_output_channels + (nr - 1)) & -nr; + const uint32_t k_stride = (conv_p.group_input_channels + (kr - 1)) & -kr; + + const size_t packed_group_weights_size = + (sizeof(uint8_t) * kernel_size * k_stride + sizeof(int32_t)) * + n_stride; + packed_weights_ = malloc(packed_group_weights_size * groups); + if (packed_weights_ == nullptr) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for packed weights", + packed_group_weights_size * groups); + assert("QNNPACK Runtime Error."); + } + /* The XZP ukernel needs the padding to be 0 */ + memset(packed_weights_, 0, packed_group_weights_size * groups); + + for (uint32_t group = 0; group < groups; group++) { + pytorch_pack_swizzle_q8gemm_brq( + conv_p.group_output_channels, + conv_p.group_input_channels, + nr, + kr, + sr, + kernel + + group * conv_p.group_output_channels * + conv_p.group_input_channels, + bias + group * conv_p.group_output_channels, + (void*)((uintptr_t)packed_weights_ + group * packed_group_weights_size)); + } + break; + } + case pytorch_qnnp_ukernel_type_gemm: + case pytorch_qnnp_ukernel_type_conv: { + const uint32_t nr = pytorch_qnnp_params.q8conv.nr; + const uint32_t kr = pytorch_qnnp_params.q8conv.kr; + const uint32_t n_stride = (conv_p.group_output_channels + (nr - 1)) & -nr; + const uint32_t k_stride = (conv_p.group_input_channels + (kr - 1)) & -kr; + + const size_t packed_group_weights_size = + (sizeof(uint8_t) * kernel_size * k_stride + sizeof(int32_t)) * + n_stride; + packed_weights_ = malloc(packed_group_weights_size * groups); + if (packed_weights_ == nullptr) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for packed weights", + packed_group_weights_size * groups); + assert("QNNPACK Runtime Error."); + } + memset( + packed_weights_, + conv_p.kernel_zero_point, + packed_group_weights_size * groups); + + switch (ukernel_type) { + case pytorch_qnnp_ukernel_type_gemm: + for (uint32_t group = 0; group < groups; group++) { + pytorch_pack_q8gemm_wrq( + conv_p.group_output_channels, + conv_p.group_input_channels, + nr, + nr, + kr, + kernel + + group * conv_p.group_output_channels * + conv_p.group_input_channels, + bias + group * conv_p.group_output_channels, + (void*)((uintptr_t)packed_weights_ + group * packed_group_weights_size)); + } + break; + case pytorch_qnnp_ukernel_type_conv: + for (uint32_t group = 0; group < groups; group++) { + pytorch_pack_q8conv_wrq( + conv_p.group_output_channels, + kernel_size, + conv_p.group_input_channels, + nr, + kr, + kernel + + group * conv_p.group_output_channels * kernel_size * + conv_p.group_input_channels, + bias + group * conv_p.group_output_channels, + (void*)((uintptr_t)packed_weights_ + group * packed_group_weights_size)); + } + break; + default: + PYTORCH_QNNP_UNREACHABLE; + } + break; + } + default: + PYTORCH_QNNP_UNREACHABLE; + } +} // namespace qnnpack +} // namespace qnnpack diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/conv-run.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/src/conv-run.cc new file mode 100644 index 0000000000000..c9a27c3dae657 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/conv-run.cc @@ -0,0 +1,690 @@ +#include +#include +#include +#include +#include +#include + +namespace qnnpack { + +static inline size_t compute_output_dimension( + size_t padded_input_dim, + size_t kernel_dimension, + size_t dilation_dimension, + size_t subsampling_dimension) { + const size_t effective_kernel_dim = + (kernel_dimension - 1) * dilation_dimension + 1; + return (padded_input_dim - effective_kernel_dim) / subsampling_dimension + 1; +} + +struct q8gemm_xzp_context { + size_t k; + size_t k_stride; + size_t n; + size_t n_stride; + const uint8_t* a; + size_t a_stride; + const void* packed_w; + uint8_t* c; + size_t c_stride; + const int32_t* a_sum; + size_t groups; + size_t batch_size; + size_t a_sum_stride; + union pytorch_qnnp_q31_requantization_params requantization_params; + const pytorch_q8gemm_xzp_ukernel_function ukernel; +}; +static void compute_q8gemm_xzp( + const struct q8gemm_xzp_context context[1], + size_t group_index, + size_t pixel_index, + size_t mr_block_start, + size_t nr_block_start, + size_t group_range /* always 1 */, + size_t pixel_range, + size_t mr_block_size, + size_t nr_block_size) { + const size_t k = context->k; + const size_t k_stride = context->k_stride; + const size_t n = context->n; + const size_t n_stride = context->n_stride; + const uint8_t* a = context->a; + const size_t a_stride = context->a_stride; + const void* packed_w = context->packed_w; + uint8_t* c = context->c; + const size_t c_stride = context->c_stride; + const int32_t* a_sum = context->a_sum; + const size_t groups = context->groups; + const size_t a_sum_stride = context->a_sum_stride; + + context->ukernel( + mr_block_size, + nr_block_size, + k, + a + (pixel_index + mr_block_start) * a_stride + group_index * k, + a_stride, + a_sum + pixel_index * groups + group_index * a_sum_stride + + mr_block_start, + (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))), + c + (pixel_index + mr_block_start) * c_stride + nr_block_start + + group_index * n, + c_stride, + &context->requantization_params); +} + +struct q8gemm_context { + size_t k; + size_t k_stride; + size_t n; + size_t n_stride; + const uint8_t* a; + size_t a_stride; + const uint8_t* packed_w; + uint8_t* c; + size_t c_stride; + union pytorch_qnnp_conv_quantization_params quantization_params; + const pytorch_q8gemm_ukernel_function ukernel; +}; +static void compute_q8gemm( + const struct q8gemm_context context[1], + size_t group_index, + size_t pixel_index, + size_t mr_block_start, + size_t nr_block_start, + size_t group_range /* always 1 */, + size_t pixel_range, + size_t mr_block_size, + size_t nr_block_size) { + const size_t k = context->k; + const size_t k_stride = context->k_stride; + const size_t n = context->n; + const size_t n_stride = context->n_stride; + const uint8_t* a = context->a; + const size_t a_stride = context->a_stride; + const void* packed_w = context->packed_w; + uint8_t* c = context->c; + const size_t c_stride = context->c_stride; + + context->ukernel( + mr_block_size, + nr_block_size, + k, + a + (pixel_index + mr_block_start) * a_stride + group_index * k, + a_stride, + (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))), + c + (pixel_index + mr_block_start) * c_stride + nr_block_start + + group_index * n, + c_stride, + &context->quantization_params); +} + +struct q8conv_context { + size_t bs; + size_t ks; + size_t kc; + size_t kc_stride; + size_t m; + size_t m_stride; + size_t n; + size_t n_stride; + const uint8_t** indirect_a; + const void* packed_w; + uint8_t* c; + size_t c_stride; + union pytorch_qnnp_conv_quantization_params quantization_params; + const pytorch_q8conv_ukernel_function ukernel; +}; +static void compute_q8conv( + const struct q8conv_context context[1], + size_t group_index, + size_t image_index, + size_t mr_block_start, + size_t nr_block_start, + size_t group_range /* always 1 */, + size_t image_range /* always 1 */, + size_t mr_block_size, + size_t nr_block_size) { + const size_t bs = context->bs; + const size_t ks = context->ks; + const size_t kc = context->kc; + const size_t kc_stride = context->kc_stride; + const size_t m = context->m; + const size_t m_stride = context->m_stride; + const size_t n = context->n; + const size_t n_stride = context->n_stride; + const uint8_t** indirect_a = context->indirect_a; + const void* packed_w = context->packed_w; + uint8_t* c = context->c; + const size_t c_stride = context->c_stride; + + context->ukernel( + mr_block_size, + nr_block_size, + kc, + ks, + indirect_a + + (mr_block_start + (image_index + group_index * bs) * m_stride) * ks, + (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (kc_stride * sizeof(uint8_t) + sizeof(int32_t))), + c + (mr_block_start + image_index * m) * c_stride + group_index * n + + nr_block_start, + c_stride, + &context->quantization_params); +} + +struct q8sum_rows_context { + const uint8_t* a; + size_t groups; + size_t m; + size_t k; + size_t a_stride; + const int32_t multiplier; + int32_t* a_sum; + size_t a_sum_stride; + const pytorch_q8sum_rows_ukernel_function ukernel; +}; +static void compute_sum_rows( + const struct q8sum_rows_context context[1], + size_t group_index, + size_t batch_index, + size_t block_start, + size_t group_range /* always 1 */, + size_t batch_range /* always 1 */, + size_t block_size) { + const uint8_t* a = context->a; + const size_t groups = context->groups; + const size_t m = context->m; + const size_t k = context->k; + const size_t a_stride = context->a_stride; + const int32_t multiplier = context->multiplier; + int32_t* a_sum = context->a_sum; + const size_t a_sum_stride = context->a_sum_stride; + + context->ukernel( + a + batch_index * m * a_stride + group_index * k + block_start * a_stride, + min(block_size, m - block_start), + k, + a_stride, + multiplier, + a_sum + batch_index * groups * a_sum_stride + group_index * a_sum_stride + + block_start); +} + +struct q8dwconv_context { + size_t groups; + size_t group_stride; + const uint8_t** indirection_buffer; + size_t indirection_buffer_row_stride; + size_t indirection_buffer_col_stride; + const void* packed_weights; + uint8_t* output; + size_t output_height; + size_t output_width; + size_t output_row_stride; + size_t output_col_increment; + union pytorch_qnnp_conv_quantization_params quantization_params; + const pytorch_q8dwconv_up_ukernel_function unipass_ukernel; + const pytorch_q8dwconv_mp_ukernel_function multipass_ukernel; +}; +static void compute_dwconv_unipass( + const struct q8dwconv_context context[1], + size_t image, + size_t output_y) { + const size_t output_height = context->output_height; + + context->unipass_ukernel( + context->groups, + context->output_width, + context->indirection_buffer + + (image * output_height + output_y) * + context->indirection_buffer_row_stride, + context->packed_weights, + context->output + + (image * output_height + output_y) * context->output_row_stride, + context->indirection_buffer_col_stride, + context->output_col_increment, + &context->quantization_params); +} +static void compute_dwconv_multiipass( + const struct q8dwconv_context context[1], + size_t image, + size_t output_y) { + const size_t output_height = context->output_height; + PYTORCH_QNNP_ALIGN(16) +#ifdef _MSC_VER + int32_t* multipass_acc = _malloca(sizeof(int32_t) * context->group_stride); +#else + int32_t multipass_acc[context->group_stride]; +#endif + + context->multipass_ukernel( + context->groups, + context->output_width, + context->indirection_buffer + + (image * output_height + output_y) * + context->indirection_buffer_row_stride, + context->packed_weights, + multipass_acc, + context->output + + (image * output_height + output_y) * context->output_row_stride, + context->indirection_buffer_col_stride, + context->output_col_increment, + &context->quantization_params); + +#ifdef _MSC_VER + _freea(multipass_acc); +#endif +} + +struct QnnpackDeleter { + void operator()(pytorch_qnnp_operator_t op) { + pytorch_qnnp_delete_operator(op); + } +}; + +enum pytorch_qnnp_status qnnpackConv( + const conv_param_t& conv_p, + void* packed_weights, + const size_t batch_size, + const size_t input_height, + const size_t input_width, + const float input_scale, + const uint8_t input_zero_point, + const uint8_t* input, + const float output_scale, + const uint8_t output_zero_point, + uint8_t* output, + pthreadpool_t threadpool) { + const size_t input_pixel_stride = conv_p.input_channels; + const size_t output_pixel_stride = conv_p.output_channels; + const size_t kernel_width = conv_p.kernel_dims[0]; + const size_t kernel_height = conv_p.kernel_dims[1]; + const size_t kernel_size = kernel_height * kernel_width; + const size_t dilation_width = conv_p.dilation[0]; + const size_t dilation_height = conv_p.dilation[1]; + const size_t groups = conv_p.groups; + + const float convolution_scale = + input_scale * conv_p.kernel_scale / output_scale; + if (convolution_scale >= 1.0f) { + pytorch_qnnp_log_error( + "failed to create convolution with %.7g input scale, %.7g kernel scale," + " and %.7g output scale: " + "convolution scale %.7g is greater or equal to 1.0", + input_scale, + conv_p.kernel_scale, + output_scale, + convolution_scale); + } + union pytorch_qnnp_q31_requantization_params requantization_params; + union pytorch_qnnp_conv_quantization_params conv_quantization_params; + if (conv_p.ukernel_type == pytorch_qnnp_ukernel_type_xzp_gemm) { + requantization_params = pytorch_qnnp_compute_requantization_params( + convolution_scale, + output_zero_point, + conv_p.output_min, + conv_p.output_max); + } else { + conv_quantization_params = pytorch_qnnp_compute_conv_quantization_params( + input_zero_point, + conv_p.kernel_zero_point, + convolution_scale, + output_zero_point, + conv_p.output_min, + conv_p.output_max); + } + uint32_t stride_width = conv_p.subsampling_dims[0]; + uint32_t stride_height = conv_p.subsampling_dims[1]; + + size_t output_height = compute_output_dimension( + conv_p.pad[0] + input_height + conv_p.pad[2], + kernel_height, + dilation_height, + stride_height); + size_t output_width = compute_output_dimension( + conv_p.pad[1] + input_width + conv_p.pad[3], + kernel_width, + dilation_width, + stride_width); + const size_t output_size = output_height * output_width; + + // FIXME temporary solution to create a qnnp_op struct for indirection buffer. + const bool any_padding = + (conv_p.pad[0] | conv_p.pad[1] | conv_p.pad[2] | conv_p.pad[3]) != 0; + size_t zero_size = 0, zero_offset = 0; + + pytorch_qnnp_operator_t convolution{nullptr}; + convolution = + static_cast(calloc(1, sizeof(struct pytorch_qnnp_operator))); + if (convolution == nullptr) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for pytorch_qnnp_operator structure", + sizeof(struct pytorch_qnnp_operator)); + return pytorch_qnnp_status_out_of_memory; + } + + std::unique_ptr qnnpack_uniq_ptr(convolution); + + convolution->input = input; + convolution->input_pixel_stride = input_pixel_stride; + convolution->groups = groups; + convolution->group_input_channels = conv_p.group_input_channels; + convolution->batch_size = batch_size; + convolution->input_height = input_height; + convolution->input_width = input_width; + convolution->output_height = output_height; + convolution->output_width = output_width; + convolution->kernel_height = kernel_height; + convolution->kernel_width = kernel_width; + convolution->stride_height = stride_height; + convolution->stride_width = stride_width; + convolution->dilation_height = dilation_height; + convolution->dilation_width = dilation_width; + convolution->input_padding_top = conv_p.pad[0]; + convolution->input_padding_left = conv_p.pad[1]; + + switch (conv_p.ukernel_type) { + case pytorch_qnnp_ukernel_type_dwconv: { + const size_t width_step = + dilation_width == 1 ? stride_width : kernel_width; + const uint32_t cr = pytorch_qnnp_params.q8dw9.cr; + const size_t group_stride = (groups + (cr - 1)) & -cr; + + if (any_padding) { + if (groups >= 8) { + zero_size = sizeof(uint8_t) * group_stride; + zero_offset = 0; + } else { + zero_size = sizeof(uint8_t) * group_stride + 8; + zero_offset = sizeof(uint8_t) * 8; + } + void* zero_buffer = malloc(zero_size); + if (zero_buffer == nullptr) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for zero padding", zero_size); + return pytorch_qnnp_status_out_of_memory; + } + memset(zero_buffer, input_zero_point, zero_size); + convolution->zero_buffer = zero_buffer; + convolution->zero_pointer = + (void*)((uintptr_t)zero_buffer + zero_offset); + } + const size_t step_width = convolution->dilation_width == 1 + ? convolution->stride_width + : kernel_width; + const size_t step_height = + kernel_size + (output_width * step_width - 1) * kernel_height; + const size_t indirection_buffer_size = + sizeof(void*) * batch_size * output_height * step_height; + + const void** indirection_buffer = (const void**)realloc( + convolution->indirection_buffer, indirection_buffer_size); + if (indirection_buffer == nullptr) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for indirection buffer", + indirection_buffer_size); + return pytorch_qnnp_status_out_of_memory; + } + convolution->indirection_buffer = indirection_buffer; + + pytorch_qnnp_indirection_init_dwconv2d(convolution, 0, step_height, step_width); + + switch (kernel_size) { + case 9: { + struct q8dwconv_context context = { + .groups = groups, + .group_stride = group_stride, + .indirection_buffer = (const uint8_t**)indirection_buffer, + .indirection_buffer_row_stride = + kernel_size + (output_width * width_step - 1) * kernel_height, + .indirection_buffer_col_stride = + kernel_height * width_step * sizeof(void*), + .packed_weights = packed_weights, + .output = output, + .output_height = output_height, + .output_width = output_width, + .output_row_stride = output_width * output_pixel_stride, + .output_col_increment = + (output_pixel_stride - groups) * sizeof(uint8_t), + .quantization_params = conv_quantization_params, + .unipass_ukernel = pytorch_qnnp_params.q8dw9.updw, + .multipass_ukernel = pytorch_qnnp_params.q8dw25.mpdw, + }; + pthreadpool_compute_2d( + threadpool, + (pthreadpool_function_2d_t)compute_dwconv_unipass, + &context, + batch_size, + output_height); + break; + } + case 25: { + struct q8dwconv_context context = { + .groups = groups, + .group_stride = group_stride, + .indirection_buffer = + (const uint8_t**)convolution->indirection_buffer, + .indirection_buffer_row_stride = + kernel_size + (output_width * width_step - 1) * kernel_height, + .indirection_buffer_col_stride = + kernel_height * width_step * sizeof(void*), + .packed_weights = packed_weights, + .output = output, + .output_height = output_height, + .output_width = output_width, + .output_row_stride = output_width * output_pixel_stride, + .output_col_increment = + (output_pixel_stride - groups) * sizeof(uint8_t), + .quantization_params = conv_quantization_params, + .unipass_ukernel = pytorch_qnnp_params.q8dw9.updw, + .multipass_ukernel = pytorch_qnnp_params.q8dw25.mpdw, + }; + pthreadpool_compute_2d( + threadpool, + (pthreadpool_function_2d_t)compute_dwconv_multiipass, + &context, + batch_size, + output_height); + break; + } + default: + PYTORCH_QNNP_UNREACHABLE; + } + break; + } + case pytorch_qnnp_ukernel_type_xzp_gemm: { + const size_t group_input_channels = conv_p.group_input_channels; + const size_t group_output_channels = conv_p.group_output_channels; + const uint32_t mr = pytorch_qnnp_params.q8conv_xzp.mr; + const uint32_t nr = pytorch_qnnp_params.q8conv_xzp.nr; + const uint32_t kr = pytorch_qnnp_params.q8conv_xzp.kr; + const size_t k_stride = (group_input_channels + (kr - 1)) & -kr; + const size_t n_stride = (group_output_channels + (nr - 1)) & -nr; + + /* compute input row sum */ + const size_t input_size = input_height * input_width; + int32_t* a_sum = (int32_t*)realloc( + convolution->a_sum, + sizeof(int32_t) * batch_size * groups * input_height * input_width); + if (a_sum == nullptr) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for row sum data", + sizeof(int32_t) * batch_size * groups * input_height * input_width); + return pytorch_qnnp_status_out_of_memory; + } + convolution->a_sum = a_sum; + struct q8sum_rows_context context = { + .a = input, + .groups = groups, + .m = input_size, + .k = conv_p.group_input_channels, + .a_stride = input_pixel_stride, + .multiplier = (int32_t)-conv_p.kernel_zero_point, + .a_sum = a_sum, + .a_sum_stride = input_size, + .ukernel = pytorch_qnnp_params.q8sum_rows.sum_rows, + }; + pthreadpool_compute_3d_tiled( + threadpool, + (pthreadpool_function_3d_tiled_t)compute_sum_rows, + &context, + groups, + batch_size, + input_size, + 1, + 1, + pytorch_qnnp_params.q8sum_rows.m); + + struct q8gemm_xzp_context q8gemm_xzp_context = { + .k = conv_p.group_input_channels, + .k_stride = k_stride, + .n = conv_p.group_output_channels, + .n_stride = n_stride, + .a = input, + .a_stride = input_pixel_stride, + .packed_w = packed_weights, + .c = output, + .c_stride = output_pixel_stride, + .a_sum = a_sum, + .groups = groups, + .batch_size = batch_size, + .a_sum_stride = input_size, + .requantization_params = requantization_params, + .ukernel = pytorch_qnnp_params.q8conv_xzp.gemm, + }; + pthreadpool_compute_4d_tiled( + threadpool, + (pthreadpool_function_4d_tiled_t)compute_q8gemm_xzp, + &q8gemm_xzp_context, + groups, + batch_size * input_size, + input_size, + group_output_channels, + 1, + input_size, + mr, + nr); + break; + } + case pytorch_qnnp_ukernel_type_gemm: { + const size_t group_input_channels = conv_p.group_input_channels; + const size_t group_output_channels = conv_p.group_output_channels; + const uint32_t mr = pytorch_qnnp_params.q8conv.mr; + const uint32_t nr = pytorch_qnnp_params.q8conv.nr; + const uint32_t kr = pytorch_qnnp_params.q8conv.kr; + const size_t k_stride = (group_input_channels + (kr - 1)) & -kr; + const size_t n_stride = (group_output_channels + (nr - 1)) & -nr; + + struct q8gemm_context q8gemm_context = { + .k = conv_p.group_input_channels, + .k_stride = k_stride, + .n = conv_p.group_output_channels, + .n_stride = n_stride, + .a = input, + .a_stride = input_pixel_stride, + .packed_w = (uint8_t*)packed_weights, + .c = output, + .c_stride = output_pixel_stride, + .quantization_params = conv_quantization_params, + .ukernel = pytorch_qnnp_params.q8conv.gemm, + }; + + pthreadpool_compute_4d_tiled( + threadpool, + (pthreadpool_function_4d_tiled_t)compute_q8gemm, + &q8gemm_context, + groups, + batch_size * output_size, + output_size, + group_output_channels, + 1, + output_size, + mr, + nr); + break; + } + case pytorch_qnnp_ukernel_type_conv: { + const size_t group_input_channels = conv_p.group_input_channels; + const size_t group_output_channels = conv_p.group_output_channels; + const uint32_t mr = pytorch_qnnp_params.q8conv.mr; + const uint32_t nr = pytorch_qnnp_params.q8conv.nr; + const uint32_t kr = pytorch_qnnp_params.q8conv.kr; + const size_t k_stride = (group_input_channels + (kr - 1)) & -kr; + const size_t n_stride = (group_output_channels + (nr - 1)) & -nr; + const size_t m_stride = round_up(output_size, mr); + + if (any_padding) { + if (group_input_channels >= 8) { + zero_size = sizeof(uint8_t) * k_stride; + zero_offset = 0; + } else { + zero_size = sizeof(uint8_t) * k_stride + 8; + zero_offset = 8; + } + void* zero_buffer = malloc(zero_size); + if (zero_buffer == nullptr) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for zero padding", zero_size); + return pytorch_qnnp_status_out_of_memory; + } + memset(zero_buffer, input_zero_point, zero_size); + convolution->zero_buffer = zero_buffer; + convolution->zero_pointer = + (void*)((uintptr_t)zero_buffer + zero_offset); + } + + const size_t output_tile_size = pytorch_qnnp_params.q8conv.mr; + const size_t tiled_output_size = round_up(output_size, output_tile_size); + const size_t indirection_buffer_size = + sizeof(void*) * batch_size * groups * tiled_output_size * kernel_size; + const void** indirection_buffer = (const void**)realloc( + convolution->indirection_buffer, indirection_buffer_size); + if (indirection_buffer == nullptr) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for indirection buffer", + indirection_buffer_size); + return pytorch_qnnp_status_out_of_memory; + } + convolution->indirection_buffer = indirection_buffer; + + pytorch_qnnp_indirection_init_conv2d( + convolution, output_tile_size, tiled_output_size); + + struct q8conv_context q8conv_context = { + .bs = batch_size, + .ks = kernel_size, + .kc = group_input_channels, + .kc_stride = k_stride * kernel_size, + .m = output_size, + .m_stride = m_stride, + .n = group_output_channels, + .n_stride = n_stride, + .indirect_a = (const uint8_t**)convolution->indirection_buffer, + .packed_w = packed_weights, + .c = output, + .c_stride = output_pixel_stride, + .quantization_params = conv_quantization_params, + .ukernel = pytorch_qnnp_params.q8conv.conv, + }; + + pthreadpool_compute_4d_tiled( + threadpool, + (pthreadpool_function_4d_tiled_t)compute_q8conv, + &q8conv_context, + groups, + batch_size, + output_size, + group_output_channels, + 1, + 1, + mr, + nr); + break; + } + default: { + pytorch_qnnp_log_error("Invalid kernel type. QNNPACK convolution run failed."); + PYTORCH_QNNP_UNREACHABLE; + } + } + return pytorch_qnnp_status_success; +} +} // namespace qnnpack diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/convolution.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/convolution.c new file mode 100644 index 0000000000000..56cf87083df47 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/convolution.c @@ -0,0 +1,625 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +static inline size_t compute_output_dimension( + size_t padded_input_dimension, + size_t kernel_dimension, + size_t dilation_dimension, + size_t subsampling_dimension) { + const size_t effective_kernel_dimension = + (kernel_dimension - 1) * dilation_dimension + 1; + return (padded_input_dimension - effective_kernel_dimension) / + subsampling_dimension + + 1; +} + +enum pytorch_qnnp_status pytorch_qnnp_create_convolution2d_nhwc_q8( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t subsampling_height, + uint32_t subsampling_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + uint8_t input_zero_point, + float input_scale, + uint8_t kernel_zero_point, + float kernel_scale, + const uint8_t* kernel, + const int32_t* bias, + uint8_t output_zero_point, + float output_scale, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + pytorch_qnnp_operator_t* convolution_out) { + pytorch_qnnp_operator_t convolution = NULL; + enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized; + + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_create_convolution2d_nhwc_q8 failed because QNNPACK is not properly initialized"); + goto error; + } + + status = pytorch_qnnp_status_invalid_parameter; + + if (kernel_width == 0 || kernel_height == 0) { + pytorch_qnnp_log_error( + "failed to create convolution with %" PRIu32 "x%" PRIu32 + " kernel: kernel dimensions must be non-zero", + kernel_width, + kernel_height); + goto error; + } + + if (subsampling_width == 0 || subsampling_height == 0) { + pytorch_qnnp_log_error( + "failed to create convolution with %" PRIu32 "x%" PRIu32 + " subsampling: " + "subsampling dimensions must be non-zero", + subsampling_width, + subsampling_height); + goto error; + } + + if (dilation_width == 0 || dilation_height == 0) { + pytorch_qnnp_log_error( + "failed to create convolution with %" PRIu32 "x%" PRIu32 + " dilation: " + "dilation dimensions must be non-zero", + dilation_width, + dilation_height); + goto error; + } + + if (input_scale <= 0.0f || !isnormal(input_scale)) { + pytorch_qnnp_log_error( + "failed to create convolution with %.7g input scale: scale must be finite and positive", + input_scale); + goto error; + } + + if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) { + pytorch_qnnp_log_error( + "failed to create convolution with %.7g kernel scale: scale must be finite and positive", + kernel_scale); + goto error; + } + + if (output_scale <= 0.0f || !isnormal(output_scale)) { + pytorch_qnnp_log_error( + "failed to create convolution with %.7g output scale: scale must be finite and positive", + output_scale); + goto error; + } + + status = pytorch_qnnp_status_unsupported_parameter; + + if (subsampling_height > kernel_height) { + pytorch_qnnp_log_info( + "inefficiency in convolution with %" PRIu32 "x%" PRIu32 + " kernel and %" PRIu32 "x%" PRIu32 + " subsampling: " + "height subsampling is greater than kernel height; subsampling should be performed before the convolution", + kernel_width, + kernel_height, + subsampling_width, + subsampling_height); + } + + if (subsampling_width > kernel_width) { + pytorch_qnnp_log_info( + "inefficiency in convolution with %" PRIu32 "x%" PRIu32 + " kernel and %" PRIu32 "x%" PRIu32 + " subsampling: " + "width subsampling is greater than kernel width; subsampling should be performed before the convolution", + kernel_width, + kernel_height, + subsampling_width, + subsampling_height); + } + + if (input_padding_top >= kernel_height) { + pytorch_qnnp_log_info( + "inefficiency in convolution with %" PRIu32 "x%" PRIu32 + " kernel and %" PRIu32 "+%" PRIu32 + " height padding: " + "input top padding is greater or equal to kernel height", + kernel_width, + kernel_height, + input_padding_top, + input_padding_bottom); + } + + if (input_padding_bottom >= kernel_height) { + pytorch_qnnp_log_info( + "inefficiency in convolution with %" PRIu32 "x%" PRIu32 + " kernel and %" PRIu32 "+%" PRIu32 + " height padding: " + "input bottom padding is greater or equal to kernel height", + kernel_width, + kernel_height, + input_padding_top, + input_padding_bottom); + } + + if (input_padding_right >= kernel_width) { + pytorch_qnnp_log_info( + "inefficiency in convolution with %" PRIu32 "x%" PRIu32 + " kernel and %" PRIu32 "+%" PRIu32 + " width padding: " + "input right padding is greater or equal to kernel width", + kernel_width, + kernel_height, + input_padding_left, + input_padding_right); + } + + if (input_padding_left >= kernel_width) { + pytorch_qnnp_log_info( + "inefficiency in convolution with %" PRIu32 "x%" PRIu32 + " kernel and %" PRIu32 "+%" PRIu32 + " width padding: " + "input left padding is greater or equal to kernel width", + kernel_width, + kernel_height, + input_padding_left, + input_padding_right); + } + + const float convolution_scale = input_scale * kernel_scale / output_scale; + if (convolution_scale >= 1.0f) { + pytorch_qnnp_log_error( + "failed to create convolution with %.7g input scale, %.7g kernel scale, and %.7g output scale: " + "convolution scale %.7g is greater or equal to 1.0", + input_scale, + kernel_scale, + output_scale, + convolution_scale); + goto error; + } + + status = pytorch_qnnp_status_out_of_memory; + + convolution = calloc(1, sizeof(struct pytorch_qnnp_operator)); + if (convolution == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for pytorch_qnnp_operator structure", + sizeof(struct pytorch_qnnp_operator)); + goto error; + } + + const size_t kernel_size = kernel_height * kernel_width; + + enum pytorch_qnnp_ukernel_type ukernel_type = pytorch_qnnp_ukernel_type_none; + const bool any_padding = (input_padding_left | input_padding_top | + input_padding_right | input_padding_bottom) != 0; + if ((kernel_size == 9 || kernel_size == 25) && group_input_channels == 1 && + group_output_channels == 1 && groups > 1) { + ukernel_type = pytorch_qnnp_ukernel_type_dwconv; + } else if ( + kernel_size == 1 && subsampling_height == 1 && subsampling_width == 1 && + !any_padding) { + ukernel_type = + group_input_channels >= pytorch_qnnp_params.q8conv_xzp.kthreshold + ? pytorch_qnnp_ukernel_type_xzp_gemm + : pytorch_qnnp_ukernel_type_gemm; + } else { + ukernel_type = pytorch_qnnp_ukernel_type_conv; + } + size_t zero_size = 0, zero_offset = 0; + + switch (ukernel_type) { + case pytorch_qnnp_ukernel_type_dwconv: { + const uint32_t cr = pytorch_qnnp_params.q8dw9.cr; + const uint32_t c_stride = (groups + (cr - 1)) & -cr; + convolution->group_stride = c_stride; + const size_t packed_weights_size = + (sizeof(uint8_t) * kernel_size + sizeof(int32_t)) * c_stride; + convolution->packed_weights = malloc(packed_weights_size); + if (convolution->packed_weights == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for packed weights", + packed_weights_size); + goto error; + } + + switch (kernel_size) { + case 9: + pytorch_pack_q8dw_w( + kernel_height, + kernel_width, + groups, + cr, +#if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION + input_zero_point, + kernel_zero_point, +#endif + kernel, + bias, + convolution->packed_weights); + break; + case 25: + /* change this later */ + pytorch_pack_q8dw_w_dilation( + kernel_height, + kernel_width, + groups, + cr, + 0, + kernel_height, + 0, + 2, + kernel, + bias, + convolution->packed_weights, + true); + pytorch_pack_q8dw_w_dilation( + kernel_height, + kernel_width, + groups, + cr, + 0, + kernel_height, + 2, + 4, + kernel, + bias, + (char*)convolution->packed_weights + + (10 + sizeof(int32_t) / sizeof(uint8_t)) * c_stride, + false); + pytorch_pack_q8dw_w_dilation( + kernel_height, + kernel_width, + groups, + cr, + 0, + kernel_height, + 4, + 5, + kernel, + bias, + (char*)convolution->packed_weights + + (20 + sizeof(int32_t) / sizeof(uint8_t)) * c_stride, + false); + break; + default: + PYTORCH_QNNP_UNREACHABLE; + } + + if (groups >= 8) { + zero_size = sizeof(uint8_t) * c_stride; + zero_offset = 0; + } else { + zero_size = sizeof(uint8_t) * c_stride + 8; + zero_offset = sizeof(uint8_t) * 8; + } + break; + } + case pytorch_qnnp_ukernel_type_xzp_gemm: { + const uint32_t nr = pytorch_qnnp_params.q8conv_xzp.nr; + const uint32_t kr = pytorch_qnnp_params.q8conv_xzp.kr; + const uint32_t sr = pytorch_qnnp_params.q8conv_xzp.kc; + const uint32_t n_stride = (group_output_channels + (nr - 1)) & -nr; + const uint32_t k_stride = (group_input_channels + (kr - 1)) & -kr; + + const size_t packed_group_weights_size = + (sizeof(uint8_t) * kernel_size * k_stride + sizeof(int32_t)) * + n_stride; + convolution->packed_weights = malloc(packed_group_weights_size * groups); + if (convolution->packed_weights == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for packed weights", + packed_group_weights_size * groups); + goto error; + } + /* The XZP ukernel needs the padding to be 0 */ + memset( + convolution->packed_weights, 0, packed_group_weights_size * groups); + + for (uint32_t group = 0; group < groups; group++) { + pytorch_pack_swizzle_q8gemm_b( + group_output_channels, + group_input_channels, + nr, + kr, + sr, +#if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION + input_zero_point, + kernel_zero_point, +#endif + kernel + group * group_output_channels * group_input_channels, + bias + group * group_output_channels, + (void*)((uintptr_t)convolution->packed_weights + group * packed_group_weights_size)); + } + break; + } + case pytorch_qnnp_ukernel_type_gemm: + case pytorch_qnnp_ukernel_type_conv: { + const uint32_t nr = pytorch_qnnp_params.q8conv.nr; + const uint32_t kr = pytorch_qnnp_params.q8conv.kr; + const uint32_t n_stride = (group_output_channels + (nr - 1)) & -nr; + const uint32_t k_stride = (group_input_channels + (kr - 1)) & -kr; + + const size_t packed_group_weights_size = + (sizeof(uint8_t) * kernel_size * k_stride + sizeof(int32_t)) * + n_stride; + convolution->packed_weights = malloc(packed_group_weights_size * groups); + if (convolution->packed_weights == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for packed weights", + packed_group_weights_size * groups); + goto error; + } + memset( + convolution->packed_weights, + kernel_zero_point, + packed_group_weights_size * groups); + + switch (ukernel_type) { + case pytorch_qnnp_ukernel_type_gemm: + for (uint32_t group = 0; group < groups; group++) { + pytorch_pack_q8gemm_w( + group_output_channels, + group_input_channels, + nr, + nr, + kr, +#if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION + input_zero_point, + kernel_zero_point, +#endif + kernel + group * group_output_channels * group_input_channels, + bias + group * group_output_channels, + (void*)((uintptr_t)convolution->packed_weights + group * packed_group_weights_size)); + } + break; + case pytorch_qnnp_ukernel_type_conv: + for (uint32_t group = 0; group < groups; group++) { + pytorch_pack_q8conv_w( + group_output_channels, + kernel_size, + group_input_channels, + nr, + kr, +#if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION + input_zero_point, + kernel_zero_point, +#endif + kernel + + group * group_output_channels * kernel_size * + group_input_channels, + bias + group * group_output_channels, + (void*)((uintptr_t)convolution->packed_weights + group * packed_group_weights_size)); + } + break; + default: + PYTORCH_QNNP_UNREACHABLE; + } + + if (group_input_channels >= 8) { + zero_size = sizeof(uint8_t) * k_stride; + zero_offset = 0; + } else { + zero_size = sizeof(uint8_t) * k_stride + 8; + zero_offset = 8; + } + break; + } + default: + PYTORCH_QNNP_UNREACHABLE; + } + + if (any_padding) { + void* zero_buffer = malloc(zero_size); + if (zero_buffer == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for zero padding", zero_size); + goto error; + } + memset(zero_buffer, input_zero_point, zero_size); + convolution->zero_buffer = zero_buffer; + convolution->zero_pointer = (void*)((uintptr_t)zero_buffer + zero_offset); + } + + convolution->input_padding_top = input_padding_top; + convolution->input_padding_right = input_padding_right; + convolution->input_padding_bottom = input_padding_bottom; + convolution->input_padding_left = input_padding_left; + + convolution->kernel_height = kernel_height; + convolution->kernel_width = kernel_width; + convolution->stride_height = subsampling_height; + convolution->stride_width = subsampling_width; + convolution->dilation_height = dilation_height; + convolution->dilation_width = dilation_width; + convolution->groups = groups; + convolution->group_input_channels = group_input_channels; + convolution->group_output_channels = group_output_channels; + + convolution->kernel_zero_point = kernel_zero_point; + + if (ukernel_type == pytorch_qnnp_ukernel_type_xzp_gemm) { + convolution->requantization_params = + pytorch_qnnp_compute_requantization_params( + convolution_scale, output_zero_point, output_min, output_max); + } else { + convolution->conv_quantization_params = + pytorch_qnnp_compute_conv_quantization_params( + input_zero_point, + kernel_zero_point, + convolution_scale, + output_zero_point, + output_min, + output_max); + } + + convolution->ukernel_type = ukernel_type; + convolution->format = pytorch_qnnp_format_quint8; + + *convolution_out = convolution; + return pytorch_qnnp_status_success; + +error: + pytorch_qnnp_delete_operator(convolution); + return status; +} + +enum pytorch_qnnp_status pytorch_qnnp_setup_convolution2d_nhwc_q8( + pytorch_qnnp_operator_t convolution, + size_t batch_size, + size_t input_height, + size_t input_width, + const uint8_t* input, + size_t input_pixel_stride, + uint8_t* output, + size_t output_pixel_stride, + pthreadpool_t threadpool) { + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_setup_convolution2d_nhwc_q8 failed because QNNPACK is not properly initialized"); + return pytorch_qnnp_status_uninitialized; + } + + if (batch_size == 0) { + convolution->batch_size = 0; + return pytorch_qnnp_status_success; + } + + if (input_width == 0 || input_height == 0) { + pytorch_qnnp_log_error( + "failed to setup convolution with %zux%zu input: input dimensions must be non-zero", + input_width, + input_height); + return pytorch_qnnp_status_invalid_parameter; + } + + convolution->batch_size = batch_size; + convolution->input_height = input_height; + convolution->input_width = input_width; + convolution->input = input; + convolution->input_pixel_stride = input_pixel_stride; + + convolution->output_height = compute_output_dimension( + convolution->input_padding_top + input_height + + convolution->input_padding_bottom, + convolution->kernel_height, + convolution->dilation_height, + convolution->stride_height); + convolution->output_width = compute_output_dimension( + convolution->input_padding_left + input_width + + convolution->input_padding_right, + convolution->kernel_width, + convolution->dilation_width, + convolution->stride_width); + convolution->output = output; + convolution->output_pixel_stride = output_pixel_stride; + + switch (convolution->ukernel_type) { + case pytorch_qnnp_ukernel_type_gemm: + /* Convolution maps directly to GEMM and doesn't use indirection buffer */ + return pytorch_qnnp_status_success; + case pytorch_qnnp_ukernel_type_xzp_gemm: { + const size_t groups = convolution->groups; + void* a_sum = (void*)realloc( + convolution->a_sum, + sizeof(int32_t) * batch_size * groups * input_height * input_width); + if (a_sum == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for row sum data", + sizeof(int32_t) * batch_size * groups * input_height * input_width); + return pytorch_qnnp_status_out_of_memory; + } + convolution->a_sum = a_sum; + return pytorch_qnnp_status_success; + } + case pytorch_qnnp_ukernel_type_conv: { + const size_t groups = convolution->groups; + const size_t kernel_height = convolution->kernel_height; + const size_t kernel_width = convolution->kernel_width; + const size_t kernel_size = kernel_height * kernel_width; + const size_t output_height = convolution->output_height; + const size_t output_width = convolution->output_width; + const size_t output_size = output_height * output_width; + const size_t output_tile_size = pytorch_qnnp_params.q8conv.mr; + const size_t tiled_output_size = round_up(output_size, output_tile_size); + const size_t indirection_buffer_size = + sizeof(void*) * batch_size * groups * tiled_output_size * kernel_size; + + const void** indirection_buffer = (const void**)realloc( + convolution->indirection_buffer, indirection_buffer_size); + if (indirection_buffer == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for indirection buffer", + indirection_buffer_size); + return pytorch_qnnp_status_out_of_memory; + } + convolution->indirection_buffer = indirection_buffer; + + pytorch_qnnp_indirection_init_conv2d( + convolution, output_tile_size, tiled_output_size); + return pytorch_qnnp_status_success; + } + case pytorch_qnnp_ukernel_type_dwconv: { + const size_t kernel_height = convolution->kernel_height; + const size_t kernel_width = convolution->kernel_width; + const size_t kernel_size = kernel_height * kernel_width; + const size_t output_height = convolution->output_height; + const size_t output_width = convolution->output_width; + const size_t step_width = convolution->dilation_width == 1 + ? convolution->stride_width + : kernel_width; + const size_t step_height = + kernel_size + (output_width * step_width - 1) * kernel_height; + const size_t indirection_buffer_size = + sizeof(void*) * batch_size * output_height * step_height; + + const void** indirection_buffer = (const void**)realloc( + convolution->indirection_buffer, indirection_buffer_size); + if (indirection_buffer == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for indirection buffer", + indirection_buffer_size); + return pytorch_qnnp_status_out_of_memory; + } + convolution->indirection_buffer = indirection_buffer; + + pytorch_qnnp_indirection_init_dwconv2d( + convolution, 0, step_height, step_width); + return pytorch_qnnp_status_success; + } + default: + PYTORCH_QNNP_UNREACHABLE; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/deconvolution.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/deconvolution.c new file mode 100644 index 0000000000000..4dba1cde74bb3 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/deconvolution.c @@ -0,0 +1,326 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +static inline size_t compute_output_dimension( + size_t input_dimension, + size_t input_padding_dimension, + size_t adjustment_dimension, + size_t kernel_dimension, + size_t dilation_dimension, + size_t stride_dimension) { + const size_t effective_kernel_dimension = + (kernel_dimension - 1) * dilation_dimension + 1; + return stride_dimension * (input_dimension - 1) + adjustment_dimension + + effective_kernel_dimension - input_padding_dimension; +} + +enum pytorch_qnnp_status pytorch_qnnp_create_deconvolution2d_nhwc_q8( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t adjustment_height, + uint32_t adjustment_width, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t stride_height, + uint32_t stride_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + uint8_t input_zero_point, + float input_scale, + uint8_t kernel_zero_point, + float kernel_scale, + const uint8_t* kernel, + const int32_t* bias, + uint8_t output_zero_point, + float output_scale, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + pytorch_qnnp_operator_t* deconvolution_out) { + pytorch_qnnp_operator_t deconvolution = NULL; + enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized; + + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_create_deconvolution2d_nhwc_q8 failed because QNNPACK is not properly initialized"); + goto error; + } + + status = pytorch_qnnp_status_invalid_parameter; + + if (kernel_width == 0 || kernel_height == 0) { + pytorch_qnnp_log_error( + "failed to create deconvolution with %" PRIu32 "x%" PRIu32 + " kernel: kernel dimensions must be non-zero", + kernel_width, + kernel_height); + goto error; + } + + if (stride_width == 0 || stride_height == 0) { + pytorch_qnnp_log_error( + "failed to create deconvolution with %" PRIu32 "x%" PRIu32 + " stride: " + "stride dimensions must be non-zero", + stride_width, + stride_height); + goto error; + } + + if (dilation_width == 0 || dilation_height == 0) { + pytorch_qnnp_log_error( + "failed to create deconvolution with %" PRIu32 "x%" PRIu32 + " dilation: " + "dilation dimensions must be non-zero", + dilation_width, + dilation_height); + goto error; + } + + if (input_scale <= 0.0f || !isnormal(input_scale)) { + pytorch_qnnp_log_error( + "failed to create deconvolution with %.7g input scale: scale must be finite and positive", + input_scale); + goto error; + } + + if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) { + pytorch_qnnp_log_error( + "failed to create deconvolution with %.7g kernel scale: scale must be finite and positive", + kernel_scale); + goto error; + } + + if (output_scale <= 0.0f || !isnormal(output_scale)) { + pytorch_qnnp_log_error( + "failed to create deconvolution with %.7g output scale: scale must be finite and positive", + output_scale); + goto error; + } + + status = pytorch_qnnp_status_unsupported_parameter; + + const float deconvolution_scale = input_scale * kernel_scale / output_scale; + if (deconvolution_scale >= 1.0f) { + pytorch_qnnp_log_error( + "failed to create deconvolution with %.7g input scale, %.7g kernel scale, and %.7g output scale: " + "deconvolution scale %.7g is greater or equal to 1.0", + input_scale, + kernel_scale, + output_scale, + deconvolution_scale); + goto error; + } + + status = pytorch_qnnp_status_out_of_memory; + + deconvolution = calloc(1, sizeof(struct pytorch_qnnp_operator)); + if (deconvolution == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for pytorch_qnnp_operator structure", + sizeof(struct pytorch_qnnp_operator)); + goto error; + } + + const uint32_t nr = pytorch_qnnp_params.q8conv.nr; + const uint32_t kr = pytorch_qnnp_params.q8conv.kr; + + const uint32_t n_stride = (group_output_channels + (nr - 1)) & -nr; + const uint32_t k_stride = (group_input_channels + (kr - 1)) & -kr; + const uint32_t kernel_size = kernel_height * kernel_width; + const size_t packed_group_weights_size = + (sizeof(uint8_t) * kernel_size * k_stride + sizeof(int32_t)) * n_stride; + deconvolution->packed_weights = malloc(packed_group_weights_size * groups); + if (deconvolution->packed_weights == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for packed weights", + packed_group_weights_size * groups); + goto error; + } + memset( + deconvolution->packed_weights, + kernel_zero_point, + packed_group_weights_size * groups); + + for (uint32_t group = 0; group < groups; group++) { + pytorch_pack_q8deconv_w( + group_output_channels, + kernel_size, + group_input_channels, + nr, + kr, +#if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION + input_zero_point, + kernel_zero_point, +#endif + kernel + + group * group_output_channels * kernel_size * group_input_channels, + bias + group * group_output_channels, + (void*)((uintptr_t)deconvolution->packed_weights + group * packed_group_weights_size)); + } + + size_t zero_size = sizeof(uint8_t) * k_stride; + size_t zero_offset = 0; + if (group_input_channels < 8) { + zero_size += 8; + zero_offset = 8; + } + + void* zero_buffer = malloc(zero_size); + if (zero_buffer == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for zero padding", zero_size); + goto error; + } + memset(zero_buffer, input_zero_point, zero_size); + deconvolution->zero_buffer = zero_buffer; + deconvolution->zero_pointer = (void*)((uintptr_t)zero_buffer + zero_offset); + + deconvolution->input_padding_top = input_padding_top; + deconvolution->input_padding_right = input_padding_right; + deconvolution->input_padding_bottom = input_padding_bottom; + deconvolution->input_padding_left = input_padding_left; + deconvolution->adjustment_height = adjustment_height; + deconvolution->adjustment_width = adjustment_width; + + deconvolution->kernel_height = kernel_height; + deconvolution->kernel_width = kernel_width; + deconvolution->stride_height = stride_height; + deconvolution->stride_width = stride_width; + deconvolution->dilation_height = dilation_height; + deconvolution->dilation_width = dilation_width; + deconvolution->groups = groups; + deconvolution->group_input_channels = group_input_channels; + deconvolution->group_output_channels = group_output_channels; + + deconvolution->kernel_zero_point = kernel_zero_point; + + deconvolution->conv_quantization_params = + pytorch_qnnp_compute_conv_quantization_params( + input_zero_point, + kernel_zero_point, + deconvolution_scale, + output_zero_point, + output_min, + output_max); + + deconvolution->ukernel_type = pytorch_qnnp_ukernel_type_conv; + deconvolution->format = pytorch_qnnp_format_quint8; + + *deconvolution_out = deconvolution; + return pytorch_qnnp_status_success; + +error: + pytorch_qnnp_delete_operator(deconvolution); + return status; +} + +enum pytorch_qnnp_status pytorch_qnnp_setup_deconvolution2d_nhwc_q8( + pytorch_qnnp_operator_t deconvolution, + size_t batch_size, + size_t input_height, + size_t input_width, + const uint8_t* input, + size_t input_pixel_stride, + uint8_t* output, + size_t output_pixel_stride, + pthreadpool_t threadpool) { + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_setup_deconvolution2d_nhwc_q8 failed because QNNPACK is not properly initialized"); + return pytorch_qnnp_status_uninitialized; + } + + if (batch_size == 0) { + deconvolution->batch_size = 0; + return pytorch_qnnp_status_success; + } + + if (input_width == 0 || input_height == 0) { + pytorch_qnnp_log_error( + "failed to setup deconvolution with %zux%zu input: input dimensions must be non-zero", + input_width, + input_height); + return pytorch_qnnp_status_invalid_parameter; + } + + deconvolution->batch_size = batch_size; + deconvolution->input_height = input_height; + deconvolution->input_width = input_width; + deconvolution->input = input; + deconvolution->input_pixel_stride = input_pixel_stride; + deconvolution->output = output; + deconvolution->output_pixel_stride = output_pixel_stride; + + const size_t kernel_height = deconvolution->kernel_height; + const size_t kernel_width = deconvolution->kernel_width; + const size_t kernel_size = kernel_height * kernel_width; + const size_t stride_height = deconvolution->stride_height; + const size_t stride_width = deconvolution->stride_width; + const size_t output_height = deconvolution->output_height = + compute_output_dimension( + input_height, + deconvolution->input_padding_top + + deconvolution->input_padding_bottom, + deconvolution->adjustment_height, + kernel_height, + deconvolution->dilation_height, + stride_height); + const size_t output_width = deconvolution->output_width = + compute_output_dimension( + input_width, + deconvolution->input_padding_left + + deconvolution->input_padding_right, + deconvolution->adjustment_width, + kernel_width, + deconvolution->dilation_width, + stride_width); + + const size_t groups = deconvolution->groups; + const size_t output_size = output_height * output_width; + const size_t output_tile_size = pytorch_qnnp_params.q8conv.mr; + const size_t tiled_output_size = round_up(output_size, output_tile_size); + const size_t indirection_buffer_size = + sizeof(void*) * batch_size * groups * tiled_output_size * kernel_size; + + const void** indirection_buffer = (const void**)realloc( + deconvolution->indirection_buffer, indirection_buffer_size); + if (indirection_buffer == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for indirection buffer", + indirection_buffer_size); + return pytorch_qnnp_status_out_of_memory; + } + deconvolution->indirection_buffer = indirection_buffer; + + pytorch_qnnp_indirection_init_deconv2d( + deconvolution, output_tile_size, tiled_output_size); + + return pytorch_qnnp_status_success; +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/fc-prepack.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/src/fc-prepack.cc new file mode 100644 index 0000000000000..f172c4c576e12 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/fc-prepack.cc @@ -0,0 +1,54 @@ +#include +#include +#include +#include +#include + +namespace qnnpack { +PackBMatrix::PackBMatrix( + const size_t input_channels, + const size_t output_channels, + const uint8_t kernel_zero_point, + const float kernel_scale, + const uint8_t* kernel, + const int32_t* bias) { + if (kernel_scale <= 0.0f || !std::isnormal(kernel_scale)) { + pytorch_qnnp_log_error( + "failed to create fully connected operator with %.7g kernel scale: " + "scale must be finite and positive", + kernel_scale); + assert("QNNPACK Runtime Error."); + } + + const uint32_t nr = pytorch_qnnp_params.q8conv.nr; + const uint32_t kr = pytorch_qnnp_params.q8conv.kr; + + const uint32_t n_stride = (output_channels + (nr - 1)) & -nr; + const uint32_t k_stride = (input_channels + (kr - 1)) & -kr; + + input_channels_ = input_channels; + output_channels_ = output_channels; + packed_weights_ = + malloc(n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t))); + if (packed_weights_ == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for packed weights", + n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t))); + assert("QNNPACK Runtime Error."); + } + memset( + packed_weights_, + kernel_zero_point, + n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t))); + + pytorch_pack_q8gemm_wrq( + output_channels, + input_channels, + nr, + nr, + kr, + kernel, + bias, + packed_weights_); +} +} // namespace qnnpack diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/fc-run.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/src/fc-run.cc new file mode 100644 index 0000000000000..19c737c2e6818 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/fc-run.cc @@ -0,0 +1,126 @@ +#include +#include +#include + +namespace qnnpack { +struct q8gemm_context { + size_t k; + size_t k_stride; + size_t n; + size_t n_stride; + const uint8_t* a; + size_t a_stride; + const uint8_t* packed_w; + uint8_t* c; + size_t c_stride; + union pytorch_qnnp_conv_quantization_params quantization_params; + const pytorch_q8gemm_ukernel_function ukernel; +}; + +static void compute_q8gemm( + const struct q8gemm_context context[1], + size_t group_index, + size_t pixel_index, + size_t mr_block_start, + size_t nr_block_start, + size_t group_range /* always 1 */, + size_t pixel_range, + size_t mr_block_size, + size_t nr_block_size) +{ + const size_t k = context->k; + const size_t k_stride = context->k_stride; + const size_t n = context->n; + const size_t n_stride = context->n_stride; + const uint8_t* a = context->a; + const size_t a_stride = context->a_stride; + const void* packed_w = context->packed_w; + uint8_t* c = context->c; + const size_t c_stride = context->c_stride; + + context->ukernel( + mr_block_size, + nr_block_size, + k, + a + (pixel_index + mr_block_start) * a_stride + group_index * k, + a_stride, + (const void*) ((uintptr_t) packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))), + c + (pixel_index + mr_block_start) * c_stride + nr_block_start + group_index * n, + c_stride, + &context->quantization_params); +} + +enum pytorch_qnnp_status qnnpackLinear( + const size_t batch_size, + const size_t input_channels, + const size_t output_channels, + const uint8_t input_zero_point, + const float input_scale, + const uint8_t kernel_zero_point, + const float kernel_scale, + const uint8_t output_zero_point, + const float output_scale, + const uint8_t output_min, + const uint8_t output_max, + const uint8_t* input, + const size_t input_stride, + void* packed_weights, + uint8_t* output, + const size_t output_stride, + pthreadpool_t threadpool) +{ + const size_t groups = 1; + const size_t group_input_channels = input_channels; + const size_t group_output_channels = output_channels; + const uint32_t mr = pytorch_qnnp_params.q8conv.mr; + const uint32_t nr = pytorch_qnnp_params.q8conv.nr; + const uint32_t kr = pytorch_qnnp_params.q8conv.kr; + const size_t k_stride = (group_input_channels + (kr - 1)) & -kr; + const size_t n_stride = (group_output_channels + (nr - 1)) & -nr; + + const size_t output_size = batch_size * 1; + const float requantization_scale = input_scale * kernel_scale / output_scale; + if (requantization_scale >= 1.0f) { + pytorch_qnnp_log_error( + "failed to create fully connected operator with %.7g input scale, %.7g " + "kernel scale, and %.7g output scale: " + "requantization scale %.7g is greater or equal to 1.0", + input_scale, + kernel_scale, + output_scale, + requantization_scale); + return pytorch_qnnp_status_unsupported_parameter; + } + union pytorch_qnnp_conv_quantization_params conv_quantization_params = pytorch_qnnp_compute_conv_quantization_params( + input_zero_point, kernel_zero_point, requantization_scale, output_zero_point, output_min, output_max); + + struct q8gemm_context q8gemm_context = { + .k = group_input_channels, + .k_stride = k_stride, + .n = group_output_channels, + .n_stride = n_stride, + .a = input, + .a_stride = input_stride, + .packed_w = (uint8_t*) packed_weights, + .c = output, + .c_stride = output_stride, + .quantization_params = conv_quantization_params, + .ukernel = pytorch_qnnp_params.q8conv.gemm, + }; + + pthreadpool_compute_4d_tiled( + threadpool, + (pthreadpool_function_4d_tiled_t) compute_q8gemm, + &q8gemm_context, + groups, + 1 * output_size, + output_size, + group_output_channels, + 1, + output_size, + mr, + nr); + + return pytorch_qnnp_status_success; +} +} // namespace qnnpack diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/fully-connected.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/fully-connected.c new file mode 100644 index 0000000000000..c650d521f5b6f --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/fully-connected.c @@ -0,0 +1,184 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +enum pytorch_qnnp_status pytorch_qnnp_create_fully_connected_nc_q8( + size_t input_channels, + size_t output_channels, + uint8_t input_zero_point, + float input_scale, + uint8_t kernel_zero_point, + float kernel_scale, + const uint8_t* kernel, + const int32_t* bias, + uint8_t output_zero_point, + float output_scale, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + pytorch_qnnp_operator_t* fully_connected_out) { + pytorch_qnnp_operator_t fully_connected = NULL; + enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized; + + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_create_fully_connected_nc_q8 failed because QNNPACK is not properly initialized"); + goto error; + } + + status = pytorch_qnnp_status_invalid_parameter; + + if (input_scale <= 0.0f || !isnormal(input_scale)) { + pytorch_qnnp_log_error( + "failed to create fully connected operator with %.7g input scale: scale must be finite and positive", + input_scale); + goto error; + } + + if (kernel_scale <= 0.0f || !isnormal(kernel_scale)) { + pytorch_qnnp_log_error( + "failed to create fully connected operator with %.7g kernel scale: scale must be finite and positive", + kernel_scale); + goto error; + } + + if (output_scale <= 0.0f || !isnormal(output_scale)) { + pytorch_qnnp_log_error( + "failed to create fully connected operator with %.7g output scale: scale must be finite and positive", + output_scale); + goto error; + } + + status = pytorch_qnnp_status_unsupported_parameter; + + const float requantization_scale = input_scale * kernel_scale / output_scale; + if (requantization_scale >= 1.0f) { + pytorch_qnnp_log_error( + "failed to create fully connected operator with %.7g input scale, %.7g kernel scale, and %.7g output scale: " + "requantization scale %.7g is greater or equal to 1.0", + input_scale, + kernel_scale, + output_scale, + requantization_scale); + goto error; + } + + status = pytorch_qnnp_status_out_of_memory; + + fully_connected = calloc(1, sizeof(struct pytorch_qnnp_operator)); + if (fully_connected == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for pytorch_qnnp_operator structure", + sizeof(struct pytorch_qnnp_operator)); + goto error; + } + + const uint32_t nr = pytorch_qnnp_params.q8conv.nr; + const uint32_t kr = pytorch_qnnp_params.q8conv.kr; + + const uint32_t n_stride = (output_channels + (nr - 1)) & -nr; + const uint32_t k_stride = (input_channels + (kr - 1)) & -kr; + + fully_connected->packed_weights = + malloc(n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t))); + if (fully_connected->packed_weights == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for packed weights", + n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t))); + goto error; + } + memset( + fully_connected->packed_weights, + kernel_zero_point, + n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t))); + + pytorch_pack_q8gemm_w( + output_channels, + input_channels, + nr, + nr, + kr, +#if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION + input_zero_point, + kernel_zero_point, +#endif + kernel, + bias, + fully_connected->packed_weights); + + fully_connected->groups = 1; + fully_connected->group_input_channels = input_channels; + fully_connected->group_output_channels = output_channels; + + fully_connected->kernel_zero_point = kernel_zero_point; + + fully_connected->conv_quantization_params = + pytorch_qnnp_compute_conv_quantization_params( + input_zero_point, + kernel_zero_point, + requantization_scale, + output_zero_point, + output_min, + output_max); + + fully_connected->ukernel_type = pytorch_qnnp_ukernel_type_gemm; + fully_connected->format = pytorch_qnnp_format_quint8; + + *fully_connected_out = fully_connected; + return pytorch_qnnp_status_success; + +error: + pytorch_qnnp_delete_operator(fully_connected); + return status; +} + +enum pytorch_qnnp_status pytorch_qnnp_setup_fully_connected_nc_q8( + pytorch_qnnp_operator_t fully_connected, + size_t batch_size, + const uint8_t* input, + size_t input_stride, + uint8_t* output, + size_t output_stride) { + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_setup_fully_connected_nc_q8 failed because QNNPACK is not properly initialized"); + return pytorch_qnnp_status_uninitialized; + } + + if (batch_size == 0) { + fully_connected->batch_size = 0; + return pytorch_qnnp_status_success; + } + + fully_connected->batch_size = 1; + fully_connected->input_height = batch_size; + fully_connected->input_width = 1; + fully_connected->input = input; + fully_connected->input_pixel_stride = input_stride; + + fully_connected->output_height = batch_size; + fully_connected->output_width = 1; + fully_connected->output = output; + fully_connected->output_pixel_stride = output_stride; + + return pytorch_qnnp_status_success; +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/global-average-pooling.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/global-average-pooling.c new file mode 100644 index 0000000000000..00be41ec7c176 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/global-average-pooling.c @@ -0,0 +1,158 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +enum pytorch_qnnp_status pytorch_qnnp_create_global_average_pooling_nwc_q8( + size_t channels, + uint8_t input_zero_point, + float input_scale, + uint8_t output_zero_point, + float output_scale, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + pytorch_qnnp_operator_t* global_average_pooling_out) { + pytorch_qnnp_operator_t global_average_pooling_op = NULL; + enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized; + + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_create_global_average_pooling_nwc_q8 failed because QNNPACK is not properly initialized"); + goto error; + } + + status = pytorch_qnnp_status_invalid_parameter; + + if (channels == 0) { + pytorch_qnnp_log_error( + "failed to create global average pooling operator with %zu channels: number of channels must be non-zero", + channels); + goto error; + } + + if (input_scale <= 0.0f || !isnormal(input_scale)) { + pytorch_qnnp_log_error( + "failed to create global average pooling operator with %.7g input scale: scale must be finite and positive", + input_scale); + goto error; + } + + if (output_scale <= 0.0f || !isnormal(output_scale)) { + pytorch_qnnp_log_error( + "failed to create global average pooling operator with %.7g output scale: scale must be finite and positive", + output_scale); + goto error; + } + + status = pytorch_qnnp_status_unsupported_parameter; + + const float input_output_scale = input_scale / output_scale; + if (input_output_scale < 0x1.0p-8f || input_output_scale >= 0x1.0p+8f) { + pytorch_qnnp_log_error( + "failed to create global average pooling operator with %.7g input-to-output scale ratio: " + "scale ratio must be in [2**-8, 2**8) range", + input_output_scale); + goto error; + } + + status = pytorch_qnnp_status_out_of_memory; + + global_average_pooling_op = calloc(1, sizeof(struct pytorch_qnnp_operator)); + if (global_average_pooling_op == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for pytorch_qnnp_operator structure", + sizeof(struct pytorch_qnnp_operator)); + goto error; + } + + void* zero_buffer = calloc(channels, sizeof(uint8_t)); + if (zero_buffer == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for zero padding", + channels * sizeof(uint8_t)); + goto error; + } + global_average_pooling_op->zero_buffer = zero_buffer; + global_average_pooling_op->zero_pointer = zero_buffer; + + global_average_pooling_op->channels = channels; + global_average_pooling_op->input_zero_point = input_zero_point; + global_average_pooling_op->output_zero_point = output_zero_point; + global_average_pooling_op->input_scale = input_scale; + global_average_pooling_op->output_scale = output_scale; + global_average_pooling_op->output_min = output_min; + global_average_pooling_op->output_max = output_max; + + global_average_pooling_op->ukernel_type = + pytorch_qnnp_ukernel_type_global_average_pooling; + global_average_pooling_op->format = pytorch_qnnp_format_quint8; + + *global_average_pooling_out = global_average_pooling_op; + return pytorch_qnnp_status_success; + +error: + pytorch_qnnp_delete_operator(global_average_pooling_op); + return status; +} + +enum pytorch_qnnp_status pytorch_qnnp_setup_global_average_pooling_nwc_q8( + pytorch_qnnp_operator_t global_average_pooling_op, + size_t batch_size, + size_t width, + const uint8_t* input, + size_t input_stride, + uint8_t* output, + size_t output_stride) { + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_setup_global_average_pooling_nwc_q8 failed because QNNPACK is not properly initialized"); + return pytorch_qnnp_status_uninitialized; + } + + if (batch_size == 0) { + global_average_pooling_op->batch_size = 0; + return pytorch_qnnp_status_success; + } + + if (width == 0) { + pytorch_qnnp_log_error( + "failed to setup global average pooling operator with width %zu: width must be non-zero", + width); + return pytorch_qnnp_status_invalid_parameter; + } + + global_average_pooling_op->batch_size = batch_size; + global_average_pooling_op->input_width = width; + global_average_pooling_op->input = input; + global_average_pooling_op->input_pixel_stride = input_stride; + global_average_pooling_op->output = output; + global_average_pooling_op->output_pixel_stride = output_stride; + + global_average_pooling_op->avgpool_quantization_params = + pytorch_qnnp_compute_avgpool_quantization_params( + -(int32_t)width * + (int32_t)(uint32_t)global_average_pooling_op->input_zero_point, + global_average_pooling_op->input_scale / + (global_average_pooling_op->output_scale * (float)width), + global_average_pooling_op->output_zero_point, + global_average_pooling_op->output_min, + global_average_pooling_op->output_max); + + return pytorch_qnnp_status_success; +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/hgemm/8x8-aarch32-neonfp16arith.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/hgemm/8x8-aarch32-neonfp16arith.S new file mode 100644 index 0000000000000..995ff104ecb35 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/hgemm/8x8-aarch32-neonfp16arith.S @@ -0,0 +1,512 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +.syntax unified + +# void pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith( +# size_t mr, +# size_t nr, +# size_t k, +# const __fp16*restrict a, +# size_t a_stride, +# const __fp16*restrict w, +# __fp16*restrict c, +# size_t c_stride, +# const struct pytorch_qnnp_fp16_clamping_params clamping_params[restrict static 1]) +BEGIN_FUNCTION pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith + .arm +#ifndef __APPLE__ + .arch armv7-a + .fpu neon +#endif + # Load w + # - ip = w + LDR ip, [sp, 4] + PUSH {r4, r5, r6, r7, r8, r9, r10, r11} + + VPUSH {d8-d15} + + # Initialize vacc0x01234567 + # - q8 = d16:d17 := vacc0x01234567 = bias01234567 + VLD1.16 {d16-d17}, [ip:64]! + + # Load a_stride + # - r10 = a_stride + LDR r10, [sp, 96] + + # Initialize vacc1x01234567 + # - q9 := vacc1x01234567 = vacc0x01234567 + VMOV.I16 q9, q8 + + # Initialize vacc2x01234567 + # - q10 := vacc2x01234567 = vacc0x01234567 + VMOV.I16 q10, q8 + + # Initialize vacc3x01234567 + # - q11 := vacc3x01234567 = vacc0x01234567 + VMOV.I16 q11, q8 + + # Initialize vacc4x01234567 + # - q12 := vacc4x01234567 = vacc0x01234567 + VMOV.I16 q12, q8 + + # Initialize vacc5x01234567 + # - q13 := vacc5x01234567 = vacc0x01234567 + VMOV.I16 q13, q8 + + # Initialize vacc6x01234567 + # - q14 := vacc6x01234567 = vacc0x01234567 + VMOV.I16 q14, q8 + + # Initialize vacc7x01234567 + # - q15 := vacc7x01234567 = vacc0x01234567 + VMOV.I16 q15, q8 + + CMP r0, 2 + ADD r4, r3, r10 + MOVLO r4, r3 + ADD r5, r4, r10 + MOVLS r5, r4 + + CMP r0, 4 + ADD r6, r5, r10 + MOVLO r6, r5 + ADD r7, r6, r10 + MOVLS r7, r6 + + CMP r0, 6 + ADD r8, r7, r10 + MOVLO r8, r7 + ADD r9, r8, r10 + MOVLS r9, r8 + + CMP r0, 8 + ADD r10, r9, r10 + MOVNE r10, r9 + + SUBS r2, r2, 4 + BLO 1f + + .p2align 5 +0: + # Load a0 + # - d0 = a0 + VLD1.16 {d0}, [r3]! + + # Load a1 + # - d1 = a1 + VLD1.16 {d1}, [r4]! + + # Load a2 + # - d2 = a2 + VLD1.16 {d2}, [r5]! + + # Load a3 + # - d3 = a3 + VLD1.16 {d3}, [r6]! + + # Load a4 + # - d4 = a4 + VLD1.16 {d4}, [r7]! + + # Load a5 + # - d5 = a5 + VLD1.16 {d5}, [r8]! + + # Load a6 + # - d6 = a6 + VLD1.16 {d6}, [r9]! + + # Load a7 + # - d7 = a7 + VLD1.16 {d7}, [r10]! + + ### Channel 0 ### + + # Load b0-b15 (channel 0) + # - q4 = d8:d9 = b0-b15 + VLD1.8 {d8-d9}, [ip:64]! + + # vacc0x01234567 += vb01234567 * va0[0]; + .word 0xF3D80140 @ VMLA.F16 q8, q4, d0[0] + # vacc1x01234567 += vb01234567 * va1[0]; + .word 0xF3D82141 @ VMLA.F16 q9, q4, d1[0] + # vacc2x01234567 += vb01234567 * va2[0]; + .word 0xF3D84142 @ VMLA.F16 q10, q4, d2[0] + # vacc3x01234567 += vb01234567 * va3[0]; + .word 0xF3D86143 @ VMLA.F16 q11, q4, d3[0] + # vacc4x01234567 += vb01234567 * va4[0]; + .word 0xF3D88144 @ VMLA.F16 q12, q4, d4[0] + # vacc5x01234567 += vb01234567 * va5[0]; + .word 0xF3D8A145 @ VMLA.F16 q13, q4, d5[0] + # vacc6x01234567 += vb01234567 * va6[0]; + .word 0xF3D8C146 @ VMLA.F16 q14, q4, d6[0] + # vacc7x01234567 += vb01234567 * va7[0]; + .word 0xF3D8E147 @ VMLA.F16 q15, q4, d7[0] + + ### Channel 1 ### + + # Load b0-b15 (channel 1) + # - q5 = d10:d11 = b0-b15 + VLD1.8 {d10-d11}, [ip:64]! + + # vacc0x01234567 += vb01234567 * va0[1]; + .word 0xF3DA0148 @ VMLA.F16 q8, q5, d0[1] + # vacc1x01234567 += vb01234567 * va1[1]; + .word 0xF3DA2149 @ VMLA.F16 q9, q5, d1[1] + # vacc2x01234567 += vb01234567 * va2[1]; + .word 0xF3DA414A @ VMLA.F16 q10, q5, d2[1] + # vacc3x01234567 += vb01234567 * va3[1]; + .word 0xF3DA614B @ VMLA.F16 q11, q5, d3[1] + # vacc4x01234567 += vb01234567 * va4[1]; + .word 0xF3DA814C @ VMLA.F16 q12, q5, d4[1] + # vacc5x01234567 += vb01234567 * va5[1]; + .word 0xF3DAA14D @ VMLA.F16 q13, q5, d5[1] + # vacc6x01234567 += vb01234567 * va6[1]; + .word 0xF3DAC14E @ VMLA.F16 q14, q5, d6[1] + # vacc7x01234567 += vb01234567 * va7[1]; + .word 0xF3DAE14F @ VMLA.F16 q15, q5, d7[1] + + ### Channel 2 ### + + # Load b0-b15 (channel 2) + # - q6 = d12:d13 = b0-b15 + VLD1.8 {d12-d13}, [ip:64]! + + # vacc0x01234567 += vb01234567 * va0[2]; + .word 0xF3DC0160 @ VMLA.F16 q8, q6, d0[2] + # vacc1x01234567 += vb01234567 * va1[2]; + .word 0xF3DC2161 @ VMLA.F16 q9, q6, d1[2] + # vacc2x01234567 += vb01234567 * va2[2]; + .word 0xF3DC4162 @ VMLA.F16 q10, q6, d2[2] + # vacc3x01234567 += vb01234567 * va3[2]; + .word 0xF3DC6163 @ VMLA.F16 q11, q6, d3[2] + # vacc4x01234567 += vb01234567 * va4[2]; + .word 0xF3DC8164 @ VMLA.F16 q12, q6, d4[2] + # vacc5x01234567 += vb01234567 * va5[2]; + .word 0xF3DCA165 @ VMLA.F16 q13, q6, d5[2] + # vacc6x01234567 += vb01234567 * va6[2]; + .word 0xF3DCC166 @ VMLA.F16 q14, q6, d6[2] + # vacc7x01234567 += vb01234567 * va7[2]; + .word 0xF3DCE167 @ VMLA.F16 q15, q6, d7[2] + + ### Channel 3 ### + + # Load b0-b15 (channel 3) + # - q7 = d14:d15 = b0-b15 + VLD1.8 {d14-d15}, [ip:64]! + + # vacc0x01234567 += vb01234567 * va0[3]; + .word 0xF3DE0168 @ VMLA.F16 q8, q7, d0[3] + # vacc1x01234567 += vb01234567 * va1[3]; + .word 0xF3DE2169 @ VMLA.F16 q9, q7, d1[3] + # vacc2x01234567 += vb01234567 * va2[3]; + .word 0xF3DE416A @ VMLA.F16 q10, q7, d2[3] + # vacc3x01234567 += vb01234567 * va3[3]; + .word 0xF3DE616B @ VMLA.F16 q11, q7, d3[3] + # vacc4x01234567 += vb01234567 * va4[3]; + .word 0xF3DE816C @ VMLA.F16 q12, q7, d4[3] + # vacc5x01234567 += vb01234567 * va5[3]; + .word 0xF3DEA16D @ VMLA.F16 q13, q7, d5[3] + # vacc6x01234567 += vb01234567 * va6[3]; + .word 0xF3DEC16E @ VMLA.F16 q14, q7, d6[3] + # vacc7x01234567 += vb01234567 * va7[3]; + .word 0xF3DEE16F @ VMLA.F16 q15, q7, d7[3] + + SUBS r2, r2, 4 + BHS 0b + +1: + CMP r2, -4 + BEQ 2f + + ADD r3, r3, r2, LSL #1 + ADD r4, r4, r2, LSL #1 + ADD r5, r5, r2, LSL #1 + ADD r6, r6, r2, LSL #1 + ADD r7, r7, r2, LSL #1 + ADD r8, r8, r2, LSL #1 + ADD r9, r9, r2, LSL #1 + ADD r10, r10, r2, LSL #1 + + LSL r2, r2, 4 + VDUP.32 d14, r2 + + # Load a0 + # - d0 = a0 + VLD1.16 {d0}, [r3]! + VSHL.U64 d0, d0, d14 + + # Load a1 + # - d1 = a1 + VLD1.16 {d1}, [r4]! + VSHL.U64 d1, d1, d14 + + # Load a2 + # - d2 = a2 + VLD1.16 {d2}, [r5]! + VSHL.U64 d2, d2, d14 + + # Load a3 + # - d3 = a3 + VLD1.16 {d3}, [r6]! + VSHL.U64 d3, d3, d14 + + # Load a4 + # - d4 = a4 + VLD1.16 {d4}, [r7]! + VSHL.U64 d4, d4, d14 + + # Load a5 + # - d5 = a5 + VLD1.16 {d5}, [r8]! + VSHL.U64 d5, d5, d14 + + # Load a6 + # - d6 = a6 + VLD1.16 {d6}, [r9]! + VSHL.U64 d6, d6, d14 + + # Load a7 + # - d7 = a7 + VLD1.16 {d7}, [r10]! + VSHL.U64 d7, d7, d14 + + ### Channel 0 ### + + # Load b0-b15 (channel 0) + # - q4 = d8:d9 = b0-b15 + VLD1.8 {d8-d9}, [ip:64]! + + # vacc0x01234567 += vb01234567 * va0[0]; + .word 0xF3D80140 @ VMLA.F16 q8, q4, d0[0] + # vacc1x01234567 += vb01234567 * va1[0]; + .word 0xF3D82141 @ VMLA.F16 q9, q4, d1[0] + # vacc2x01234567 += vb01234567 * va2[0]; + .word 0xF3D84142 @ VMLA.F16 q10, q4, d2[0] + # vacc3x01234567 += vb01234567 * va3[0]; + .word 0xF3D86143 @ VMLA.F16 q11, q4, d3[0] + # vacc4x01234567 += vb01234567 * va4[0]; + .word 0xF3D88144 @ VMLA.F16 q12, q4, d4[0] + # vacc5x01234567 += vb01234567 * va5[0]; + .word 0xF3D8A145 @ VMLA.F16 q13, q4, d5[0] + # vacc6x01234567 += vb01234567 * va6[0]; + .word 0xF3D8C146 @ VMLA.F16 q14, q4, d6[0] + # vacc7x01234567 += vb01234567 * va7[0]; + .word 0xF3D8E147 @ VMLA.F16 q15, q4, d7[0] + + CMP r2, -32 + BLO 2f + + ### Channel 1 ### + + # Load b0-b15 (channel 1) + # - q5 = d10:d11 = b0-b15 + VLD1.8 {d10-d11}, [ip:64]! + + # vacc0x01234567 += vb01234567 * va0[1]; + .word 0xF3DA0148 @ VMLA.F16 q8, q5, d0[1] + # vacc1x01234567 += vb01234567 * va1[1]; + .word 0xF3DA2149 @ VMLA.F16 q9, q5, d1[1] + # vacc2x01234567 += vb01234567 * va2[1]; + .word 0xF3DA414A @ VMLA.F16 q10, q5, d2[1] + # vacc3x01234567 += vb01234567 * va3[1]; + .word 0xF3DA614B @ VMLA.F16 q11, q5, d3[1] + # vacc4x01234567 += vb01234567 * va4[1]; + .word 0xF3DA814C @ VMLA.F16 q12, q5, d4[1] + # vacc5x01234567 += vb01234567 * va5[1]; + .word 0xF3DAA14D @ VMLA.F16 q13, q5, d5[1] + # vacc6x01234567 += vb01234567 * va6[1]; + .word 0xF3DAC14E @ VMLA.F16 q14, q5, d6[1] + # vacc7x01234567 += vb01234567 * va7[1]; + .word 0xF3DAE14F @ VMLA.F16 q15, q5, d7[1] + + BLS 2f + + ### Channel 2 ### + + # Load b0-b15 (channel 2) + # - q6 = d12:d13 = b0-b15 + VLD1.8 {d12-d13}, [ip:64]! + + # vacc0x01234567 += vb01234567 * va0[2]; + .word 0xF3DC0160 @ VMLA.F16 q8, q6, d0[2] + # vacc1x01234567 += vb01234567 * va1[2]; + .word 0xF3DC2161 @ VMLA.F16 q9, q6, d1[2] + # vacc2x01234567 += vb01234567 * va2[2]; + .word 0xF3DC4162 @ VMLA.F16 q10, q6, d2[2] + # vacc3x01234567 += vb01234567 * va3[2]; + .word 0xF3DC6163 @ VMLA.F16 q11, q6, d3[2] + # vacc4x01234567 += vb01234567 * va4[2]; + .word 0xF3DC8164 @ VMLA.F16 q12, q6, d4[2] + # vacc5x01234567 += vb01234567 * va5[2]; + .word 0xF3DCA165 @ VMLA.F16 q13, q6, d5[2] + # vacc6x01234567 += vb01234567 * va6[2]; + .word 0xF3DCC166 @ VMLA.F16 q14, q6, d6[2] + # vacc7x01234567 += vb01234567 * va7[2]; + .word 0xF3DCE167 @ VMLA.F16 q15, q6, d7[2] + + .p2align 4 +2: + # Load params: + # - ip = params + LDR ip, [sp, 112] + + # Load scale: + # - q0 = d0:d1 = vscale + VLD1.16 {d0[], d1[]}, [ip]! + + .word 0xF3500DD0 @ VMUL.F16 q8, q8, q0 + .word 0xF3522DD0 @ VMUL.F16 q9, q9, q0 + .word 0xF3544DD0 @ VMUL.F16 q10, q10, q0 + .word 0xF3566DD0 @ VMUL.F16 q11, q11, q0 + .word 0xF3588DD0 @ VMUL.F16 q12, q12, q0 + .word 0xF35AADD0 @ VMUL.F16 q13, q13, q0 + .word 0xF35CCDD0 @ VMUL.F16 q14, q14, q0 + .word 0xF35EEDD0 @ VMUL.F16 q15, q15, q0 + + # Load max: + # - q1 = d2:d3 = vmax + VLD1.16 {d2[], d3[]}, [ip]! + + .word 0xF2700FC2 @ VMIN.F16 q8, q8, q1 + .word 0xF2722FC2 @ VMIN.F16 q9, q9, q1 + .word 0xF2744FC2 @ VMIN.F16 q10, q10, q1 + .word 0xF2766FC2 @ VMIN.F16 q11, q11, q1 + .word 0xF2788FC2 @ VMIN.F16 q12, q12, q1 + .word 0xF27AAFC2 @ VMIN.F16 q13, q13, q1 + .word 0xF27CCFC2 @ VMIN.F16 q14, q14, q1 + .word 0xF27EEFC2 @ VMIN.F16 q15, q15, q1 + + # Load min: + # - q2 = d4:d5 = vmin + VLD1.16 {d4[], d5[]}, [ip] + + .word 0xF2500FC4 @ VMAX.F16 q8, q8, q2 + .word 0xF2522FC4 @ VMAX.F16 q9, q9, q2 + .word 0xF2544FC4 @ VMAX.F16 q10, q10, q2 + .word 0xF2566FC4 @ VMAX.F16 q11, q11, q2 + .word 0xF2588FC4 @ VMAX.F16 q12, q12, q2 + .word 0xF25AAFC4 @ VMAX.F16 q13, q13, q2 + .word 0xF25CCFC4 @ VMAX.F16 q14, q14, q2 + .word 0xF25EEFC4 @ VMAX.F16 q15, q15, q2 + + # Load c, c_stride: + # - r2 = c + # - r3 = c_stride + LDRD r2, r3, [sp, 104] + + CMP r0, 2 + ADD r4, r2, r3 + MOVLO r4, r2 + ADD r5, r4, r3 + MOVLS r5, r4 + + CMP r0, 4 + ADD r6, r5, r3 + MOVLO r6, r5 + ADD r7, r6, r3 + MOVLS r7, r6 + + CMP r0, 6 + ADD r8, r7, r3 + MOVLO r8, r7 + ADD r9, r8, r3 + MOVLS r9, r8 + + CMP r0, 8 + ADD r3, r9, r3 + MOVNE r3, r9 + + CMP r1, 8 + BNE 4f + + VST1.16 {d16-d17}, [r2] + VST1.16 {d18-d19}, [r4] + VST1.16 {d20-d21}, [r5] + VST1.16 {d22-d23}, [r6] + VST1.16 {d24-d25}, [r7] + VST1.16 {d26-d27}, [r8] + VST1.16 {d28-d29}, [r9] + VST1.16 {d30-d31}, [r3] + + VPOP {d8-d15} + POP {r4, r5, r6, r7, r8, r9, r10, r11} + BX lr + + .p2align 3 +4: + CMP r1, 4 + BLO 5f + + VST1.16 {d16}, [r2]! + VST1.16 {d18}, [r4]! + VST1.16 {d20}, [r5]! + VST1.16 {d22}, [r6]! + VST1.16 {d24}, [r7]! + VST1.16 {d26}, [r8]! + VST1.16 {d28}, [r9]! + VST1.16 {d30}, [r3]! + + SUB r1, 4 + VMOV.I16 d16, d17 + VMOV.I16 d18, d19 + VMOV.I16 d20, d21 + VMOV.I16 d22, d23 + VMOV.I16 d24, d25 + VMOV.I16 d26, d27 + VMOV.I16 d28, d29 + VMOV.I16 d30, d31 + +5: + CMP r1, 2 + BLO 6f + + VST1.32 {d16[0]}, [r2]! + VST1.32 {d18[0]}, [r4]! + VST1.32 {d20[0]}, [r5]! + VST1.32 {d22[0]}, [r6]! + VST1.32 {d24[0]}, [r7]! + VST1.32 {d26[0]}, [r8]! + VST1.32 {d28[0]}, [r9]! + VST1.32 {d30[0]}, [r3]! + + SUB r1, 2 + VEXT.8 d16, d16, d16, 4 + VEXT.8 d18, d18, d18, 4 + VEXT.8 d20, d20, d20, 4 + VEXT.8 d22, d22, d22, 4 + VEXT.8 d24, d24, d24, 4 + VEXT.8 d26, d26, d26, 4 + VEXT.8 d28, d28, d28, 4 + VEXT.8 d30, d30, d30, 4 + +6: + TEQ r1, 0 + BEQ 7f + + VST1.16 {d16[0]}, [r2] + VST1.16 {d18[0]}, [r4] + VST1.16 {d20[0]}, [r5] + VST1.16 {d22[0]}, [r6] + VST1.16 {d24[0]}, [r7] + VST1.16 {d26[0]}, [r8] + VST1.16 {d28[0]}, [r9] + VST1.16 {d30[0]}, [r3] + +7: + VPOP {d8-d15} + POP {r4, r5, r6, r7, r8, r9, r10, r11} + BX lr +END_FUNCTION pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith + +#ifdef __ELF__ +.section ".note.GNU-stack","",%progbits +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/hgemm/8x8-neonfp16arith.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/hgemm/8x8-neonfp16arith.c new file mode 100644 index 0000000000000..0de4dc4d1535c --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/hgemm/8x8-neonfp16arith.c @@ -0,0 +1,370 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_hgemm_ukernel_8x8__neonfp16arith( + size_t mr, + size_t nr, + size_t k, + const void* restrict a, + size_t a_stride, + const void* restrict w, + void* restrict c, + size_t c_stride, + const struct pytorch_qnnp_fp16_clamping_params + clamping_params[restrict static 1]) { + float16x8_t vacc0x01234567 = vld1q_f16(w); + w = (void*)((uintptr_t)w + sizeof(float16x8_t)); + float16x8_t vacc1x01234567 = vacc0x01234567; + float16x8_t vacc2x01234567 = vacc0x01234567; + float16x8_t vacc3x01234567 = vacc0x01234567; + float16x8_t vacc4x01234567 = vacc0x01234567; + float16x8_t vacc5x01234567 = vacc0x01234567; + float16x8_t vacc6x01234567 = vacc0x01234567; + float16x8_t vacc7x01234567 = vacc0x01234567; + + const __fp16* a0 = a; + const __fp16* a1 = (const __fp16*)((uintptr_t)a0 + a_stride); + if (mr < 2) { + a1 = a0; + } + const __fp16* a2 = (const __fp16*)((uintptr_t)a1 + a_stride); + if (mr <= 2) { + a2 = a1; + } + const __fp16* a3 = (const __fp16*)((uintptr_t)a2 + a_stride); + if (mr < 4) { + a3 = a2; + } + const __fp16* a4 = (const __fp16*)((uintptr_t)a3 + a_stride); + if (mr <= 4) { + a4 = a3; + } + const __fp16* a5 = (const __fp16*)((uintptr_t)a4 + a_stride); + if (mr < 6) { + a5 = a4; + } + const __fp16* a6 = (const __fp16*)((uintptr_t)a5 + a_stride); + if (mr <= 6) { + a6 = a5; + } + const __fp16* a7 = (const __fp16*)((uintptr_t)a6 + a_stride); + if (mr != 8) { + a7 = a6; + } + + for (; k >= 4; k -= 4) { + const float16x4_t va0 = vld1_f16(a0); + a0 += 4; + const float16x4_t va1 = vld1_f16(a1); + a1 += 4; + const float16x4_t va2 = vld1_f16(a2); + a2 += 4; + const float16x4_t va3 = vld1_f16(a3); + a3 += 4; + const float16x4_t va4 = vld1_f16(a4); + a4 += 4; + const float16x4_t va5 = vld1_f16(a5); + a5 += 4; + const float16x4_t va6 = vld1_f16(a6); + a6 += 4; + const float16x4_t va7 = vld1_f16(a7); + a7 += 4; + + { + const float16x8_t vb01234567 = vld1q_f16(w); + w = (void*)((uintptr_t)w + sizeof(float16x8_t)); + + vacc0x01234567 = vmlaq_lane_f16(vacc0x01234567, vb01234567, va0, 0); + vacc1x01234567 = vmlaq_lane_f16(vacc1x01234567, vb01234567, va1, 0); + vacc2x01234567 = vmlaq_lane_f16(vacc2x01234567, vb01234567, va2, 0); + vacc3x01234567 = vmlaq_lane_f16(vacc3x01234567, vb01234567, va3, 0); + vacc4x01234567 = vmlaq_lane_f16(vacc4x01234567, vb01234567, va4, 0); + vacc5x01234567 = vmlaq_lane_f16(vacc5x01234567, vb01234567, va5, 0); + vacc6x01234567 = vmlaq_lane_f16(vacc6x01234567, vb01234567, va6, 0); + vacc7x01234567 = vmlaq_lane_f16(vacc7x01234567, vb01234567, va7, 0); + } + + { + const float16x8_t vb01234567 = vld1q_f16(w); + w = (void*)((uintptr_t)w + sizeof(float16x8_t)); + + vacc0x01234567 = vmlaq_lane_f16(vacc0x01234567, vb01234567, va0, 1); + vacc1x01234567 = vmlaq_lane_f16(vacc1x01234567, vb01234567, va1, 1); + vacc2x01234567 = vmlaq_lane_f16(vacc2x01234567, vb01234567, va2, 1); + vacc3x01234567 = vmlaq_lane_f16(vacc3x01234567, vb01234567, va3, 1); + vacc4x01234567 = vmlaq_lane_f16(vacc4x01234567, vb01234567, va4, 1); + vacc5x01234567 = vmlaq_lane_f16(vacc5x01234567, vb01234567, va5, 1); + vacc6x01234567 = vmlaq_lane_f16(vacc6x01234567, vb01234567, va6, 1); + vacc7x01234567 = vmlaq_lane_f16(vacc7x01234567, vb01234567, va7, 1); + } + + { + const float16x8_t vb01234567 = vld1q_f16(w); + w = (void*)((uintptr_t)w + sizeof(float16x8_t)); + + vacc0x01234567 = vmlaq_lane_f16(vacc0x01234567, vb01234567, va0, 2); + vacc1x01234567 = vmlaq_lane_f16(vacc1x01234567, vb01234567, va1, 2); + vacc2x01234567 = vmlaq_lane_f16(vacc2x01234567, vb01234567, va2, 2); + vacc3x01234567 = vmlaq_lane_f16(vacc3x01234567, vb01234567, va3, 2); + vacc4x01234567 = vmlaq_lane_f16(vacc4x01234567, vb01234567, va4, 2); + vacc5x01234567 = vmlaq_lane_f16(vacc5x01234567, vb01234567, va5, 2); + vacc6x01234567 = vmlaq_lane_f16(vacc6x01234567, vb01234567, va6, 2); + vacc7x01234567 = vmlaq_lane_f16(vacc7x01234567, vb01234567, va7, 2); + } + + { + const float16x8_t vb01234567 = vld1q_f16(w); + w = (void*)((uintptr_t)w + sizeof(float16x8_t)); + + vacc0x01234567 = vmlaq_lane_f16(vacc0x01234567, vb01234567, va0, 3); + vacc1x01234567 = vmlaq_lane_f16(vacc1x01234567, vb01234567, va1, 3); + vacc2x01234567 = vmlaq_lane_f16(vacc2x01234567, vb01234567, va2, 3); + vacc3x01234567 = vmlaq_lane_f16(vacc3x01234567, vb01234567, va3, 3); + vacc4x01234567 = vmlaq_lane_f16(vacc4x01234567, vb01234567, va4, 3); + vacc5x01234567 = vmlaq_lane_f16(vacc5x01234567, vb01234567, va5, 3); + vacc6x01234567 = vmlaq_lane_f16(vacc6x01234567, vb01234567, va6, 3); + vacc7x01234567 = vmlaq_lane_f16(vacc7x01234567, vb01234567, va7, 3); + } + } + if (k != 0) { + const size_t a_predecrement = 4 - k; + const int64x1_t va_shift = vmov_n_s64(-16 * a_predecrement); + const float16x4_t va0 = vreinterpret_f16_u64(vshl_u64( + vreinterpret_u64_f16(vld1_f16(a0 - a_predecrement)), va_shift)); + const float16x4_t va1 = vreinterpret_f16_u64(vshl_u64( + vreinterpret_u64_f16(vld1_f16(a1 - a_predecrement)), va_shift)); + const float16x4_t va2 = vreinterpret_f16_u64(vshl_u64( + vreinterpret_u64_f16(vld1_f16(a2 - a_predecrement)), va_shift)); + const float16x4_t va3 = vreinterpret_f16_u64(vshl_u64( + vreinterpret_u64_f16(vld1_f16(a3 - a_predecrement)), va_shift)); + const float16x4_t va4 = vreinterpret_f16_u64(vshl_u64( + vreinterpret_u64_f16(vld1_f16(a4 - a_predecrement)), va_shift)); + const float16x4_t va5 = vreinterpret_f16_u64(vshl_u64( + vreinterpret_u64_f16(vld1_f16(a5 - a_predecrement)), va_shift)); + const float16x4_t va6 = vreinterpret_f16_u64(vshl_u64( + vreinterpret_u64_f16(vld1_f16(a6 - a_predecrement)), va_shift)); + const float16x4_t va7 = vreinterpret_f16_u64(vshl_u64( + vreinterpret_u64_f16(vld1_f16(a7 - a_predecrement)), va_shift)); + + { + const float16x8_t vb01234567 = vld1q_f16(w); + w = (void*)((uintptr_t)w + sizeof(float16x8_t)); + + vacc0x01234567 = vmlaq_lane_f16(vacc0x01234567, vb01234567, va0, 0); + vacc1x01234567 = vmlaq_lane_f16(vacc1x01234567, vb01234567, va1, 0); + vacc2x01234567 = vmlaq_lane_f16(vacc2x01234567, vb01234567, va2, 0); + vacc3x01234567 = vmlaq_lane_f16(vacc3x01234567, vb01234567, va3, 0); + vacc4x01234567 = vmlaq_lane_f16(vacc4x01234567, vb01234567, va4, 0); + vacc5x01234567 = vmlaq_lane_f16(vacc5x01234567, vb01234567, va5, 0); + vacc6x01234567 = vmlaq_lane_f16(vacc6x01234567, vb01234567, va6, 0); + vacc7x01234567 = vmlaq_lane_f16(vacc7x01234567, vb01234567, va7, 0); + } + + if (k >= 2) { + const float16x8_t vb01234567 = vld1q_f16(w); + w = (void*)((uintptr_t)w + sizeof(float16x8_t)); + + vacc0x01234567 = vmlaq_lane_f16(vacc0x01234567, vb01234567, va0, 1); + vacc1x01234567 = vmlaq_lane_f16(vacc1x01234567, vb01234567, va1, 1); + vacc2x01234567 = vmlaq_lane_f16(vacc2x01234567, vb01234567, va2, 1); + vacc3x01234567 = vmlaq_lane_f16(vacc3x01234567, vb01234567, va3, 1); + vacc4x01234567 = vmlaq_lane_f16(vacc4x01234567, vb01234567, va4, 1); + vacc5x01234567 = vmlaq_lane_f16(vacc5x01234567, vb01234567, va5, 1); + vacc6x01234567 = vmlaq_lane_f16(vacc6x01234567, vb01234567, va6, 1); + vacc7x01234567 = vmlaq_lane_f16(vacc7x01234567, vb01234567, va7, 1); + + if (k > 2) { + const float16x8_t vb01234567 = vld1q_f16(w); + w = (void*)((uintptr_t)w + sizeof(float16x8_t)); + + vacc0x01234567 = vmlaq_lane_f16(vacc0x01234567, vb01234567, va0, 2); + vacc1x01234567 = vmlaq_lane_f16(vacc1x01234567, vb01234567, va1, 2); + vacc2x01234567 = vmlaq_lane_f16(vacc2x01234567, vb01234567, va2, 2); + vacc3x01234567 = vmlaq_lane_f16(vacc3x01234567, vb01234567, va3, 2); + vacc4x01234567 = vmlaq_lane_f16(vacc4x01234567, vb01234567, va4, 2); + vacc5x01234567 = vmlaq_lane_f16(vacc5x01234567, vb01234567, va5, 2); + vacc6x01234567 = vmlaq_lane_f16(vacc6x01234567, vb01234567, va6, 2); + vacc7x01234567 = vmlaq_lane_f16(vacc7x01234567, vb01234567, va7, 2); + + if (k >= 4) { + const float16x8_t vb01234567 = vld1q_f16(w); + + vacc0x01234567 = vmlaq_lane_f16(vacc0x01234567, vb01234567, va0, 3); + vacc1x01234567 = vmlaq_lane_f16(vacc1x01234567, vb01234567, va1, 3); + vacc2x01234567 = vmlaq_lane_f16(vacc2x01234567, vb01234567, va2, 3); + vacc3x01234567 = vmlaq_lane_f16(vacc3x01234567, vb01234567, va3, 3); + vacc4x01234567 = vmlaq_lane_f16(vacc4x01234567, vb01234567, va4, 3); + vacc5x01234567 = vmlaq_lane_f16(vacc5x01234567, vb01234567, va5, 3); + vacc6x01234567 = vmlaq_lane_f16(vacc6x01234567, vb01234567, va6, 3); + vacc7x01234567 = vmlaq_lane_f16(vacc7x01234567, vb01234567, va7, 3); + } + } + } + } + const float16x8_t vscale = + vld1q_dup_f16((const __fp16*)&clamping_params->scale); + vacc0x01234567 = vmulq_f16(vacc0x01234567, vscale); + vacc1x01234567 = vmulq_f16(vacc1x01234567, vscale); + vacc2x01234567 = vmulq_f16(vacc2x01234567, vscale); + vacc3x01234567 = vmulq_f16(vacc3x01234567, vscale); + vacc4x01234567 = vmulq_f16(vacc4x01234567, vscale); + vacc5x01234567 = vmulq_f16(vacc5x01234567, vscale); + vacc6x01234567 = vmulq_f16(vacc6x01234567, vscale); + vacc7x01234567 = vmulq_f16(vacc7x01234567, vscale); + + const float16x8_t vmax = vld1q_dup_f16((const __fp16*)&clamping_params->max); + vacc0x01234567 = vminq_f16(vacc0x01234567, vmax); + vacc1x01234567 = vminq_f16(vacc1x01234567, vmax); + vacc2x01234567 = vminq_f16(vacc2x01234567, vmax); + vacc3x01234567 = vminq_f16(vacc3x01234567, vmax); + vacc4x01234567 = vminq_f16(vacc4x01234567, vmax); + vacc5x01234567 = vminq_f16(vacc5x01234567, vmax); + vacc6x01234567 = vminq_f16(vacc6x01234567, vmax); + vacc7x01234567 = vminq_f16(vacc7x01234567, vmax); + + const float16x8_t vmin = vld1q_dup_f16((const __fp16*)&clamping_params->min); + vacc0x01234567 = vmaxq_f16(vacc0x01234567, vmin); + vacc1x01234567 = vmaxq_f16(vacc1x01234567, vmin); + vacc2x01234567 = vmaxq_f16(vacc2x01234567, vmin); + vacc3x01234567 = vmaxq_f16(vacc3x01234567, vmin); + vacc4x01234567 = vmaxq_f16(vacc4x01234567, vmin); + vacc5x01234567 = vmaxq_f16(vacc5x01234567, vmin); + vacc6x01234567 = vmaxq_f16(vacc6x01234567, vmin); + vacc7x01234567 = vmaxq_f16(vacc7x01234567, vmin); + + __fp16* c0 = c; + __fp16* c1 = (__fp16*)((uintptr_t)c0 + c_stride); + if (mr < 2) { + c1 = c0; + } + __fp16* c2 = (__fp16*)((uintptr_t)c1 + c_stride); + if (mr <= 2) { + c2 = c1; + } + __fp16* c3 = (__fp16*)((uintptr_t)c2 + c_stride); + if (mr < 4) { + c3 = c2; + } + __fp16* c4 = (__fp16*)((uintptr_t)c3 + c_stride); + if (mr <= 4) { + c4 = c3; + } + __fp16* c5 = (__fp16*)((uintptr_t)c4 + c_stride); + if (mr < 6) { + c5 = c4; + } + __fp16* c6 = (__fp16*)((uintptr_t)c5 + c_stride); + if (mr <= 6) { + c6 = c5; + } + __fp16* c7 = (__fp16*)((uintptr_t)c6 + c_stride); + if (mr != 8) { + c7 = c6; + } + if (nr == 8) { + vst1q_f16(c0, vacc0x01234567); + vst1q_f16(c1, vacc1x01234567); + vst1q_f16(c2, vacc2x01234567); + vst1q_f16(c3, vacc3x01234567); + vst1q_f16(c4, vacc4x01234567); + vst1q_f16(c5, vacc5x01234567); + vst1q_f16(c6, vacc6x01234567); + vst1q_f16(c7, vacc7x01234567); + } else { + if (nr & 4) { + vst1_f16(c0, vget_low_f16(vacc0x01234567)); + c0 += 4; + vst1_f16(c1, vget_low_f16(vacc1x01234567)); + c1 += 4; + vst1_f16(c2, vget_low_f16(vacc2x01234567)); + c2 += 4; + vst1_f16(c3, vget_low_f16(vacc3x01234567)); + c3 += 4; + vst1_f16(c4, vget_low_f16(vacc4x01234567)); + c4 += 4; + vst1_f16(c5, vget_low_f16(vacc5x01234567)); + c5 += 4; + vst1_f16(c6, vget_low_f16(vacc6x01234567)); + c6 += 4; + vst1_f16(c7, vget_low_f16(vacc7x01234567)); + c7 += 4; + vacc0x01234567 = vextq_f16(vacc0x01234567, vacc0x01234567, 4); + vacc1x01234567 = vextq_f16(vacc1x01234567, vacc1x01234567, 4); + vacc2x01234567 = vextq_f16(vacc2x01234567, vacc2x01234567, 4); + vacc3x01234567 = vextq_f16(vacc3x01234567, vacc3x01234567, 4); + vacc4x01234567 = vextq_f16(vacc4x01234567, vacc4x01234567, 4); + vacc5x01234567 = vextq_f16(vacc5x01234567, vacc5x01234567, 4); + vacc6x01234567 = vextq_f16(vacc6x01234567, vacc6x01234567, 4); + vacc7x01234567 = vextq_f16(vacc7x01234567, vacc7x01234567, 4); + } + if (nr & 2) { + vst1_lane_u32( + __builtin_assume_aligned(c0, 1), + vreinterpret_u32_f16(vget_low_f16(vacc0x01234567)), + 0); + c0 += 2; + vst1_lane_u32( + __builtin_assume_aligned(c1, 1), + vreinterpret_u32_f16(vget_low_f16(vacc1x01234567)), + 0); + c1 += 2; + vst1_lane_u32( + __builtin_assume_aligned(c2, 1), + vreinterpret_u32_f16(vget_low_f16(vacc2x01234567)), + 0); + c2 += 2; + vst1_lane_u32( + __builtin_assume_aligned(c3, 1), + vreinterpret_u32_f16(vget_low_f16(vacc3x01234567)), + 0); + c3 += 2; + vst1_lane_u32( + __builtin_assume_aligned(c4, 1), + vreinterpret_u32_f16(vget_low_f16(vacc4x01234567)), + 0); + c4 += 2; + vst1_lane_u32( + __builtin_assume_aligned(c5, 1), + vreinterpret_u32_f16(vget_low_f16(vacc5x01234567)), + 0); + c5 += 2; + vst1_lane_u32( + __builtin_assume_aligned(c6, 1), + vreinterpret_u32_f16(vget_low_f16(vacc6x01234567)), + 0); + c6 += 2; + vst1_lane_u32( + __builtin_assume_aligned(c7, 1), + vreinterpret_u32_f16(vget_low_f16(vacc7x01234567)), + 0); + c7 += 2; + vacc0x01234567 = vextq_f16(vacc0x01234567, vacc0x01234567, 2); + vacc1x01234567 = vextq_f16(vacc1x01234567, vacc1x01234567, 2); + vacc2x01234567 = vextq_f16(vacc2x01234567, vacc2x01234567, 2); + vacc3x01234567 = vextq_f16(vacc3x01234567, vacc3x01234567, 2); + vacc4x01234567 = vextq_f16(vacc4x01234567, vacc4x01234567, 2); + vacc5x01234567 = vextq_f16(vacc5x01234567, vacc5x01234567, 2); + vacc6x01234567 = vextq_f16(vacc6x01234567, vacc6x01234567, 2); + vacc7x01234567 = vextq_f16(vacc7x01234567, vacc7x01234567, 2); + } + if (nr & 1) { + vst1q_lane_f16(c0, vacc0x01234567, 0); + vst1q_lane_f16(c1, vacc1x01234567, 0); + vst1q_lane_f16(c2, vacc2x01234567, 0); + vst1q_lane_f16(c3, vacc3x01234567, 0); + vst1q_lane_f16(c4, vacc4x01234567, 0); + vst1q_lane_f16(c5, vacc5x01234567, 0); + vst1q_lane_f16(c6, vacc6x01234567, 0); + vst1q_lane_f16(c7, vacc7x01234567, 0); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/indirection.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/indirection.c new file mode 100644 index 0000000000000..8785b3bbeda9e --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/indirection.c @@ -0,0 +1,277 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include +#include + +void pytorch_qnnp_indirection_init_conv2d( + pytorch_qnnp_operator_t op, + size_t output_tile_size, + size_t tiled_output_size) { + const void** indirection_buffer = op->indirection_buffer; + const void* input = op->input; + const size_t input_pixel_stride = op->input_pixel_stride; + const void* zero = op->zero_pointer; + const size_t groups = op->groups; + const size_t group_input_channels = op->group_input_channels; + const size_t batch_size = op->batch_size; + const size_t input_height = op->input_height; + const size_t input_width = op->input_width; + const size_t output_height = op->output_height; + const size_t output_width = op->output_width; + const size_t kernel_height = op->kernel_height; + const size_t kernel_width = op->kernel_width; + const size_t stride_height = op->stride_height; + const size_t stride_width = op->stride_width; + const size_t dilation_height = op->dilation_height; + const size_t dilation_width = op->dilation_width; + const size_t input_padding_top = op->input_padding_top; + const size_t input_padding_left = op->input_padding_left; + + const size_t output_size = output_height * output_width; + const size_t kernel_size = kernel_height * kernel_width; + const struct fxdiv_divisor_size_t output_width_divisor = + fxdiv_init_size_t(output_width); + for (size_t group = 0; group < groups; group++) { + for (size_t image = 0; image < batch_size; image++) { + for (size_t output_tile_start = 0; output_tile_start < tiled_output_size; + output_tile_start += output_tile_size) { + for (size_t output_tile_offset = 0; + output_tile_offset < output_tile_size; + output_tile_offset++) { + const size_t tiled_output_index = + output_tile_start + output_tile_offset; + const size_t output_index = min(tiled_output_index, output_size - 1); + const struct fxdiv_result_size_t output_index_components = + fxdiv_divide_size_t(output_index, output_width_divisor); + const size_t output_y = output_index_components.quotient; + const size_t output_x = output_index_components.remainder; + for (size_t kernel_y = 0; kernel_y < kernel_height; kernel_y++) { + const size_t input_y = output_y * stride_height + + kernel_y * dilation_height - input_padding_top; + if (input_y < input_height) { + for (size_t kernel_x = 0; kernel_x < kernel_width; kernel_x++) { + const size_t input_x = output_x * stride_width + + kernel_x * dilation_width - input_padding_left; + const size_t index = (group * batch_size + image) * + tiled_output_size * kernel_size + + output_tile_start * kernel_size + + (kernel_y * kernel_width + kernel_x) * output_tile_size + + output_tile_offset; + if (input_x < input_width) { + indirection_buffer[index] = (char*)input + + ((image * input_height + input_y) * input_width + + input_x) * + input_pixel_stride + + group * group_input_channels; + } else { + indirection_buffer[index] = zero; + } + } + } else { + for (size_t kernel_x = 0; kernel_x < kernel_width; kernel_x++) { + const size_t index = (group * batch_size + image) * + tiled_output_size * kernel_size + + output_tile_start * kernel_size + + (kernel_y * kernel_width + kernel_x) * output_tile_size + + output_tile_offset; + indirection_buffer[index] = zero; + } + } + } + } + } + } + } +} + +void pytorch_qnnp_indirection_init_dwconv2d( + pytorch_qnnp_operator_t op, + size_t batch_start, + size_t step_height, + size_t step_width) { + const void** indirection_buffer = op->indirection_buffer; + const void* input = op->input; + const size_t input_pixel_stride = op->input_pixel_stride; + const void* zero = op->zero_pointer; + const size_t batch_size = op->batch_size; + const size_t input_height = op->input_height; + const size_t input_width = op->input_width; + const size_t output_height = op->output_height; + const size_t output_width = op->output_width; + const size_t kernel_height = op->kernel_height; + const size_t kernel_width = op->kernel_width; + const size_t stride_height = op->stride_height; + const size_t stride_width = op->stride_width; + const size_t dilation_height = op->dilation_height; + const size_t dilation_width = op->dilation_width; + const size_t input_padding_top = op->input_padding_top; + const size_t input_padding_left = op->input_padding_left; + + for (size_t image = batch_start; image < batch_size; image++) { + for (size_t output_y = 0; output_y < output_height; output_y++) { + for (size_t kernel_y = 0; kernel_y < kernel_height; kernel_y++) { + const size_t input_y = output_y * stride_height + + kernel_y * dilation_height - input_padding_top; + if (input_y < input_height) { + for (size_t output_x = 0; output_x < output_width; output_x++) { + for (size_t kernel_x = 0; kernel_x < kernel_width; kernel_x++) { + const size_t input_x = output_x * stride_width + + kernel_x * dilation_width - input_padding_left; + const size_t index = + (image * output_height + output_y) * step_height + + output_x * step_width * kernel_height + + kernel_x * kernel_height + kernel_y; + if (input_x < input_width) { + indirection_buffer[index] = (char*)input + + ((image * input_height + input_y) * input_width + input_x) * + input_pixel_stride; + } else { + indirection_buffer[index] = zero; + } + } + } + } else { + for (size_t output_x = 0; output_x < output_width; output_x++) { + for (size_t kernel_x = 0; kernel_x < kernel_width; kernel_x++) { + const size_t index = + (image * output_height + output_y) * step_height + + output_x * step_width * kernel_height + + kernel_x * kernel_height + kernel_y; + indirection_buffer[index] = zero; + } + } + } + } + } + } +} + +void pytorch_qnnp_indirection_init_deconv2d( + pytorch_qnnp_operator_t op, + size_t output_tile_size, + size_t tiled_output_size) { + const void** indirection_buffer = op->indirection_buffer; + const void* input = op->input; + const size_t input_pixel_stride = op->input_pixel_stride; + const void* zero = op->zero_pointer; + const size_t groups = op->groups; + const size_t group_input_channels = op->group_input_channels; + const size_t batch_size = op->batch_size; + const size_t input_height = op->input_height; + const size_t input_width = op->input_width; + const size_t output_height = op->output_height; + const size_t output_width = op->output_width; + const size_t kernel_height = op->kernel_height; + const size_t kernel_width = op->kernel_width; + const size_t stride_height = op->stride_height; + const size_t stride_width = op->stride_width; + const size_t dilation_height = op->dilation_height; + const size_t dilation_width = op->dilation_width; + const size_t input_padding_top = op->input_padding_top; + const size_t input_padding_left = op->input_padding_left; + + const size_t output_size = output_height * output_width; + const size_t kernel_size = kernel_height * kernel_width; + + for (size_t group = 0; group < groups; group++) { + for (size_t image = 0; image < batch_size; image++) { + for (size_t output_tile_start = 0; output_tile_start < tiled_output_size; + output_tile_start += output_tile_size) { + for (size_t output_tile_offset = 0; + output_tile_offset < output_tile_size; + output_tile_offset++) { + const size_t tiled_output_index = + output_tile_start + output_tile_offset; + const size_t output_index = min(tiled_output_index, output_size - 1); + const size_t output_y = output_index / output_width; + const size_t output_x = output_index % output_width; + for (size_t kernel_y = 0; kernel_y < kernel_height; kernel_y++) { + const size_t y = + output_y + input_padding_top - kernel_y * dilation_height; + const size_t input_y = y / stride_height; + for (size_t kernel_x = 0; kernel_x < kernel_width; kernel_x++) { + const size_t x = + output_x + input_padding_left - kernel_x * dilation_width; + const size_t input_x = x / stride_width; + const size_t index = (group * batch_size + image) * + tiled_output_size * kernel_size + + output_tile_start * kernel_size + + (kernel_y * kernel_width + kernel_x) * output_tile_size + + output_tile_offset; + if (input_y * stride_height == y && input_y < input_height && + input_x * stride_width == x && input_x < input_width) { + indirection_buffer[index] = (char*)input + + ((image * input_height + input_y) * input_width + input_x) * + input_pixel_stride + + group * group_input_channels; + } else { + indirection_buffer[index] = zero; + } + } + } + } + } + } + } +} + +void pytorch_qnnp_indirection_init_maxpool2d( + pytorch_qnnp_operator_t op, + size_t batch_start, + size_t step_height, + size_t step_width) { + const void** indirection_buffer = op->indirection_buffer; + const void* input = op->input; + const size_t input_pixel_stride = op->input_pixel_stride; + const size_t batch_size = op->batch_size; + const size_t input_height = op->input_height; + const size_t input_width = op->input_width; + const size_t output_height = op->output_height; + const size_t output_width = op->output_width; + const size_t pooling_height = op->kernel_height; + const size_t pooling_width = op->kernel_width; + const size_t stride_height = op->stride_height; + const size_t stride_width = op->stride_width; + const size_t dilation_height = op->dilation_height; + const size_t dilation_width = op->dilation_width; + const size_t input_padding_top = op->input_padding_top; + const size_t input_padding_left = op->input_padding_left; + + for (size_t image = batch_start; image < batch_size; image++) { + for (size_t output_y = 0; output_y < output_height; output_y++) { + for (size_t pooling_y = 0; pooling_y < pooling_height; pooling_y++) { + const size_t input_y = + doz(output_y * stride_height + pooling_y * dilation_height, + input_padding_top); + const size_t clamped_input_y = min(input_y, input_height - 1); + for (size_t output_x = 0; output_x < output_width; output_x++) { + for (size_t pooling_x = 0; pooling_x < pooling_width; pooling_x++) { + const size_t input_x = + doz(output_x * stride_width + pooling_x * dilation_width, + input_padding_left); + const size_t clamped_input_x = min(input_x, input_width - 1); + const size_t index = + (image * output_height + output_y) * step_height + + output_x * step_width * pooling_height + + pooling_x * pooling_height + pooling_y; + indirection_buffer[index] = (char*)input + + ((image * input_height + clamped_input_y) * input_width + + clamped_input_x) * + input_pixel_stride; + } + } + } + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/init.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/init.c new file mode 100644 index 0000000000000..f1c55806f103e --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/init.c @@ -0,0 +1,277 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#ifdef _MSC_VER +#include +#else +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#ifdef _MSC_VER +static INIT_ONCE init_guard; +BOOL CALLBACK init_win(PINIT_ONCE InitOnce, PVOID Parameter, PVOID* lpContex); +#else +static pthread_once_t init_guard = PTHREAD_ONCE_INIT; +#endif + +struct pytorch_qnnp_parameters pytorch_qnnp_params = {.initialized = false}; + +static void init(void) { +#if CPUINFO_ARCH_ARM + if (!cpuinfo_has_arm_neon()) { + pytorch_qnnp_log_error( + "QNNPACK initialization failed: NEON is not supported"); + return; + } + pytorch_qnnp_params.q8conv = (struct pytorch_q8conv_parameters){ + .gemm = pytorch_q8gemm_ukernel_4x8__aarch32_neon, + .conv = pytorch_q8conv_ukernel_4x8__aarch32_neon, + .mr = 4, + .nr = 8, + .kr = 1, + }; +#if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION + pytorch_qnnp_params.q8conv_xzp = (struct pytorch_q8conv_xzp_parameters){ + .gemm = pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon, + .mr = 4, + .nr = 8, + .kr = 2, + .kc = 8, + .kthreshold = SIZE_MAX, + }; + /* setup xzp threshold based on measurements */ + switch (cpuinfo_get_core(0)->uarch) { + case cpuinfo_uarch_cortex_a72: + pytorch_qnnp_params.q8conv_xzp.kthreshold = 64; + break; + case cpuinfo_uarch_cortex_a73: + pytorch_qnnp_params.q8conv_xzp.kthreshold = 256; + break; + case cpuinfo_uarch_cortex_a75: + pytorch_qnnp_params.q8conv_xzp.kthreshold = 32; + break; + case cpuinfo_uarch_cortex_a76: + pytorch_qnnp_params.q8conv_xzp.kthreshold = 16; + break; + default: + break; + } +#else + pytorch_qnnp_params.q8conv_xzp = (struct pytorch_q8conv_xzp_parameters){ + .kthreshold = SIZE_MAX, + }; +#endif + pytorch_qnnp_params.q8dw9 = (struct pytorch_q8dwconv_up_parameters){ + .updw = pytorch_q8dwconv_ukernel_up8x9__aarch32_neon, + .cr = 8, + }; + pytorch_qnnp_params.q8dw25 = (struct pytorch_q8dwconv_mp_parameters){ + .mpdw = pytorch_q8dwconv_ukernel_mp8x25__neon, + .cr = 8, + }; + pytorch_qnnp_params.q8sum_rows = (struct pytorch_q8sum_rows_parameters){ + .sum_rows = pytorch_q8sumrows_ukernel_4x__neon, + .m = 4, + }; + pytorch_qnnp_params.q8vadd = pytorch_q8vadd_ukernel__neon; + pytorch_qnnp_params.q8gavgpool = (struct pytorch_q8gavgpool_parameters){ + .ltnr = pytorch_q8gavgpool_ukernel_up8xm__neon, + .genr_lemr = pytorch_q8gavgpool_ukernel_up8x7__neon, + .genr_gtmr = pytorch_q8gavgpool_ukernel_mp8x7p7q__neon, + .mr = 7, + .nr = 8, + }; + pytorch_qnnp_params.q8avgpool = (struct pytorch_q8avgpool_parameters){ + .ltkr = pytorch_q8avgpool_ukernel_up8xm__neon, + .gekr_lemr = pytorch_q8avgpool_ukernel_up8x9__neon, + .gekr_gtmr = pytorch_q8avgpool_ukernel_mp8x9p8q__neon, + .mr = 9, + .qr = 8, + .kr = 8, + }; + pytorch_qnnp_params.u8maxpool = (struct pytorch_u8maxpool_parameters){ + .ltkr = pytorch_u8maxpool_ukernel_sub16__neon, + .gekr = pytorch_u8maxpool_ukernel_16x9p8q__neon, + .mr = 9, + .qr = 8, + .kr = 16, + }; + pytorch_qnnp_params.x8zip = (struct pytorch_x8zip_parameters){ + .x2 = pytorch_qnnp_x8zip_x2__neon, + .x3 = pytorch_qnnp_x8zip_x3__neon, + .x4 = pytorch_qnnp_x8zip_x4__neon, + .xm = pytorch_qnnp_x8zip_xm__neon, + }; + pytorch_qnnp_params.u8clamp = pytorch_u8clamp_ukernel__neon; + pytorch_qnnp_params.u8rmax = pytorch_u8rmax_ukernel__neon; + pytorch_qnnp_params.u8lut32norm = pytorch_u8lut32norm_ukernel__scalar; + pytorch_qnnp_params.x8lut = pytorch_x8lut_ukernel__scalar; +#elif CPUINFO_ARCH_ARM64 + pytorch_qnnp_params.q8conv = (struct pytorch_q8conv_parameters){ + .gemm = pytorch_q8gemm_ukernel_8x8__aarch64_neon, + .conv = pytorch_q8conv_ukernel_8x8__aarch64_neon, + .mr = 8, + .nr = 8, + .kr = 1, + }; + pytorch_qnnp_params.q8conv_xzp = (struct pytorch_q8conv_xzp_parameters){ + .kthreshold = SIZE_MAX, + }; + pytorch_qnnp_params.q8dw9 = (struct pytorch_q8dwconv_up_parameters){ + .updw = pytorch_q8dwconv_ukernel_up8x9__neon, + .cr = 8, + }; + pytorch_qnnp_params.q8dw25 = (struct pytorch_q8dwconv_mp_parameters){ + .mpdw = pytorch_q8dwconv_ukernel_mp8x25__neon, + .cr = 8, + }; + pytorch_qnnp_params.q8vadd = pytorch_q8vadd_ukernel__neon; + pytorch_qnnp_params.q8gavgpool = (struct pytorch_q8gavgpool_parameters){ + .ltnr = pytorch_q8gavgpool_ukernel_up8xm__neon, + .genr_lemr = pytorch_q8gavgpool_ukernel_up8x7__neon, + .genr_gtmr = pytorch_q8gavgpool_ukernel_mp8x7p7q__neon, + .mr = 7, + .nr = 8, + }; + pytorch_qnnp_params.q8avgpool = (struct pytorch_q8avgpool_parameters){ + .ltkr = pytorch_q8avgpool_ukernel_up8xm__neon, + .gekr_lemr = pytorch_q8avgpool_ukernel_up8x9__neon, + .gekr_gtmr = pytorch_q8avgpool_ukernel_mp8x9p8q__neon, + .mr = 9, + .qr = 8, + .kr = 8, + }; + pytorch_qnnp_params.u8maxpool = (struct pytorch_u8maxpool_parameters){ + .ltkr = pytorch_u8maxpool_ukernel_sub16__neon, + .gekr = pytorch_u8maxpool_ukernel_16x9p8q__neon, + .mr = 9, + .qr = 8, + .kr = 16, + }; + pytorch_qnnp_params.x8zip = (struct pytorch_x8zip_parameters){ + .x2 = pytorch_qnnp_x8zip_x2__neon, + .x3 = pytorch_qnnp_x8zip_x3__neon, + .x4 = pytorch_qnnp_x8zip_x4__neon, + .xm = pytorch_qnnp_x8zip_xm__neon, + }; + pytorch_qnnp_params.u8clamp = pytorch_u8clamp_ukernel__neon; + pytorch_qnnp_params.u8rmax = pytorch_u8rmax_ukernel__neon; + pytorch_qnnp_params.u8lut32norm = pytorch_u8lut32norm_ukernel__scalar; + pytorch_qnnp_params.x8lut = pytorch_x8lut_ukernel__scalar; +#elif CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + if (!cpuinfo_has_x86_sse2()) { + pytorch_qnnp_log_error( + "QNNPACK initialization failed: SSE2 is not supported"); + return; + } + pytorch_qnnp_params.q8conv = (struct pytorch_q8conv_parameters){ + .gemm = pytorch_q8gemm_ukernel_4x4c2__sse2, + .conv = pytorch_q8conv_ukernel_4x4c2__sse2, + .mr = 4, + .nr = 4, + .kr = 2, + }; + pytorch_qnnp_params.q8conv_xzp = (struct pytorch_q8conv_xzp_parameters){ + .kthreshold = SIZE_MAX, + }; + pytorch_qnnp_params.q8dw9 = (struct pytorch_q8dwconv_up_parameters){ + .updw = pytorch_q8dwconv_ukernel_up8x9__sse2, + .cr = 8, + }; + pytorch_qnnp_params.q8dw25 = (struct pytorch_q8dwconv_mp_parameters){ + .mpdw = pytorch_q8dwconv_ukernel_mp8x25__sse2, + .cr = 8, + }; + pytorch_qnnp_params.q8vadd = pytorch_q8vadd_ukernel__sse2; + pytorch_qnnp_params.q8gavgpool = (struct pytorch_q8gavgpool_parameters){ + .ltnr = pytorch_q8gavgpool_ukernel_up8xm__sse2, + .genr_lemr = pytorch_q8gavgpool_ukernel_up8x7__sse2, + .genr_gtmr = pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2, + .mr = 7, + .nr = 8, + }; + pytorch_qnnp_params.q8avgpool = (struct pytorch_q8avgpool_parameters){ + .ltkr = pytorch_q8avgpool_ukernel_up8xm__sse2, + .gekr_lemr = pytorch_q8avgpool_ukernel_up8x9__sse2, + .gekr_gtmr = pytorch_q8avgpool_ukernel_mp8x9p8q__sse2, + .mr = 9, + .qr = 8, + .kr = 8, + }; + pytorch_qnnp_params.u8maxpool = (struct pytorch_u8maxpool_parameters){ + .ltkr = pytorch_u8maxpool_ukernel_sub16__sse2, + .gekr = pytorch_u8maxpool_ukernel_16x9p8q__sse2, + .mr = 9, + .qr = 8, + .kr = 16, + }; + pytorch_qnnp_params.x8zip = (struct pytorch_x8zip_parameters){ + .x2 = pytorch_qnnp_x8zip_x2__sse2, + .x3 = pytorch_qnnp_x8zip_x3__sse2, + .x4 = pytorch_qnnp_x8zip_x4__sse2, + .xm = pytorch_qnnp_x8zip_xm__sse2, + }; + pytorch_qnnp_params.u8clamp = pytorch_u8clamp_ukernel__sse2; + pytorch_qnnp_params.u8rmax = pytorch_u8rmax_ukernel__sse2; + pytorch_qnnp_params.u8lut32norm = pytorch_u8lut32norm_ukernel__scalar; + pytorch_qnnp_params.x8lut = pytorch_x8lut_ukernel__scalar; +#else +#error "Unsupported architecture" +#endif + pytorch_qnnp_params.initialized = true; +} + +enum pytorch_qnnp_status pytorch_qnnp_initialize(void) { + if (!cpuinfo_initialize()) { + return pytorch_qnnp_status_out_of_memory; + } +#ifdef _MSC_VER + InitOnceExecuteOnce(&init_guard, init_win, NULL, NULL); +#else + pthread_once(&init_guard, &init); +#endif + if (pytorch_qnnp_params.initialized) { + return pytorch_qnnp_status_success; + } else { + return pytorch_qnnp_status_unsupported_hardware; + } +} + +enum pytorch_qnnp_status pytorch_qnnp_deinitialize(void) { + cpuinfo_deinitialize(); + return pytorch_qnnp_status_success; +} + +#ifdef _MSC_VER +BOOL CALLBACK init_win(PINIT_ONCE InitOnce, PVOID Parameter, PVOID* lpContex) { + init(); + return TRUE; +} +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/leaky-relu.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/leaky-relu.c new file mode 100644 index 0000000000000..a04d39a315db6 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/leaky-relu.c @@ -0,0 +1,169 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +enum pytorch_qnnp_status pytorch_qnnp_create_leaky_relu_nc_q8( + size_t channels, + float negative_slope, + uint8_t input_zero_point, + float input_scale, + uint8_t output_zero_point, + float output_scale, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + pytorch_qnnp_operator_t* leaky_relu_out) { + pytorch_qnnp_operator_t leaky_relu_op = NULL; + enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized; + + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_create_leaky_relu_nc_q8 failed because QNNPACK is not properly initialized"); + goto error; + } + + status = pytorch_qnnp_status_invalid_parameter; + + if (channels == 0) { + pytorch_qnnp_log_error( + "failed to create Leaky ReLU operator with %zu channels: number of channels must be non-zero", + channels); + goto error; + } + + if (negative_slope <= 0.0f || !isnormal(negative_slope)) { + pytorch_qnnp_log_error( + "failed to create Leaky ReLU operator with %.7g negative slope: slope must be finite and positive", + negative_slope); + goto error; + } + + if (negative_slope > 1.0f) { + pytorch_qnnp_log_error( + "failed to create Leaky ReLU operator with %.7g negative slope: slope must not exceed 1.0", + negative_slope); + goto error; + } + + if (input_scale <= 0.0f || !isnormal(input_scale)) { + pytorch_qnnp_log_error( + "failed to create Leaky ReLU operator with %.7g input scale: scale must be finite and positive", + input_scale); + goto error; + } + + if (output_scale <= 0.0f || !isnormal(output_scale)) { + pytorch_qnnp_log_error( + "failed to create Leaky ReLU operator with %.7g output scale: scale must be finite and positive", + output_scale); + goto error; + } + + if (output_min >= output_max) { + pytorch_qnnp_log_error( + "failed to create Leaky ReLU operator with [%" PRIu8 ", %" PRIu8 + "] output range: range min must be below range max", + output_min, + output_max); + goto error; + } + + status = pytorch_qnnp_status_unsupported_parameter; + + const float input_output_scale = input_scale / output_scale; + if (input_output_scale < 0x1.0p-8f || input_output_scale >= 0x1.0p+8f) { + pytorch_qnnp_log_error( + "failed to create Leaky ReLU operator with %.7g input-to-output scale ratio: " + "scale ratio must be in [2**-8, 2**8) range", + input_output_scale); + goto error; + } + + status = pytorch_qnnp_status_out_of_memory; + + leaky_relu_op = calloc(1, sizeof(struct pytorch_qnnp_operator)); + if (leaky_relu_op == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for pytorch_qnnp_operator structure", + sizeof(struct pytorch_qnnp_operator)); + goto error; + } + + leaky_relu_op->lookup_table = malloc(256 * sizeof(uint8_t)); + if (leaky_relu_op->lookup_table == NULL) { + pytorch_qnnp_log_error( + "failed to allocate 256 bytes for Leaky ReLU lookup table"); + goto error; + } + + uint8_t* lookup_table = leaky_relu_op->lookup_table; + const float scaled_min_less_zero_point = + (float)((int32_t)output_min - (int32_t)output_zero_point); + const float scaled_max_less_zero_point = + (float)((int32_t)output_max - (int32_t)output_zero_point); + for (int32_t i = 0; i < 256; i++) { + const float x = + input_output_scale * (float)(i - (int32_t)(uint32_t)input_zero_point); + float y = x < 0.0f ? x * negative_slope : x; + if (y < scaled_min_less_zero_point) { + y = scaled_min_less_zero_point; + } + if (y > scaled_max_less_zero_point) { + y = scaled_max_less_zero_point; + } + lookup_table[(uint32_t)i] = (uint8_t)(lrintf(y) + (long)output_zero_point); + } + + leaky_relu_op->channels = channels; + + leaky_relu_op->ukernel_type = pytorch_qnnp_ukernel_type_lut; + leaky_relu_op->format = pytorch_qnnp_format_quint8; + + *leaky_relu_out = leaky_relu_op; + return pytorch_qnnp_status_success; + +error: + pytorch_qnnp_delete_operator(leaky_relu_op); + return status; +} + +enum pytorch_qnnp_status pytorch_qnnp_setup_leaky_relu_nc_q8( + pytorch_qnnp_operator_t leaky_relu, + size_t batch_size, + const uint8_t* input, + size_t input_stride, + uint8_t* output, + size_t output_stride) { + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_setup_leaky_relu_nc_q8 failed because QNNPACK is not properly initialized"); + return pytorch_qnnp_status_uninitialized; + } + + if (batch_size == 0) { + leaky_relu->batch_size = 0; + return pytorch_qnnp_status_success; + } + + leaky_relu->batch_size = batch_size; + leaky_relu->input = input; + leaky_relu->input_pixel_stride = input_stride; + leaky_relu->output = output; + leaky_relu->output_pixel_stride = output_stride; + + return pytorch_qnnp_status_success; +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/max-pooling.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/max-pooling.c new file mode 100644 index 0000000000000..7e1c27575f41b --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/max-pooling.c @@ -0,0 +1,243 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +static inline size_t compute_output_dimension( + size_t padded_input_dimension, + size_t kernel_dimension, + size_t dilation_dimension, + size_t stride_dimension) { + const size_t effective_kernel_dimension = + (kernel_dimension - 1) * dilation_dimension + 1; + return (padded_input_dimension - effective_kernel_dimension) / + stride_dimension + + 1; +} + +enum pytorch_qnnp_status pytorch_qnnp_create_max_pooling2d_nhwc_u8( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t pooling_height, + uint32_t pooling_width, + uint32_t stride_height, + uint32_t stride_width, + uint32_t dilation_height, + uint32_t dilation_width, + size_t channels, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + pytorch_qnnp_operator_t* max_pooling_out) { + pytorch_qnnp_operator_t max_pooling = NULL; + enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized; + + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_create_max_pooling2d_nhwc_u8 failed because QNNPACK is not properly initialized"); + goto error; + } + + status = pytorch_qnnp_status_invalid_parameter; + + const uint32_t pooling_size = pooling_height * pooling_width; + if (pooling_size == 0) { + pytorch_qnnp_log_error( + "failed to create max pooling with %" PRIu32 "x%" PRIu32 + " pooling size: " + "pooling size dimensions must be non-zero", + pooling_width, + pooling_height); + goto error; + } + + if (pooling_size == 1) { + pytorch_qnnp_log_error( + "failed to create max pooling with 1 pooling element: " + "1x1 pooling is meaningless"); + goto error; + } + + if (stride_height == 0 || stride_width == 0) { + pytorch_qnnp_log_error( + "failed to create max pooling with %" PRIu32 "x%" PRIu32 + " stride: " + "stride dimensions must be non-zero", + stride_width, + stride_height); + goto error; + } + + if (dilation_height == 0 || dilation_width == 0) { + pytorch_qnnp_log_error( + "failed to create max pooling with %" PRIu32 "x%" PRIu32 + " dilation: " + "dilation dimensions must be non-zero", + dilation_width, + dilation_height); + goto error; + } + + if (channels == 0) { + pytorch_qnnp_log_error( + "failed to create max pooling with %zu channels: " + "number of channels must be non-zero", + channels); + goto error; + } + + status = pytorch_qnnp_status_out_of_memory; + + max_pooling = calloc(1, sizeof(struct pytorch_qnnp_operator)); + if (max_pooling == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for pytorch_qnnp_operator structure", + sizeof(struct pytorch_qnnp_operator)); + goto error; + } + + max_pooling->input_padding_top = input_padding_top; + max_pooling->input_padding_right = input_padding_right; + max_pooling->input_padding_bottom = input_padding_bottom; + max_pooling->input_padding_left = input_padding_left; + + max_pooling->kernel_height = pooling_height; + max_pooling->kernel_width = pooling_width; + max_pooling->stride_height = stride_height; + max_pooling->stride_width = stride_width; + max_pooling->dilation_height = dilation_height; + max_pooling->dilation_width = dilation_width; + max_pooling->channels = channels; + + max_pooling->u8_clamping_params = + pytorch_qnnp_compute_u8_clamping_params(output_min, output_max); + + max_pooling->ukernel_type = pytorch_qnnp_ukernel_type_max_pooling; + max_pooling->format = pytorch_qnnp_format_quint8; + + *max_pooling_out = max_pooling; + return pytorch_qnnp_status_success; + +error: + pytorch_qnnp_delete_operator(max_pooling); + return status; +} + +enum pytorch_qnnp_status pytorch_qnnp_setup_max_pooling2d_nhwc_u8( + pytorch_qnnp_operator_t max_pooling, + size_t batch_size, + size_t input_height, + size_t input_width, + const uint8_t* input, + size_t input_pixel_stride, + uint8_t* output, + size_t output_pixel_stride, + pthreadpool_t threadpool) { + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_setup_max_pooling2d_nhwc_u8 failed because QNNPACK is not properly initialized"); + return pytorch_qnnp_status_uninitialized; + } + + if (batch_size == 0) { + max_pooling->batch_size = 0; + return pytorch_qnnp_status_success; + } + + if (input_width == 0 || input_height == 0) { + pytorch_qnnp_log_error( + "failed to setup max pooling with %zux%zu input: input dimensions must be non-zero", + input_width, + input_height); + return pytorch_qnnp_status_invalid_parameter; + } + + max_pooling->batch_size = batch_size; + max_pooling->input_height = input_height; + max_pooling->input_width = input_width; + max_pooling->input = input; + max_pooling->input_pixel_stride = input_pixel_stride; + + max_pooling->output_height = compute_output_dimension( + max_pooling->input_padding_top + input_height + + max_pooling->input_padding_bottom, + max_pooling->kernel_height, + max_pooling->dilation_height, + max_pooling->stride_height); + max_pooling->output_width = compute_output_dimension( + max_pooling->input_padding_left + input_width + + max_pooling->input_padding_right, + max_pooling->kernel_width, + max_pooling->dilation_width, + max_pooling->stride_width); + max_pooling->output = output; + max_pooling->output_pixel_stride = output_pixel_stride; + + size_t valid_batch_size = 0; + if (input == max_pooling->last_input && + input_height == max_pooling->last_input_height && + input_width == max_pooling->last_input_width) { + valid_batch_size = max_pooling->valid_batch_size; + if (batch_size <= valid_batch_size) { + return pytorch_qnnp_status_success; + } + } + + const size_t pooling_height = max_pooling->kernel_height; + const size_t pooling_width = max_pooling->kernel_width; + const size_t pooling_size = pooling_height * pooling_width; + const size_t output_height = max_pooling->output_height; + const size_t output_width = max_pooling->output_width; + /* Micro-kernel may read up to (mr - 1) elements after the end of indirection + * buffer */ + const uint32_t mr = pytorch_qnnp_params.u8maxpool.mr; + + const size_t step_width = max_pooling->dilation_width > 1 + ? pooling_width + : min(max_pooling->stride_width, pooling_width); + const size_t step_height = + pooling_size + (output_width * step_width - 1) * pooling_height; + const size_t indirection_buffer_size = + sizeof(void*) * ((mr - 1) + batch_size * output_height * step_height); + + const void** indirection_buffer = (const void**)realloc( + max_pooling->indirection_buffer, indirection_buffer_size); + if (indirection_buffer == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for indirection buffer", + indirection_buffer_size); + return pytorch_qnnp_status_out_of_memory; + } + max_pooling->indirection_buffer = indirection_buffer; + + pytorch_qnnp_indirection_init_maxpool2d( + max_pooling, valid_batch_size, step_height, step_width); + + max_pooling->last_input = input; + max_pooling->last_input_height = input_height; + max_pooling->last_input_width = input_width; + max_pooling->valid_batch_size = max(valid_batch_size, batch_size); + + return pytorch_qnnp_status_success; +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/operator-delete.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/operator-delete.c new file mode 100644 index 0000000000000..a53403546bab2 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/operator-delete.c @@ -0,0 +1,27 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +enum pytorch_qnnp_status pytorch_qnnp_delete_operator( + pytorch_qnnp_operator_t op) { + if (op == NULL) { + return pytorch_qnnp_status_invalid_parameter; + } + + free(op->indirection_buffer); + free(op->packed_weights); + free(op->a_sum); + free(op->zero_buffer); + free(op->lookup_table); + free(op); + return pytorch_qnnp_status_success; +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/operator-run.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/operator-run.c new file mode 100644 index 0000000000000..08e42af5c1a31 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/operator-run.c @@ -0,0 +1,1237 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#ifdef _MSC_VER +#include +#endif + +struct q8gemm_context { + size_t k; + size_t k_stride; + size_t n; + size_t n_stride; + const uint8_t* a; + size_t a_stride; + const uint8_t* packed_w; + uint8_t* c; + size_t c_stride; + union pytorch_qnnp_conv_quantization_params quantization_params; + const pytorch_q8gemm_ukernel_function ukernel; +}; + +static void compute_q8gemm( + const struct q8gemm_context context[RESTRICT_STATIC 1], + size_t group_index, + size_t pixel_index, + size_t mr_block_start, + size_t nr_block_start, + size_t group_range /* always 1 */, + size_t pixel_range, + size_t mr_block_size, + size_t nr_block_size) { + const size_t k = context->k; + const size_t k_stride = context->k_stride; + const size_t n = context->n; + const size_t n_stride = context->n_stride; + const uint8_t* restrict a = context->a; + const size_t a_stride = context->a_stride; + const void* restrict packed_w = context->packed_w; + uint8_t* restrict c = context->c; + const size_t c_stride = context->c_stride; + + context->ukernel( + mr_block_size, + nr_block_size, + k, + a + (pixel_index + mr_block_start) * a_stride + group_index * k, + a_stride, + (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))), + c + (pixel_index + mr_block_start) * c_stride + nr_block_start + + group_index * n, + c_stride, + &context->quantization_params); +} + +struct q8sum_rows_context { + const uint8_t* a; + size_t groups; + size_t m; + size_t k; + size_t a_stride; + const int32_t multiplier; + int32_t* a_sum; + size_t a_sum_stride; + const pytorch_q8sum_rows_ukernel_function ukernel; +}; + +static void compute_sum_rows( + const struct q8sum_rows_context context[RESTRICT_STATIC 1], + size_t group_index, + size_t batch_index, + size_t block_start, + size_t group_range /* always 1 */, + size_t batch_range /* always 1 */, + size_t block_size) { + const uint8_t* a = context->a; + const size_t groups = context->groups; + const size_t m = context->m; + const size_t k = context->k; + const size_t a_stride = context->a_stride; + const int32_t multiplier = context->multiplier; + int32_t* a_sum = context->a_sum; + const size_t a_sum_stride = context->a_sum_stride; + + context->ukernel( + a + batch_index * m * a_stride + group_index * k + block_start * a_stride, + min(block_size, m - block_start), + k, + a_stride, + multiplier, + a_sum + batch_index * groups * a_sum_stride + group_index * a_sum_stride + + block_start); +} + +struct q8gemm_xzp_context { + size_t k; + size_t k_stride; + size_t n; + size_t n_stride; + const uint8_t* a; + size_t a_stride; + const void* packed_w; + uint8_t* c; + size_t c_stride; + const int32_t* a_sum; + size_t groups; + size_t batch_size; + size_t a_sum_stride; + union pytorch_qnnp_q31_requantization_params requantization_params; + const pytorch_q8gemm_xzp_ukernel_function ukernel; +}; + +static void compute_q8gemm_xzp( + const struct q8gemm_xzp_context context[RESTRICT_STATIC 1], + size_t group_index, + size_t pixel_index, + size_t mr_block_start, + size_t nr_block_start, + size_t group_range /* always 1 */, + size_t pixel_range, + size_t mr_block_size, + size_t nr_block_size) { + const size_t k = context->k; + const size_t k_stride = context->k_stride; + const size_t n = context->n; + const size_t n_stride = context->n_stride; + const uint8_t* restrict a = context->a; + const size_t a_stride = context->a_stride; + const void* restrict packed_w = context->packed_w; + uint8_t* restrict c = context->c; + const size_t c_stride = context->c_stride; + const int32_t* a_sum = context->a_sum; + const size_t groups = context->groups; + const size_t a_sum_stride = context->a_sum_stride; + + context->ukernel( + mr_block_size, + nr_block_size, + k, + a + (pixel_index + mr_block_start) * a_stride + group_index * k, + a_stride, + a_sum + pixel_index * groups + group_index * a_sum_stride + + mr_block_start, + (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (k_stride * sizeof(uint8_t) + sizeof(int32_t))), + c + (pixel_index + mr_block_start) * c_stride + nr_block_start + + group_index * n, + c_stride, + &context->requantization_params); +} + +struct q8conv_context { + size_t bs; + size_t ks; + size_t kc; + size_t kc_stride; + size_t m; + size_t m_stride; + size_t n; + size_t n_stride; + const uint8_t** indirect_a; + const void* packed_w; + uint8_t* c; + size_t c_stride; + union pytorch_qnnp_conv_quantization_params quantization_params; + const pytorch_q8conv_ukernel_function ukernel; +}; + +static void compute_q8conv( + const struct q8conv_context context[RESTRICT_STATIC 1], + size_t group_index, + size_t image_index, + size_t mr_block_start, + size_t nr_block_start, + size_t group_range /* always 1 */, + size_t image_range /* always 1 */, + size_t mr_block_size, + size_t nr_block_size) { + const size_t bs = context->bs; + const size_t ks = context->ks; + const size_t kc = context->kc; + const size_t kc_stride = context->kc_stride; + const size_t m = context->m; + const size_t m_stride = context->m_stride; + const size_t n = context->n; + const size_t n_stride = context->n_stride; + const uint8_t** restrict indirect_a = context->indirect_a; + const void* restrict packed_w = context->packed_w; + uint8_t* restrict c = context->c; + const size_t c_stride = context->c_stride; + + context->ukernel( + mr_block_size, + nr_block_size, + kc, + ks, + indirect_a + + (mr_block_start + (image_index + group_index * bs) * m_stride) * ks, + (const void*)((uintptr_t)packed_w + (nr_block_start + group_index * n_stride) * (kc_stride * sizeof(uint8_t) + sizeof(int32_t))), + c + (mr_block_start + image_index * m) * c_stride + group_index * n + + nr_block_start, + c_stride, + &context->quantization_params); +} + +struct q8dwconv_context { + size_t groups; + size_t group_stride; + const uint8_t** indirection_buffer; + size_t indirection_buffer_row_stride; + size_t indirection_buffer_col_stride; + const void* packed_weights; + uint8_t* output; + size_t output_height; + size_t output_width; + size_t output_row_stride; + size_t output_col_increment; + union pytorch_qnnp_conv_quantization_params quantization_params; + union { + const pytorch_q8dwconv_up_ukernel_function unipass_ukernel; + const pytorch_q8dwconv_mp_ukernel_function multipass_ukernel; + }; +}; + +static void compute_dwconv_unipass( + const struct q8dwconv_context context[RESTRICT_STATIC 1], + size_t image, + size_t output_y) { + const size_t output_height = context->output_height; + + context->unipass_ukernel( + context->groups, + context->output_width, + context->indirection_buffer + + (image * output_height + output_y) * + context->indirection_buffer_row_stride, + context->packed_weights, + context->output + + (image * output_height + output_y) * context->output_row_stride, + context->indirection_buffer_col_stride, + context->output_col_increment, + &context->quantization_params); +} + +static void compute_dwconv_multiipass( + const struct q8dwconv_context context[RESTRICT_STATIC 1], + size_t image, + size_t output_y) { + const size_t output_height = context->output_height; + PYTORCH_QNNP_ALIGN(16) +#ifdef _MSC_VER + int32_t* multipass_acc = _malloca(sizeof(int32_t) * context->group_stride); +#else + int32_t multipass_acc[context->group_stride]; +#endif + + context->multipass_ukernel( + context->groups, + context->output_width, + context->indirection_buffer + + (image * output_height + output_y) * + context->indirection_buffer_row_stride, + context->packed_weights, + multipass_acc, + context->output + + (image * output_height + output_y) * context->output_row_stride, + context->indirection_buffer_col_stride, + context->output_col_increment, + &context->quantization_params); + +#ifdef _MSC_VER + _freea(multipass_acc); +#endif +} + +struct max_pooling_context { + const void** indirect_input; + size_t indirect_input_batch_stride; + size_t indirect_input_height_stride; + void* output; + size_t output_batch_stride; + size_t output_height_stride; + size_t output_width; + size_t pooling_size; + size_t channels; + size_t input_increment; + size_t output_increment; + union pytorch_qnnp_u8_clamping_params params; + pytorch_u8maxpool_ukernel_function ukernel; +}; + +static void compute_max_pooling( + const struct max_pooling_context context[RESTRICT_STATIC 1], + size_t batch_index, + size_t output_y) { + const void** indirect_input = + (const void**) ((uintptr_t) context->indirect_input + + batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride); + void* output = + (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride); + + context->ukernel( + context->output_width, + context->pooling_size, + context->channels, + (const uint8_t**)indirect_input, + output, + context->input_increment, + context->output_increment, + &context->params); +} + +struct average_pooling_context { + const void** indirect_input; + size_t indirect_input_batch_stride; + size_t indirect_input_height_stride; + void* output; + size_t output_batch_stride; + size_t output_height_stride; + size_t output_width; + size_t pooling_size; + size_t channels; + size_t packed_channels; + const void* zero; + size_t input_increment; + size_t output_increment; + union pytorch_qnnp_avgpool_quantization_params quantization_params; + union { + pytorch_q8avgpool_up_ukernel_function unipass_ukernel; + pytorch_q8avgpool_mp_ukernel_function multipass_ukernel; + }; +}; + +static void compute_average_pooling_unipass( + const struct average_pooling_context context[RESTRICT_STATIC 1], + size_t batch_index, + size_t output_y) { + const void** indirect_input = + (const void**) ((uintptr_t) context->indirect_input + + batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride); + void* output = + (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride); + + context->unipass_ukernel( + context->output_width, + context->pooling_size, + context->channels, + (const uint8_t**)indirect_input, + context->zero, + output, + context->input_increment, + context->output_increment, + &context->quantization_params); +} + +static void compute_average_pooling_multipass( + const struct average_pooling_context context[RESTRICT_STATIC 1], + size_t batch_index, + size_t output_y) { + const void** indirect_input = + (const void**) ((uintptr_t) context->indirect_input + + batch_index * context->indirect_input_batch_stride + output_y * context->indirect_input_height_stride); + void* output = + (void*) ((uintptr_t) context->output + batch_index * context->output_batch_stride + output_y * context->output_height_stride); + PYTORCH_QNNP_ALIGN(16) +#ifdef _MSC_VER + int32_t* multipass_buffer = + _malloca(sizeof(int32_t) * context->packed_channels); +#else + int32_t multipass_buffer[context->packed_channels]; +#endif + + context->multipass_ukernel( + context->output_width, + context->pooling_size, + context->channels, + (const uint8_t**)indirect_input, + context->zero, + multipass_buffer, + output, + context->input_increment, + context->output_increment, + &context->quantization_params); + +#ifdef _MSC_VER + _freea(multipass_buffer); +#endif +} + +struct global_average_pooling_context { + const void* input; + const void* zero; + size_t input_pixel_stride; + size_t input_batch_stride; + size_t input_elements; + size_t channels; + size_t packed_channels; + void* output; + size_t output_batch_stride; + union pytorch_qnnp_avgpool_quantization_params quantization_params; + union { + pytorch_q8gavgpool_up_ukernel_function unipass_ukernel; + pytorch_q8gavgpool_mp_ukernel_function multipass_ukernel; + }; +}; + +static void compute_global_average_pooling_unipass( + const struct global_average_pooling_context context[RESTRICT_STATIC 1], + size_t batch_index) { + const void* input = + (const void*)((uintptr_t)context->input + batch_index * context->input_batch_stride); + void* output = + (void*)((uintptr_t)context->output + batch_index * context->output_batch_stride); + + context->unipass_ukernel( + context->input_elements, + context->channels, + input, + context->input_pixel_stride, + context->zero, + output, + &context->quantization_params); +} + +static void compute_global_average_pooling_multipass( + const struct global_average_pooling_context context[RESTRICT_STATIC 1], + size_t batch_index) { + const void* input = + (const void*)((uintptr_t)context->input + batch_index * context->input_batch_stride); + void* output = + (void*)((uintptr_t)context->output + batch_index * context->output_batch_stride); + PYTORCH_QNNP_ALIGN(16) +#ifdef _MSC_VER + int32_t* multipass_buffer = + _malloca(sizeof(int32_t) * context->packed_channels); +#else + int32_t multipass_buffer[context->packed_channels]; +#endif + + context->multipass_ukernel( + context->input_elements, + context->channels, + input, + context->input_pixel_stride, + context->zero, + multipass_buffer, + output, + &context->quantization_params); + +#ifdef _MSC_VER + _freea(multipass_buffer); +#endif +} + +struct q8add_strided_context { + size_t n; + const uint8_t* a; + size_t a_stride; + const uint8_t* b; + size_t b_stride; + const uint8_t* y; + size_t y_stride; + union pytorch_qnnp_add_quantization_params quantization_params; + pytorch_q8vadd_ukernel_function ukernel; +}; + +static void compute_q8add_strided( + const struct q8add_strided_context context[RESTRICT_STATIC 1], + size_t batch_offset, + size_t batch_range /* always 1 */) { + assert(batch_range == 1); + + const size_t n = context->n; + const size_t a_stride = context->a_stride; + const size_t b_stride = context->b_stride; + const size_t y_stride = context->y_stride; + const void* a = + (const void*)((uintptr_t)context->a + a_stride * batch_offset); + const void* b = + (const void*)((uintptr_t)context->b + b_stride * batch_offset); + void* y = (void*)((uintptr_t)context->y + y_stride * batch_offset); + + context->ukernel(n, a, b, y, &context->quantization_params); +} + +struct q8add_contiguous_context { + const uint8_t* a; + const uint8_t* b; + uint8_t* y; + union pytorch_qnnp_add_quantization_params quantization_params; + pytorch_q8vadd_ukernel_function ukernel; +}; + +static void compute_q8add_contiguous( + const struct q8add_contiguous_context context[RESTRICT_STATIC 1], + size_t offset, + size_t size) { + const void* a = (const void*)((uintptr_t)context->a + offset); + const void* b = (const void*)((uintptr_t)context->b + offset); + void* y = (void*)((uintptr_t)context->y + offset); + context->ukernel(size, a, b, y, &context->quantization_params); +} + +struct channel_shuffle_context { + const void* x; + size_t x_stride; + void* y; + size_t y_stride; + size_t n; + size_t m; + union { + pytorch_xzipc_ukernel_function fixed_ukernel; + pytorch_xzipv_ukernel_function variable_ukernel; + }; +}; + +static void compute_channel_shuffle_fixed( + const struct channel_shuffle_context context[RESTRICT_STATIC 1], + size_t index) { + const void* x = + (const void*)((uintptr_t)context->x + index * context->x_stride); + void* y = (void*)((uintptr_t)context->y + index * context->y_stride); + + context->fixed_ukernel(context->n, x, y); +} + +static void compute_channel_shuffle_variable( + const struct channel_shuffle_context context[RESTRICT_STATIC 1], + size_t index) { + const void* x = + (const void*)((uintptr_t)context->x + index * context->x_stride); + void* y = (void*)((uintptr_t)context->y + index * context->y_stride); + + context->variable_ukernel(context->n, context->m, x, y); +} + +struct lut_strided_context { + size_t n; + const void* x; + size_t x_stride; + const void* t; + void* y; + size_t y_stride; + pytorch_x8lut_ukernel_function ukernel; +}; + +static void compute_lut_strided( + const struct lut_strided_context context[RESTRICT_STATIC 1], + size_t batch_index) { + const void* x = + (const void*)((uintptr_t)context->x + context->x_stride * batch_index); + void* y = (void*)((uintptr_t)context->y + context->y_stride * batch_index); + + context->ukernel(context->n, x, context->t, y); +} + +struct lut_contiguous_context { + const void* x; + size_t x_stride; + const void* t; + void* y; + size_t y_stride; + pytorch_x8lut_ukernel_function ukernel; +}; + +static void compute_lut_contiguous( + const struct lut_contiguous_context context[RESTRICT_STATIC 1], + size_t offset, + size_t size) { + const void* x = (const void*)((uintptr_t)context->x + offset); + void* y = (void*)((uintptr_t)context->y + offset); + + context->ukernel(size, x, context->t, y); +} + +struct clamp_strided_context { + size_t n; + const void* x; + size_t x_stride; + void* y; + size_t y_stride; + pytorch_u8clamp_ukernel_function ukernel; + union pytorch_qnnp_u8_clamping_params params; +}; + +static void compute_clamp_strided( + const struct clamp_strided_context context[RESTRICT_STATIC 1], + size_t batch_index) { + const void* x = + (const void*)((uintptr_t)context->x + context->x_stride * batch_index); + void* y = (void*)((uintptr_t)context->y + context->y_stride * batch_index); + context->ukernel(context->n, x, y, &context->params); +} + +struct clamp_contiguous_context { + const void* x; + size_t x_stride; + void* y; + size_t y_stride; + pytorch_u8clamp_ukernel_function ukernel; + union pytorch_qnnp_u8_clamping_params params; +}; + +static void compute_clamp_contiguous( + const struct clamp_contiguous_context context[RESTRICT_STATIC 1], + size_t offset, + size_t size) { + const void* x = (const void*)((uintptr_t)context->x + offset); + void* y = (void*)((uintptr_t)context->y + offset); + context->ukernel(size, x, y, &context->params); +} + +struct u8softargmax_context { + size_t n; + const uint8_t* x; + size_t x_stride; + const uint32_t* t; + uint8_t* y; + size_t y_stride; + pytorch_u8rmax_ukernel_function rmax_ukernel; + pytorch_u8lut32norm_ukernel_function lut_norm_ukernel; +}; + +static void compute_u8softargmax( + const struct u8softargmax_context context[RESTRICT_STATIC 1], + size_t batch_index) { + const uint8_t* x = + (const uint8_t*)((uintptr_t)context->x + context->x_stride * batch_index); + uint8_t* y = + (uint8_t*)((uintptr_t)context->y + context->y_stride * batch_index); + const size_t n = context->n; + + const uint8_t x_max = context->rmax_ukernel(n, x); + const size_t adjustment = x_max ^ 255; + const uint32_t* t = (const uint32_t*)context->t + adjustment; + context->lut_norm_ukernel(n, x, t, y); +} + +enum pytorch_qnnp_status pytorch_qnnp_run_operator( + pytorch_qnnp_operator_t op, + pthreadpool_t threadpool) { + // For any ukernel type, there is no work to do if the batch size is 0. + if (op->batch_size == 0) { + return pytorch_qnnp_status_success; + } + + switch (op->ukernel_type) { + case pytorch_qnnp_ukernel_type_dwconv: { + const size_t batch_size = op->batch_size; + const size_t groups = op->groups; + const size_t kernel_height = op->kernel_height; + const size_t kernel_width = op->kernel_width; + const size_t kernel_size = kernel_height * kernel_width; + const size_t width_step = + op->dilation_width == 1 ? op->stride_width : op->kernel_width; + const size_t output_height = op->output_height; + const size_t output_width = op->output_width; + + switch (kernel_size) { + case 9: { + struct q8dwconv_context context = { + .groups = groups, + .indirection_buffer = (const uint8_t**)op->indirection_buffer, + .indirection_buffer_row_stride = + kernel_size + (output_width * width_step - 1) * kernel_height, + .indirection_buffer_col_stride = + kernel_height * width_step * sizeof(void*), + .packed_weights = op->packed_weights, + .output = op->output, + .output_height = output_height, + .output_width = output_width, + .output_row_stride = output_width * op->output_pixel_stride, + .output_col_increment = + (op->output_pixel_stride - groups) * sizeof(uint8_t), + .quantization_params = op->conv_quantization_params, + .unipass_ukernel = pytorch_qnnp_params.q8dw9.updw, + }; + pthreadpool_compute_2d( + threadpool, + (pthreadpool_function_2d_t)compute_dwconv_unipass, + &context, + batch_size, + output_height); + break; + } + case 25: { + struct q8dwconv_context context = { + .groups = groups, + .group_stride = op->group_stride, + .indirection_buffer = (const uint8_t**)op->indirection_buffer, + .indirection_buffer_row_stride = + kernel_size + (output_width * width_step - 1) * kernel_height, + .indirection_buffer_col_stride = + kernel_height * width_step * sizeof(void*), + .packed_weights = op->packed_weights, + .output = op->output, + .output_height = output_height, + .output_width = output_width, + .output_row_stride = output_width * op->output_pixel_stride, + .output_col_increment = + (op->output_pixel_stride - groups) * sizeof(uint8_t), + .quantization_params = op->conv_quantization_params, + .multipass_ukernel = pytorch_qnnp_params.q8dw25.mpdw, + }; + pthreadpool_compute_2d( + threadpool, + (pthreadpool_function_2d_t)compute_dwconv_multiipass, + &context, + batch_size, + output_height); + break; + } + default: + PYTORCH_QNNP_UNREACHABLE; + } + break; + } + case pytorch_qnnp_ukernel_type_xzp_gemm: { + const size_t batch_size = op->batch_size; + const size_t groups = op->groups; + const size_t group_input_channels = op->group_input_channels; + const size_t group_output_channels = op->group_output_channels; + const uint32_t mr = pytorch_qnnp_params.q8conv_xzp.mr; + const uint32_t nr = pytorch_qnnp_params.q8conv_xzp.nr; + const uint32_t kr = pytorch_qnnp_params.q8conv_xzp.kr; + const size_t k_stride = (group_input_channels + (kr - 1)) & -kr; + const size_t n_stride = (group_output_channels + (nr - 1)) & -nr; + + /* compute input row sum */ + const size_t input_size = op->input_height * op->input_width; + int32_t* a_sum = (int32_t*)op->a_sum; + + struct q8sum_rows_context context = { + .a = op->input, + .groups = groups, + .m = input_size, + .k = group_input_channels, + .a_stride = op->input_pixel_stride, + .multiplier = (int32_t)-op->kernel_zero_point, + .a_sum = a_sum, + .a_sum_stride = input_size, + .ukernel = pytorch_qnnp_params.q8sum_rows.sum_rows, + }; + pthreadpool_compute_3d_tiled( + threadpool, + (pthreadpool_function_3d_tiled_t)compute_sum_rows, + &context, + groups, + batch_size, + input_size, + 1, + 1, + pytorch_qnnp_params.q8sum_rows.m); + + struct q8gemm_xzp_context q8gemm_xzp_context = { + .k = group_input_channels, + .k_stride = k_stride, + .n = group_output_channels, + .n_stride = n_stride, + .a = op->input, + .a_stride = op->input_pixel_stride, + .packed_w = op->packed_weights, + .c = op->output, + .c_stride = op->output_pixel_stride, + .a_sum = a_sum, + .groups = op->groups, + .batch_size = batch_size, + .a_sum_stride = input_size, + .requantization_params = op->requantization_params, + .ukernel = pytorch_qnnp_params.q8conv_xzp.gemm, + }; + pthreadpool_compute_4d_tiled( + threadpool, + (pthreadpool_function_4d_tiled_t)compute_q8gemm_xzp, + &q8gemm_xzp_context, + groups, + batch_size * input_size, + input_size, + group_output_channels, + 1, + input_size, + mr, + nr); + break; + } + case pytorch_qnnp_ukernel_type_gemm: { + const size_t batch_size = op->batch_size; + const size_t groups = op->groups; + const size_t group_input_channels = op->group_input_channels; + const size_t group_output_channels = op->group_output_channels; + const uint32_t mr = pytorch_qnnp_params.q8conv.mr; + const uint32_t nr = pytorch_qnnp_params.q8conv.nr; + const uint32_t kr = pytorch_qnnp_params.q8conv.kr; + const size_t k_stride = (group_input_channels + (kr - 1)) & -kr; + const size_t n_stride = (group_output_channels + (nr - 1)) & -nr; + + const size_t output_size = op->output_height * op->output_width; + struct q8gemm_context q8gemm_context = { + .k = group_input_channels, + .k_stride = k_stride, + .n = group_output_channels, + .n_stride = n_stride, + .a = op->input, + .a_stride = op->input_pixel_stride, + .packed_w = op->packed_weights, + .c = op->output, + .c_stride = op->output_pixel_stride, + .quantization_params = op->conv_quantization_params, + .ukernel = pytorch_qnnp_params.q8conv.gemm, + }; + + pthreadpool_compute_4d_tiled( + threadpool, + (pthreadpool_function_4d_tiled_t)compute_q8gemm, + &q8gemm_context, + groups, + batch_size * output_size, + output_size, + group_output_channels, + 1, + output_size, + mr, + nr); + break; + } + case pytorch_qnnp_ukernel_type_conv: { + const size_t batch_size = op->batch_size; + const size_t groups = op->groups; + const size_t group_input_channels = op->group_input_channels; + const size_t group_output_channels = op->group_output_channels; + const uint32_t mr = pytorch_qnnp_params.q8conv.mr; + const uint32_t nr = pytorch_qnnp_params.q8conv.nr; + const uint32_t kr = pytorch_qnnp_params.q8conv.kr; + const size_t k_stride = (group_input_channels + (kr - 1)) & -kr; + const size_t n_stride = (group_output_channels + (nr - 1)) & -nr; + + const size_t output_size = op->output_height * op->output_width; + const size_t kernel_size = op->kernel_height * op->kernel_width; + const size_t m_stride = round_up(output_size, mr); + struct q8conv_context q8conv_context = { + .bs = batch_size, + .ks = kernel_size, + .kc = group_input_channels, + .kc_stride = k_stride * kernel_size, + .m = output_size, + .m_stride = m_stride, + .n = group_output_channels, + .n_stride = n_stride, + .indirect_a = (const uint8_t**)op->indirection_buffer, + .packed_w = op->packed_weights, + .c = op->output, + .c_stride = op->output_pixel_stride, + .quantization_params = op->conv_quantization_params, + .ukernel = pytorch_qnnp_params.q8conv.conv, + }; + + pthreadpool_compute_4d_tiled( + threadpool, + (pthreadpool_function_4d_tiled_t)compute_q8conv, + &q8conv_context, + groups, + batch_size, + output_size, + group_output_channels, + 1, + 1, + mr, + nr); + break; + } + case pytorch_qnnp_ukernel_type_average_pooling: { + const uint32_t kr = pytorch_qnnp_params.q8avgpool.kr; + const uint32_t mr = pytorch_qnnp_params.q8avgpool.mr; + const uint32_t qr = pytorch_qnnp_params.q8avgpool.qr; + const size_t channels = op->channels; + const size_t output_width = op->output_width; + const size_t output_height = op->output_height; + const size_t pooling_height = op->kernel_height; + const size_t pooling_width = op->kernel_width; + const size_t pooling_size = pooling_height * pooling_width; + + const size_t width_step = min(op->stride_width, pooling_width); + const size_t indirect_input_height_stride = + (pooling_size + (output_width * width_step - 1) * pooling_height) * + sizeof(void*); + const size_t output_height_stride = + output_width * op->output_pixel_stride; + + size_t multipass_adjustment = 0; + if (channels >= kr && pooling_size > mr) { + multipass_adjustment = round_up(pooling_size - mr, qr) + mr - qr; + } + struct average_pooling_context context = { + .indirect_input = op->indirection_buffer, + .indirect_input_batch_stride = + output_height * indirect_input_height_stride, + .indirect_input_height_stride = indirect_input_height_stride, + .output = op->output, + .output_batch_stride = output_height * output_height_stride, + .output_height_stride = output_height_stride, + .output_width = output_width, + .pooling_size = pooling_size, + .channels = channels, + .packed_channels = (channels + (kr - 1)) & -kr, + .zero = op->zero_pointer, + .input_increment = + (pooling_height * width_step - multipass_adjustment) * + sizeof(void*), + .output_increment = + (op->output_pixel_stride - channels) * sizeof(uint8_t), + .quantization_params = op->avgpool_quantization_params, + }; + + pthreadpool_function_2d_t compute_function = NULL; + if (channels < kr) { + compute_function = + (pthreadpool_function_2d_t)compute_average_pooling_unipass; + context.unipass_ukernel = pytorch_qnnp_params.q8avgpool.ltkr; + } else { + if (pooling_size <= mr) { + compute_function = + (pthreadpool_function_2d_t)compute_average_pooling_unipass; + context.unipass_ukernel = pytorch_qnnp_params.q8avgpool.gekr_lemr; + } else { + compute_function = + (pthreadpool_function_2d_t)compute_average_pooling_multipass; + context.multipass_ukernel = pytorch_qnnp_params.q8avgpool.gekr_gtmr; + } + } + + pthreadpool_compute_2d( + threadpool, + compute_function, + &context, + op->batch_size, + output_height); + break; + } + case pytorch_qnnp_ukernel_type_max_pooling: { + const uint32_t kr = pytorch_qnnp_params.u8maxpool.kr; + const uint32_t mr = pytorch_qnnp_params.u8maxpool.mr; + const uint32_t qr = pytorch_qnnp_params.u8maxpool.qr; + const size_t channels = op->channels; + const size_t output_width = op->output_width; + const size_t output_height = op->output_height; + const size_t pooling_height = op->kernel_height; + const size_t pooling_width = op->kernel_width; + const size_t pooling_size = pooling_height * pooling_width; + + const size_t width_step = op->dilation_width > 1 + ? pooling_width + : min(op->stride_width, pooling_width); + const size_t indirect_input_height_stride = + (pooling_size + (output_width * width_step - 1) * pooling_height) * + sizeof(void*); + const size_t output_height_stride = + output_width * op->output_pixel_stride; + + size_t multipass_adjustment = pooling_size; + if (channels >= kr) { + multipass_adjustment = round_up(doz(pooling_size, mr), qr) + mr; + } + struct max_pooling_context context = { + .indirect_input = op->indirection_buffer, + .indirect_input_batch_stride = + output_height * indirect_input_height_stride, + .indirect_input_height_stride = indirect_input_height_stride, + .output = op->output, + .output_batch_stride = output_height * output_height_stride, + .output_height_stride = output_height_stride, + .output_width = output_width, + .pooling_size = pooling_size, + .channels = channels, + .input_increment = + (pooling_height * width_step - multipass_adjustment) * + sizeof(void*), + .output_increment = + (op->output_pixel_stride - channels) * sizeof(uint8_t), + .params = op->u8_clamping_params, + .ukernel = channels < kr ? pytorch_qnnp_params.u8maxpool.ltkr + : pytorch_qnnp_params.u8maxpool.gekr, + }; + + pthreadpool_compute_2d( + threadpool, + (pthreadpool_function_2d_t)compute_max_pooling, + &context, + op->batch_size, + output_height); + break; + }; + case pytorch_qnnp_ukernel_type_add: { + const size_t batch_size = op->batch_size; + const size_t channels = op->channels; + const size_t a_stride = op->input_pixel_stride; + const size_t b_stride = op->input2_pixel_stride; + const size_t y_stride = op->output_pixel_stride; + if ((((a_stride ^ channels) | (b_stride ^ channels) | + (y_stride ^ channels)) == 0) || + batch_size == 1) { + const size_t block_size = 4096; + struct q8add_contiguous_context add_context = { + .a = op->input, + .b = op->input2, + .y = op->output, + .quantization_params = op->add_quantization_params, + .ukernel = pytorch_qnnp_params.q8vadd, + }; + pthreadpool_compute_1d_tiled( + threadpool, + (pthreadpool_function_1d_tiled_t)compute_q8add_contiguous, + &add_context, + batch_size * channels * sizeof(uint8_t), + block_size); + } else { + struct q8add_strided_context add_context = { + .a = op->input, + .a_stride = a_stride * sizeof(uint8_t), + .b = op->input2, + .b_stride = b_stride * sizeof(uint8_t), + .y = op->output, + .y_stride = y_stride * sizeof(uint8_t), + .n = channels, + .quantization_params = op->add_quantization_params, + .ukernel = pytorch_qnnp_params.q8vadd, + }; + pthreadpool_compute_1d_tiled( + threadpool, + (pthreadpool_function_1d_tiled_t)compute_q8add_strided, + &add_context, + batch_size, + 1); + } + break; + } + case pytorch_qnnp_ukernel_type_global_average_pooling: { + const uint32_t nr = pytorch_qnnp_params.q8gavgpool.nr; + const uint32_t mr = pytorch_qnnp_params.q8gavgpool.mr; + const size_t input_pixel_stride = + op->input_pixel_stride * sizeof(uint8_t); + const size_t input_width = op->input_width; + const size_t channels = op->channels; + struct global_average_pooling_context context = { + .input = op->input, + .zero = op->zero_pointer, + .input_pixel_stride = input_pixel_stride, + .input_batch_stride = input_pixel_stride * input_width, + .input_elements = input_width, + .channels = channels, + .packed_channels = (channels + (nr - 1)) & -nr, + .output = op->output, + .output_batch_stride = op->output_pixel_stride * sizeof(uint8_t), + .quantization_params = op->avgpool_quantization_params, + }; + pthreadpool_function_1d_t compute_function = NULL; + if (channels < nr) { + compute_function = + (pthreadpool_function_1d_t)compute_global_average_pooling_unipass; + context.unipass_ukernel = pytorch_qnnp_params.q8gavgpool.ltnr; + } else { + if (input_width <= mr) { + compute_function = + (pthreadpool_function_1d_t)compute_global_average_pooling_unipass; + context.unipass_ukernel = pytorch_qnnp_params.q8gavgpool.genr_lemr; + } else { + compute_function = (pthreadpool_function_1d_t) + compute_global_average_pooling_multipass; + context.multipass_ukernel = pytorch_qnnp_params.q8gavgpool.genr_gtmr; + } + } + + pthreadpool_compute_1d( + threadpool, compute_function, &context, op->batch_size); + break; + } + case pytorch_qnnp_ukernel_type_lut: { + const size_t batch_size = op->batch_size; + const size_t channels = op->channels; + const size_t x_stride = op->input_pixel_stride; + const size_t y_stride = op->output_pixel_stride; + if ((((x_stride ^ channels) | (y_stride ^ channels)) == 0) || + batch_size == 1) { + const size_t block_size = 1024; + struct lut_contiguous_context context = { + .x = op->input, + .x_stride = x_stride * sizeof(uint8_t), + .t = op->lookup_table, + .y = op->output, + .y_stride = y_stride * sizeof(uint8_t), + .ukernel = pytorch_qnnp_params.x8lut, + }; + pthreadpool_compute_1d_tiled( + threadpool, + (pthreadpool_function_1d_tiled_t)compute_lut_contiguous, + &context, + batch_size * channels * sizeof(uint8_t), + block_size); + } else { + struct lut_strided_context context = { + .n = channels, + .x = op->input, + .x_stride = x_stride * sizeof(uint8_t), + .t = op->lookup_table, + .y = op->output, + .y_stride = y_stride * sizeof(uint8_t), + .ukernel = pytorch_qnnp_params.x8lut, + }; + pthreadpool_compute_1d( + threadpool, + (pthreadpool_function_1d_t)compute_lut_strided, + &context, + batch_size); + } + break; + } + case pytorch_qnnp_ukernel_type_clamp: { + const size_t batch_size = op->batch_size; + const size_t channels = op->channels; + const size_t x_stride = op->input_pixel_stride; + const size_t y_stride = op->output_pixel_stride; + if ((((x_stride ^ channels) | (y_stride ^ channels)) == 0) || + batch_size == 1) { + const size_t block_size = 4096; + struct clamp_contiguous_context context = { + .x = op->input, + .x_stride = x_stride * sizeof(uint8_t), + .y = op->output, + .y_stride = y_stride * sizeof(uint8_t), + .ukernel = pytorch_qnnp_params.u8clamp, + .params = op->u8_clamping_params, + }; + pthreadpool_compute_1d_tiled( + threadpool, + (pthreadpool_function_1d_tiled_t)compute_clamp_contiguous, + &context, + batch_size * channels * sizeof(uint8_t), + block_size); + } else { + struct clamp_strided_context context = { + .n = channels, + .x = op->input, + .x_stride = x_stride * sizeof(uint8_t), + .y = op->output, + .y_stride = y_stride * sizeof(uint8_t), + .ukernel = pytorch_qnnp_params.u8clamp, + .params = op->u8_clamping_params, + }; + pthreadpool_compute_1d( + threadpool, + (pthreadpool_function_1d_t)compute_clamp_strided, + &context, + batch_size); + } + break; + } + case pytorch_qnnp_ukernel_type_softargmax: { + struct u8softargmax_context context = { + .n = op->channels, + .x = op->input, + .x_stride = op->input_pixel_stride * sizeof(uint8_t), + .t = op->lookup_table, + .y = op->output, + .y_stride = op->output_pixel_stride * sizeof(uint8_t), + .rmax_ukernel = pytorch_qnnp_params.u8rmax, + .lut_norm_ukernel = pytorch_qnnp_params.u8lut32norm, + }; + pthreadpool_compute_1d( + threadpool, + (pthreadpool_function_1d_t)compute_u8softargmax, + &context, + op->batch_size); + break; + } + case pytorch_qnnp_ukernel_type_channel_shuffle: { + const size_t groups = op->groups; + struct channel_shuffle_context channel_shuffle_context = { + .x = op->input, + .x_stride = op->input_pixel_stride * sizeof(uint8_t), + .y = op->output, + .y_stride = op->output_pixel_stride * sizeof(uint8_t), + .n = op->group_channels * sizeof(uint8_t), + .m = groups, + }; + pthreadpool_function_1d_t compute_function = NULL; + switch (groups) { + case 2: + compute_function = + (pthreadpool_function_1d_t)compute_channel_shuffle_fixed; + channel_shuffle_context.fixed_ukernel = pytorch_qnnp_params.x8zip.x2; + break; + case 3: + compute_function = + (pthreadpool_function_1d_t)compute_channel_shuffle_fixed; + channel_shuffle_context.fixed_ukernel = pytorch_qnnp_params.x8zip.x3; + break; + case 4: + compute_function = + (pthreadpool_function_1d_t)compute_channel_shuffle_fixed; + channel_shuffle_context.fixed_ukernel = pytorch_qnnp_params.x8zip.x4; + break; + default: + compute_function = + (pthreadpool_function_1d_t)compute_channel_shuffle_variable; + channel_shuffle_context.variable_ukernel = + pytorch_qnnp_params.x8zip.xm; + break; + case 0: + case 1: + PYTORCH_QNNP_UNREACHABLE; + } + pthreadpool_compute_1d( + threadpool, + compute_function, + &channel_shuffle_context, + op->batch_size); + break; + } + default: + PYTORCH_QNNP_UNREACHABLE; + } + return pytorch_qnnp_status_success; +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/mp8x9p8q-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/mp8x9p8q-neon.c new file mode 100644 index 0000000000000..d57efa24cc7a5 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/mp8x9p8q-neon.c @@ -0,0 +1,547 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +void pytorch_q8avgpool_ukernel_mp8x9p8q__neon( + size_t n, + size_t ks, + size_t kc, + const uint8_t** input, + const uint8_t* zero, + int32_t* buffer, + uint8_t* output, + size_t input_increment, + size_t output_increment, + const union pytorch_qnnp_avgpool_quantization_params + quantization_params[restrict static 1]) { + assert(n != 0); + assert(ks > 9); + assert(kc >= 8); + + const int32x4_t vbias = vld1q_dup_s32(&quantization_params->neon.bias); +#ifdef __aarch64__ + const int32x4_t vmultiplier = + vld1q_dup_s32(&quantization_params->neon.multiplier); +#else + const int32x2_t vmultiplier = + vld1_dup_s32(&quantization_params->neon.multiplier); +#endif + const int64x2_t vleft_shift = + vld1q_dup_s64(&quantization_params->neon.left_shift); + const int16x8_t voutput_zero_point = + vld1q_dup_s16(&quantization_params->neon.output_zero_point); + const uint8x8_t voutput_min = + vld1_dup_u8(&quantization_params->neon.output_min); + const uint8x8_t voutput_max = + vld1_dup_u8(&quantization_params->neon.output_max); + + do { + { + const uint8_t* i0 = *input++; + const uint8_t* i1 = *input++; + const uint8_t* i2 = *input++; + const uint8_t* i3 = *input++; + const uint8_t* i4 = *input++; + const uint8_t* i5 = *input++; + const uint8_t* i6 = *input++; + const uint8_t* i7 = *input++; + const uint8_t* i8 = *input++; + + size_t k = kc; + int32_t* acc = buffer; + while (k >= 8) { + const uint8x8_t vi0 = vld1_u8(i0); + i0 += 8; + const uint8x8_t vi1 = vld1_u8(i1); + i1 += 8; + const uint8x8_t vi2 = vld1_u8(i2); + i2 += 8; + const uint8x8_t vi3 = vld1_u8(i3); + i3 += 8; + const uint8x8_t vi4 = vld1_u8(i4); + i4 += 8; + const uint8x8_t vi5 = vld1_u8(i5); + i5 += 8; + const uint8x8_t vi6 = vld1_u8(i6); + i6 += 8; + const uint8x8_t vi7 = vld1_u8(i7); + i7 += 8; + const uint8x8_t vi8 = vld1_u8(i8); + i8 += 8; + + const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8); + const uint16x8_t vsum23 = vaddl_u8(vi2, vi3); + const uint16x8_t vsum45 = vaddl_u8(vi4, vi5); + const uint16x8_t vsum67 = vaddl_u8(vi6, vi7); + + const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45); + const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67); + const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678); + + const int32x4_t vacc_lo = + vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum))); + const int32x4_t vacc_hi = + vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum))); + + vst1q_s32(acc, vacc_lo); + acc += 4; + vst1q_s32(acc, vacc_hi); + acc += 4; + + k -= 8; + } + if (k != 0) { + const size_t address_increment = k - 8; + i0 = (const uint8_t*)((uintptr_t)i0 + address_increment); + i1 = (const uint8_t*)((uintptr_t)i1 + address_increment); + i2 = (const uint8_t*)((uintptr_t)i2 + address_increment); + i3 = (const uint8_t*)((uintptr_t)i3 + address_increment); + i4 = (const uint8_t*)((uintptr_t)i4 + address_increment); + i5 = (const uint8_t*)((uintptr_t)i5 + address_increment); + i6 = (const uint8_t*)((uintptr_t)i6 + address_increment); + i7 = (const uint8_t*)((uintptr_t)i7 + address_increment); + i8 = (const uint8_t*)((uintptr_t)i8 + address_increment); + const int64x1_t vshift = vmov_n_s64(8 * address_increment); + + const uint8x8_t vi0 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vshift)); + const uint8x8_t vi1 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vshift)); + const uint8x8_t vi2 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vshift)); + const uint8x8_t vi3 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vshift)); + const uint8x8_t vi4 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vshift)); + const uint8x8_t vi5 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vshift)); + const uint8x8_t vi6 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vshift)); + const uint8x8_t vi7 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vshift)); + const uint8x8_t vi8 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i8)), vshift)); + + const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8); + const uint16x8_t vsum23 = vaddl_u8(vi2, vi3); + const uint16x8_t vsum45 = vaddl_u8(vi4, vi5); + const uint16x8_t vsum67 = vaddl_u8(vi6, vi7); + + const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45); + const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67); + const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678); + + const int32x4_t vacc_lo = + vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum))); + const int32x4_t vacc_hi = + vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum))); + + vst1q_s32(acc, vacc_lo); + acc += 4; + vst1q_s32(acc, vacc_hi); + } + } + + size_t m = ks; + for (m -= 9; m > 8; m -= 8) { + const uint8_t* i0 = *input++; + const uint8_t* i1 = *input++; + const uint8_t* i2 = *input++; + const uint8_t* i3 = *input++; + const uint8_t* i4 = *input++; + const uint8_t* i5 = *input++; + const uint8_t* i6 = *input++; + const uint8_t* i7 = *input++; + + size_t k = kc; + int32_t* acc = buffer; + while (k >= 8) { + const uint8x8_t vi0 = vld1_u8(i0); + i0 += 8; + const uint8x8_t vi1 = vld1_u8(i1); + i1 += 8; + const uint8x8_t vi2 = vld1_u8(i2); + i2 += 8; + const uint8x8_t vi3 = vld1_u8(i3); + i3 += 8; + const uint8x8_t vi4 = vld1_u8(i4); + i4 += 8; + const uint8x8_t vi5 = vld1_u8(i5); + i5 += 8; + const uint8x8_t vi6 = vld1_u8(i6); + i6 += 8; + const uint8x8_t vi7 = vld1_u8(i7); + i7 += 8; + int32x4_t vacc_lo = vld1q_s32(acc); + int32x4_t vacc_hi = vld1q_s32(acc + 4); + + const uint16x8_t vsum01 = vaddl_u8(vi0, vi1); + const uint16x8_t vsum23 = vaddl_u8(vi2, vi3); + const uint16x8_t vsum45 = vaddl_u8(vi4, vi5); + const uint16x8_t vsum67 = vaddl_u8(vi6, vi7); + + const uint16x8_t vsum0123 = vaddq_u16(vsum01, vsum23); + const uint16x8_t vsum4567 = vaddq_u16(vsum45, vsum67); + const uint16x8_t vsum = vaddq_u16(vsum0123, vsum4567); + + vacc_lo = vaddw_s16(vacc_lo, vreinterpret_s16_u16(vget_low_u16(vsum))); + vacc_hi = vaddw_s16(vacc_hi, vreinterpret_s16_u16(vget_high_u16(vsum))); + + vst1q_s32(acc, vacc_lo); + acc += 4; + vst1q_s32(acc, vacc_hi); + acc += 4; + + k -= 8; + } + if (k != 0) { + const size_t address_increment = k - 8; + i0 = (const uint8_t*)((uintptr_t)i0 + address_increment); + i1 = (const uint8_t*)((uintptr_t)i1 + address_increment); + i2 = (const uint8_t*)((uintptr_t)i2 + address_increment); + i3 = (const uint8_t*)((uintptr_t)i3 + address_increment); + i4 = (const uint8_t*)((uintptr_t)i4 + address_increment); + i5 = (const uint8_t*)((uintptr_t)i5 + address_increment); + i6 = (const uint8_t*)((uintptr_t)i6 + address_increment); + i7 = (const uint8_t*)((uintptr_t)i7 + address_increment); + const int64x1_t vshift = vmov_n_s64(8 * address_increment); + + const uint8x8_t vi0 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vshift)); + const uint8x8_t vi1 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vshift)); + const uint8x8_t vi2 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vshift)); + const uint8x8_t vi3 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vshift)); + const uint8x8_t vi4 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vshift)); + const uint8x8_t vi5 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vshift)); + const uint8x8_t vi6 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vshift)); + const uint8x8_t vi7 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vshift)); + int32x4_t vacc_lo = vld1q_s32(acc); + int32x4_t vacc_hi = vld1q_s32(acc + 4); + + const uint16x8_t vsum01 = vaddl_u8(vi0, vi1); + const uint16x8_t vsum23 = vaddl_u8(vi2, vi3); + const uint16x8_t vsum45 = vaddl_u8(vi4, vi5); + const uint16x8_t vsum67 = vaddl_u8(vi6, vi7); + + const uint16x8_t vsum0123 = vaddq_u16(vsum01, vsum23); + const uint16x8_t vsum4567 = vaddq_u16(vsum45, vsum67); + const uint16x8_t vsum = vaddq_u16(vsum0123, vsum4567); + + vacc_lo = vaddw_s16(vacc_lo, vreinterpret_s16_u16(vget_low_u16(vsum))); + vacc_hi = vaddw_s16(vacc_hi, vreinterpret_s16_u16(vget_high_u16(vsum))); + + vst1q_s32(acc, vacc_lo); + acc += 4; + vst1q_s32(acc, vacc_hi); + } + } + + { + const uint8_t* i0 = input[0]; + const uint8_t* i1 = input[1]; + const uint8_t* i2 = input[2]; + const uint8_t* i3 = input[3]; + const uint8_t* i4 = input[4]; + const uint8_t* i5 = input[5]; + const uint8_t* i6 = input[6]; + const uint8_t* i7 = input[7]; + input = (const uint8_t**)((uintptr_t)input + input_increment); + if (m < 2) { + i1 = zero; + } + if (m <= 2) { + i2 = zero; + } + if (m < 4) { + i3 = zero; + } + if (m <= 4) { + i4 = zero; + } + if (m < 6) { + i5 = zero; + } + if (m <= 6) { + i6 = zero; + } + if (m != 8) { + i7 = zero; + } + + size_t k = kc; + int32_t* acc = buffer; + while (k >= 8) { + const uint8x8_t vi0 = vld1_u8(i0); + i0 += 8; + const uint8x8_t vi1 = vld1_u8(i1); + i1 += 8; + const uint8x8_t vi2 = vld1_u8(i2); + i2 += 8; + const uint8x8_t vi3 = vld1_u8(i3); + i3 += 8; + const uint8x8_t vi4 = vld1_u8(i4); + i4 += 8; + const uint8x8_t vi5 = vld1_u8(i5); + i5 += 8; + const uint8x8_t vi6 = vld1_u8(i6); + i6 += 8; + const uint8x8_t vi7 = vld1_u8(i7); + i7 += 8; + int32x4_t vacc_lo = vld1q_s32(acc); + acc += 4; + int32x4_t vacc_hi = vld1q_s32(acc); + acc += 4; + + const int16x8_t vsum01 = vreinterpretq_s16_u16(vaddl_u8(vi0, vi1)); + const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3)); + const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5)); + const int16x8_t vsum67 = vreinterpretq_s16_u16(vaddl_u8(vi6, vi7)); + + const int16x8_t vsum0123 = vaddq_s16(vsum01, vsum23); + const int16x8_t vsum4567 = vaddq_s16(vsum45, vsum67); + const int16x8_t vsum = vaddq_s16(vsum0123, vsum4567); + + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum)); + + const int32x4_t vneg_mask_lo = + vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0))); + const int32x4_t vneg_mask_hi = + vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0))); + +#if defined(__aarch64__) + const int64x2_t vproduct01 = + vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier)); + const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier); + const int64x2_t vproduct45 = + vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier)); + const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_high_s32(vproduct23, vneg_mask_lo); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_high_s32(vproduct67, vneg_mask_hi); +#else + const int64x2_t vproduct01 = + vmull_s32(vget_low_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct23 = + vmull_s32(vget_high_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct45 = + vmull_s32(vget_low_s32(vacc_hi), vmultiplier); + const int64x2_t vproduct67 = + vmull_s32(vget_high_s32(vacc_hi), vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi)); +#endif + + const int64x2_t vscaled_acc01 = + vrshlq_s64(vadjusted_product01, vleft_shift); + const int64x2_t vscaled_acc23 = + vrshlq_s64(vadjusted_product23, vleft_shift); + const int64x2_t vscaled_acc45 = + vrshlq_s64(vadjusted_product45, vleft_shift); + const int64x2_t vscaled_acc67 = + vrshlq_s64(vadjusted_product67, vleft_shift); + +#ifdef __aarch64__ + vacc_lo = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc01), + vreinterpretq_s32_s64(vscaled_acc23)); + vacc_hi = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc45), + vreinterpretq_s32_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point); +#else + vacc_lo = + vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23)); + vacc_hi = + vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), + voutput_zero_point); +#endif + + uint8x8_t vout = vqmovun_s16(vacc); + vout = vmax_u8(vout, voutput_min); + vout = vmin_u8(vout, voutput_max); + + vst1_u8(output, vout); + output += 8; + + k -= 8; + } + if (k != 0) { + const size_t address_increment = k - 8; + i0 = (const uint8_t*)((uintptr_t)i0 + address_increment); + i1 = (const uint8_t*)((uintptr_t)i1 + address_increment); + i2 = (const uint8_t*)((uintptr_t)i2 + address_increment); + i3 = (const uint8_t*)((uintptr_t)i3 + address_increment); + i4 = (const uint8_t*)((uintptr_t)i4 + address_increment); + i5 = (const uint8_t*)((uintptr_t)i5 + address_increment); + i6 = (const uint8_t*)((uintptr_t)i6 + address_increment); + i7 = (const uint8_t*)((uintptr_t)i7 + address_increment); + const int64x1_t vshift = vmov_n_s64(8 * address_increment); + + const uint8x8_t vi0 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vshift)); + const uint8x8_t vi1 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vshift)); + const uint8x8_t vi2 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vshift)); + const uint8x8_t vi3 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vshift)); + const uint8x8_t vi4 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vshift)); + const uint8x8_t vi5 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vshift)); + const uint8x8_t vi6 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vshift)); + const uint8x8_t vi7 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vshift)); + int32x4_t vacc_lo = vld1q_s32(acc); + acc += 4; + int32x4_t vacc_hi = vld1q_s32(acc); + + const int16x8_t vsum01 = vreinterpretq_s16_u16(vaddl_u8(vi0, vi1)); + const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3)); + const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5)); + const int16x8_t vsum67 = vreinterpretq_s16_u16(vaddl_u8(vi6, vi7)); + + const int16x8_t vsum0123 = vaddq_s16(vsum01, vsum23); + const int16x8_t vsum4567 = vaddq_s16(vsum45, vsum67); + const int16x8_t vsum = vaddq_s16(vsum0123, vsum4567); + + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum)); + + const int32x4_t vneg_mask_lo = + vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0))); + const int32x4_t vneg_mask_hi = + vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0))); + +#if defined(__aarch64__) + const int64x2_t vproduct01 = + vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier)); + const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier); + const int64x2_t vproduct45 = + vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier)); + const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_high_s32(vproduct23, vneg_mask_lo); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_high_s32(vproduct67, vneg_mask_hi); +#else + const int64x2_t vproduct01 = + vmull_s32(vget_low_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct23 = + vmull_s32(vget_high_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct45 = + vmull_s32(vget_low_s32(vacc_hi), vmultiplier); + const int64x2_t vproduct67 = + vmull_s32(vget_high_s32(vacc_hi), vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi)); +#endif + + const int64x2_t vscaled_acc01 = + vrshlq_s64(vadjusted_product01, vleft_shift); + const int64x2_t vscaled_acc23 = + vrshlq_s64(vadjusted_product23, vleft_shift); + const int64x2_t vscaled_acc45 = + vrshlq_s64(vadjusted_product45, vleft_shift); + const int64x2_t vscaled_acc67 = + vrshlq_s64(vadjusted_product67, vleft_shift); + +#ifdef __aarch64__ + vacc_lo = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc01), + vreinterpretq_s32_s64(vscaled_acc23)); + vacc_hi = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc45), + vreinterpretq_s32_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point); +#else + vacc_lo = + vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23)); + vacc_hi = + vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), + voutput_zero_point); +#endif + + uint8x8_t vout = vqmovun_s16(vacc); + vout = vmax_u8(vout, voutput_min); + vout = vmin_u8(vout, voutput_max); + + if (k & 4) { + vst1_lane_u32( + __builtin_assume_aligned(output, 1), + vreinterpret_u32_u8(vout), + 0); + output += 4; + vout = vext_u8(vout, vout, 4); + } + if (k & 2) { + vst1_lane_u16( + __builtin_assume_aligned(output, 1), + vreinterpret_u16_u8(vout), + 0); + output += 2; + vout = vext_u8(vout, vout, 2); + } + if (k & 1) { + vst1_lane_u8(output, vout, 0); + output += 1; + } + } + } + output = (uint8_t*)((uintptr_t)output + output_increment); + } while (--n != 0); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/mp8x9p8q-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/mp8x9p8q-sse2.c new file mode 100644 index 0000000000000..d63a0a0705c88 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/mp8x9p8q-sse2.c @@ -0,0 +1,562 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +void pytorch_q8avgpool_ukernel_mp8x9p8q__sse2( + size_t n, + size_t ks, + size_t kc, + const uint8_t** input, + const uint8_t* zero, + int32_t* buffer, + uint8_t* output, + size_t input_increment, + size_t output_increment, + const union pytorch_qnnp_avgpool_quantization_params + quantization_params[RESTRICT_STATIC 1]) { + assert(n != 0); + assert(ks > 9); + assert(kc >= 8); + + const __m128i vbias = + _mm_load_si128((const __m128i*)&quantization_params->sse2.bias); + const __m128i vzero = _mm_setzero_si128(); + const __m128i vmultiplier = + _mm_load_si128((const __m128i*)quantization_params->sse2.multiplier); + const __m128i vrounding = + _mm_load_si128((const __m128i*)quantization_params->sse2.rounding); + const __m128i vright_shift = + _mm_loadl_epi64((const __m128i*)quantization_params->sse2.right_shift); + + do { + { + const uint8_t* i0 = *input++; + const uint8_t* i1 = *input++; + const uint8_t* i2 = *input++; + const uint8_t* i3 = *input++; + const uint8_t* i4 = *input++; + const uint8_t* i5 = *input++; + const uint8_t* i6 = *input++; + const uint8_t* i7 = *input++; + const uint8_t* i8 = *input++; + + size_t k = kc; + int32_t* acc = buffer; + while (k >= 8) { + const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0); + i0 += 8; + const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1); + i1 += 8; + const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2); + i2 += 8; + const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3); + i3 += 8; + const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4); + i4 += 8; + const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5); + i5 += 8; + const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6); + i6 += 8; + const __m128i vi7 = _mm_loadl_epi64((const __m128i*)i7); + i7 += 8; + const __m128i vi8 = _mm_loadl_epi64((const __m128i*)i8); + i8 += 8; + + const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero); + const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero); + const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero); + const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero); + const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero); + const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero); + const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero); + const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero); + const __m128i vxi8 = _mm_unpacklo_epi8(vi8, vzero); + + const __m128i vsum018 = _mm_add_epi16(_mm_add_epi16(vxi0, vxi1), vxi8); + const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3); + const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5); + const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7); + + const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45); + const __m128i vsum01678 = _mm_add_epi16(vsum018, vsum67); + const __m128i vsum = _mm_add_epi16(vsum2345, vsum01678); + + const __m128i vacc_lo = + _mm_add_epi32(vbias, _mm_unpacklo_epi16(vsum, vzero)); + const __m128i vacc_hi = + _mm_add_epi32(vbias, _mm_unpackhi_epi16(vsum, vzero)); + + _mm_store_si128((__m128i*)acc, vacc_lo); + _mm_store_si128((__m128i*)acc + 1, vacc_hi); + acc += 8; + + k -= 8; + } + if (k != 0) { + const size_t address_decrement = 8 - k; + i0 = (const uint8_t*)((uintptr_t)i0 - address_decrement); + i1 = (const uint8_t*)((uintptr_t)i1 - address_decrement); + i2 = (const uint8_t*)((uintptr_t)i2 - address_decrement); + i3 = (const uint8_t*)((uintptr_t)i3 - address_decrement); + i4 = (const uint8_t*)((uintptr_t)i4 - address_decrement); + i5 = (const uint8_t*)((uintptr_t)i5 - address_decrement); + i6 = (const uint8_t*)((uintptr_t)i6 - address_decrement); + i7 = (const uint8_t*)((uintptr_t)i7 - address_decrement); + i8 = (const uint8_t*)((uintptr_t)i8 - address_decrement); + const __m128i vshift = _mm_cvtsi32_si128(8 * address_decrement); + + const __m128i vi0 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i0), vshift); + const __m128i vi1 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i1), vshift); + const __m128i vi2 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i2), vshift); + const __m128i vi3 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i3), vshift); + const __m128i vi4 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i4), vshift); + const __m128i vi5 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i5), vshift); + const __m128i vi6 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i6), vshift); + const __m128i vi7 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i7), vshift); + const __m128i vi8 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i8), vshift); + + const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero); + const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero); + const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero); + const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero); + const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero); + const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero); + const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero); + const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero); + const __m128i vxi8 = _mm_unpacklo_epi8(vi8, vzero); + + const __m128i vsum018 = _mm_add_epi16(_mm_add_epi16(vxi0, vxi1), vxi8); + const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3); + const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5); + const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7); + + const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45); + const __m128i vsum01678 = _mm_add_epi16(vsum018, vsum67); + const __m128i vsum = _mm_add_epi16(vsum2345, vsum01678); + + const __m128i vacc_lo = + _mm_add_epi32(vbias, _mm_unpacklo_epi16(vsum, vzero)); + const __m128i vacc_hi = + _mm_add_epi32(vbias, _mm_unpackhi_epi16(vsum, vzero)); + + _mm_store_si128((__m128i*)acc, vacc_lo); + _mm_store_si128((__m128i*)acc + 1, vacc_hi); + } + } + + size_t m = ks; + for (m -= 9; m > 8; m -= 8) { + const uint8_t* i0 = *input++; + const uint8_t* i1 = *input++; + const uint8_t* i2 = *input++; + const uint8_t* i3 = *input++; + const uint8_t* i4 = *input++; + const uint8_t* i5 = *input++; + const uint8_t* i6 = *input++; + const uint8_t* i7 = *input++; + + size_t k = kc; + int32_t* acc = buffer; + while (k >= 8) { + const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0); + i0 += 8; + const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1); + i1 += 8; + const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2); + i2 += 8; + const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3); + i3 += 8; + const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4); + i4 += 8; + const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5); + i5 += 8; + const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6); + i6 += 8; + const __m128i vi7 = _mm_loadl_epi64((const __m128i*)i7); + i7 += 8; + __m128i vacc_lo = _mm_load_si128((const __m128i*)acc); + __m128i vacc_hi = _mm_load_si128((const __m128i*)acc + 1); + + const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero); + const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero); + const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero); + const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero); + const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero); + const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero); + const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero); + const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero); + + const __m128i vsum01 = _mm_add_epi16(vxi0, vxi1); + const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3); + const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5); + const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7); + + const __m128i vsum0123 = _mm_add_epi16(vsum01, vsum23); + const __m128i vsum4567 = _mm_add_epi16(vsum45, vsum67); + const __m128i vsum = _mm_add_epi16(vsum0123, vsum4567); + + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vsum, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vsum, vzero)); + + _mm_store_si128((__m128i*)acc, vacc_lo); + _mm_store_si128((__m128i*)acc + 1, vacc_hi); + acc += 8; + + k -= 8; + } + if (k != 0) { + const size_t address_decrement = 8 - k; + i0 = (const uint8_t*)((uintptr_t)i0 - address_decrement); + i1 = (const uint8_t*)((uintptr_t)i1 - address_decrement); + i2 = (const uint8_t*)((uintptr_t)i2 - address_decrement); + i3 = (const uint8_t*)((uintptr_t)i3 - address_decrement); + i4 = (const uint8_t*)((uintptr_t)i4 - address_decrement); + i5 = (const uint8_t*)((uintptr_t)i5 - address_decrement); + i6 = (const uint8_t*)((uintptr_t)i6 - address_decrement); + i7 = (const uint8_t*)((uintptr_t)i7 - address_decrement); + const __m128i vshift = _mm_cvtsi32_si128(8 * address_decrement); + + const __m128i vi0 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i0), vshift); + const __m128i vi1 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i1), vshift); + const __m128i vi2 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i2), vshift); + const __m128i vi3 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i3), vshift); + const __m128i vi4 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i4), vshift); + const __m128i vi5 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i5), vshift); + const __m128i vi6 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i6), vshift); + const __m128i vi7 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i7), vshift); + __m128i vacc_lo = _mm_load_si128((const __m128i*)acc); + __m128i vacc_hi = _mm_load_si128((const __m128i*)acc + 1); + + const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero); + const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero); + const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero); + const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero); + const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero); + const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero); + const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero); + const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero); + + const __m128i vsum01 = _mm_add_epi16(vxi0, vxi1); + const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3); + const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5); + const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7); + + const __m128i vsum0123 = _mm_add_epi16(vsum01, vsum23); + const __m128i vsum4567 = _mm_add_epi16(vsum45, vsum67); + const __m128i vsum = _mm_add_epi16(vsum0123, vsum4567); + + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vsum, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vsum, vzero)); + + _mm_store_si128((__m128i*)acc, vacc_lo); + _mm_store_si128((__m128i*)acc + 1, vacc_hi); + } + } + + { + const uint8_t* i0 = input[0]; + const uint8_t* i1 = input[1]; + const uint8_t* i2 = input[2]; + const uint8_t* i3 = input[3]; + const uint8_t* i4 = input[4]; + const uint8_t* i5 = input[5]; + const uint8_t* i6 = input[6]; + const uint8_t* i7 = input[7]; + input = (const uint8_t**)((uintptr_t)input + input_increment); + if (m < 2) { + i1 = zero; + } + if (m <= 2) { + i2 = zero; + } + if (m < 4) { + i3 = zero; + } + if (m <= 4) { + i4 = zero; + } + if (m < 6) { + i5 = zero; + } + if (m <= 6) { + i6 = zero; + } + if (m != 8) { + i7 = zero; + } + + size_t k = kc; + int32_t* acc = buffer; + while (k >= 8) { + const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0); + i0 += 8; + const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1); + i1 += 8; + const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2); + i2 += 8; + const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3); + i3 += 8; + const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4); + i4 += 8; + const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5); + i5 += 8; + const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6); + i6 += 8; + const __m128i vi7 = _mm_loadl_epi64((const __m128i*)i7); + i7 += 8; + __m128i vacc_lo = _mm_load_si128((const __m128i*)acc); + __m128i vacc_hi = _mm_load_si128((const __m128i*)acc + 1); + acc += 8; + + const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero); + const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero); + const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero); + const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero); + const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero); + const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero); + const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero); + const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero); + + const __m128i vsum01 = _mm_add_epi16(vxi0, vxi1); + const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3); + const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5); + const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7); + + const __m128i vsum0123 = _mm_add_epi16(vsum01, vsum23); + const __m128i vsum4567 = _mm_add_epi16(vsum45, vsum67); + const __m128i vsum = _mm_add_epi16(vsum0123, vsum4567); + + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vsum, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vsum, vzero)); + + const __m128i vneg_mask_lo = + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo); + const __m128i vneg_mask_hi = + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi); + + const __m128i vabs_lo0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vabs_hi0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi); + + const __m128i vabs_lo1032 = + _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabs_hi1032 = + _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier); + const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier); + + const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier); + const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier); + + const __m128i vabs_scaled_lo02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift); + const __m128i vabs_scaled_lo13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift); + const __m128i vabs_scaled_hi02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift); + const __m128i vabs_scaled_hi13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift); + + const __m128i vabs_scaled_lo0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_lo02), + _mm_castsi128_ps(vabs_scaled_lo13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vabs_scaled_hi0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_hi02), + _mm_castsi128_ps(vabs_scaled_hi13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i vabs_scaled_lo = + _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vabs_scaled_hi = + _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i vscaled_lo = _mm_sub_epi32( + _mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vscaled_hi = _mm_sub_epi32( + _mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi); + + __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi); + vout = _mm_adds_epi16( + vout, + _mm_load_si128( + (const __m128i*)&quantization_params->sse2.output_zero_point)); + vout = _mm_packus_epi16(vout, vout); + vout = _mm_min_epu8( + vout, + _mm_load_si128( + (const __m128i*)&quantization_params->sse2.output_max)); + vout = _mm_max_epu8( + vout, + _mm_load_si128( + (const __m128i*)&quantization_params->sse2.output_min)); + + _mm_storel_epi64((__m128i*)output, vout); + output += 8; + + k -= 8; + } + if (k != 0) { + const size_t address_decrement = 8 - k; + i0 = (const uint8_t*)((uintptr_t)i0 - address_decrement); + i1 = (const uint8_t*)((uintptr_t)i1 - address_decrement); + i2 = (const uint8_t*)((uintptr_t)i2 - address_decrement); + i3 = (const uint8_t*)((uintptr_t)i3 - address_decrement); + i4 = (const uint8_t*)((uintptr_t)i4 - address_decrement); + i5 = (const uint8_t*)((uintptr_t)i5 - address_decrement); + i6 = (const uint8_t*)((uintptr_t)i6 - address_decrement); + i7 = (const uint8_t*)((uintptr_t)i7 - address_decrement); + const __m128i vshift = _mm_cvtsi32_si128(8 * address_decrement); + + const __m128i vi0 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i0), vshift); + const __m128i vi1 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i1), vshift); + const __m128i vi2 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i2), vshift); + const __m128i vi3 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i3), vshift); + const __m128i vi4 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i4), vshift); + const __m128i vi5 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i5), vshift); + const __m128i vi6 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i6), vshift); + const __m128i vi7 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i7), vshift); + __m128i vacc_lo = _mm_load_si128((const __m128i*)acc); + __m128i vacc_hi = _mm_load_si128((const __m128i*)acc + 1); + + const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero); + const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero); + const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero); + const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero); + const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero); + const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero); + const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero); + const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero); + + const __m128i vsum01 = _mm_add_epi16(vxi0, vxi1); + const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3); + const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5); + const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7); + + const __m128i vsum0123 = _mm_add_epi16(vsum01, vsum23); + const __m128i vsum4567 = _mm_add_epi16(vsum45, vsum67); + const __m128i vsum = _mm_add_epi16(vsum0123, vsum4567); + + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vsum, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vsum, vzero)); + + const __m128i vneg_mask_lo = + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo); + const __m128i vneg_mask_hi = + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi); + + const __m128i vabs_lo0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vabs_hi0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi); + + const __m128i vabs_lo1032 = + _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabs_hi1032 = + _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier); + const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier); + + const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier); + const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier); + + const __m128i vabs_scaled_lo02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift); + const __m128i vabs_scaled_lo13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift); + const __m128i vabs_scaled_hi02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift); + const __m128i vabs_scaled_hi13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift); + + const __m128i vabs_scaled_lo0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_lo02), + _mm_castsi128_ps(vabs_scaled_lo13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vabs_scaled_hi0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_hi02), + _mm_castsi128_ps(vabs_scaled_hi13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i vabs_scaled_lo = + _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vabs_scaled_hi = + _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i vscaled_lo = _mm_sub_epi32( + _mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vscaled_hi = _mm_sub_epi32( + _mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi); + + __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi); + vout = _mm_adds_epi16( + vout, + _mm_load_si128( + (const __m128i*)&quantization_params->sse2.output_zero_point)); + vout = _mm_packus_epi16(vout, vout); + vout = _mm_min_epu8( + vout, + _mm_load_si128( + (const __m128i*)&quantization_params->sse2.output_max)); + vout = _mm_max_epu8( + vout, + _mm_load_si128( + (const __m128i*)&quantization_params->sse2.output_min)); + + if (k & 4) { + *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout); + output += 4; + vout = _mm_srli_epi64(vout, 32); + } + if (k & 2) { + *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0); + output += 2; + vout = _mm_srli_epi32(vout, 16); + } + if (k & 1) { + *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout); + output += 1; + } + } + } + output = (uint8_t*)((uintptr_t)output + output_increment); + } while (--n != 0); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8x9-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8x9-neon.c new file mode 100644 index 0000000000000..f7ec90c1a464c --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8x9-neon.c @@ -0,0 +1,338 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +void pytorch_q8avgpool_ukernel_up8x9__neon( + size_t n, + size_t ks, + size_t kc, + const uint8_t** input, + const uint8_t* zero, + uint8_t* output, + size_t input_increment, + size_t output_increment, + const union pytorch_qnnp_avgpool_quantization_params + quantization_params[restrict static 1]) { + assert(n != 0); + assert(ks <= 9); + assert(kc >= 8); + + const int32x4_t vbias = vld1q_dup_s32(&quantization_params->neon.bias); +#ifdef __aarch64__ + const int32x4_t vmultiplier = + vld1q_dup_s32(&quantization_params->neon.multiplier); +#else + const int32x2_t vmultiplier = + vld1_dup_s32(&quantization_params->neon.multiplier); +#endif + const int64x2_t vleft_shift = + vld1q_dup_s64(&quantization_params->neon.left_shift); + const int16x8_t voutput_zero_point = + vld1q_dup_s16(&quantization_params->neon.output_zero_point); + const uint8x8_t voutput_min = + vld1_dup_u8(&quantization_params->neon.output_min); + const uint8x8_t voutput_max = + vld1_dup_u8(&quantization_params->neon.output_max); + + do { + const uint8_t* i0 = input[0]; + const uint8_t* i1 = input[1]; + const uint8_t* i2 = input[2]; + const uint8_t* i3 = input[3]; + const uint8_t* i4 = input[4]; + const uint8_t* i5 = input[5]; + const uint8_t* i6 = input[6]; + const uint8_t* i7 = input[7]; + const uint8_t* i8 = input[8]; + input = (const uint8_t**)((uintptr_t)input + input_increment); + if (ks < 2) { + i1 = zero; + } + if (ks <= 2) { + i2 = zero; + } + if (ks < 4) { + i3 = zero; + } + if (ks <= 4) { + i4 = zero; + } + if (ks < 6) { + i5 = zero; + } + if (ks <= 6) { + i6 = zero; + } + if (ks < 8) { + i7 = zero; + } + if (ks <= 8) { + i8 = zero; + } + + size_t k = kc; + while (k >= 8) { + const uint8x8_t vi0 = vld1_u8(i0); + i0 += 8; + const uint8x8_t vi1 = vld1_u8(i1); + i1 += 8; + const uint8x8_t vi2 = vld1_u8(i2); + i2 += 8; + const uint8x8_t vi3 = vld1_u8(i3); + i3 += 8; + const uint8x8_t vi4 = vld1_u8(i4); + i4 += 8; + const uint8x8_t vi5 = vld1_u8(i5); + i5 += 8; + const uint8x8_t vi6 = vld1_u8(i6); + i6 += 8; + const uint8x8_t vi7 = vld1_u8(i7); + i7 += 8; + const uint8x8_t vi8 = vld1_u8(i8); + i8 += 8; + + const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8); + const uint16x8_t vsum23 = vaddl_u8(vi2, vi3); + const uint16x8_t vsum45 = vaddl_u8(vi4, vi5); + const uint16x8_t vsum67 = vaddl_u8(vi6, vi7); + + const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45); + const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67); + const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678); + + int32x4_t vacc_lo = + vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum))); + int32x4_t vacc_hi = + vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum))); + + const int32x4_t vneg_mask_lo = + vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0))); + const int32x4_t vneg_mask_hi = + vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0))); + +#if defined(__aarch64__) + const int64x2_t vproduct01 = + vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier)); + const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier); + const int64x2_t vproduct45 = + vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier)); + const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_high_s32(vproduct23, vneg_mask_lo); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_high_s32(vproduct67, vneg_mask_hi); +#else + const int64x2_t vproduct01 = + vmull_s32(vget_low_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct23 = + vmull_s32(vget_high_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct45 = + vmull_s32(vget_low_s32(vacc_hi), vmultiplier); + const int64x2_t vproduct67 = + vmull_s32(vget_high_s32(vacc_hi), vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi)); +#endif + + const int64x2_t vscaled_acc01 = + vrshlq_s64(vadjusted_product01, vleft_shift); + const int64x2_t vscaled_acc23 = + vrshlq_s64(vadjusted_product23, vleft_shift); + const int64x2_t vscaled_acc45 = + vrshlq_s64(vadjusted_product45, vleft_shift); + const int64x2_t vscaled_acc67 = + vrshlq_s64(vadjusted_product67, vleft_shift); + +#ifdef __aarch64__ + vacc_lo = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc01), + vreinterpretq_s32_s64(vscaled_acc23)); + vacc_hi = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc45), + vreinterpretq_s32_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point); +#else + vacc_lo = + vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23)); + vacc_hi = + vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), + voutput_zero_point); +#endif + + uint8x8_t vout = vqmovun_s16(vacc); + vout = vmax_u8(vout, voutput_min); + vout = vmin_u8(vout, voutput_max); + + vst1_u8(output, vout); + output += 8; + + k -= 8; + } + if (k != 0) { + const size_t address_increment = k - 8; + i0 = (const uint8_t*)((uintptr_t)i0 + address_increment); + i1 = (const uint8_t*)((uintptr_t)i1 + address_increment); + i2 = (const uint8_t*)((uintptr_t)i2 + address_increment); + i3 = (const uint8_t*)((uintptr_t)i3 + address_increment); + i4 = (const uint8_t*)((uintptr_t)i4 + address_increment); + i5 = (const uint8_t*)((uintptr_t)i5 + address_increment); + i6 = (const uint8_t*)((uintptr_t)i6 + address_increment); + i7 = (const uint8_t*)((uintptr_t)i7 + address_increment); + i8 = (const uint8_t*)((uintptr_t)i8 + address_increment); + const int64x1_t vshift = vmov_n_s64(8 * address_increment); + + const uint8x8_t vi0 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vshift)); + const uint8x8_t vi1 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vshift)); + const uint8x8_t vi2 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vshift)); + const uint8x8_t vi3 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vshift)); + const uint8x8_t vi4 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vshift)); + const uint8x8_t vi5 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vshift)); + const uint8x8_t vi6 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vshift)); + const uint8x8_t vi7 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vshift)); + const uint8x8_t vi8 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i8)), vshift)); + + const uint16x8_t vsum018 = vaddw_u8(vaddl_u8(vi0, vi1), vi8); + const uint16x8_t vsum23 = vaddl_u8(vi2, vi3); + const uint16x8_t vsum45 = vaddl_u8(vi4, vi5); + const uint16x8_t vsum67 = vaddl_u8(vi6, vi7); + + const uint16x8_t vsum2345 = vaddq_u16(vsum23, vsum45); + const uint16x8_t vsum01678 = vaddq_u16(vsum018, vsum67); + const uint16x8_t vsum = vaddq_u16(vsum2345, vsum01678); + + int32x4_t vacc_lo = + vaddw_s16(vbias, vreinterpret_s16_u16(vget_low_u16(vsum))); + int32x4_t vacc_hi = + vaddw_s16(vbias, vreinterpret_s16_u16(vget_high_u16(vsum))); + + const int32x4_t vneg_mask_lo = + vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0))); + const int32x4_t vneg_mask_hi = + vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0))); + +#if defined(__aarch64__) + const int64x2_t vproduct01 = + vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier)); + const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier); + const int64x2_t vproduct45 = + vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier)); + const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_high_s32(vproduct23, vneg_mask_lo); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_high_s32(vproduct67, vneg_mask_hi); +#else + const int64x2_t vproduct01 = + vmull_s32(vget_low_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct23 = + vmull_s32(vget_high_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct45 = + vmull_s32(vget_low_s32(vacc_hi), vmultiplier); + const int64x2_t vproduct67 = + vmull_s32(vget_high_s32(vacc_hi), vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi)); +#endif + + const int64x2_t vscaled_acc01 = + vrshlq_s64(vadjusted_product01, vleft_shift); + const int64x2_t vscaled_acc23 = + vrshlq_s64(vadjusted_product23, vleft_shift); + const int64x2_t vscaled_acc45 = + vrshlq_s64(vadjusted_product45, vleft_shift); + const int64x2_t vscaled_acc67 = + vrshlq_s64(vadjusted_product67, vleft_shift); + +#ifdef __aarch64__ + vacc_lo = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc01), + vreinterpretq_s32_s64(vscaled_acc23)); + vacc_hi = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc45), + vreinterpretq_s32_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point); +#else + vacc_lo = + vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23)); + vacc_hi = + vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), + voutput_zero_point); +#endif + + uint8x8_t vout = vqmovun_s16(vacc); + vout = vmax_u8(vout, voutput_min); + vout = vmin_u8(vout, voutput_max); + + if (k & 4) { + vst1_lane_u32( + __builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0); + output += 4; + vout = vext_u8(vout, vout, 4); + } + if (k & 2) { + vst1_lane_u16( + __builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0); + output += 2; + vout = vext_u8(vout, vout, 2); + } + if (k & 1) { + vst1_lane_u8(output, vout, 0); + output += 1; + } + } + output = (uint8_t*)((uintptr_t)output + output_increment); + } while (--n != 0); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8x9-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8x9-sse2.c new file mode 100644 index 0000000000000..f1be4f11543cb --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8x9-sse2.c @@ -0,0 +1,327 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +void pytorch_q8avgpool_ukernel_up8x9__sse2( + size_t n, + size_t ks, + size_t kc, + const uint8_t** input, + const uint8_t* zero, + uint8_t* output, + size_t input_increment, + size_t output_increment, + const union pytorch_qnnp_avgpool_quantization_params + quantization_params[RESTRICT_STATIC 1]) { + assert(n != 0); + assert(ks <= 9); + assert(kc >= 8); + + const __m128i vbias = + _mm_load_si128((const __m128i*)&quantization_params->sse2.bias); + const __m128i vzero = _mm_setzero_si128(); + const __m128i vmultiplier = + _mm_load_si128((const __m128i*)quantization_params->sse2.multiplier); + const __m128i vrounding = + _mm_load_si128((const __m128i*)quantization_params->sse2.rounding); + const __m128i vright_shift = + _mm_loadl_epi64((const __m128i*)quantization_params->sse2.right_shift); + + do { + const uint8_t* i0 = input[0]; + const uint8_t* i1 = input[1]; + const uint8_t* i2 = input[2]; + const uint8_t* i3 = input[3]; + const uint8_t* i4 = input[4]; + const uint8_t* i5 = input[5]; + const uint8_t* i6 = input[6]; + const uint8_t* i7 = input[7]; + const uint8_t* i8 = input[8]; + input = (const uint8_t**)((uintptr_t)input + input_increment); + if (ks < 2) { + i1 = zero; + } + if (ks <= 2) { + i2 = zero; + } + if (ks < 4) { + i3 = zero; + } + if (ks <= 4) { + i4 = zero; + } + if (ks < 6) { + i5 = zero; + } + if (ks <= 6) { + i6 = zero; + } + if (ks < 8) { + i7 = zero; + } + if (ks <= 8) { + i8 = zero; + } + + size_t k = kc; + while (k >= 8) { + const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0); + i0 += 8; + const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1); + i1 += 8; + const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2); + i2 += 8; + const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3); + i3 += 8; + const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4); + i4 += 8; + const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5); + i5 += 8; + const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6); + i6 += 8; + const __m128i vi7 = _mm_loadl_epi64((const __m128i*)i7); + i7 += 8; + const __m128i vi8 = _mm_loadl_epi64((const __m128i*)i8); + i8 += 8; + + const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero); + const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero); + const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero); + const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero); + const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero); + const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero); + const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero); + const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero); + const __m128i vxi8 = _mm_unpacklo_epi8(vi8, vzero); + + const __m128i vsum018 = _mm_add_epi16(_mm_add_epi16(vxi0, vxi1), vxi8); + const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3); + const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5); + const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7); + + const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45); + const __m128i vsum01678 = _mm_add_epi16(vsum018, vsum67); + const __m128i vsum = _mm_add_epi16(vsum2345, vsum01678); + + const __m128i vacc_lo = + _mm_add_epi32(vbias, _mm_unpacklo_epi16(vsum, vzero)); + const __m128i vacc_hi = + _mm_add_epi32(vbias, _mm_unpackhi_epi16(vsum, vzero)); + + const __m128i vneg_mask_lo = + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo); + const __m128i vneg_mask_hi = + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi); + + const __m128i vabs_lo0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vabs_hi0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi); + + const __m128i vabs_lo1032 = + _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabs_hi1032 = + _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier); + const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier); + + const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier); + const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier); + + const __m128i vabs_scaled_lo02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift); + const __m128i vabs_scaled_lo13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift); + const __m128i vabs_scaled_hi02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift); + const __m128i vabs_scaled_hi13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift); + + const __m128i vabs_scaled_lo0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_lo02), + _mm_castsi128_ps(vabs_scaled_lo13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vabs_scaled_hi0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_hi02), + _mm_castsi128_ps(vabs_scaled_hi13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i vabs_scaled_lo = + _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vabs_scaled_hi = + _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i vscaled_lo = _mm_sub_epi32( + _mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vscaled_hi = _mm_sub_epi32( + _mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi); + + __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi); + vout = _mm_adds_epi16( + vout, + _mm_load_si128( + (const __m128i*)&quantization_params->sse2.output_zero_point)); + vout = _mm_packus_epi16(vout, vout); + vout = _mm_min_epu8( + vout, + _mm_load_si128( + (const __m128i*)&quantization_params->sse2.output_max)); + vout = _mm_max_epu8( + vout, + _mm_load_si128( + (const __m128i*)&quantization_params->sse2.output_min)); + + _mm_storel_epi64((__m128i*)output, vout); + output += 8; + + k -= 8; + } + if (k != 0) { + const size_t address_decrement = 8 - k; + i0 = (const uint8_t*)((uintptr_t)i0 - address_decrement); + i1 = (const uint8_t*)((uintptr_t)i1 - address_decrement); + i2 = (const uint8_t*)((uintptr_t)i2 - address_decrement); + i3 = (const uint8_t*)((uintptr_t)i3 - address_decrement); + i4 = (const uint8_t*)((uintptr_t)i4 - address_decrement); + i5 = (const uint8_t*)((uintptr_t)i5 - address_decrement); + i6 = (const uint8_t*)((uintptr_t)i6 - address_decrement); + i7 = (const uint8_t*)((uintptr_t)i7 - address_decrement); + i8 = (const uint8_t*)((uintptr_t)i8 - address_decrement); + const __m128i vshift = _mm_cvtsi32_si128(8 * address_decrement); + + const __m128i vi0 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i0), vshift); + const __m128i vi1 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i1), vshift); + const __m128i vi2 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i2), vshift); + const __m128i vi3 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i3), vshift); + const __m128i vi4 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i4), vshift); + const __m128i vi5 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i5), vshift); + const __m128i vi6 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i6), vshift); + const __m128i vi7 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i7), vshift); + const __m128i vi8 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i8), vshift); + + const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero); + const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero); + const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero); + const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero); + const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero); + const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero); + const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero); + const __m128i vxi7 = _mm_unpacklo_epi8(vi7, vzero); + const __m128i vxi8 = _mm_unpacklo_epi8(vi8, vzero); + + const __m128i vsum018 = _mm_add_epi16(_mm_add_epi16(vxi0, vxi1), vxi8); + const __m128i vsum23 = _mm_add_epi16(vxi2, vxi3); + const __m128i vsum45 = _mm_add_epi16(vxi4, vxi5); + const __m128i vsum67 = _mm_add_epi16(vxi6, vxi7); + + const __m128i vsum2345 = _mm_add_epi16(vsum23, vsum45); + const __m128i vsum01678 = _mm_add_epi16(vsum018, vsum67); + const __m128i vsum = _mm_add_epi16(vsum2345, vsum01678); + + const __m128i vacc_lo = + _mm_add_epi32(vbias, _mm_unpacklo_epi16(vsum, vzero)); + const __m128i vacc_hi = + _mm_add_epi32(vbias, _mm_unpackhi_epi16(vsum, vzero)); + + const __m128i vneg_mask_lo = + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo); + const __m128i vneg_mask_hi = + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi); + + const __m128i vabs_lo0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vabs_hi0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi); + + const __m128i vabs_lo1032 = + _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabs_hi1032 = + _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier); + const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier); + + const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier); + const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier); + + const __m128i vabs_scaled_lo02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift); + const __m128i vabs_scaled_lo13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift); + const __m128i vabs_scaled_hi02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift); + const __m128i vabs_scaled_hi13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift); + + const __m128i vabs_scaled_lo0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_lo02), + _mm_castsi128_ps(vabs_scaled_lo13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vabs_scaled_hi0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_hi02), + _mm_castsi128_ps(vabs_scaled_hi13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i vabs_scaled_lo = + _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vabs_scaled_hi = + _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i vscaled_lo = _mm_sub_epi32( + _mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vscaled_hi = _mm_sub_epi32( + _mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi); + + __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi); + vout = _mm_adds_epi16( + vout, + _mm_load_si128( + (const __m128i*)&quantization_params->sse2.output_zero_point)); + vout = _mm_packus_epi16(vout, vout); + vout = _mm_min_epu8( + vout, + _mm_load_si128( + (const __m128i*)&quantization_params->sse2.output_max)); + vout = _mm_max_epu8( + vout, + _mm_load_si128( + (const __m128i*)&quantization_params->sse2.output_min)); + + if (k & 4) { + *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout); + output += 4; + vout = _mm_srli_epi64(vout, 32); + } + if (k & 2) { + *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0); + output += 2; + vout = _mm_srli_epi32(vout, 16); + } + if (k & 1) { + *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout); + output += 1; + } + } + output = (uint8_t*)((uintptr_t)output + output_increment); + } while (--n != 0); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8xm-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8xm-neon.c new file mode 100644 index 0000000000000..a95296f0983b4 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8xm-neon.c @@ -0,0 +1,169 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +void pytorch_q8avgpool_ukernel_up8xm__neon( + size_t n, + size_t ks, + size_t kc, + const uint8_t** input, + const uint8_t* zero, + uint8_t* output, + size_t input_increment, + size_t output_increment, + const union pytorch_qnnp_avgpool_quantization_params + quantization_params[restrict static 1]) { + assert(n != 0); + assert(ks != 0); + assert(kc < 8); + + const int32x4_t vbias = vld1q_dup_s32(&quantization_params->neon.bias); +#ifdef __aarch64__ + const int32x4_t vmultiplier = + vld1q_dup_s32(&quantization_params->neon.multiplier); +#else + const int32x2_t vmultiplier = + vld1_dup_s32(&quantization_params->neon.multiplier); +#endif + const int64x2_t vleft_shift = + vld1q_dup_s64(&quantization_params->neon.left_shift); + const int16x8_t voutput_zero_point = + vld1q_dup_s16(&quantization_params->neon.output_zero_point); + const uint8x8_t voutput_min = + vld1_dup_u8(&quantization_params->neon.output_min); + const uint8x8_t voutput_max = + vld1_dup_u8(&quantization_params->neon.output_max); + + do { + int32x4_t vacc_lo = vbias; + int32x4_t vacc_hi = vbias; + const uint8_t** next_input = + (const uint8_t**)((uintptr_t)input + input_increment); + + size_t m = ks; + do { + const uint8_t* i = *input++; + i += kc; + uint8x8_t vi = vmov_n_u8(0); + if (kc & 1) { + i -= 1; + vi = vld1_lane_u8(i, vi, 0); + } + if (kc & 2) { + vi = vext_u8(vi, vi, 6); + i -= 2; + vi = vreinterpret_u8_u16(vld1_lane_u16( + __builtin_assume_aligned(i, 1), vreinterpret_u16_u8(vi), 0)); + } + if (kc & 4) { + vi = vext_u8(vi, vi, 4); + i -= 4; + vi = vreinterpret_u8_u32(vld1_lane_u32( + __builtin_assume_aligned(i, 1), vreinterpret_u32_u8(vi), 0)); + } + + const uint16x8_t vxi = vmovl_u8(vi); + vacc_lo = vaddw_s16(vacc_lo, vreinterpret_s16_u16(vget_low_u16(vxi))); + vacc_hi = vaddw_s16(vacc_hi, vreinterpret_s16_u16(vget_high_u16(vxi))); + } while (--m != 0); + input = next_input; + + const int32x4_t vneg_mask_lo = + vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0))); + const int32x4_t vneg_mask_hi = + vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0))); + +#if defined(__aarch64__) + const int64x2_t vproduct01 = + vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier)); + const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier); + const int64x2_t vproduct45 = + vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier)); + const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_high_s32(vproduct23, vneg_mask_lo); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_high_s32(vproduct67, vneg_mask_hi); +#else + const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier); + const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi)); +#endif + + const int64x2_t vscaled_acc01 = + vrshlq_s64(vadjusted_product01, vleft_shift); + const int64x2_t vscaled_acc23 = + vrshlq_s64(vadjusted_product23, vleft_shift); + const int64x2_t vscaled_acc45 = + vrshlq_s64(vadjusted_product45, vleft_shift); + const int64x2_t vscaled_acc67 = + vrshlq_s64(vadjusted_product67, vleft_shift); + +#ifdef __aarch64__ + vacc_lo = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc01), + vreinterpretq_s32_s64(vscaled_acc23)); + vacc_hi = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc45), + vreinterpretq_s32_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point); +#else + vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23)); + vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), + voutput_zero_point); +#endif + + uint8x8_t vout = vqmovun_s16(vacc); + vout = vmax_u8(vout, voutput_min); + vout = vmin_u8(vout, voutput_max); + + if (kc & 4) { + vst1_lane_u32( + __builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0); + output += 4; + vout = vext_u8(vout, vout, 4); + } + if (kc & 2) { + vst1_lane_u16( + __builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0); + output += 2; + vout = vext_u8(vout, vout, 2); + } + if (kc & 1) { + vst1_lane_u8(output, vout, 0); + output += 1; + } + output = (uint8_t*)((uintptr_t)output + output_increment); + + } while (--n != 0); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8xm-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8xm-sse2.c new file mode 100644 index 0000000000000..a98cd666d9a4d --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8avgpool/up8xm-sse2.c @@ -0,0 +1,148 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +void pytorch_q8avgpool_ukernel_up8xm__sse2( + size_t n, + size_t ks, + size_t kc, + const uint8_t** input, + const uint8_t* zero, + uint8_t* output, + size_t input_increment, + size_t output_increment, + const union pytorch_qnnp_avgpool_quantization_params + quantization_params[RESTRICT_STATIC 1]) { + assert(n != 0); + assert(ks != 0); + assert(kc < 8); + + const __m128i vbias = + _mm_load_si128((const __m128i*)&quantization_params->sse2.bias); + const __m128i vzero = _mm_setzero_si128(); + const __m128i vmultiplier = + _mm_load_si128((const __m128i*)quantization_params->sse2.multiplier); + const __m128i vrounding = + _mm_load_si128((const __m128i*)quantization_params->sse2.rounding); + const __m128i vright_shift = + _mm_loadl_epi64((const __m128i*)quantization_params->sse2.right_shift); + + do { + const uint8_t** next_input = + (const uint8_t**)((uintptr_t)input + input_increment); + __m128i vacc_lo = vbias; + __m128i vacc_hi = vbias; + + size_t m = ks; + do { + const uint8_t* i = *input++; + i += kc; + __m128i vi = _mm_setzero_si128(); + if (kc & 1) { + i -= 1; + vi = _mm_cvtsi32_si128((int)(uint32_t)*i); + } + if (kc & 2) { + vi = _mm_slli_epi32(vi, 16); + i -= 2; + vi = _mm_insert_epi16(vi, *((const uint16_t*)i), 0); + } + if (kc & 4) { + i -= 4; + vi = _mm_unpacklo_epi32( + _mm_cvtsi32_si128((int)*((const uint32_t*)i)), vi); + } + + const __m128i vxi = _mm_unpacklo_epi8(vi, vzero); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi, vzero)); + } while (--m != 0); + input = next_input; + + const __m128i vneg_mask_lo = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo); + const __m128i vneg_mask_hi = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi); + + const __m128i vabs_lo0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vabs_hi0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi); + + const __m128i vabs_lo1032 = + _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabs_hi1032 = + _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier); + const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier); + + const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier); + const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier); + + const __m128i vabs_scaled_lo02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift); + const __m128i vabs_scaled_lo13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift); + const __m128i vabs_scaled_hi02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift); + const __m128i vabs_scaled_hi13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift); + + const __m128i vabs_scaled_lo0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_lo02), + _mm_castsi128_ps(vabs_scaled_lo13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vabs_scaled_hi0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_hi02), + _mm_castsi128_ps(vabs_scaled_hi13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i vabs_scaled_lo = + _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vabs_scaled_hi = + _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i vscaled_lo = _mm_sub_epi32( + _mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vscaled_hi = _mm_sub_epi32( + _mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi); + + __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi); + vout = _mm_adds_epi16( + vout, + _mm_load_si128( + (const __m128i*)quantization_params->sse2.output_zero_point)); + vout = _mm_packus_epi16(vout, vout); + vout = _mm_min_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_max)); + vout = _mm_max_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_min)); + + if (kc & 4) { + *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout); + output += 4; + vout = _mm_srli_epi64(vout, 32); + } + if (kc & 2) { + *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0); + output += 2; + vout = _mm_srli_epi32(vout, 16); + } + if (kc & 1) { + *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout); + output += 1; + } + output = (uint8_t*)((uintptr_t)output + output_increment); + } while (--n != 0); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x4c2-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x4c2-sse2.c new file mode 100644 index 0000000000000..d25d6aa03206f --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x4c2-sse2.c @@ -0,0 +1,465 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +void pytorch_q8conv_ukernel_4x4c2__sse2( + size_t mr, + size_t nr, + size_t kc, + size_t ks, + const uint8_t** restrict a, + const void* restrict w, + uint8_t* restrict c, + size_t c_stride, + const union pytorch_qnnp_conv_quantization_params + quantization_params[RESTRICT_STATIC 1]) { + __m128i vacc0x0123 = _mm_loadu_si128((const __m128i*)w); + __m128i vacc1x0123 = vacc0x0123; + __m128i vacc2x0123 = vacc0x0123; + __m128i vacc3x0123 = vacc0x0123; + w = (const void*)((uintptr_t)w + 16); + + const __m128i va_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.input_zero_point); + const __m128i vb_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.kernel_zero_point); + const __m128i vzero = _mm_setzero_si128(); + do { + const uint8_t* restrict a0 = *a++; + const uint8_t* restrict a1 = *a++; + const uint8_t* restrict a2 = *a++; + const uint8_t* restrict a3 = *a++; + + size_t k = kc; + for (; k >= 8; k -= 8) { + const __m128i va0 = _mm_loadl_epi64((const __m128i*)a0); + const __m128i vxa0 = + sub_zero_point(_mm_unpacklo_epi8(va0, vzero), va_zero_point); + a0 += 8; + const __m128i va1 = _mm_loadl_epi64((const __m128i*)a1); + const __m128i vxa1 = + sub_zero_point(_mm_unpacklo_epi8(va1, vzero), va_zero_point); + a1 += 8; + const __m128i va2 = _mm_loadl_epi64((const __m128i*)a2); + const __m128i vxa2 = + sub_zero_point(_mm_unpacklo_epi8(va2, vzero), va_zero_point); + a2 += 8; + const __m128i va3 = _mm_loadl_epi64((const __m128i*)a3); + const __m128i vxa3 = + sub_zero_point(_mm_unpacklo_epi8(va3, vzero), va_zero_point); + a3 += 8; + + const __m128i vb0 = _mm_loadl_epi64((const __m128i*)w); + const __m128i vxb0 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point); + vacc0x0123 = _mm_add_epi32( + vacc0x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(0, 0, 0, 0)), vxb0)); + vacc1x0123 = _mm_add_epi32( + vacc1x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(0, 0, 0, 0)), vxb0)); + vacc2x0123 = _mm_add_epi32( + vacc2x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(0, 0, 0, 0)), vxb0)); + vacc3x0123 = _mm_add_epi32( + vacc3x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(0, 0, 0, 0)), vxb0)); + + const __m128i vb1 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 8)); + const __m128i vxb1 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point); + vacc0x0123 = _mm_add_epi32( + vacc0x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(1, 1, 1, 1)), vxb1)); + vacc1x0123 = _mm_add_epi32( + vacc1x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(1, 1, 1, 1)), vxb1)); + vacc2x0123 = _mm_add_epi32( + vacc2x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(1, 1, 1, 1)), vxb1)); + vacc3x0123 = _mm_add_epi32( + vacc3x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(1, 1, 1, 1)), vxb1)); + + const __m128i vb2 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 16)); + const __m128i vxb2 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point); + vacc0x0123 = _mm_add_epi32( + vacc0x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); + vacc1x0123 = _mm_add_epi32( + vacc1x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); + vacc2x0123 = _mm_add_epi32( + vacc2x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); + vacc3x0123 = _mm_add_epi32( + vacc3x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); + + const __m128i vb3 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 24)); + const __m128i vxb3 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point); + vacc0x0123 = _mm_add_epi32( + vacc0x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); + vacc1x0123 = _mm_add_epi32( + vacc1x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); + vacc2x0123 = _mm_add_epi32( + vacc2x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); + vacc3x0123 = _mm_add_epi32( + vacc3x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); + + w = (void*)((uintptr_t)w + 32); + } + if (k != 0) { + const size_t a_predecrement = 8 - k; + const __m128i va_shift = _mm_cvtsi32_si128(8 * a_predecrement); + + const __m128i va0 = _mm_srl_epi64( + _mm_loadl_epi64((const __m128i*)(a0 - a_predecrement)), va_shift); + const __m128i vxa0 = + sub_zero_point(_mm_unpacklo_epi8(va0, vzero), va_zero_point); + const __m128i va1 = _mm_srl_epi64( + _mm_loadl_epi64((const __m128i*)(a1 - a_predecrement)), va_shift); + const __m128i vxa1 = + sub_zero_point(_mm_unpacklo_epi8(va1, vzero), va_zero_point); + const __m128i va2 = _mm_srl_epi64( + _mm_loadl_epi64((const __m128i*)(a2 - a_predecrement)), va_shift); + const __m128i vxa2 = + sub_zero_point(_mm_unpacklo_epi8(va2, vzero), va_zero_point); + const __m128i va3 = _mm_srl_epi64( + _mm_loadl_epi64((const __m128i*)(a3 - a_predecrement)), va_shift); + const __m128i vxa3 = + sub_zero_point(_mm_unpacklo_epi8(va3, vzero), va_zero_point); + + const __m128i vb0 = _mm_loadl_epi64((const __m128i*)w); + const __m128i vxb0 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point); + w = (void*)((uintptr_t)w + 8); + + vacc0x0123 = _mm_add_epi32( + vacc0x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(0, 0, 0, 0)), vxb0)); + vacc1x0123 = _mm_add_epi32( + vacc1x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(0, 0, 0, 0)), vxb0)); + vacc2x0123 = _mm_add_epi32( + vacc2x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(0, 0, 0, 0)), vxb0)); + vacc3x0123 = _mm_add_epi32( + vacc3x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(0, 0, 0, 0)), vxb0)); + + if (k > 2) { + const __m128i vb1 = _mm_loadl_epi64((const __m128i*)w); + const __m128i vxb1 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point); + w = (void*)((uintptr_t)w + 8); + + vacc0x0123 = _mm_add_epi32( + vacc0x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(1, 1, 1, 1)), vxb1)); + vacc1x0123 = _mm_add_epi32( + vacc1x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(1, 1, 1, 1)), vxb1)); + vacc2x0123 = _mm_add_epi32( + vacc2x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(1, 1, 1, 1)), vxb1)); + vacc3x0123 = _mm_add_epi32( + vacc3x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(1, 1, 1, 1)), vxb1)); + + if (k > 4) { + const __m128i vb2 = _mm_loadl_epi64((const __m128i*)w); + const __m128i vxb2 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point); + w = (void*)((uintptr_t)w + 8); + + vacc0x0123 = _mm_add_epi32( + vacc0x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); + vacc1x0123 = _mm_add_epi32( + vacc1x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); + vacc2x0123 = _mm_add_epi32( + vacc2x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); + vacc3x0123 = _mm_add_epi32( + vacc3x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); + + if (k > 6) { + const __m128i vb3 = _mm_loadl_epi64((const __m128i*)w); + const __m128i vxb3 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point); + w = (void*)((uintptr_t)w + 8); + + vacc0x0123 = _mm_add_epi32( + vacc0x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); + vacc1x0123 = _mm_add_epi32( + vacc1x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); + vacc2x0123 = _mm_add_epi32( + vacc2x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); + vacc3x0123 = _mm_add_epi32( + vacc3x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); + } + } + } + } + } while (--ks != 0); + + const __m128i vmultiplier = + _mm_load_si128((const __m128i*)quantization_params->sse2.multiplier); + const __m128i vrounding = + _mm_load_si128((const __m128i*)quantization_params->sse2.rounding); + + const __m128i vnmask0x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc0x0123); + const __m128i vnmask1x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc1x0123); + const __m128i vnmask2x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc2x0123); + const __m128i vnmask3x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc3x0123); + + const __m128i vabsacc0x0123 = + _mm_sub_epi32(_mm_xor_si128(vacc0x0123, vnmask0x0123), vnmask0x0123); + const __m128i vabsacc1x0123 = + _mm_sub_epi32(_mm_xor_si128(vacc1x0123, vnmask1x0123), vnmask1x0123); + const __m128i vabsacc2x0123 = + _mm_sub_epi32(_mm_xor_si128(vacc2x0123, vnmask2x0123), vnmask2x0123); + const __m128i vabsacc3x0123 = + _mm_sub_epi32(_mm_xor_si128(vacc3x0123, vnmask3x0123), vnmask3x0123); + + const __m128i vabsacc0x1032 = + _mm_shuffle_epi32(vabsacc0x0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabsacc1x1032 = + _mm_shuffle_epi32(vabsacc1x0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabsacc2x1032 = + _mm_shuffle_epi32(vabsacc2x0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabsacc3x1032 = + _mm_shuffle_epi32(vabsacc3x0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i vabsprod0x02 = _mm_mul_epu32(vabsacc0x0123, vmultiplier); + const __m128i vabsprod1x02 = _mm_mul_epu32(vabsacc1x0123, vmultiplier); + const __m128i vabsprod2x02 = _mm_mul_epu32(vabsacc2x0123, vmultiplier); + const __m128i vabsprod3x02 = _mm_mul_epu32(vabsacc3x0123, vmultiplier); + + const __m128i vnmask0x02 = + _mm_shuffle_epi32(vnmask0x0123, _MM_SHUFFLE(2, 2, 0, 0)); + const __m128i vnmask1x02 = + _mm_shuffle_epi32(vnmask1x0123, _MM_SHUFFLE(2, 2, 0, 0)); + const __m128i vnmask2x02 = + _mm_shuffle_epi32(vnmask2x0123, _MM_SHUFFLE(2, 2, 0, 0)); + const __m128i vnmask3x02 = + _mm_shuffle_epi32(vnmask3x0123, _MM_SHUFFLE(2, 2, 0, 0)); + + const __m128i vprod0x02 = + _mm_sub_epi64(_mm_xor_si128(vabsprod0x02, vnmask0x02), vnmask0x02); + const __m128i vprod1x02 = + _mm_sub_epi64(_mm_xor_si128(vabsprod1x02, vnmask1x02), vnmask1x02); + const __m128i vprod2x02 = + _mm_sub_epi64(_mm_xor_si128(vabsprod2x02, vnmask2x02), vnmask2x02); + const __m128i vprod3x02 = + _mm_sub_epi64(_mm_xor_si128(vabsprod3x02, vnmask3x02), vnmask3x02); + + const __m128i vq31prod0x02 = + _mm_srli_epi64(_mm_add_epi64(vprod0x02, vrounding), 31); + const __m128i vq31prod1x02 = + _mm_srli_epi64(_mm_add_epi64(vprod1x02, vrounding), 31); + const __m128i vq31prod2x02 = + _mm_srli_epi64(_mm_add_epi64(vprod2x02, vrounding), 31); + const __m128i vq31prod3x02 = + _mm_srli_epi64(_mm_add_epi64(vprod3x02, vrounding), 31); + + const __m128i vabsprod0x13 = _mm_mul_epu32(vabsacc0x1032, vmultiplier); + const __m128i vabsprod1x13 = _mm_mul_epu32(vabsacc1x1032, vmultiplier); + const __m128i vabsprod2x13 = _mm_mul_epu32(vabsacc2x1032, vmultiplier); + const __m128i vabsprod3x13 = _mm_mul_epu32(vabsacc3x1032, vmultiplier); + + const __m128i vnmask0x13 = + _mm_shuffle_epi32(vnmask0x0123, _MM_SHUFFLE(3, 3, 1, 1)); + const __m128i vnmask1x13 = + _mm_shuffle_epi32(vnmask1x0123, _MM_SHUFFLE(3, 3, 1, 1)); + const __m128i vnmask2x13 = + _mm_shuffle_epi32(vnmask2x0123, _MM_SHUFFLE(3, 3, 1, 1)); + const __m128i vnmask3x13 = + _mm_shuffle_epi32(vnmask3x0123, _MM_SHUFFLE(3, 3, 1, 1)); + + const __m128i vprod0x13 = + _mm_sub_epi64(_mm_xor_si128(vabsprod0x13, vnmask0x13), vnmask0x13); + const __m128i vprod1x13 = + _mm_sub_epi64(_mm_xor_si128(vabsprod1x13, vnmask1x13), vnmask1x13); + const __m128i vprod2x13 = + _mm_sub_epi64(_mm_xor_si128(vabsprod2x13, vnmask2x13), vnmask2x13); + const __m128i vprod3x13 = + _mm_sub_epi64(_mm_xor_si128(vabsprod3x13, vnmask3x13), vnmask3x13); + + const __m128i vq31prod0x13 = + _mm_srli_epi64(_mm_add_epi64(vprod0x13, vrounding), 31); + const __m128i vq31prod1x13 = + _mm_srli_epi64(_mm_add_epi64(vprod1x13, vrounding), 31); + const __m128i vq31prod2x13 = + _mm_srli_epi64(_mm_add_epi64(vprod2x13, vrounding), 31); + const __m128i vq31prod3x13 = + _mm_srli_epi64(_mm_add_epi64(vprod3x13, vrounding), 31); + + const __m128i vq31prod0x0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vq31prod0x02), + _mm_castsi128_ps(vq31prod0x13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vq31prod1x0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vq31prod1x02), + _mm_castsi128_ps(vq31prod1x13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vq31prod2x0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vq31prod2x02), + _mm_castsi128_ps(vq31prod2x13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vq31prod3x0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vq31prod3x02), + _mm_castsi128_ps(vq31prod3x13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i vq31prod0x0123 = + _mm_shuffle_epi32(vq31prod0x0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vq31prod1x0123 = + _mm_shuffle_epi32(vq31prod1x0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vq31prod2x0123 = + _mm_shuffle_epi32(vq31prod2x0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vq31prod3x0123 = + _mm_shuffle_epi32(vq31prod3x0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i vremainder_mask = + _mm_load_si128((const __m128i*)quantization_params->sse2.remainder_mask); + + const __m128i vrem0x0123 = _mm_add_epi32( + _mm_and_si128(vq31prod0x0123, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod0x0123)); + const __m128i vrem1x0123 = _mm_add_epi32( + _mm_and_si128(vq31prod1x0123, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod1x0123)); + const __m128i vrem2x0123 = _mm_add_epi32( + _mm_and_si128(vq31prod2x0123, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod2x0123)); + const __m128i vrem3x0123 = _mm_add_epi32( + _mm_and_si128(vq31prod3x0123, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod3x0123)); + + const __m128i vremainder_threshold = _mm_load_si128( + (const __m128i*)quantization_params->sse2.remainder_threshold); + const __m128i vshift = + _mm_load_si128((const __m128i*)quantization_params->sse2.shift); + + vacc0x0123 = _mm_sub_epi32( + _mm_sra_epi32(vq31prod0x0123, vshift), + _mm_cmpgt_epi32(vrem0x0123, vremainder_threshold)); + vacc1x0123 = _mm_sub_epi32( + _mm_sra_epi32(vq31prod1x0123, vshift), + _mm_cmpgt_epi32(vrem1x0123, vremainder_threshold)); + vacc2x0123 = _mm_sub_epi32( + _mm_sra_epi32(vq31prod2x0123, vshift), + _mm_cmpgt_epi32(vrem2x0123, vremainder_threshold)); + vacc3x0123 = _mm_sub_epi32( + _mm_sra_epi32(vq31prod3x0123, vshift), + _mm_cmpgt_epi32(vrem3x0123, vremainder_threshold)); + + const __m128i voutput_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.output_zero_point); + const __m128i vacc01x0123 = _mm_adds_epi16( + _mm_packs_epi32(vacc0x0123, vacc1x0123), voutput_zero_point); + const __m128i vacc23x0123 = _mm_adds_epi16( + _mm_packs_epi32(vacc2x0123, vacc3x0123), voutput_zero_point); + __m128i vout = _mm_packus_epi16(vacc01x0123, vacc23x0123); + vout = _mm_min_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_max)); + vout = _mm_max_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_min)); + + uint8_t* c0 = c; + uint8_t* c1 = (uint8_t*)((uintptr_t)c0 + c_stride); + if (mr < 2) { + c1 = c0; + } + uint8_t* c2 = (uint8_t*)((uintptr_t)c1 + c_stride); + if (mr <= 2) { + c2 = c1; + } + uint8_t* c3 = (uint8_t*)((uintptr_t)c2 + c_stride); + if (mr != 4) { + c3 = c2; + } + if (nr == 4) { + *((uint32_t*)c0) = (uint32_t)_mm_cvtsi128_si32(vout); + *((uint32_t*)c1) = (uint32_t)_mm_cvtsi128_si32(_mm_srli_epi64(vout, 32)); + *((uint32_t*)c2) = + (uint32_t)_mm_cvtsi128_si32(_mm_unpackhi_epi32(vout, vout)); + *((uint32_t*)c3) = (uint32_t)_mm_cvtsi128_si32(_mm_srli_si128(vout, 12)); + } else { + if (nr >= 2) { + *((uint16_t*)c0) = (uint16_t)_mm_extract_epi16(vout, 0); + c0 += 2; + *((uint16_t*)c1) = (uint16_t)_mm_extract_epi16(vout, 2); + c1 += 2; + *((uint16_t*)c2) = (uint16_t)_mm_extract_epi16(vout, 4); + c2 += 2; + *((uint16_t*)c3) = (uint16_t)_mm_extract_epi16(vout, 6); + c3 += 2; + vout = _mm_srli_epi32(vout, 16); + nr -= 2; + } + if (nr != 0) { + *((uint8_t*)c0) = (uint8_t)_mm_cvtsi128_si32(vout); + *((uint8_t*)c1) = (uint8_t)_mm_extract_epi16(vout, 2); + *((uint8_t*)c2) = (uint8_t)_mm_extract_epi16(vout, 4); + *((uint8_t*)c3) = (uint8_t)_mm_extract_epi16(vout, 6); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-aarch32-neon.S new file mode 100644 index 0000000000000..4b0dfeb5440c3 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-aarch32-neon.S @@ -0,0 +1,793 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +.syntax unified + +# void pytorch_q8conv_ukernel_4x8__aarch32_neon( +# size_t mr, +# size_t nr, +# size_t kc, +# size_t ks, +# const uint8_t**restrict a, +# const void*restrict w, +# uint8_t*restrict c, +# size_t c_stride, +# const union pytorch_qnnp_conv_quantization_params quantization_params[restrict static 1]) +BEGIN_FUNCTION pytorch_q8conv_ukernel_4x8__aarch32_neon + .arm +#ifndef __APPLE__ + .arch armv7-a + .fpu neon +#endif + # Load w + # - ip = w + LDR ip, [sp, 4] + PUSH {r4, r5, r6, r7, r8, r9, r10, r11} + + VPUSH {d8-d15} + + # Load bias0123, bias4567 + VLDM ip!, {d16-d19} + # Load params: + # - r9 = params + LDR r9, [sp, 112] + + # q10 := vacc1x0123 + VMOV.I32 q10, q8 + MOV r4, 2 + # q11 := vacc1x4567 + VMOV.I32 q11, q9 + # Load a + # - r8 = a + LDR r8, [sp, 96] + # q12 := vacc2x0123 + VMOV.I32 q12, q8 + # q13 := vacc2x4567 + VMOV.I32 q13, q9 + # q14 := vacc3x0123 + VMOV.I32 q14, q8 + # Load b_zero_point: + # - d15 = b_zero_point + VLD1.8 {d15[]}, [r9], r4 + # Load a_zero_point: + # - d14 = a_zero_point + VLD1.8 {d14[]}, [r9], r4 + # q15 := vacc3x4567 + VMOV.I32 q15, q9 + # Load multiplier: + # - d12 = vmultiplier + VLD1.32 {d12[]}, [r9]! + + .p2align 5 +0: + SUBS r10, r2, 8 + + # Load a0, a1, a2, a3 + # - r4 = a0 + # - r5 = a1 + # - r6 = a2 + # - r7 = a3 + LDM r8!, {r4-r7} + + BLO 2f + +1: + # Load va0 + # - d1 = va0 + VLD1.8 {d1}, [r4]! + + # Load va1 + # - d3 = va1 + VLD1.8 {d3}, [r5]! + + # Load vb0-vb7 (channel 0) + # - d9 = vb0-vb7 + VLD1.8 {d9}, [ip:64]! + + # Load va2 + # - d5 = va2 + VLD1.8 {d5}, [r6]! + + # q0 = va0 = a0 + SUB_ZERO_POINT q0, d1, d14 + + # Load va3 + # - d7 = va3 + VLD1.8 {d7}, [r7]! + + # q1 = va1 = a1 + SUB_ZERO_POINT q1, d3, d14 + + # q4 = b0:7 - vb_zero_point + # - d8 = vb0123 (channel 0) + # - d9 = vb4567 (channel 0) + VSUBL.U8 q4, d9, d15 + + # q2 = va2 = a2 + SUB_ZERO_POINT q2, d5, d14 + # q3 = va3 = a3 + SUB_ZERO_POINT q3, d7, d14 + + ### Channel 0 ### + + # Load b0-b7 (channel 1) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[0] + VMLAL.S16 q8, d8, d0[0] + # vacc0x4567 += vb4567 * va0[0] + VMLAL.S16 q9, d9, d0[0] + + # vacc1x0123 += vb0123 * va1[0] + VMLAL.S16 q10, d8, d2[0] + # vacc1x4567 += vb4567 * va1[0] + VMLAL.S16 q11, d9, d2[0] + + # vacc2x0123 += vb0123 * va2[0] + VMLAL.S16 q12, d8, d4[0] + # vacc2x4567 += vb4567 * va2[0] + VMLAL.S16 q13, d9, d4[0] + + # q5 = b0:7 - vb_zero_point + # - d10 = vb0123 (channel 1) + # - d11 = vb4567 (channel 1) + VSUBL.U8 q5, d11, d15 + + # vacc3x0123 += vb0123 * va3[0] + VMLAL.S16 q14, d8, d6[0] + # vacc3x4567 += vb4567 * va3[0] + VMLAL.S16 q15, d9, d6[0] + + ### Channel 1 ### + + # Load b0-b7 (channel 2) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[1] + VMLAL.S16 q8, d10, d0[1] + # vacc0x4567 += vb4567 * va0[1] + VMLAL.S16 q9, d11, d0[1] + + # vacc1x0123 += vb0123 * va1[1] + VMLAL.S16 q10, d10, d2[1] + # vacc1x4567 += vb4567 * va1[1] + VMLAL.S16 q11, d11, d2[1] + + # vacc2x0123 += vb0123 * va2[1] + VMLAL.S16 q12, d10, d4[1] + # vacc2x4567 += vb4567 * va2[1] + VMLAL.S16 q13, d11, d4[1] + + # q4 = b0:7 - vb_zero_point + # - d8 = vb0123 (channel 2) + # - d9 = vb4567 (channel 2) + VSUBL.U8 q4, d9, d15 + + # vacc3x0123 += vb0123 * va3[1] + VMLAL.S16 q14, d10, d6[1] + # vacc3x4567 += vb4567 * va3[1] + VMLAL.S16 q15, d11, d6[1] + + ### Channel 2 ### + + # Load b0-b7 (channel 3) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[2] + VMLAL.S16 q8, d8, d0[2] + # vacc0x4567 += vb4567 * va0[2] + VMLAL.S16 q9, d9, d0[2] + + # vacc1x0123 += vb0123 * va1[2] + VMLAL.S16 q10, d8, d2[2] + # vacc1x4567 += vb4567 * va1[2] + VMLAL.S16 q11, d9, d2[2] + + # vacc2x0123 += vb0123 * va2[2] + VMLAL.S16 q12, d8, d4[2] + # vacc2x4567 += vb4567 * va2[2] + VMLAL.S16 q13, d9, d4[2] + + # q5 = b0:7 - vb_zero_point + # - d10 = vb0123 (channel 3) + # - d11 = vb4567 (channel 3) + VSUBL.U8 q5, d11, d15 + + # vacc3x0123 += vb0123 * va3[2] + VMLAL.S16 q14, d8, d6[2] + # vacc3x4567 += vb4567 * va3[2] + VMLAL.S16 q15, d9, d6[2] + + ### Channel 3 ### + + # Load b0-b7 (channel 4) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[3] + VMLAL.S16 q8, d10, d0[3] + # vacc0x4567 += vb4567 * va0[3] + VMLAL.S16 q9, d11, d0[3] + + # vacc1x0123 += vb0123 * va1[3] + VMLAL.S16 q10, d10, d2[3] + # vacc1x4567 += vb4567 * va1[3] + VMLAL.S16 q11, d11, d2[3] + + # vacc2x0123 += vb0123 * va2[3] + VMLAL.S16 q12, d10, d4[3] + # vacc2x4567 += vb4567 * va2[3] + VMLAL.S16 q13, d11, d4[3] + + # q5 = b0:7 - vb_zero_point + # - d10 = vb0123 (channel 4) + # - d11 = vb4567 (channel 4) + VSUBL.U8 q4, d9, d15 + + # vacc3x0123 += vb0123 * va3[3] + VMLAL.S16 q14, d10, d6[3] + # vacc3x4567 += vb4567 * va3[3] + VMLAL.S16 q15, d11, d6[3] + + ### Channel 4 ### + + # Load b0-b7 (channel 5) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[4] + VMLAL.S16 q8, d8, d1[0] + # vacc0x4567 += vb4567 * va0[4] + VMLAL.S16 q9, d9, d1[0] + + # vacc1x0123 += vb0123 * va1[4] + VMLAL.S16 q10, d8, d3[0] + # vacc1x4567 += vb4567 * va1[4] + VMLAL.S16 q11, d9, d3[0] + + # vacc2x0123 += vb0123 * va2[4] + VMLAL.S16 q12, d8, d5[0] + # vacc2x4567 += vb4567 * va2[4] + VMLAL.S16 q13, d9, d5[0] + + # q4 = b0:7 - vb_zero_point + # - d8 = vb0123 (channel 5) + # - d9 = vb4567 (channel 5) + VSUBL.U8 q5, d11, d15 + + # vacc3x0123 += vb0123 * va3[4] + VMLAL.S16 q14, d8, d7[0] + # vacc3x4567 += vb4567 * va3[4] + VMLAL.S16 q15, d9, d7[0] + + ### Channel 5 ### + + # Load b0-b7 (channel 6) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[5] + VMLAL.S16 q8, d10, d1[1] + # vacc0x4567 += vb4567 * va0[5] + VMLAL.S16 q9, d11, d1[1] + + # vacc1x0123 += vb0123 * va1[5] + VMLAL.S16 q10, d10, d3[1] + # vacc1x4567 += vb4567 * va1[5] + VMLAL.S16 q11, d11, d3[1] + + # vacc2x0123 += vb0123 * va2[5] + VMLAL.S16 q12, d10, d5[1] + # vacc2x4567 += vb4567 * va2[5] + VMLAL.S16 q13, d11, d5[1] + + # q4 = b0:7 - vb_zero_point + # - d8 = vb0123 (channel 6) + # - d9 = vb4567 (channel 6) + VSUBL.U8 q4, d9, d15 + + # vacc3x0123 += vb0123 * va3[5] + VMLAL.S16 q14, d10, d7[1] + # vacc3x4567 += vb4567 * va3[5] + VMLAL.S16 q15, d11, d7[1] + + ### Channel 6 ### + + # Load b0-b7 (channel 7) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[6] + VMLAL.S16 q8, d8, d1[2] + # vacc0x4567 += vb4567 * va0[6] + VMLAL.S16 q9, d9, d1[2] + + # vacc1x0123 += vb0123 * va1[6] + VMLAL.S16 q10, d8, d3[2] + # vacc1x4567 += vb4567 * va1[6] + VMLAL.S16 q11, d9, d3[2] + + # vacc2x0123 += vb0123 * va2[6] + VMLAL.S16 q12, d8, d5[2] + + # q5 = b0:7 - vb_zero_point + # - d10 = vb0123 (channel 7) + # - d11 = vb4567 (channel 7) + VSUBL.U8 q5, d11, d15 + + # vacc2x4567 += vb4567 * va2[6] + VMLAL.S16 q13, d9, d5[2] + + # vacc3x0123 += vb0123 * va3[6] + VMLAL.S16 q14, d8, d7[2] + # vacc3x4567 += vb4567 * va3[6] + VMLAL.S16 q15, d9, d7[2] + + ### Channel 8 ### + SUBS r10, r10, 8 + + # vacc0x0123 += vb0123 * va0[7] + VMLAL.S16 q8, d10, d1[3] + # vacc0x4567 += vb4567 * va0[7] + VMLAL.S16 q9, d11, d1[3] + + # vacc1x0123 += vb0123 * va1[7] + VMLAL.S16 q10, d10, d3[3] + # vacc1x4567 += vb4567 * va1[7] + VMLAL.S16 q11, d11, d3[3] + + # vacc2x0123 += vb0123 * va2[7] + VMLAL.S16 q12, d10, d5[3] + # vacc2x4567 += vb4567 * va2[7] + VMLAL.S16 q13, d11, d5[3] + + # vacc3x0123 += vb0123 * va3[7] + VMLAL.S16 q14, d10, d7[3] + # vacc3x4567 += vb4567 * va3[7] + VMLAL.S16 q15, d11, d7[3] + + BHS 1b + +2: + CMP r10, -8 + BEQ 3f + + # Adjust a0, a1, a2, a3 + ADD r4, r10 + ADD r5, r10 + ADD r6, r10 + ADD r7, r10 + + # a_shift = 8 * k - 64 + LSL r10, r10, 3 + VDUP.32 d13, r10 + + # Load va0 + # - d1 = va0 + VLD1.8 {d1}, [r4] + + # Load va1 + # - d3 = va1 + VLD1.8 {d3}, [r5] + + # Load b0-b7 (channel 0) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # Load a2 + # - d5 = a2 + VLD1.8 {d5}, [r6] + + # q0 = va0 = a0 + VSHL.U64 d1, d1, d13 + SUB_ZERO_POINT q0, d1, d14 + + # Load a3 + # - d7 = a3 + VLD1.8 {d7}, [r7] + + # q1 = va1 = a1 + VSHL.U64 d3, d3, d13 + SUB_ZERO_POINT q1, d3, d14 + + # q4 = b0:7 - vb_zero_point + # - d8 = vb0123 (channel 0) + # - d9 = vb4567 (channel 0) + VSUBL.U8 q4, d9, d15 + + # q2 = va2 = a2 + VSHL.U64 d5, d5, d13 + SUB_ZERO_POINT q2, d5, d14 + # q3 = va3 = a3 + VSHL.U64 d7, d7, d13 + SUB_ZERO_POINT q3, d7, d14 + + ### Channel 0 ### + + # vacc0x0123 += vb0123 * va0[0] + VMLAL.S16 q8, d8, d0[0] + # vacc0x4567 += vb4567 * va0[0] + VMLAL.S16 q9, d9, d0[0] + + # vacc1x0123 += vb0123 * va1[0] + VMLAL.S16 q10, d8, d2[0] + # vacc1x4567 += vb4567 * va1[0] + VMLAL.S16 q11, d9, d2[0] + + # vacc2x0123 += vb0123 * va2[0] + VMLAL.S16 q12, d8, d4[0] + # vacc2x4567 += vb4567 * va2[0] + VMLAL.S16 q13, d9, d4[0] + + # vacc3x0123 += vb0123 * va3[0] + VMLAL.S16 q14, d8, d6[0] + # vacc3x4567 += vb4567 * va3[0] + VMLAL.S16 q15, d9, d6[0] + + CMP r10, -48 + BLO 3f + + ### Channel 1 ### + + # Load b0-b7 (channel 1) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # q5 = b0:7 - vb_zero_point + # - d10 = vb0123 (channel 1) + # - d11 = vb4567 (channel 1) + VSUBL.U8 q5, d11, d15 + + # vacc0x0123 += vb0123 * va0[1] + VMLAL.S16 q8, d10, d0[1] + # vacc0x4567 += vb4567 * va0[1] + VMLAL.S16 q9, d11, d0[1] + + # vacc1x0123 += vb0123 * va1[1] + VMLAL.S16 q10, d10, d2[1] + # vacc1x4567 += vb4567 * va1[1] + VMLAL.S16 q11, d11, d2[1] + + # vacc2x0123 += vb0123 * va2[1] + VMLAL.S16 q12, d10, d4[1] + # vacc2x4567 += vb4567 * va2[1] + VMLAL.S16 q13, d11, d4[1] + + # vacc3x0123 += vb0123 * va3[1] + VMLAL.S16 q14, d10, d6[1] + # vacc3x4567 += vb4567 * va3[1] + VMLAL.S16 q15, d11, d6[1] + + ### Channel 2 ### + BLS 3f + + # Load b0-b7 (channel 2) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # q4 = b0:7 - vb_zero_point + # - d8 = vb0123 (channel 2) + # - d9 = vb4567 (channel 2) + VSUBL.U8 q4, d9, d15 + + # vacc0x0123 += vb0123 * va0[2] + VMLAL.S16 q8, d8, d0[2] + # vacc0x4567 += vb4567 * va0[2] + VMLAL.S16 q9, d9, d0[2] + + # vacc1x0123 += vb0123 * va1[2] + VMLAL.S16 q10, d8, d2[2] + # vacc1x4567 += vb4567 * va1[2] + VMLAL.S16 q11, d9, d2[2] + + # vacc2x0123 += vb0123 * va2[2] + VMLAL.S16 q12, d8, d4[2] + # vacc2x4567 += vb4567 * va2[2] + VMLAL.S16 q13, d9, d4[2] + + # vacc3x0123 += vb0123 * va3[2] + VMLAL.S16 q14, d8, d6[2] + # vacc3x4567 += vb4567 * va3[2] + VMLAL.S16 q15, d9, d6[2] + + ### Channel 3 ### + CMP r10, -32 + BLO 3f + + # Load b0-b7 (channel 3) + # - d9 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # q4 = b0:7 - vb_zero_point + # - d8 = vb0123 (channel 3) + # - d9 = vb4567 (channel 3) + VSUBL.U8 q5, d11, d15 + + # vacc0x0123 += vb0123 * va0[3] + VMLAL.S16 q8, d10, d0[3] + # vacc0x4567 += vb4567 * va0[3] + VMLAL.S16 q9, d11, d0[3] + + # vacc1x0123 += vb0123 * va1[3] + VMLAL.S16 q10, d10, d2[3] + # vacc1x4567 += vb4567 * va1[3] + VMLAL.S16 q11, d11, d2[3] + + # vacc2x0123 += vb0123 * va2[3] + VMLAL.S16 q12, d10, d4[3] + # vacc2x4567 += vb4567 * va2[3] + VMLAL.S16 q13, d11, d4[3] + + # vacc3x0123 += vb0123 * va3[3] + VMLAL.S16 q14, d10, d6[3] + # vacc3x4567 += vb4567 * va3[3] + VMLAL.S16 q15, d11, d6[3] + + ### Channel 4 ### + BLS 3f + + # Load b0-b7 (channel 4) + # - d11 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # q5 = b0:7 - vb_zero_point + # - d10 = vb0123 (channel 4) + # - d11 = vb4567 (channel 4) + VSUBL.U8 q4, d9, d15 + + # vacc0x0123 += vb0123 * va0[4] + VMLAL.S16 q8, d8, d1[0] + # vacc0x4567 += vb4567 * va0[4] + VMLAL.S16 q9, d9, d1[0] + + # vacc1x0123 += vb0123 * va1[4] + VMLAL.S16 q10, d8, d3[0] + # vacc1x4567 += vb4567 * va1[4] + VMLAL.S16 q11, d9, d3[0] + + # vacc2x0123 += vb0123 * va2[4] + VMLAL.S16 q12, d8, d5[0] + # vacc2x4567 += vb4567 * va2[4] + VMLAL.S16 q13, d9, d5[0] + + # vacc3x0123 += vb0123 * va3[4] + VMLAL.S16 q14, d8, d7[0] + # vacc3x4567 += vb4567 * va3[4] + VMLAL.S16 q15, d9, d7[0] + + ### Channel 5 ### + CMP r10, -16 + BLO 3f + + # Load b0-b7 (channel 5) + # - d13 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # q5 = b0:7 - vb_zero_point + # - d10 = vb0123 (channel 5) + # - d11 = vb4567 (channel 5) + VSUBL.U8 q5, d11, d15 + + # vacc0x0123 += vb0123 * va0[5] + VMLAL.S16 q8, d10, d1[1] + # vacc0x4567 += vb4567 * va0[5] + VMLAL.S16 q9, d11, d1[1] + + # vacc1x0123 += vb0123 * va1[5] + VMLAL.S16 q10, d10, d3[1] + # vacc1x4567 += vb4567 * va1[5] + VMLAL.S16 q11, d11, d3[1] + + # vacc2x0123 += vb0123 * va2[5] + VMLAL.S16 q12, d10, d5[1] + # vacc2x4567 += vb4567 * va2[5] + VMLAL.S16 q13, d11, d5[1] + + # vacc3x0123 += vb0123 * va3[5] + VMLAL.S16 q14, d10, d7[1] + # vacc3x4567 += vb4567 * va3[5] + VMLAL.S16 q15, d11, d7[1] + + ### Channel 6 ### + BLS 3f + + # Load b0-b7 (channel 6) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # q4 = b0:7 - vb_zero_point + # - d8 = vb0123 (channel 6) + # - d9 = vb4567 (channel 6) + VSUBL.U8 q4, d9, d15 + + # vacc0x0123 += vb0123 * va0[6] + VMLAL.S16 q8, d8, d1[2] + # vacc0x4567 += vb4567 * va0[6] + VMLAL.S16 q9, d9, d1[2] + + # vacc1x0123 += vb0123 * va1[6] + VMLAL.S16 q10, d8, d3[2] + # vacc1x4567 += vb4567 * va1[6] + VMLAL.S16 q11, d9, d3[2] + + # vacc2x0123 += vb0123 * va2[6] + VMLAL.S16 q12, d8, d5[2] + + # vacc2x4567 += vb4567 * va2[6] + VMLAL.S16 q13, d9, d5[2] + + # vacc3x0123 += vb0123 * va3[6] + VMLAL.S16 q14, d8, d7[2] + # vacc3x4567 += vb4567 * va3[6] + VMLAL.S16 q15, d9, d7[2] + + .p2align 4 +3: + SUBS r3, r3, 1 + BNE 0b + + # Load right_shift + # - q4 = d8:d9 = vright_shift + VLD1.32 {d8[], d9[]}, [r9]! + + VQRDMULH.S32 q8, q8, d12[0] + VQRDMULH.S32 q9, q9, d12[0] + VQRDMULH.S32 q10, q10, d12[0] + VQRDMULH.S32 q11, q11, d12[0] + + # Compute vzero_shift_mask + # - q5 = vzero_shift_mask + VCEQ.S32 q5, q4, 0 + + VQRDMULH.S32 q12, q12, d12[0] + VQRDMULH.S32 q13, q13, d12[0] + VQRDMULH.S32 q14, q14, d12[0] + VQRDMULH.S32 q15, q15, d12[0] + + VBIC q0, q8, q5 + VBIC q1, q9, q5 + VBIC q2, q10, q5 + VBIC q3, q11, q5 + + VSRA.S32 q8, q0, 31 + VSRA.S32 q9, q1, 31 + VSRA.S32 q10, q2, 31 + VSRA.S32 q11, q3, 31 + + # Load output_zero_point + # - q7 = d14:d15 = voutput_zero_point + VLD1.16 {d14[], d15[]}, [r9]! + + VBIC q0, q12, q5 + VBIC q1, q13, q5 + VBIC q2, q14, q5 + VBIC q3, q15, q5 + + VSRA.S32 q12, q0, 31 + VSRA.S32 q13, q1, 31 + VSRA.S32 q14, q2, 31 + VSRA.S32 q15, q3, 31 + + # Load max: + # - q5 = d10:d11 = voutput_max + VLD1.8 {d10[], d11[]}, [r9]! + + VRSHL.S32 q8, q8, q4 + VRSHL.S32 q9, q9, q4 + VRSHL.S32 q10, q10, q4 + VRSHL.S32 q11, q11, q4 + VRSHL.S32 q12, q12, q4 + VRSHL.S32 q13, q13, q4 + VRSHL.S32 q14, q14, q4 + VRSHL.S32 q15, q15, q4 + + # Load c, c_stride: + # - r2 = c + # - r3 = c_stride + LDRD r2, r3, [sp, 104] + + VQMOVN.S32 d16, q8 + VQMOVN.S32 d17, q9 + VQMOVN.S32 d18, q10 + VQMOVN.S32 d19, q11 + VQMOVN.S32 d20, q12 + VQMOVN.S32 d21, q13 + VQMOVN.S32 d22, q14 + VQMOVN.S32 d23, q15 + + # Load min: + # - q4 = q8:q9 = voutput_min + VLD1.8 {d8[], d9[]}, [r9]! + ADD r4, r2, r3 + + VQADD.S16 q8, q8, q7 + VQADD.S16 q9, q9, q7 + CMP r0, 2 + VQADD.S16 q10, q10, q7 + VQADD.S16 q11, q11, q7 + MOVLO r4, r2 + + VQMOVUN.S16 d16, q8 + VQMOVUN.S16 d17, q9 + ADD r5, r4, r3 + VQMOVUN.S16 d18, q10 + VQMOVUN.S16 d19, q11 + MOVLS r5, r4 + + VMIN.U8 q8, q8, q5 + CMP r0, 4 + VMIN.U8 q9, q9, q5 + ADD r3, r5, r3 + + VMAX.U8 q8, q8, q4 + MOVNE r3, r5 + CMP r1, 8 + VMAX.U8 q9, q9, q4 + + BNE 5f + + VST1.8 {d16}, [r2] + VST1.8 {d17}, [r4] + VST1.8 {d18}, [r5] + VST1.8 {d19}, [r3] + + VPOP {d8-d15} + POP {r4, r5, r6, r7, r8, r9, r10, r11} + BX lr + + .p2align 3 +5: + CMP r1, 4 + BLO 6f + + VST1.32 {d16[0]}, [r2]! + VST1.32 {d17[0]}, [r4]! + VST1.32 {d18[0]}, [r5]! + VST1.32 {d19[0]}, [r3]! + + SUB r1, 4 + VEXT.8 q8, q8, q8, 4 + VEXT.8 q9, q9, q9, 4 + +6: + CMP r1, 2 + BLO 7f + + VST1.16 {d16[0]}, [r2]! + VST1.16 {d17[0]}, [r4]! + VST1.16 {d18[0]}, [r5]! + VST1.16 {d19[0]}, [r3]! + + SUB r1, 2 + VEXT.8 q8, q8, q8, 2 + VEXT.8 q9, q9, q9, 2 + +7: + TEQ r1, 0 + BEQ 8f + + VST1.8 {d16[0]}, [r2] + VST1.8 {d17[0]}, [r4] + VST1.8 {d18[0]}, [r5] + VST1.8 {d19[0]}, [r3] + +8: + VPOP {d8-d15} + POP {r4, r5, r6, r7, r8, r9, r10, r11} + BX lr +END_FUNCTION pytorch_q8conv_ukernel_4x8__aarch32_neon + +#ifdef __ELF__ +.section ".note.GNU-stack","",%progbits +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-neon.c new file mode 100644 index 0000000000000..2316ee0442276 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/4x8-neon.c @@ -0,0 +1,686 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +void pytorch_q8conv_ukernel_4x8__neon( + size_t mr, + size_t nr, + size_t kc, + size_t ks, + const uint8_t** restrict a, + const void* restrict w, + uint8_t* restrict c, + size_t c_stride, + const union pytorch_qnnp_conv_quantization_params + quantization_params[restrict static 1]) { + const uint8x8_t va_zero_point = + vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point); + const uint8x8_t vb_zero_point = + vld1_dup_u8((const uint8_t*)&quantization_params->neon.kernel_zero_point); + + int32x4_t vacc0x0123 = vld1q_s32(w); + w = (void*)((uintptr_t)w + sizeof(int32x4_t)); + int32x4_t vacc0x4567 = vld1q_s32(w); + w = (void*)((uintptr_t)w + sizeof(int32x4_t)); + int32x4_t vacc1x0123 = vacc0x0123; + int32x4_t vacc1x4567 = vacc0x4567; + int32x4_t vacc2x0123 = vacc0x0123; + int32x4_t vacc2x4567 = vacc0x4567; + int32x4_t vacc3x0123 = vacc0x0123; + int32x4_t vacc3x4567 = vacc0x4567; + + do { + const uint8_t* restrict a0 = *a++; + const uint8_t* restrict a1 = *a++; + const uint8_t* restrict a2 = *a++; + const uint8_t* restrict a3 = *a++; + + size_t k = kc; + for (; k >= 8; k -= 8) { + const uint8x8_t va0 = vld1_u8(a0); + a0 += 8; + const uint8x8_t va1 = vld1_u8(a1); + a1 += 8; + const uint8x8_t va2 = vld1_u8(a2); + a2 += 8; + const uint8x8_t va3 = vld1_u8(a3); + a3 += 8; + const int16x8_t vxa0 = + vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point)); + const int16x8_t vxa1 = + vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point)); + const int16x8_t vxa2 = + vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point)); + const int16x8_t vxa3 = + vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point)); + + { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 0); + } + + { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 1); + } + + { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 2); + } + + { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 3); + } + + { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 0); + } + + { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 1); + } + + { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 2); + } + + { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 3); + } + } + if (k != 0) { + const size_t a_predecrement = 8 - k; + const int64x1_t va_shift = vmov_n_s64(-8 * a_predecrement); + const uint8x8_t va0 = vreinterpret_u8_u64(vshl_u64( + vreinterpret_u64_u8(vld1_u8(a0 - a_predecrement)), va_shift)); + const uint8x8_t va1 = vreinterpret_u8_u64(vshl_u64( + vreinterpret_u64_u8(vld1_u8(a1 - a_predecrement)), va_shift)); + const uint8x8_t va2 = vreinterpret_u8_u64(vshl_u64( + vreinterpret_u64_u8(vld1_u8(a2 - a_predecrement)), va_shift)); + const uint8x8_t va3 = vreinterpret_u8_u64(vshl_u64( + vreinterpret_u64_u8(vld1_u8(a3 - a_predecrement)), va_shift)); + const int16x8_t vxa0 = + vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point)); + const int16x8_t vxa1 = + vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point)); + const int16x8_t vxa2 = + vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point)); + const int16x8_t vxa3 = + vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point)); + + { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 0); + } + + if (k >= 2) { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 1); + + if (k > 2) { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 2); + + if (k >= 4) { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 3); + + if (k > 4) { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa0), + 0); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa0), + 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa1), + 0); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa1), + 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa2), + 0); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa2), + 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa3), + 0); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa3), + 0); + + if (k >= 6) { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa0), + 1); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa0), + 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa1), + 1); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa1), + 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa2), + 1); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa2), + 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa3), + 1); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa3), + 1); + + if (k > 6) { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = vreinterpretq_s16_u16( + vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa0), + 2); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa0), + 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa1), + 2); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa1), + 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa2), + 2); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa2), + 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa3), + 2); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa3), + 2); + } + } + } + } + } + } + } + } while (--ks != 0); + + const int32x4_t vmultiplier = + vld1q_dup_s32(&quantization_params->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier); + vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier); + vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier); + vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier); + vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier); + + const int32x4_t vright_shift = + vld1q_dup_s32(&quantization_params->neon.right_shift); + const int32x4_t vzero_shift_mask = + vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = + vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = + vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc1x0123 = + vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31); + vacc1x4567 = + vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31); + vacc2x0123 = + vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31); + vacc2x4567 = + vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31); + vacc3x0123 = + vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31); + vacc3x4567 = + vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift); + vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift); + vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift); + vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift); + vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift); + + const int16x8_t voutput_zero_point = + vld1q_dup_s16(&quantization_params->neon.output_zero_point); +#ifdef __aarch64__ + const int16x8_t vacc0x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point); + + uint8x16_t vout0x01234567_1x01234567 = + vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc1x01234567); + uint8x16_t vout2x01234567_3x01234567 = + vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc3x01234567); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), + voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), + voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), + voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), + voutput_zero_point); + + uint8x16_t vout0x01234567_1x01234567 = + vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc1x01234567)); + uint8x16_t vout2x01234567_3x01234567 = + vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc3x01234567)); +#endif + const uint8x16_t voutput_min = + vld1q_dup_u8(&quantization_params->neon.output_min); + const uint8x16_t voutput_max = + vld1q_dup_u8(&quantization_params->neon.output_max); + + vout0x01234567_1x01234567 = vmaxq_u8(vout0x01234567_1x01234567, voutput_min); + vout2x01234567_3x01234567 = vmaxq_u8(vout2x01234567_3x01234567, voutput_min); + vout0x01234567_1x01234567 = vminq_u8(vout0x01234567_1x01234567, voutput_max); + vout2x01234567_3x01234567 = vminq_u8(vout2x01234567_3x01234567, voutput_max); + + uint8_t* c0 = c; + uint8_t* c1 = (uint8_t*)((uintptr_t)c0 + c_stride); + if (mr < 2) { + c1 = c0; + } + uint8_t* c2 = (uint8_t*)((uintptr_t)c1 + c_stride); + if (mr <= 2) { + c2 = c1; + } + uint8_t* c3 = (uint8_t*)((uintptr_t)c2 + c_stride); + if (mr != 4) { + c3 = c2; + } + if (nr == 8) { + vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567)); + vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567)); + vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567)); + vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567)); + } else { + if (nr >= 4) { + vst1q_lane_u32( + __builtin_assume_aligned(c0, 1), + vreinterpretq_u32_u8(vout0x01234567_1x01234567), + 0); + c0 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c1, 1), + vreinterpretq_u32_u8(vout0x01234567_1x01234567), + 2); + c1 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c2, 1), + vreinterpretq_u32_u8(vout2x01234567_3x01234567), + 0); + c2 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c3, 1), + vreinterpretq_u32_u8(vout2x01234567_3x01234567), + 2); + c3 += 4; + vout0x01234567_1x01234567 = + vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + vout2x01234567_3x01234567 = + vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4); + nr -= 4; + } + if (nr >= 2) { + vst1q_lane_u16( + __builtin_assume_aligned(c0, 1), + vreinterpretq_u16_u8(vout0x01234567_1x01234567), + 0); + c0 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c1, 1), + vreinterpretq_u16_u8(vout0x01234567_1x01234567), + 4); + c1 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c2, 1), + vreinterpretq_u16_u8(vout2x01234567_3x01234567), + 0); + c2 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c3, 1), + vreinterpretq_u16_u8(vout2x01234567_3x01234567), + 4); + c3 += 2; + vout0x01234567_1x01234567 = + vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + vout2x01234567_3x01234567 = + vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2); + nr -= 2; + } + if (nr != 0) { + vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0); + vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8); + vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0); + vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/8x8-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/8x8-aarch64-neon.S new file mode 100644 index 0000000000000..7a89fd3df3740 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/8x8-aarch64-neon.S @@ -0,0 +1,765 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + + +# void pytorch_q8conv_ukernel_8x8__aarch64_neon( +# size_t mr, +# size_t nr, +# size_t kc, +# size_t ks, +# const uint8_t** restrict a, +# const void* restrict w, +# uint8_t* restrict c, +# size_t c_stride, +# const union pytorch_qnnp_q31_requantization_params quantization_params[restrict static 1]) +BEGIN_FUNCTION pytorch_q8conv_ukernel_8x8__aarch64_neon + # Load params + LDR x8, [sp] + + STP d15, d14, [sp, -16] + STP d13, d12, [sp, -32] + STP d11, d10, [sp, -48] + STP d9, d8, [sp, -64] + + # Load bias0123, bias4567 + LD1 {v8.4s, v9.4s}, [x5], 32 + MOV x9, 2 + + # v10 := vacc1x0123 + MOV v10.16b, v8.16b + + # v11 := vacc1x4567 + MOV v11.16b, v9.16b + # Load b_zero_point + LD1R {v25.8b}, [x8], x9 + # Load a_zero_point + LD1R {v24.8b}, [x8], x9 + + # v12 := vacc2x0123 + MOV v12.16b, v8.16b + # v13 := vacc2x4567 + MOV v13.16b, v9.16b + + # v14 := vacc3x0123 + MOV v14.16b, v8.16b + # v15 := vacc3x4567 + MOV v15.16b, v9.16b + + # v16 := vacc4x0123 + MOV v16.16b, v8.16b + # v17 := vacc4x4567 + MOV v17.16b, v9.16b + + # v18 := vacc5x0123 + MOV v18.16b, v8.16b + # v19 := vacc5x4567 + MOV v19.16b, v9.16b + + # v20 := vacc6x0123 + MOV v20.16b, v8.16b + # v21 := vacc6x4567 + MOV v21.16b, v9.16b + + # v22 := vacc7x0123 + MOV v22.16b, v8.16b + # v23 := vacc7x4567 + MOV v23.16b, v9.16b + + # Load multiplier + # - v26 = vmultiplier + LD1R {v26.4s}, [x8], 4 + +#ifndef IGNORE_CODE_ALIGN_DIRECTIVES + .p2align 4 +#endif +3: + MOV x17, x2 + + LDR x16, [x4], 8 // a0 + LDR x9, [x4], 8 // a1 + LDR x10, [x4], 8 // a2 + LDR x11, [x4], 8 // a3 + LDR x12, [x4], 8 // a4 + LDR x13, [x4], 8 // a5 + LDR x14, [x4], 8 // a6 + LDR x15, [x4], 8 // a7 + + SUBS x17, x17, 8 + B.LO 1f + +#ifndef IGNORE_CODE_ALIGN_DIRECTIVES + .p2align 5 +#endif +0: + # b0-7 (channel 0) + LD1 {v27.8b}, [x5], 8 + USUBL v27.8h, v27.8b, v25.8b + + # va0 - va7 := va - va_offset + LD1 {v0.8b}, [x16], 8 + SUB_ZERO_POINT v0.8h, v0.8b, v24.8b + LD1 {v1.8b}, [x9], 8 + SUB_ZERO_POINT v1.8h, v1.8b, v24.8b + LD1 {v2.8b}, [x10], 8 + SUB_ZERO_POINT v2.8h, v2.8b, v24.8b + LD1 {v3.8b}, [x11], 8 + SUB_ZERO_POINT v3.8h, v3.8b, v24.8b + LD1 {v4.8b}, [x12], 8 + SUB_ZERO_POINT v4.8h, v4.8b, v24.8b + LD1 {v5.8b}, [x13], 8 + SUB_ZERO_POINT v5.8h, v5.8b, v24.8b + LD1 {v6.8b}, [x14], 8 + SUB_ZERO_POINT v6.8h, v6.8b, v24.8b + LD1 {v7.8b}, [x15], 8 + SUB_ZERO_POINT v7.8h, v7.8b, v24.8b + + // b0-7 (channel 1) + LD1 {v28.8b}, [x5], 8 + + SMLAL v8.4s, v27.4h, v0.h[0] // vacc0x0123 += vb0123 * va0[0] + SMLAL2 v9.4s, v27.8h, v0.h[0] // vacc0x4567 += vb4567 * va0[0] + SMLAL v10.4s, v27.4h, v1.h[0] // vacc1x0123 += vb0123 * va1[0] + SMLAL2 v11.4s, v27.8h, v1.h[0] // vacc1x4567 += vb4567 * va1[0] + SMLAL v12.4s, v27.4h, v2.h[0] // vacc2x0123 += vb0123 * va2[0] + SMLAL2 v13.4s, v27.8h, v2.h[0] // vacc2x4567 += vb4567 * va2[0] + SMLAL v14.4s, v27.4h, v3.h[0] // vacc3x0123 += vb0123 * va3[0] + SMLAL2 v15.4s, v27.8h, v3.h[0] // vacc3x4567 += vb4567 * va3[0] + USUBL v28.8h, v28.8b, v25.8b + SMLAL v16.4s, v27.4h, v4.h[0] // vacc4x0123 += vb0123 * va4[0] + SMLAL2 v17.4s, v27.8h, v4.h[0] // vacc4x4567 += vb4567 * va4[0] + SMLAL v18.4s, v27.4h, v5.h[0] // vacc5x0123 += vb0123 * va5[0] + SMLAL2 v19.4s, v27.8h, v5.h[0] // vacc5x4567 += vb4567 * va5[0] + SMLAL v20.4s, v27.4h, v6.h[0] // vacc6x0123 += vb0123 * va6[0] + SMLAL2 v21.4s, v27.8h, v6.h[0] // vacc6x4567 += vb4567 * va6[0] + SMLAL v22.4s, v27.4h, v7.h[0] // vacc7x0123 += vb0123 * va7[0] + SMLAL2 v23.4s, v27.8h, v7.h[0] // vacc7x4567 += vb4567 * va7[0] + + // b0-7 (channel 2) + LD1 {v27.8b}, [x5], 8 + + SMLAL v8.4s, v28.4h, v0.h[1] // vacc0x0123 += vb0123 * va0[1] + SMLAL2 v9.4s, v28.8h, v0.h[1] // vacc0x4567 += vb4567 * va0[1] + SMLAL v10.4s, v28.4h, v1.h[1] // vacc1x0123 += vb0123 * va1[1] + SMLAL2 v11.4s, v28.8h, v1.h[1] // vacc1x4567 += vb4567 * va1[1] + SMLAL v12.4s, v28.4h, v2.h[1] // vacc2x0123 += vb0123 * va2[1] + SMLAL2 v13.4s, v28.8h, v2.h[1] // vacc2x4567 += vb4567 * va2[1] + SMLAL v14.4s, v28.4h, v3.h[1] // vacc3x0123 += vb0123 * va3[1] + SMLAL2 v15.4s, v28.8h, v3.h[1] // vacc3x4567 += vb4567 * va3[1] + USUBL v27.8h, v27.8b, v25.8b + SMLAL v16.4s, v28.4h, v4.h[1] // vacc4x0123 += vb0123 * va4[1] + SMLAL2 v17.4s, v28.8h, v4.h[1] // vacc4x4567 += vb4567 * va4[1] + SMLAL v18.4s, v28.4h, v5.h[1] // vacc5x0123 += vb0123 * va5[1] + SMLAL2 v19.4s, v28.8h, v5.h[1] // vacc5x4567 += vb4567 * va5[1] + SMLAL v20.4s, v28.4h, v6.h[1] // vacc6x0123 += vb0123 * va6[1] + SMLAL2 v21.4s, v28.8h, v6.h[1] // vacc6x4567 += vb4567 * va6[1] + SMLAL v22.4s, v28.4h, v7.h[1] // vacc7x0123 += vb0123 * va7[1] + SMLAL2 v23.4s, v28.8h, v7.h[1] // vacc7x4567 += vb4567 * va7[1] + + // b0-7 (channel 3) + LD1 {v28.8b}, [x5], 8 + + SMLAL v8.4s, v27.4h, v0.h[2] // vacc0x0123 += vb0123 * va0[2] + SMLAL2 v9.4s, v27.8h, v0.h[2] // vacc0x4567 += vb4567 * va0[2] + SMLAL v10.4s, v27.4h, v1.h[2] // vacc1x0123 += vb0123 * va1[2] + SMLAL2 v11.4s, v27.8h, v1.h[2] // vacc1x4567 += vb4567 * va1[2] + SMLAL v12.4s, v27.4h, v2.h[2] // vacc2x0123 += vb0123 * va2[2] + SMLAL2 v13.4s, v27.8h, v2.h[2] // vacc2x4567 += vb4567 * va2[2] + SMLAL v14.4s, v27.4h, v3.h[2] // vacc3x0123 += vb0123 * va3[2] + SMLAL2 v15.4s, v27.8h, v3.h[2] // vacc3x4567 += vb4567 * va3[2] + USUBL v28.8h, v28.8b, v25.8b + SMLAL v16.4s, v27.4h, v4.h[2] // vacc4x0123 += vb0123 * va4[2] + SMLAL2 v17.4s, v27.8h, v4.h[2] // vacc4x4567 += vb4567 * va4[2] + SMLAL v18.4s, v27.4h, v5.h[2] // vacc5x0123 += vb0123 * va5[2] + SMLAL2 v19.4s, v27.8h, v5.h[2] // vacc5x4567 += vb4567 * va5[2] + SMLAL v20.4s, v27.4h, v6.h[2] // vacc6x0123 += vb0123 * va6[2] + SMLAL2 v21.4s, v27.8h, v6.h[2] // vacc6x4567 += vb4567 * va6[2] + SMLAL v22.4s, v27.4h, v7.h[2] // vacc7x0123 += vb0123 * va7[2] + SMLAL2 v23.4s, v27.8h, v7.h[2] // vacc7x4567 += vb4567 * va7[2] + + // b0-7 (channel 4) + LD1 {v27.8b}, [x5], 8 + + SMLAL v8.4s, v28.4h, v0.h[3] // vacc0x0123 += vb0123 * va0[3] + SMLAL2 v9.4s, v28.8h, v0.h[3] // vacc0x4567 += vb4567 * va0[3] + SMLAL v10.4s, v28.4h, v1.h[3] // vacc1x0123 += vb0123 * va1[3] + SMLAL2 v11.4s, v28.8h, v1.h[3] // vacc1x4567 += vb4567 * va1[3] + SMLAL v12.4s, v28.4h, v2.h[3] // vacc2x0123 += vb0123 * va2[3] + SMLAL2 v13.4s, v28.8h, v2.h[3] // vacc2x4567 += vb4567 * va2[3] + SMLAL v14.4s, v28.4h, v3.h[3] // vacc3x0123 += vb0123 * va3[3] + SMLAL2 v15.4s, v28.8h, v3.h[3] // vacc3x4567 += vb4567 * va3[3] + USUBL v27.8h, v27.8b, v25.8b + SMLAL v16.4s, v28.4h, v4.h[3] // vacc4x0123 += vb0123 * va4[3] + SMLAL2 v17.4s, v28.8h, v4.h[3] // vacc4x4567 += vb4567 * va4[3] + SMLAL v18.4s, v28.4h, v5.h[3] // vacc5x0123 += vb0123 * va5[3] + SMLAL2 v19.4s, v28.8h, v5.h[3] // vacc5x4567 += vb4567 * va5[3] + SMLAL v20.4s, v28.4h, v6.h[3] // vacc6x0123 += vb0123 * va6[3] + SMLAL2 v21.4s, v28.8h, v6.h[3] // vacc6x4567 += vb4567 * va6[3] + SMLAL v22.4s, v28.4h, v7.h[3] // vacc7x0123 += vb0123 * va7[3] + SMLAL2 v23.4s, v28.8h, v7.h[3] // vacc7x4567 += vb4567 * va7[3] + + // b0-7 (channel 5) + LD1 {v28.8b}, [x5], 8 + + SMLAL v8.4s, v27.4h, v0.h[4] // vacc0x0123 += vb0123 * va0[4] + SMLAL2 v9.4s, v27.8h, v0.h[4] // vacc0x4567 += vb4567 * va0[4] + SMLAL v10.4s, v27.4h, v1.h[4] // vacc1x0123 += vb0123 * va1[4] + SMLAL2 v11.4s, v27.8h, v1.h[4] // vacc1x4567 += vb4567 * va1[4] + SMLAL v12.4s, v27.4h, v2.h[4] // vacc2x0123 += vb0123 * va2[4] + SMLAL2 v13.4s, v27.8h, v2.h[4] // vacc2x4567 += vb4567 * va2[4] + SMLAL v14.4s, v27.4h, v3.h[4] // vacc3x0123 += vb0123 * va3[4] + SMLAL2 v15.4s, v27.8h, v3.h[4] // vacc3x4567 += vb4567 * va3[4] + USUBL v28.8h, v28.8b, v25.8b + SMLAL v16.4s, v27.4h, v4.h[4] // vacc4x0123 += vb0123 * va4[4] + SMLAL2 v17.4s, v27.8h, v4.h[4] // vacc4x4567 += vb4567 * va4[4] + SMLAL v18.4s, v27.4h, v5.h[4] // vacc5x0123 += vb0123 * va5[4] + SMLAL2 v19.4s, v27.8h, v5.h[4] // vacc5x4567 += vb4567 * va5[4] + SMLAL v20.4s, v27.4h, v6.h[4] // vacc6x0123 += vb0123 * va6[4] + SMLAL2 v21.4s, v27.8h, v6.h[4] // vacc6x4567 += vb4567 * va6[4] + SMLAL v22.4s, v27.4h, v7.h[4] // vacc7x0123 += vb0123 * va7[4] + SMLAL2 v23.4s, v27.8h, v7.h[4] // vacc7x4567 += vb4567 * va7[4] + + // b0-7 (channel 6) + LD1 {v27.8b}, [x5], 8 + + SMLAL v8.4s, v28.4h, v0.h[5] // vacc0x0123 += vb0123 * va0[5] + SMLAL2 v9.4s, v28.8h, v0.h[5] // vacc0x4567 += vb4567 * va0[5] + SMLAL v10.4s, v28.4h, v1.h[5] // vacc1x0123 += vb0123 * va1[5] + SMLAL2 v11.4s, v28.8h, v1.h[5] // vacc1x4567 += vb4567 * va1[5] + SMLAL v12.4s, v28.4h, v2.h[5] // vacc2x0123 += vb0123 * va2[5] + SMLAL2 v13.4s, v28.8h, v2.h[5] // vacc2x4567 += vb4567 * va2[5] + SMLAL v14.4s, v28.4h, v3.h[5] // vacc3x0123 += vb0123 * va3[5] + SMLAL2 v15.4s, v28.8h, v3.h[5] // vacc3x4567 += vb4567 * va3[5] + USUBL v27.8h, v27.8b, v25.8b + SMLAL v16.4s, v28.4h, v4.h[5] // vacc4x0123 += vb0123 * va4[5] + SMLAL2 v17.4s, v28.8h, v4.h[5] // vacc4x4567 += vb4567 * va4[5] + SMLAL v18.4s, v28.4h, v5.h[5] // vacc5x0123 += vb0123 * va5[5] + SMLAL2 v19.4s, v28.8h, v5.h[5] // vacc5x4567 += vb4567 * va5[5] + SMLAL v20.4s, v28.4h, v6.h[5] // vacc6x0123 += vb0123 * va6[5] + SMLAL2 v21.4s, v28.8h, v6.h[5] // vacc6x4567 += vb4567 * va6[5] + SMLAL v22.4s, v28.4h, v7.h[5] // vacc7x0123 += vb0123 * va7[5] + SMLAL2 v23.4s, v28.8h, v7.h[5] // vacc7x4567 += vb4567 * va7[5] + + // b0-7 (channel 7) + LD1 {v28.8b}, [x5], 8 + + SMLAL v8.4s, v27.4h, v0.h[6] // vacc0x0123 += vb0123 * va0[6] + SMLAL2 v9.4s, v27.8h, v0.h[6] // vacc0x4567 += vb4567 * va0[6] + SMLAL v10.4s, v27.4h, v1.h[6] // vacc1x0123 += vb0123 * va1[6] + SMLAL2 v11.4s, v27.8h, v1.h[6] // vacc1x4567 += vb4567 * va1[6] + SMLAL v12.4s, v27.4h, v2.h[6] // vacc2x0123 += vb0123 * va2[6] + SMLAL2 v13.4s, v27.8h, v2.h[6] // vacc2x4567 += vb4567 * va2[6] + SMLAL v14.4s, v27.4h, v3.h[6] // vacc3x0123 += vb0123 * va3[6] + SMLAL2 v15.4s, v27.8h, v3.h[6] // vacc3x4567 += vb4567 * va3[6] + USUBL v28.8h, v28.8b, v25.8b + SMLAL v16.4s, v27.4h, v4.h[6] // vacc4x0123 += vb0123 * va4[6] + SMLAL2 v17.4s, v27.8h, v4.h[6] // vacc4x4567 += vb4567 * va4[6] + SMLAL v18.4s, v27.4h, v5.h[6] // vacc5x0123 += vb0123 * va5[6] + SMLAL2 v19.4s, v27.8h, v5.h[6] // vacc5x4567 += vb4567 * va5[6] + SMLAL v20.4s, v27.4h, v6.h[6] // vacc6x0123 += vb0123 * va6[6] + SMLAL2 v21.4s, v27.8h, v6.h[6] // vacc6x4567 += vb4567 * va6[6] + SMLAL v22.4s, v27.4h, v7.h[6] // vacc7x0123 += vb0123 * va7[6] + SMLAL2 v23.4s, v27.8h, v7.h[6] // vacc7x4567 += vb4567 * va7[6] + + SUBS x17, x17, 8 + + SMLAL v8.4s, v28.4h, v0.h[7] // vacc0x0123 += vb0123 * va0[7] + SMLAL2 v9.4s, v28.8h, v0.h[7] // vacc0x4567 += vb4567 * va0[7] + SMLAL v10.4s, v28.4h, v1.h[7] // vacc1x0123 += vb0123 * va1[7] + SMLAL2 v11.4s, v28.8h, v1.h[7] // vacc1x4567 += vb4567 * va1[7] + SMLAL v12.4s, v28.4h, v2.h[7] // vacc2x0123 += vb0123 * va2[7] + SMLAL2 v13.4s, v28.8h, v2.h[7] // vacc2x4567 += vb4567 * va2[7] + SMLAL v14.4s, v28.4h, v3.h[7] // vacc3x0123 += vb0123 * va3[7] + SMLAL2 v15.4s, v28.8h, v3.h[7] // vacc3x4567 += vb4567 * va3[7] + SMLAL v16.4s, v28.4h, v4.h[7] // vacc4x0123 += vb0123 * va4[7] + SMLAL2 v17.4s, v28.8h, v4.h[7] // vacc4x4567 += vb4567 * va4[7] + SMLAL v18.4s, v28.4h, v5.h[7] // vacc5x0123 += vb0123 * va5[7] + SMLAL2 v19.4s, v28.8h, v5.h[7] // vacc5x4567 += vb4567 * va5[7] + SMLAL v20.4s, v28.4h, v6.h[7] // vacc6x0123 += vb0123 * va6[7] + SMLAL2 v21.4s, v28.8h, v6.h[7] // vacc6x4567 += vb4567 * va6[7] + SMLAL v22.4s, v28.4h, v7.h[7] // vacc7x0123 += vb0123 * va7[7] + SMLAL2 v23.4s, v28.8h, v7.h[7] // vacc7x4567 += vb4567 * va7[7] + + B.HS 0b + +1: + CMP x17, -8 + B.EQ 2f + + // Adjust a0-a7 + ADD x16, x16, x17 + ADD x9, x9, x17 + ADD x10, x10, x17 + ADD x11, x11, x17 + ADD x12, x12, x17 + ADD x13, x13, x17 + ADD x14, x14, x17 + ADD x15, x15, x17 + + // a_shift = 8 * k - 64 + LSL x17, x17, 3 + FMOV d29, x17 + USHL d31, d24, d29 + + // Load x0-a7 + LD1 {v0.8b}, [x16], 8 + USHL d0, d0, d29 + SUB_ZERO_POINT v0.8h, v0.8b, v24.8b + + LD1 {v1.8b}, [x9], 8 + USHL d1, d1, d29 + SUB_ZERO_POINT v1.8h, v1.8b, v24.8b + + LD1 {v2.8b}, [x10], 8 + USHL d2, d2, d29 + SUB_ZERO_POINT v2.8h, v2.8b, v24.8b + + LD1 {v3.8b}, [x11], 8 + USHL d3, d3, d29 + SUB_ZERO_POINT v3.8h, v3.8b, v24.8b + + LD1 {v4.8b}, [x12], 8 + USHL d4, d4, d29 + SUB_ZERO_POINT v4.8h, v4.8b, v24.8b + + LD1 {v5.8b}, [x13], 8 + USHL d5, d5, d29 + SUB_ZERO_POINT v5.8h, v5.8b, v24.8b + + LD1 {v6.8b}, [x14], 8 + USHL d6, d6, d29 + SUB_ZERO_POINT v6.8h, v6.8b, v24.8b + + LD1 {v7.8b}, [x15], 8 + USHL d7, d7, d29 + SUB_ZERO_POINT v7.8h, v7.8b, v24.8b + + // Channel 0 + LD1 {v27.8b}, [x5], 8 + USUBL v27.8h, v27.8b, v25.8b + + SMLAL v8.4s, v27.4h, v0.h[0] // vacc0x0123 += vb0123 * va0[0] + SMLAL2 v9.4s, v27.8h, v0.h[0] // vacc0x4567 += vb4567 * va0[0] + SMLAL v10.4s, v27.4h, v1.h[0] // vacc1x0123 += vb0123 * va1[0] + SMLAL2 v11.4s, v27.8h, v1.h[0] // vacc1x4567 += vb4567 * va1[0] + SMLAL v12.4s, v27.4h, v2.h[0] // vacc2x0123 += vb0123 * va2[0] + SMLAL2 v13.4s, v27.8h, v2.h[0] // vacc2x4567 += vb4567 * va2[0] + SMLAL v14.4s, v27.4h, v3.h[0] // vacc3x0123 += vb0123 * va3[0] + SMLAL2 v15.4s, v27.8h, v3.h[0] // vacc3x4567 += vb4567 * va3[0] + SMLAL v16.4s, v27.4h, v4.h[0] // vacc4x0123 += vb0123 * va4[0] + SMLAL2 v17.4s, v27.8h, v4.h[0] // vacc4x4567 += vb4567 * va4[0] + SMLAL v18.4s, v27.4h, v5.h[0] // vacc5x0123 += vb0123 * va5[0] + SMLAL2 v19.4s, v27.8h, v5.h[0] // vacc5x4567 += vb4567 * va5[0] + SMLAL v20.4s, v27.4h, v6.h[0] // vacc6x0123 += vb0123 * va6[0] + SMLAL2 v21.4s, v27.8h, v6.h[0] // vacc6x4567 += vb4567 * va6[0] + SMLAL v22.4s, v27.4h, v7.h[0] // vacc7x0123 += vb0123 * va7[0] + SMLAL2 v23.4s, v27.8h, v7.h[0] // vacc7x4567 += vb4567 * va7[0] + + CMP x17, -48 + B.LO 2f + + // Channel 1 + LD1 {v28.8b}, [x5], 8 + USUBL v28.8h, v28.8b, v25.8b + + SMLAL v8.4s, v28.4h, v0.h[1] // vacc0x0123 += vb0123 * va0[1] + SMLAL2 v9.4s, v28.8h, v0.h[1] // vacc0x4567 += vb4567 * va0[1] + SMLAL v10.4s, v28.4h, v1.h[1] // vacc1x0123 += vb0123 * va1[1] + SMLAL2 v11.4s, v28.8h, v1.h[1] // vacc1x4567 += vb4567 * va1[1] + SMLAL v12.4s, v28.4h, v2.h[1] // vacc2x0123 += vb0123 * va2[1] + SMLAL2 v13.4s, v28.8h, v2.h[1] // vacc2x4567 += vb4567 * va2[1] + SMLAL v14.4s, v28.4h, v3.h[1] // vacc3x0123 += vb0123 * va3[1] + SMLAL2 v15.4s, v28.8h, v3.h[1] // vacc3x4567 += vb4567 * va3[1] + SMLAL v16.4s, v28.4h, v4.h[1] // vacc4x0123 += vb0123 * va4[1] + SMLAL2 v17.4s, v28.8h, v4.h[1] // vacc4x4567 += vb4567 * va4[1] + SMLAL v18.4s, v28.4h, v5.h[1] // vacc5x0123 += vb0123 * va5[1] + SMLAL2 v19.4s, v28.8h, v5.h[1] // vacc5x4567 += vb4567 * va5[1] + SMLAL v20.4s, v28.4h, v6.h[1] // vacc6x0123 += vb0123 * va6[1] + SMLAL2 v21.4s, v28.8h, v6.h[1] // vacc6x4567 += vb4567 * va6[1] + SMLAL v22.4s, v28.4h, v7.h[1] // vacc7x0123 += vb0123 * va7[1] + SMLAL2 v23.4s, v28.8h, v7.h[1] // vacc7x4567 += vb4567 * va7[1] + + B.LS 2f + + // Channel 2 + LD1 {v27.8b}, [x5], 8 + USUBL v27.8h, v27.8b, v25.8b + + SMLAL v8.4s, v27.4h, v0.h[2] // vacc0x0123 += vb0123 * va0[2] + SMLAL2 v9.4s, v27.8h, v0.h[2] // vacc0x4567 += vb4567 * va0[2] + SMLAL v10.4s, v27.4h, v1.h[2] // vacc1x0123 += vb0123 * va1[2] + SMLAL2 v11.4s, v27.8h, v1.h[2] // vacc1x4567 += vb4567 * va1[2] + SMLAL v12.4s, v27.4h, v2.h[2] // vacc2x0123 += vb0123 * va2[2] + SMLAL2 v13.4s, v27.8h, v2.h[2] // vacc2x4567 += vb4567 * va2[2] + SMLAL v14.4s, v27.4h, v3.h[2] // vacc3x0123 += vb0123 * va3[2] + SMLAL2 v15.4s, v27.8h, v3.h[2] // vacc3x4567 += vb4567 * va3[2] + SMLAL v16.4s, v27.4h, v4.h[2] // vacc4x0123 += vb0123 * va4[2] + SMLAL2 v17.4s, v27.8h, v4.h[2] // vacc4x4567 += vb4567 * va4[2] + SMLAL v18.4s, v27.4h, v5.h[2] // vacc5x0123 += vb0123 * va5[2] + SMLAL2 v19.4s, v27.8h, v5.h[2] // vacc5x4567 += vb4567 * va5[2] + SMLAL v20.4s, v27.4h, v6.h[2] // vacc6x0123 += vb0123 * va6[2] + SMLAL2 v21.4s, v27.8h, v6.h[2] // vacc6x4567 += vb4567 * va6[2] + SMLAL v22.4s, v27.4h, v7.h[2] // vacc7x0123 += vb0123 * va7[2] + SMLAL2 v23.4s, v27.8h, v7.h[2] // vacc7x4567 += vb4567 * va7[2] + + CMP x17, -32 + B.LO 2f + + // Channel 3 + LD1 {v28.8b}, [x5], 8 + USUBL v28.8h, v28.8b, v25.8b + + SMLAL v8.4s, v28.4h, v0.h[3] // vacc0x0123 += vb0123 * va0[3] + SMLAL2 v9.4s, v28.8h, v0.h[3] // vacc0x4567 += vb4567 * va0[3] + SMLAL v10.4s, v28.4h, v1.h[3] // vacc1x0123 += vb0123 * va1[3] + SMLAL2 v11.4s, v28.8h, v1.h[3] // vacc1x4567 += vb4567 * va1[3] + SMLAL v12.4s, v28.4h, v2.h[3] // vacc2x0123 += vb0123 * va2[3] + SMLAL2 v13.4s, v28.8h, v2.h[3] // vacc2x4567 += vb4567 * va2[3] + SMLAL v14.4s, v28.4h, v3.h[3] // vacc3x0123 += vb0123 * va3[3] + SMLAL2 v15.4s, v28.8h, v3.h[3] // vacc3x4567 += vb4567 * va3[3] + SMLAL v16.4s, v28.4h, v4.h[3] // vacc4x0123 += vb0123 * va4[3] + SMLAL2 v17.4s, v28.8h, v4.h[3] // vacc4x4567 += vb4567 * va4[3] + SMLAL v18.4s, v28.4h, v5.h[3] // vacc5x0123 += vb0123 * va5[3] + SMLAL2 v19.4s, v28.8h, v5.h[3] // vacc5x4567 += vb4567 * va5[3] + SMLAL v20.4s, v28.4h, v6.h[3] // vacc6x0123 += vb0123 * va6[3] + SMLAL2 v21.4s, v28.8h, v6.h[3] // vacc6x4567 += vb4567 * va6[3] + SMLAL v22.4s, v28.4h, v7.h[3] // vacc7x0123 += vb0123 * va7[3] + SMLAL2 v23.4s, v28.8h, v7.h[3] // vacc7x4567 += vb4567 * va7[3] + + B.LS 2f + + // Channel 4 + LD1 {v27.8b}, [x5], 8 + USUBL v27.8h, v27.8b, v25.8b + + SMLAL v8.4s, v27.4h, v0.h[4] // vacc0x0123 += vb0123 * va0[4] + SMLAL2 v9.4s, v27.8h, v0.h[4] // vacc0x4567 += vb4567 * va0[4] + SMLAL v10.4s, v27.4h, v1.h[4] // vacc1x0123 += vb0123 * va1[4] + SMLAL2 v11.4s, v27.8h, v1.h[4] // vacc1x4567 += vb4567 * va1[4] + SMLAL v12.4s, v27.4h, v2.h[4] // vacc2x0123 += vb0123 * va2[4] + SMLAL2 v13.4s, v27.8h, v2.h[4] // vacc2x4567 += vb4567 * va2[4] + SMLAL v14.4s, v27.4h, v3.h[4] // vacc3x0123 += vb0123 * va3[4] + SMLAL2 v15.4s, v27.8h, v3.h[4] // vacc3x4567 += vb4567 * va3[4] + SMLAL v16.4s, v27.4h, v4.h[4] // vacc4x0123 += vb0123 * va4[4] + SMLAL2 v17.4s, v27.8h, v4.h[4] // vacc4x4567 += vb4567 * va4[4] + SMLAL v18.4s, v27.4h, v5.h[4] // vacc5x0123 += vb0123 * va5[4] + SMLAL2 v19.4s, v27.8h, v5.h[4] // vacc5x4567 += vb4567 * va5[4] + SMLAL v20.4s, v27.4h, v6.h[4] // vacc6x0123 += vb0123 * va6[4] + SMLAL2 v21.4s, v27.8h, v6.h[4] // vacc6x4567 += vb4567 * va6[4] + SMLAL v22.4s, v27.4h, v7.h[4] // vacc7x0123 += vb0123 * va7[4] + SMLAL2 v23.4s, v27.8h, v7.h[4] // vacc7x4567 += vb4567 * va7[4] + + CMP x17, -16 + B.LO 2f + + // Channel 5 + LD1 {v28.8b}, [x5], 8 + USUBL v28.8h, v28.8b, v25.8b + + SMLAL v8.4s, v28.4h, v0.h[5] // vacc0x0123 += vb0123 * va0[5] + SMLAL2 v9.4s, v28.8h, v0.h[5] // vacc0x4567 += vb4567 * va0[5] + SMLAL v10.4s, v28.4h, v1.h[5] // vacc1x0123 += vb0123 * va1[5] + SMLAL2 v11.4s, v28.8h, v1.h[5] // vacc1x4567 += vb4567 * va1[5] + SMLAL v12.4s, v28.4h, v2.h[5] // vacc2x0123 += vb0123 * va2[5] + SMLAL2 v13.4s, v28.8h, v2.h[5] // vacc2x4567 += vb4567 * va2[5] + SMLAL v14.4s, v28.4h, v3.h[5] // vacc3x0123 += vb0123 * va3[5] + SMLAL2 v15.4s, v28.8h, v3.h[5] // vacc3x4567 += vb4567 * va3[5] + SMLAL v16.4s, v28.4h, v4.h[5] // vacc4x0123 += vb0123 * va4[5] + SMLAL2 v17.4s, v28.8h, v4.h[5] // vacc4x4567 += vb4567 * va4[5] + SMLAL v18.4s, v28.4h, v5.h[5] // vacc5x0123 += vb0123 * va5[5] + SMLAL2 v19.4s, v28.8h, v5.h[5] // vacc5x4567 += vb4567 * va5[5] + SMLAL v20.4s, v28.4h, v6.h[5] // vacc6x0123 += vb0123 * va6[5] + SMLAL2 v21.4s, v28.8h, v6.h[5] // vacc6x4567 += vb4567 * va6[5] + SMLAL v22.4s, v28.4h, v7.h[5] // vacc7x0123 += vb0123 * va7[5] + SMLAL2 v23.4s, v28.8h, v7.h[5] // vacc7x4567 += vb4567 * va7[5] + + B.LS 2f + + // Channel 6 + LD1 {v27.8b}, [x5], 8 + USUBL v27.8h, v27.8b, v25.8b + + SMLAL v8.4s, v27.4h, v0.h[6] // vacc0x0123 += vb0123 * va0[6] + SMLAL2 v9.4s, v27.8h, v0.h[6] // vacc0x4567 += vb4567 * va0[6] + SMLAL v10.4s, v27.4h, v1.h[6] // vacc1x0123 += vb0123 * va1[6] + SMLAL2 v11.4s, v27.8h, v1.h[6] // vacc1x4567 += vb4567 * va1[6] + SMLAL v12.4s, v27.4h, v2.h[6] // vacc2x0123 += vb0123 * va2[6] + SMLAL2 v13.4s, v27.8h, v2.h[6] // vacc2x4567 += vb4567 * va2[6] + SMLAL v14.4s, v27.4h, v3.h[6] // vacc3x0123 += vb0123 * va3[6] + SMLAL2 v15.4s, v27.8h, v3.h[6] // vacc3x4567 += vb4567 * va3[6] + SMLAL v16.4s, v27.4h, v4.h[6] // vacc4x0123 += vb0123 * va4[6] + SMLAL2 v17.4s, v27.8h, v4.h[6] // vacc4x4567 += vb4567 * va4[6] + SMLAL v18.4s, v27.4h, v5.h[6] // vacc5x0123 += vb0123 * va5[6] + SMLAL2 v19.4s, v27.8h, v5.h[6] // vacc5x4567 += vb4567 * va5[6] + SMLAL v20.4s, v27.4h, v6.h[6] // vacc6x0123 += vb0123 * va6[6] + SMLAL2 v21.4s, v27.8h, v6.h[6] // vacc6x4567 += vb4567 * va6[6] + SMLAL v22.4s, v27.4h, v7.h[6] // vacc7x0123 += vb0123 * va7[6] + SMLAL2 v23.4s, v27.8h, v7.h[6] // vacc7x4567 += vb4567 * va7[6] + +#ifndef IGNORE_CODE_ALIGN_DIRECTIVES + .p2align 4 +#endif +2: + + SUB x3, x3, 1 + CBNZ x3, 3b + + // Load right_shift: + // - v27 = vright_shift + LD1R {v27.4s}, [x8], 4 + + SQRDMULH v8.4s, v8.4s, v26.4s + SQRDMULH v9.4s, v9.4s, v26.4s + SQRDMULH v10.4s, v10.4s, v26.4s + SQRDMULH v11.4s, v11.4s, v26.4s + SQRDMULH v12.4s, v12.4s, v26.4s + SQRDMULH v13.4s, v13.4s, v26.4s + SQRDMULH v14.4s, v14.4s, v26.4s + SQRDMULH v15.4s, v15.4s, v26.4s + + // Compute vzero_shift_mask + // - v28 = vzero_shift_mask + CMEQ v28.4s, v27.4s, 0 + + SQRDMULH v16.4s, v16.4s, v26.4s + SQRDMULH v17.4s, v17.4s, v26.4s + SQRDMULH v18.4s, v18.4s, v26.4s + SQRDMULH v19.4s, v19.4s, v26.4s + SQRDMULH v20.4s, v20.4s, v26.4s + SQRDMULH v21.4s, v21.4s, v26.4s + SQRDMULH v22.4s, v22.4s, v26.4s + SQRDMULH v23.4s, v23.4s, v26.4s + + // Load zero_point: + // - v29 = vzero_point + LD1R {v29.8h}, [x8], 2 + + BIC v0.16b, v8.16b, v28.16b + BIC v1.16b, v9.16b, v28.16b + BIC v2.16b, v10.16b, v28.16b + BIC v3.16b, v11.16b, v28.16b + BIC v4.16b, v12.16b, v28.16b + BIC v5.16b, v13.16b, v28.16b + BIC v6.16b, v14.16b, v28.16b + BIC v7.16b, v15.16b, v28.16b + + SSRA v8.4s, v0.4s, 31 + SSRA v9.4s, v1.4s, 31 + SSRA v10.4s, v2.4s, 31 + SSRA v11.4s, v3.4s, 31 + SSRA v12.4s, v4.4s, 31 + SSRA v13.4s, v5.4s, 31 + SSRA v14.4s, v6.4s, 31 + SSRA v15.4s, v7.4s, 31 + + // Load max: + // - v30 = vmax + LD1R {v30.16b}, [x8], 1 + + BIC v0.16b, v16.16b, v28.16b + BIC v1.16b, v17.16b, v28.16b + BIC v2.16b, v18.16b, v28.16b + BIC v3.16b, v19.16b, v28.16b + BIC v4.16b, v20.16b, v28.16b + BIC v5.16b, v21.16b, v28.16b + BIC v6.16b, v22.16b, v28.16b + BIC v7.16b, v23.16b, v28.16b + + SSRA v16.4s, v0.4s, 31 + SSRA v17.4s, v1.4s, 31 + SSRA v18.4s, v2.4s, 31 + SSRA v19.4s, v3.4s, 31 + SSRA v20.4s, v4.4s, 31 + SSRA v21.4s, v5.4s, 31 + SSRA v22.4s, v6.4s, 31 + SSRA v23.4s, v7.4s, 31 + + // Load min: + // - v31 = vmin + LD1R {v31.16b}, [x8] + + SRSHL v8.4s, v8.4s, v27.4s + SRSHL v9.4s, v9.4s, v27.4s + SRSHL v10.4s, v10.4s, v27.4s + SRSHL v11.4s, v11.4s, v27.4s + SRSHL v12.4s, v12.4s, v27.4s + SRSHL v13.4s, v13.4s, v27.4s + SRSHL v14.4s, v14.4s, v27.4s + SRSHL v15.4s, v15.4s, v27.4s + SRSHL v16.4s, v16.4s, v27.4s + SRSHL v17.4s, v17.4s, v27.4s + SRSHL v18.4s, v18.4s, v27.4s + SRSHL v19.4s, v19.4s, v27.4s + SRSHL v20.4s, v20.4s, v27.4s + SRSHL v21.4s, v21.4s, v27.4s + SRSHL v22.4s, v22.4s, v27.4s + SRSHL v23.4s, v23.4s, v27.4s + + SQXTN v8.4h, v8.4s + SQXTN v10.4h, v10.4s + SQXTN v12.4h, v12.4s + SQXTN v14.4h, v14.4s + SQXTN v16.4h, v16.4s + SQXTN v18.4h, v18.4s + SQXTN v20.4h, v20.4s + SQXTN v22.4h, v22.4s + + SQXTN2 v8.8h, v9.4s + SQXTN2 v10.8h, v11.4s + SQXTN2 v12.8h, v13.4s + SQXTN2 v14.8h, v15.4s + SQXTN2 v16.8h, v17.4s + SQXTN2 v18.8h, v19.4s + SQXTN2 v20.8h, v21.4s + SQXTN2 v22.8h, v23.4s + + SQADD v8.8h, v8.8h, v29.8h + SQADD v10.8h, v10.8h, v29.8h + SQADD v12.8h, v12.8h, v29.8h + SQADD v14.8h, v14.8h, v29.8h + SQADD v16.8h, v16.8h, v29.8h + SQADD v18.8h, v18.8h, v29.8h + SQADD v20.8h, v20.8h, v29.8h + SQADD v22.8h, v22.8h, v29.8h + + SQXTUN v8.8b, v8.8h + SQXTUN v12.8b, v12.8h + SQXTUN v16.8b, v16.8h + SQXTUN v20.8b, v20.8h + + SQXTUN2 v8.16b, v10.8h + SQXTUN2 v12.16b, v14.8h + SQXTUN2 v16.16b, v18.8h + SQXTUN2 v20.16b, v22.8h + + UMIN v8.16b, v8.16b, v30.16b + UMIN v12.16b, v12.16b, v30.16b + UMIN v16.16b, v16.16b, v30.16b + UMIN v20.16b, v20.16b, v30.16b + + UMAX v8.16b, v8.16b, v31.16b + UMAX v12.16b, v12.16b, v31.16b + UMAX v16.16b, v16.16b, v31.16b + UMAX v20.16b, v20.16b, v31.16b + + // Compute c0-c7 + + ADD x9, x6, x7 + CMP x0, 2 + CSEL x9, x6, x9, LO + + ADD x10, x9, x7 + CSEL x10, x9, x10, LS + + ADD x11, x10, x7 + CMP x0, 4 + CSEL x11, x10, x11, LO + + ADD x12, x11, x7 + CSEL x12, x11, x12, LS + + ADD x13, x12, x7 + CMP x0, 6 + CSEL x13, x12, x13, LO + + ADD x14, x13, x7 + CSEL x14, x13, x14, LS + + ADD x15, x14, x7 + CMP x0, 8 + CSEL x15, x14, x15, NE + + CMP x1, 8 + B.NE 4f + + // Store results + ST1 {v8.d}[0], [x6] + ST1 {v8.d}[1], [x9] + ST1 {v12.d}[0], [x10] + ST1 {v12.d}[1], [x11] + ST1 {v16.d}[0], [x12] + ST1 {v16.d}[1], [x13] + ST1 {v20.d}[0], [x14] + ST1 {v20.d}[1], [x15] + + LDP d9, d8, [sp, -64] + LDP d11, d10, [sp, -48] + LDP d13, d12, [sp, -32] + LDP d15, d14, [sp, -16] + + RET + +#ifndef IGNORE_CODE_ALIGN_DIRECTIVES + .p2align 3 +#endif +4: + CMP x1, 4 + B.LO 5f + + ST1 {v8.s}[0], [x6], 4 + ST1 {v8.s}[2], [x9], 4 + ST1 {v12.s}[0], [x10], 4 + ST1 {v12.s}[2], [x11], 4 + ST1 {v16.s}[0], [x12], 4 + ST1 {v16.s}[2], [x13], 4 + ST1 {v20.s}[0], [x14], 4 + ST1 {v20.s}[2], [x15], 4 + + SUB x1, x1, 4 + EXT v8.16b, v8.16b, v8.16b, 4 + EXT v12.16b, v12.16b, v12.16b, 4 + EXT v16.16b, v16.16b, v16.16b, 4 + EXT v20.16b, v20.16b, v20.16b, 4 + +5: + CMP x1, 2 + B.LO 6f + + ST1 {v8.h}[0], [x6], 2 + ST1 {v8.h}[4], [x9], 2 + ST1 {v12.h}[0], [x10], 2 + ST1 {v12.h}[4], [x11], 2 + ST1 {v16.h}[0], [x12], 2 + ST1 {v16.h}[4], [x13], 2 + ST1 {v20.h}[0], [x14], 2 + ST1 {v20.h}[4], [x15], 2 + + SUB x1, x1, 2 + EXT v8.16b, v8.16b, v8.16b, 2 + EXT v12.16b, v12.16b, v12.16b, 2 + EXT v16.16b, v16.16b, v16.16b, 2 + EXT v20.16b, v20.16b, v20.16b, 2 + +6: + CMP x1, 1 + B.LO 7f + + ST1 {v8.b}[0], [x6] + ST1 {v8.b}[8], [x9] + ST1 {v12.b}[0], [x10] + ST1 {v12.b}[8], [x11] + ST1 {v16.b}[0], [x12] + ST1 {v16.b}[8], [x13] + ST1 {v20.b}[0], [x14] + ST1 {v20.b}[8], [x15] + +7: + LDP d9, d8, [sp, -64] + LDP d11, d10, [sp, -48] + LDP d13, d12, [sp, -32] + LDP d15, d14, [sp, -16] + + RET + +END_FUNCTION pytorch_q8conv_ukernel_8x8__aarch64_neon + +#ifdef __ELF__ +.section ".note.GNU-stack","",%progbits +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/8x8-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/8x8-neon.c new file mode 100644 index 0000000000000..50e72ec2c1e17 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8conv/8x8-neon.c @@ -0,0 +1,1178 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +void pytorch_q8conv_ukernel_8x8__neon( + size_t mr, + size_t nr, + size_t kc, + size_t ks, + const uint8_t** restrict a, + const void* restrict w, + uint8_t* restrict c, + size_t c_stride, + const union pytorch_qnnp_conv_quantization_params + quantization_params[restrict static 1]) { + const uint8x8_t va_zero_point = + vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point); + const uint8x8_t vb_zero_point = + vld1_dup_u8((const uint8_t*)&quantization_params->neon.kernel_zero_point); + + int32x4_t vacc0x0123 = vld1q_s32(w); + w = (void*)((uintptr_t)w + sizeof(int32x4_t)); + int32x4_t vacc0x4567 = vld1q_s32(w); + w = (void*)((uintptr_t)w + sizeof(int32x4_t)); + int32x4_t vacc1x0123 = vacc0x0123; + int32x4_t vacc1x4567 = vacc0x4567; + int32x4_t vacc2x0123 = vacc0x0123; + int32x4_t vacc2x4567 = vacc0x4567; + int32x4_t vacc3x0123 = vacc0x0123; + int32x4_t vacc3x4567 = vacc0x4567; + int32x4_t vacc4x0123 = vacc0x0123; + int32x4_t vacc4x4567 = vacc0x4567; + int32x4_t vacc5x0123 = vacc0x0123; + int32x4_t vacc5x4567 = vacc0x4567; + int32x4_t vacc6x0123 = vacc0x0123; + int32x4_t vacc6x4567 = vacc0x4567; + int32x4_t vacc7x0123 = vacc0x0123; + int32x4_t vacc7x4567 = vacc0x4567; + + do { + const uint8_t* restrict a0 = *a++; + const uint8_t* restrict a1 = *a++; + const uint8_t* restrict a2 = *a++; + const uint8_t* restrict a3 = *a++; + const uint8_t* restrict a4 = *a++; + const uint8_t* restrict a5 = *a++; + const uint8_t* restrict a6 = *a++; + const uint8_t* restrict a7 = *a++; + + size_t k = kc; + for (; k >= 8; k -= 8) { + const uint8x8_t va0 = vld1_u8(a0); + a0 += 8; + const uint8x8_t va1 = vld1_u8(a1); + a1 += 8; + const uint8x8_t va2 = vld1_u8(a2); + a2 += 8; + const uint8x8_t va3 = vld1_u8(a3); + a3 += 8; + const uint8x8_t va4 = vld1_u8(a4); + a4 += 8; + const uint8x8_t va5 = vld1_u8(a5); + a5 += 8; + const uint8x8_t va6 = vld1_u8(a6); + a6 += 8; + const uint8x8_t va7 = vld1_u8(a7); + a7 += 8; + const int16x8_t vxa0 = + vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point)); + const int16x8_t vxa1 = + vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point)); + const int16x8_t vxa2 = + vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point)); + const int16x8_t vxa3 = + vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point)); + const int16x8_t vxa4 = + vreinterpretq_s16_u16(sub_zero_point(va4, va_zero_point)); + const int16x8_t vxa5 = + vreinterpretq_s16_u16(sub_zero_point(va5, va_zero_point)); + const int16x8_t vxa6 = + vreinterpretq_s16_u16(sub_zero_point(va6, va_zero_point)); + const int16x8_t vxa7 = + vreinterpretq_s16_u16(sub_zero_point(va7, va_zero_point)); + + { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 0); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa4), 0); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa4), 0); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa5), 0); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa5), 0); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa6), 0); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa6), 0); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa7), 0); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa7), 0); + } + + { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 1); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa4), 1); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa4), 1); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa5), 1); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa5), 1); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa6), 1); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa6), 1); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa7), 1); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa7), 1); + } + + { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 2); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa4), 2); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa4), 2); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa5), 2); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa5), 2); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa6), 2); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa6), 2); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa7), 2); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa7), 2); + } + + { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 3); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa4), 3); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa4), 3); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa5), 3); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa5), 3); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa6), 3); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa6), 3); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa7), 3); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa7), 3); + } + + { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 0); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa4), 0); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa4), 0); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa5), 0); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa5), 0); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa6), 0); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa6), 0); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa7), 0); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa7), 0); + } + + { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 1); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa4), 1); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa4), 1); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa5), 1); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa5), 1); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa6), 1); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa6), 1); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa7), 1); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa7), 1); + } + + { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 2); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa4), 2); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa4), 2); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa5), 2); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa5), 2); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa6), 2); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa6), 2); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa7), 2); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa7), 2); + } + + { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa3), 3); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa4), 3); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa4), 3); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa5), 3); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa5), 3); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa6), 3); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa6), 3); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567), vget_high_s16(vxa7), 3); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567), vget_high_s16(vxa7), 3); + } + } + if (k != 0) { + const size_t a_predecrement = 8 - k; + const int64x1_t va_shift = vmov_n_s64(-8 * a_predecrement); + const uint8x8_t va0 = vreinterpret_u8_u64(vshl_u64( + vreinterpret_u64_u8(vld1_u8(a0 - a_predecrement)), va_shift)); + const uint8x8_t va1 = vreinterpret_u8_u64(vshl_u64( + vreinterpret_u64_u8(vld1_u8(a1 - a_predecrement)), va_shift)); + const uint8x8_t va2 = vreinterpret_u8_u64(vshl_u64( + vreinterpret_u64_u8(vld1_u8(a2 - a_predecrement)), va_shift)); + const uint8x8_t va3 = vreinterpret_u8_u64(vshl_u64( + vreinterpret_u64_u8(vld1_u8(a3 - a_predecrement)), va_shift)); + const uint8x8_t va4 = vreinterpret_u8_u64(vshl_u64( + vreinterpret_u64_u8(vld1_u8(a4 - a_predecrement)), va_shift)); + const uint8x8_t va5 = vreinterpret_u8_u64(vshl_u64( + vreinterpret_u64_u8(vld1_u8(a5 - a_predecrement)), va_shift)); + const uint8x8_t va6 = vreinterpret_u8_u64(vshl_u64( + vreinterpret_u64_u8(vld1_u8(a6 - a_predecrement)), va_shift)); + const uint8x8_t va7 = vreinterpret_u8_u64(vshl_u64( + vreinterpret_u64_u8(vld1_u8(a7 - a_predecrement)), va_shift)); + const int16x8_t vxa0 = + vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point)); + const int16x8_t vxa1 = + vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point)); + const int16x8_t vxa2 = + vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point)); + const int16x8_t vxa3 = + vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point)); + const int16x8_t vxa4 = + vreinterpretq_s16_u16(sub_zero_point(va4, va_zero_point)); + const int16x8_t vxa5 = + vreinterpretq_s16_u16(sub_zero_point(va5, va_zero_point)); + const int16x8_t vxa6 = + vreinterpretq_s16_u16(sub_zero_point(va6, va_zero_point)); + const int16x8_t vxa7 = + vreinterpretq_s16_u16(sub_zero_point(va7, va_zero_point)); + + { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 0); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa4), 0); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa4), 0); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa5), 0); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa5), 0); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa6), 0); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa6), 0); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa7), 0); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa7), 0); + } + + if (k >= 2) { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 1); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa4), 1); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa4), 1); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa5), 1); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa5), 1); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa6), 1); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa6), 1); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa7), 1); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa7), 1); + + if (k > 2) { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 2); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa4), 2); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa4), 2); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa5), 2); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa5), 2); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa6), 2); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa6), 2); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa7), 2); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa7), 2); + + if (k >= 4) { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa3), 3); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa4), 3); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa4), 3); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa5), 3); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa5), 3); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa6), 3); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa6), 3); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567), vget_low_s16(vxa7), 3); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567), vget_low_s16(vxa7), 3); + + if (k > 4) { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa0), + 0); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa0), + 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa1), + 0); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa1), + 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa2), + 0); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa2), + 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa3), + 0); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa3), + 0); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa4), + 0); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa4), + 0); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa5), + 0); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa5), + 0); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa6), + 0); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa6), + 0); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa7), + 0); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa7), + 0); + + if (k >= 6) { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa0), + 1); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa0), + 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa1), + 1); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa1), + 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa2), + 1); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa2), + 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa3), + 1); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa3), + 1); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa4), + 1); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa4), + 1); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa5), + 1); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa5), + 1); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa6), + 1); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa6), + 1); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa7), + 1); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa7), + 1); + + if (k > 6) { + const uint8x8_t vb01234567 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const int16x8_t vxb01234567 = vreinterpretq_s16_u16( + vsubl_u8(vb01234567, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa0), + 2); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa0), + 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa1), + 2); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa1), + 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa2), + 2); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa2), + 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa3), + 2); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa3), + 2); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa4), + 2); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa4), + 2); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa5), + 2); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa5), + 2); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa6), + 2); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa6), + 2); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, + vget_low_s16(vxb01234567), + vget_high_s16(vxa7), + 2); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, + vget_high_s16(vxb01234567), + vget_high_s16(vxa7), + 2); + } + } + } + } + } + } + } + } while (--ks != 0); + + const int32x4_t vmultiplier = + vld1q_dup_s32(&quantization_params->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier); + vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier); + vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier); + vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier); + vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier); + vacc4x0123 = vqrdmulhq_s32(vacc4x0123, vmultiplier); + vacc4x4567 = vqrdmulhq_s32(vacc4x4567, vmultiplier); + vacc5x0123 = vqrdmulhq_s32(vacc5x0123, vmultiplier); + vacc5x4567 = vqrdmulhq_s32(vacc5x4567, vmultiplier); + vacc6x0123 = vqrdmulhq_s32(vacc6x0123, vmultiplier); + vacc6x4567 = vqrdmulhq_s32(vacc6x4567, vmultiplier); + vacc7x0123 = vqrdmulhq_s32(vacc7x0123, vmultiplier); + vacc7x4567 = vqrdmulhq_s32(vacc7x4567, vmultiplier); + + const int32x4_t vright_shift = + vld1q_dup_s32(&quantization_params->neon.right_shift); + const int32x4_t vzero_shift_mask = + vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = + vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = + vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc1x0123 = + vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31); + vacc1x4567 = + vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31); + vacc2x0123 = + vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31); + vacc2x4567 = + vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31); + vacc3x0123 = + vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31); + vacc3x4567 = + vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask), 31); + vacc4x0123 = + vsraq_n_s32(vacc4x0123, vbicq_s32(vacc4x0123, vzero_shift_mask), 31); + vacc4x4567 = + vsraq_n_s32(vacc4x4567, vbicq_s32(vacc4x4567, vzero_shift_mask), 31); + vacc5x0123 = + vsraq_n_s32(vacc5x0123, vbicq_s32(vacc5x0123, vzero_shift_mask), 31); + vacc5x4567 = + vsraq_n_s32(vacc5x4567, vbicq_s32(vacc5x4567, vzero_shift_mask), 31); + vacc6x0123 = + vsraq_n_s32(vacc6x0123, vbicq_s32(vacc6x0123, vzero_shift_mask), 31); + vacc6x4567 = + vsraq_n_s32(vacc6x4567, vbicq_s32(vacc6x4567, vzero_shift_mask), 31); + vacc7x0123 = + vsraq_n_s32(vacc7x0123, vbicq_s32(vacc7x0123, vzero_shift_mask), 31); + vacc7x4567 = + vsraq_n_s32(vacc7x4567, vbicq_s32(vacc7x4567, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift); + vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift); + vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift); + vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift); + vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift); + vacc4x0123 = vrshlq_s32(vacc4x0123, vright_shift); + vacc4x4567 = vrshlq_s32(vacc4x4567, vright_shift); + vacc5x0123 = vrshlq_s32(vacc5x0123, vright_shift); + vacc5x4567 = vrshlq_s32(vacc5x4567, vright_shift); + vacc6x0123 = vrshlq_s32(vacc6x0123, vright_shift); + vacc6x4567 = vrshlq_s32(vacc6x4567, vright_shift); + vacc7x0123 = vrshlq_s32(vacc7x0123, vright_shift); + vacc7x4567 = vrshlq_s32(vacc7x4567, vright_shift); + + const int16x8_t voutput_zero_point = + vld1q_dup_s16(&quantization_params->neon.output_zero_point); +#ifdef __aarch64__ + const int16x8_t vacc0x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point); + const int16x8_t vacc4x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc4x0123), vacc4x4567), voutput_zero_point); + const int16x8_t vacc5x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc5x0123), vacc5x4567), voutput_zero_point); + const int16x8_t vacc6x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc6x0123), vacc6x4567), voutput_zero_point); + const int16x8_t vacc7x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc7x0123), vacc7x4567), voutput_zero_point); + + uint8x16_t vout0x01234567_1x01234567 = + vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc1x01234567); + uint8x16_t vout2x01234567_3x01234567 = + vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc3x01234567); + uint8x16_t vout4x01234567_5x01234567 = + vqmovun_high_s16(vqmovun_s16(vacc4x01234567), vacc5x01234567); + uint8x16_t vout6x01234567_7x01234567 = + vqmovun_high_s16(vqmovun_s16(vacc6x01234567), vacc7x01234567); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), + voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), + voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), + voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), + voutput_zero_point); + const int16x8_t vacc4x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc4x0123), vqmovn_s32(vacc4x4567)), + voutput_zero_point); + const int16x8_t vacc5x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc5x0123), vqmovn_s32(vacc5x4567)), + voutput_zero_point); + const int16x8_t vacc6x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc6x0123), vqmovn_s32(vacc6x4567)), + voutput_zero_point); + const int16x8_t vacc7x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc7x0123), vqmovn_s32(vacc7x4567)), + voutput_zero_point); + + uint8x16_t vout0x01234567_1x01234567 = + vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc1x01234567)); + uint8x16_t vout2x01234567_3x01234567 = + vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc3x01234567)); + uint8x16_t vout4x01234567_5x01234567 = + vcombine_u8(vqmovun_s16(vacc4x01234567), vqmovun_s16(vacc5x01234567)); + uint8x16_t vout6x01234567_7x01234567 = + vcombine_u8(vqmovun_s16(vacc6x01234567), vqmovun_s16(vacc7x01234567)); +#endif + const uint8x16_t voutput_min = + vld1q_dup_u8(&quantization_params->neon.output_min); + const uint8x16_t voutput_max = + vld1q_dup_u8(&quantization_params->neon.output_max); + + vout0x01234567_1x01234567 = vmaxq_u8(vout0x01234567_1x01234567, voutput_min); + vout2x01234567_3x01234567 = vmaxq_u8(vout2x01234567_3x01234567, voutput_min); + vout4x01234567_5x01234567 = vmaxq_u8(vout4x01234567_5x01234567, voutput_min); + vout6x01234567_7x01234567 = vmaxq_u8(vout6x01234567_7x01234567, voutput_min); + vout0x01234567_1x01234567 = vminq_u8(vout0x01234567_1x01234567, voutput_max); + vout2x01234567_3x01234567 = vminq_u8(vout2x01234567_3x01234567, voutput_max); + vout4x01234567_5x01234567 = vminq_u8(vout4x01234567_5x01234567, voutput_max); + vout6x01234567_7x01234567 = vminq_u8(vout6x01234567_7x01234567, voutput_max); + + uint8_t* c0 = c; + uint8_t* c1 = (uint8_t*)((uintptr_t)c0 + c_stride); + if (mr < 2) { + c1 = c0; + } + uint8_t* c2 = (uint8_t*)((uintptr_t)c1 + c_stride); + if (mr <= 2) { + c2 = c1; + } + uint8_t* c3 = (uint8_t*)((uintptr_t)c2 + c_stride); + if (mr < 4) { + c3 = c2; + } + uint8_t* c4 = (uint8_t*)((uintptr_t)c3 + c_stride); + if (mr <= 4) { + c4 = c3; + } + uint8_t* c5 = (uint8_t*)((uintptr_t)c4 + c_stride); + if (mr < 6) { + c5 = c4; + } + uint8_t* c6 = (uint8_t*)((uintptr_t)c5 + c_stride); + if (mr <= 6) { + c6 = c5; + } + uint8_t* c7 = (uint8_t*)((uintptr_t)c6 + c_stride); + if (mr != 8) { + c7 = c6; + } + if (nr == 8) { + vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567)); + vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567)); + vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567)); + vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567)); + vst1_u8(c4, vget_low_u8(vout4x01234567_5x01234567)); + vst1_u8(c5, vget_high_u8(vout4x01234567_5x01234567)); + vst1_u8(c6, vget_low_u8(vout6x01234567_7x01234567)); + vst1_u8(c7, vget_high_u8(vout6x01234567_7x01234567)); + } else { + if (nr >= 4) { + vst1q_lane_u32( + __builtin_assume_aligned(c0, 1), + vreinterpretq_u32_u8(vout0x01234567_1x01234567), + 0); + c0 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c1, 1), + vreinterpretq_u32_u8(vout0x01234567_1x01234567), + 2); + c1 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c2, 1), + vreinterpretq_u32_u8(vout2x01234567_3x01234567), + 0); + c2 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c3, 1), + vreinterpretq_u32_u8(vout2x01234567_3x01234567), + 2); + c3 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c4, 1), + vreinterpretq_u32_u8(vout4x01234567_5x01234567), + 0); + c4 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c5, 1), + vreinterpretq_u32_u8(vout4x01234567_5x01234567), + 2); + c5 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c6, 1), + vreinterpretq_u32_u8(vout6x01234567_7x01234567), + 0); + c6 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c7, 1), + vreinterpretq_u32_u8(vout6x01234567_7x01234567), + 2); + c7 += 4; + vout0x01234567_1x01234567 = + vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + vout2x01234567_3x01234567 = + vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4); + vout4x01234567_5x01234567 = + vextq_u8(vout4x01234567_5x01234567, vout4x01234567_5x01234567, 4); + vout6x01234567_7x01234567 = + vextq_u8(vout6x01234567_7x01234567, vout6x01234567_7x01234567, 4); + nr -= 4; + } + if (nr >= 2) { + vst1q_lane_u16( + __builtin_assume_aligned(c0, 1), + vreinterpretq_u16_u8(vout0x01234567_1x01234567), + 0); + c0 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c1, 1), + vreinterpretq_u16_u8(vout0x01234567_1x01234567), + 4); + c1 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c2, 1), + vreinterpretq_u16_u8(vout2x01234567_3x01234567), + 0); + c2 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c3, 1), + vreinterpretq_u16_u8(vout2x01234567_3x01234567), + 4); + c3 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c4, 1), + vreinterpretq_u16_u8(vout4x01234567_5x01234567), + 0); + c4 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c5, 1), + vreinterpretq_u16_u8(vout4x01234567_5x01234567), + 4); + c5 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c6, 1), + vreinterpretq_u16_u8(vout6x01234567_7x01234567), + 0); + c6 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c7, 1), + vreinterpretq_u16_u8(vout6x01234567_7x01234567), + 4); + c7 += 2; + vout0x01234567_1x01234567 = + vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + vout2x01234567_3x01234567 = + vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2); + vout4x01234567_5x01234567 = + vextq_u8(vout4x01234567_5x01234567, vout4x01234567_5x01234567, 2); + vout6x01234567_7x01234567 = + vextq_u8(vout6x01234567_7x01234567, vout6x01234567_7x01234567, 2); + nr -= 2; + } + if (nr != 0) { + vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0); + vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8); + vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0); + vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8); + vst1q_lane_u8(c4, vout4x01234567_5x01234567, 0); + vst1q_lane_u8(c5, vout4x01234567_5x01234567, 8); + vst1q_lane_u8(c6, vout6x01234567_7x01234567, 0); + vst1q_lane_u8(c7, vout6x01234567_7x01234567, 8); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/mp8x25-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/mp8x25-neon.c new file mode 100644 index 0000000000000..85231d28272d3 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/mp8x25-neon.c @@ -0,0 +1,908 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_q8dwconv_ukernel_mp8x25__neon( + size_t channels, + size_t output_width, + const uint8_t** input, + const void* weights, + int32_t* outacc32, + uint8_t* output, + size_t input_stride, + size_t output_increment, + const union pytorch_qnnp_conv_quantization_params + quantization_params[restrict static 1]) { + const uint8x8_t vinput_zero_point = + vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point); + const uint8x8_t vkernel_zero_point = + vld1_dup_u8((const uint8_t*)&quantization_params->neon.kernel_zero_point); + const int32x4_t vmultiplier = + vld1q_dup_s32(&quantization_params->neon.multiplier); + const int32x4_t vright_shift = + vld1q_dup_s32(&quantization_params->neon.right_shift); + const int16x8_t vzero_point = + vld1q_dup_s16(&quantization_params->neon.output_zero_point); + const uint8x8_t vmin = vld1_dup_u8(&quantization_params->neon.output_min); + const uint8x8_t vmax = vld1_dup_u8(&quantization_params->neon.output_max); + + do { + uint8_t* output_start = output; + int32_t* outacc = outacc32; + const void* w = weights; + { + const uint8_t* i0 = input[0]; + const uint8_t* i1 = input[1]; + const uint8_t* i2 = input[2]; + const uint8_t* i3 = input[3]; + const uint8_t* i4 = input[4]; + const uint8_t* i5 = input[5]; + const uint8_t* i6 = input[6]; + const uint8_t* i7 = input[7]; + const uint8_t* i8 = input[8]; + const uint8_t* i9 = input[9]; + + size_t c = channels; + for (; c >= 8; c -= 8) { + int32x4_t vaccX1_lo = vld1q_s32(w); + w = (void*)((uintptr_t)w + sizeof(int32x4_t)); + int32x4_t vaccX1_hi = vld1q_s32(w); + w = (void*)((uintptr_t)w + sizeof(int32x4_t)); + + const uint8x8_t vk0 = vld1_u8(w); + w += 8; + const uint8x8_t vi0 = vld1_u8(i0); + i0 += 8; + const int16x8_t vxk0 = + vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point)); + const int16x8_t vxi0 = + vreinterpretq_s16_u16(vsubl_u8(vi0, vinput_zero_point)); + int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0)); + int32x4_t vaccX0_hi = + vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0)); + + const uint8x8_t vk1 = vld1_u8(w); + w += 8; + const uint8x8_t vi1 = vld1_u8(i1); + i1 += 8; + const int16x8_t vxk1 = + vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point)); + const int16x8_t vxi1 = + vreinterpretq_s16_u16(vsubl_u8(vi1, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1)); + + const uint8x8_t vk2 = vld1_u8(w); + w += 8; + const uint8x8_t vi2 = vld1_u8(i2); + i2 += 8; + const int16x8_t vxk2 = + vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point)); + const int16x8_t vxi2 = + vreinterpretq_s16_u16(vsubl_u8(vi2, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2)); + + const uint8x8_t vk3 = vld1_u8(w); + w += 8; + const uint8x8_t vi3 = vld1_u8(i3); + i3 += 8; + const int16x8_t vxk3 = + vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point)); + const int16x8_t vxi3 = + vreinterpretq_s16_u16(vsubl_u8(vi3, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3)); + + const uint8x8_t vk4 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi4 = vld1_u8(i4); + i4 += 8; + const int16x8_t vxk4 = + vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point)); + const int16x8_t vxi4 = + vreinterpretq_s16_u16(vsubl_u8(vi4, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4)); + + const uint8x8_t vk5 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi5 = vld1_u8(i5); + i5 += 8; + const int16x8_t vxk5 = + vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point)); + const int16x8_t vxi5 = + vreinterpretq_s16_u16(vsubl_u8(vi5, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5)); + + const uint8x8_t vk6 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi6 = vld1_u8(i6); + i6 += 8; + const int16x8_t vxk6 = + vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point)); + const int16x8_t vxi6 = + vreinterpretq_s16_u16(vsubl_u8(vi6, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6)); + + const uint8x8_t vk7 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi7 = vld1_u8(i7); + i7 += 8; + const int16x8_t vxk7 = + vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point)); + const int16x8_t vxi7 = + vreinterpretq_s16_u16(vsubl_u8(vi7, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7)); + + const uint8x8_t vk8 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi8 = vld1_u8(i8); + i8 += 8; + const int16x8_t vxk8 = + vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point)); + const int16x8_t vxi8 = + vreinterpretq_s16_u16(vsubl_u8(vi8, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8)); + + const uint8x8_t vk9 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi9 = vld1_u8(i9); + i9 += 8; + const int16x8_t vxk9 = + vreinterpretq_s16_u16(vsubl_u8(vk9, vkernel_zero_point)); + const int16x8_t vxi9 = + vreinterpretq_s16_u16(vsubl_u8(vi9, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk9), vget_low_s16(vxi9)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk9), vget_high_s16(vxi9)); + + int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo); + int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi); + + vst1q_s32(outacc, vacc_lo); + outacc += 4; + vst1q_s32(outacc, vacc_hi); + outacc += 4; + } + if (c != 0) { + const size_t c_predecrement = 8 - c; + const int64x1_t vi_shift = vmov_n_s64(-8 * c_predecrement); + i0 -= c_predecrement; + i1 -= c_predecrement; + i2 -= c_predecrement; + i3 -= c_predecrement; + i4 -= c_predecrement; + i5 -= c_predecrement; + i6 -= c_predecrement; + i7 -= c_predecrement; + i8 -= c_predecrement; + i9 -= c_predecrement; + + int32x4_t vaccX1_lo = vld1q_s32(w); + w = (void*)((uintptr_t)w + sizeof(int32x4_t)); + int32x4_t vaccX1_hi = vld1q_s32(w); + w = (void*)((uintptr_t)w + sizeof(int32x4_t)); + + const uint8x8_t vk0 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi0 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vi_shift)); + const int16x8_t vxk0 = + vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point)); + const int16x8_t vxi0 = + vreinterpretq_s16_u16(vsubl_u8(vi0, vinput_zero_point)); + int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0)); + int32x4_t vaccX0_hi = + vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0)); + + const uint8x8_t vk1 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi1 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vi_shift)); + const int16x8_t vxk1 = + vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point)); + const int16x8_t vxi1 = + vreinterpretq_s16_u16(vsubl_u8(vi1, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1)); + + const uint8x8_t vk2 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi2 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vi_shift)); + const int16x8_t vxk2 = + vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point)); + const int16x8_t vxi2 = + vreinterpretq_s16_u16(vsubl_u8(vi2, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2)); + + const uint8x8_t vk3 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi3 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vi_shift)); + const int16x8_t vxk3 = + vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point)); + const int16x8_t vxi3 = + vreinterpretq_s16_u16(vsubl_u8(vi3, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3)); + + const uint8x8_t vk4 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi4 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vi_shift)); + const int16x8_t vxk4 = + vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point)); + const int16x8_t vxi4 = + vreinterpretq_s16_u16(vsubl_u8(vi4, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4)); + + const uint8x8_t vk5 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi5 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vi_shift)); + const int16x8_t vxk5 = + vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point)); + const int16x8_t vxi5 = + vreinterpretq_s16_u16(vsubl_u8(vi5, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5)); + + const uint8x8_t vk6 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi6 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vi_shift)); + const int16x8_t vxk6 = + vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point)); + const int16x8_t vxi6 = + vreinterpretq_s16_u16(vsubl_u8(vi6, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6)); + + const uint8x8_t vk7 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi7 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vi_shift)); + const int16x8_t vxk7 = + vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point)); + const int16x8_t vxi7 = + vreinterpretq_s16_u16(vsubl_u8(vi7, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7)); + + const uint8x8_t vk8 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi8 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i8)), vi_shift)); + const int16x8_t vxk8 = + vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point)); + const int16x8_t vxi8 = + vreinterpretq_s16_u16(vsubl_u8(vi8, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8)); + + const uint8x8_t vk9 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi9 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i9)), vi_shift)); + const int16x8_t vxk9 = + vreinterpretq_s16_u16(vsubl_u8(vk9, vkernel_zero_point)); + const int16x8_t vxi9 = + vreinterpretq_s16_u16(vsubl_u8(vi9, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk9), vget_low_s16(vxi9)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk9), vget_high_s16(vxi9)); + + int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo); + int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi); + + vst1q_s32(outacc, vacc_lo); + outacc += 4; + vst1q_s32(outacc, vacc_hi); + outacc += 4; + } + } + { + const uint8_t* i0 = input[10]; + const uint8_t* i1 = input[11]; + const uint8_t* i2 = input[12]; + const uint8_t* i3 = input[13]; + const uint8_t* i4 = input[14]; + const uint8_t* i5 = input[15]; + const uint8_t* i6 = input[16]; + const uint8_t* i7 = input[17]; + const uint8_t* i8 = input[18]; + const uint8_t* i9 = input[19]; + output = output_start; + outacc = outacc32; + + size_t c = channels; + for (; c >= 8; c -= 8) { + const uint8x8_t vk0 = vld1_u8(w); + w += 8; + const uint8x8_t vi0 = vld1_u8(i0); + i0 += 8; + const int16x8_t vxk0 = + vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point)); + const int16x8_t vxi0 = + vreinterpretq_s16_u16(vsubl_u8(vi0, vinput_zero_point)); + int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0)); + int32x4_t vaccX0_hi = + vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0)); + + const uint8x8_t vk1 = vld1_u8(w); + w += 8; + const uint8x8_t vi1 = vld1_u8(i1); + i1 += 8; + const int16x8_t vxk1 = + vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point)); + const int16x8_t vxi1 = + vreinterpretq_s16_u16(vsubl_u8(vi1, vinput_zero_point)); + int32x4_t vaccX1_lo = vmull_s16(vget_low_s16(vxk1), vget_low_s16(vxi1)); + int32x4_t vaccX1_hi = + vmull_s16(vget_high_s16(vxk1), vget_high_s16(vxi1)); + + const uint8x8_t vk2 = vld1_u8(w); + w += 8; + const uint8x8_t vi2 = vld1_u8(i2); + i2 += 8; + const int16x8_t vxk2 = + vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point)); + const int16x8_t vxi2 = + vreinterpretq_s16_u16(vsubl_u8(vi2, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2)); + + const uint8x8_t vk3 = vld1_u8(w); + w += 8; + const uint8x8_t vi3 = vld1_u8(i3); + i3 += 8; + const int16x8_t vxk3 = + vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point)); + const int16x8_t vxi3 = + vreinterpretq_s16_u16(vsubl_u8(vi3, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3)); + + const uint8x8_t vk4 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi4 = vld1_u8(i4); + i4 += 8; + const int16x8_t vxk4 = + vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point)); + const int16x8_t vxi4 = + vreinterpretq_s16_u16(vsubl_u8(vi4, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4)); + + const uint8x8_t vk5 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi5 = vld1_u8(i5); + i5 += 8; + const int16x8_t vxk5 = + vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point)); + const int16x8_t vxi5 = + vreinterpretq_s16_u16(vsubl_u8(vi5, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5)); + + const uint8x8_t vk6 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi6 = vld1_u8(i6); + i6 += 8; + const int16x8_t vxk6 = + vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point)); + const int16x8_t vxi6 = + vreinterpretq_s16_u16(vsubl_u8(vi6, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6)); + + const uint8x8_t vk7 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi7 = vld1_u8(i7); + i7 += 8; + const int16x8_t vxk7 = + vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point)); + const int16x8_t vxi7 = + vreinterpretq_s16_u16(vsubl_u8(vi7, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7)); + + const uint8x8_t vk8 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi8 = vld1_u8(i8); + i8 += 8; + const int16x8_t vxk8 = + vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point)); + const int16x8_t vxi8 = + vreinterpretq_s16_u16(vsubl_u8(vi8, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8)); + + const uint8x8_t vk9 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi9 = vld1_u8(i9); + i9 += 8; + const int16x8_t vxk9 = + vreinterpretq_s16_u16(vsubl_u8(vk9, vkernel_zero_point)); + const int16x8_t vxi9 = + vreinterpretq_s16_u16(vsubl_u8(vi9, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk9), vget_low_s16(vxi9)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk9), vget_high_s16(vxi9)); + + int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo); + int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi); + + const int32x4_t vacc_lo_old = vld1q_s32(outacc); + const int32x4_t vacc_hi_old = vld1q_s32(outacc + 4); + vacc_lo = vaddq_s32(vacc_lo, vacc_lo_old); + vacc_hi = vaddq_s32(vacc_hi, vacc_hi_old); + vst1q_s32(outacc, vacc_lo); + outacc += 4; + vst1q_s32(outacc, vacc_hi); + outacc += 4; + } + if (c != 0) { + const size_t c_predecrement = 8 - c; + const int64x1_t vi_shift = vmov_n_s64(-8 * c_predecrement); + i0 -= c_predecrement; + i1 -= c_predecrement; + i2 -= c_predecrement; + i3 -= c_predecrement; + i4 -= c_predecrement; + i5 -= c_predecrement; + i6 -= c_predecrement; + i7 -= c_predecrement; + i8 -= c_predecrement; + i9 -= c_predecrement; + + const uint8x8_t vk0 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi0 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vi_shift)); + const int16x8_t vxk0 = + vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point)); + const int16x8_t vxi0 = + vreinterpretq_s16_u16(vsubl_u8(vi0, vinput_zero_point)); + int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0)); + int32x4_t vaccX0_hi = + vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0)); + + const uint8x8_t vk1 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi1 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vi_shift)); + const int16x8_t vxk1 = + vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point)); + const int16x8_t vxi1 = + vreinterpretq_s16_u16(vsubl_u8(vi1, vinput_zero_point)); + int32x4_t vaccX1_lo = vmull_s16(vget_low_s16(vxk1), vget_low_s16(vxi1)); + int32x4_t vaccX1_hi = + vmull_s16(vget_high_s16(vxk1), vget_high_s16(vxi1)); + + const uint8x8_t vk2 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi2 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vi_shift)); + const int16x8_t vxk2 = + vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point)); + const int16x8_t vxi2 = + vreinterpretq_s16_u16(vsubl_u8(vi2, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2)); + + const uint8x8_t vk3 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi3 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vi_shift)); + const int16x8_t vxk3 = + vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point)); + const int16x8_t vxi3 = + vreinterpretq_s16_u16(vsubl_u8(vi3, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3)); + + const uint8x8_t vk4 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi4 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vi_shift)); + const int16x8_t vxk4 = + vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point)); + const int16x8_t vxi4 = + vreinterpretq_s16_u16(vsubl_u8(vi4, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4)); + + const uint8x8_t vk5 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi5 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vi_shift)); + const int16x8_t vxk5 = + vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point)); + const int16x8_t vxi5 = + vreinterpretq_s16_u16(vsubl_u8(vi5, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5)); + + const uint8x8_t vk6 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi6 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vi_shift)); + const int16x8_t vxk6 = + vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point)); + const int16x8_t vxi6 = + vreinterpretq_s16_u16(vsubl_u8(vi6, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6)); + + const uint8x8_t vk7 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi7 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vi_shift)); + const int16x8_t vxk7 = + vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point)); + const int16x8_t vxi7 = + vreinterpretq_s16_u16(vsubl_u8(vi7, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7)); + + const uint8x8_t vk8 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi8 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i8)), vi_shift)); + const int16x8_t vxk8 = + vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point)); + const int16x8_t vxi8 = + vreinterpretq_s16_u16(vsubl_u8(vi8, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8)); + + const uint8x8_t vk9 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi9 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i9)), vi_shift)); + const int16x8_t vxk9 = + vreinterpretq_s16_u16(vsubl_u8(vk9, vkernel_zero_point)); + const int16x8_t vxi9 = + vreinterpretq_s16_u16(vsubl_u8(vi9, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk9), vget_low_s16(vxi9)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk9), vget_high_s16(vxi9)); + + int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo); + int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi); + + const int32x4_t vacc_lo_old = vld1q_s32(outacc); + const int32x4_t vacc_hi_old = vld1q_s32(outacc + 4); + vacc_lo = vaddq_s32(vacc_lo, vacc_lo_old); + vacc_hi = vaddq_s32(vacc_hi, vacc_hi_old); + vst1q_s32(outacc, vacc_lo); + outacc += 4; + vst1q_s32(outacc, vacc_hi); + outacc += 4; + } + } + + { + const uint8_t* i0 = input[20]; + const uint8_t* i1 = input[21]; + const uint8_t* i2 = input[22]; + const uint8_t* i3 = input[23]; + const uint8_t* i4 = input[24]; + input = (const uint8_t**)((uintptr_t)input + input_stride); + output = output_start; + outacc = outacc32; + + size_t c = channels; + for (; c >= 8; c -= 8) { + const uint8x8_t vk0 = vld1_u8(w); + w += 8; + const uint8x8_t vi0 = vld1_u8(i0); + i0 += 8; + const int16x8_t vxk0 = + vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point)); + const int16x8_t vxi0 = + vreinterpretq_s16_u16(vsubl_u8(vi0, vinput_zero_point)); + int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0)); + int32x4_t vaccX0_hi = + vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0)); + + const uint8x8_t vk1 = vld1_u8(w); + w += 8; + const uint8x8_t vi1 = vld1_u8(i1); + i1 += 8; + const int16x8_t vxk1 = + vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point)); + const int16x8_t vxi1 = + vreinterpretq_s16_u16(vsubl_u8(vi1, vinput_zero_point)); + int32x4_t vaccX1_lo = vmull_s16(vget_low_s16(vxk1), vget_low_s16(vxi1)); + int32x4_t vaccX1_hi = + vmull_s16(vget_high_s16(vxk1), vget_high_s16(vxi1)); + + const uint8x8_t vk2 = vld1_u8(w); + w += 8; + const uint8x8_t vi2 = vld1_u8(i2); + i2 += 8; + const int16x8_t vxk2 = + vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point)); + const int16x8_t vxi2 = + vreinterpretq_s16_u16(vsubl_u8(vi2, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2)); + + const uint8x8_t vk3 = vld1_u8(w); + w += 8; + const uint8x8_t vi3 = vld1_u8(i3); + i3 += 8; + const int16x8_t vxk3 = + vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point)); + const int16x8_t vxi3 = + vreinterpretq_s16_u16(vsubl_u8(vi3, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3)); + + const uint8x8_t vk4 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi4 = vld1_u8(i4); + i4 += 8; + const int16x8_t vxk4 = + vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point)); + const int16x8_t vxi4 = + vreinterpretq_s16_u16(vsubl_u8(vi4, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4)); + + int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo); + int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi); + + const int32x4_t vacc_lo_old = vld1q_s32(outacc); + outacc += 4; + const int32x4_t vacc_hi_old = vld1q_s32(outacc); + outacc += 4; + vacc_lo = vaddq_s32(vacc_lo, vacc_lo_old); + vacc_hi = vaddq_s32(vacc_hi, vacc_hi_old); + vacc_lo = vqrdmulhq_s32(vacc_lo, vmultiplier); + vacc_hi = vqrdmulhq_s32(vacc_hi, vmultiplier); + + const int32x4_t vzero_shift_mask = + vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc_lo = + vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31); + vacc_hi = + vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31); + + vacc_lo = vrshlq_s32(vacc_lo, vright_shift); + vacc_hi = vrshlq_s32(vacc_hi, vright_shift); + +#ifdef __aarch64__ + const int16x8_t vacc = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), vzero_point); +#else + const int16x8_t vacc = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), + vzero_point); +#endif + uint8x8_t vout = vqmovun_s16(vacc); + vout = vmax_u8(vout, vmin); + vout = vmin_u8(vout, vmax); + + vst1_u8(output, vout); + output += 8; + } + if (c != 0) { + const size_t c_predecrement = 8 - c; + const int64x1_t vi_shift = vmov_n_s64(-8 * c_predecrement); + i0 -= c_predecrement; + i1 -= c_predecrement; + i2 -= c_predecrement; + i3 -= c_predecrement; + i4 -= c_predecrement; + + const uint8x8_t vk0 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi0 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vi_shift)); + const int16x8_t vxk0 = + vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point)); + const int16x8_t vxi0 = + vreinterpretq_s16_u16(vsubl_u8(vi0, vinput_zero_point)); + int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0)); + int32x4_t vaccX0_hi = + vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0)); + + const uint8x8_t vk1 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi1 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vi_shift)); + const int16x8_t vxk1 = + vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point)); + const int16x8_t vxi1 = + vreinterpretq_s16_u16(vsubl_u8(vi1, vinput_zero_point)); + int32x4_t vaccX1_lo = vmull_s16(vget_low_s16(vxk1), vget_low_s16(vxi1)); + int32x4_t vaccX1_hi = + vmull_s16(vget_high_s16(vxk1), vget_high_s16(vxi1)); + + const uint8x8_t vk2 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi2 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vi_shift)); + const int16x8_t vxk2 = + vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point)); + const int16x8_t vxi2 = + vreinterpretq_s16_u16(vsubl_u8(vi2, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2)); + + const uint8x8_t vk3 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi3 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vi_shift)); + const int16x8_t vxk3 = + vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point)); + const int16x8_t vxi3 = + vreinterpretq_s16_u16(vsubl_u8(vi3, vinput_zero_point)); + vaccX1_lo = + vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3)); + + const uint8x8_t vk4 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi4 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vi_shift)); + const int16x8_t vxk4 = + vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point)); + const int16x8_t vxi4 = + vreinterpretq_s16_u16(vsubl_u8(vi4, vinput_zero_point)); + vaccX0_lo = + vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4)); + + int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo); + int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi); + + const int32x4_t vacc_lo_old = vld1q_s32(outacc); + const int32x4_t vacc_hi_old = vld1q_s32(outacc + 4); + vacc_lo = vaddq_s32(vacc_lo, vacc_lo_old); + vacc_hi = vaddq_s32(vacc_hi, vacc_hi_old); + + vacc_lo = vqrdmulhq_s32(vacc_lo, vmultiplier); + vacc_hi = vqrdmulhq_s32(vacc_hi, vmultiplier); + + const int32x4_t vzero_shift_mask = + vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc_lo = + vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31); + vacc_hi = + vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31); + + vacc_lo = vrshlq_s32(vacc_lo, vright_shift); + vacc_hi = vrshlq_s32(vacc_hi, vright_shift); + +#ifdef __aarch64__ + const int16x8_t vacc = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), vzero_point); +#else + const int16x8_t vacc = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), + vzero_point); +#endif + uint8x8_t vout = vqmovun_s16(vacc); + vout = vmax_u8(vout, vmin); + vout = vmin_u8(vout, vmax); + + if (c & 4) { + vst1_lane_u32( + __builtin_assume_aligned(output, 1), + vreinterpret_u32_u8(vout), + 0); + output += 4; + vout = vext_u8(vout, vout, 4); + } + if (c & 2) { + vst1_lane_u16( + __builtin_assume_aligned(output, 1), + vreinterpret_u16_u8(vout), + 0); + output += 2; + vout = vext_u8(vout, vout, 2); + } + if (c & 1) { + vst1_lane_u8(__builtin_assume_aligned(output, 1), vout, 0); + output++; + } + } + } + + output = (uint8_t*)((uintptr_t)output + output_increment); + } while (--output_width != 0); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/mp8x25-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/mp8x25-sse2.c new file mode 100644 index 0000000000000..f4b7d128da480 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/mp8x25-sse2.c @@ -0,0 +1,1138 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_q8dwconv_ukernel_mp8x25__sse2( + size_t channels, + size_t output_width, + const uint8_t** input, + const void* weights, + int32_t* outacc32, + uint8_t* output, + size_t input_stride, + size_t output_increment, + const union pytorch_qnnp_conv_quantization_params + quantization_params[RESTRICT_STATIC 1]) { + const __m128i vinput_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.input_zero_point); + const __m128i vkernel_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.kernel_zero_point); + const __m128i vzero = _mm_setzero_si128(); + + do { + int32_t* outacc = outacc32; + const void* w = weights; + { + const uint8_t* i00 = input[0]; + const uint8_t* i01 = input[1]; + const uint8_t* i02 = input[2]; + const uint8_t* i10 = input[3]; + const uint8_t* i11 = input[4]; + const uint8_t* i12 = input[5]; + const uint8_t* i20 = input[6]; + const uint8_t* i21 = input[7]; + const uint8_t* i22 = input[8]; + const uint8_t* i23 = input[9]; + + size_t c = channels; + for (; c >= 8; c -= 8) { + __m128i vacc_lo = _mm_loadu_si128((const __m128i*)w); + __m128i vacc_hi = _mm_loadu_si128((const __m128i*)((uintptr_t)w + 16)); + + const __m128i vi00 = _mm_loadl_epi64((const __m128i*)i00); + i00 += 8; + const __m128i vxi00 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi00, vzero), vinput_zero_point); + const __m128i vk00 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32)); + const __m128i vxk00 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk00, vzero), vkernel_zero_point); + const __m128i vprod00_odd = _mm_mullo_epi16(vxi00, vxk00); + const __m128i vprod00_even = _mm_mulhi_epi16(vxi00, vxk00); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod00_odd, vprod00_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod00_odd, vprod00_even)); + + const __m128i vi01 = _mm_loadl_epi64((const __m128i*)i01); + i01 += 8; + const __m128i vxi01 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi01, vzero), vinput_zero_point); + const __m128i vk01 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 40)); + const __m128i vxk01 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk01, vzero), vkernel_zero_point); + const __m128i vprod01_odd = _mm_mullo_epi16(vxi01, vxk01); + const __m128i vprod01_even = _mm_mulhi_epi16(vxi01, vxk01); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod01_odd, vprod01_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod01_odd, vprod01_even)); + + const __m128i vi02 = _mm_loadl_epi64((const __m128i*)i02); + i02 += 8; + const __m128i vxi02 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi02, vzero), vinput_zero_point); + const __m128i vk02 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 48)); + const __m128i vxk02 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk02, vzero), vkernel_zero_point); + const __m128i vprod02_odd = _mm_mullo_epi16(vxi02, vxk02); + const __m128i vprod02_even = _mm_mulhi_epi16(vxi02, vxk02); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod02_odd, vprod02_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod02_odd, vprod02_even)); + + const __m128i vi10 = _mm_loadl_epi64((const __m128i*)i10); + i10 += 8; + const __m128i vxi10 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi10, vzero), vinput_zero_point); + const __m128i vk10 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 56)); + const __m128i vxk10 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk10, vzero), vkernel_zero_point); + const __m128i vprod10_odd = _mm_mullo_epi16(vxi10, vxk10); + const __m128i vprod10_even = _mm_mulhi_epi16(vxi10, vxk10); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod10_odd, vprod10_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod10_odd, vprod10_even)); + + const __m128i vi11 = _mm_loadl_epi64((const __m128i*)i11); + i11 += 8; + const __m128i vxi11 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi11, vzero), vinput_zero_point); + const __m128i vk11 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 64)); + const __m128i vxk11 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk11, vzero), vkernel_zero_point); + const __m128i vprod11_odd = _mm_mullo_epi16(vxi11, vxk11); + const __m128i vprod11_even = _mm_mulhi_epi16(vxi11, vxk11); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod11_odd, vprod11_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod11_odd, vprod11_even)); + + const __m128i vi12 = _mm_loadl_epi64((const __m128i*)i12); + i12 += 8; + const __m128i vxi12 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi12, vzero), vinput_zero_point); + const __m128i vk12 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 72)); + const __m128i vxk12 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk12, vzero), vkernel_zero_point); + const __m128i vprod12_odd = _mm_mullo_epi16(vxi12, vxk12); + const __m128i vprod12_even = _mm_mulhi_epi16(vxi12, vxk12); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod12_odd, vprod12_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod12_odd, vprod12_even)); + + const __m128i vi20 = _mm_loadl_epi64((const __m128i*)i20); + i20 += 8; + const __m128i vxi20 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi20, vzero), vinput_zero_point); + const __m128i vk20 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 80)); + const __m128i vxk20 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk20, vzero), vkernel_zero_point); + const __m128i vprod20_odd = _mm_mullo_epi16(vxi20, vxk20); + const __m128i vprod20_even = _mm_mulhi_epi16(vxi20, vxk20); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod20_odd, vprod20_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod20_odd, vprod20_even)); + + const __m128i vi21 = _mm_loadl_epi64((const __m128i*)i21); + i21 += 8; + const __m128i vxi21 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi21, vzero), vinput_zero_point); + const __m128i vk21 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 88)); + const __m128i vxk21 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk21, vzero), vkernel_zero_point); + const __m128i vprod21_odd = _mm_mullo_epi16(vxi21, vxk21); + const __m128i vprod21_even = _mm_mulhi_epi16(vxi21, vxk21); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod21_odd, vprod21_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod21_odd, vprod21_even)); + + const __m128i vi22 = _mm_loadl_epi64((const __m128i*)i22); + i22 += 8; + const __m128i vxi22 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi22, vzero), vinput_zero_point); + const __m128i vk22 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 96)); + const __m128i vxk22 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk22, vzero), vkernel_zero_point); + const __m128i vprod22_odd = _mm_mullo_epi16(vxi22, vxk22); + const __m128i vprod22_even = _mm_mulhi_epi16(vxi22, vxk22); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod22_odd, vprod22_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod22_odd, vprod22_even)); + + const __m128i vi23 = _mm_loadl_epi64((const __m128i*)i23); + i23 += 8; + const __m128i vxi23 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi23, vzero), vinput_zero_point); + const __m128i vk23 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 104)); + const __m128i vxk23 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk23, vzero), vkernel_zero_point); + const __m128i vprod23_odd = _mm_mullo_epi16(vxi23, vxk23); + const __m128i vprod23_even = _mm_mulhi_epi16(vxi23, vxk23); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod23_odd, vprod23_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod23_odd, vprod23_even)); + + w = (const void*)((uintptr_t)w + 112); + _mm_storeu_si128((__m128i*)outacc, vacc_lo); + outacc += 4; + _mm_storeu_si128((__m128i*)outacc, vacc_hi); + outacc += 4; + } + if (c != 0) { + const size_t i_predecrement = 8 - c; + const __m128i vi_shift = _mm_cvtsi32_si128(8 * i_predecrement); + i00 -= i_predecrement; + i01 -= i_predecrement; + i02 -= i_predecrement; + i10 -= i_predecrement; + i11 -= i_predecrement; + i12 -= i_predecrement; + i20 -= i_predecrement; + i21 -= i_predecrement; + i22 -= i_predecrement; + i23 -= i_predecrement; + + __m128i vacc_lo = _mm_loadu_si128((const __m128i*)w); + __m128i vacc_hi = _mm_loadu_si128((const __m128i*)((uintptr_t)w + 16)); + + const __m128i vi00 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i00), vi_shift); + const __m128i vxi00 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi00, vzero), vinput_zero_point); + const __m128i vk00 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32)); + const __m128i vxk00 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk00, vzero), vkernel_zero_point); + const __m128i vprod00_odd = _mm_mullo_epi16(vxi00, vxk00); + const __m128i vprod00_even = _mm_mulhi_epi16(vxi00, vxk00); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod00_odd, vprod00_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod00_odd, vprod00_even)); + + const __m128i vi01 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i01), vi_shift); + const __m128i vxi01 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi01, vzero), vinput_zero_point); + const __m128i vk01 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 40)); + const __m128i vxk01 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk01, vzero), vkernel_zero_point); + const __m128i vprod01_odd = _mm_mullo_epi16(vxi01, vxk01); + const __m128i vprod01_even = _mm_mulhi_epi16(vxi01, vxk01); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod01_odd, vprod01_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod01_odd, vprod01_even)); + + const __m128i vi02 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i02), vi_shift); + const __m128i vxi02 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi02, vzero), vinput_zero_point); + const __m128i vk02 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 48)); + const __m128i vxk02 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk02, vzero), vkernel_zero_point); + const __m128i vprod02_odd = _mm_mullo_epi16(vxi02, vxk02); + const __m128i vprod02_even = _mm_mulhi_epi16(vxi02, vxk02); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod02_odd, vprod02_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod02_odd, vprod02_even)); + + const __m128i vi10 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i10), vi_shift); + const __m128i vxi10 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi10, vzero), vinput_zero_point); + const __m128i vk10 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 56)); + const __m128i vxk10 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk10, vzero), vkernel_zero_point); + const __m128i vprod10_odd = _mm_mullo_epi16(vxi10, vxk10); + const __m128i vprod10_even = _mm_mulhi_epi16(vxi10, vxk10); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod10_odd, vprod10_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod10_odd, vprod10_even)); + + const __m128i vi11 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i11), vi_shift); + const __m128i vxi11 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi11, vzero), vinput_zero_point); + const __m128i vk11 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 64)); + const __m128i vxk11 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk11, vzero), vkernel_zero_point); + const __m128i vprod11_odd = _mm_mullo_epi16(vxi11, vxk11); + const __m128i vprod11_even = _mm_mulhi_epi16(vxi11, vxk11); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod11_odd, vprod11_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod11_odd, vprod11_even)); + + const __m128i vi12 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i12), vi_shift); + const __m128i vxi12 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi12, vzero), vinput_zero_point); + const __m128i vk12 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 72)); + const __m128i vxk12 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk12, vzero), vkernel_zero_point); + const __m128i vprod12_odd = _mm_mullo_epi16(vxi12, vxk12); + const __m128i vprod12_even = _mm_mulhi_epi16(vxi12, vxk12); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod12_odd, vprod12_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod12_odd, vprod12_even)); + + const __m128i vi20 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i20), vi_shift); + const __m128i vxi20 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi20, vzero), vinput_zero_point); + const __m128i vk20 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 80)); + const __m128i vxk20 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk20, vzero), vkernel_zero_point); + const __m128i vprod20_odd = _mm_mullo_epi16(vxi20, vxk20); + const __m128i vprod20_even = _mm_mulhi_epi16(vxi20, vxk20); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod20_odd, vprod20_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod20_odd, vprod20_even)); + + const __m128i vi21 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i21), vi_shift); + const __m128i vxi21 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi21, vzero), vinput_zero_point); + const __m128i vk21 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 88)); + const __m128i vxk21 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk21, vzero), vkernel_zero_point); + const __m128i vprod21_odd = _mm_mullo_epi16(vxi21, vxk21); + const __m128i vprod21_even = _mm_mulhi_epi16(vxi21, vxk21); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod21_odd, vprod21_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod21_odd, vprod21_even)); + + const __m128i vi22 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i22), vi_shift); + const __m128i vxi22 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi22, vzero), vinput_zero_point); + const __m128i vk22 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 96)); + const __m128i vxk22 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk22, vzero), vkernel_zero_point); + const __m128i vprod22_odd = _mm_mullo_epi16(vxi22, vxk22); + const __m128i vprod22_even = _mm_mulhi_epi16(vxi22, vxk22); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod22_odd, vprod22_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod22_odd, vprod22_even)); + + const __m128i vi23 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i23), vi_shift); + const __m128i vxi23 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi23, vzero), vinput_zero_point); + const __m128i vk23 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 104)); + const __m128i vxk23 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk23, vzero), vkernel_zero_point); + const __m128i vprod23_odd = _mm_mullo_epi16(vxi23, vxk23); + const __m128i vprod23_even = _mm_mulhi_epi16(vxi23, vxk23); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod23_odd, vprod23_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod23_odd, vprod23_even)); + + w = (const void*)((uintptr_t)w + 112); + _mm_storeu_si128((__m128i*)outacc, vacc_lo); + outacc += 4; + _mm_storeu_si128((__m128i*)outacc, vacc_hi); + outacc += 4; + } + } + { + const uint8_t* i00 = input[10]; + const uint8_t* i01 = input[11]; + const uint8_t* i02 = input[12]; + const uint8_t* i10 = input[13]; + const uint8_t* i11 = input[14]; + const uint8_t* i12 = input[15]; + const uint8_t* i20 = input[16]; + const uint8_t* i21 = input[17]; + const uint8_t* i22 = input[18]; + const uint8_t* i23 = input[19]; + outacc = outacc32; + + size_t c = channels; + for (; c >= 8; c -= 8) { + const __m128i vi00 = _mm_loadl_epi64((const __m128i*)i00); + i00 += 8; + const __m128i vxi00 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi00, vzero), vinput_zero_point); + const __m128i vk00 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w)); + const __m128i vxk00 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk00, vzero), vkernel_zero_point); + const __m128i vprod00_odd = _mm_mullo_epi16(vxi00, vxk00); + const __m128i vprod00_even = _mm_mulhi_epi16(vxi00, vxk00); + __m128i vacc_lo = _mm_unpacklo_epi16(vprod00_odd, vprod00_even); + __m128i vacc_hi = _mm_unpackhi_epi16(vprod00_odd, vprod00_even); + + const __m128i vi01 = _mm_loadl_epi64((const __m128i*)i01); + i01 += 8; + const __m128i vxi01 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi01, vzero), vinput_zero_point); + const __m128i vk01 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 8)); + const __m128i vxk01 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk01, vzero), vkernel_zero_point); + const __m128i vprod01_odd = _mm_mullo_epi16(vxi01, vxk01); + const __m128i vprod01_even = _mm_mulhi_epi16(vxi01, vxk01); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod01_odd, vprod01_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod01_odd, vprod01_even)); + + const __m128i vi02 = _mm_loadl_epi64((const __m128i*)i02); + i02 += 8; + const __m128i vxi02 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi02, vzero), vinput_zero_point); + const __m128i vk02 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 16)); + const __m128i vxk02 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk02, vzero), vkernel_zero_point); + const __m128i vprod02_odd = _mm_mullo_epi16(vxi02, vxk02); + const __m128i vprod02_even = _mm_mulhi_epi16(vxi02, vxk02); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod02_odd, vprod02_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod02_odd, vprod02_even)); + + const __m128i vi10 = _mm_loadl_epi64((const __m128i*)i10); + i10 += 8; + const __m128i vxi10 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi10, vzero), vinput_zero_point); + const __m128i vk10 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 24)); + const __m128i vxk10 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk10, vzero), vkernel_zero_point); + const __m128i vprod10_odd = _mm_mullo_epi16(vxi10, vxk10); + const __m128i vprod10_even = _mm_mulhi_epi16(vxi10, vxk10); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod10_odd, vprod10_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod10_odd, vprod10_even)); + + const __m128i vi11 = _mm_loadl_epi64((const __m128i*)i11); + i11 += 8; + const __m128i vxi11 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi11, vzero), vinput_zero_point); + const __m128i vk11 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32)); + const __m128i vxk11 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk11, vzero), vkernel_zero_point); + const __m128i vprod11_odd = _mm_mullo_epi16(vxi11, vxk11); + const __m128i vprod11_even = _mm_mulhi_epi16(vxi11, vxk11); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod11_odd, vprod11_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod11_odd, vprod11_even)); + + const __m128i vi12 = _mm_loadl_epi64((const __m128i*)i12); + i12 += 8; + const __m128i vxi12 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi12, vzero), vinput_zero_point); + const __m128i vk12 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 40)); + const __m128i vxk12 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk12, vzero), vkernel_zero_point); + const __m128i vprod12_odd = _mm_mullo_epi16(vxi12, vxk12); + const __m128i vprod12_even = _mm_mulhi_epi16(vxi12, vxk12); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod12_odd, vprod12_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod12_odd, vprod12_even)); + + const __m128i vi20 = _mm_loadl_epi64((const __m128i*)i20); + i20 += 8; + const __m128i vxi20 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi20, vzero), vinput_zero_point); + const __m128i vk20 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 48)); + const __m128i vxk20 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk20, vzero), vkernel_zero_point); + const __m128i vprod20_odd = _mm_mullo_epi16(vxi20, vxk20); + const __m128i vprod20_even = _mm_mulhi_epi16(vxi20, vxk20); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod20_odd, vprod20_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod20_odd, vprod20_even)); + + const __m128i vi21 = _mm_loadl_epi64((const __m128i*)i21); + i21 += 8; + const __m128i vxi21 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi21, vzero), vinput_zero_point); + const __m128i vk21 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 56)); + const __m128i vxk21 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk21, vzero), vkernel_zero_point); + const __m128i vprod21_odd = _mm_mullo_epi16(vxi21, vxk21); + const __m128i vprod21_even = _mm_mulhi_epi16(vxi21, vxk21); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod21_odd, vprod21_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod21_odd, vprod21_even)); + + const __m128i vi22 = _mm_loadl_epi64((const __m128i*)i22); + i22 += 8; + const __m128i vxi22 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi22, vzero), vinput_zero_point); + const __m128i vk22 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 64)); + const __m128i vxk22 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk22, vzero), vkernel_zero_point); + const __m128i vprod22_odd = _mm_mullo_epi16(vxi22, vxk22); + const __m128i vprod22_even = _mm_mulhi_epi16(vxi22, vxk22); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod22_odd, vprod22_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod22_odd, vprod22_even)); + + const __m128i vi23 = _mm_loadl_epi64((const __m128i*)i23); + i23 += 8; + const __m128i vxi23 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi23, vzero), vinput_zero_point); + const __m128i vk23 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 72)); + const __m128i vxk23 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk23, vzero), vkernel_zero_point); + const __m128i vprod23_odd = _mm_mullo_epi16(vxi23, vxk23); + const __m128i vprod23_even = _mm_mulhi_epi16(vxi23, vxk23); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod23_odd, vprod23_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod23_odd, vprod23_even)); + + w = (const void*)((uintptr_t)w + 80); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_loadu_si128((__m128i*)outacc)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_loadu_si128((__m128i*)(outacc + 4))); + _mm_storeu_si128((__m128i*)outacc, vacc_lo); + outacc += 4; + _mm_storeu_si128((__m128i*)outacc, vacc_hi); + outacc += 4; + } + if (c != 0) { + const size_t i_predecrement = 8 - c; + const __m128i vi_shift = _mm_cvtsi32_si128(8 * i_predecrement); + i00 -= i_predecrement; + i01 -= i_predecrement; + i02 -= i_predecrement; + i10 -= i_predecrement; + i11 -= i_predecrement; + i12 -= i_predecrement; + i20 -= i_predecrement; + i21 -= i_predecrement; + i22 -= i_predecrement; + i23 -= i_predecrement; + + const __m128i vi00 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i00), vi_shift); + const __m128i vxi00 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi00, vzero), vinput_zero_point); + const __m128i vk00 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w)); + const __m128i vxk00 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk00, vzero), vkernel_zero_point); + const __m128i vprod00_odd = _mm_mullo_epi16(vxi00, vxk00); + const __m128i vprod00_even = _mm_mulhi_epi16(vxi00, vxk00); + __m128i vacc_lo = _mm_unpacklo_epi16(vprod00_odd, vprod00_even); + __m128i vacc_hi = _mm_unpackhi_epi16(vprod00_odd, vprod00_even); + + const __m128i vi01 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i01), vi_shift); + const __m128i vxi01 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi01, vzero), vinput_zero_point); + const __m128i vk01 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 8)); + const __m128i vxk01 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk01, vzero), vkernel_zero_point); + const __m128i vprod01_odd = _mm_mullo_epi16(vxi01, vxk01); + const __m128i vprod01_even = _mm_mulhi_epi16(vxi01, vxk01); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod01_odd, vprod01_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod01_odd, vprod01_even)); + + const __m128i vi02 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i02), vi_shift); + const __m128i vxi02 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi02, vzero), vinput_zero_point); + const __m128i vk02 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 16)); + const __m128i vxk02 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk02, vzero), vkernel_zero_point); + const __m128i vprod02_odd = _mm_mullo_epi16(vxi02, vxk02); + const __m128i vprod02_even = _mm_mulhi_epi16(vxi02, vxk02); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod02_odd, vprod02_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod02_odd, vprod02_even)); + + const __m128i vi10 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i10), vi_shift); + const __m128i vxi10 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi10, vzero), vinput_zero_point); + const __m128i vk10 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 24)); + const __m128i vxk10 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk10, vzero), vkernel_zero_point); + const __m128i vprod10_odd = _mm_mullo_epi16(vxi10, vxk10); + const __m128i vprod10_even = _mm_mulhi_epi16(vxi10, vxk10); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod10_odd, vprod10_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod10_odd, vprod10_even)); + + const __m128i vi11 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i11), vi_shift); + const __m128i vxi11 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi11, vzero), vinput_zero_point); + const __m128i vk11 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32)); + const __m128i vxk11 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk11, vzero), vkernel_zero_point); + const __m128i vprod11_odd = _mm_mullo_epi16(vxi11, vxk11); + const __m128i vprod11_even = _mm_mulhi_epi16(vxi11, vxk11); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod11_odd, vprod11_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod11_odd, vprod11_even)); + + const __m128i vi12 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i12), vi_shift); + const __m128i vxi12 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi12, vzero), vinput_zero_point); + const __m128i vk12 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 40)); + const __m128i vxk12 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk12, vzero), vkernel_zero_point); + const __m128i vprod12_odd = _mm_mullo_epi16(vxi12, vxk12); + const __m128i vprod12_even = _mm_mulhi_epi16(vxi12, vxk12); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod12_odd, vprod12_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod12_odd, vprod12_even)); + + const __m128i vi20 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i20), vi_shift); + const __m128i vxi20 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi20, vzero), vinput_zero_point); + const __m128i vk20 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 48)); + const __m128i vxk20 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk20, vzero), vkernel_zero_point); + const __m128i vprod20_odd = _mm_mullo_epi16(vxi20, vxk20); + const __m128i vprod20_even = _mm_mulhi_epi16(vxi20, vxk20); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod20_odd, vprod20_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod20_odd, vprod20_even)); + + const __m128i vi21 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i21), vi_shift); + const __m128i vxi21 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi21, vzero), vinput_zero_point); + const __m128i vk21 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 56)); + const __m128i vxk21 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk21, vzero), vkernel_zero_point); + const __m128i vprod21_odd = _mm_mullo_epi16(vxi21, vxk21); + const __m128i vprod21_even = _mm_mulhi_epi16(vxi21, vxk21); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod21_odd, vprod21_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod21_odd, vprod21_even)); + + const __m128i vi22 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i22), vi_shift); + const __m128i vxi22 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi22, vzero), vinput_zero_point); + const __m128i vk22 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 64)); + const __m128i vxk22 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk22, vzero), vkernel_zero_point); + const __m128i vprod22_odd = _mm_mullo_epi16(vxi22, vxk22); + const __m128i vprod22_even = _mm_mulhi_epi16(vxi22, vxk22); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod22_odd, vprod22_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod22_odd, vprod22_even)); + + const __m128i vi23 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i23), vi_shift); + const __m128i vxi23 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi23, vzero), vinput_zero_point); + const __m128i vk23 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 72)); + const __m128i vxk23 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk23, vzero), vkernel_zero_point); + const __m128i vprod23_odd = _mm_mullo_epi16(vxi23, vxk23); + const __m128i vprod23_even = _mm_mulhi_epi16(vxi23, vxk23); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod23_odd, vprod23_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod23_odd, vprod23_even)); + + w = (const void*)((uintptr_t)w + 80); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_loadu_si128((__m128i*)outacc)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_loadu_si128((__m128i*)(outacc + 4))); + _mm_storeu_si128((__m128i*)outacc, vacc_lo); + outacc += 4; + _mm_storeu_si128((__m128i*)outacc, vacc_hi); + outacc += 4; + } + } + { + const uint8_t* i00 = input[20]; + const uint8_t* i01 = input[21]; + const uint8_t* i02 = input[22]; + const uint8_t* i10 = input[23]; + const uint8_t* i11 = input[24]; + input = (const uint8_t**)((uintptr_t)input + input_stride); + outacc = outacc32; + size_t c = channels; + for (; c >= 8; c -= 8) { + const __m128i vi00 = _mm_loadl_epi64((const __m128i*)i00); + i00 += 8; + const __m128i vxi00 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi00, vzero), vinput_zero_point); + const __m128i vk00 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w)); + const __m128i vxk00 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk00, vzero), vkernel_zero_point); + const __m128i vprod00_odd = _mm_mullo_epi16(vxi00, vxk00); + const __m128i vprod00_even = _mm_mulhi_epi16(vxi00, vxk00); + __m128i vacc_lo = _mm_unpacklo_epi16(vprod00_odd, vprod00_even); + __m128i vacc_hi = _mm_unpackhi_epi16(vprod00_odd, vprod00_even); + + const __m128i vi01 = _mm_loadl_epi64((const __m128i*)i01); + i01 += 8; + const __m128i vxi01 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi01, vzero), vinput_zero_point); + const __m128i vk01 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 8)); + const __m128i vxk01 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk01, vzero), vkernel_zero_point); + const __m128i vprod01_odd = _mm_mullo_epi16(vxi01, vxk01); + const __m128i vprod01_even = _mm_mulhi_epi16(vxi01, vxk01); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod01_odd, vprod01_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod01_odd, vprod01_even)); + + const __m128i vi02 = _mm_loadl_epi64((const __m128i*)i02); + i02 += 8; + const __m128i vxi02 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi02, vzero), vinput_zero_point); + const __m128i vk02 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 16)); + const __m128i vxk02 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk02, vzero), vkernel_zero_point); + const __m128i vprod02_odd = _mm_mullo_epi16(vxi02, vxk02); + const __m128i vprod02_even = _mm_mulhi_epi16(vxi02, vxk02); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod02_odd, vprod02_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod02_odd, vprod02_even)); + + const __m128i vi10 = _mm_loadl_epi64((const __m128i*)i10); + i10 += 8; + const __m128i vxi10 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi10, vzero), vinput_zero_point); + const __m128i vk10 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 24)); + const __m128i vxk10 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk10, vzero), vkernel_zero_point); + const __m128i vprod10_odd = _mm_mullo_epi16(vxi10, vxk10); + const __m128i vprod10_even = _mm_mulhi_epi16(vxi10, vxk10); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod10_odd, vprod10_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod10_odd, vprod10_even)); + + const __m128i vi11 = _mm_loadl_epi64((const __m128i*)i11); + i11 += 8; + const __m128i vxi11 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi11, vzero), vinput_zero_point); + const __m128i vk11 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32)); + const __m128i vxk11 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk11, vzero), vkernel_zero_point); + const __m128i vprod11_odd = _mm_mullo_epi16(vxi11, vxk11); + const __m128i vprod11_even = _mm_mulhi_epi16(vxi11, vxk11); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod11_odd, vprod11_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod11_odd, vprod11_even)); + + w = (const void*)((uintptr_t)w + 40); + + vacc_lo = _mm_add_epi32(vacc_lo, _mm_loadu_si128((__m128i*)outacc)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_loadu_si128((__m128i*)(outacc + 4))); + outacc += 8; + + const __m128i vmultiplier = _mm_load_si128( + (const __m128i*)quantization_params->sse2.multiplier); + const __m128i vrounding = + _mm_load_si128((const __m128i*)quantization_params->sse2.rounding); + + const __m128i vnmask_lo0123 = + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo); + const __m128i vnmask_hi0123 = + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi); + + const __m128i vabsacc_lo0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_lo, vnmask_lo0123), vnmask_lo0123); + const __m128i vabsacc_hi0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_hi, vnmask_hi0123), vnmask_hi0123); + + const __m128i vabsacc_lo1032 = + _mm_shuffle_epi32(vabsacc_lo0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabsacc_hi1032 = + _mm_shuffle_epi32(vabsacc_hi0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i vabsprod_lo02 = + _mm_mul_epu32(vabsacc_lo0123, vmultiplier); + const __m128i vabsprod_hi02 = + _mm_mul_epu32(vabsacc_hi0123, vmultiplier); + + const __m128i vnmask_lo02 = + _mm_shuffle_epi32(vnmask_lo0123, _MM_SHUFFLE(2, 2, 0, 0)); + const __m128i vnmask_hi02 = + _mm_shuffle_epi32(vnmask_hi0123, _MM_SHUFFLE(2, 2, 0, 0)); + + const __m128i vprod_lo02 = _mm_sub_epi64( + _mm_xor_si128(vabsprod_lo02, vnmask_lo02), vnmask_lo02); + const __m128i vprod_hi02 = _mm_sub_epi64( + _mm_xor_si128(vabsprod_hi02, vnmask_hi02), vnmask_hi02); + + const __m128i vq31prod_lo02 = + _mm_srli_epi64(_mm_add_epi64(vprod_lo02, vrounding), 31); + const __m128i vq31prod_hi02 = + _mm_srli_epi64(_mm_add_epi64(vprod_hi02, vrounding), 31); + + const __m128i vabsprod_lo13 = + _mm_mul_epu32(vabsacc_lo1032, vmultiplier); + const __m128i vabsprod_hi13 = + _mm_mul_epu32(vabsacc_hi1032, vmultiplier); + + const __m128i vnmask_lo13 = + _mm_shuffle_epi32(vnmask_lo0123, _MM_SHUFFLE(3, 3, 1, 1)); + const __m128i vnmask_hi13 = + _mm_shuffle_epi32(vnmask_hi0123, _MM_SHUFFLE(3, 3, 1, 1)); + + const __m128i vprod_lo13 = _mm_sub_epi64( + _mm_xor_si128(vabsprod_lo13, vnmask_lo13), vnmask_lo13); + const __m128i vprod_hi13 = _mm_sub_epi64( + _mm_xor_si128(vabsprod_hi13, vnmask_hi13), vnmask_hi13); + + const __m128i vq31prod_lo13 = + _mm_srli_epi64(_mm_add_epi64(vprod_lo13, vrounding), 31); + const __m128i vq31prod_hi13 = + _mm_srli_epi64(_mm_add_epi64(vprod_hi13, vrounding), 31); + + const __m128i vq31prod_lo0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vq31prod_lo02), + _mm_castsi128_ps(vq31prod_lo13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vq31prod_hi0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vq31prod_hi02), + _mm_castsi128_ps(vq31prod_hi13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i vq31prod_lo0123 = + _mm_shuffle_epi32(vq31prod_lo0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vq31prod_hi0123 = + _mm_shuffle_epi32(vq31prod_hi0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i vremainder_mask = _mm_load_si128( + (const __m128i*)quantization_params->sse2.remainder_mask); + + const __m128i vrem_lo0123 = _mm_add_epi32( + _mm_and_si128(vq31prod_lo0123, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod_lo0123)); + const __m128i vrem_hi0123 = _mm_add_epi32( + _mm_and_si128(vq31prod_hi0123, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod_hi0123)); + + const __m128i vremainder_threshold = _mm_load_si128( + (const __m128i*)quantization_params->sse2.remainder_threshold); + const __m128i vshift = + _mm_load_si128((const __m128i*)quantization_params->sse2.shift); + + const __m128i vout_lo = _mm_sub_epi32( + _mm_sra_epi32(vq31prod_lo0123, vshift), + _mm_cmpgt_epi32(vrem_lo0123, vremainder_threshold)); + const __m128i vout_hi = _mm_sub_epi32( + _mm_sra_epi32(vq31prod_hi0123, vshift), + _mm_cmpgt_epi32(vrem_hi0123, vremainder_threshold)); + + const __m128i voutput_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.output_zero_point); + __m128i vout = _mm_adds_epi16( + _mm_packs_epi32(vout_lo, vout_hi), voutput_zero_point); + vout = _mm_packus_epi16(vout, vout); + vout = _mm_max_epu8( + vout, + _mm_load_si128( + (const __m128i*)quantization_params->sse2.output_min)); + vout = _mm_min_epu8( + vout, + _mm_load_si128( + (const __m128i*)quantization_params->sse2.output_max)); + + _mm_storel_epi64((__m128i*)output, vout); + output += 8; + } + if (c != 0) { + const size_t i_predecrement = 8 - c; + const __m128i vi_shift = _mm_cvtsi32_si128(8 * i_predecrement); + i00 -= i_predecrement; + i01 -= i_predecrement; + i02 -= i_predecrement; + i10 -= i_predecrement; + i11 -= i_predecrement; + + const __m128i vi00 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i00), vi_shift); + const __m128i vxi00 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi00, vzero), vinput_zero_point); + const __m128i vk00 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w)); + const __m128i vxk00 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk00, vzero), vkernel_zero_point); + const __m128i vprod00_odd = _mm_mullo_epi16(vxi00, vxk00); + const __m128i vprod00_even = _mm_mulhi_epi16(vxi00, vxk00); + __m128i vacc_lo = _mm_unpacklo_epi16(vprod00_odd, vprod00_even); + __m128i vacc_hi = _mm_unpackhi_epi16(vprod00_odd, vprod00_even); + + const __m128i vi01 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i01), vi_shift); + const __m128i vxi01 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi01, vzero), vinput_zero_point); + const __m128i vk01 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 8)); + const __m128i vxk01 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk01, vzero), vkernel_zero_point); + const __m128i vprod01_odd = _mm_mullo_epi16(vxi01, vxk01); + const __m128i vprod01_even = _mm_mulhi_epi16(vxi01, vxk01); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod01_odd, vprod01_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod01_odd, vprod01_even)); + + const __m128i vi02 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i02), vi_shift); + const __m128i vxi02 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi02, vzero), vinput_zero_point); + const __m128i vk02 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 16)); + const __m128i vxk02 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk02, vzero), vkernel_zero_point); + const __m128i vprod02_odd = _mm_mullo_epi16(vxi02, vxk02); + const __m128i vprod02_even = _mm_mulhi_epi16(vxi02, vxk02); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod02_odd, vprod02_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod02_odd, vprod02_even)); + + const __m128i vi10 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i10), vi_shift); + const __m128i vxi10 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi10, vzero), vinput_zero_point); + const __m128i vk10 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 24)); + const __m128i vxk10 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk10, vzero), vkernel_zero_point); + const __m128i vprod10_odd = _mm_mullo_epi16(vxi10, vxk10); + const __m128i vprod10_even = _mm_mulhi_epi16(vxi10, vxk10); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod10_odd, vprod10_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod10_odd, vprod10_even)); + + const __m128i vi11 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i11), vi_shift); + const __m128i vxi11 = + _mm_sub_epi16(_mm_unpacklo_epi8(vi11, vzero), vinput_zero_point); + const __m128i vk11 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32)); + const __m128i vxk11 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk11, vzero), vkernel_zero_point); + const __m128i vprod11_odd = _mm_mullo_epi16(vxi11, vxk11); + const __m128i vprod11_even = _mm_mulhi_epi16(vxi11, vxk11); + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vprod11_odd, vprod11_even)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vprod11_odd, vprod11_even)); + + vacc_lo = _mm_add_epi32(vacc_lo, _mm_loadu_si128((__m128i*)outacc)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_loadu_si128((__m128i*)(outacc + 4))); + outacc += 8; + + const __m128i vmultiplier = _mm_load_si128( + (const __m128i*)quantization_params->sse2.multiplier); + const __m128i vrounding = + _mm_load_si128((const __m128i*)quantization_params->sse2.rounding); + + const __m128i vnmask_lo0123 = + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo); + const __m128i vnmask_hi0123 = + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi); + + const __m128i vabsacc_lo0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_lo, vnmask_lo0123), vnmask_lo0123); + const __m128i vabsacc_hi0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_hi, vnmask_hi0123), vnmask_hi0123); + + const __m128i vabsacc_lo1032 = + _mm_shuffle_epi32(vabsacc_lo0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabsacc_hi1032 = + _mm_shuffle_epi32(vabsacc_hi0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i vabsprod_lo02 = + _mm_mul_epu32(vabsacc_lo0123, vmultiplier); + const __m128i vabsprod_hi02 = + _mm_mul_epu32(vabsacc_hi0123, vmultiplier); + + const __m128i vnmask_lo02 = + _mm_shuffle_epi32(vnmask_lo0123, _MM_SHUFFLE(2, 2, 0, 0)); + const __m128i vnmask_hi02 = + _mm_shuffle_epi32(vnmask_hi0123, _MM_SHUFFLE(2, 2, 0, 0)); + + const __m128i vprod_lo02 = _mm_sub_epi64( + _mm_xor_si128(vabsprod_lo02, vnmask_lo02), vnmask_lo02); + const __m128i vprod_hi02 = _mm_sub_epi64( + _mm_xor_si128(vabsprod_hi02, vnmask_hi02), vnmask_hi02); + + const __m128i vq31prod_lo02 = + _mm_srli_epi64(_mm_add_epi64(vprod_lo02, vrounding), 31); + const __m128i vq31prod_hi02 = + _mm_srli_epi64(_mm_add_epi64(vprod_hi02, vrounding), 31); + + const __m128i vabsprod_lo13 = + _mm_mul_epu32(vabsacc_lo1032, vmultiplier); + const __m128i vabsprod_hi13 = + _mm_mul_epu32(vabsacc_hi1032, vmultiplier); + + const __m128i vnmask_lo13 = + _mm_shuffle_epi32(vnmask_lo0123, _MM_SHUFFLE(3, 3, 1, 1)); + const __m128i vnmask_hi13 = + _mm_shuffle_epi32(vnmask_hi0123, _MM_SHUFFLE(3, 3, 1, 1)); + + const __m128i vprod_lo13 = _mm_sub_epi64( + _mm_xor_si128(vabsprod_lo13, vnmask_lo13), vnmask_lo13); + const __m128i vprod_hi13 = _mm_sub_epi64( + _mm_xor_si128(vabsprod_hi13, vnmask_hi13), vnmask_hi13); + + const __m128i vq31prod_lo13 = + _mm_srli_epi64(_mm_add_epi64(vprod_lo13, vrounding), 31); + const __m128i vq31prod_hi13 = + _mm_srli_epi64(_mm_add_epi64(vprod_hi13, vrounding), 31); + + const __m128i vq31prod_lo0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vq31prod_lo02), + _mm_castsi128_ps(vq31prod_lo13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vq31prod_hi0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vq31prod_hi02), + _mm_castsi128_ps(vq31prod_hi13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i vq31prod_lo0123 = + _mm_shuffle_epi32(vq31prod_lo0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vq31prod_hi0123 = + _mm_shuffle_epi32(vq31prod_hi0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i vremainder_mask = _mm_load_si128( + (const __m128i*)quantization_params->sse2.remainder_mask); + + const __m128i vrem_lo0123 = _mm_add_epi32( + _mm_and_si128(vq31prod_lo0123, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod_lo0123)); + const __m128i vrem_hi0123 = _mm_add_epi32( + _mm_and_si128(vq31prod_hi0123, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod_hi0123)); + + const __m128i vremainder_threshold = _mm_load_si128( + (const __m128i*)quantization_params->sse2.remainder_threshold); + const __m128i vshift = + _mm_load_si128((const __m128i*)quantization_params->sse2.shift); + + const __m128i vout_lo = _mm_sub_epi32( + _mm_sra_epi32(vq31prod_lo0123, vshift), + _mm_cmpgt_epi32(vrem_lo0123, vremainder_threshold)); + const __m128i vout_hi = _mm_sub_epi32( + _mm_sra_epi32(vq31prod_hi0123, vshift), + _mm_cmpgt_epi32(vrem_hi0123, vremainder_threshold)); + + const __m128i voutput_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.output_zero_point); + __m128i vout = _mm_adds_epi16( + _mm_packs_epi32(vout_lo, vout_hi), voutput_zero_point); + vout = _mm_packus_epi16(vout, vout); + vout = _mm_max_epu8( + vout, + _mm_load_si128( + (const __m128i*)quantization_params->sse2.output_min)); + vout = _mm_min_epu8( + vout, + _mm_load_si128( + (const __m128i*)quantization_params->sse2.output_max)); + + if (c & 4) { + *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout); + output += 4; + vout = _mm_srli_epi64(vout, 32); + } + if (c & 2) { + *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0); + output += 2; + vout = _mm_srli_epi32(vout, 16); + } + if (c & 1) { + *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout); + output += 1; + } + } + } + output = (uint8_t*)((uintptr_t)output + output_increment); + } while (--output_width != 0); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-aarch32-neon.S new file mode 100644 index 0000000000000..f02e1ae6274c7 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-aarch32-neon.S @@ -0,0 +1,397 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +.syntax unified + +# void pytorch_q8dwconv_ukernel_up8x9__aarch32_neon( +# size_t channels, +# size_t output_width, +# const uint8_t** input, +# const void* weights, +# uint8_t* output, +# size_t input_stride, +# size_t output_increment, +# const union pytorch_qnnp_conv_quantization_params quantization_params[restrict static 1]) +BEGIN_FUNCTION pytorch_q8dwconv_ukernel_up8x9__aarch32_neon + .arm +#ifndef __APPLE__ + .arch armv7-a + .fpu neon +#endif + + # Load params + # - r12 = quantization_params + LDR r12, [sp, 12] + + PUSH {r4, r5, r6, r7, r8, r9, r10, r11, lr} + VPUSH {d8-d15} + + STR r0, [sp, #-8] + STR r3, [sp, #-4] + + MOV r4, 2 + + # Load o: + # - lr = o = output + LDR lr, [sp, 100] + + # Load kernel zero point: + # - d31 = vkernel_zero_point + VLD1.8 {d31[]}, [r12], r4 + + # Load input zero point: + # - d30 = vinput_zero_point + VLD1.8 {d30[]}, [r12], r4 + + # Load multiplier: + # - q14 = d28:d29 = vmultiplier + VLD1.32 {d28[], d29[]}, [r12]! + + # Load right shift: + # - q13 = d26:d27 = vright_shift + VLD1.32 {d26[], d27[]}, [r12]! + + # Load output zero point: + # - q12 = d24:d25 = voutput_zero_point + VLD1.16 {d24[], d25[]}, [r12]! + + # Compute vzero_shift_mask + # - q11 = vzero_shift_mask + VCEQ.S32 q11, q13, 0 + + # Load output max: + # - d20 = voutput_max + VLD1.8 {d20[]}, [r12]! + + # Load output min: + # - d21 = voutput_min + VLD1.8 {d21[]}, [r12] + + .p2align 3 +0: + # Load input stride + # - r3 = input_stride + LDR r3, [sp, 104] + + # Load c: + # - r0 = c = channels + LDR r0, [sp, #-8] + + # Load i0, i1, i2, i3, i4, i5, i6, i7, i8 + # - r4 = i0 + # - r5 = i1 + # - r6 = i2 + # - r7 = i3 + # - r8 = i4 + # - r9 = i5 + # - r10 = i6 + # - r11 = i7 + # - r12 = i8 + LDM r2, {r4, r5, r6, r7, r8, r9, r10, r11, r12} + + # Pre-decrement c + SUBS r0, r0, 8 + + # Increment input by input stride + # - input = r2 := input + input_stride + ADD r2, r2, r3 + + # Load w: + # - r3 = w = weights + LDR r3, [sp, #-4] + + BLO 2f + + .p2align 4 +1: + VLDM r3!, {d0-d3} + + VLD1.8 {d4}, [r4]! + VLD1.8 {d6}, [r3]! + + VLD1.8 {d8}, [r5]! + VLD1.8 {d10}, [r3]! + + SUB_ZERO_POINT q2, d4, d30 + VSUBL.U8 q3, d6, d31 + + VLD1.8 {d12}, [r6]! + VLD1.8 {d14}, [r3]! + + SUB_ZERO_POINT q4, d8, d30 + VSUBL.U8 q5, d10, d31 + + VMLAL.S16 q0, d4, d6 + VMLAL.S16 q1, d5, d7 + + VLD1.8 {d4}, [r7]! + VLD1.8 {d6}, [r3]! + + SUB_ZERO_POINT q6, d12, d30 + VSUBL.U8 q7, d14, d31 + + VMLAL.S16 q0, d8, d10 + VMLAL.S16 q1, d9, d11 + + VLD1.8 {d8}, [r8]! + VLD1.8 {d10}, [r3]! + + SUB_ZERO_POINT q2, d4, d30 + VSUBL.U8 q3, d6, d31 + + VMLAL.S16 q0, d12, d14 + VMLAL.S16 q1, d13, d15 + + VLD1.8 {d12}, [r9]! + VLD1.8 {d14}, [r3]! + + SUB_ZERO_POINT q4, d8, d30 + VSUBL.U8 q5, d10, d31 + + VMLAL.S16 q0, d4, d6 + VMLAL.S16 q1, d5, d7 + + VLD1.8 {d4}, [r10]! + VLD1.8 {d6}, [r3]! + + SUB_ZERO_POINT q6, d12, d30 + VSUBL.U8 q7, d14, d31 + + VMLAL.S16 q0, d8, d10 + VMLAL.S16 q1, d9, d11 + + VLD1.8 {d8}, [r11]! + VLD1.8 {d10}, [r3]! + + SUB_ZERO_POINT q2, d4, d30 + VSUBL.U8 q3, d6, d31 + + VMLAL.S16 q0, d12, d14 + VMLAL.S16 q1, d13, d15 + + VLD1.8 {d12}, [r12]! + VLD1.8 {d14}, [r3]! + + SUB_ZERO_POINT q4, d8, d30 + VSUBL.U8 q5, d10, d31 + + VMLAL.S16 q0, d4, d6 + VMLAL.S16 q1, d5, d7 + + SUB_ZERO_POINT q6, d12, d30 + VSUBL.U8 q7, d14, d31 + + VMLAL.S16 q0, d8, d10 + VMLAL.S16 q1, d9, d11 + + VMLAL.S16 q0, d12, d14 + VMLAL.S16 q1, d13, d15 + + VQRDMULH.S32 q0, q0, q14 + VQRDMULH.S32 q1, q1, q14 + + VBIC q2, q0, q11 + VBIC q3, q1, q11 + + VSRA.S32 q0, q2, 31 + VSRA.S32 q1, q3, 31 + + VRSHL.S32 q0, q0, q13 + VRSHL.S32 q1, q1, q13 + + VQMOVN.S32 d0, q0 + VQMOVN.S32 d1, q1 + + VQADD.S16 q0, q12 + VQMOVUN.S16 d0, q0 + VMIN.U8 d0, d0, d20 + VMAX.U8 d0, d0, d21 + + VST1.8 {d0}, [lr]! + SUBS r0, r0, 8 + BHS 1b + +2: + CMP r0, -8 + BEQ 5f + + ADD r4, r4, r0 + ADD r5, r5, r0 + ADD r6, r6, r0 + ADD r7, r7, r0 + ADD r8, r8, r0 + ADD r9, r9, r0 + ADD r10, r10, r0 + ADD r11, r11, r0 + ADD r12, r12, r0 + + LSL r0, r0, 3 + VDUP.32 d22, r0 + + VLDM r3!, {d0-d3} + + VLD1.8 {d4}, [r4]! + VLD1.8 {d6}, [r3]! + VLD1.8 {d8}, [r5]! + VLD1.8 {d10}, [r3]! + + VSHL.U64 d4, d4, d22 + + VLD1.8 {d12}, [r6]! + VLD1.8 {d14}, [r3]! + + SUB_ZERO_POINT q2, d4, d30 + VSUBL.U8 q3, d6, d31 + + VSHL.U64 d8, d8, d22 + + VLD1.8 {d16}, [r7]! + VLD1.8 {d18}, [r3]! + + VSHL.U64 d12, d12, d22 + + SUB_ZERO_POINT q4, d8, d30 + VSUBL.U8 q5, d10, d31 + + VMLAL.S16 q0, d4, d6 + VMLAL.S16 q1, d5, d7 + + VLD1.8 {d4}, [r8]! + VLD1.8 {d6}, [r3]! + + VSHL.U64 d16, d16, d22 + + SUB_ZERO_POINT q6, d12, d30 + VSUBL.U8 q7, d14, d31 + + VMLAL.S16 q0, d8, d10 + VMLAL.S16 q1, d9, d11 + + VLD1.8 {d8}, [r9]! + VLD1.8 {d10}, [r3]! + + VSHL.U64 d4, d4, d22 + + SUB_ZERO_POINT q8, d16, d30 + VSUBL.U8 q9, d18, d31 + + VMLAL.S16 q0, d12, d14 + VMLAL.S16 q1, d13, d15 + + VLD1.8 {d12}, [r10]! + VLD1.8 {d14}, [r3]! + + VSHL.U64 d8, d8, d22 + + SUB_ZERO_POINT q2, d4, d30 + VSUBL.U8 q3, d6, d31 + + VMLAL.S16 q0, d16, d18 + VMLAL.S16 q1, d17, d19 + + VLD1.8 {d16}, [r11]! + VLD1.8 {d18}, [r3]! + + VSHL.U64 d12, d12, d22 + + SUB_ZERO_POINT q4, d8, d30 + VSUBL.U8 q5, d10, d31 + + VMLAL.S16 q0, d4, d6 + VMLAL.S16 q1, d5, d7 + + VLD1.8 {d4}, [r12]! + VLD1.8 {d6}, [r3]! + + VSHL.U64 d16, d16, d22 + + SUB_ZERO_POINT q6, d12, d30 + VSUBL.U8 q7, d14, d31 + + VMLAL.S16 q0, d8, d10 + VMLAL.S16 q1, d9, d11 + + VSHL.U64 d4, d4, d22 + + SUB_ZERO_POINT q8, d16, d30 + VSUBL.U8 q9, d18, d31 + + VMLAL.S16 q0, d12, d14 + VMLAL.S16 q1, d13, d15 + + SUB_ZERO_POINT q2, d4, d30 + VSUBL.U8 q3, d6, d31 + + VMLAL.S16 q0, d16, d18 + VMLAL.S16 q1, d17, d19 + + VMLAL.S16 q0, d4, d6 + VMLAL.S16 q1, d5, d7 + + VQRDMULH.S32 q0, q0, q14 + VQRDMULH.S32 q1, q1, q14 + + VCEQ.S32 q11, q13, 0 + + VBIC q2, q0, q11 + VBIC q3, q1, q11 + + VSRA.S32 q0, q2, 31 + VSRA.S32 q1, q3, 31 + + VRSHL.S32 q0, q0, q13 + VRSHL.S32 q1, q1, q13 + + VQMOVN.S32 d0, q0 + VQMOVN.S32 d1, q1 + + VQADD.S16 q0, q12 + VQMOVUN.S16 d0, q0 + VMIN.U8 d0, d0, d20 + VMAX.U8 d0, d0, d21 + + TST r0, 32 + BEQ 3f + VST1.32 {d0[0]}, [lr]! + VEXT.8 d0, d0, 4 + +3: + TST r0, 16 + BEQ 4f + VST1.16 {d0[0]}, [lr]! + VEXT.8 d0, d0, 2 + +4: + TST r0, 8 + BEQ 5f + VST1.8 {d0[0]}, [lr]! + +5: + # Load output increment + # - r3 = output_increment + LDR r3, [sp, 108] + + # Decrement output width + SUBS r1, r1, 1 + + # Increment output by output_increment + ADD lr, lr, r3 + + # If output width is non-zero, process another pixel + BNE 0b + + VPOP {d8-d15} + POP {r4, r5, r6, r7, r8, r9, r10, r11, pc} +END_FUNCTION pytorch_q8dwconv_ukernel_up8x9__aarch32_neon + +#ifdef __ELF__ +.section ".note.GNU-stack","",%progbits +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-neon.c new file mode 100644 index 0000000000000..441072abe12ae --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-neon.c @@ -0,0 +1,966 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +void pytorch_q8dwconv_ukernel_up8x9__neon( + size_t channels, + size_t output_width, + const uint8_t** input, + const void* weights, + uint8_t* output, + size_t input_stride, + size_t output_increment, + const union pytorch_qnnp_conv_quantization_params + quantization_params[restrict static 1]) { + const uint8x8_t va_zero_point = + vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point); + const uint8x8_t vkernel_zero_point = + vld1_dup_u8((const uint8_t*)&quantization_params->neon.kernel_zero_point); + const int32x4_t vmultiplier = + vld1q_dup_s32(&quantization_params->neon.multiplier); + const int32x4_t vright_shift = + vld1q_dup_s32(&quantization_params->neon.right_shift); + const int16x8_t voutput_zero_point = + vld1q_dup_s16(&quantization_params->neon.output_zero_point); + const uint8x8_t voutput_min = + vld1_dup_u8(&quantization_params->neon.output_min); + const uint8x8_t voutput_max = + vld1_dup_u8(&quantization_params->neon.output_max); + +#ifdef __aarch64__ + /* Larger number of registers on AArch64 make it possible to process few + * pixels at a time */ + if (input_stride == 3 * sizeof(void*)) { + for (; output_width >= 3; output_width -= 3) { + const uint8_t* i00 = input[0]; + const uint8_t* i10 = input[1]; + const uint8_t* i20 = input[2]; + const uint8_t* i01 = input[3]; + const uint8_t* i11 = input[4]; + const uint8_t* i21 = input[5]; + const uint8_t* i02 = input[6]; + const uint8_t* i12 = input[7]; + const uint8_t* i22 = input[8]; + const uint8_t* i03 = input[9]; + const uint8_t* i13 = input[10]; + const uint8_t* i23 = input[11]; + const uint8_t* i04 = input[12]; + const uint8_t* i14 = input[13]; + const uint8_t* i24 = input[14]; + + uint8_t* output0 = output; + uint8_t* output1 = output0 + channels + output_increment; + uint8_t* output2 = output1 + channels + output_increment; + + input += 9; + + size_t c = channels; + const void* w = weights; + for (; c >= 8; c -= 8) { + int32x4_t vacc0_lo = vld1q_s32(w); + w = (void*)((uintptr_t)w + sizeof(int32x4_t)); + int32x4_t vacc0_hi = vld1q_s32(w); + w = (void*)((uintptr_t)w + sizeof(int32x4_t)); + int32x4_t vacc1_lo = vacc0_lo; + int32x4_t vacc2_lo = vacc0_lo; + int32x4_t vacc1_hi = vacc0_hi; + int32x4_t vacc2_hi = vacc0_hi; + + const uint8x8_t vk00 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi00 = vld1_u8(i00); + i00 += 8; + const uint8x8_t vi01 = vld1_u8(i01); + i01 += 8; + const uint8x8_t vi02 = vld1_u8(i02); + i02 += 8; + const int16x8_t vxk00 = + vreinterpretq_s16_u16(vsubl_u8(vk00, vkernel_zero_point)); + const int16x8_t vxi00 = + vreinterpretq_s16_u16(sub_zero_point(vi00, va_zero_point)); + const int16x8_t vxi01 = + vreinterpretq_s16_u16(sub_zero_point(vi01, va_zero_point)); + const int16x8_t vxi02 = + vreinterpretq_s16_u16(sub_zero_point(vi02, va_zero_point)); + vacc0_lo = + vmlal_s16(vacc0_lo, vget_low_s16(vxk00), vget_low_s16(vxi00)); + vacc0_hi = vmlal_high_s16(vacc0_hi, vxk00, vxi00); + vacc1_lo = + vmlal_s16(vacc1_lo, vget_low_s16(vxk00), vget_low_s16(vxi01)); + vacc1_hi = vmlal_high_s16(vacc1_hi, vxk00, vxi01); + vacc2_lo = + vmlal_s16(vacc2_lo, vget_low_s16(vxk00), vget_low_s16(vxi02)); + vacc2_hi = vmlal_high_s16(vacc2_hi, vxk00, vxi02); + + const uint8x8_t vk10 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi10 = vld1_u8(i10); + i10 += 8; + const uint8x8_t vi11 = vld1_u8(i11); + i11 += 8; + const uint8x8_t vi12 = vld1_u8(i12); + i12 += 8; + const int16x8_t vxk10 = + vreinterpretq_s16_u16(vsubl_u8(vk10, vkernel_zero_point)); + const int16x8_t vxi10 = + vreinterpretq_s16_u16(sub_zero_point(vi10, va_zero_point)); + const int16x8_t vxi11 = + vreinterpretq_s16_u16(sub_zero_point(vi11, va_zero_point)); + const int16x8_t vxi12 = + vreinterpretq_s16_u16(sub_zero_point(vi12, va_zero_point)); + vacc0_lo = + vmlal_s16(vacc0_lo, vget_low_s16(vxk10), vget_low_s16(vxi10)); + vacc0_hi = vmlal_high_s16(vacc0_hi, vxk10, vxi10); + vacc1_lo = + vmlal_s16(vacc1_lo, vget_low_s16(vxk10), vget_low_s16(vxi11)); + vacc1_hi = vmlal_high_s16(vacc1_hi, vxk10, vxi11); + vacc2_lo = + vmlal_s16(vacc2_lo, vget_low_s16(vxk10), vget_low_s16(vxi12)); + vacc2_hi = vmlal_high_s16(vacc2_hi, vxk10, vxi12); + + const uint8x8_t vk20 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi20 = vld1_u8(i20); + i20 += 8; + const uint8x8_t vi21 = vld1_u8(i21); + i21 += 8; + const uint8x8_t vi22 = vld1_u8(i22); + i22 += 8; + const int16x8_t vxk20 = + vreinterpretq_s16_u16(vsubl_u8(vk20, vkernel_zero_point)); + const int16x8_t vxi20 = + vreinterpretq_s16_u16(sub_zero_point(vi20, va_zero_point)); + const int16x8_t vxi21 = + vreinterpretq_s16_u16(sub_zero_point(vi21, va_zero_point)); + const int16x8_t vxi22 = + vreinterpretq_s16_u16(sub_zero_point(vi22, va_zero_point)); + vacc0_lo = + vmlal_s16(vacc0_lo, vget_low_s16(vxk20), vget_low_s16(vxi20)); + vacc0_hi = vmlal_high_s16(vacc0_hi, vxk20, vxi20); + vacc1_lo = + vmlal_s16(vacc1_lo, vget_low_s16(vxk20), vget_low_s16(vxi21)); + vacc1_hi = vmlal_high_s16(vacc1_hi, vxk20, vxi21); + vacc2_lo = + vmlal_s16(vacc2_lo, vget_low_s16(vxk20), vget_low_s16(vxi22)); + vacc2_hi = vmlal_high_s16(vacc2_hi, vxk20, vxi22); + + const uint8x8_t vk01 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi03 = vld1_u8(i03); + i03 += 8; + const int16x8_t vxk01 = + vreinterpretq_s16_u16(vsubl_u8(vk01, vkernel_zero_point)); + const int16x8_t vxi03 = + vreinterpretq_s16_u16(sub_zero_point(vi03, va_zero_point)); + vacc0_lo = + vmlal_s16(vacc0_lo, vget_low_s16(vxk01), vget_low_s16(vxi01)); + vacc0_hi = vmlal_high_s16(vacc0_hi, vxk01, vxi01); + vacc1_lo = + vmlal_s16(vacc1_lo, vget_low_s16(vxk01), vget_low_s16(vxi02)); + vacc1_hi = vmlal_high_s16(vacc1_hi, vxk01, vxi02); + vacc2_lo = + vmlal_s16(vacc2_lo, vget_low_s16(vxk01), vget_low_s16(vxi03)); + vacc2_hi = vmlal_high_s16(vacc2_hi, vxk01, vxi03); + + const uint8x8_t vk11 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi13 = vld1_u8(i13); + i13 += 8; + const int16x8_t vxk11 = + vreinterpretq_s16_u16(vsubl_u8(vk11, vkernel_zero_point)); + const int16x8_t vxi13 = + vreinterpretq_s16_u16(sub_zero_point(vi13, va_zero_point)); + vacc0_lo = + vmlal_s16(vacc0_lo, vget_low_s16(vxk11), vget_low_s16(vxi11)); + vacc0_hi = vmlal_high_s16(vacc0_hi, vxk11, vxi11); + vacc1_lo = + vmlal_s16(vacc1_lo, vget_low_s16(vxk11), vget_low_s16(vxi12)); + vacc1_hi = vmlal_high_s16(vacc1_hi, vxk11, vxi12); + vacc2_lo = + vmlal_s16(vacc2_lo, vget_low_s16(vxk11), vget_low_s16(vxi13)); + vacc2_hi = vmlal_high_s16(vacc2_hi, vxk11, vxi13); + + const uint8x8_t vk21 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi23 = vld1_u8(i23); + i23 += 8; + const int16x8_t vxk21 = + vreinterpretq_s16_u16(vsubl_u8(vk21, vkernel_zero_point)); + const int16x8_t vxi23 = + vreinterpretq_s16_u16(sub_zero_point(vi23, va_zero_point)); + vacc0_lo = + vmlal_s16(vacc0_lo, vget_low_s16(vxk21), vget_low_s16(vxi21)); + vacc0_hi = vmlal_high_s16(vacc0_hi, vxk21, vxi21); + vacc1_lo = + vmlal_s16(vacc1_lo, vget_low_s16(vxk21), vget_low_s16(vxi22)); + vacc1_hi = vmlal_high_s16(vacc1_hi, vxk21, vxi22); + vacc2_lo = + vmlal_s16(vacc2_lo, vget_low_s16(vxk21), vget_low_s16(vxi23)); + vacc2_hi = vmlal_high_s16(vacc2_hi, vxk21, vxi23); + + const uint8x8_t vk02 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi04 = vld1_u8(i04); + i04 += 8; + const int16x8_t vxk02 = + vreinterpretq_s16_u16(vsubl_u8(vk02, vkernel_zero_point)); + const int16x8_t vxi04 = + vreinterpretq_s16_u16(sub_zero_point(vi04, va_zero_point)); + vacc0_lo = + vmlal_s16(vacc0_lo, vget_low_s16(vxk02), vget_low_s16(vxi02)); + vacc0_hi = vmlal_high_s16(vacc0_hi, vxk02, vxi02); + vacc1_lo = + vmlal_s16(vacc1_lo, vget_low_s16(vxk02), vget_low_s16(vxi03)); + vacc1_hi = vmlal_high_s16(vacc1_hi, vxk02, vxi03); + vacc2_lo = + vmlal_s16(vacc2_lo, vget_low_s16(vxk02), vget_low_s16(vxi04)); + vacc2_hi = vmlal_high_s16(vacc2_hi, vxk02, vxi04); + + const uint8x8_t vk12 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi14 = vld1_u8(i14); + i14 += 8; + const int16x8_t vxk12 = + vreinterpretq_s16_u16(vsubl_u8(vk12, vkernel_zero_point)); + const int16x8_t vxi14 = + vreinterpretq_s16_u16(sub_zero_point(vi14, va_zero_point)); + vacc0_lo = + vmlal_s16(vacc0_lo, vget_low_s16(vxk12), vget_low_s16(vxi12)); + vacc0_hi = vmlal_high_s16(vacc0_hi, vxk12, vxi12); + vacc1_lo = + vmlal_s16(vacc1_lo, vget_low_s16(vxk12), vget_low_s16(vxi13)); + vacc1_hi = vmlal_high_s16(vacc1_hi, vxk12, vxi13); + vacc2_lo = + vmlal_s16(vacc2_lo, vget_low_s16(vxk12), vget_low_s16(vxi14)); + vacc2_hi = vmlal_high_s16(vacc2_hi, vxk12, vxi14); + + const uint8x8_t vk22 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi24 = vld1_u8(i24); + i24 += 8; + const int16x8_t vxk22 = + vreinterpretq_s16_u16(vsubl_u8(vk22, vkernel_zero_point)); + const int16x8_t vxi24 = + vreinterpretq_s16_u16(sub_zero_point(vi24, va_zero_point)); + vacc0_lo = + vmlal_s16(vacc0_lo, vget_low_s16(vxk22), vget_low_s16(vxi22)); + vacc0_hi = vmlal_high_s16(vacc0_hi, vxk22, vxi22); + vacc1_lo = + vmlal_s16(vacc1_lo, vget_low_s16(vxk22), vget_low_s16(vxi23)); + vacc1_hi = vmlal_high_s16(vacc1_hi, vxk22, vxi23); + vacc2_lo = + vmlal_s16(vacc2_lo, vget_low_s16(vxk22), vget_low_s16(vxi24)); + vacc2_hi = vmlal_high_s16(vacc2_hi, vxk22, vxi24); + + vacc0_lo = vqrdmulhq_s32(vacc0_lo, vmultiplier); + vacc0_hi = vqrdmulhq_s32(vacc0_hi, vmultiplier); + vacc1_lo = vqrdmulhq_s32(vacc1_lo, vmultiplier); + vacc1_hi = vqrdmulhq_s32(vacc1_hi, vmultiplier); + vacc2_lo = vqrdmulhq_s32(vacc2_lo, vmultiplier); + vacc2_hi = vqrdmulhq_s32(vacc2_hi, vmultiplier); + + const int32x4_t vzero_shift_mask = + vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0_lo = + vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31); + vacc0_hi = + vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31); + vacc1_lo = + vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31); + vacc1_hi = + vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31); + vacc2_lo = + vsraq_n_s32(vacc2_lo, vbicq_s32(vacc2_lo, vzero_shift_mask), 31); + vacc2_hi = + vsraq_n_s32(vacc2_hi, vbicq_s32(vacc2_hi, vzero_shift_mask), 31); + + vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift); + vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift); + vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift); + vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift); + vacc2_lo = vrshlq_s32(vacc2_lo, vright_shift); + vacc2_hi = vrshlq_s32(vacc2_hi, vright_shift); + + const int16x8_t vacc0 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi), + voutput_zero_point); + const int16x8_t vacc1 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc1_lo), vacc1_hi), + voutput_zero_point); + const int16x8_t vacc2 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc2_lo), vacc2_hi), + voutput_zero_point); + uint8x8_t vout0 = vqmovun_s16(vacc0); + uint8x8_t vout1 = vqmovun_s16(vacc1); + uint8x8_t vout2 = vqmovun_s16(vacc2); + vout0 = vmax_u8(vout0, voutput_min); + vout1 = vmax_u8(vout1, voutput_min); + vout2 = vmax_u8(vout2, voutput_min); + vout0 = vmin_u8(vout0, voutput_max); + vout1 = vmin_u8(vout1, voutput_max); + vout2 = vmin_u8(vout2, voutput_max); + + vst1_u8(output0, vout0); + output0 += 8; + vst1_u8(output1, vout1); + output1 += 8; + vst1_u8(output2, vout2); + output2 += 8; + } + if (c != 0) { + const size_t c_predecrement = 8 - c; + const int64x1_t vi_shift = vmov_n_s64(-8 * c_predecrement); + i00 -= c_predecrement; + i10 -= c_predecrement; + i20 -= c_predecrement; + i01 -= c_predecrement; + i11 -= c_predecrement; + i21 -= c_predecrement; + i02 -= c_predecrement; + i12 -= c_predecrement; + i22 -= c_predecrement; + i03 -= c_predecrement; + i13 -= c_predecrement; + i23 -= c_predecrement; + i04 -= c_predecrement; + i14 -= c_predecrement; + i24 -= c_predecrement; + + int32x4_t vacc0_lo = vld1q_s32(w); + w = (void*)((uintptr_t)w + sizeof(int32x4_t)); + int32x4_t vacc0_hi = vld1q_s32(w); + w = (void*)((uintptr_t)w + sizeof(int32x4_t)); + int32x4_t vacc1_lo = vacc0_lo; + int32x4_t vacc2_lo = vacc0_lo; + int32x4_t vacc1_hi = vacc0_hi; + int32x4_t vacc2_hi = vacc0_hi; + + const uint8x8_t vk00 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi00 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i00)), vi_shift)); + const uint8x8_t vi01 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i01)), vi_shift)); + const uint8x8_t vi02 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i02)), vi_shift)); + const int16x8_t vxk00 = + vreinterpretq_s16_u16(vsubl_u8(vk00, vkernel_zero_point)); + const int16x8_t vxi00 = + vreinterpretq_s16_u16(sub_zero_point(vi00, va_zero_point)); + const int16x8_t vxi01 = + vreinterpretq_s16_u16(sub_zero_point(vi01, va_zero_point)); + const int16x8_t vxi02 = + vreinterpretq_s16_u16(sub_zero_point(vi02, va_zero_point)); + vacc0_lo = + vmlal_s16(vacc0_lo, vget_low_s16(vxk00), vget_low_s16(vxi00)); + vacc0_hi = vmlal_high_s16(vacc0_hi, vxk00, vxi00); + vacc1_lo = + vmlal_s16(vacc1_lo, vget_low_s16(vxk00), vget_low_s16(vxi01)); + vacc1_hi = vmlal_high_s16(vacc1_hi, vxk00, vxi01); + vacc2_lo = + vmlal_s16(vacc2_lo, vget_low_s16(vxk00), vget_low_s16(vxi02)); + vacc2_hi = vmlal_high_s16(vacc2_hi, vxk00, vxi02); + + const uint8x8_t vk10 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi10 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i10)), vi_shift)); + const uint8x8_t vi11 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i11)), vi_shift)); + const uint8x8_t vi12 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i12)), vi_shift)); + const int16x8_t vxk10 = + vreinterpretq_s16_u16(vsubl_u8(vk10, vkernel_zero_point)); + const int16x8_t vxi10 = + vreinterpretq_s16_u16(sub_zero_point(vi10, va_zero_point)); + const int16x8_t vxi11 = + vreinterpretq_s16_u16(sub_zero_point(vi11, va_zero_point)); + const int16x8_t vxi12 = + vreinterpretq_s16_u16(sub_zero_point(vi12, va_zero_point)); + vacc0_lo = + vmlal_s16(vacc0_lo, vget_low_s16(vxk10), vget_low_s16(vxi10)); + vacc0_hi = vmlal_high_s16(vacc0_hi, vxk10, vxi10); + vacc1_lo = + vmlal_s16(vacc1_lo, vget_low_s16(vxk10), vget_low_s16(vxi11)); + vacc1_hi = vmlal_high_s16(vacc1_hi, vxk10, vxi11); + vacc2_lo = + vmlal_s16(vacc2_lo, vget_low_s16(vxk10), vget_low_s16(vxi12)); + vacc2_hi = vmlal_high_s16(vacc2_hi, vxk10, vxi12); + + const uint8x8_t vk20 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi20 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i20)), vi_shift)); + const uint8x8_t vi21 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i21)), vi_shift)); + const uint8x8_t vi22 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i22)), vi_shift)); + const int16x8_t vxk20 = + vreinterpretq_s16_u16(vsubl_u8(vk20, vkernel_zero_point)); + const int16x8_t vxi20 = + vreinterpretq_s16_u16(sub_zero_point(vi20, va_zero_point)); + const int16x8_t vxi21 = + vreinterpretq_s16_u16(sub_zero_point(vi21, va_zero_point)); + const int16x8_t vxi22 = + vreinterpretq_s16_u16(sub_zero_point(vi22, va_zero_point)); + vacc0_lo = + vmlal_s16(vacc0_lo, vget_low_s16(vxk20), vget_low_s16(vxi20)); + vacc0_hi = vmlal_high_s16(vacc0_hi, vxk20, vxi20); + vacc1_lo = + vmlal_s16(vacc1_lo, vget_low_s16(vxk20), vget_low_s16(vxi21)); + vacc1_hi = vmlal_high_s16(vacc1_hi, vxk20, vxi21); + vacc2_lo = + vmlal_s16(vacc2_lo, vget_low_s16(vxk20), vget_low_s16(vxi22)); + vacc2_hi = vmlal_high_s16(vacc2_hi, vxk20, vxi22); + + const uint8x8_t vk01 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi03 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i03)), vi_shift)); + const int16x8_t vxk01 = + vreinterpretq_s16_u16(vsubl_u8(vk01, vkernel_zero_point)); + const int16x8_t vxi03 = + vreinterpretq_s16_u16(sub_zero_point(vi03, va_zero_point)); + vacc0_lo = + vmlal_s16(vacc0_lo, vget_low_s16(vxk01), vget_low_s16(vxi01)); + vacc0_hi = vmlal_high_s16(vacc0_hi, vxk01, vxi01); + vacc1_lo = + vmlal_s16(vacc1_lo, vget_low_s16(vxk01), vget_low_s16(vxi02)); + vacc1_hi = vmlal_high_s16(vacc1_hi, vxk01, vxi02); + vacc2_lo = + vmlal_s16(vacc2_lo, vget_low_s16(vxk01), vget_low_s16(vxi03)); + vacc2_hi = vmlal_high_s16(vacc2_hi, vxk01, vxi03); + + const uint8x8_t vk11 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi13 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i13)), vi_shift)); + const int16x8_t vxk11 = + vreinterpretq_s16_u16(vsubl_u8(vk11, vkernel_zero_point)); + const int16x8_t vxi13 = + vreinterpretq_s16_u16(sub_zero_point(vi13, va_zero_point)); + vacc0_lo = + vmlal_s16(vacc0_lo, vget_low_s16(vxk11), vget_low_s16(vxi11)); + vacc0_hi = vmlal_high_s16(vacc0_hi, vxk11, vxi11); + vacc1_lo = + vmlal_s16(vacc1_lo, vget_low_s16(vxk11), vget_low_s16(vxi12)); + vacc1_hi = vmlal_high_s16(vacc1_hi, vxk11, vxi12); + vacc2_lo = + vmlal_s16(vacc2_lo, vget_low_s16(vxk11), vget_low_s16(vxi13)); + vacc2_hi = vmlal_high_s16(vacc2_hi, vxk11, vxi13); + + const uint8x8_t vk21 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi23 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i23)), vi_shift)); + const int16x8_t vxk21 = + vreinterpretq_s16_u16(vsubl_u8(vk21, vkernel_zero_point)); + const int16x8_t vxi23 = + vreinterpretq_s16_u16(sub_zero_point(vi23, va_zero_point)); + vacc0_lo = + vmlal_s16(vacc0_lo, vget_low_s16(vxk21), vget_low_s16(vxi21)); + vacc0_hi = vmlal_high_s16(vacc0_hi, vxk21, vxi21); + vacc1_lo = + vmlal_s16(vacc1_lo, vget_low_s16(vxk21), vget_low_s16(vxi22)); + vacc1_hi = vmlal_high_s16(vacc1_hi, vxk21, vxi22); + vacc2_lo = + vmlal_s16(vacc2_lo, vget_low_s16(vxk21), vget_low_s16(vxi23)); + vacc2_hi = vmlal_high_s16(vacc2_hi, vxk21, vxi23); + + const uint8x8_t vk02 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi04 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i04)), vi_shift)); + const int16x8_t vxk02 = + vreinterpretq_s16_u16(vsubl_u8(vk02, vkernel_zero_point)); + const int16x8_t vxi04 = + vreinterpretq_s16_u16(sub_zero_point(vi04, va_zero_point)); + vacc0_lo = + vmlal_s16(vacc0_lo, vget_low_s16(vxk02), vget_low_s16(vxi02)); + vacc0_hi = vmlal_high_s16(vacc0_hi, vxk02, vxi02); + vacc1_lo = + vmlal_s16(vacc1_lo, vget_low_s16(vxk02), vget_low_s16(vxi03)); + vacc1_hi = vmlal_high_s16(vacc1_hi, vxk02, vxi03); + vacc2_lo = + vmlal_s16(vacc2_lo, vget_low_s16(vxk02), vget_low_s16(vxi04)); + vacc2_hi = vmlal_high_s16(vacc2_hi, vxk02, vxi04); + + const uint8x8_t vk12 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi14 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i14)), vi_shift)); + const int16x8_t vxk12 = + vreinterpretq_s16_u16(vsubl_u8(vk12, vkernel_zero_point)); + const int16x8_t vxi14 = + vreinterpretq_s16_u16(sub_zero_point(vi14, va_zero_point)); + vacc0_lo = + vmlal_s16(vacc0_lo, vget_low_s16(vxk12), vget_low_s16(vxi12)); + vacc0_hi = vmlal_high_s16(vacc0_hi, vxk12, vxi12); + vacc1_lo = + vmlal_s16(vacc1_lo, vget_low_s16(vxk12), vget_low_s16(vxi13)); + vacc1_hi = vmlal_high_s16(vacc1_hi, vxk12, vxi13); + vacc2_lo = + vmlal_s16(vacc2_lo, vget_low_s16(vxk12), vget_low_s16(vxi14)); + vacc2_hi = vmlal_high_s16(vacc2_hi, vxk12, vxi14); + + const uint8x8_t vk22 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi24 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i24)), vi_shift)); + const int16x8_t vxk22 = + vreinterpretq_s16_u16(vsubl_u8(vk22, vkernel_zero_point)); + const int16x8_t vxi24 = + vreinterpretq_s16_u16(sub_zero_point(vi24, va_zero_point)); + vacc0_lo = + vmlal_s16(vacc0_lo, vget_low_s16(vxk22), vget_low_s16(vxi22)); + vacc0_hi = vmlal_high_s16(vacc0_hi, vxk22, vxi22); + vacc1_lo = + vmlal_s16(vacc1_lo, vget_low_s16(vxk22), vget_low_s16(vxi23)); + vacc1_hi = vmlal_high_s16(vacc1_hi, vxk22, vxi23); + vacc2_lo = + vmlal_s16(vacc2_lo, vget_low_s16(vxk22), vget_low_s16(vxi24)); + vacc2_hi = vmlal_high_s16(vacc2_hi, vxk22, vxi24); + + vacc0_lo = vqrdmulhq_s32(vacc0_lo, vmultiplier); + vacc0_hi = vqrdmulhq_s32(vacc0_hi, vmultiplier); + vacc1_lo = vqrdmulhq_s32(vacc1_lo, vmultiplier); + vacc1_hi = vqrdmulhq_s32(vacc1_hi, vmultiplier); + vacc2_lo = vqrdmulhq_s32(vacc2_lo, vmultiplier); + vacc2_hi = vqrdmulhq_s32(vacc2_hi, vmultiplier); + + const int32x4_t vzero_shift_mask = + vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0_lo = + vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31); + vacc0_hi = + vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31); + vacc1_lo = + vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31); + vacc1_hi = + vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31); + vacc2_lo = + vsraq_n_s32(vacc2_lo, vbicq_s32(vacc2_lo, vzero_shift_mask), 31); + vacc2_hi = + vsraq_n_s32(vacc2_hi, vbicq_s32(vacc2_hi, vzero_shift_mask), 31); + + vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift); + vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift); + vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift); + vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift); + vacc2_lo = vrshlq_s32(vacc2_lo, vright_shift); + vacc2_hi = vrshlq_s32(vacc2_hi, vright_shift); + + const int16x8_t vacc0 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi), + voutput_zero_point); + const int16x8_t vacc1 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc1_lo), vacc1_hi), + voutput_zero_point); + const int16x8_t vacc2 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc2_lo), vacc2_hi), + voutput_zero_point); + uint8x8_t vout0 = vqmovun_s16(vacc0); + uint8x8_t vout1 = vqmovun_s16(vacc1); + uint8x8_t vout2 = vqmovun_s16(vacc2); + vout0 = vmax_u8(vout0, voutput_min); + vout1 = vmax_u8(vout1, voutput_min); + vout2 = vmax_u8(vout2, voutput_min); + vout0 = vmin_u8(vout0, voutput_max); + vout1 = vmin_u8(vout1, voutput_max); + vout2 = vmin_u8(vout2, voutput_max); + + if (c & 4) { + vst1_lane_u32( + __builtin_assume_aligned(output0, 1), + vreinterpret_u32_u8(vout0), + 0); + output0 += 4; + vst1_lane_u32( + __builtin_assume_aligned(output1, 1), + vreinterpret_u32_u8(vout1), + 0); + output1 += 4; + vst1_lane_u32( + __builtin_assume_aligned(output2, 1), + vreinterpret_u32_u8(vout2), + 0); + output2 += 4; + vout0 = vext_u8(vout0, vout0, 4); + vout1 = vext_u8(vout1, vout1, 4); + vout2 = vext_u8(vout2, vout2, 4); + } + if (c & 2) { + vst1_lane_u16( + __builtin_assume_aligned(output0, 1), + vreinterpret_u16_u8(vout0), + 0); + output0 += 2; + vst1_lane_u16( + __builtin_assume_aligned(output1, 1), + vreinterpret_u16_u8(vout1), + 0); + output1 += 2; + vst1_lane_u16( + __builtin_assume_aligned(output2, 1), + vreinterpret_u16_u8(vout2), + 0); + output2 += 2; + vout0 = vext_u8(vout0, vout0, 2); + vout1 = vext_u8(vout1, vout1, 2); + vout2 = vext_u8(vout2, vout2, 2); + } + if (c & 1) { + vst1_lane_u8(__builtin_assume_aligned(output0, 1), vout0, 0); + output0++; + vst1_lane_u8(__builtin_assume_aligned(output1, 1), vout1, 0); + output1++; + vst1_lane_u8(__builtin_assume_aligned(output2, 1), vout2, 0); + output2++; + } + } + + output = (uint8_t*)((uintptr_t)output2 + output_increment); + } + if (output_width == 0) { + return; + } + } +#endif + + do { + const uint8_t* i0 = input[0]; + const uint8_t* i1 = input[1]; + const uint8_t* i2 = input[2]; + const uint8_t* i3 = input[3]; + const uint8_t* i4 = input[4]; + const uint8_t* i5 = input[5]; + const uint8_t* i6 = input[6]; + const uint8_t* i7 = input[7]; + const uint8_t* i8 = input[8]; + + input = (const uint8_t**)((uintptr_t)input + input_stride); + + size_t c = channels; + const void* w = weights; + for (; c >= 8; c -= 8) { + int32x4_t vaccX1_lo = vld1q_s32(w); + w = (void*)((uintptr_t)w + sizeof(int32x4_t)); + int32x4_t vaccX1_hi = vld1q_s32(w); + w = (void*)((uintptr_t)w + sizeof(int32x4_t)); + + const uint8x8_t vk0 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi0 = vld1_u8(i0); + i0 += 8; + const int16x8_t vxk0 = + vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point)); + const int16x8_t vxi0 = + vreinterpretq_s16_u16(sub_zero_point(vi0, va_zero_point)); + int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0)); + int32x4_t vaccX0_hi = vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0)); + + const uint8x8_t vk1 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi1 = vld1_u8(i1); + i1 += 8; + const int16x8_t vxk1 = + vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point)); + const int16x8_t vxi1 = + vreinterpretq_s16_u16(sub_zero_point(vi1, va_zero_point)); + vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1)); + + const uint8x8_t vk2 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi2 = vld1_u8(i2); + i2 += 8; + const int16x8_t vxk2 = + vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point)); + const int16x8_t vxi2 = + vreinterpretq_s16_u16(sub_zero_point(vi2, va_zero_point)); + vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2)); + + const uint8x8_t vk3 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi3 = vld1_u8(i3); + i3 += 8; + const int16x8_t vxk3 = + vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point)); + const int16x8_t vxi3 = + vreinterpretq_s16_u16(sub_zero_point(vi3, va_zero_point)); + vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3)); + + const uint8x8_t vk4 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi4 = vld1_u8(i4); + i4 += 8; + const int16x8_t vxk4 = + vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point)); + const int16x8_t vxi4 = + vreinterpretq_s16_u16(sub_zero_point(vi4, va_zero_point)); + vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4)); + + const uint8x8_t vk5 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi5 = vld1_u8(i5); + i5 += 8; + const int16x8_t vxk5 = + vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point)); + const int16x8_t vxi5 = + vreinterpretq_s16_u16(sub_zero_point(vi5, va_zero_point)); + vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5)); + + const uint8x8_t vk6 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi6 = vld1_u8(i6); + i6 += 8; + const int16x8_t vxk6 = + vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point)); + const int16x8_t vxi6 = + vreinterpretq_s16_u16(sub_zero_point(vi6, va_zero_point)); + vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6)); + + const uint8x8_t vk7 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi7 = vld1_u8(i7); + i7 += 8; + const int16x8_t vxk7 = + vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point)); + const int16x8_t vxi7 = + vreinterpretq_s16_u16(sub_zero_point(vi7, va_zero_point)); + vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7)); + + const uint8x8_t vk8 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi8 = vld1_u8(i8); + i8 += 8; + const int16x8_t vxk8 = + vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point)); + const int16x8_t vxi8 = + vreinterpretq_s16_u16(sub_zero_point(vi8, va_zero_point)); + vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8)); + + int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo); + int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi); + + vacc_lo = vqrdmulhq_s32(vacc_lo, vmultiplier); + vacc_hi = vqrdmulhq_s32(vacc_hi, vmultiplier); + + const int32x4_t vzero_shift_mask = + vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc_lo = vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31); + vacc_hi = vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31); + + vacc_lo = vrshlq_s32(vacc_lo, vright_shift); + vacc_hi = vrshlq_s32(vacc_hi, vright_shift); + +#ifdef __aarch64__ + const int16x8_t vacc = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point); +#else + const int16x8_t vacc = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), + voutput_zero_point); +#endif + uint8x8_t vout = vqmovun_s16(vacc); + vout = vmax_u8(vout, voutput_min); + vout = vmin_u8(vout, voutput_max); + + vst1_u8(output, vout); + output += 8; + } + if (c != 0) { + const size_t c_predecrement = 8 - c; + const int64x1_t vi_shift = vmov_n_s64(-8 * c_predecrement); + i0 -= c_predecrement; + i1 -= c_predecrement; + i2 -= c_predecrement; + i3 -= c_predecrement; + i4 -= c_predecrement; + i5 -= c_predecrement; + i6 -= c_predecrement; + i7 -= c_predecrement; + i8 -= c_predecrement; + + int32x4_t vaccX1_lo = vld1q_s32(w); + w = (void*)((uintptr_t)w + sizeof(int32x4_t)); + int32x4_t vaccX1_hi = vld1q_s32(w); + w = (void*)((uintptr_t)w + sizeof(int32x4_t)); + + const uint8x8_t vk0 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi0 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vi_shift)); + const int16x8_t vxk0 = + vreinterpretq_s16_u16(vsubl_u8(vk0, vkernel_zero_point)); + const int16x8_t vxi0 = + vreinterpretq_s16_u16(sub_zero_point(vi0, va_zero_point)); + int32x4_t vaccX0_lo = vmull_s16(vget_low_s16(vxk0), vget_low_s16(vxi0)); + int32x4_t vaccX0_hi = vmull_s16(vget_high_s16(vxk0), vget_high_s16(vxi0)); + + const uint8x8_t vk1 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi1 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vi_shift)); + const int16x8_t vxk1 = + vreinterpretq_s16_u16(vsubl_u8(vk1, vkernel_zero_point)); + const int16x8_t vxi1 = + vreinterpretq_s16_u16(sub_zero_point(vi1, va_zero_point)); + vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk1), vget_low_s16(vxi1)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk1), vget_high_s16(vxi1)); + + const uint8x8_t vk2 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi2 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vi_shift)); + const int16x8_t vxk2 = + vreinterpretq_s16_u16(vsubl_u8(vk2, vkernel_zero_point)); + const int16x8_t vxi2 = + vreinterpretq_s16_u16(sub_zero_point(vi2, va_zero_point)); + vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk2), vget_low_s16(vxi2)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk2), vget_high_s16(vxi2)); + + const uint8x8_t vk3 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi3 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vi_shift)); + const int16x8_t vxk3 = + vreinterpretq_s16_u16(vsubl_u8(vk3, vkernel_zero_point)); + const int16x8_t vxi3 = + vreinterpretq_s16_u16(sub_zero_point(vi3, va_zero_point)); + vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk3), vget_low_s16(vxi3)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk3), vget_high_s16(vxi3)); + + const uint8x8_t vk4 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi4 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vi_shift)); + const int16x8_t vxk4 = + vreinterpretq_s16_u16(vsubl_u8(vk4, vkernel_zero_point)); + const int16x8_t vxi4 = + vreinterpretq_s16_u16(sub_zero_point(vi4, va_zero_point)); + vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk4), vget_low_s16(vxi4)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk4), vget_high_s16(vxi4)); + + const uint8x8_t vk5 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi5 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vi_shift)); + const int16x8_t vxk5 = + vreinterpretq_s16_u16(vsubl_u8(vk5, vkernel_zero_point)); + const int16x8_t vxi5 = + vreinterpretq_s16_u16(sub_zero_point(vi5, va_zero_point)); + vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk5), vget_low_s16(vxi5)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk5), vget_high_s16(vxi5)); + + const uint8x8_t vk6 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi6 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vi_shift)); + const int16x8_t vxk6 = + vreinterpretq_s16_u16(vsubl_u8(vk6, vkernel_zero_point)); + const int16x8_t vxi6 = + vreinterpretq_s16_u16(sub_zero_point(vi6, va_zero_point)); + vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk6), vget_low_s16(vxi6)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk6), vget_high_s16(vxi6)); + + const uint8x8_t vk7 = vld1_u8(w); + w = (void*)((uintptr_t)w + sizeof(uint8x8_t)); + const uint8x8_t vi7 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i7)), vi_shift)); + const int16x8_t vxk7 = + vreinterpretq_s16_u16(vsubl_u8(vk7, vkernel_zero_point)); + const int16x8_t vxi7 = + vreinterpretq_s16_u16(sub_zero_point(vi7, va_zero_point)); + vaccX1_lo = vmlal_s16(vaccX1_lo, vget_low_s16(vxk7), vget_low_s16(vxi7)); + vaccX1_hi = + vmlal_s16(vaccX1_hi, vget_high_s16(vxk7), vget_high_s16(vxi7)); + + const uint8x8_t vk8 = vld1_u8(w); + const uint8x8_t vi8 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(i8)), vi_shift)); + const int16x8_t vxk8 = + vreinterpretq_s16_u16(vsubl_u8(vk8, vkernel_zero_point)); + const int16x8_t vxi8 = + vreinterpretq_s16_u16(sub_zero_point(vi8, va_zero_point)); + vaccX0_lo = vmlal_s16(vaccX0_lo, vget_low_s16(vxk8), vget_low_s16(vxi8)); + vaccX0_hi = + vmlal_s16(vaccX0_hi, vget_high_s16(vxk8), vget_high_s16(vxi8)); + + int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo); + int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi); + + vacc_lo = vqrdmulhq_s32(vacc_lo, vmultiplier); + vacc_hi = vqrdmulhq_s32(vacc_hi, vmultiplier); + + const int32x4_t vzero_shift_mask = + vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc_lo = vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31); + vacc_hi = vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31); + + vacc_lo = vrshlq_s32(vacc_lo, vright_shift); + vacc_hi = vrshlq_s32(vacc_hi, vright_shift); + +#ifdef __aarch64__ + const int16x8_t vacc = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point); +#else + const int16x8_t vacc = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), + voutput_zero_point); +#endif + uint8x8_t vout = vqmovun_s16(vacc); + vout = vmax_u8(vout, voutput_min); + vout = vmin_u8(vout, voutput_max); + + if (c & 4) { + vst1_lane_u32( + __builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0); + output += 4; + vout = vext_u8(vout, vout, 4); + } + if (c & 2) { + vst1_lane_u16( + __builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0); + output += 2; + vout = vext_u8(vout, vout, 2); + } + if (c & 1) { + vst1_lane_u8(__builtin_assume_aligned(output, 1), vout, 0); + output++; + } + } + + output = (uint8_t*)((uintptr_t)output + output_increment); + } while (--output_width != 0); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-sse2.c new file mode 100644 index 0000000000000..a0a58803a3f0b --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8dwconv/up8x9-sse2.c @@ -0,0 +1,548 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +void pytorch_q8dwconv_ukernel_up8x9__sse2( + size_t channels, + size_t output_width, + const uint8_t** input, + const void* weights, + uint8_t* output, + size_t input_stride, + size_t output_increment, + const union pytorch_qnnp_conv_quantization_params + quantization_params[RESTRICT_STATIC 1]) { + const __m128i va_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.input_zero_point); + const __m128i vkernel_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.kernel_zero_point); + const __m128i vzero = _mm_setzero_si128(); + + do { + const uint8_t* i0 = input[0]; + const uint8_t* i1 = input[1]; + const uint8_t* i2 = input[2]; + const uint8_t* i3 = input[3]; + const uint8_t* i4 = input[4]; + const uint8_t* i5 = input[5]; + const uint8_t* i6 = input[6]; + const uint8_t* i7 = input[7]; + const uint8_t* i8 = input[8]; + + input = (const uint8_t**)((uintptr_t)input + input_stride); + + size_t c = channels; + const void* w = weights; + for (; c >= 8; c -= 8) { + __m128i vacc_lo = _mm_loadu_si128((const __m128i*)w); + __m128i vacc_hi = _mm_loadu_si128((const __m128i*)((uintptr_t)w + 16)); + + const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0); + i0 += 8; + const __m128i vxi0 = + sub_zero_point(_mm_unpacklo_epi8(vi0, vzero), va_zero_point); + const __m128i vk0 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32)); + const __m128i vxk0 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk0, vzero), vkernel_zero_point); + const __m128i vprod0_odd = _mm_mullo_epi16(vxi0, vxk0); + const __m128i vprod0_even = _mm_mulhi_epi16(vxi0, vxk0); + vacc_lo = + _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod0_odd, vprod0_even)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod0_odd, vprod0_even)); + + const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1); + i1 += 8; + const __m128i vxi1 = + sub_zero_point(_mm_unpacklo_epi8(vi1, vzero), va_zero_point); + const __m128i vk1 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 40)); + const __m128i vxk1 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk1, vzero), vkernel_zero_point); + const __m128i vprod1_odd = _mm_mullo_epi16(vxi1, vxk1); + const __m128i vprod1_even = _mm_mulhi_epi16(vxi1, vxk1); + vacc_lo = + _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod1_odd, vprod1_even)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod1_odd, vprod1_even)); + + const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2); + i2 += 8; + const __m128i vxi2 = + sub_zero_point(_mm_unpacklo_epi8(vi2, vzero), va_zero_point); + const __m128i vk2 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 48)); + const __m128i vxk2 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk2, vzero), vkernel_zero_point); + const __m128i vprod2_odd = _mm_mullo_epi16(vxi2, vxk2); + const __m128i vprod2_even = _mm_mulhi_epi16(vxi2, vxk2); + vacc_lo = + _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod2_odd, vprod2_even)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod2_odd, vprod2_even)); + + const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3); + i3 += 8; + const __m128i vxi3 = + sub_zero_point(_mm_unpacklo_epi8(vi3, vzero), va_zero_point); + const __m128i vk3 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 56)); + const __m128i vxk3 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk3, vzero), vkernel_zero_point); + const __m128i vprod3_odd = _mm_mullo_epi16(vxi3, vxk3); + const __m128i vprod3_even = _mm_mulhi_epi16(vxi3, vxk3); + vacc_lo = + _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod3_odd, vprod3_even)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod3_odd, vprod3_even)); + + const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4); + i4 += 8; + const __m128i vxi4 = + sub_zero_point(_mm_unpacklo_epi8(vi4, vzero), va_zero_point); + const __m128i vk4 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 64)); + const __m128i vxk4 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk4, vzero), vkernel_zero_point); + const __m128i vprod4_odd = _mm_mullo_epi16(vxi4, vxk4); + const __m128i vprod4_even = _mm_mulhi_epi16(vxi4, vxk4); + vacc_lo = + _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod4_odd, vprod4_even)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod4_odd, vprod4_even)); + + const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5); + i5 += 8; + const __m128i vxi5 = + sub_zero_point(_mm_unpacklo_epi8(vi5, vzero), va_zero_point); + const __m128i vk5 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 72)); + const __m128i vxk5 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk5, vzero), vkernel_zero_point); + const __m128i vprod5_odd = _mm_mullo_epi16(vxi5, vxk5); + const __m128i vprod5_even = _mm_mulhi_epi16(vxi5, vxk5); + vacc_lo = + _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod5_odd, vprod5_even)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod5_odd, vprod5_even)); + + const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6); + i6 += 8; + const __m128i vxi6 = + sub_zero_point(_mm_unpacklo_epi8(vi6, vzero), va_zero_point); + const __m128i vk6 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 80)); + const __m128i vxk6 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk6, vzero), vkernel_zero_point); + const __m128i vprod6_odd = _mm_mullo_epi16(vxi6, vxk6); + const __m128i vprod6_even = _mm_mulhi_epi16(vxi6, vxk6); + vacc_lo = + _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod6_odd, vprod6_even)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod6_odd, vprod6_even)); + + const __m128i vi7 = _mm_loadl_epi64((const __m128i*)i7); + i7 += 8; + const __m128i vxi7 = + sub_zero_point(_mm_unpacklo_epi8(vi7, vzero), va_zero_point); + const __m128i vk7 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 88)); + const __m128i vxk7 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk7, vzero), vkernel_zero_point); + const __m128i vprod7_odd = _mm_mullo_epi16(vxi7, vxk7); + const __m128i vprod7_even = _mm_mulhi_epi16(vxi7, vxk7); + vacc_lo = + _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod7_odd, vprod7_even)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod7_odd, vprod7_even)); + + const __m128i vi8 = _mm_loadl_epi64((const __m128i*)i8); + i8 += 8; + const __m128i vxi8 = + sub_zero_point(_mm_unpacklo_epi8(vi8, vzero), va_zero_point); + const __m128i vk8 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 96)); + const __m128i vxk8 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk8, vzero), vkernel_zero_point); + const __m128i vprod8_odd = _mm_mullo_epi16(vxi8, vxk8); + const __m128i vprod8_even = _mm_mulhi_epi16(vxi8, vxk8); + vacc_lo = + _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod8_odd, vprod8_even)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod8_odd, vprod8_even)); + + w = (void*)((uintptr_t)w + 104); + + const __m128i vmultiplier = + _mm_load_si128((const __m128i*)quantization_params->sse2.multiplier); + const __m128i vrounding = + _mm_load_si128((const __m128i*)quantization_params->sse2.rounding); + + const __m128i vnmask_lo0123 = + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo); + const __m128i vnmask_hi0123 = + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi); + + const __m128i vabsacc_lo0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_lo, vnmask_lo0123), vnmask_lo0123); + const __m128i vabsacc_hi0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_hi, vnmask_hi0123), vnmask_hi0123); + + const __m128i vabsacc_lo1032 = + _mm_shuffle_epi32(vabsacc_lo0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabsacc_hi1032 = + _mm_shuffle_epi32(vabsacc_hi0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i vabsprod_lo02 = _mm_mul_epu32(vabsacc_lo0123, vmultiplier); + const __m128i vabsprod_hi02 = _mm_mul_epu32(vabsacc_hi0123, vmultiplier); + + const __m128i vnmask_lo02 = + _mm_shuffle_epi32(vnmask_lo0123, _MM_SHUFFLE(2, 2, 0, 0)); + const __m128i vnmask_hi02 = + _mm_shuffle_epi32(vnmask_hi0123, _MM_SHUFFLE(2, 2, 0, 0)); + + const __m128i vprod_lo02 = + _mm_sub_epi64(_mm_xor_si128(vabsprod_lo02, vnmask_lo02), vnmask_lo02); + const __m128i vprod_hi02 = + _mm_sub_epi64(_mm_xor_si128(vabsprod_hi02, vnmask_hi02), vnmask_hi02); + + const __m128i vq31prod_lo02 = + _mm_srli_epi64(_mm_add_epi64(vprod_lo02, vrounding), 31); + const __m128i vq31prod_hi02 = + _mm_srli_epi64(_mm_add_epi64(vprod_hi02, vrounding), 31); + + const __m128i vabsprod_lo13 = _mm_mul_epu32(vabsacc_lo1032, vmultiplier); + const __m128i vabsprod_hi13 = _mm_mul_epu32(vabsacc_hi1032, vmultiplier); + + const __m128i vnmask_lo13 = + _mm_shuffle_epi32(vnmask_lo0123, _MM_SHUFFLE(3, 3, 1, 1)); + const __m128i vnmask_hi13 = + _mm_shuffle_epi32(vnmask_hi0123, _MM_SHUFFLE(3, 3, 1, 1)); + + const __m128i vprod_lo13 = + _mm_sub_epi64(_mm_xor_si128(vabsprod_lo13, vnmask_lo13), vnmask_lo13); + const __m128i vprod_hi13 = + _mm_sub_epi64(_mm_xor_si128(vabsprod_hi13, vnmask_hi13), vnmask_hi13); + + const __m128i vq31prod_lo13 = + _mm_srli_epi64(_mm_add_epi64(vprod_lo13, vrounding), 31); + const __m128i vq31prod_hi13 = + _mm_srli_epi64(_mm_add_epi64(vprod_hi13, vrounding), 31); + + const __m128i vq31prod_lo0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vq31prod_lo02), + _mm_castsi128_ps(vq31prod_lo13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vq31prod_hi0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vq31prod_hi02), + _mm_castsi128_ps(vq31prod_hi13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i vq31prod_lo0123 = + _mm_shuffle_epi32(vq31prod_lo0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vq31prod_hi0123 = + _mm_shuffle_epi32(vq31prod_hi0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i vremainder_mask = _mm_load_si128( + (const __m128i*)quantization_params->sse2.remainder_mask); + + const __m128i vrem_lo0123 = _mm_add_epi32( + _mm_and_si128(vq31prod_lo0123, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod_lo0123)); + const __m128i vrem_hi0123 = _mm_add_epi32( + _mm_and_si128(vq31prod_hi0123, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod_hi0123)); + + const __m128i vremainder_threshold = _mm_load_si128( + (const __m128i*)quantization_params->sse2.remainder_threshold); + const __m128i vshift = + _mm_load_si128((const __m128i*)quantization_params->sse2.shift); + + const __m128i vout_lo = _mm_sub_epi32( + _mm_sra_epi32(vq31prod_lo0123, vshift), + _mm_cmpgt_epi32(vrem_lo0123, vremainder_threshold)); + const __m128i vout_hi = _mm_sub_epi32( + _mm_sra_epi32(vq31prod_hi0123, vshift), + _mm_cmpgt_epi32(vrem_hi0123, vremainder_threshold)); + + const __m128i voutput_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.output_zero_point); + __m128i vout = + _mm_adds_epi16(_mm_packs_epi32(vout_lo, vout_hi), voutput_zero_point); + vout = _mm_packus_epi16(vout, vout); + vout = _mm_min_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_max)); + vout = _mm_max_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_min)); + + _mm_storel_epi64((__m128i*)output, vout); + output += 8; + } + if (c != 0) { + const size_t i_predecrement = 8 - c; + const __m128i vi_shift = _mm_cvtsi32_si128(8 * i_predecrement); + i0 -= i_predecrement; + i1 -= i_predecrement; + i2 -= i_predecrement; + i3 -= i_predecrement; + i4 -= i_predecrement; + i5 -= i_predecrement; + i6 -= i_predecrement; + i7 -= i_predecrement; + i8 -= i_predecrement; + + __m128i vacc_lo = _mm_loadu_si128((const __m128i*)w); + __m128i vacc_hi = _mm_loadu_si128((const __m128i*)((uintptr_t)w + 16)); + + const __m128i vi0 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i0), vi_shift); + const __m128i vxi0 = + sub_zero_point(_mm_unpacklo_epi8(vi0, vzero), va_zero_point); + const __m128i vk0 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32)); + const __m128i vxk0 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk0, vzero), vkernel_zero_point); + const __m128i vprod0_odd = _mm_mullo_epi16(vxi0, vxk0); + const __m128i vprod0_even = _mm_mulhi_epi16(vxi0, vxk0); + vacc_lo = + _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod0_odd, vprod0_even)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod0_odd, vprod0_even)); + + const __m128i vi1 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i1), vi_shift); + const __m128i vxi1 = + sub_zero_point(_mm_unpacklo_epi8(vi1, vzero), va_zero_point); + const __m128i vk1 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 40)); + const __m128i vxk1 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk1, vzero), vkernel_zero_point); + const __m128i vprod1_odd = _mm_mullo_epi16(vxi1, vxk1); + const __m128i vprod1_even = _mm_mulhi_epi16(vxi1, vxk1); + vacc_lo = + _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod1_odd, vprod1_even)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod1_odd, vprod1_even)); + + const __m128i vi2 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i2), vi_shift); + const __m128i vxi2 = + sub_zero_point(_mm_unpacklo_epi8(vi2, vzero), va_zero_point); + const __m128i vk2 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 48)); + const __m128i vxk2 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk2, vzero), vkernel_zero_point); + const __m128i vprod2_odd = _mm_mullo_epi16(vxi2, vxk2); + const __m128i vprod2_even = _mm_mulhi_epi16(vxi2, vxk2); + vacc_lo = + _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod2_odd, vprod2_even)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod2_odd, vprod2_even)); + + const __m128i vi3 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i3), vi_shift); + const __m128i vxi3 = + sub_zero_point(_mm_unpacklo_epi8(vi3, vzero), va_zero_point); + const __m128i vk3 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 56)); + const __m128i vxk3 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk3, vzero), vkernel_zero_point); + const __m128i vprod3_odd = _mm_mullo_epi16(vxi3, vxk3); + const __m128i vprod3_even = _mm_mulhi_epi16(vxi3, vxk3); + vacc_lo = + _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod3_odd, vprod3_even)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod3_odd, vprod3_even)); + + const __m128i vi4 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i4), vi_shift); + const __m128i vxi4 = + sub_zero_point(_mm_unpacklo_epi8(vi4, vzero), va_zero_point); + const __m128i vk4 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 64)); + const __m128i vxk4 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk4, vzero), vkernel_zero_point); + const __m128i vprod4_odd = _mm_mullo_epi16(vxi4, vxk4); + const __m128i vprod4_even = _mm_mulhi_epi16(vxi4, vxk4); + vacc_lo = + _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod4_odd, vprod4_even)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod4_odd, vprod4_even)); + + const __m128i vi5 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i5), vi_shift); + const __m128i vxi5 = + sub_zero_point(_mm_unpacklo_epi8(vi5, vzero), va_zero_point); + const __m128i vk5 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 72)); + const __m128i vxk5 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk5, vzero), vkernel_zero_point); + const __m128i vprod5_odd = _mm_mullo_epi16(vxi5, vxk5); + const __m128i vprod5_even = _mm_mulhi_epi16(vxi5, vxk5); + vacc_lo = + _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod5_odd, vprod5_even)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod5_odd, vprod5_even)); + + const __m128i vi6 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i6), vi_shift); + const __m128i vxi6 = + sub_zero_point(_mm_unpacklo_epi8(vi6, vzero), va_zero_point); + const __m128i vk6 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 80)); + const __m128i vxk6 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk6, vzero), vkernel_zero_point); + const __m128i vprod6_odd = _mm_mullo_epi16(vxi6, vxk6); + const __m128i vprod6_even = _mm_mulhi_epi16(vxi6, vxk6); + vacc_lo = + _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod6_odd, vprod6_even)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod6_odd, vprod6_even)); + + const __m128i vi7 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i7), vi_shift); + const __m128i vxi7 = + sub_zero_point(_mm_unpacklo_epi8(vi7, vzero), va_zero_point); + const __m128i vk7 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 88)); + const __m128i vxk7 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk7, vzero), vkernel_zero_point); + const __m128i vprod7_odd = _mm_mullo_epi16(vxi7, vxk7); + const __m128i vprod7_even = _mm_mulhi_epi16(vxi7, vxk7); + vacc_lo = + _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod7_odd, vprod7_even)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod7_odd, vprod7_even)); + + const __m128i vi8 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i8), vi_shift); + const __m128i vxi8 = + sub_zero_point(_mm_unpacklo_epi8(vi8, vzero), va_zero_point); + const __m128i vk8 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 96)); + const __m128i vxk8 = + _mm_sub_epi16(_mm_unpacklo_epi8(vk8, vzero), vkernel_zero_point); + const __m128i vprod8_odd = _mm_mullo_epi16(vxi8, vxk8); + const __m128i vprod8_even = _mm_mulhi_epi16(vxi8, vxk8); + vacc_lo = + _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vprod8_odd, vprod8_even)); + vacc_hi = + _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod8_odd, vprod8_even)); + + const __m128i vmultiplier = + _mm_load_si128((const __m128i*)quantization_params->sse2.multiplier); + const __m128i vrounding = + _mm_load_si128((const __m128i*)quantization_params->sse2.rounding); + + const __m128i vnmask_lo0123 = + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo); + const __m128i vnmask_hi0123 = + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi); + + const __m128i vabsacc_lo0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_lo, vnmask_lo0123), vnmask_lo0123); + const __m128i vabsacc_hi0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_hi, vnmask_hi0123), vnmask_hi0123); + + const __m128i vabsacc_lo1032 = + _mm_shuffle_epi32(vabsacc_lo0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabsacc_hi1032 = + _mm_shuffle_epi32(vabsacc_hi0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i vabsprod_lo02 = _mm_mul_epu32(vabsacc_lo0123, vmultiplier); + const __m128i vabsprod_hi02 = _mm_mul_epu32(vabsacc_hi0123, vmultiplier); + + const __m128i vnmask_lo02 = + _mm_shuffle_epi32(vnmask_lo0123, _MM_SHUFFLE(2, 2, 0, 0)); + const __m128i vnmask_hi02 = + _mm_shuffle_epi32(vnmask_hi0123, _MM_SHUFFLE(2, 2, 0, 0)); + + const __m128i vprod_lo02 = + _mm_sub_epi64(_mm_xor_si128(vabsprod_lo02, vnmask_lo02), vnmask_lo02); + const __m128i vprod_hi02 = + _mm_sub_epi64(_mm_xor_si128(vabsprod_hi02, vnmask_hi02), vnmask_hi02); + + const __m128i vq31prod_lo02 = + _mm_srli_epi64(_mm_add_epi64(vprod_lo02, vrounding), 31); + const __m128i vq31prod_hi02 = + _mm_srli_epi64(_mm_add_epi64(vprod_hi02, vrounding), 31); + + const __m128i vabsprod_lo13 = _mm_mul_epu32(vabsacc_lo1032, vmultiplier); + const __m128i vabsprod_hi13 = _mm_mul_epu32(vabsacc_hi1032, vmultiplier); + + const __m128i vnmask_lo13 = + _mm_shuffle_epi32(vnmask_lo0123, _MM_SHUFFLE(3, 3, 1, 1)); + const __m128i vnmask_hi13 = + _mm_shuffle_epi32(vnmask_hi0123, _MM_SHUFFLE(3, 3, 1, 1)); + + const __m128i vprod_lo13 = + _mm_sub_epi64(_mm_xor_si128(vabsprod_lo13, vnmask_lo13), vnmask_lo13); + const __m128i vprod_hi13 = + _mm_sub_epi64(_mm_xor_si128(vabsprod_hi13, vnmask_hi13), vnmask_hi13); + + const __m128i vq31prod_lo13 = + _mm_srli_epi64(_mm_add_epi64(vprod_lo13, vrounding), 31); + const __m128i vq31prod_hi13 = + _mm_srli_epi64(_mm_add_epi64(vprod_hi13, vrounding), 31); + + const __m128i vq31prod_lo0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vq31prod_lo02), + _mm_castsi128_ps(vq31prod_lo13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vq31prod_hi0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vq31prod_hi02), + _mm_castsi128_ps(vq31prod_hi13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i vq31prod_lo0123 = + _mm_shuffle_epi32(vq31prod_lo0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vq31prod_hi0123 = + _mm_shuffle_epi32(vq31prod_hi0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i vremainder_mask = _mm_load_si128( + (const __m128i*)quantization_params->sse2.remainder_mask); + + const __m128i vrem_lo0123 = _mm_add_epi32( + _mm_and_si128(vq31prod_lo0123, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod_lo0123)); + const __m128i vrem_hi0123 = _mm_add_epi32( + _mm_and_si128(vq31prod_hi0123, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod_hi0123)); + + const __m128i vremainder_threshold = _mm_load_si128( + (const __m128i*)quantization_params->sse2.remainder_threshold); + const __m128i vshift = + _mm_load_si128((const __m128i*)quantization_params->sse2.shift); + + const __m128i vout_lo = _mm_sub_epi32( + _mm_sra_epi32(vq31prod_lo0123, vshift), + _mm_cmpgt_epi32(vrem_lo0123, vremainder_threshold)); + const __m128i vout_hi = _mm_sub_epi32( + _mm_sra_epi32(vq31prod_hi0123, vshift), + _mm_cmpgt_epi32(vrem_hi0123, vremainder_threshold)); + + const __m128i voutput_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.output_zero_point); + __m128i vout = + _mm_adds_epi16(_mm_packs_epi32(vout_lo, vout_hi), voutput_zero_point); + vout = _mm_packus_epi16(vout, vout); + vout = _mm_min_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_max)); + vout = _mm_max_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_min)); + + if (c & 4) { + *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout); + output += 4; + vout = _mm_srli_epi64(vout, 32); + } + if (c & 2) { + *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0); + output += 2; + vout = _mm_srli_epi32(vout, 16); + } + if (c & 1) { + *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout); + output += 1; + } + } + + output = (uint8_t*)((uintptr_t)output + output_increment); + } while (--output_width != 0); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/mp8x7p7q-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/mp8x7p7q-neon.c new file mode 100644 index 0000000000000..6375d9b2c7c9a --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/mp8x7p7q-neon.c @@ -0,0 +1,395 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +void pytorch_q8gavgpool_ukernel_mp8x7p7q__neon( + size_t m, + size_t n, + const uint8_t* input, + size_t input_stride, + const uint8_t* zero, + int32_t* buffer, + uint8_t* output, + const union pytorch_qnnp_avgpool_quantization_params + quantization_params[restrict static 1]) { + assert(m > 7); + assert(n >= 8); + + const uint8_t* i0 = input; + const uint8_t* i1 = i0 + input_stride; + const uint8_t* i2 = i1 + input_stride; + const uint8_t* i3 = i2 + input_stride; + const uint8_t* i4 = i3 + input_stride; + const uint8_t* i5 = i4 + input_stride; + const uint8_t* i6 = i5 + input_stride; + const size_t packed_n = (n + 7) & -8; + const size_t input_increment = 7 * input_stride - packed_n; + const int32x4_t vbias = vld1q_dup_s32(&quantization_params->neon.bias); + + /* note: goes up to 7 elements over bound */ + int32_t* acc = buffer; + for (size_t k = 0; k < n; k += 8) { + const uint8x8_t vi0 = vld1_u8(i0); + i0 += 8; + const uint8x8_t vi1 = vld1_u8(i1); + i1 += 8; + const uint8x8_t vi2 = vld1_u8(i2); + i2 += 8; + const uint8x8_t vi3 = vld1_u8(i3); + i3 += 8; + const uint8x8_t vi4 = vld1_u8(i4); + i4 += 8; + const uint8x8_t vi5 = vld1_u8(i5); + i5 += 8; + const uint8x8_t vi6 = vld1_u8(i6); + i6 += 8; + + const int16x8_t vsum016 = + vreinterpretq_s16_u16(vaddw_u8(vaddl_u8(vi0, vi1), vi6)); + const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3)); + const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5)); + + int32x4_t vacc_lo = vaddw_s16(vbias, vget_low_s16(vsum23)); + int32x4_t vacc_hi = vaddw_s16(vbias, vget_high_s16(vsum23)); + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum45)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum45)); + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum016)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum016)); + vst1q_s32(acc, vacc_lo); + acc += 4; + vst1q_s32(acc, vacc_hi); + acc += 4; + } + for (m -= 7; m > 7; m -= 7) { + acc = buffer; + i0 = (const uint8_t*)((uintptr_t)i0 + input_increment); + i1 = (const uint8_t*)((uintptr_t)i1 + input_increment); + i2 = (const uint8_t*)((uintptr_t)i2 + input_increment); + i3 = (const uint8_t*)((uintptr_t)i3 + input_increment); + i4 = (const uint8_t*)((uintptr_t)i4 + input_increment); + i5 = (const uint8_t*)((uintptr_t)i5 + input_increment); + i6 = (const uint8_t*)((uintptr_t)i6 + input_increment); + + /* note: goes up to 7 elements over bound */ + for (size_t k = 0; k < n; k += 8) { + const uint8x8_t vi0 = vld1_u8(i0); + i0 += 8; + const uint8x8_t vi1 = vld1_u8(i1); + i1 += 8; + const uint8x8_t vi2 = vld1_u8(i2); + i2 += 8; + const uint8x8_t vi3 = vld1_u8(i3); + i3 += 8; + const uint8x8_t vi4 = vld1_u8(i4); + i4 += 8; + const uint8x8_t vi5 = vld1_u8(i5); + i5 += 8; + const uint8x8_t vi6 = vld1_u8(i6); + i6 += 8; + int32x4_t vacc_lo = vld1q_s32(acc); + int32x4_t vacc_hi = vld1q_s32(acc + 4); + + const int16x8_t vsum016 = + vreinterpretq_s16_u16(vaddw_u8(vaddl_u8(vi0, vi1), vi6)); + const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3)); + const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5)); + + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum23)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum23)); + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum45)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum45)); + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum016)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum016)); + vst1q_s32(acc, vacc_lo); + acc += 4; + vst1q_s32(acc, vacc_hi); + acc += 4; + } + } + +#ifdef __aarch64__ + const int32x4_t vmultiplier = + vld1q_dup_s32(&quantization_params->neon.multiplier); +#else + const int32x2_t vmultiplier = + vld1_dup_s32(&quantization_params->neon.multiplier); +#endif + const int64x2_t vleft_shift = + vld1q_dup_s64(&quantization_params->neon.left_shift); + const int16x8_t voutput_zero_point = + vld1q_dup_s16(&quantization_params->neon.output_zero_point); + const uint8x8_t voutput_min = + vld1_dup_u8(&quantization_params->neon.output_min); + const uint8x8_t voutput_max = + vld1_dup_u8(&quantization_params->neon.output_max); + + i0 = (const uint8_t*)((uintptr_t)i0 + input_increment); + i1 = (const uint8_t*)((uintptr_t)i1 + input_increment); + if (m < 2) { + i1 = zero; + } + i2 = (const uint8_t*)((uintptr_t)i2 + input_increment); + if (m <= 2) { + i2 = zero; + } + i3 = (const uint8_t*)((uintptr_t)i3 + input_increment); + if (m < 4) { + i3 = zero; + } + i4 = (const uint8_t*)((uintptr_t)i4 + input_increment); + if (m <= 4) { + i4 = zero; + } + i5 = (const uint8_t*)((uintptr_t)i5 + input_increment); + if (m < 6) { + i5 = zero; + } + i6 = (const uint8_t*)((uintptr_t)i6 + input_increment); + if (m <= 6) { + i6 = zero; + } + + acc = buffer; + do { + const uint8x8_t vi0 = vld1_u8(i0); + i0 += 8; + const uint8x8_t vi1 = vld1_u8(i1); + i1 += 8; + const uint8x8_t vi2 = vld1_u8(i2); + i2 += 8; + const uint8x8_t vi3 = vld1_u8(i3); + i3 += 8; + const uint8x8_t vi4 = vld1_u8(i4); + i4 += 8; + const uint8x8_t vi5 = vld1_u8(i5); + i5 += 8; + const uint8x8_t vi6 = vld1_u8(i6); + i6 += 8; + int32x4_t vacc_lo = vld1q_s32(acc); + acc += 4; + int32x4_t vacc_hi = vld1q_s32(acc); + acc += 4; + + const int16x8_t vsum016 = + vreinterpretq_s16_u16(vaddw_u8(vaddl_u8(vi0, vi1), vi6)); + const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3)); + const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5)); + + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum23)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum23)); + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum45)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum45)); + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum016)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum016)); + + const int32x4_t vneg_mask_lo = + vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0))); + const int32x4_t vneg_mask_hi = + vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0))); + +#if defined(__aarch64__) + const int64x2_t vproduct01 = + vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier)); + const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier); + const int64x2_t vproduct45 = + vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier)); + const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_high_s32(vproduct23, vneg_mask_lo); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_high_s32(vproduct67, vneg_mask_hi); +#else + const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier); + const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi)); +#endif + + const int64x2_t vscaled_acc01 = + vrshlq_s64(vadjusted_product01, vleft_shift); + const int64x2_t vscaled_acc23 = + vrshlq_s64(vadjusted_product23, vleft_shift); + const int64x2_t vscaled_acc45 = + vrshlq_s64(vadjusted_product45, vleft_shift); + const int64x2_t vscaled_acc67 = + vrshlq_s64(vadjusted_product67, vleft_shift); + +#ifdef __aarch64__ + vacc_lo = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc01), + vreinterpretq_s32_s64(vscaled_acc23)); + vacc_hi = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc45), + vreinterpretq_s32_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point); +#else + vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23)); + vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), + voutput_zero_point); +#endif + + uint8x8_t vout = vqmovun_s16(vacc); + vout = vmax_u8(vout, voutput_min); + vout = vmin_u8(vout, voutput_max); + + vst1_u8(output, vout); + output += 8; + + n -= 8; + } while (n >= 8); + if (n != 0) { + const size_t address_increment = n - 8; + i0 = (const uint8_t*)((uintptr_t)i0 + address_increment); + i1 = (const uint8_t*)((uintptr_t)i1 + address_increment); + i2 = (const uint8_t*)((uintptr_t)i2 + address_increment); + i3 = (const uint8_t*)((uintptr_t)i3 + address_increment); + i4 = (const uint8_t*)((uintptr_t)i4 + address_increment); + i5 = (const uint8_t*)((uintptr_t)i5 + address_increment); + i6 = (const uint8_t*)((uintptr_t)i6 + address_increment); + const int64x1_t vshift = vmov_n_s64(8 * address_increment); + + const uint8x8_t vi0 = + vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vshift)); + const uint8x8_t vi1 = + vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vshift)); + const uint8x8_t vi2 = + vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vshift)); + const uint8x8_t vi3 = + vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vshift)); + const uint8x8_t vi4 = + vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vshift)); + const uint8x8_t vi5 = + vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vshift)); + const uint8x8_t vi6 = + vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vshift)); + int32x4_t vacc_lo = vld1q_s32(acc); + acc += 4; + int32x4_t vacc_hi = vld1q_s32(acc); + + const int16x8_t vsum016 = + vreinterpretq_s16_u16(vaddw_u8(vaddl_u8(vi0, vi1), vi6)); + const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3)); + const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5)); + + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum23)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum23)); + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum45)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum45)); + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum016)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum016)); + + const int32x4_t vneg_mask_lo = + vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0))); + const int32x4_t vneg_mask_hi = + vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0))); + +#if defined(__aarch64__) + const int64x2_t vproduct01 = + vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier)); + const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier); + const int64x2_t vproduct45 = + vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier)); + const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_high_s32(vproduct23, vneg_mask_lo); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_high_s32(vproduct67, vneg_mask_hi); +#else + const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier); + const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi)); +#endif + + const int64x2_t vscaled_acc01 = + vrshlq_s64(vadjusted_product01, vleft_shift); + const int64x2_t vscaled_acc23 = + vrshlq_s64(vadjusted_product23, vleft_shift); + const int64x2_t vscaled_acc45 = + vrshlq_s64(vadjusted_product45, vleft_shift); + const int64x2_t vscaled_acc67 = + vrshlq_s64(vadjusted_product67, vleft_shift); + +#ifdef __aarch64__ + vacc_lo = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc01), + vreinterpretq_s32_s64(vscaled_acc23)); + vacc_hi = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc45), + vreinterpretq_s32_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point); +#else + vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23)); + vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), + voutput_zero_point); +#endif + + uint8x8_t vout = vqmovun_s16(vacc); + vout = vmax_u8(vout, voutput_min); + vout = vmin_u8(vout, voutput_max); + + if (n & 4) { + vst1_lane_u32( + __builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0); + output += 4; + vout = vext_u8(vout, vout, 4); + } + if (n & 2) { + vst1_lane_u16( + __builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0); + output += 2; + vout = vext_u8(vout, vout, 2); + } + if (n & 1) { + vst1_lane_u8(output, vout, 0); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/mp8x7p7q-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/mp8x7p7q-sse2.c new file mode 100644 index 0000000000000..8d6658033f14b --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/mp8x7p7q-sse2.c @@ -0,0 +1,410 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +void pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2( + size_t m, + size_t n, + const uint8_t* input, + size_t input_stride, + const uint8_t* zero, + int32_t* buffer, + uint8_t* output, + const union pytorch_qnnp_avgpool_quantization_params + quantization_params[RESTRICT_STATIC 1]) { + assert(m > 7); + assert(n >= 8); + + const uint8_t* i0 = input; + const uint8_t* i1 = i0 + input_stride; + const uint8_t* i2 = i1 + input_stride; + const uint8_t* i3 = i2 + input_stride; + const uint8_t* i4 = i3 + input_stride; + const uint8_t* i5 = i4 + input_stride; + const uint8_t* i6 = i5 + input_stride; + const size_t packed_n = (n + 7) & -8; + const size_t input_increment = 7 * input_stride - packed_n; + const __m128i vbias = + _mm_load_si128((const __m128i*)&quantization_params->sse2.bias); + const __m128i vzero = _mm_setzero_si128(); + + /* note: goes up to 7 elements over bound */ + int32_t* acc = buffer; + for (size_t k = 0; k < n; k += 8) { + const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0); + i0 += 8; + const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1); + i1 += 8; + const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2); + i2 += 8; + const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3); + i3 += 8; + const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4); + i4 += 8; + const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5); + i5 += 8; + const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6); + i6 += 8; + + const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero); + const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero); + const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero); + const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero); + const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero); + const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero); + const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero); + + __m128i vacc_lo = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vxi0, vzero)); + __m128i vacc_hi = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vxi0, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi1, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi1, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi2, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi2, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi3, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi3, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi4, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi4, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi5, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi5, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi6, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi6, vzero)); + + _mm_store_si128((__m128i*)acc, vacc_lo); + _mm_store_si128((__m128i*)acc + 1, vacc_hi); + acc += 8; + } + for (m -= 7; m > 7; m -= 7) { + acc = buffer; + i0 = (const uint8_t*)((uintptr_t)i0 + input_increment); + i1 = (const uint8_t*)((uintptr_t)i1 + input_increment); + i2 = (const uint8_t*)((uintptr_t)i2 + input_increment); + i3 = (const uint8_t*)((uintptr_t)i3 + input_increment); + i4 = (const uint8_t*)((uintptr_t)i4 + input_increment); + i5 = (const uint8_t*)((uintptr_t)i5 + input_increment); + i6 = (const uint8_t*)((uintptr_t)i6 + input_increment); + + /* note: goes up to 7 elements over bound */ + for (size_t k = 0; k < n; k += 8) { + const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0); + i0 += 8; + const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1); + i1 += 8; + const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2); + i2 += 8; + const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3); + i3 += 8; + const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4); + i4 += 8; + const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5); + i5 += 8; + const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6); + i6 += 8; + __m128i vacc_lo = _mm_load_si128((const __m128i*)acc); + __m128i vacc_hi = _mm_load_si128((const __m128i*)acc + 1); + + const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero); + const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero); + const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero); + const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero); + const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero); + const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero); + const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero); + + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi0, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi0, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi1, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi1, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi2, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi2, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi3, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi3, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi4, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi4, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi5, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi5, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi6, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi6, vzero)); + + _mm_store_si128((__m128i*)acc, vacc_lo); + _mm_store_si128((__m128i*)acc + 1, vacc_hi); + acc += 8; + } + } + + const __m128i vmultiplier = + _mm_load_si128((const __m128i*)quantization_params->sse2.multiplier); + const __m128i vrounding = + _mm_load_si128((const __m128i*)quantization_params->sse2.rounding); + const __m128i vright_shift = + _mm_loadl_epi64((const __m128i*)quantization_params->sse2.right_shift); + + i0 = (const uint8_t*)((uintptr_t)i0 + input_increment); + i1 = (const uint8_t*)((uintptr_t)i1 + input_increment); + if (m < 2) { + i1 = zero; + } + i2 = (const uint8_t*)((uintptr_t)i2 + input_increment); + if (m <= 2) { + i2 = zero; + } + i3 = (const uint8_t*)((uintptr_t)i3 + input_increment); + if (m < 4) { + i3 = zero; + } + i4 = (const uint8_t*)((uintptr_t)i4 + input_increment); + if (m <= 4) { + i4 = zero; + } + i5 = (const uint8_t*)((uintptr_t)i5 + input_increment); + if (m < 6) { + i5 = zero; + } + i6 = (const uint8_t*)((uintptr_t)i6 + input_increment); + if (m <= 6) { + i6 = zero; + } + + acc = buffer; + do { + const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0); + i0 += 8; + const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1); + i1 += 8; + const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2); + i2 += 8; + const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3); + i3 += 8; + const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4); + i4 += 8; + const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5); + i5 += 8; + const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6); + i6 += 8; + __m128i vacc_lo = _mm_load_si128((const __m128i*)acc); + __m128i vacc_hi = _mm_load_si128((const __m128i*)acc + 1); + acc += 8; + + const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero); + const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero); + const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero); + const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero); + const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero); + const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero); + const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero); + + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi0, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi0, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi1, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi1, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi2, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi2, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi3, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi3, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi4, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi4, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi5, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi5, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi6, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi6, vzero)); + + const __m128i vneg_mask_lo = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo); + const __m128i vneg_mask_hi = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi); + + const __m128i vabs_lo0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vabs_hi0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi); + + const __m128i vabs_lo1032 = + _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabs_hi1032 = + _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier); + const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier); + + const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier); + const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier); + + const __m128i vabs_scaled_lo02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift); + const __m128i vabs_scaled_lo13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift); + const __m128i vabs_scaled_hi02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift); + const __m128i vabs_scaled_hi13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift); + + const __m128i vabs_scaled_lo0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_lo02), + _mm_castsi128_ps(vabs_scaled_lo13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vabs_scaled_hi0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_hi02), + _mm_castsi128_ps(vabs_scaled_hi13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i vabs_scaled_lo = + _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vabs_scaled_hi = + _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i vscaled_lo = _mm_sub_epi32( + _mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vscaled_hi = _mm_sub_epi32( + _mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi); + + __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi); + vout = _mm_adds_epi16( + vout, + _mm_load_si128( + (const __m128i*)quantization_params->sse2.output_zero_point)); + vout = _mm_packus_epi16(vout, vout); + vout = _mm_min_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_max)); + vout = _mm_max_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_min)); + + _mm_storel_epi64((__m128i*)output, vout); + output += 8; + + n -= 8; + } while (n >= 8); + if (n != 0) { + const size_t address_decrement = 8 - n; + i0 = (const uint8_t*)((uintptr_t)i0 - address_decrement); + i1 = (const uint8_t*)((uintptr_t)i1 - address_decrement); + i2 = (const uint8_t*)((uintptr_t)i2 - address_decrement); + i3 = (const uint8_t*)((uintptr_t)i3 - address_decrement); + i4 = (const uint8_t*)((uintptr_t)i4 - address_decrement); + i5 = (const uint8_t*)((uintptr_t)i5 - address_decrement); + i6 = (const uint8_t*)((uintptr_t)i6 - address_decrement); + const __m128i vi_shift = _mm_cvtsi32_si128(8 * address_decrement); + + const __m128i vi0 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i0), vi_shift); + const __m128i vi1 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i1), vi_shift); + const __m128i vi2 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i2), vi_shift); + const __m128i vi3 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i3), vi_shift); + const __m128i vi4 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i4), vi_shift); + const __m128i vi5 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i5), vi_shift); + const __m128i vi6 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i6), vi_shift); + __m128i vacc_lo = _mm_load_si128((const __m128i*)acc); + __m128i vacc_hi = _mm_load_si128((const __m128i*)acc + 1); + + const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero); + const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero); + const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero); + const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero); + const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero); + const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero); + const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero); + + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi0, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi0, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi1, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi1, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi2, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi2, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi3, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi3, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi4, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi4, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi5, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi5, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi6, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi6, vzero)); + + const __m128i vneg_mask_lo = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo); + const __m128i vneg_mask_hi = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi); + + const __m128i vabs_lo0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vabs_hi0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi); + + const __m128i vabs_lo1032 = + _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabs_hi1032 = + _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier); + const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier); + + const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier); + const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier); + + const __m128i vabs_scaled_lo02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift); + const __m128i vabs_scaled_lo13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift); + const __m128i vabs_scaled_hi02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift); + const __m128i vabs_scaled_hi13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift); + + const __m128i vabs_scaled_lo0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_lo02), + _mm_castsi128_ps(vabs_scaled_lo13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vabs_scaled_hi0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_hi02), + _mm_castsi128_ps(vabs_scaled_hi13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i vabs_scaled_lo = + _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vabs_scaled_hi = + _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i vscaled_lo = _mm_sub_epi32( + _mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vscaled_hi = _mm_sub_epi32( + _mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi); + + __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi); + vout = _mm_adds_epi16( + vout, + _mm_load_si128( + (const __m128i*)quantization_params->sse2.output_zero_point)); + vout = _mm_packus_epi16(vout, vout); + vout = _mm_min_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_max)); + vout = _mm_max_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_min)); + + if (n & 4) { + *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout); + output += 4; + vout = _mm_srli_epi64(vout, 32); + } + if (n & 2) { + *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0); + output += 2; + vout = _mm_srli_epi32(vout, 16); + } + if (n & 1) { + *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8x7-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8x7-neon.c new file mode 100644 index 0000000000000..0d0c81c1f1ac3 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8x7-neon.c @@ -0,0 +1,295 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +void pytorch_q8gavgpool_ukernel_up8x7__neon( + size_t m, + size_t n, + const uint8_t* input, + size_t input_stride, + const uint8_t* zero, + uint8_t* output, + const union pytorch_qnnp_avgpool_quantization_params + quantization_params[restrict static 1]) { + assert(m >= 1); + assert(m <= 7); + assert(n >= 8); + + const uint8_t* i0 = input; + const uint8_t* i1 = i0 + input_stride; + if (m < 2) { + i1 = zero; + } + const uint8_t* i2 = i1 + input_stride; + if (m <= 2) { + i2 = zero; + } + const uint8_t* i3 = i2 + input_stride; + if (m < 4) { + i3 = zero; + } + const uint8_t* i4 = i3 + input_stride; + if (m <= 4) { + i4 = zero; + } + const uint8_t* i5 = i4 + input_stride; + if (m < 6) { + i5 = zero; + } + const uint8_t* i6 = i5 + input_stride; + if (m <= 6) { + i6 = zero; + } + const int32x4_t vbias = vld1q_dup_s32(&quantization_params->neon.bias); +#ifdef __aarch64__ + const int32x4_t vmultiplier = + vld1q_dup_s32(&quantization_params->neon.multiplier); +#else + const int32x2_t vmultiplier = + vld1_dup_s32(&quantization_params->neon.multiplier); +#endif + const int64x2_t vleft_shift = + vld1q_dup_s64(&quantization_params->neon.left_shift); + const int16x8_t voutput_zero_point = + vld1q_dup_s16(&quantization_params->neon.output_zero_point); + const uint8x8_t voutput_min = + vld1_dup_u8(&quantization_params->neon.output_min); + const uint8x8_t voutput_max = + vld1_dup_u8(&quantization_params->neon.output_max); + + do { + const uint8x8_t vi0 = vld1_u8(i0); + i0 += 8; + const uint8x8_t vi1 = vld1_u8(i1); + i1 += 8; + const uint8x8_t vi2 = vld1_u8(i2); + i2 += 8; + const uint8x8_t vi3 = vld1_u8(i3); + i3 += 8; + const uint8x8_t vi4 = vld1_u8(i4); + i4 += 8; + const uint8x8_t vi5 = vld1_u8(i5); + i5 += 8; + const uint8x8_t vi6 = vld1_u8(i6); + i6 += 8; + + const int16x8_t vsum016 = + vreinterpretq_s16_u16(vaddw_u8(vaddl_u8(vi0, vi1), vi6)); + const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3)); + const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5)); + + int32x4_t vacc_lo = vaddw_s16(vbias, vget_low_s16(vsum23)); + int32x4_t vacc_hi = vaddw_s16(vbias, vget_high_s16(vsum23)); + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum45)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum45)); + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum016)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum016)); + + const int32x4_t vneg_mask_lo = + vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0))); + const int32x4_t vneg_mask_hi = + vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0))); + +#if defined(__aarch64__) + const int64x2_t vproduct01 = + vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier)); + const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier); + const int64x2_t vproduct45 = + vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier)); + const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_high_s32(vproduct23, vneg_mask_lo); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_high_s32(vproduct67, vneg_mask_hi); +#else + const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier); + const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi)); +#endif + + const int64x2_t vscaled_acc01 = + vrshlq_s64(vadjusted_product01, vleft_shift); + const int64x2_t vscaled_acc23 = + vrshlq_s64(vadjusted_product23, vleft_shift); + const int64x2_t vscaled_acc45 = + vrshlq_s64(vadjusted_product45, vleft_shift); + const int64x2_t vscaled_acc67 = + vrshlq_s64(vadjusted_product67, vleft_shift); + +#ifdef __aarch64__ + vacc_lo = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc01), + vreinterpretq_s32_s64(vscaled_acc23)); + vacc_hi = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc45), + vreinterpretq_s32_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point); +#else + vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23)); + vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), + voutput_zero_point); +#endif + + uint8x8_t vout = vqmovun_s16(vacc); + vout = vmax_u8(vout, voutput_min); + vout = vmin_u8(vout, voutput_max); + + vst1_u8(output, vout); + output += 8; + + n -= 8; + } while (n >= 8); + if (n != 0) { + const size_t address_increment = n - 8; + i0 = (const uint8_t*)((uintptr_t)i0 + address_increment); + i1 = (const uint8_t*)((uintptr_t)i1 + address_increment); + i2 = (const uint8_t*)((uintptr_t)i2 + address_increment); + i3 = (const uint8_t*)((uintptr_t)i3 + address_increment); + i4 = (const uint8_t*)((uintptr_t)i4 + address_increment); + i5 = (const uint8_t*)((uintptr_t)i5 + address_increment); + i6 = (const uint8_t*)((uintptr_t)i6 + address_increment); + const int64x1_t vshift = vmov_n_s64(8 * address_increment); + + const uint8x8_t vi0 = + vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i0)), vshift)); + const uint8x8_t vi1 = + vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i1)), vshift)); + const uint8x8_t vi2 = + vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i2)), vshift)); + const uint8x8_t vi3 = + vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i3)), vshift)); + const uint8x8_t vi4 = + vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i4)), vshift)); + const uint8x8_t vi5 = + vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i5)), vshift)); + const uint8x8_t vi6 = + vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(i6)), vshift)); + + const int16x8_t vsum016 = + vreinterpretq_s16_u16(vaddw_u8(vaddl_u8(vi0, vi1), vi6)); + const int16x8_t vsum23 = vreinterpretq_s16_u16(vaddl_u8(vi2, vi3)); + const int16x8_t vsum45 = vreinterpretq_s16_u16(vaddl_u8(vi4, vi5)); + + int32x4_t vacc_lo = vaddw_s16(vbias, vget_low_s16(vsum23)); + int32x4_t vacc_hi = vaddw_s16(vbias, vget_high_s16(vsum23)); + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum45)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum45)); + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vsum016)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vsum016)); + + const int32x4_t vneg_mask_lo = + vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0))); + const int32x4_t vneg_mask_hi = + vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0))); + +#if defined(__aarch64__) + const int64x2_t vproduct01 = + vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier)); + const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier); + const int64x2_t vproduct45 = + vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier)); + const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_high_s32(vproduct23, vneg_mask_lo); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_high_s32(vproduct67, vneg_mask_hi); +#else + const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier); + const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi)); +#endif + + const int64x2_t vscaled_acc01 = + vrshlq_s64(vadjusted_product01, vleft_shift); + const int64x2_t vscaled_acc23 = + vrshlq_s64(vadjusted_product23, vleft_shift); + const int64x2_t vscaled_acc45 = + vrshlq_s64(vadjusted_product45, vleft_shift); + const int64x2_t vscaled_acc67 = + vrshlq_s64(vadjusted_product67, vleft_shift); + +#ifdef __aarch64__ + vacc_lo = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc01), + vreinterpretq_s32_s64(vscaled_acc23)); + vacc_hi = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc45), + vreinterpretq_s32_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point); +#else + vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23)); + vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), + voutput_zero_point); +#endif + + uint8x8_t vout = vqmovun_s16(vacc); + vout = vmax_u8(vout, voutput_min); + vout = vmin_u8(vout, voutput_max); + + if (n & 4) { + vst1_lane_u32( + __builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0); + output += 4; + vout = vext_u8(vout, vout, 4); + } + if (n & 2) { + vst1_lane_u16( + __builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0); + output += 2; + vout = vext_u8(vout, vout, 2); + } + if (n & 1) { + vst1_lane_u8(output, vout, 0); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8x7-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8x7-sse2.c new file mode 100644 index 0000000000000..0aa84b39fe05c --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8x7-sse2.c @@ -0,0 +1,291 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +void pytorch_q8gavgpool_ukernel_up8x7__sse2( + size_t m, + size_t n, + const uint8_t* input, + size_t input_stride, + const uint8_t* zero, + uint8_t* output, + const union pytorch_qnnp_avgpool_quantization_params + quantization_params[RESTRICT_STATIC 1]) { + assert(m >= 1); + assert(m <= 7); + assert(n >= 8); + + const uint8_t* i0 = input; + const uint8_t* i1 = i0 + input_stride; + if (m < 2) { + i1 = zero; + } + const uint8_t* i2 = i1 + input_stride; + if (m <= 2) { + i2 = zero; + } + const uint8_t* i3 = i2 + input_stride; + if (m < 4) { + i3 = zero; + } + const uint8_t* i4 = i3 + input_stride; + if (m <= 4) { + i4 = zero; + } + const uint8_t* i5 = i4 + input_stride; + if (m < 6) { + i5 = zero; + } + const uint8_t* i6 = i5 + input_stride; + if (m <= 6) { + i6 = zero; + } + const __m128i vbias = + _mm_load_si128((const __m128i*)&quantization_params->sse2.bias); + const __m128i vzero = _mm_setzero_si128(); + + const __m128i vmultiplier = + _mm_load_si128((const __m128i*)quantization_params->sse2.multiplier); + const __m128i vrounding = + _mm_load_si128((const __m128i*)quantization_params->sse2.rounding); + const __m128i vright_shift = + _mm_loadl_epi64((const __m128i*)quantization_params->sse2.right_shift); + + do { + const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0); + i0 += 8; + const __m128i vi1 = _mm_loadl_epi64((const __m128i*)i1); + i1 += 8; + const __m128i vi2 = _mm_loadl_epi64((const __m128i*)i2); + i2 += 8; + const __m128i vi3 = _mm_loadl_epi64((const __m128i*)i3); + i3 += 8; + const __m128i vi4 = _mm_loadl_epi64((const __m128i*)i4); + i4 += 8; + const __m128i vi5 = _mm_loadl_epi64((const __m128i*)i5); + i5 += 8; + const __m128i vi6 = _mm_loadl_epi64((const __m128i*)i6); + i6 += 8; + + const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero); + const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero); + const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero); + const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero); + const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero); + const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero); + const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero); + + __m128i vacc_lo = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vxi0, vzero)); + __m128i vacc_hi = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vxi0, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi1, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi1, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi2, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi2, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi3, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi3, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi4, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi4, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi5, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi5, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi6, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi6, vzero)); + + const __m128i vneg_mask_lo = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo); + const __m128i vneg_mask_hi = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi); + + const __m128i vabs_lo0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vabs_hi0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi); + + const __m128i vabs_lo1032 = + _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabs_hi1032 = + _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier); + const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier); + + const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier); + const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier); + + const __m128i vabs_scaled_lo02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift); + const __m128i vabs_scaled_lo13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift); + const __m128i vabs_scaled_hi02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift); + const __m128i vabs_scaled_hi13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift); + + const __m128i vabs_scaled_lo0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_lo02), + _mm_castsi128_ps(vabs_scaled_lo13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vabs_scaled_hi0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_hi02), + _mm_castsi128_ps(vabs_scaled_hi13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i vabs_scaled_lo = + _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vabs_scaled_hi = + _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i vscaled_lo = _mm_sub_epi32( + _mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vscaled_hi = _mm_sub_epi32( + _mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi); + + __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi); + vout = _mm_adds_epi16( + vout, + _mm_load_si128( + (const __m128i*)quantization_params->sse2.output_zero_point)); + vout = _mm_packus_epi16(vout, vout); + vout = _mm_min_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_max)); + vout = _mm_max_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_min)); + + _mm_storel_epi64((__m128i*)output, vout); + output += 8; + + n -= 8; + } while (n >= 8); + if (n != 0) { + const size_t address_decrement = 8 - n; + i0 = (const uint8_t*)((uintptr_t)i0 - address_decrement); + i1 = (const uint8_t*)((uintptr_t)i1 - address_decrement); + i2 = (const uint8_t*)((uintptr_t)i2 - address_decrement); + i3 = (const uint8_t*)((uintptr_t)i3 - address_decrement); + i4 = (const uint8_t*)((uintptr_t)i4 - address_decrement); + i5 = (const uint8_t*)((uintptr_t)i5 - address_decrement); + i6 = (const uint8_t*)((uintptr_t)i6 - address_decrement); + const __m128i vi_shift = _mm_cvtsi32_si128(8 * address_decrement); + + const __m128i vi0 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i0), vi_shift); + const __m128i vi1 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i1), vi_shift); + const __m128i vi2 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i2), vi_shift); + const __m128i vi3 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i3), vi_shift); + const __m128i vi4 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i4), vi_shift); + const __m128i vi5 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i5), vi_shift); + const __m128i vi6 = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)i6), vi_shift); + + const __m128i vxi0 = _mm_unpacklo_epi8(vi0, vzero); + const __m128i vxi1 = _mm_unpacklo_epi8(vi1, vzero); + const __m128i vxi2 = _mm_unpacklo_epi8(vi2, vzero); + const __m128i vxi3 = _mm_unpacklo_epi8(vi3, vzero); + const __m128i vxi4 = _mm_unpacklo_epi8(vi4, vzero); + const __m128i vxi5 = _mm_unpacklo_epi8(vi5, vzero); + const __m128i vxi6 = _mm_unpacklo_epi8(vi6, vzero); + + __m128i vacc_lo = _mm_add_epi32(vbias, _mm_unpacklo_epi16(vxi0, vzero)); + __m128i vacc_hi = _mm_add_epi32(vbias, _mm_unpackhi_epi16(vxi0, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi1, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi1, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi2, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi2, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi3, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi3, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi4, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi4, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi5, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi5, vzero)); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi16(vxi6, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vxi6, vzero)); + + const __m128i vneg_mask_lo = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo); + const __m128i vneg_mask_hi = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi); + + const __m128i vabs_lo0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vabs_hi0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi); + + const __m128i vabs_lo1032 = + _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabs_hi1032 = + _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier); + const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier); + + const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier); + const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier); + + const __m128i vabs_scaled_lo02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift); + const __m128i vabs_scaled_lo13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift); + const __m128i vabs_scaled_hi02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift); + const __m128i vabs_scaled_hi13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift); + + const __m128i vabs_scaled_lo0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_lo02), + _mm_castsi128_ps(vabs_scaled_lo13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vabs_scaled_hi0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_hi02), + _mm_castsi128_ps(vabs_scaled_hi13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i vabs_scaled_lo = + _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vabs_scaled_hi = + _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i vscaled_lo = _mm_sub_epi32( + _mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vscaled_hi = _mm_sub_epi32( + _mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi); + + __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi); + vout = _mm_adds_epi16( + vout, + _mm_load_si128( + (const __m128i*)quantization_params->sse2.output_zero_point)); + vout = _mm_packus_epi16(vout, vout); + vout = _mm_min_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_max)); + vout = _mm_max_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_min)); + + if (n & 4) { + *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout); + output += 4; + vout = _mm_srli_epi64(vout, 32); + } + if (n & 2) { + *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0); + output += 2; + vout = _mm_srli_epi32(vout, 16); + } + if (n & 1) { + *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8xm-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8xm-neon.c new file mode 100644 index 0000000000000..71ab645c39d80 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8xm-neon.c @@ -0,0 +1,161 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +void pytorch_q8gavgpool_ukernel_up8xm__neon( + size_t m, + size_t n, + const uint8_t* input, + size_t input_stride, + const uint8_t* zero, + uint8_t* output, + const union pytorch_qnnp_avgpool_quantization_params + quantization_params[restrict static 1]) { + assert(m >= 1); + assert(n < 8); + + const int32x4_t vbias = vld1q_dup_s32(&quantization_params->neon.bias); + int32x4_t vacc_lo = vbias; + int32x4_t vacc_hi = vbias; + while (m >= 8) { + const uint8x8_t vinput = vld1_u8(input); + input += input_stride; + const int16x8_t vxinput = vreinterpretq_s16_u16(vmovl_u8(vinput)); + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vxinput)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vxinput)); + + m--; + } + while (m-- != 0) { + input += n; + uint8x8_t vinput = vmov_n_u8(0); + if (n & 1) { + input -= 1; + vinput = vld1_lane_u8(input, vinput, 0); + } + if (n & 2) { + vinput = vext_u8(vinput, vinput, 6); + input -= 2; + vinput = vreinterpret_u8_u16(vld1_lane_u16( + __builtin_assume_aligned(input, 1), vreinterpret_u16_u8(vinput), 0)); + } + if (n & 4) { + vinput = vext_u8(vinput, vinput, 4); + input -= 4; + vinput = vreinterpret_u8_u32(vld1_lane_u32( + __builtin_assume_aligned(input, 1), vreinterpret_u32_u8(vinput), 0)); + } + input += input_stride; + + const int16x8_t vxinput = vreinterpretq_s16_u16(vmovl_u8(vinput)); + vacc_lo = vaddw_s16(vacc_lo, vget_low_s16(vxinput)); + vacc_hi = vaddw_s16(vacc_hi, vget_high_s16(vxinput)); + } + +#ifdef __aarch64__ + const int32x4_t vmultiplier = + vld1q_dup_s32(&quantization_params->neon.multiplier); +#else + const int32x2_t vmultiplier = + vld1_dup_s32(&quantization_params->neon.multiplier); +#endif + const int64x2_t vleft_shift = + vld1q_dup_s64(&quantization_params->neon.left_shift); + const int16x8_t voutput_zero_point = + vld1q_dup_s16(&quantization_params->neon.output_zero_point); + const uint8x8_t voutput_min = + vld1_dup_u8(&quantization_params->neon.output_min); + const uint8x8_t voutput_max = + vld1_dup_u8(&quantization_params->neon.output_max); + + const int32x4_t vneg_mask_lo = + vreinterpretq_s32_u32(vcltq_s32(vacc_lo, vmovq_n_s32(0))); + const int32x4_t vneg_mask_hi = + vreinterpretq_s32_u32(vcltq_s32(vacc_hi, vmovq_n_s32(0))); + +#if defined(__aarch64__) + const int64x2_t vproduct01 = + vmull_s32(vget_low_s32(vacc_lo), vget_low_s32(vmultiplier)); + const int64x2_t vproduct23 = vmull_high_s32(vacc_lo, vmultiplier); + const int64x2_t vproduct45 = + vmull_s32(vget_low_s32(vacc_hi), vget_low_s32(vmultiplier)); + const int64x2_t vproduct67 = vmull_high_s32(vacc_hi, vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_high_s32(vproduct23, vneg_mask_lo); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_high_s32(vproduct67, vneg_mask_hi); +#else + const int64x2_t vproduct01 = vmull_s32(vget_low_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct23 = vmull_s32(vget_high_s32(vacc_lo), vmultiplier); + const int64x2_t vproduct45 = vmull_s32(vget_low_s32(vacc_hi), vmultiplier); + const int64x2_t vproduct67 = vmull_s32(vget_high_s32(vacc_hi), vmultiplier); + + const int64x2_t vadjusted_product01 = + vaddw_s32(vproduct01, vget_low_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product23 = + vaddw_s32(vproduct23, vget_high_s32(vneg_mask_lo)); + const int64x2_t vadjusted_product45 = + vaddw_s32(vproduct45, vget_low_s32(vneg_mask_hi)); + const int64x2_t vadjusted_product67 = + vaddw_s32(vproduct67, vget_high_s32(vneg_mask_hi)); +#endif + + const int64x2_t vscaled_acc01 = vrshlq_s64(vadjusted_product01, vleft_shift); + const int64x2_t vscaled_acc23 = vrshlq_s64(vadjusted_product23, vleft_shift); + const int64x2_t vscaled_acc45 = vrshlq_s64(vadjusted_product45, vleft_shift); + const int64x2_t vscaled_acc67 = vrshlq_s64(vadjusted_product67, vleft_shift); + +#ifdef __aarch64__ + vacc_lo = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc01), + vreinterpretq_s32_s64(vscaled_acc23)); + vacc_hi = vuzp1q_s32( + vreinterpretq_s32_s64(vscaled_acc45), + vreinterpretq_s32_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), voutput_zero_point); +#else + vacc_lo = vcombine_s32(vmovn_s64(vscaled_acc01), vmovn_s64(vscaled_acc23)); + vacc_hi = vcombine_s32(vmovn_s64(vscaled_acc45), vmovn_s64(vscaled_acc67)); + + const int16x8_t vacc = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), + voutput_zero_point); +#endif + + uint8x8_t vout = vqmovun_s16(vacc); + vout = vmax_u8(vout, voutput_min); + vout = vmin_u8(vout, voutput_max); + + if (n & 4) { + vst1_lane_u32( + __builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0); + output += 4; + vout = vext_u8(vout, vout, 4); + } + if (n & 2) { + vst1_lane_u16( + __builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0); + output += 2; + vout = vext_u8(vout, vout, 2); + } + if (n & 1) { + vst1_lane_u8(output, vout, 0); + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8xm-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8xm-sse2.c new file mode 100644 index 0000000000000..1798282963f57 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gavgpool/up8xm-sse2.c @@ -0,0 +1,145 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +void pytorch_q8gavgpool_ukernel_up8xm__sse2( + size_t m, + size_t n, + const uint8_t* input, + size_t input_stride, + const uint8_t* zero, + uint8_t* output, + const union pytorch_qnnp_avgpool_quantization_params + quantization_params[RESTRICT_STATIC 1]) { + assert(m >= 1); + assert(n < 8); + + const __m128i vbias = + _mm_loadu_si128((const __m128i*)&quantization_params->sse2.bias); + __m128i vacc_lo = vbias; + __m128i vacc_hi = vbias; + __m128i vzero = _mm_setzero_si128(); + while (m >= 8) { + const __m128i vinput = _mm_loadl_epi64((const __m128i*)input); + const __m128i vxinput = _mm_unpacklo_epi8(vinput, vzero); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi8(vxinput, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi8(vxinput, vzero)); + + input += input_stride; + m--; + } + while (m-- != 0) { + input += n; + __m128i vinput = _mm_setzero_si128(); + if (n & 1) { + input -= 1; + vinput = _mm_cvtsi32_si128((int)(uint32_t)*input); + } + if (n & 2) { + vinput = _mm_slli_epi32(vinput, 16); + input -= 2; + vinput = _mm_insert_epi16(vinput, *((const uint16_t*)input), 0); + } + if (n & 4) { + input -= 4; + vinput = _mm_unpacklo_epi32( + _mm_cvtsi32_si128((int)*((const uint32_t*)input)), vinput); + } + input += input_stride; + + const __m128i vxinput = _mm_unpacklo_epi8(vinput, vzero); + vacc_lo = _mm_add_epi32(vacc_lo, _mm_unpacklo_epi8(vxinput, vzero)); + vacc_hi = _mm_add_epi32(vacc_hi, _mm_unpackhi_epi8(vxinput, vzero)); + } + + const __m128i vmultiplier = + _mm_load_si128((const __m128i*)quantization_params->sse2.multiplier); + const __m128i vrounding = + _mm_load_si128((const __m128i*)quantization_params->sse2.rounding); + const __m128i vright_shift = + _mm_loadl_epi64((const __m128i*)quantization_params->sse2.right_shift); + + const __m128i vneg_mask_lo = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo); + const __m128i vneg_mask_hi = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi); + + const __m128i vabs_lo0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vabs_hi0123 = + _mm_sub_epi32(_mm_xor_si128(vacc_hi, vneg_mask_hi), vneg_mask_hi); + + const __m128i vabs_lo1032 = + _mm_shuffle_epi32(vabs_lo0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabs_hi1032 = + _mm_shuffle_epi32(vabs_hi0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i vabsmul_lo02 = _mm_mul_epu32(vabs_lo0123, vmultiplier); + const __m128i vabsmul_hi02 = _mm_mul_epu32(vabs_hi0123, vmultiplier); + + const __m128i vabsmul_lo13 = _mm_mul_epu32(vabs_lo1032, vmultiplier); + const __m128i vabsmul_hi13 = _mm_mul_epu32(vabs_hi1032, vmultiplier); + + const __m128i vabs_scaled_lo02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo02, vrounding), vright_shift); + const __m128i vabs_scaled_lo13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_lo13, vrounding), vright_shift); + const __m128i vabs_scaled_hi02 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi02, vrounding), vright_shift); + const __m128i vabs_scaled_hi13 = + _mm_srl_epi64(_mm_add_epi64(vabsmul_hi13, vrounding), vright_shift); + + const __m128i vabs_scaled_lo0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_lo02), + _mm_castsi128_ps(vabs_scaled_lo13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vabs_scaled_hi0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vabs_scaled_hi02), + _mm_castsi128_ps(vabs_scaled_hi13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i vabs_scaled_lo = + _mm_shuffle_epi32(vabs_scaled_lo0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vabs_scaled_hi = + _mm_shuffle_epi32(vabs_scaled_hi0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i vscaled_lo = + _mm_sub_epi32(_mm_xor_si128(vabs_scaled_lo, vneg_mask_lo), vneg_mask_lo); + const __m128i vscaled_hi = + _mm_sub_epi32(_mm_xor_si128(vabs_scaled_hi, vneg_mask_hi), vneg_mask_hi); + + __m128i vout = _mm_packs_epi32(vscaled_lo, vscaled_hi); + vout = _mm_adds_epi16( + vout, + _mm_load_si128( + (const __m128i*)quantization_params->sse2.output_zero_point)); + vout = _mm_packus_epi16(vout, vout); + vout = _mm_min_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_max)); + vout = _mm_max_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_min)); + + if (n & 4) { + *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout); + output += 4; + vout = _mm_srli_epi64(vout, 32); + } + if (n & 2) { + *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0); + output += 2; + vout = _mm_srli_epi32(vout, 16); + } + if (n & 1) { + *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout); + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/2x4c8-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/2x4c8-sse2.c new file mode 100644 index 0000000000000..41712d2627e3c --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/2x4c8-sse2.c @@ -0,0 +1,291 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +static inline __m128i pytorch_sse_reduce4_i32( + __m128i x, + __m128i y, + __m128i z, + __m128i w) { +#if defined(__SSSE3__) && !defined(__ANDROID__) + /* xxyy = ( y2 + y3, y0 + y1, x2 + x3, x0 + x1 ) */ + const __m128i xxyy = _mm_hadd_epi32(x, y); + /* zzww = ( w2 + w3, w0 + w1, z2 + z3, z0 + z1 ) */ + const __m128i zzww = _mm_hadd_epi32(z, w); + /* xyzw = ( w0 + w1 + w2 + w3, y0 + y1 + y2 + y3, z0 + z1 + z2 + z3, x0 + x1 + + * x2 + x3 ) */ + return _mm_hadd_epi32(xxyy, zzww); +#else + /* xzxz = ( z1 + z3, x1 + x3, z0 + z2, x0 + x2 ) */ + const __m128i xzxz = + _mm_add_epi32(_mm_unpacklo_epi32(x, z), _mm_unpackhi_epi32(x, z)); + /* ywyw = ( w1 + w3, y1 + y3, w0 + w2, y0 + y2 ) */ + const __m128i ywyw = + _mm_add_epi32(_mm_unpacklo_epi32(y, w), _mm_unpackhi_epi32(y, w)); + /* xyzw = ( w0 + w2 + w1 + w3, y0 + y2 + y1 + y3, z0 + z2 + z1 + z3, x0 + x2 + + * x1 + x3 ) */ + return _mm_add_epi32( + _mm_unpacklo_epi32(xzxz, ywyw), _mm_unpackhi_epi32(xzxz, ywyw)); +#endif +} + +void pytorch_q8gemm_ukernel_2x4c8__sse2( + size_t mr, + size_t nr, + size_t k, + const uint8_t* restrict a, + size_t a_stride, + const void* restrict w, + uint8_t* restrict c, + size_t c_stride, + const union pytorch_qnnp_conv_quantization_params + quantization_params[RESTRICT_STATIC 1]) { + __m128i vacc00 = _mm_cvtsi32_si128((int)((const int32_t*)w)[0]); + __m128i vacc01 = _mm_cvtsi32_si128((int)((const int32_t*)w)[1]); + __m128i vacc02 = _mm_cvtsi32_si128((int)((const int32_t*)w)[2]); + __m128i vacc03 = _mm_cvtsi32_si128((int)((const int32_t*)w)[3]); + __m128i vacc10 = vacc00; + __m128i vacc11 = vacc01; + __m128i vacc12 = vacc02; + __m128i vacc13 = vacc03; + w = (const void*)((uintptr_t)w + 16); + + const uint8_t* a0 = a; + const uint8_t* a1 = (const uint8_t*)((uintptr_t)a0 + a_stride); + if (mr != 2) { + a1 = a0; + } + + const uint8_t* b0 = w; + const uint8_t* b1 = b0 + 8; + if (nr < 2) { + b1 = b0; + } + const uint8_t* b2 = b1 + 8; + if (nr <= 2) { + b2 = b1; + } + const uint8_t* b3 = b2 + 8; + if (nr != 4) { + b3 = b2; + } + const size_t b_stride = nr * 8; + + const __m128i va_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.input_zero_point); + const __m128i vb_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.kernel_zero_point); + const __m128i vzero = _mm_setzero_si128(); + for (; k >= 8; k -= 8) { + const __m128i va0 = _mm_loadl_epi64((const __m128i*)a0); + const __m128i vxa0 = + sub_zero_point(_mm_unpacklo_epi8(va0, vzero), va_zero_point); + a0 += 8; + const __m128i va1 = _mm_loadl_epi64((const __m128i*)a1); + const __m128i vxa1 = + sub_zero_point(_mm_unpacklo_epi8(va1, vzero), va_zero_point); + a1 += 8; + + const __m128i vb0 = _mm_loadl_epi64((const __m128i*)b0); + const __m128i vxb0 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point); + b0 += b_stride; + const __m128i vb1 = _mm_loadl_epi64((const __m128i*)b1); + const __m128i vxb1 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point); + b1 += b_stride; + const __m128i vb2 = _mm_loadl_epi64((const __m128i*)b2); + const __m128i vxb2 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point); + b2 += b_stride; + const __m128i vb3 = _mm_loadl_epi64((const __m128i*)b3); + const __m128i vxb3 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point); + b3 += b_stride; + + vacc00 = _mm_add_epi32(vacc00, _mm_madd_epi16(vxa0, vxb0)); + vacc01 = _mm_add_epi32(vacc01, _mm_madd_epi16(vxa0, vxb1)); + vacc02 = _mm_add_epi32(vacc02, _mm_madd_epi16(vxa0, vxb2)); + vacc03 = _mm_add_epi32(vacc03, _mm_madd_epi16(vxa0, vxb3)); + vacc10 = _mm_add_epi32(vacc10, _mm_madd_epi16(vxa1, vxb0)); + vacc11 = _mm_add_epi32(vacc11, _mm_madd_epi16(vxa1, vxb1)); + vacc12 = _mm_add_epi32(vacc12, _mm_madd_epi16(vxa1, vxb2)); + vacc13 = _mm_add_epi32(vacc13, _mm_madd_epi16(vxa1, vxb3)); + } + if (k != 0) { + const size_t a_predecrement = 8 - k; + const __m128i va_shift = _mm_cvtsi32_si128(8 * a_predecrement); + + const __m128i va_zero_point_partial = _mm_unpacklo_epi8( + _mm_srl_epi64(_mm_packus_epi16(va_zero_point, va_zero_point), va_shift), + vzero); + + const __m128i va0 = _mm_srl_epi64( + _mm_loadl_epi64((const __m128i*)(a0 - a_predecrement)), va_shift); + const __m128i vxa0 = + sub_zero_point(_mm_unpacklo_epi8(va0, vzero), va_zero_point_partial); + const __m128i va1 = _mm_srl_epi64( + _mm_loadl_epi64((const __m128i*)(a1 - a_predecrement)), va_shift); + const __m128i vxa1 = + sub_zero_point(_mm_unpacklo_epi8(va1, vzero), va_zero_point_partial); + + const __m128i vb0 = _mm_loadl_epi64((const __m128i*)b0); + const __m128i vxb0 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point); + const __m128i vb1 = _mm_loadl_epi64((const __m128i*)b1); + const __m128i vxb1 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point); + const __m128i vb2 = _mm_loadl_epi64((const __m128i*)b2); + const __m128i vxb2 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point); + const __m128i vb3 = _mm_loadl_epi64((const __m128i*)b3); + const __m128i vxb3 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point); + + vacc00 = _mm_add_epi32(vacc00, _mm_madd_epi16(vxa0, vxb0)); + vacc01 = _mm_add_epi32(vacc01, _mm_madd_epi16(vxa0, vxb1)); + vacc02 = _mm_add_epi32(vacc02, _mm_madd_epi16(vxa0, vxb2)); + vacc03 = _mm_add_epi32(vacc03, _mm_madd_epi16(vxa0, vxb3)); + vacc10 = _mm_add_epi32(vacc10, _mm_madd_epi16(vxa1, vxb0)); + vacc11 = _mm_add_epi32(vacc11, _mm_madd_epi16(vxa1, vxb1)); + vacc12 = _mm_add_epi32(vacc12, _mm_madd_epi16(vxa1, vxb2)); + vacc13 = _mm_add_epi32(vacc13, _mm_madd_epi16(vxa1, vxb3)); + } + + __m128i vacc0x0123 = pytorch_sse_reduce4_i32(vacc00, vacc01, vacc02, vacc03); + __m128i vacc1x0123 = pytorch_sse_reduce4_i32(vacc10, vacc11, vacc12, vacc13); + + const __m128i vmultiplier = + _mm_load_si128((const __m128i*)quantization_params->sse2.multiplier); + const __m128i vrounding = + _mm_load_si128((const __m128i*)quantization_params->sse2.rounding); + + const __m128i vnmask0x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc0x0123); + const __m128i vnmask1x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc1x0123); + + const __m128i vabsacc0x0123 = + _mm_sub_epi32(_mm_xor_si128(vacc0x0123, vnmask0x0123), vnmask0x0123); + const __m128i vabsacc1x0123 = + _mm_sub_epi32(_mm_xor_si128(vacc1x0123, vnmask1x0123), vnmask1x0123); + + const __m128i vabsacc0x1032 = + _mm_shuffle_epi32(vabsacc0x0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabsacc1x1032 = + _mm_shuffle_epi32(vabsacc1x0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i vabsprod0x02 = _mm_mul_epu32(vabsacc0x0123, vmultiplier); + const __m128i vabsprod1x02 = _mm_mul_epu32(vabsacc1x0123, vmultiplier); + + const __m128i vnmask0x02 = + _mm_shuffle_epi32(vnmask0x0123, _MM_SHUFFLE(2, 2, 0, 0)); + const __m128i vnmask1x02 = + _mm_shuffle_epi32(vnmask1x0123, _MM_SHUFFLE(2, 2, 0, 0)); + + const __m128i vprod0x02 = + _mm_sub_epi64(_mm_xor_si128(vabsprod0x02, vnmask0x02), vnmask0x02); + const __m128i vprod1x02 = + _mm_sub_epi64(_mm_xor_si128(vabsprod1x02, vnmask1x02), vnmask1x02); + + const __m128i vq31prod0x02 = + _mm_srli_epi64(_mm_add_epi64(vprod0x02, vrounding), 31); + const __m128i vq31prod1x02 = + _mm_srli_epi64(_mm_add_epi64(vprod1x02, vrounding), 31); + + const __m128i vabsprod0x13 = _mm_mul_epu32(vabsacc0x1032, vmultiplier); + const __m128i vabsprod1x13 = _mm_mul_epu32(vabsacc1x1032, vmultiplier); + + const __m128i vnmask0x13 = + _mm_shuffle_epi32(vnmask0x0123, _MM_SHUFFLE(3, 3, 1, 1)); + const __m128i vnmask1x13 = + _mm_shuffle_epi32(vnmask1x0123, _MM_SHUFFLE(3, 3, 1, 1)); + + const __m128i vprod0x13 = + _mm_sub_epi64(_mm_xor_si128(vabsprod0x13, vnmask0x13), vnmask0x13); + const __m128i vprod1x13 = + _mm_sub_epi64(_mm_xor_si128(vabsprod1x13, vnmask1x13), vnmask1x13); + + const __m128i vq31prod0x13 = + _mm_srli_epi64(_mm_add_epi64(vprod0x13, vrounding), 31); + const __m128i vq31prod1x13 = + _mm_srli_epi64(_mm_add_epi64(vprod1x13, vrounding), 31); + + const __m128i vq31prod0x0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vq31prod0x02), + _mm_castsi128_ps(vq31prod0x13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vq31prod1x0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vq31prod1x02), + _mm_castsi128_ps(vq31prod1x13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i vq31prod0x0123 = + _mm_shuffle_epi32(vq31prod0x0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vq31prod1x0123 = + _mm_shuffle_epi32(vq31prod1x0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i vremainder_mask = + _mm_load_si128((const __m128i*)quantization_params->sse2.remainder_mask); + + const __m128i vrem0x0123 = _mm_add_epi32( + _mm_and_si128(vq31prod0x0123, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod0x0123)); + const __m128i vrem1x0123 = _mm_add_epi32( + _mm_and_si128(vq31prod1x0123, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod1x0123)); + + const __m128i vremainder_threshold = _mm_load_si128( + (const __m128i*)quantization_params->sse2.remainder_threshold); + const __m128i vshift = + _mm_load_si128((const __m128i*)quantization_params->sse2.shift); + + vacc0x0123 = _mm_sub_epi32( + _mm_sra_epi32(vq31prod0x0123, vshift), + _mm_cmpgt_epi32(vrem0x0123, vremainder_threshold)); + vacc1x0123 = _mm_sub_epi32( + _mm_sra_epi32(vq31prod1x0123, vshift), + _mm_cmpgt_epi32(vrem1x0123, vremainder_threshold)); + + const __m128i voutput_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.output_zero_point); + const __m128i vacc01x0123 = _mm_adds_epi16( + _mm_packs_epi32(vacc0x0123, vacc1x0123), voutput_zero_point); + __m128i vout = _mm_packus_epi16(vacc01x0123, vacc01x0123); + vout = _mm_min_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_max)); + vout = _mm_max_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_min)); + + uint8_t* c0 = c; + uint8_t* c1 = (uint8_t*)((uintptr_t)c0 + c_stride); + if (mr != 2) { + c1 = c0; + } + if (nr == 4) { + *((uint32_t*)c0) = (uint32_t)_mm_cvtsi128_si32(vout); + *((uint32_t*)c1) = (uint32_t)_mm_cvtsi128_si32(_mm_srli_epi64(vout, 32)); + } else { + if (nr >= 2) { + *((uint16_t*)c0) = (uint16_t)_mm_extract_epi16(vout, 0); + c0 += 2; + *((uint16_t*)c1) = (uint16_t)_mm_extract_epi16(vout, 2); + c1 += 2; + vout = _mm_srli_epi32(vout, 16); + nr -= 2; + } + if (nr != 0) { + *((uint8_t*)c0) = (uint8_t)_mm_cvtsi128_si32(vout); + *((uint8_t*)c1) = (uint8_t)_mm_extract_epi16(vout, 2); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x-sumrows-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x-sumrows-neon.c new file mode 100644 index 0000000000000..4d6f7201d1289 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x-sumrows-neon.c @@ -0,0 +1,154 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_q8sumrows_ukernel_4x__neon( + const uint8_t* restrict a, + size_t m, + size_t k, + size_t stride, + const int32_t multiplier, + int32_t* restrict a_sum) { + const uint8_t* a0 = a; + const uint8_t* a1 = a0; + if (m >= 2) { + a1 += stride; + } + const uint8_t* a2 = a1; + if (m > 2) { + a2 += stride; + } + const uint8_t* a3 = a2; + if (m == 4) { + a3 += stride; + } + + uint32x4_t vacc0x0123 = vmovq_n_u32(0); // row 0 + uint32x4_t vacc1x0123 = vmovq_n_u32(0); // row 1 + uint32x4_t vacc2x0123 = vmovq_n_u32(0); // row 2 + uint32x4_t vacc3x0123 = vmovq_n_u32(0); // row 3 + for (; k >= 16; k -= 16) { + // row 0 + const uint8x16_t va0x0_15 = vld1q_u8(a0); + a0 += 16; + vacc0x0123 = vpadalq_u16( + vacc0x0123, vaddl_u8(vget_low_u8(va0x0_15), vget_high_u8(va0x0_15))); + + // row 1 + const uint8x16_t va1x0_15 = vld1q_u8(a1); + a1 += 16; + vacc1x0123 = vpadalq_u16( + vacc1x0123, vaddl_u8(vget_low_u8(va1x0_15), vget_high_u8(va1x0_15))); + + // row 2 + const uint8x16_t va2x0_15 = vld1q_u8(a2); + a2 += 16; + vacc2x0123 = vpadalq_u16( + vacc2x0123, vaddl_u8(vget_low_u8(va2x0_15), vget_high_u8(va2x0_15))); + + // row 3 + const uint8x16_t va3x0_15 = vld1q_u8(a3); + a3 += 16; + vacc3x0123 = vpadalq_u16( + vacc3x0123, vaddl_u8(vget_low_u8(va3x0_15), vget_high_u8(va3x0_15))); + } + + if (k >= 8) { + vacc0x0123 = vaddw_u16(vacc0x0123, vpaddl_u8(vld1_u8(a0))); + a0 += 8; + vacc1x0123 = vaddw_u16(vacc1x0123, vpaddl_u8(vld1_u8(a1))); + a1 += 8; + vacc2x0123 = vaddw_u16(vacc2x0123, vpaddl_u8(vld1_u8(a2))); + a2 += 8; + vacc3x0123 = vaddw_u16(vacc3x0123, vpaddl_u8(vld1_u8(a3))); + a3 += 8; + k -= 8; + } + + if (k >= 4) { + vacc0x0123 = vaddw_u16( + vacc0x0123, + vget_low_u16(vmovl_u8(vreinterpret_u8_u32( + vld1_dup_u32(__builtin_assume_aligned((const uint32_t*)a0, 1)))))); + a0 += 4; + vacc1x0123 = vaddw_u16( + vacc1x0123, + vget_low_u16(vmovl_u8(vreinterpret_u8_u32( + vld1_dup_u32(__builtin_assume_aligned((const uint32_t*)a1, 1)))))); + a1 += 4; + vacc2x0123 = vaddw_u16( + vacc2x0123, + vget_low_u16(vmovl_u8(vreinterpret_u8_u32( + vld1_dup_u32(__builtin_assume_aligned((const uint32_t*)a2, 1)))))); + a2 += 4; + vacc3x0123 = vaddw_u16( + vacc3x0123, + vget_low_u16(vmovl_u8(vreinterpret_u8_u32( + vld1_dup_u32(__builtin_assume_aligned((const uint32_t*)a3, 1)))))); + a3 += 4; + k -= 4; + } + + const uint32x2_t vsum0x01 = + vpadd_u32(vget_low_u32(vacc0x0123), vget_high_u32(vacc0x0123)); + const uint32x2_t vsum1x01 = + vpadd_u32(vget_low_u32(vacc1x0123), vget_high_u32(vacc1x0123)); + const uint32x2_t vsum2x01 = + vpadd_u32(vget_low_u32(vacc2x0123), vget_high_u32(vacc2x0123)); + const uint32x2_t vsum3x01 = + vpadd_u32(vget_low_u32(vacc3x0123), vget_high_u32(vacc3x0123)); + uint32x4_t vacc0123 = vcombine_u32( + vpadd_u32(vsum0x01, vsum1x01), vpadd_u32(vsum2x01, vsum3x01)); + + if (k >= 2) { + const uint8x8_t va0x01010101 = vreinterpret_u8_u16( + vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a0, 1))); + a0 += 2; + const uint8x8_t va1x01010101 = vreinterpret_u8_u16( + vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a1, 1))); + a1 += 2; + const uint8x8_t va2x01010101 = vreinterpret_u8_u16( + vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a2, 1))); + a2 += 2; + const uint8x8_t va3x01010101 = vreinterpret_u8_u16( + vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a3, 1))); + a3 += 2; + const uint8x8_t va0x01_1x010101 = vext_u8(va0x01010101, va1x01010101, 2); + const uint8x8_t va2x01_3x010101 = vext_u8(va2x01010101, va3x01010101, 6); + const uint8x8_t va0123x01 = vext_u8(va0x01_1x010101, va2x01_3x010101, 4); + vacc0123 = vaddw_u16(vacc0123, vpaddl_u8(va0123x01)); + k -= 2; + } + + if (k > 0) { + uint8x8_t vax0x1x2x3 = vmov_n_u8(0); + vax0x1x2x3 = vld1_lane_u8(a0, vax0x1x2x3, 0); + vax0x1x2x3 = vld1_lane_u8(a1, vax0x1x2x3, 2); + vax0x1x2x3 = vld1_lane_u8(a2, vax0x1x2x3, 4); + vax0x1x2x3 = vld1_lane_u8(a3, vax0x1x2x3, 6); + vacc0123 = vaddw_u16(vacc0123, vpaddl_u8(vax0x1x2x3)); + } + + int32x4_t vsum0123 = vmulq_n_s32(vreinterpretq_s32_u32(vacc0123), multiplier); + if (m == 4) { + vst1q_s32(a_sum, vsum0123); + } else { + if (m >= 2) { + vst1_s32(a_sum, vget_low_s32(vsum0123)); + a_sum += 2; + vsum0123 = vextq_s32(vsum0123, vsum0123, 2); + m -= 2; + } + if (m != 0) { + vst1q_lane_s32(a_sum, vsum0123, 0); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x4c2-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x4c2-sse2.c new file mode 100644 index 0000000000000..d7b55d837723b --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x4c2-sse2.c @@ -0,0 +1,452 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +void pytorch_q8gemm_ukernel_4x4c2__sse2( + size_t mr, + size_t nr, + size_t k, + const uint8_t* restrict a, + size_t a_stride, + const void* restrict w, + uint8_t* restrict c, + size_t c_stride, + const union pytorch_qnnp_conv_quantization_params + quantization_params[RESTRICT_STATIC 1]) { + __m128i vacc0x0123 = _mm_loadu_si128((const __m128i*)w); + __m128i vacc1x0123 = vacc0x0123; + __m128i vacc2x0123 = vacc0x0123; + __m128i vacc3x0123 = vacc0x0123; + w = (const void*)((uintptr_t)w + 16); + + const uint8_t* a0 = a; + const uint8_t* a1 = (const uint8_t*)((uintptr_t)a0 + a_stride); + if (mr < 2) { + a1 = a0; + } + const uint8_t* a2 = (const uint8_t*)((uintptr_t)a1 + a_stride); + if (mr <= 2) { + a2 = a1; + } + const uint8_t* a3 = (const uint8_t*)((uintptr_t)a2 + a_stride); + if (mr != 4) { + a3 = a2; + } + + const __m128i va_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.input_zero_point); + const __m128i vb_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.kernel_zero_point); + const __m128i vzero = _mm_setzero_si128(); + for (; k >= 8; k -= 8) { + const __m128i va0 = _mm_loadl_epi64((const __m128i*)a0); + const __m128i vxa0 = + sub_zero_point(_mm_unpacklo_epi8(va0, vzero), va_zero_point); + a0 += 8; + const __m128i va1 = _mm_loadl_epi64((const __m128i*)a1); + const __m128i vxa1 = + sub_zero_point(_mm_unpacklo_epi8(va1, vzero), va_zero_point); + a1 += 8; + const __m128i va2 = _mm_loadl_epi64((const __m128i*)a2); + const __m128i vxa2 = + sub_zero_point(_mm_unpacklo_epi8(va2, vzero), va_zero_point); + a2 += 8; + const __m128i va3 = _mm_loadl_epi64((const __m128i*)a3); + const __m128i vxa3 = + sub_zero_point(_mm_unpacklo_epi8(va3, vzero), va_zero_point); + a3 += 8; + + const __m128i vb0 = _mm_loadl_epi64((const __m128i*)w); + const __m128i vxb0 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point); + + vacc0x0123 = _mm_add_epi32( + vacc0x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(0, 0, 0, 0)), vxb0)); + vacc1x0123 = _mm_add_epi32( + vacc1x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(0, 0, 0, 0)), vxb0)); + vacc2x0123 = _mm_add_epi32( + vacc2x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(0, 0, 0, 0)), vxb0)); + vacc3x0123 = _mm_add_epi32( + vacc3x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(0, 0, 0, 0)), vxb0)); + + const __m128i vb1 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 8)); + const __m128i vxb1 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point); + + vacc0x0123 = _mm_add_epi32( + vacc0x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(1, 1, 1, 1)), vxb1)); + vacc1x0123 = _mm_add_epi32( + vacc1x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(1, 1, 1, 1)), vxb1)); + vacc2x0123 = _mm_add_epi32( + vacc2x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(1, 1, 1, 1)), vxb1)); + vacc3x0123 = _mm_add_epi32( + vacc3x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(1, 1, 1, 1)), vxb1)); + + const __m128i vb2 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 16)); + const __m128i vxb2 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point); + + vacc0x0123 = _mm_add_epi32( + vacc0x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); + vacc1x0123 = _mm_add_epi32( + vacc1x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); + vacc2x0123 = _mm_add_epi32( + vacc2x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); + vacc3x0123 = _mm_add_epi32( + vacc3x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); + + const __m128i vb3 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 24)); + const __m128i vxb3 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point); + w = (const void*)((uintptr_t)w + 32); + + vacc0x0123 = _mm_add_epi32( + vacc0x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); + vacc1x0123 = _mm_add_epi32( + vacc1x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); + vacc2x0123 = _mm_add_epi32( + vacc2x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); + vacc3x0123 = _mm_add_epi32( + vacc3x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); + } + if (k != 0) { + const size_t a_predecrement = 8 - k; + const __m128i va_shift = _mm_cvtsi32_si128(8 * a_predecrement); + + const __m128i va0 = _mm_srl_epi64( + _mm_loadl_epi64((const __m128i*)(a0 - a_predecrement)), va_shift); + const __m128i vxa0 = + sub_zero_point(_mm_unpacklo_epi8(va0, vzero), va_zero_point); + const __m128i va1 = _mm_srl_epi64( + _mm_loadl_epi64((const __m128i*)(a1 - a_predecrement)), va_shift); + const __m128i vxa1 = + sub_zero_point(_mm_unpacklo_epi8(va1, vzero), va_zero_point); + const __m128i va2 = _mm_srl_epi64( + _mm_loadl_epi64((const __m128i*)(a2 - a_predecrement)), va_shift); + const __m128i vxa2 = + sub_zero_point(_mm_unpacklo_epi8(va2, vzero), va_zero_point); + const __m128i va3 = _mm_srl_epi64( + _mm_loadl_epi64((const __m128i*)(a3 - a_predecrement)), va_shift); + const __m128i vxa3 = + sub_zero_point(_mm_unpacklo_epi8(va3, vzero), va_zero_point); + + const __m128i vb0 = _mm_loadl_epi64((const __m128i*)w); + const __m128i vxb0 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb0, vzero), vb_zero_point); + + vacc0x0123 = _mm_add_epi32( + vacc0x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa0, _MM_SHUFFLE(0, 0, 0, 0)), vxb0)); + vacc1x0123 = _mm_add_epi32( + vacc1x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa1, _MM_SHUFFLE(0, 0, 0, 0)), vxb0)); + vacc2x0123 = _mm_add_epi32( + vacc2x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa2, _MM_SHUFFLE(0, 0, 0, 0)), vxb0)); + vacc3x0123 = _mm_add_epi32( + vacc3x0123, + _mm_madd_epi16(_mm_shuffle_epi32(vxa3, _MM_SHUFFLE(0, 0, 0, 0)), vxb0)); + + if (k > 2) { + const __m128i vb1 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 8)); + const __m128i vxb1 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb1, vzero), vb_zero_point); + + vacc0x0123 = _mm_add_epi32( + vacc0x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(1, 1, 1, 1)), vxb1)); + vacc1x0123 = _mm_add_epi32( + vacc1x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(1, 1, 1, 1)), vxb1)); + vacc2x0123 = _mm_add_epi32( + vacc2x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(1, 1, 1, 1)), vxb1)); + vacc3x0123 = _mm_add_epi32( + vacc3x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(1, 1, 1, 1)), vxb1)); + + if (k > 4) { + const __m128i vb2 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 16)); + const __m128i vxb2 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb2, vzero), vb_zero_point); + + vacc0x0123 = _mm_add_epi32( + vacc0x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); + vacc1x0123 = _mm_add_epi32( + vacc1x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); + vacc2x0123 = _mm_add_epi32( + vacc2x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); + vacc3x0123 = _mm_add_epi32( + vacc3x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(2, 2, 2, 2)), vxb2)); + + if (k > 6) { + const __m128i vb3 = + _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 24)); + const __m128i vxb3 = + _mm_sub_epi16(_mm_unpacklo_epi8(vb3, vzero), vb_zero_point); + + vacc0x0123 = _mm_add_epi32( + vacc0x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa0, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); + vacc1x0123 = _mm_add_epi32( + vacc1x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa1, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); + vacc2x0123 = _mm_add_epi32( + vacc2x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa2, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); + vacc3x0123 = _mm_add_epi32( + vacc3x0123, + _mm_madd_epi16( + _mm_shuffle_epi32(vxa3, _MM_SHUFFLE(3, 3, 3, 3)), vxb3)); + } + } + } + } + + const __m128i vmultiplier = + _mm_load_si128((const __m128i*)quantization_params->sse2.multiplier); + const __m128i vrounding = + _mm_load_si128((const __m128i*)quantization_params->sse2.rounding); + + const __m128i vnmask0x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc0x0123); + const __m128i vnmask1x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc1x0123); + const __m128i vnmask2x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc2x0123); + const __m128i vnmask3x0123 = _mm_cmpgt_epi32(_mm_setzero_si128(), vacc3x0123); + + const __m128i vabsacc0x0123 = + _mm_sub_epi32(_mm_xor_si128(vacc0x0123, vnmask0x0123), vnmask0x0123); + const __m128i vabsacc1x0123 = + _mm_sub_epi32(_mm_xor_si128(vacc1x0123, vnmask1x0123), vnmask1x0123); + const __m128i vabsacc2x0123 = + _mm_sub_epi32(_mm_xor_si128(vacc2x0123, vnmask2x0123), vnmask2x0123); + const __m128i vabsacc3x0123 = + _mm_sub_epi32(_mm_xor_si128(vacc3x0123, vnmask3x0123), vnmask3x0123); + + const __m128i vabsacc0x1032 = + _mm_shuffle_epi32(vabsacc0x0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabsacc1x1032 = + _mm_shuffle_epi32(vabsacc1x0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabsacc2x1032 = + _mm_shuffle_epi32(vabsacc2x0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i vabsacc3x1032 = + _mm_shuffle_epi32(vabsacc3x0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i vabsprod0x02 = _mm_mul_epu32(vabsacc0x0123, vmultiplier); + const __m128i vabsprod1x02 = _mm_mul_epu32(vabsacc1x0123, vmultiplier); + const __m128i vabsprod2x02 = _mm_mul_epu32(vabsacc2x0123, vmultiplier); + const __m128i vabsprod3x02 = _mm_mul_epu32(vabsacc3x0123, vmultiplier); + + const __m128i vnmask0x02 = + _mm_shuffle_epi32(vnmask0x0123, _MM_SHUFFLE(2, 2, 0, 0)); + const __m128i vnmask1x02 = + _mm_shuffle_epi32(vnmask1x0123, _MM_SHUFFLE(2, 2, 0, 0)); + const __m128i vnmask2x02 = + _mm_shuffle_epi32(vnmask2x0123, _MM_SHUFFLE(2, 2, 0, 0)); + const __m128i vnmask3x02 = + _mm_shuffle_epi32(vnmask3x0123, _MM_SHUFFLE(2, 2, 0, 0)); + + const __m128i vprod0x02 = + _mm_sub_epi64(_mm_xor_si128(vabsprod0x02, vnmask0x02), vnmask0x02); + const __m128i vprod1x02 = + _mm_sub_epi64(_mm_xor_si128(vabsprod1x02, vnmask1x02), vnmask1x02); + const __m128i vprod2x02 = + _mm_sub_epi64(_mm_xor_si128(vabsprod2x02, vnmask2x02), vnmask2x02); + const __m128i vprod3x02 = + _mm_sub_epi64(_mm_xor_si128(vabsprod3x02, vnmask3x02), vnmask3x02); + + const __m128i vq31prod0x02 = + _mm_srli_epi64(_mm_add_epi64(vprod0x02, vrounding), 31); + const __m128i vq31prod1x02 = + _mm_srli_epi64(_mm_add_epi64(vprod1x02, vrounding), 31); + const __m128i vq31prod2x02 = + _mm_srli_epi64(_mm_add_epi64(vprod2x02, vrounding), 31); + const __m128i vq31prod3x02 = + _mm_srli_epi64(_mm_add_epi64(vprod3x02, vrounding), 31); + + const __m128i vabsprod0x13 = _mm_mul_epu32(vabsacc0x1032, vmultiplier); + const __m128i vabsprod1x13 = _mm_mul_epu32(vabsacc1x1032, vmultiplier); + const __m128i vabsprod2x13 = _mm_mul_epu32(vabsacc2x1032, vmultiplier); + const __m128i vabsprod3x13 = _mm_mul_epu32(vabsacc3x1032, vmultiplier); + + const __m128i vnmask0x13 = + _mm_shuffle_epi32(vnmask0x0123, _MM_SHUFFLE(3, 3, 1, 1)); + const __m128i vnmask1x13 = + _mm_shuffle_epi32(vnmask1x0123, _MM_SHUFFLE(3, 3, 1, 1)); + const __m128i vnmask2x13 = + _mm_shuffle_epi32(vnmask2x0123, _MM_SHUFFLE(3, 3, 1, 1)); + const __m128i vnmask3x13 = + _mm_shuffle_epi32(vnmask3x0123, _MM_SHUFFLE(3, 3, 1, 1)); + + const __m128i vprod0x13 = + _mm_sub_epi64(_mm_xor_si128(vabsprod0x13, vnmask0x13), vnmask0x13); + const __m128i vprod1x13 = + _mm_sub_epi64(_mm_xor_si128(vabsprod1x13, vnmask1x13), vnmask1x13); + const __m128i vprod2x13 = + _mm_sub_epi64(_mm_xor_si128(vabsprod2x13, vnmask2x13), vnmask2x13); + const __m128i vprod3x13 = + _mm_sub_epi64(_mm_xor_si128(vabsprod3x13, vnmask3x13), vnmask3x13); + + const __m128i vq31prod0x13 = + _mm_srli_epi64(_mm_add_epi64(vprod0x13, vrounding), 31); + const __m128i vq31prod1x13 = + _mm_srli_epi64(_mm_add_epi64(vprod1x13, vrounding), 31); + const __m128i vq31prod2x13 = + _mm_srli_epi64(_mm_add_epi64(vprod2x13, vrounding), 31); + const __m128i vq31prod3x13 = + _mm_srli_epi64(_mm_add_epi64(vprod3x13, vrounding), 31); + + const __m128i vq31prod0x0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vq31prod0x02), + _mm_castsi128_ps(vq31prod0x13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vq31prod1x0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vq31prod1x02), + _mm_castsi128_ps(vq31prod1x13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vq31prod2x0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vq31prod2x02), + _mm_castsi128_ps(vq31prod2x13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i vq31prod3x0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vq31prod3x02), + _mm_castsi128_ps(vq31prod3x13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i vq31prod0x0123 = + _mm_shuffle_epi32(vq31prod0x0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vq31prod1x0123 = + _mm_shuffle_epi32(vq31prod1x0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vq31prod2x0123 = + _mm_shuffle_epi32(vq31prod2x0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i vq31prod3x0123 = + _mm_shuffle_epi32(vq31prod3x0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i vremainder_mask = + _mm_load_si128((const __m128i*)quantization_params->sse2.remainder_mask); + + const __m128i vrem0x0123 = _mm_add_epi32( + _mm_and_si128(vq31prod0x0123, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod0x0123)); + const __m128i vrem1x0123 = _mm_add_epi32( + _mm_and_si128(vq31prod1x0123, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod1x0123)); + const __m128i vrem2x0123 = _mm_add_epi32( + _mm_and_si128(vq31prod2x0123, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod2x0123)); + const __m128i vrem3x0123 = _mm_add_epi32( + _mm_and_si128(vq31prod3x0123, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vq31prod3x0123)); + + const __m128i vremainder_threshold = _mm_load_si128( + (const __m128i*)quantization_params->sse2.remainder_threshold); + const __m128i vshift = + _mm_load_si128((const __m128i*)quantization_params->sse2.shift); + + vacc0x0123 = _mm_sub_epi32( + _mm_sra_epi32(vq31prod0x0123, vshift), + _mm_cmpgt_epi32(vrem0x0123, vremainder_threshold)); + vacc1x0123 = _mm_sub_epi32( + _mm_sra_epi32(vq31prod1x0123, vshift), + _mm_cmpgt_epi32(vrem1x0123, vremainder_threshold)); + vacc2x0123 = _mm_sub_epi32( + _mm_sra_epi32(vq31prod2x0123, vshift), + _mm_cmpgt_epi32(vrem2x0123, vremainder_threshold)); + vacc3x0123 = _mm_sub_epi32( + _mm_sra_epi32(vq31prod3x0123, vshift), + _mm_cmpgt_epi32(vrem3x0123, vremainder_threshold)); + + const __m128i voutput_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.output_zero_point); + const __m128i vacc01x0123 = _mm_adds_epi16( + _mm_packs_epi32(vacc0x0123, vacc1x0123), voutput_zero_point); + const __m128i vacc23x0123 = _mm_adds_epi16( + _mm_packs_epi32(vacc2x0123, vacc3x0123), voutput_zero_point); + __m128i vout = _mm_packus_epi16(vacc01x0123, vacc23x0123); + vout = _mm_min_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_max)); + vout = _mm_max_epu8( + vout, + _mm_load_si128((const __m128i*)quantization_params->sse2.output_min)); + + uint8_t* c0 = c; + uint8_t* c1 = (uint8_t*)((uintptr_t)c0 + c_stride); + if (mr < 2) { + c1 = c0; + } + uint8_t* c2 = (uint8_t*)((uintptr_t)c1 + c_stride); + if (mr <= 2) { + c2 = c1; + } + uint8_t* c3 = (uint8_t*)((uintptr_t)c2 + c_stride); + if (mr != 4) { + c3 = c2; + } + if (nr == 4) { + *((uint32_t*)c0) = (uint32_t)_mm_cvtsi128_si32(vout); + *((uint32_t*)c1) = (uint32_t)_mm_cvtsi128_si32(_mm_srli_epi64(vout, 32)); + *((uint32_t*)c2) = + (uint32_t)_mm_cvtsi128_si32(_mm_unpackhi_epi32(vout, vout)); + *((uint32_t*)c3) = (uint32_t)_mm_cvtsi128_si32(_mm_srli_si128(vout, 12)); + } else { + if (nr >= 2) { + *((uint16_t*)c0) = (uint16_t)_mm_extract_epi16(vout, 0); + c0 += 2; + *((uint16_t*)c1) = (uint16_t)_mm_extract_epi16(vout, 2); + c1 += 2; + *((uint16_t*)c2) = (uint16_t)_mm_extract_epi16(vout, 4); + c2 += 2; + *((uint16_t*)c3) = (uint16_t)_mm_extract_epi16(vout, 6); + c3 += 2; + vout = _mm_srli_epi32(vout, 16); + nr -= 2; + } + if (nr != 0) { + *((uint8_t*)c0) = (uint8_t)_mm_cvtsi128_si32(vout); + *((uint8_t*)c1) = (uint8_t)_mm_extract_epi16(vout, 2); + *((uint8_t*)c2) = (uint8_t)_mm_extract_epi16(vout, 4); + *((uint8_t*)c3) = (uint8_t)_mm_extract_epi16(vout, 6); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-aarch32-neon.S new file mode 100644 index 0000000000000..a113e71f99685 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-aarch32-neon.S @@ -0,0 +1,795 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +.syntax unified + +# void pytorch_q8gemm_ukernel_4x8__aarch32_neon( +# size_t mr, +# size_t nr, +# size_t k, +# const uint8_t*restrict a, +# size_t a_stride, +# const void*restrict w, +# uint8_t*restrict c, +# size_t c_stride, +# const union pytorch_qnnp_conv_quantization_params quantization_params[restrict static 1]) +BEGIN_FUNCTION pytorch_q8gemm_ukernel_4x8__aarch32_neon + .arm +#ifndef __APPLE__ + .arch armv7-a + .fpu neon +#endif + # Load w + # - ip = w + LDR ip, [sp, 4] + PUSH {r4, r5, r6, r7} + + VPUSH {d8-d15} + # Load quantization params + # - r7 = quantization_params + LDR r7, [sp, 96] + + # Load bias0123, bias4567 + VLDM ip!, {d16-d19} + + # Load a_stride + # - r6 = a_stride + LDR r6, [sp, 80] + CMP r0, 2 + + ADD r4, r3, r6 + + # Load b_zero_point: + # - d15 = b_zero_point + VLD1.8 {d15[]}, [r7] + ADD r7, r7, 2 + + # Load a_zero_point: + # - d14 = a_zero_point + VLD1.8 {d14[]}, [r7] + + MOVLO r4, r3 + + ADD r7, r7, 2 + ADD r5, r4, r6 + + # q10 := vacc1x0123 + VMOV.I32 q10, q8 + MOVLS r5, r4 + # q11 := vacc1x4567 + VMOV.I32 q11, q9 + ADD r6, r5, r6 + # q12 := vacc2x0123 + VMOV.I32 q12, q8 + CMP r0, 4 + # q13 := vacc2x4567 + VMOV.I32 q13, q9 + MOVNE r6, r5 + # q14 := vacc3x0123 + VMOV.I32 q14, q8 + SUBS r2, r2, 8 + # q15 := vacc3x4567 + VMOV.I32 q15, q9 + # Load multiplier: + # - d12 = vmultiplier + VLD1.32 {d12[]}, [r7]! + BLO 1f + + .p2align 5 +0: + # Load a0 + # - d1 = a0 + VLD1.8 {d1}, [r3]! + + # Load a1 + # - d3 = a1 + VLD1.8 {d3}, [r4]! + + # Load b0-b7 (channel 0) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # Load a2 + # - d5 = a2 + VLD1.8 {d5}, [r5]! + + # q0 = va0 = a0 + SUB_ZERO_POINT q0, d1, d14 + + # Load a3 + # - d7 = a3 + VLD1.8 {d7}, [r6]! + + # q1 = va1 = a1 + SUB_ZERO_POINT q1, d3, d14 + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 0) + # - d9 = vb4567 (channel 0) + VSUBL.U8 q4, d9, d15 + + # q2 = va2 = a2 + SUB_ZERO_POINT q2, d5, d14 + # q3 = va3 = a3 + SUB_ZERO_POINT q3, d7, d14 + + ### Channel 0 ### + + # Load b0-b7 (channel 1) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[0] + VMLAL.S16 q8, d8, d0[0] + # vacc0x4567 += vb4567 * va0[0] + VMLAL.S16 q9, d9, d0[0] + + # vacc1x0123 += vb0123 * va1[0] + VMLAL.S16 q10, d8, d2[0] + # vacc1x4567 += vb4567 * va1[0] + VMLAL.S16 q11, d9, d2[0] + + # vacc2x0123 += vb0123 * va2[0] + VMLAL.S16 q12, d8, d4[0] + # vacc2x4567 += vb4567 * va2[0] + VMLAL.S16 q13, d9, d4[0] + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 1) + # - d11 = vb4567 (channel 1) + VSUBL.U8 q5, d11, d15 + + # vacc3x0123 += vb0123 * va3[0] + VMLAL.S16 q14, d8, d6[0] + # vacc3x4567 += vb4567 * va3[0] + VMLAL.S16 q15, d9, d6[0] + + ### Channel 1 ### + + # Load b0-b7 (channel 2) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[1] + VMLAL.S16 q8, d10, d0[1] + # vacc0x4567 += vb4567 * va0[1] + VMLAL.S16 q9, d11, d0[1] + + # vacc1x0123 += vb0123 * va1[1] + VMLAL.S16 q10, d10, d2[1] + # vacc1x4567 += vb4567 * va1[1] + VMLAL.S16 q11, d11, d2[1] + + # vacc2x0123 += vb0123 * va2[1] + VMLAL.S16 q12, d10, d4[1] + # vacc2x4567 += vb4567 * va2[1] + VMLAL.S16 q13, d11, d4[1] + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 2) + # - d9 = vb4567 (channel 2) + VSUBL.U8 q4, d9, d15 + + # vacc3x0123 += vb0123 * va3[1] + VMLAL.S16 q14, d10, d6[1] + # vacc3x4567 += vb4567 * va3[1] + VMLAL.S16 q15, d11, d6[1] + + ### Channel 2 ### + + # Load b0-b7 (channel 3) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[2] + VMLAL.S16 q8, d8, d0[2] + # vacc0x4567 += vb4567 * va0[2] + VMLAL.S16 q9, d9, d0[2] + + # vacc1x0123 += vb0123 * va1[2] + VMLAL.S16 q10, d8, d2[2] + # vacc1x4567 += vb4567 * va1[2] + VMLAL.S16 q11, d9, d2[2] + + # vacc2x0123 += vb0123 * va2[2] + VMLAL.S16 q12, d8, d4[2] + # vacc2x4567 += vb4567 * va2[2] + VMLAL.S16 q13, d9, d4[2] + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 3) + # - d11 = vb4567 (channel 3) + VSUBL.U8 q5, d11, d15 + + # vacc3x0123 += vb0123 * va3[2] + VMLAL.S16 q14, d8, d6[2] + # vacc3x4567 += vb4567 * va3[2] + VMLAL.S16 q15, d9, d6[2] + + ### Channel 3 ### + + # Load b0-b7 (channel 4) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[3] + VMLAL.S16 q8, d10, d0[3] + # vacc0x4567 += vb4567 * va0[3] + VMLAL.S16 q9, d11, d0[3] + + # vacc1x0123 += vb0123 * va1[3] + VMLAL.S16 q10, d10, d2[3] + # vacc1x4567 += vb4567 * va1[3] + VMLAL.S16 q11, d11, d2[3] + + # vacc2x0123 += vb0123 * va2[3] + VMLAL.S16 q12, d10, d4[3] + # vacc2x4567 += vb4567 * va2[3] + VMLAL.S16 q13, d11, d4[3] + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 4) + # - d11 = vb4567 (channel 4) + VSUBL.U8 q4, d9, d15 + + # vacc3x0123 += vb0123 * va3[3] + VMLAL.S16 q14, d10, d6[3] + # vacc3x4567 += vb4567 * va3[3] + VMLAL.S16 q15, d11, d6[3] + + ### Channel 4 ### + + # Load b0-b7 (channel 5) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[4] + VMLAL.S16 q8, d8, d1[0] + # vacc0x4567 += vb4567 * va0[4] + VMLAL.S16 q9, d9, d1[0] + + # vacc1x0123 += vb0123 * va1[4] + VMLAL.S16 q10, d8, d3[0] + # vacc1x4567 += vb4567 * va1[4] + VMLAL.S16 q11, d9, d3[0] + + # vacc2x0123 += vb0123 * va2[4] + VMLAL.S16 q12, d8, d5[0] + # vacc2x4567 += vb4567 * va2[4] + VMLAL.S16 q13, d9, d5[0] + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 5) + # - d9 = vb4567 (channel 5) + VSUBL.U8 q5, d11, d15 + + # vacc3x0123 += vb0123 * va3[4] + VMLAL.S16 q14, d8, d7[0] + # vacc3x4567 += vb4567 * va3[4] + VMLAL.S16 q15, d9, d7[0] + + ### Channel 5 ### + + # Load b0-b7 (channel 6) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[5] + VMLAL.S16 q8, d10, d1[1] + # vacc0x4567 += vb4567 * va0[5] + VMLAL.S16 q9, d11, d1[1] + + # vacc1x0123 += vb0123 * va1[5] + VMLAL.S16 q10, d10, d3[1] + # vacc1x4567 += vb4567 * va1[5] + VMLAL.S16 q11, d11, d3[1] + + # vacc2x0123 += vb0123 * va2[5] + VMLAL.S16 q12, d10, d5[1] + # vacc2x4567 += vb4567 * va2[5] + VMLAL.S16 q13, d11, d5[1] + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 6) + # - d9 = vb4567 (channel 6) + VSUBL.U8 q4, d9, d15 + + # vacc3x0123 += vb0123 * va3[5] + VMLAL.S16 q14, d10, d7[1] + # vacc3x4567 += vb4567 * va3[5] + VMLAL.S16 q15, d11, d7[1] + + ### Channel 6 ### + + # Load b0-b7 (channel 7) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[6] + VMLAL.S16 q8, d8, d1[2] + # vacc0x4567 += vb4567 * va0[6] + VMLAL.S16 q9, d9, d1[2] + + # vacc1x0123 += vb0123 * va1[6] + VMLAL.S16 q10, d8, d3[2] + # vacc1x4567 += vb4567 * va1[6] + VMLAL.S16 q11, d9, d3[2] + + # vacc2x0123 += vb0123 * va2[6] + VMLAL.S16 q12, d8, d5[2] + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 7) + # - d11 = vb4567 (channel 7) + VSUBL.U8 q5, d11, d15 + + # vacc2x4567 += vb4567 * va2[6] + VMLAL.S16 q13, d9, d5[2] + + # vacc3x0123 += vb0123 * va3[6] + VMLAL.S16 q14, d8, d7[2] + # vacc3x4567 += vb4567 * va3[6] + VMLAL.S16 q15, d9, d7[2] + + ### Channel 8 ### + SUBS r2, r2, 8 + + # vacc0x0123 += vb0123 * va0[7] + VMLAL.S16 q8, d10, d1[3] + # vacc0x4567 += vb4567 * va0[7] + VMLAL.S16 q9, d11, d1[3] + + # vacc1x0123 += vb0123 * va1[7] + VMLAL.S16 q10, d10, d3[3] + # vacc1x4567 += vb4567 * va1[7] + VMLAL.S16 q11, d11, d3[3] + + # vacc2x0123 += vb0123 * va2[7] + VMLAL.S16 q12, d10, d5[3] + # vacc2x4567 += vb4567 * va2[7] + VMLAL.S16 q13, d11, d5[3] + + # vacc3x0123 += vb0123 * va3[7] + VMLAL.S16 q14, d10, d7[3] + # vacc3x4567 += vb4567 * va3[7] + VMLAL.S16 q15, d11, d7[3] + + BHS 0b + +1: + CMP r2, -8 + BEQ 2f + + # Adjust a0, a1, a2, a3 + ADD r3, r2 + ADD r4, r2 + ADD r5, r2 + ADD r6, r2 + + # a_shift = 8 * k - 64 + LSL r2, r2, 3 + VDUP.32 d13, r2 + + # Load a0 + # - d1 = a0 + VLD1.8 {d1}, [r3] + + # Load a1 + # - d3 = a1 + VLD1.8 {d3}, [r4] + + # Load b0-b7 (channel 0) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # Load a2 + # - d5 = a2 + VLD1.8 {d5}, [r5] + + # q0 = va0 = a0 + VSHL.U64 d1, d1, d13 + SUB_ZERO_POINT q0, d1, d14 + + # Load a3 + # - d7 = a3 + VLD1.8 {d7}, [r6] + + # q1 = va1 = a1 + VSHL.U64 d3, d3, d13 + SUB_ZERO_POINT q1, d3, d14 + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 0) + # - d9 = vb4567 (channel 0) + VSUBL.U8 q4, d9, d15 + + # q2 = va2 = a2 + VSHL.U64 d5, d5, d13 + SUB_ZERO_POINT q2, d5, d14 + # q3 = va3 = a3 + VSHL.U64 d7, d7, d13 + SUB_ZERO_POINT q3, d7, d14 + + ### Channel 0 ### + + # vacc0x0123 += vb0123 * va0[0] + VMLAL.S16 q8, d8, d0[0] + # vacc0x4567 += vb4567 * va0[0] + VMLAL.S16 q9, d9, d0[0] + + # vacc1x0123 += vb0123 * va1[0] + VMLAL.S16 q10, d8, d2[0] + # vacc1x4567 += vb4567 * va1[0] + VMLAL.S16 q11, d9, d2[0] + + # vacc2x0123 += vb0123 * va2[0] + VMLAL.S16 q12, d8, d4[0] + # vacc2x4567 += vb4567 * va2[0] + VMLAL.S16 q13, d9, d4[0] + + # vacc3x0123 += vb0123 * va3[0] + VMLAL.S16 q14, d8, d6[0] + # vacc3x4567 += vb4567 * va3[0] + VMLAL.S16 q15, d9, d6[0] + + CMP r2, -48 + BLO 2f + + ### Channel 1 ### + + # Load b0-b7 (channel 1) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 1) + # - d11 = vb4567 (channel 1) + VSUBL.U8 q5, d11, d15 + + # vacc0x0123 += vb0123 * va0[1] + VMLAL.S16 q8, d10, d0[1] + # vacc0x4567 += vb4567 * va0[1] + VMLAL.S16 q9, d11, d0[1] + + # vacc1x0123 += vb0123 * va1[1] + VMLAL.S16 q10, d10, d2[1] + # vacc1x4567 += vb4567 * va1[1] + VMLAL.S16 q11, d11, d2[1] + + # vacc2x0123 += vb0123 * va2[1] + VMLAL.S16 q12, d10, d4[1] + # vacc2x4567 += vb4567 * va2[1] + VMLAL.S16 q13, d11, d4[1] + + # vacc3x0123 += vb0123 * va3[1] + VMLAL.S16 q14, d10, d6[1] + # vacc3x4567 += vb4567 * va3[1] + VMLAL.S16 q15, d11, d6[1] + + ### Channel 2 ### + BLS 2f + + # Load b0-b7 (channel 2) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 2) + # - d9 = vb4567 (channel 2) + VSUBL.U8 q4, d9, d15 + + # vacc0x0123 += vb0123 * va0[2] + VMLAL.S16 q8, d8, d0[2] + # vacc0x4567 += vb4567 * va0[2] + VMLAL.S16 q9, d9, d0[2] + + # vacc1x0123 += vb0123 * va1[2] + VMLAL.S16 q10, d8, d2[2] + # vacc1x4567 += vb4567 * va1[2] + VMLAL.S16 q11, d9, d2[2] + + # vacc2x0123 += vb0123 * va2[2] + VMLAL.S16 q12, d8, d4[2] + # vacc2x4567 += vb4567 * va2[2] + VMLAL.S16 q13, d9, d4[2] + + # vacc3x0123 += vb0123 * va3[2] + VMLAL.S16 q14, d8, d6[2] + # vacc3x4567 += vb4567 * va3[2] + VMLAL.S16 q15, d9, d6[2] + + ### Channel 3 ### + CMP r2, -32 + BLO 2f + + # Load b0-b7 (channel 3) + # - d9 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 3) + # - d9 = vb4567 (channel 3) + VSUBL.U8 q5, d11, d15 + + # vacc0x0123 += vb0123 * va0[3] + VMLAL.S16 q8, d10, d0[3] + # vacc0x4567 += vb4567 * va0[3] + VMLAL.S16 q9, d11, d0[3] + + # vacc1x0123 += vb0123 * va1[3] + VMLAL.S16 q10, d10, d2[3] + # vacc1x4567 += vb4567 * va1[3] + VMLAL.S16 q11, d11, d2[3] + + # vacc2x0123 += vb0123 * va2[3] + VMLAL.S16 q12, d10, d4[3] + # vacc2x4567 += vb4567 * va2[3] + VMLAL.S16 q13, d11, d4[3] + + # vacc3x0123 += vb0123 * va3[3] + VMLAL.S16 q14, d10, d6[3] + # vacc3x4567 += vb4567 * va3[3] + VMLAL.S16 q15, d11, d6[3] + + ### Channel 4 ### + BLS 2f + + # Load b0-b7 (channel 4) + # - d11 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 4) + # - d11 = vb4567 (channel 4) + VSUBL.U8 q4, d9, d15 + + # vacc0x0123 += vb0123 * va0[4] + VMLAL.S16 q8, d8, d1[0] + # vacc0x4567 += vb4567 * va0[4] + VMLAL.S16 q9, d9, d1[0] + + # vacc1x0123 += vb0123 * va1[4] + VMLAL.S16 q10, d8, d3[0] + # vacc1x4567 += vb4567 * va1[4] + VMLAL.S16 q11, d9, d3[0] + + # vacc2x0123 += vb0123 * va2[4] + VMLAL.S16 q12, d8, d5[0] + # vacc2x4567 += vb4567 * va2[4] + VMLAL.S16 q13, d9, d5[0] + + # vacc3x0123 += vb0123 * va3[4] + VMLAL.S16 q14, d8, d7[0] + # vacc3x4567 += vb4567 * va3[4] + VMLAL.S16 q15, d9, d7[0] + + ### Channel 5 ### + CMP r2, -16 + BLO 2f + + # Load b0-b7 (channel 5) + # - d13 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 5) + # - d11 = vb4567 (channel 5) + VSUBL.U8 q5, d11, d15 + + # vacc0x0123 += vb0123 * va0[5] + VMLAL.S16 q8, d10, d1[1] + # vacc0x4567 += vb4567 * va0[5] + VMLAL.S16 q9, d11, d1[1] + + # vacc1x0123 += vb0123 * va1[5] + VMLAL.S16 q10, d10, d3[1] + # vacc1x4567 += vb4567 * va1[5] + VMLAL.S16 q11, d11, d3[1] + + # vacc2x0123 += vb0123 * va2[5] + VMLAL.S16 q12, d10, d5[1] + # vacc2x4567 += vb4567 * va2[5] + VMLAL.S16 q13, d11, d5[1] + + # vacc3x0123 += vb0123 * va3[5] + VMLAL.S16 q14, d10, d7[1] + # vacc3x4567 += vb4567 * va3[5] + VMLAL.S16 q15, d11, d7[1] + + ### Channel 6 ### + BLS 2f + + # Load b0-b7 (channel 6) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64] + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 6) + # - d9 = vb4567 (channel 6) + VSUBL.U8 q4, d9, d15 + + # vacc0x0123 += vb0123 * va0[6] + VMLAL.S16 q8, d8, d1[2] + # vacc0x4567 += vb4567 * va0[6] + VMLAL.S16 q9, d9, d1[2] + + # vacc1x0123 += vb0123 * va1[6] + VMLAL.S16 q10, d8, d3[2] + # vacc1x4567 += vb4567 * va1[6] + VMLAL.S16 q11, d9, d3[2] + + # vacc2x0123 += vb0123 * va2[6] + VMLAL.S16 q12, d8, d5[2] + + # vacc2x4567 += vb4567 * va2[6] + VMLAL.S16 q13, d9, d5[2] + + # vacc3x0123 += vb0123 * va3[6] + VMLAL.S16 q14, d8, d7[2] + # vacc3x4567 += vb4567 * va3[6] + VMLAL.S16 q15, d9, d7[2] + + .p2align 4 +2: + # Load right_shift + # - q4 = d8:d9 = vright_shift + VLD1.32 {d8[], d9[]}, [r7]! + + VQRDMULH.S32 q8, q8, d12[0] + VQRDMULH.S32 q9, q9, d12[0] + VQRDMULH.S32 q10, q10, d12[0] + VQRDMULH.S32 q11, q11, d12[0] + + # Compute vzero_shift_mask + # - q5 = vzero_shift_mask + VCEQ.S32 q5, q4, 0 + + VQRDMULH.S32 q12, q12, d12[0] + VQRDMULH.S32 q13, q13, d12[0] + VQRDMULH.S32 q14, q14, d12[0] + VQRDMULH.S32 q15, q15, d12[0] + + VBIC q0, q8, q5 + VBIC q1, q9, q5 + VBIC q2, q10, q5 + VBIC q3, q11, q5 + + VSRA.S32 q8, q0, 31 + VSRA.S32 q9, q1, 31 + VSRA.S32 q10, q2, 31 + VSRA.S32 q11, q3, 31 + + # Load zero_point + # - q7 = d14:d15 = vzero_point + VLD1.16 {d14[], d15[]}, [r7]! + + VBIC q0, q12, q5 + VBIC q1, q13, q5 + VBIC q2, q14, q5 + VBIC q3, q15, q5 + + VSRA.S32 q12, q0, 31 + VSRA.S32 q13, q1, 31 + VSRA.S32 q14, q2, 31 + VSRA.S32 q15, q3, 31 + + # Load max: + # - q5 = d10:d11 = vmax + VLD1.8 {d10[], d11[]}, [r7]! + + VRSHL.S32 q8, q8, q4 + VRSHL.S32 q9, q9, q4 + VRSHL.S32 q10, q10, q4 + VRSHL.S32 q11, q11, q4 + VRSHL.S32 q12, q12, q4 + VRSHL.S32 q13, q13, q4 + VRSHL.S32 q14, q14, q4 + VRSHL.S32 q15, q15, q4 + + # Load c, c_stride: + # - r2 = c + # - r2 = c_stride + LDRD r2, r3, [sp, 88] + + VQMOVN.S32 d16, q8 + VQMOVN.S32 d17, q9 + VQMOVN.S32 d18, q10 + VQMOVN.S32 d19, q11 + VQMOVN.S32 d20, q12 + VQMOVN.S32 d21, q13 + VQMOVN.S32 d22, q14 + VQMOVN.S32 d23, q15 + + # Load min: + # - q4 = q8:q9 = vmin + VLD1.8 {d8[], d9[]}, [r7]! + ADD r4, r2, r3 + + VQADD.S16 q8, q8, q7 + VQADD.S16 q9, q9, q7 + CMP r0, 2 + VQADD.S16 q10, q10, q7 + VQADD.S16 q11, q11, q7 + MOVLO r4, r2 + + VQMOVUN.S16 d16, q8 + VQMOVUN.S16 d17, q9 + ADD r5, r4, r3 + VQMOVUN.S16 d18, q10 + VQMOVUN.S16 d19, q11 + MOVLS r5, r4 + + VMIN.U8 q8, q8, q5 + CMP r0, 4 + VMIN.U8 q9, q9, q5 + ADD r3, r5, r3 + + VMAX.U8 q8, q8, q4 + MOVNE r3, r5 + CMP r1, 8 + VMAX.U8 q9, q9, q4 + + BNE 4f + + VST1.8 {d16}, [r2] + VST1.8 {d17}, [r4] + VST1.8 {d18}, [r5] + VST1.8 {d19}, [r3] + + VPOP {d8-d15} + POP {r4, r5, r6, r7} + BX lr + + .p2align 3 +4: + CMP r1, 4 + BLO 5f + + VST1.32 {d16[0]}, [r2]! + VST1.32 {d17[0]}, [r4]! + VST1.32 {d18[0]}, [r5]! + VST1.32 {d19[0]}, [r3]! + + SUB r1, 4 + VEXT.8 q8, q8, q8, 4 + VEXT.8 q9, q9, q9, 4 + +5: + CMP r1, 2 + BLO 6f + + VST1.16 {d16[0]}, [r2]! + VST1.16 {d17[0]}, [r4]! + VST1.16 {d18[0]}, [r5]! + VST1.16 {d19[0]}, [r3]! + + SUB r1, 2 + VEXT.8 q8, q8, q8, 2 + VEXT.8 q9, q9, q9, 2 + +6: + TEQ r1, 0 + BEQ 7f + + VST1.8 {d16[0]}, [r2] + VST1.8 {d17[0]}, [r4] + VST1.8 {d18[0]}, [r5] + VST1.8 {d19[0]}, [r3] + +7: + VPOP {d8-d15} + POP {r4, r5, r6, r7} + BX lr +END_FUNCTION pytorch_q8gemm_ukernel_4x8__aarch32_neon + +#ifdef __ELF__ +.section ".note.GNU-stack","",%progbits +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-neon.c new file mode 100644 index 0000000000000..ea04afaf64bb7 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8-neon.c @@ -0,0 +1,673 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +void pytorch_q8gemm_ukernel_4x8__neon( + size_t mr, + size_t nr, + size_t k, + const uint8_t* restrict a, + size_t a_stride, + const void* restrict w, + uint8_t* restrict c, + size_t c_stride, + const union pytorch_qnnp_conv_quantization_params + quantization_params[restrict static 1]) { + int32x4_t vacc0x0123 = vld1q_s32(w); + w = (const void*)((uintptr_t)w + 16); + int32x4_t vacc0x4567 = vld1q_s32(w); + w = (const void*)((uintptr_t)w + 16); + int32x4_t vacc1x0123 = vacc0x0123; + int32x4_t vacc1x4567 = vacc0x4567; + int32x4_t vacc2x0123 = vacc0x0123; + int32x4_t vacc2x4567 = vacc0x4567; + int32x4_t vacc3x0123 = vacc0x0123; + int32x4_t vacc3x4567 = vacc0x4567; + + const uint8_t* a0 = a; + const uint8_t* a1 = (const uint8_t*)((uintptr_t)a0 + a_stride); + if (mr < 2) { + a1 = a0; + } + const uint8_t* a2 = (const uint8_t*)((uintptr_t)a1 + a_stride); + if (mr <= 2) { + a2 = a1; + } + const uint8_t* a3 = (const uint8_t*)((uintptr_t)a2 + a_stride); + if (mr != 4) { + a3 = a2; + } + + const uint8x8_t va_zero_point = + vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point); + const uint8x8_t vb_zero_point = + vld1_dup_u8((const uint8_t*)&quantization_params->neon.kernel_zero_point); + for (; k >= 8; k -= 8) { + const uint8x8_t va0 = vld1_u8(a0); + a0 += 8; + const int16x8_t vxa0 = + vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point)); + const uint8x8_t va1 = vld1_u8(a1); + a1 += 8; + const int16x8_t vxa1 = + vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point)); + const uint8x8_t va2 = vld1_u8(a2); + a2 += 8; + const int16x8_t vxa2 = + vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point)); + const uint8x8_t va3 = vld1_u8(a3); + a3 += 8; + const int16x8_t vxa3 = + vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point)); + + const uint8x8_t vb01234567c0 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c0 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + + const uint8x8_t vb01234567c1 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c1 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + + const uint8x8_t vb01234567c2 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c2 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + + const uint8x8_t vb01234567c3 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c3 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + + const uint8x8_t vb01234567c4 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c4 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa3), 0); + + const uint8x8_t vb01234567c5 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c5 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa3), 1); + + const uint8x8_t vb01234567c6 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c6 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c6, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa3), 2); + + const uint8x8_t vb01234567c7 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c7 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c7, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa3), 3); + } + if (k != 0) { + const size_t a_predecrement = 8 - k; + const int64x1_t va_shift = vmov_n_s64(-8 * a_predecrement); + const uint8x8_t va0 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a0 - a_predecrement)), va_shift)); + const int16x8_t vxa0 = + vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point)); + const uint8x8_t va1 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a1 - a_predecrement)), va_shift)); + const int16x8_t vxa1 = + vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point)); + const uint8x8_t va2 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a2 - a_predecrement)), va_shift)); + const int16x8_t vxa2 = + vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point)); + const uint8x8_t va3 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a3 - a_predecrement)), va_shift)); + const int16x8_t vxa3 = + vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point)); + + const uint8x8_t vb01234567c0 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c0 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + + if (k >= 2) { + const uint8x8_t vb01234567c1 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c1 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + + if (k >= 3) { + const uint8x8_t vb01234567c2 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c2 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + + if (k >= 4) { + const uint8x8_t vb01234567c3 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c3 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + + if (k >= 5) { + const uint8x8_t vb01234567c4 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c4 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, + vget_low_s16(vxb01234567c4), + vget_high_s16(vxa0), + 0); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, + vget_high_s16(vxb01234567c4), + vget_high_s16(vxa0), + 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, + vget_low_s16(vxb01234567c4), + vget_high_s16(vxa1), + 0); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, + vget_high_s16(vxb01234567c4), + vget_high_s16(vxa1), + 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, + vget_low_s16(vxb01234567c4), + vget_high_s16(vxa2), + 0); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, + vget_high_s16(vxb01234567c4), + vget_high_s16(vxa2), + 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, + vget_low_s16(vxb01234567c4), + vget_high_s16(vxa3), + 0); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, + vget_high_s16(vxb01234567c4), + vget_high_s16(vxa3), + 0); + + if (k >= 6) { + const uint8x8_t vb01234567c5 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c5 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, + vget_low_s16(vxb01234567c5), + vget_high_s16(vxa0), + 1); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, + vget_high_s16(vxb01234567c5), + vget_high_s16(vxa0), + 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, + vget_low_s16(vxb01234567c5), + vget_high_s16(vxa1), + 1); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, + vget_high_s16(vxb01234567c5), + vget_high_s16(vxa1), + 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, + vget_low_s16(vxb01234567c5), + vget_high_s16(vxa2), + 1); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, + vget_high_s16(vxb01234567c5), + vget_high_s16(vxa2), + 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, + vget_low_s16(vxb01234567c5), + vget_high_s16(vxa3), + 1); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, + vget_high_s16(vxb01234567c5), + vget_high_s16(vxa3), + 1); + + if (k >= 7) { + const uint8x8_t vb01234567c6 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c6 = vreinterpretq_s16_u16( + vsubl_u8(vb01234567c6, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, + vget_low_s16(vxb01234567c6), + vget_high_s16(vxa0), + 2); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, + vget_high_s16(vxb01234567c6), + vget_high_s16(vxa0), + 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, + vget_low_s16(vxb01234567c6), + vget_high_s16(vxa1), + 2); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, + vget_high_s16(vxb01234567c6), + vget_high_s16(vxa1), + 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, + vget_low_s16(vxb01234567c6), + vget_high_s16(vxa2), + 2); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, + vget_high_s16(vxb01234567c6), + vget_high_s16(vxa2), + 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, + vget_low_s16(vxb01234567c6), + vget_high_s16(vxa3), + 2); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, + vget_high_s16(vxb01234567c6), + vget_high_s16(vxa3), + 2); + } + } + } + } + } + } + } + + const int32x4_t vmultiplier = + vld1q_dup_s32(&quantization_params->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier); + vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier); + vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier); + vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier); + vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier); + + const int32x4_t vright_shift = + vld1q_dup_s32(&quantization_params->neon.right_shift); + const int32x4_t vzero_shift_mask = + vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = + vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = + vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc1x0123 = + vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31); + vacc1x4567 = + vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31); + vacc2x0123 = + vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31); + vacc2x4567 = + vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31); + vacc3x0123 = + vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31); + vacc3x4567 = + vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift); + vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift); + vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift); + vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift); + vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift); + + const int16x8_t voutput_zero_point = + vld1q_dup_s16(&quantization_params->neon.output_zero_point); +#ifdef __aarch64__ + const int16x8_t vacc0x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point); + + uint8x16_t vout0x01234567_1x01234567 = + vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc1x01234567); + uint8x16_t vout2x01234567_3x01234567 = + vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc3x01234567); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), + voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), + voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), + voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), + voutput_zero_point); + + uint8x16_t vout0x01234567_1x01234567 = + vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc1x01234567)); + uint8x16_t vout2x01234567_3x01234567 = + vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc3x01234567)); +#endif + const uint8x16_t voutput_min = + vld1q_dup_u8(&quantization_params->neon.output_min); + const uint8x16_t voutput_max = + vld1q_dup_u8(&quantization_params->neon.output_max); + + vout0x01234567_1x01234567 = vmaxq_u8(vout0x01234567_1x01234567, voutput_min); + vout2x01234567_3x01234567 = vmaxq_u8(vout2x01234567_3x01234567, voutput_min); + vout0x01234567_1x01234567 = vminq_u8(vout0x01234567_1x01234567, voutput_max); + vout2x01234567_3x01234567 = vminq_u8(vout2x01234567_3x01234567, voutput_max); + + uint8_t* c0 = c; + uint8_t* c1 = (uint8_t*)((uintptr_t)c0 + c_stride); + if (mr < 2) { + c1 = c0; + } + uint8_t* c2 = (uint8_t*)((uintptr_t)c1 + c_stride); + if (mr <= 2) { + c2 = c1; + } + uint8_t* c3 = (uint8_t*)((uintptr_t)c2 + c_stride); + if (mr != 4) { + c3 = c2; + } + if (nr == 8) { + vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567)); + vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567)); + vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567)); + vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567)); + } else { + if (nr >= 4) { + vst1q_lane_u32( + __builtin_assume_aligned(c0, 1), + vreinterpretq_u32_u8(vout0x01234567_1x01234567), + 0); + c0 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c1, 1), + vreinterpretq_u32_u8(vout0x01234567_1x01234567), + 2); + c1 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c2, 1), + vreinterpretq_u32_u8(vout2x01234567_3x01234567), + 0); + c2 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c3, 1), + vreinterpretq_u32_u8(vout2x01234567_3x01234567), + 2); + c3 += 4; + vout0x01234567_1x01234567 = + vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + vout2x01234567_3x01234567 = + vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4); + nr -= 4; + } + if (nr >= 2) { + vst1q_lane_u16( + __builtin_assume_aligned(c0, 1), + vreinterpretq_u16_u8(vout0x01234567_1x01234567), + 0); + c0 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c1, 1), + vreinterpretq_u16_u8(vout0x01234567_1x01234567), + 4); + c1 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c2, 1), + vreinterpretq_u16_u8(vout2x01234567_3x01234567), + 0); + c2 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c3, 1), + vreinterpretq_u16_u8(vout2x01234567_3x01234567), + 4); + c3 += 2; + vout0x01234567_1x01234567 = + vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + vout2x01234567_3x01234567 = + vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2); + nr -= 2; + } + if (nr != 0) { + vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0); + vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8); + vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0); + vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8c2-xzp-aarch32-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8c2-xzp-aarch32-neon.S new file mode 100644 index 0000000000000..9180621507f80 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8c2-xzp-aarch32-neon.S @@ -0,0 +1,618 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +.syntax unified + +# void pytorch_q8gemm_xzp_ukernel_4x8c2__neon( +# size_t mr, +# size_t nr, +# size_t k, +# const uint8_t* restrict a, +# size_t a_stride, +# const int32_t* restrict a_sum, +# const void* restrict w, +# uint8_t* restrict c, +# size_t c_stride, +# const union pytorch_qnnp_q31_requantization_params requantization_params[restrict static 1]) +BEGIN_FUNCTION pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon + .arm +#ifndef __APPLE__ + .arch armv7-a + .fpu neon +#endif + + # Load w + # - ip = w + LDR ip, [sp, 8] + + # Load bias0123(q8), bias4567(q9) + # q8 := vacc0x0123 + # q9 := vacc0x4567 + VLD1.8 {d16-d19}, [ip]! + + # q10 := vacc1x0123 + VMOV.I32 q10, q8 + # q11 := vacc1x4567 + VMOV.I32 q11, q9 + # q12 := vacc2x0123 + VMOV.I32 q12, q8 + # q13 := vacc2x4567 + VMOV.I32 q13, q9 + # q14 := vacc3x0123 + VMOV.I32 q14, q8 + # q15 := vacc3x4567 + VMOV.I32 q15, q9 + + PUSH {r4, r5, r6, r7, r8, r9, r10, r11} + VPUSH {d8-d15} + + # r3 := a0 + # r4 := a1 + # r5 := a2 + # r6 := a3 + + # r7 := a_sum0 + # r8 := a_sum1 + # r9 := a_sum2 + # r10 := a_sum3 + + # a_sum0 := a_sum + LDR r7, [sp, 100] + + # Load a_stride + # - ip = a_stride + LDR r10, [sp, 96] + + # compare mr to 2 + CMP r0, 2 + + # a1 += a_stride + ADD r4, r3, r10 + + # mr < 2, a1 := a0 + MOVLO r4, r3 + + # r8 := a_sum1 + ADD r8, r7, 4 + + # mr < 2, a_sum1 := a_sum0 + MOVLO r8, r7 + + # r5 := a2 + ADD r5, r4, r10 + # mr <= 2, a2 := a1 + MOVLS r5, r4 + + # r9 := a_sum2 + ADD r9, r8, 4 + # mr <= 2, a_sum2 := a_sum1 + MOVLS r9, r8 + + # compare mr to 4 + CMP r0, 4 + + # r6 := a3 + ADD r6, r5, r10 + # mr != 4, a3 := a2 + MOVNE r6, r5 + + # a_sum3 := a_sum2 + 1 + # r10 := a_sum3 + ADD r10, r9, 4 + # mr != 4, a_sum3 := a_sum2 + MOVNE r10, r9 + + # load a_sum + # q0: va_sum0 + VLD1.32 {d0[], d1[]}, [r7] + # q1: va_sum1 + VLD1.32 {d2[], d3[]}, [r8] + # q2: va_sum2 + VLD1.32 {d4[], d5[]}, [r9] + # q3: va_sum3 + VLD1.32 {d6[], d7[]}, [r10] + + # accumulate a_sum into vacc + # vacc0x0123 = vaddq_s32(vacc0x0123, va_sum0) + VADD.I32 q8, q8, q0 + # vacc0x4567 = vaddq_s32(vacc0x4567, va_sum0) + VADD.I32 q9, q9, q0 + # vacc1x0123 = vaddq_s32(vacc1x0123, va_sum1) + VADD.I32 q10, q10, q1 + # vacc1x4567 = vaddq_s32(vacc1x4567, va_sum1) + VADD.I32 q11, q11, q1 + # vacc2x0123 = vaddq_s32(vacc2x0123, va_sum2) + VADD.I32 q12, q12, q2 + # vacc2x4567 = vaddq_s32(vacc2x4567, va_sum2) + VADD.I32 q13, q13, q2 + # vacc3x0123 = vaddq_s32(vacc3x0123, va_sum3) + VADD.I32 q14, q14, q3 + # vacc3x4567 = vaddq_s32(vacc3x4567, va_sum3) + VADD.I32 q15, q15, q3 + + # k -= 8 + SUBS r2, r2, 8 + + BLO 1f + +.p2align 5 +0: + # load a + # d0 := va0x01234567 + VLD1.8 {d0}, [r3]! + + # d1 := va1x01234567 + VLD1.8 {d1}, [r4]! + + # d2 := va1x01234567 + VLD1.8 {d2}, [r5]! + + # d3 := va2x01234567 + VLD1.8 {d3}, [r6]! + + ##### k = 0, 1 ##### + # load b + # q2 := vb01234567x01 + VLD1.8 {d4, d5}, [ip]! + + VMULL.U8 q4, d0, d4 + VPADAL.U16 q8, q4 + + VMULL.U8 q5, d0, d5 + VPADAL.U16 q9, q5 + + VMULL.U8 q6, d1, d4 + VPADAL.U16 q10, q6 + + VMULL.U8 q7, d1, d5 + VPADAL.U16 q11, q7 + + VMULL.U8 q4, d2, d4 + VPADAL.U16 q12, q4 + + VMULL.U8 q5, d2, d5 + VPADAL.U16 q13, q5 + + VMULL.U8 q6, d3, d4 + VPADAL.U16 q14, q6 + + VMULL.U8 q7, d3, d5 + VPADAL.U16 q15, q7 + + ##### k = 2, 3 ##### + # load b + # q2 := vb01234567x01 + VLD1.8 {d4, d5}, [ip]! + + # rotate a + VEXT.8 d0, d0, d0, 2 + VEXT.8 d1, d1, d1, 2 + VEXT.8 d2, d2, d2, 2 + VEXT.8 d3, d3, d3, 2 + + VMULL.U8 q4, d0, d4 + VPADAL.U16 q8, q4 + + VMULL.U8 q5, d0, d5 + VPADAL.U16 q9, q5 + + VMULL.U8 q6, d1, d4 + VPADAL.U16 q10, q6 + + VMULL.U8 q7, d1, d5 + VPADAL.U16 q11, q7 + + VMULL.U8 q4, d2, d4 + VPADAL.U16 q12, q4 + + VMULL.U8 q5, d2, d5 + VPADAL.U16 q13, q5 + + VMULL.U8 q6, d3, d4 + VPADAL.U16 q14, q6 + + VMULL.U8 q7, d3, d5 + VPADAL.U16 q15, q7 + + ##### k = 4, 5 ##### + # load b + # q2 := vb01234567x01 + VLD1.8 {d4, d5}, [ip]! + + # rotate a + VEXT.8 d0, d0, d0, 2 + VEXT.8 d1, d1, d1, 2 + VEXT.8 d2, d2, d2, 2 + VEXT.8 d3, d3, d3, 2 + + VMULL.U8 q4, d0, d4 + VPADAL.U16 q8, q4 + + VMULL.U8 q5, d0, d5 + VPADAL.U16 q9, q5 + + VMULL.U8 q6, d1, d4 + VPADAL.U16 q10, q6 + + VMULL.U8 q7, d1, d5 + VPADAL.U16 q11, q7 + + VMULL.U8 q4, d2, d4 + VPADAL.U16 q12, q4 + + VMULL.U8 q5, d2, d5 + VPADAL.U16 q13, q5 + + VMULL.U8 q6, d3, d4 + VPADAL.U16 q14, q6 + + VMULL.U8 q7, d3, d5 + VPADAL.U16 q15, q7 + + ##### k = 6, 7 ##### + # load b + # q2 := vb01234567x01 + VLD1.8 {d4, d5}, [ip]! + + # rotate a + VEXT.8 d0, d0, d0, 2 + VEXT.8 d1, d1, d1, 2 + VEXT.8 d2, d2, d2, 2 + VEXT.8 d3, d3, d3, 2 + + VMULL.U8 q4, d0, d4 + VPADAL.U16 q8, q4 + + VMULL.U8 q5, d0, d5 + VPADAL.U16 q9, q5 + + VMULL.U8 q6, d1, d4 + VPADAL.U16 q10, q6 + + VMULL.U8 q7, d1, d5 + VPADAL.U16 q11, q7 + + VMULL.U8 q4, d2, d4 + VPADAL.U16 q12, q4 + + VMULL.U8 q5, d2, d5 + VPADAL.U16 q13, q5 + + VMULL.U8 q6, d3, d4 + VPADAL.U16 q14, q6 + + VMULL.U8 q7, d3, d5 + VPADAL.U16 q15, q7 + + # k -= 8 + SUBS r2, r2, 8 + + # k >= 0, loop + BHS 0b + +1: + # k >= 4 + ADDS r2, 8 + CMP r2, 4 + + # branch to 2f when k < 4 + BLO 2f + SUB r2, r2, 4 + + ##### k = 0, 1 ##### + # d0 := va0x01010101 + VLD1.16 {d0[]}, [r3]! + # d1 := va1x01010101 + VLD1.16 {d1[]}, [r4]! + # d2 := va2x01010101 + VLD1.16 {d2[]}, [r5]! + # d3 := va3x01010101 + VLD1.16 {d3[]}, [r6]! + + # q7 := vb01234567x01 + VLD1.8 {d14, d15}, [ip]! + + # row 0 + VMULL.U8 q2, d0, d14 + VPADAL.U16 q8, q2 + VMULL.U8 q3, d0, d15 + VPADAL.U16 q9, q3 + # row 1 + VMULL.U8 q4, d1, d14 + VPADAL.U16 q10, q4 + VMULL.U8 q5, d1, d15 + VPADAL.U16 q11, q5 + # row 2 + VMULL.U8 q2, d2, d14 + VPADAL.U16 q12, q2 + VMULL.U8 q3, d2, d15 + VPADAL.U16 q13, q3 + # row 3 + VMULL.U8 q4, d3, d14 + VPADAL.U16 q14, q4 + VMULL.U8 q5, d3, d15 + VPADAL.U16 q15, q5 + + ##### k = 2, 3 ##### + # d0 := va0x01010101 + VLD1.16 {d0[]}, [r3]! + # d1 := va1x01010101 + VLD1.16 {d1[]}, [r4]! + # d2 := va2x01010101 + VLD1.16 {d2[]}, [r5]! + # d3 := va3x01010101 + VLD1.16 {d3[]}, [r6]! + + # q7 := vb01234567x01 + VLD1.8 {d14, d15}, [ip]! + + # row 0 + VMULL.U8 q2, d0, d14 + VPADAL.U16 q8, q2 + VMULL.U8 q3, d0, d15 + VPADAL.U16 q9, q3 + # row 1 + VMULL.U8 q4, d1, d14 + VPADAL.U16 q10, q4 + VMULL.U8 q5, d1, d15 + VPADAL.U16 q11, q5 + # row 2 + VMULL.U8 q2, d2, d14 + VPADAL.U16 q12, q2 + VMULL.U8 q3, d2, d15 + VPADAL.U16 q13, q3 + # row 3 + VMULL.U8 q4, d3, d14 + VPADAL.U16 q14, q4 + VMULL.U8 q5, d3, d15 + VPADAL.U16 q15, q5 + +2: + # k >= 2 + CMP r2, 2 + BLO 3f + SUB r2, r2, 2 + + ##### k = 0, 1 ##### + # d0 := va0x01010101 + VLD1.16 {d0[]}, [r3]! + # d1 := va1x01010101 + VLD1.16 {d1[]}, [r4]! + # d2 := va2x01010101 + VLD1.16 {d2[]}, [r5]! + # d3 := va3x01010101 + VLD1.16 {d3[]}, [r6]! + + # q7 := vb01234567x01 + VLD1.8 {d14, d15}, [ip]! + + # row 0 + VMULL.U8 q2, d0, d14 + VPADAL.U16 q8, q2 + VMULL.U8 q3, d0, d15 + VPADAL.U16 q9, q3 + # row 1 + VMULL.U8 q4, d1, d14 + VPADAL.U16 q10, q4 + VMULL.U8 q5, d1, d15 + VPADAL.U16 q11, q5 + # row 2 + VMULL.U8 q2, d2, d14 + VPADAL.U16 q12, q2 + VMULL.U8 q3, d2, d15 + VPADAL.U16 q13, q3 + # row 3 + VMULL.U8 q4, d3, d14 + VPADAL.U16 q14, q4 + VMULL.U8 q5, d3, d15 + VPADAL.U16 q15, q5 + +3: + # k == 1 + CMP r2, 1 + BLO 4f + + # d0 := va0x01010101 + VLD1.8 {d0[]}, [r3] + # d1 := va1x01010101 + VLD1.8 {d1[]}, [r4] + # d2 := va2x01010101 + VLD1.8 {d2[]}, [r5] + # d3 := va3x01010101 + VLD1.8 {d3[]}, [r6] + + # q7 := vb01234567x01 + VLD1.8 {d14, d15}, [ip] + + # row 0 + VMULL.U8 q2, d0, d14 + VPADAL.U16 q8, q2 + VMULL.U8 q3, d0, d15 + VPADAL.U16 q9, q3 + # row 1 + VMULL.U8 q4, d1, d14 + VPADAL.U16 q10, q4 + VMULL.U8 q5, d1, d15 + VPADAL.U16 q11, q5 + # row 2 + VMULL.U8 q2, d2, d14 + VPADAL.U16 q12, q2 + VMULL.U8 q3, d2, d15 + VPADAL.U16 q13, q3 + # row 3 + VMULL.U8 q4, d3, d14 + VPADAL.U16 q14, q4 + VMULL.U8 q5, d3, d15 + VPADAL.U16 q15, q5 + + .p2align 4 +4: + # Load params: + # - ip = params + LDR ip, [sp, 116] + + # Load multiplier: + # - d12 = vmultiplier + VLD1.32 {d12[]}, [ip]! + + # Load right_shift + # - q4 = d8:d9 = vright_shift + VLD1.32 {d8[], d9[]}, [ip]! + + VQRDMULH.S32 q8, q8, d12[0] + VQRDMULH.S32 q9, q9, d12[0] + VQRDMULH.S32 q10, q10, d12[0] + VQRDMULH.S32 q11, q11, d12[0] + + # Compute vzero_shift_mask + # - q5 = vzero_shift_mask + VCEQ.S32 q5, q4, 0 + + VQRDMULH.S32 q12, q12, d12[0] + VQRDMULH.S32 q13, q13, d12[0] + VQRDMULH.S32 q14, q14, d12[0] + VQRDMULH.S32 q15, q15, d12[0] + + VBIC q0, q8, q5 + VBIC q1, q9, q5 + VBIC q2, q10, q5 + VBIC q3, q11, q5 + + VSRA.S32 q8, q0, 31 + VSRA.S32 q9, q1, 31 + VSRA.S32 q10, q2, 31 + VSRA.S32 q11, q3, 31 + + # Load zero_point + # - q7 = d14:d15 = vzero_point + VLD1.16 {d14[], d15[]}, [ip]! + + VBIC q0, q12, q5 + VBIC q1, q13, q5 + VBIC q2, q14, q5 + VBIC q3, q15, q5 + + VSRA.S32 q12, q0, 31 + VSRA.S32 q13, q1, 31 + VSRA.S32 q14, q2, 31 + VSRA.S32 q15, q3, 31 + + # Load max: + # - q5 = d10:d11 = vmax + VLD1.8 {d10[], d11[]}, [ip]! + + VRSHL.S32 q8, q8, q4 + VRSHL.S32 q9, q9, q4 + VRSHL.S32 q10, q10, q4 + VRSHL.S32 q11, q11, q4 + VRSHL.S32 q12, q12, q4 + VRSHL.S32 q13, q13, q4 + VRSHL.S32 q14, q14, q4 + VRSHL.S32 q15, q15, q4 + + # Load c, c_stride: + # - r2 = c + # - r3 = c_stride + LDRD r2, r3, [sp, 108] + + VQMOVN.S32 d16, q8 + VQMOVN.S32 d17, q9 + VQMOVN.S32 d18, q10 + VQMOVN.S32 d19, q11 + VQMOVN.S32 d20, q12 + VQMOVN.S32 d21, q13 + VQMOVN.S32 d22, q14 + VQMOVN.S32 d23, q15 + + # Load min: + # - q4 = q8:q9 = vmin + VLD1.8 {d8[], d9[]}, [ip]! + ADD r4, r2, r3 + + VQADD.S16 q8, q8, q7 + VQADD.S16 q9, q9, q7 + CMP r0, 2 + VQADD.S16 q10, q10, q7 + VQADD.S16 q11, q11, q7 + MOVLO r4, r2 + + VQMOVUN.S16 d16, q8 + VQMOVUN.S16 d17, q9 + ADD r5, r4, r3 + VQMOVUN.S16 d18, q10 + VQMOVUN.S16 d19, q11 + MOVLS r5, r4 + + VMIN.U8 q8, q8, q5 + CMP r0, 4 + VMIN.U8 q9, q9, q5 + ADD r3, r5, r3 + + VMAX.U8 q8, q8, q4 + MOVNE r3, r5 + CMP r1, 8 + VMAX.U8 q9, q9, q4 + + BNE 5f + + VST1.8 {d16}, [r2] + VST1.8 {d17}, [r4] + VST1.8 {d18}, [r5] + VST1.8 {d19}, [r3] + + VPOP {d8-d15} + POP {r4, r5, r6, r7, r8, r9, r10, r11} + BX lr + + .p2align 3 +5: + CMP r1, 4 + BLO 6f + + VST1.32 {d16[0]}, [r2]! + VST1.32 {d17[0]}, [r4]! + VST1.32 {d18[0]}, [r5]! + VST1.32 {d19[0]}, [r3]! + + SUB r1, 4 + VEXT.8 q8, q8, q8, 4 + VEXT.8 q9, q9, q9, 4 + +6: + CMP r1, 2 + BLO 7f + + VST1.16 {d16[0]}, [r2]! + VST1.16 {d17[0]}, [r4]! + VST1.16 {d18[0]}, [r5]! + VST1.16 {d19[0]}, [r3]! + + SUB r1, 2 + VEXT.8 q8, q8, q8, 2 + VEXT.8 q9, q9, q9, 2 + +7: + TEQ r1, 0 + BEQ 8f + + VST1.8 {d16[0]}, [r2] + VST1.8 {d17[0]}, [r4] + VST1.8 {d18[0]}, [r5] + VST1.8 {d19[0]}, [r3] +8: + VPOP {d8-d15} + POP {r4, r5, r6, r7, r8, r9, r10, r11} + BX lr + +END_FUNCTION pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon + +#ifdef __ELF__ +.section ".note.GNU-stack","",%progbits +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8c2-xzp-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8c2-xzp-neon.c new file mode 100644 index 0000000000000..a0b196b855cdb --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/4x8c2-xzp-neon.c @@ -0,0 +1,543 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_q8gemm_xzp_ukernel_4x8c2__neon( + size_t mr, + size_t nr, + size_t k, + const uint8_t* restrict a, + size_t a_stride, + const int32_t* restrict a_sum, + const void* restrict w, + uint8_t* restrict c, + size_t c_stride, + const union pytorch_qnnp_q31_requantization_params + requantization_params[restrict static 1]) { + int32x4_t vacc0x0123 = vld1q_s32(w); + w = (const void*)((uintptr_t)w + 16); + int32x4_t vacc0x4567 = vld1q_s32(w); + w = (const void*)((uintptr_t)w + 16); + int32x4_t vacc1x0123 = vacc0x0123; + int32x4_t vacc1x4567 = vacc0x4567; + int32x4_t vacc2x0123 = vacc0x0123; + int32x4_t vacc2x4567 = vacc0x4567; + int32x4_t vacc3x0123 = vacc0x0123; + int32x4_t vacc3x4567 = vacc0x4567; + + const uint8_t* a0 = a; + const uint8_t* a1 = a0; + const int32_t* a_sum0 = a_sum; + const int32_t* a_sum1 = a_sum0; + if (mr >= 2) { + a1 += a_stride; + a_sum1 += 1; + } + const uint8_t* a2 = a1; + const int32_t* a_sum2 = a_sum1; + if (mr > 2) { + a2 += a_stride; + a_sum2 += 1; + } + const uint8_t* a3 = a2; + const int32_t* a_sum3 = a_sum2; + if (mr == 4) { + a3 += a_stride; + a_sum3 += 1; + } + + const int32x4_t va_sum0 = vld1q_dup_s32(a_sum0); + const int32x4_t va_sum1 = vld1q_dup_s32(a_sum1); + const int32x4_t va_sum2 = vld1q_dup_s32(a_sum2); + const int32x4_t va_sum3 = vld1q_dup_s32(a_sum3); + vacc0x0123 = vaddq_s32(vacc0x0123, va_sum0); + vacc0x4567 = vaddq_s32(vacc0x4567, va_sum0); + vacc1x0123 = vaddq_s32(vacc1x0123, va_sum1); + vacc1x4567 = vaddq_s32(vacc1x4567, va_sum1); + vacc2x0123 = vaddq_s32(vacc2x0123, va_sum2); + vacc2x4567 = vaddq_s32(vacc2x4567, va_sum2); + vacc3x0123 = vaddq_s32(vacc3x0123, va_sum3); + vacc3x4567 = vaddq_s32(vacc3x4567, va_sum3); + + for (; k >= 8; k -= 8) { + uint8x8_t va0x01234567 = vld1_u8(a0); + a0 += 8; + uint8x8_t va1x01234567 = vld1_u8(a1); + a1 += 8; + uint8x8_t va2x01234567 = vld1_u8(a2); + a2 += 8; + uint8x8_t va3x01234567 = vld1_u8(a3); + a3 += 8; + + /* k = 0, 1 */ + const uint8x16_t vb01234567x01 = vld1q_u8(w); + w += 16; + + vacc0x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc0x0123), + vmull_u8(va0x01234567, vget_low_u8(vb01234567x01)))); + vacc0x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc0x4567), + vmull_u8(va0x01234567, vget_high_u8(vb01234567x01)))); + + vacc1x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc1x0123), + vmull_u8(va1x01234567, vget_low_u8(vb01234567x01)))); + vacc1x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc1x4567), + vmull_u8(va1x01234567, vget_high_u8(vb01234567x01)))); + + vacc2x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc2x0123), + vmull_u8(va2x01234567, vget_low_u8(vb01234567x01)))); + vacc2x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc2x4567), + vmull_u8(va2x01234567, vget_high_u8(vb01234567x01)))); + + vacc3x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc3x0123), + vmull_u8(va3x01234567, vget_low_u8(vb01234567x01)))); + vacc3x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc3x4567), + vmull_u8(va3x01234567, vget_high_u8(vb01234567x01)))); + + /* k = 2, 3 */ + va0x01234567 = vext_u8(va0x01234567, va0x01234567, 2); + va1x01234567 = vext_u8(va1x01234567, va1x01234567, 2); + va2x01234567 = vext_u8(va2x01234567, va2x01234567, 2); + va3x01234567 = vext_u8(va3x01234567, va3x01234567, 2); + + const uint8x16_t vb01234567x23 = vld1q_u8(w); + w += 16; + + vacc0x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc0x0123), + vmull_u8(va0x01234567, vget_low_u8(vb01234567x23)))); + vacc0x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc0x4567), + vmull_u8(va0x01234567, vget_high_u8(vb01234567x23)))); + + vacc1x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc1x0123), + vmull_u8(va1x01234567, vget_low_u8(vb01234567x23)))); + vacc1x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc1x4567), + vmull_u8(va1x01234567, vget_high_u8(vb01234567x23)))); + + vacc2x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc2x0123), + vmull_u8(va2x01234567, vget_low_u8(vb01234567x23)))); + vacc2x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc2x4567), + vmull_u8(va2x01234567, vget_high_u8(vb01234567x23)))); + + vacc3x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc3x0123), + vmull_u8(va3x01234567, vget_low_u8(vb01234567x23)))); + vacc3x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc3x4567), + vmull_u8(va3x01234567, vget_high_u8(vb01234567x23)))); + + /* k = 4, 5 */ + va0x01234567 = vext_u8(va0x01234567, va0x01234567, 2); + va1x01234567 = vext_u8(va1x01234567, va1x01234567, 2); + va2x01234567 = vext_u8(va2x01234567, va2x01234567, 2); + va3x01234567 = vext_u8(va3x01234567, va3x01234567, 2); + + const uint8x16_t vb01234567x45 = vld1q_u8(w); + w += 16; + + vacc0x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc0x0123), + vmull_u8(va0x01234567, vget_low_u8(vb01234567x45)))); + vacc0x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc0x4567), + vmull_u8(va0x01234567, vget_high_u8(vb01234567x45)))); + + vacc1x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc1x0123), + vmull_u8(va1x01234567, vget_low_u8(vb01234567x45)))); + vacc1x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc1x4567), + vmull_u8(va1x01234567, vget_high_u8(vb01234567x45)))); + + vacc2x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc2x0123), + vmull_u8(va2x01234567, vget_low_u8(vb01234567x45)))); + vacc2x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc2x4567), + vmull_u8(va2x01234567, vget_high_u8(vb01234567x45)))); + + vacc3x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc3x0123), + vmull_u8(va3x01234567, vget_low_u8(vb01234567x45)))); + vacc3x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc3x4567), + vmull_u8(va3x01234567, vget_high_u8(vb01234567x45)))); + + /* k = 6, 7 */ + va0x01234567 = vext_u8(va0x01234567, va0x01234567, 2); + va1x01234567 = vext_u8(va1x01234567, va1x01234567, 2); + va2x01234567 = vext_u8(va2x01234567, va2x01234567, 2); + va3x01234567 = vext_u8(va3x01234567, va3x01234567, 2); + + const uint8x16_t vb01234567x67 = vld1q_u8(w); + w += 16; + + vacc0x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc0x0123), + vmull_u8(va0x01234567, vget_low_u8(vb01234567x67)))); + vacc0x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc0x4567), + vmull_u8(va0x01234567, vget_high_u8(vb01234567x67)))); + + vacc1x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc1x0123), + vmull_u8(va1x01234567, vget_low_u8(vb01234567x67)))); + vacc1x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc1x4567), + vmull_u8(va1x01234567, vget_high_u8(vb01234567x67)))); + + vacc2x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc2x0123), + vmull_u8(va2x01234567, vget_low_u8(vb01234567x67)))); + vacc2x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc2x4567), + vmull_u8(va2x01234567, vget_high_u8(vb01234567x67)))); + + vacc3x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc3x0123), + vmull_u8(va3x01234567, vget_low_u8(vb01234567x67)))); + vacc3x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc3x4567), + vmull_u8(va3x01234567, vget_high_u8(vb01234567x67)))); + } + + /* for k < 8, reuse the packing scheme for the original xzp ukernel */ + if (k & 4) { + /* k = 0, 1 */ + const uint8x8_t va0x01010101 = vreinterpret_u8_u16( + vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a0, 1))); + a0 += 2; + const uint8x8_t va1x01010101 = vreinterpret_u8_u16( + vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a1, 1))); + a1 += 2; + const uint8x8_t va2x01010101 = vreinterpret_u8_u16( + vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a2, 1))); + a2 += 2; + const uint8x8_t va3x01010101 = vreinterpret_u8_u16( + vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a3, 1))); + a3 += 2; + const uint8x16_t vb01234567x01 = vld1q_u8(w); + w += 16; + vacc0x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc0x0123), + vmull_u8(va0x01010101, vget_low_u8(vb01234567x01)))); + vacc0x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc0x4567), + vmull_u8(va0x01010101, vget_high_u8(vb01234567x01)))); + vacc1x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc1x0123), + vmull_u8(va1x01010101, vget_low_u8(vb01234567x01)))); + vacc1x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc1x4567), + vmull_u8(va1x01010101, vget_high_u8(vb01234567x01)))); + vacc2x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc2x0123), + vmull_u8(va2x01010101, vget_low_u8(vb01234567x01)))); + vacc2x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc2x4567), + vmull_u8(va2x01010101, vget_high_u8(vb01234567x01)))); + vacc3x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc3x0123), + vmull_u8(va3x01010101, vget_low_u8(vb01234567x01)))); + vacc3x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc3x4567), + vmull_u8(va3x01010101, vget_high_u8(vb01234567x01)))); + + /* k = 2, 3 */ + const uint8x8_t va0x23232323 = vreinterpret_u8_u16( + vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a0, 1))); + a0 += 2; + const uint8x8_t va1x23232323 = vreinterpret_u8_u16( + vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a1, 1))); + a1 += 2; + const uint8x8_t va2x23232323 = vreinterpret_u8_u16( + vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a2, 1))); + a2 += 2; + const uint8x8_t va3x23232323 = vreinterpret_u8_u16( + vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a3, 1))); + a3 += 2; + const uint8x16_t vb01234567x23 = vld1q_u8(w); + w += 16; + vacc0x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc0x0123), + vmull_u8(va0x23232323, vget_low_u8(vb01234567x23)))); + vacc0x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc0x4567), + vmull_u8(va0x23232323, vget_high_u8(vb01234567x23)))); + vacc1x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc1x0123), + vmull_u8(va1x23232323, vget_low_u8(vb01234567x23)))); + vacc1x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc1x4567), + vmull_u8(va1x23232323, vget_high_u8(vb01234567x23)))); + vacc2x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc2x0123), + vmull_u8(va2x23232323, vget_low_u8(vb01234567x23)))); + vacc2x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc2x4567), + vmull_u8(va2x23232323, vget_high_u8(vb01234567x23)))); + vacc3x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc3x0123), + vmull_u8(va3x23232323, vget_low_u8(vb01234567x23)))); + vacc3x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc3x4567), + vmull_u8(va3x23232323, vget_high_u8(vb01234567x23)))); + } + if (k & 2) { + /* k = 0, 1 */ + const uint8x8_t va0x01010101 = vreinterpret_u8_u16( + vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a0, 1))); + a0 += 2; + const uint8x8_t va1x01010101 = vreinterpret_u8_u16( + vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a1, 1))); + a1 += 2; + const uint8x8_t va2x01010101 = vreinterpret_u8_u16( + vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a2, 1))); + a2 += 2; + const uint8x8_t va3x01010101 = vreinterpret_u8_u16( + vld1_dup_u16(__builtin_assume_aligned((const uint16_t*)a3, 1))); + a3 += 2; + const uint8x16_t vb01234567x01 = vld1q_u8(w); + w += 16; + vacc0x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc0x0123), + vmull_u8(va0x01010101, vget_low_u8(vb01234567x01)))); + vacc0x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc0x4567), + vmull_u8(va0x01010101, vget_high_u8(vb01234567x01)))); + vacc1x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc1x0123), + vmull_u8(va1x01010101, vget_low_u8(vb01234567x01)))); + vacc1x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc1x4567), + vmull_u8(va1x01010101, vget_high_u8(vb01234567x01)))); + vacc2x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc2x0123), + vmull_u8(va2x01010101, vget_low_u8(vb01234567x01)))); + vacc2x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc2x4567), + vmull_u8(va2x01010101, vget_high_u8(vb01234567x01)))); + vacc3x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc3x0123), + vmull_u8(va3x01010101, vget_low_u8(vb01234567x01)))); + vacc3x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc3x4567), + vmull_u8(va3x01010101, vget_high_u8(vb01234567x01)))); + } + if (k & 1) { + const uint8x8_t va0x00000000 = vld1_dup_u8(a0); + const uint8x8_t va1x00000000 = vld1_dup_u8(a1); + const uint8x8_t va2x00000000 = vld1_dup_u8(a2); + const uint8x8_t va3x00000000 = vld1_dup_u8(a3); + const uint8x16_t vb01234567x0 = vld1q_u8(w); + vacc0x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc0x0123), + vmull_u8(va0x00000000, vget_low_u8(vb01234567x0)))); + vacc0x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc0x4567), + vmull_u8(va0x00000000, vget_high_u8(vb01234567x0)))); + vacc1x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc1x0123), + vmull_u8(va1x00000000, vget_low_u8(vb01234567x0)))); + vacc1x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc1x4567), + vmull_u8(va1x00000000, vget_high_u8(vb01234567x0)))); + vacc2x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc2x0123), + vmull_u8(va2x00000000, vget_low_u8(vb01234567x0)))); + vacc2x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc2x4567), + vmull_u8(va2x00000000, vget_high_u8(vb01234567x0)))); + vacc3x0123 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc3x0123), + vmull_u8(va3x00000000, vget_low_u8(vb01234567x0)))); + vacc3x4567 = vreinterpretq_s32_u32(vpadalq_u16( + vreinterpretq_u32_s32(vacc3x4567), + vmull_u8(va3x00000000, vget_high_u8(vb01234567x0)))); + } + + const int32x4_t vmultiplier = + vld1q_dup_s32(&requantization_params->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier); + vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier); + vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier); + vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier); + vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier); + + const int32x4_t vright_shift = + vld1q_dup_s32(&requantization_params->neon.right_shift); + const int32x4_t vzero_shift_mask = + vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = + vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = + vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc1x0123 = + vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31); + vacc1x4567 = + vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31); + vacc2x0123 = + vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31); + vacc2x4567 = + vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31); + vacc3x0123 = + vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31); + vacc3x4567 = + vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift); + vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift); + vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift); + vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift); + vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift); + + const int16x8_t vzero_point = + vld1q_dup_s16(&requantization_params->neon.zero_point); +#ifdef __aarch64__ + const int16x8_t vacc0x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), vzero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), vzero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), vzero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), vzero_point); + + uint8x16_t vout0x01234567_1x01234567 = + vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc1x01234567); + uint8x16_t vout2x01234567_3x01234567 = + vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc3x01234567); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), + vzero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), + vzero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), + vzero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), + vzero_point); + + uint8x16_t vout0x01234567_1x01234567 = + vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc1x01234567)); + uint8x16_t vout2x01234567_3x01234567 = + vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc3x01234567)); +#endif + const uint8x16_t vmin = vld1q_dup_u8(&requantization_params->neon.min); + const uint8x16_t vmax = vld1q_dup_u8(&requantization_params->neon.max); + + vout0x01234567_1x01234567 = vmaxq_u8(vout0x01234567_1x01234567, vmin); + vout2x01234567_3x01234567 = vmaxq_u8(vout2x01234567_3x01234567, vmin); + vout0x01234567_1x01234567 = vminq_u8(vout0x01234567_1x01234567, vmax); + vout2x01234567_3x01234567 = vminq_u8(vout2x01234567_3x01234567, vmax); + + uint8_t* c0 = c; + uint8_t* c1 = c0; + if (mr >= 2) { + c1 += c_stride; + } + uint8_t* c2 = c1; + if (mr > 2) { + c2 += c_stride; + } + uint8_t* c3 = c2; + if (mr == 4) { + c3 += c_stride; + } + if (nr == 8) { + vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567)); + vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567)); + vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567)); + vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567)); + } else { + if (nr >= 4) { + vst1q_lane_u32( + __builtin_assume_aligned(c0, 1), + vreinterpretq_u32_u8(vout0x01234567_1x01234567), + 0); + c0 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c1, 1), + vreinterpretq_u32_u8(vout0x01234567_1x01234567), + 2); + c1 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c2, 1), + vreinterpretq_u32_u8(vout2x01234567_3x01234567), + 0); + c2 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c3, 1), + vreinterpretq_u32_u8(vout2x01234567_3x01234567), + 2); + c3 += 4; + vout0x01234567_1x01234567 = + vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + vout2x01234567_3x01234567 = + vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4); + nr -= 4; + } + if (nr >= 2) { + vst1q_lane_u16( + __builtin_assume_aligned(c0, 1), + vreinterpretq_u16_u8(vout0x01234567_1x01234567), + 0); + c0 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c1, 1), + vreinterpretq_u16_u8(vout0x01234567_1x01234567), + 4); + c1 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c2, 1), + vreinterpretq_u16_u8(vout2x01234567_3x01234567), + 0); + c2 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c3, 1), + vreinterpretq_u16_u8(vout2x01234567_3x01234567), + 4); + c3 += 2; + vout0x01234567_1x01234567 = + vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + vout2x01234567_3x01234567 = + vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2); + nr -= 2; + } + if (nr != 0) { + vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0); + vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8); + vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0); + vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/6x4-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/6x4-neon.c new file mode 100644 index 0000000000000..aab297060460a --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/6x4-neon.c @@ -0,0 +1,547 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +void pytorch_q8gemm_ukernel_6x4__neon( + size_t mr, + size_t nr, + size_t k, + const uint8_t* restrict a, + size_t a_stride, + const void* restrict w, + uint8_t* restrict c, + size_t c_stride, + const union pytorch_qnnp_conv_quantization_params + quantization_params[restrict static 1]) { + int32x4_t vacc0x0123 = vld1q_s32(w); + w = (const void*)((uintptr_t)w + 16); + int32x4_t vacc1x0123 = vacc0x0123; + int32x4_t vacc2x0123 = vacc0x0123; + int32x4_t vacc3x0123 = vacc0x0123; + int32x4_t vacc4x0123 = vacc0x0123; + int32x4_t vacc5x0123 = vacc0x0123; + + const uint8_t* a0 = a; + const uint8_t* a1 = (const uint8_t*)((uintptr_t)a0 + a_stride); + if (mr < 2) { + a1 = a0; + } + const uint8_t* a2 = (const uint8_t*)((uintptr_t)a1 + a_stride); + if (mr <= 2) { + a2 = a1; + } + const uint8_t* a3 = (const uint8_t*)((uintptr_t)a2 + a_stride); + if (mr < 4) { + a3 = a2; + } + const uint8_t* a4 = (const uint8_t*)((uintptr_t)a3 + a_stride); + if (mr <= 4) { + a4 = a3; + }; + const uint8_t* a5 = (const uint8_t*)((uintptr_t)a4 + a_stride); + if (mr != 6) { + a5 = a4; + } + + const uint8x8_t va_zero_point = + vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point); + const uint8x8_t vb_zero_point = + vld1_dup_u8((const uint8_t*)&quantization_params->neon.kernel_zero_point); + for (; k >= 8; k -= 8) { + const uint8x8_t va0 = vld1_u8(a0); + a0 += 8; + const int16x8_t vxa0 = + vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point)); + const uint8x8_t va1 = vld1_u8(a1); + a1 += 8; + const int16x8_t vxa1 = + vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point)); + const uint8x8_t va2 = vld1_u8(a2); + a2 += 8; + const int16x8_t vxa2 = + vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point)); + const uint8x8_t va3 = vld1_u8(a3); + a3 += 8; + const int16x8_t vxa3 = + vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point)); + const uint8x8_t va4 = vld1_u8(a4); + a4 += 8; + const int16x8_t vxa4 = + vreinterpretq_s16_u16(sub_zero_point(va4, va_zero_point)); + const uint8x8_t va5 = vld1_u8(a5); + a5 += 8; + const int16x8_t vxa5 = + vreinterpretq_s16_u16(sub_zero_point(va5, va_zero_point)); + + const uint8x8_t vb0123c01 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb0123c01 = + vreinterpretq_s16_u16(vsubl_u8(vb0123c01, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb0123c01), vget_low_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb0123c01), vget_low_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb0123c01), vget_low_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb0123c01), vget_low_s16(vxa3), 0); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb0123c01), vget_low_s16(vxa4), 0); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb0123c01), vget_low_s16(vxa5), 0); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_high_s16(vxb0123c01), vget_low_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_high_s16(vxb0123c01), vget_low_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_high_s16(vxb0123c01), vget_low_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_high_s16(vxb0123c01), vget_low_s16(vxa3), 1); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_high_s16(vxb0123c01), vget_low_s16(vxa4), 1); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_high_s16(vxb0123c01), vget_low_s16(vxa5), 1); + + const uint8x8_t vb0123c23 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb0123c23 = + vreinterpretq_s16_u16(vsubl_u8(vb0123c23, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb0123c23), vget_low_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb0123c23), vget_low_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb0123c23), vget_low_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb0123c23), vget_low_s16(vxa3), 2); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb0123c23), vget_low_s16(vxa4), 2); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb0123c23), vget_low_s16(vxa5), 2); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_high_s16(vxb0123c23), vget_low_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_high_s16(vxb0123c23), vget_low_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_high_s16(vxb0123c23), vget_low_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_high_s16(vxb0123c23), vget_low_s16(vxa3), 3); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_high_s16(vxb0123c23), vget_low_s16(vxa4), 3); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_high_s16(vxb0123c23), vget_low_s16(vxa5), 3); + + const uint8x8_t vb0123c45 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb0123c45 = + vreinterpretq_s16_u16(vsubl_u8(vb0123c45, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb0123c45), vget_high_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb0123c45), vget_high_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb0123c45), vget_high_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb0123c45), vget_high_s16(vxa3), 0); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb0123c45), vget_high_s16(vxa4), 0); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb0123c45), vget_high_s16(vxa5), 0); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_high_s16(vxb0123c45), vget_high_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_high_s16(vxb0123c45), vget_high_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_high_s16(vxb0123c45), vget_high_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_high_s16(vxb0123c45), vget_high_s16(vxa3), 1); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_high_s16(vxb0123c45), vget_high_s16(vxa4), 1); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_high_s16(vxb0123c45), vget_high_s16(vxa5), 1); + + const uint8x8_t vb0123c67 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb0123c67 = + vreinterpretq_s16_u16(vsubl_u8(vb0123c67, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb0123c67), vget_high_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb0123c67), vget_high_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb0123c67), vget_high_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb0123c67), vget_high_s16(vxa3), 2); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb0123c67), vget_high_s16(vxa4), 2); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb0123c67), vget_high_s16(vxa5), 2); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_high_s16(vxb0123c67), vget_high_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_high_s16(vxb0123c67), vget_high_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_high_s16(vxb0123c67), vget_high_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_high_s16(vxb0123c67), vget_high_s16(vxa3), 3); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_high_s16(vxb0123c67), vget_high_s16(vxa4), 3); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_high_s16(vxb0123c67), vget_high_s16(vxa5), 3); + } + if (k != 0) { + const size_t a_predecrement = 8 - k; + const int64x1_t va_shift = vmov_n_s64(-8 * a_predecrement); + const uint8x8_t va0 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a0 - a_predecrement)), va_shift)); + const int16x8_t vxa0 = + vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point)); + const uint8x8_t va1 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a1 - a_predecrement)), va_shift)); + const int16x8_t vxa1 = + vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point)); + const uint8x8_t va2 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a2 - a_predecrement)), va_shift)); + const int16x8_t vxa2 = + vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point)); + const uint8x8_t va3 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a3 - a_predecrement)), va_shift)); + const int16x8_t vxa3 = + vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point)); + const uint8x8_t va4 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a4 - a_predecrement)), va_shift)); + const int16x8_t vxa4 = + vreinterpretq_s16_u16(sub_zero_point(va4, va_zero_point)); + const uint8x8_t va5 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a5 - a_predecrement)), va_shift)); + const int16x8_t vxa5 = + vreinterpretq_s16_u16(sub_zero_point(va5, va_zero_point)); + + const uint8x8_t vb0123c0 = vreinterpret_u8_u32(vld1_dup_u32(w)); + w = (const void*)((uintptr_t)w + 4); + const int16x8_t vxb0123c0 = + vreinterpretq_s16_u16(vsubl_u8(vb0123c0, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb0123c0), vget_low_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb0123c0), vget_low_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb0123c0), vget_low_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb0123c0), vget_low_s16(vxa3), 0); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb0123c0), vget_low_s16(vxa4), 0); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb0123c0), vget_low_s16(vxa5), 0); + + if (k >= 2) { + const uint8x8_t vb0123c1 = vreinterpret_u8_u32(vld1_dup_u32(w)); + w = (const void*)((uintptr_t)w + 4); + const int16x8_t vxb0123c1 = + vreinterpretq_s16_u16(vsubl_u8(vb0123c1, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb0123c1), vget_low_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb0123c1), vget_low_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb0123c1), vget_low_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb0123c1), vget_low_s16(vxa3), 1); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb0123c1), vget_low_s16(vxa4), 1); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb0123c1), vget_low_s16(vxa5), 1); + + if (k > 2) { + const uint8x8_t vb0123c2 = vreinterpret_u8_u32(vld1_dup_u32(w)); + w = (const void*)((uintptr_t)w + 4); + const int16x8_t vxb0123c2 = + vreinterpretq_s16_u16(vsubl_u8(vb0123c2, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb0123c2), vget_low_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb0123c2), vget_low_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb0123c2), vget_low_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb0123c2), vget_low_s16(vxa3), 2); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb0123c2), vget_low_s16(vxa4), 2); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb0123c2), vget_low_s16(vxa5), 2); + + if (k >= 4) { + const uint8x8_t vb0123c3 = vreinterpret_u8_u32(vld1_dup_u32(w)); + w = (const void*)((uintptr_t)w + 4); + const int16x8_t vxb0123c3 = + vreinterpretq_s16_u16(vsubl_u8(vb0123c3, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb0123c3), vget_low_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb0123c3), vget_low_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb0123c3), vget_low_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb0123c3), vget_low_s16(vxa3), 3); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb0123c3), vget_low_s16(vxa4), 3); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb0123c3), vget_low_s16(vxa5), 3); + + if (k > 4) { + const uint8x8_t vb0123c4 = vreinterpret_u8_u32(vld1_dup_u32(w)); + w = (const void*)((uintptr_t)w + 4); + const int16x8_t vxb0123c4 = + vreinterpretq_s16_u16(vsubl_u8(vb0123c4, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb0123c4), vget_high_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb0123c4), vget_high_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb0123c4), vget_high_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb0123c4), vget_high_s16(vxa3), 0); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb0123c4), vget_high_s16(vxa4), 0); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb0123c4), vget_high_s16(vxa5), 0); + + if (k >= 6) { + const uint8x8_t vb0123c5 = vreinterpret_u8_u32(vld1_dup_u32(w)); + w = (const void*)((uintptr_t)w + 4); + const int16x8_t vxb0123c5 = + vreinterpretq_s16_u16(vsubl_u8(vb0123c5, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb0123c5), vget_high_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb0123c5), vget_high_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb0123c5), vget_high_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb0123c5), vget_high_s16(vxa3), 1); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb0123c5), vget_high_s16(vxa4), 1); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb0123c5), vget_high_s16(vxa5), 1); + + if (k > 6) { + const uint8x8_t vb0123c6 = vreinterpret_u8_u32(vld1_dup_u32(w)); + const int16x8_t vxb0123c6 = + vreinterpretq_s16_u16(vsubl_u8(vb0123c6, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, + vget_low_s16(vxb0123c6), + vget_high_s16(vxa0), + 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, + vget_low_s16(vxb0123c6), + vget_high_s16(vxa1), + 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, + vget_low_s16(vxb0123c6), + vget_high_s16(vxa2), + 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, + vget_low_s16(vxb0123c6), + vget_high_s16(vxa3), + 2); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, + vget_low_s16(vxb0123c6), + vget_high_s16(vxa4), + 2); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, + vget_low_s16(vxb0123c6), + vget_high_s16(vxa5), + 2); + } + } + } + } + } + } + } + + const int32x4_t vmultiplier = + vld1q_dup_s32(&quantization_params->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier); + vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier); + vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier); + vacc4x0123 = vqrdmulhq_s32(vacc4x0123, vmultiplier); + vacc5x0123 = vqrdmulhq_s32(vacc5x0123, vmultiplier); + + const int32x4_t vright_shift = + vld1q_dup_s32(&quantization_params->neon.right_shift); + const int32x4_t vzero_shift_mask = + vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = + vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc1x0123 = + vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31); + vacc2x0123 = + vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31); + vacc3x0123 = + vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31); + vacc4x0123 = + vsraq_n_s32(vacc4x0123, vbicq_s32(vacc4x0123, vzero_shift_mask), 31); + vacc5x0123 = + vsraq_n_s32(vacc5x0123, vbicq_s32(vacc5x0123, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift); + vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift); + vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift); + vacc4x0123 = vrshlq_s32(vacc4x0123, vright_shift); + vacc5x0123 = vrshlq_s32(vacc5x0123, vright_shift); + + const int16x8_t voutput_zero_point = + vld1q_dup_s16(&quantization_params->neon.output_zero_point); +#ifdef __aarch64__ + const int16x8_t vacc01x0123 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc1x0123), voutput_zero_point); + const int16x8_t vacc23x0123 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc3x0123), voutput_zero_point); + const int16x8_t vacc45x0123 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc4x0123), vacc5x0123), voutput_zero_point); + + uint8x16_t vout0123x0123 = + vqmovun_high_s16(vqmovun_s16(vacc01x0123), vacc23x0123); + uint8x8_t vout45x0123 = vqmovun_s16(vacc45x0123); +#else + const int16x8_t vacc01x0123 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc1x0123)), + voutput_zero_point); + const int16x8_t vacc23x0123 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc3x0123)), + voutput_zero_point); + const int16x8_t vacc45x0123 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc4x0123), vqmovn_s32(vacc5x0123)), + voutput_zero_point); + + uint8x16_t vout0123x0123 = + vcombine_u8(vqmovun_s16(vacc01x0123), vqmovun_s16(vacc23x0123)); + uint8x8_t vout45x0123 = vqmovun_s16(vacc45x0123); +#endif + const uint8x16_t voutput_min = + vld1q_dup_u8(&quantization_params->neon.output_min); + const uint8x16_t voutput_max = + vld1q_dup_u8(&quantization_params->neon.output_max); + + vout0123x0123 = vmaxq_u8(vout0123x0123, voutput_min); + vout45x0123 = vmax_u8(vout45x0123, vget_low_u8(voutput_min)); + vout0123x0123 = vminq_u8(vout0123x0123, voutput_max); + vout45x0123 = vmin_u8(vout45x0123, vget_low_u8(voutput_max)); + + uint8_t* c0 = c; + uint8_t* c1 = (uint8_t*)((uintptr_t)c0 + c_stride); + if (mr < 2) { + c1 = c0; + } + uint8_t* c2 = (uint8_t*)((uintptr_t)c1 + c_stride); + if (mr <= 2) { + c2 = c1; + } + uint8_t* c3 = (uint8_t*)((uintptr_t)c2 + c_stride); + if (mr < 4) { + c3 = c2; + } + uint8_t* c4 = (uint8_t*)((uintptr_t)c3 + c_stride); + if (mr <= 4) { + c4 = c3; + } + uint8_t* c5 = (uint8_t*)((uintptr_t)c4 + c_stride); + if (mr != 6) { + c5 = c4; + } + if (nr == 4) { + vst1q_lane_u32( + __builtin_assume_aligned(c0, 1), + vreinterpretq_u32_u8(vout0123x0123), + 0); + vst1q_lane_u32( + __builtin_assume_aligned(c1, 1), + vreinterpretq_u32_u8(vout0123x0123), + 1); + vst1q_lane_u32( + __builtin_assume_aligned(c2, 1), + vreinterpretq_u32_u8(vout0123x0123), + 2); + vst1q_lane_u32( + __builtin_assume_aligned(c3, 1), + vreinterpretq_u32_u8(vout0123x0123), + 3); + vst1_lane_u32( + __builtin_assume_aligned(c4, 1), vreinterpret_u32_u8(vout45x0123), 0); + vst1_lane_u32( + __builtin_assume_aligned(c5, 1), vreinterpret_u32_u8(vout45x0123), 1); + } else { + if (nr >= 2) { + vst1q_lane_u16( + __builtin_assume_aligned(c0, 1), + vreinterpretq_u16_u8(vout0123x0123), + 0); + c0 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c1, 1), + vreinterpretq_u16_u8(vout0123x0123), + 2); + c1 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c2, 1), + vreinterpretq_u16_u8(vout0123x0123), + 4); + c2 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c3, 1), + vreinterpretq_u16_u8(vout0123x0123), + 6); + c3 += 2; + vst1_lane_u16( + __builtin_assume_aligned(c4, 1), vreinterpret_u16_u8(vout45x0123), 0); + c4 += 2; + vst1_lane_u16( + __builtin_assume_aligned(c5, 1), vreinterpret_u16_u8(vout45x0123), 2); + c5 += 2; + vout0123x0123 = vextq_u8(vout0123x0123, vout0123x0123, 2); + vout45x0123 = vext_u8(vout45x0123, vout45x0123, 2); + nr -= 2; + } + if (nr != 0) { + vst1q_lane_u8(__builtin_assume_aligned(c0, 1), vout0123x0123, 0); + vst1q_lane_u8(__builtin_assume_aligned(c1, 1), vout0123x0123, 4); + vst1q_lane_u8(__builtin_assume_aligned(c2, 1), vout0123x0123, 8); + vst1q_lane_u8(__builtin_assume_aligned(c3, 1), vout0123x0123, 12); + vst1_lane_u8(__builtin_assume_aligned(c4, 1), vout45x0123, 0); + vst1_lane_u8(__builtin_assume_aligned(c5, 1), vout45x0123, 4); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-aarch64-neon.S b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-aarch64-neon.S new file mode 100644 index 0000000000000..8e57b40cb2142 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-aarch64-neon.S @@ -0,0 +1,780 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + + +# void pytorch_q8gemm_ukernel_8x8__aarch64_neon( +# size_t mr, +# size_t nr, +# size_t k, +# const uint8_t*restrict a, +# size_t a_stride, +# const void*restrict w, +# uint8_t*restrict c, +# size_t c_stride, +# const union pytorch_qnnp_conv_quantization_params quantization_params[restrict static 1]) +BEGIN_FUNCTION pytorch_q8gemm_ukernel_8x8__aarch64_neon + + STP d15, d14, [sp, -16] + STP d13, d12, [sp, -32] + STP d11, d10, [sp, -48] + STP d9, d8, [sp, -64] + + # Load bias0123, bias4567 + LD1 {v8.4s, v9.4s}, [x5], 32 + + MOV x9, 2 + + # Load params + LDR x8, [sp] + + # Load b_zero_point + LD1R {v25.8b}, [x8], x9 + # Load a_zero_point + LD1R {v24.8b}, [x8], x9 + + # v10 := vacc1x0123 + MOV v10.16b, v8.16b + # v11 := vacc1x4567 + MOV v11.16b, v9.16b + + # v12 := vacc2x0123 + MOV v12.16b, v8.16b + # v13 := vacc2x4567 + MOV v13.16b, v9.16b + + # v14 := vacc3x0123 + MOV v14.16b, v8.16b + # v15 := vacc3x4567 + MOV v15.16b, v9.16b + + # v16 := vacc4x0123 + MOV v16.16b, v8.16b + # v17 := vacc4x4567 + MOV v17.16b, v9.16b + + # v18 := vacc5x0123 + MOV v18.16b, v8.16b + # v19 := vacc5x4567 + MOV v19.16b, v9.16b + + # v20 := vacc6x0123 + MOV v20.16b, v8.16b + # v21 := vacc6x4567 + MOV v21.16b, v9.16b + + # v22 := vacc7x0123 + MOV v22.16b, v8.16b + # v23 := vacc7x4567 + MOV v23.16b, v9.16b + + // Load multiplier + // - v26 = vmultiplier + LD1R {v26.4s}, [x8], 4 + + # a1 + CMP x0, 2 + ADD x9, x3, x4 + CSEL x9, x3, x9, LO + + # a2 + ADD x10, x9, x4 + CSEL x10, x9, x10, LS + + # a3 + CMP x0, 4 + ADD x11, x10, x4 + CSEL x11, x10, x11, LO + + # a4 + ADD x12, x11, x4 + CSEL x12, x11, x12, LS + + # a5 + CMP x0, 6 + ADD x13, x12, x4 + CSEL x13, x12, x13, LO + + # a6 + ADD x14, x13, x4 + CSEL x14, x13, x14, LS + + # a7 + CMP x0, 8 + ADD x15, x14, x4 + CSEL x15, x14, x15, NE + + SUBS x2, x2, 8 + B.LO 1f + +#ifndef IGNORE_CODE_ALIGN_DIRECTIVES + .p2align 5 +#endif +0: + // b0-7 (channel 0) + LD1 {v27.8b}, [x5], 8 + USUBL v27.8h, v27.8b, v25.8b + + # va0 - va7 := va - va_zero_point + LD1 {v0.8b}, [x3], 8 + SUB_ZERO_POINT v0.8h, v0.8b, v24.8b + LD1 {v1.8b}, [x9], 8 + SUB_ZERO_POINT v1.8h, v1.8b, v24.8b + LD1 {v2.8b}, [x10], 8 + SUB_ZERO_POINT v2.8h, v2.8b, v24.8b + LD1 {v3.8b}, [x11], 8 + SUB_ZERO_POINT v3.8h, v3.8b, v24.8b + LD1 {v4.8b}, [x12], 8 + SUB_ZERO_POINT v4.8h, v4.8b, v24.8b + LD1 {v5.8b}, [x13], 8 + SUB_ZERO_POINT v5.8h, v5.8b, v24.8b + LD1 {v6.8b}, [x14], 8 + SUB_ZERO_POINT v6.8h, v6.8b, v24.8b + LD1 {v7.8b}, [x15], 8 + SUB_ZERO_POINT v7.8h, v7.8b, v24.8b + + // b0-7 (channel 1) + LD1 {v28.8b}, [x5], 8 + + SMLAL v8.4s, v27.4h, v0.h[0] // vacc0x0123 += vb0123 * va0[0] + SMLAL2 v9.4s, v27.8h, v0.h[0] // vacc0x4567 += vb4567 * va0[0] + SMLAL v10.4s, v27.4h, v1.h[0] // vacc1x0123 += vb0123 * va1[0] + SMLAL2 v11.4s, v27.8h, v1.h[0] // vacc1x4567 += vb4567 * va1[0] + SMLAL v12.4s, v27.4h, v2.h[0] // vacc2x0123 += vb0123 * va2[0] + SMLAL2 v13.4s, v27.8h, v2.h[0] // vacc2x4567 += vb4567 * va2[0] + SMLAL v14.4s, v27.4h, v3.h[0] // vacc3x0123 += vb0123 * va3[0] + SMLAL2 v15.4s, v27.8h, v3.h[0] // vacc3x4567 += vb4567 * va3[0] + USUBL v28.8h, v28.8b, v25.8b + SMLAL v16.4s, v27.4h, v4.h[0] // vacc4x0123 += vb0123 * va4[0] + SMLAL2 v17.4s, v27.8h, v4.h[0] // vacc4x4567 += vb4567 * va4[0] + SMLAL v18.4s, v27.4h, v5.h[0] // vacc5x0123 += vb0123 * va5[0] + SMLAL2 v19.4s, v27.8h, v5.h[0] // vacc5x4567 += vb4567 * va5[0] + SMLAL v20.4s, v27.4h, v6.h[0] // vacc6x0123 += vb0123 * va6[0] + SMLAL2 v21.4s, v27.8h, v6.h[0] // vacc6x4567 += vb4567 * va6[0] + SMLAL v22.4s, v27.4h, v7.h[0] // vacc7x0123 += vb0123 * va7[0] + SMLAL2 v23.4s, v27.8h, v7.h[0] // vacc7x4567 += vb4567 * va7[0] + + // b0-7 (channel 2) + LD1 {v27.8b}, [x5], 8 + + SMLAL v8.4s, v28.4h, v0.h[1] // vacc0x0123 += vb0123 * va0[1] + SMLAL2 v9.4s, v28.8h, v0.h[1] // vacc0x4567 += vb4567 * va0[1] + SMLAL v10.4s, v28.4h, v1.h[1] // vacc1x0123 += vb0123 * va1[1] + SMLAL2 v11.4s, v28.8h, v1.h[1] // vacc1x4567 += vb4567 * va1[1] + SMLAL v12.4s, v28.4h, v2.h[1] // vacc2x0123 += vb0123 * va2[1] + SMLAL2 v13.4s, v28.8h, v2.h[1] // vacc2x4567 += vb4567 * va2[1] + SMLAL v14.4s, v28.4h, v3.h[1] // vacc3x0123 += vb0123 * va3[1] + SMLAL2 v15.4s, v28.8h, v3.h[1] // vacc3x4567 += vb4567 * va3[1] + USUBL v27.8h, v27.8b, v25.8b + SMLAL v16.4s, v28.4h, v4.h[1] // vacc4x0123 += vb0123 * va4[1] + SMLAL2 v17.4s, v28.8h, v4.h[1] // vacc4x4567 += vb4567 * va4[1] + SMLAL v18.4s, v28.4h, v5.h[1] // vacc5x0123 += vb0123 * va5[1] + SMLAL2 v19.4s, v28.8h, v5.h[1] // vacc5x4567 += vb4567 * va5[1] + SMLAL v20.4s, v28.4h, v6.h[1] // vacc6x0123 += vb0123 * va6[1] + SMLAL2 v21.4s, v28.8h, v6.h[1] // vacc6x4567 += vb4567 * va6[1] + SMLAL v22.4s, v28.4h, v7.h[1] // vacc7x0123 += vb0123 * va7[1] + SMLAL2 v23.4s, v28.8h, v7.h[1] // vacc7x4567 += vb4567 * va7[1] + + // b0-7 (channel 3) + LD1 {v28.8b}, [x5], 8 + + SMLAL v8.4s, v27.4h, v0.h[2] // vacc0x0123 += vb0123 * va0[2] + SMLAL2 v9.4s, v27.8h, v0.h[2] // vacc0x4567 += vb4567 * va0[2] + SMLAL v10.4s, v27.4h, v1.h[2] // vacc1x0123 += vb0123 * va1[2] + SMLAL2 v11.4s, v27.8h, v1.h[2] // vacc1x4567 += vb4567 * va1[2] + SMLAL v12.4s, v27.4h, v2.h[2] // vacc2x0123 += vb0123 * va2[2] + SMLAL2 v13.4s, v27.8h, v2.h[2] // vacc2x4567 += vb4567 * va2[2] + SMLAL v14.4s, v27.4h, v3.h[2] // vacc3x0123 += vb0123 * va3[2] + SMLAL2 v15.4s, v27.8h, v3.h[2] // vacc3x4567 += vb4567 * va3[2] + USUBL v28.8h, v28.8b, v25.8b + SMLAL v16.4s, v27.4h, v4.h[2] // vacc4x0123 += vb0123 * va4[2] + SMLAL2 v17.4s, v27.8h, v4.h[2] // vacc4x4567 += vb4567 * va4[2] + SMLAL v18.4s, v27.4h, v5.h[2] // vacc5x0123 += vb0123 * va5[2] + SMLAL2 v19.4s, v27.8h, v5.h[2] // vacc5x4567 += vb4567 * va5[2] + SMLAL v20.4s, v27.4h, v6.h[2] // vacc6x0123 += vb0123 * va6[2] + SMLAL2 v21.4s, v27.8h, v6.h[2] // vacc6x4567 += vb4567 * va6[2] + SMLAL v22.4s, v27.4h, v7.h[2] // vacc7x0123 += vb0123 * va7[2] + SMLAL2 v23.4s, v27.8h, v7.h[2] // vacc7x4567 += vb4567 * va7[2] + + // b0-7 (channel 4) + LD1 {v27.8b}, [x5], 8 + + SMLAL v8.4s, v28.4h, v0.h[3] // vacc0x0123 += vb0123 * va0[3] + SMLAL2 v9.4s, v28.8h, v0.h[3] // vacc0x4567 += vb4567 * va0[3] + SMLAL v10.4s, v28.4h, v1.h[3] // vacc1x0123 += vb0123 * va1[3] + SMLAL2 v11.4s, v28.8h, v1.h[3] // vacc1x4567 += vb4567 * va1[3] + SMLAL v12.4s, v28.4h, v2.h[3] // vacc2x0123 += vb0123 * va2[3] + SMLAL2 v13.4s, v28.8h, v2.h[3] // vacc2x4567 += vb4567 * va2[3] + SMLAL v14.4s, v28.4h, v3.h[3] // vacc3x0123 += vb0123 * va3[3] + SMLAL2 v15.4s, v28.8h, v3.h[3] // vacc3x4567 += vb4567 * va3[3] + USUBL v27.8h, v27.8b, v25.8b + SMLAL v16.4s, v28.4h, v4.h[3] // vacc4x0123 += vb0123 * va4[3] + SMLAL2 v17.4s, v28.8h, v4.h[3] // vacc4x4567 += vb4567 * va4[3] + SMLAL v18.4s, v28.4h, v5.h[3] // vacc5x0123 += vb0123 * va5[3] + SMLAL2 v19.4s, v28.8h, v5.h[3] // vacc5x4567 += vb4567 * va5[3] + SMLAL v20.4s, v28.4h, v6.h[3] // vacc6x0123 += vb0123 * va6[3] + SMLAL2 v21.4s, v28.8h, v6.h[3] // vacc6x4567 += vb4567 * va6[3] + SMLAL v22.4s, v28.4h, v7.h[3] // vacc7x0123 += vb0123 * va7[3] + SMLAL2 v23.4s, v28.8h, v7.h[3] // vacc7x4567 += vb4567 * va7[3] + + // b0-7 (channel 5) + LD1 {v28.8b}, [x5], 8 + + SMLAL v8.4s, v27.4h, v0.h[4] // vacc0x0123 += vb0123 * va0[4] + SMLAL2 v9.4s, v27.8h, v0.h[4] // vacc0x4567 += vb4567 * va0[4] + SMLAL v10.4s, v27.4h, v1.h[4] // vacc1x0123 += vb0123 * va1[4] + SMLAL2 v11.4s, v27.8h, v1.h[4] // vacc1x4567 += vb4567 * va1[4] + SMLAL v12.4s, v27.4h, v2.h[4] // vacc2x0123 += vb0123 * va2[4] + SMLAL2 v13.4s, v27.8h, v2.h[4] // vacc2x4567 += vb4567 * va2[4] + SMLAL v14.4s, v27.4h, v3.h[4] // vacc3x0123 += vb0123 * va3[4] + SMLAL2 v15.4s, v27.8h, v3.h[4] // vacc3x4567 += vb4567 * va3[4] + USUBL v28.8h, v28.8b, v25.8b + SMLAL v16.4s, v27.4h, v4.h[4] // vacc4x0123 += vb0123 * va4[4] + SMLAL2 v17.4s, v27.8h, v4.h[4] // vacc4x4567 += vb4567 * va4[4] + SMLAL v18.4s, v27.4h, v5.h[4] // vacc5x0123 += vb0123 * va5[4] + SMLAL2 v19.4s, v27.8h, v5.h[4] // vacc5x4567 += vb4567 * va5[4] + SMLAL v20.4s, v27.4h, v6.h[4] // vacc6x0123 += vb0123 * va6[4] + SMLAL2 v21.4s, v27.8h, v6.h[4] // vacc6x4567 += vb4567 * va6[4] + SMLAL v22.4s, v27.4h, v7.h[4] // vacc7x0123 += vb0123 * va7[4] + SMLAL2 v23.4s, v27.8h, v7.h[4] // vacc7x4567 += vb4567 * va7[4] + + // b0-7 (channel 6) + LD1 {v27.8b}, [x5], 8 + + SMLAL v8.4s, v28.4h, v0.h[5] // vacc0x0123 += vb0123 * va0[5] + SMLAL2 v9.4s, v28.8h, v0.h[5] // vacc0x4567 += vb4567 * va0[5] + SMLAL v10.4s, v28.4h, v1.h[5] // vacc1x0123 += vb0123 * va1[5] + SMLAL2 v11.4s, v28.8h, v1.h[5] // vacc1x4567 += vb4567 * va1[5] + SMLAL v12.4s, v28.4h, v2.h[5] // vacc2x0123 += vb0123 * va2[5] + SMLAL2 v13.4s, v28.8h, v2.h[5] // vacc2x4567 += vb4567 * va2[5] + SMLAL v14.4s, v28.4h, v3.h[5] // vacc3x0123 += vb0123 * va3[5] + SMLAL2 v15.4s, v28.8h, v3.h[5] // vacc3x4567 += vb4567 * va3[5] + USUBL v27.8h, v27.8b, v25.8b + SMLAL v16.4s, v28.4h, v4.h[5] // vacc4x0123 += vb0123 * va4[5] + SMLAL2 v17.4s, v28.8h, v4.h[5] // vacc4x4567 += vb4567 * va4[5] + SMLAL v18.4s, v28.4h, v5.h[5] // vacc5x0123 += vb0123 * va5[5] + SMLAL2 v19.4s, v28.8h, v5.h[5] // vacc5x4567 += vb4567 * va5[5] + SMLAL v20.4s, v28.4h, v6.h[5] // vacc6x0123 += vb0123 * va6[5] + SMLAL2 v21.4s, v28.8h, v6.h[5] // vacc6x4567 += vb4567 * va6[5] + SMLAL v22.4s, v28.4h, v7.h[5] // vacc7x0123 += vb0123 * va7[5] + SMLAL2 v23.4s, v28.8h, v7.h[5] // vacc7x4567 += vb4567 * va7[5] + + // b0-7 (channel 7) + LD1 {v28.8b}, [x5], 8 + + SMLAL v8.4s, v27.4h, v0.h[6] // vacc0x0123 += vb0123 * va0[6] + SMLAL2 v9.4s, v27.8h, v0.h[6] // vacc0x4567 += vb4567 * va0[6] + SMLAL v10.4s, v27.4h, v1.h[6] // vacc1x0123 += vb0123 * va1[6] + SMLAL2 v11.4s, v27.8h, v1.h[6] // vacc1x4567 += vb4567 * va1[6] + SMLAL v12.4s, v27.4h, v2.h[6] // vacc2x0123 += vb0123 * va2[6] + SMLAL2 v13.4s, v27.8h, v2.h[6] // vacc2x4567 += vb4567 * va2[6] + SMLAL v14.4s, v27.4h, v3.h[6] // vacc3x0123 += vb0123 * va3[6] + SMLAL2 v15.4s, v27.8h, v3.h[6] // vacc3x4567 += vb4567 * va3[6] + USUBL v28.8h, v28.8b, v25.8b + SMLAL v16.4s, v27.4h, v4.h[6] // vacc4x0123 += vb0123 * va4[6] + SMLAL2 v17.4s, v27.8h, v4.h[6] // vacc4x4567 += vb4567 * va4[6] + SMLAL v18.4s, v27.4h, v5.h[6] // vacc5x0123 += vb0123 * va5[6] + SMLAL2 v19.4s, v27.8h, v5.h[6] // vacc5x4567 += vb4567 * va5[6] + SMLAL v20.4s, v27.4h, v6.h[6] // vacc6x0123 += vb0123 * va6[6] + SMLAL2 v21.4s, v27.8h, v6.h[6] // vacc6x4567 += vb4567 * va6[6] + SMLAL v22.4s, v27.4h, v7.h[6] // vacc7x0123 += vb0123 * va7[6] + SMLAL2 v23.4s, v27.8h, v7.h[6] // vacc7x4567 += vb4567 * va7[6] + + SUBS x2, x2, 8 + + SMLAL v8.4s, v28.4h, v0.h[7] // vacc0x0123 += vb0123 * va0[7] + SMLAL2 v9.4s, v28.8h, v0.h[7] // vacc0x4567 += vb4567 * va0[7] + SMLAL v10.4s, v28.4h, v1.h[7] // vacc1x0123 += vb0123 * va1[7] + SMLAL2 v11.4s, v28.8h, v1.h[7] // vacc1x4567 += vb4567 * va1[7] + SMLAL v12.4s, v28.4h, v2.h[7] // vacc2x0123 += vb0123 * va2[7] + SMLAL2 v13.4s, v28.8h, v2.h[7] // vacc2x4567 += vb4567 * va2[7] + SMLAL v14.4s, v28.4h, v3.h[7] // vacc3x0123 += vb0123 * va3[7] + SMLAL2 v15.4s, v28.8h, v3.h[7] // vacc3x4567 += vb4567 * va3[7] + SMLAL v16.4s, v28.4h, v4.h[7] // vacc4x0123 += vb0123 * va4[7] + SMLAL2 v17.4s, v28.8h, v4.h[7] // vacc4x4567 += vb4567 * va4[7] + SMLAL v18.4s, v28.4h, v5.h[7] // vacc5x0123 += vb0123 * va5[7] + SMLAL2 v19.4s, v28.8h, v5.h[7] // vacc5x4567 += vb4567 * va5[7] + SMLAL v20.4s, v28.4h, v6.h[7] // vacc6x0123 += vb0123 * va6[7] + SMLAL2 v21.4s, v28.8h, v6.h[7] // vacc6x4567 += vb4567 * va6[7] + SMLAL v22.4s, v28.4h, v7.h[7] // vacc7x0123 += vb0123 * va7[7] + SMLAL2 v23.4s, v28.8h, v7.h[7] // vacc7x4567 += vb4567 * va7[7] + + B.HS 0b + +1: + CMP x2, -8 + B.EQ 2f + + // Adjust a0-a7 + ADD x3, x3, x2 + ADD x9, x9, x2 + ADD x10, x10, x2 + ADD x11, x11, x2 + ADD x12, x12, x2 + ADD x13, x13, x2 + ADD x14, x14, x2 + ADD x15, x15, x2 + + // a_shift = 8 * k - 64 + LSL x2, x2, 3 + FMOV d29, x2 + USHL d24, d24, d29 + + // Load x0-a7 + LD1 {v0.8b}, [x3], 8 + USHL d0, d0, d29 + SUB_ZERO_POINT v0.8h, v0.8b, v24.8b + + LD1 {v1.8b}, [x9], 8 + USHL d1, d1, d29 + SUB_ZERO_POINT v1.8h, v1.8b, v24.8b + + LD1 {v2.8b}, [x10], 8 + USHL d2, d2, d29 + SUB_ZERO_POINT v2.8h, v2.8b, v24.8b + + LD1 {v3.8b}, [x11], 8 + USHL d3, d3, d29 + SUB_ZERO_POINT v3.8h, v3.8b, v24.8b + + LD1 {v4.8b}, [x12], 8 + USHL d4, d4, d29 + SUB_ZERO_POINT v4.8h, v4.8b, v24.8b + + LD1 {v5.8b}, [x13], 8 + USHL d5, d5, d29 + SUB_ZERO_POINT v5.8h, v5.8b, v24.8b + + LD1 {v6.8b}, [x14], 8 + USHL d6, d6, d29 + SUB_ZERO_POINT v6.8h, v6.8b, v24.8b + + LD1 {v7.8b}, [x15], 8 + USHL d7, d7, d29 + SUB_ZERO_POINT v7.8h, v7.8b, v24.8b + + // Channel 0 + LD1 {v27.8b}, [x5], 8 + USUBL v27.8h, v27.8b, v25.8b + + SMLAL v8.4s, v27.4h, v0.h[0] // vacc0x0123 += vb0123 * va0[0] + SMLAL2 v9.4s, v27.8h, v0.h[0] // vacc0x4567 += vb4567 * va0[0] + SMLAL v10.4s, v27.4h, v1.h[0] // vacc1x0123 += vb0123 * va1[0] + SMLAL2 v11.4s, v27.8h, v1.h[0] // vacc1x4567 += vb4567 * va1[0] + SMLAL v12.4s, v27.4h, v2.h[0] // vacc2x0123 += vb0123 * va2[0] + SMLAL2 v13.4s, v27.8h, v2.h[0] // vacc2x4567 += vb4567 * va2[0] + SMLAL v14.4s, v27.4h, v3.h[0] // vacc3x0123 += vb0123 * va3[0] + SMLAL2 v15.4s, v27.8h, v3.h[0] // vacc3x4567 += vb4567 * va3[0] + SMLAL v16.4s, v27.4h, v4.h[0] // vacc4x0123 += vb0123 * va4[0] + SMLAL2 v17.4s, v27.8h, v4.h[0] // vacc4x4567 += vb4567 * va4[0] + SMLAL v18.4s, v27.4h, v5.h[0] // vacc5x0123 += vb0123 * va5[0] + SMLAL2 v19.4s, v27.8h, v5.h[0] // vacc5x4567 += vb4567 * va5[0] + SMLAL v20.4s, v27.4h, v6.h[0] // vacc6x0123 += vb0123 * va6[0] + SMLAL2 v21.4s, v27.8h, v6.h[0] // vacc6x4567 += vb4567 * va6[0] + SMLAL v22.4s, v27.4h, v7.h[0] // vacc7x0123 += vb0123 * va7[0] + SMLAL2 v23.4s, v27.8h, v7.h[0] // vacc7x4567 += vb4567 * va7[0] + + CMP x2, -48 + B.LO 2f + + // Channel 1 + LD1 {v28.8b}, [x5], 8 + USUBL v28.8h, v28.8b, v25.8b + + SMLAL v8.4s, v28.4h, v0.h[1] // vacc0x0123 += vb0123 * va0[1] + SMLAL2 v9.4s, v28.8h, v0.h[1] // vacc0x4567 += vb4567 * va0[1] + SMLAL v10.4s, v28.4h, v1.h[1] // vacc1x0123 += vb0123 * va1[1] + SMLAL2 v11.4s, v28.8h, v1.h[1] // vacc1x4567 += vb4567 * va1[1] + SMLAL v12.4s, v28.4h, v2.h[1] // vacc2x0123 += vb0123 * va2[1] + SMLAL2 v13.4s, v28.8h, v2.h[1] // vacc2x4567 += vb4567 * va2[1] + SMLAL v14.4s, v28.4h, v3.h[1] // vacc3x0123 += vb0123 * va3[1] + SMLAL2 v15.4s, v28.8h, v3.h[1] // vacc3x4567 += vb4567 * va3[1] + SMLAL v16.4s, v28.4h, v4.h[1] // vacc4x0123 += vb0123 * va4[1] + SMLAL2 v17.4s, v28.8h, v4.h[1] // vacc4x4567 += vb4567 * va4[1] + SMLAL v18.4s, v28.4h, v5.h[1] // vacc5x0123 += vb0123 * va5[1] + SMLAL2 v19.4s, v28.8h, v5.h[1] // vacc5x4567 += vb4567 * va5[1] + SMLAL v20.4s, v28.4h, v6.h[1] // vacc6x0123 += vb0123 * va6[1] + SMLAL2 v21.4s, v28.8h, v6.h[1] // vacc6x4567 += vb4567 * va6[1] + SMLAL v22.4s, v28.4h, v7.h[1] // vacc7x0123 += vb0123 * va7[1] + SMLAL2 v23.4s, v28.8h, v7.h[1] // vacc7x4567 += vb4567 * va7[1] + + B.LS 2f + + // Channel 2 + LD1 {v27.8b}, [x5], 8 + USUBL v27.8h, v27.8b, v25.8b + + SMLAL v8.4s, v27.4h, v0.h[2] // vacc0x0123 += vb0123 * va0[2] + SMLAL2 v9.4s, v27.8h, v0.h[2] // vacc0x4567 += vb4567 * va0[2] + SMLAL v10.4s, v27.4h, v1.h[2] // vacc1x0123 += vb0123 * va1[2] + SMLAL2 v11.4s, v27.8h, v1.h[2] // vacc1x4567 += vb4567 * va1[2] + SMLAL v12.4s, v27.4h, v2.h[2] // vacc2x0123 += vb0123 * va2[2] + SMLAL2 v13.4s, v27.8h, v2.h[2] // vacc2x4567 += vb4567 * va2[2] + SMLAL v14.4s, v27.4h, v3.h[2] // vacc3x0123 += vb0123 * va3[2] + SMLAL2 v15.4s, v27.8h, v3.h[2] // vacc3x4567 += vb4567 * va3[2] + SMLAL v16.4s, v27.4h, v4.h[2] // vacc4x0123 += vb0123 * va4[2] + SMLAL2 v17.4s, v27.8h, v4.h[2] // vacc4x4567 += vb4567 * va4[2] + SMLAL v18.4s, v27.4h, v5.h[2] // vacc5x0123 += vb0123 * va5[2] + SMLAL2 v19.4s, v27.8h, v5.h[2] // vacc5x4567 += vb4567 * va5[2] + SMLAL v20.4s, v27.4h, v6.h[2] // vacc6x0123 += vb0123 * va6[2] + SMLAL2 v21.4s, v27.8h, v6.h[2] // vacc6x4567 += vb4567 * va6[2] + SMLAL v22.4s, v27.4h, v7.h[2] // vacc7x0123 += vb0123 * va7[2] + SMLAL2 v23.4s, v27.8h, v7.h[2] // vacc7x4567 += vb4567 * va7[2] + + CMP x2, -32 + B.LO 2f + + // Channel 3 + LD1 {v28.8b}, [x5], 8 + USUBL v28.8h, v28.8b, v25.8b + + SMLAL v8.4s, v28.4h, v0.h[3] // vacc0x0123 += vb0123 * va0[3] + SMLAL2 v9.4s, v28.8h, v0.h[3] // vacc0x4567 += vb4567 * va0[3] + SMLAL v10.4s, v28.4h, v1.h[3] // vacc1x0123 += vb0123 * va1[3] + SMLAL2 v11.4s, v28.8h, v1.h[3] // vacc1x4567 += vb4567 * va1[3] + SMLAL v12.4s, v28.4h, v2.h[3] // vacc2x0123 += vb0123 * va2[3] + SMLAL2 v13.4s, v28.8h, v2.h[3] // vacc2x4567 += vb4567 * va2[3] + SMLAL v14.4s, v28.4h, v3.h[3] // vacc3x0123 += vb0123 * va3[3] + SMLAL2 v15.4s, v28.8h, v3.h[3] // vacc3x4567 += vb4567 * va3[3] + SMLAL v16.4s, v28.4h, v4.h[3] // vacc4x0123 += vb0123 * va4[3] + SMLAL2 v17.4s, v28.8h, v4.h[3] // vacc4x4567 += vb4567 * va4[3] + SMLAL v18.4s, v28.4h, v5.h[3] // vacc5x0123 += vb0123 * va5[3] + SMLAL2 v19.4s, v28.8h, v5.h[3] // vacc5x4567 += vb4567 * va5[3] + SMLAL v20.4s, v28.4h, v6.h[3] // vacc6x0123 += vb0123 * va6[3] + SMLAL2 v21.4s, v28.8h, v6.h[3] // vacc6x4567 += vb4567 * va6[3] + SMLAL v22.4s, v28.4h, v7.h[3] // vacc7x0123 += vb0123 * va7[3] + SMLAL2 v23.4s, v28.8h, v7.h[3] // vacc7x4567 += vb4567 * va7[3] + + B.LS 2f + + // Channel 4 + LD1 {v27.8b}, [x5], 8 + USUBL v27.8h, v27.8b, v25.8b + + SMLAL v8.4s, v27.4h, v0.h[4] // vacc0x0123 += vb0123 * va0[4] + SMLAL2 v9.4s, v27.8h, v0.h[4] // vacc0x4567 += vb4567 * va0[4] + SMLAL v10.4s, v27.4h, v1.h[4] // vacc1x0123 += vb0123 * va1[4] + SMLAL2 v11.4s, v27.8h, v1.h[4] // vacc1x4567 += vb4567 * va1[4] + SMLAL v12.4s, v27.4h, v2.h[4] // vacc2x0123 += vb0123 * va2[4] + SMLAL2 v13.4s, v27.8h, v2.h[4] // vacc2x4567 += vb4567 * va2[4] + SMLAL v14.4s, v27.4h, v3.h[4] // vacc3x0123 += vb0123 * va3[4] + SMLAL2 v15.4s, v27.8h, v3.h[4] // vacc3x4567 += vb4567 * va3[4] + SMLAL v16.4s, v27.4h, v4.h[4] // vacc4x0123 += vb0123 * va4[4] + SMLAL2 v17.4s, v27.8h, v4.h[4] // vacc4x4567 += vb4567 * va4[4] + SMLAL v18.4s, v27.4h, v5.h[4] // vacc5x0123 += vb0123 * va5[4] + SMLAL2 v19.4s, v27.8h, v5.h[4] // vacc5x4567 += vb4567 * va5[4] + SMLAL v20.4s, v27.4h, v6.h[4] // vacc6x0123 += vb0123 * va6[4] + SMLAL2 v21.4s, v27.8h, v6.h[4] // vacc6x4567 += vb4567 * va6[4] + SMLAL v22.4s, v27.4h, v7.h[4] // vacc7x0123 += vb0123 * va7[4] + SMLAL2 v23.4s, v27.8h, v7.h[4] // vacc7x4567 += vb4567 * va7[4] + + CMP x2, -16 + B.LO 2f + + // Channel 5 + LD1 {v28.8b}, [x5], 8 + USUBL v28.8h, v28.8b, v25.8b + + SMLAL v8.4s, v28.4h, v0.h[5] // vacc0x0123 += vb0123 * va0[5] + SMLAL2 v9.4s, v28.8h, v0.h[5] // vacc0x4567 += vb4567 * va0[5] + SMLAL v10.4s, v28.4h, v1.h[5] // vacc1x0123 += vb0123 * va1[5] + SMLAL2 v11.4s, v28.8h, v1.h[5] // vacc1x4567 += vb4567 * va1[5] + SMLAL v12.4s, v28.4h, v2.h[5] // vacc2x0123 += vb0123 * va2[5] + SMLAL2 v13.4s, v28.8h, v2.h[5] // vacc2x4567 += vb4567 * va2[5] + SMLAL v14.4s, v28.4h, v3.h[5] // vacc3x0123 += vb0123 * va3[5] + SMLAL2 v15.4s, v28.8h, v3.h[5] // vacc3x4567 += vb4567 * va3[5] + SMLAL v16.4s, v28.4h, v4.h[5] // vacc4x0123 += vb0123 * va4[5] + SMLAL2 v17.4s, v28.8h, v4.h[5] // vacc4x4567 += vb4567 * va4[5] + SMLAL v18.4s, v28.4h, v5.h[5] // vacc5x0123 += vb0123 * va5[5] + SMLAL2 v19.4s, v28.8h, v5.h[5] // vacc5x4567 += vb4567 * va5[5] + SMLAL v20.4s, v28.4h, v6.h[5] // vacc6x0123 += vb0123 * va6[5] + SMLAL2 v21.4s, v28.8h, v6.h[5] // vacc6x4567 += vb4567 * va6[5] + SMLAL v22.4s, v28.4h, v7.h[5] // vacc7x0123 += vb0123 * va7[5] + SMLAL2 v23.4s, v28.8h, v7.h[5] // vacc7x4567 += vb4567 * va7[5] + + B.LS 2f + + // Channel 6 + LD1 {v27.8b}, [x5], 8 + USUBL v27.8h, v27.8b, v25.8b + + SMLAL v8.4s, v27.4h, v0.h[6] // vacc0x0123 += vb0123 * va0[6] + SMLAL2 v9.4s, v27.8h, v0.h[6] // vacc0x4567 += vb4567 * va0[6] + SMLAL v10.4s, v27.4h, v1.h[6] // vacc1x0123 += vb0123 * va1[6] + SMLAL2 v11.4s, v27.8h, v1.h[6] // vacc1x4567 += vb4567 * va1[6] + SMLAL v12.4s, v27.4h, v2.h[6] // vacc2x0123 += vb0123 * va2[6] + SMLAL2 v13.4s, v27.8h, v2.h[6] // vacc2x4567 += vb4567 * va2[6] + SMLAL v14.4s, v27.4h, v3.h[6] // vacc3x0123 += vb0123 * va3[6] + SMLAL2 v15.4s, v27.8h, v3.h[6] // vacc3x4567 += vb4567 * va3[6] + SMLAL v16.4s, v27.4h, v4.h[6] // vacc4x0123 += vb0123 * va4[6] + SMLAL2 v17.4s, v27.8h, v4.h[6] // vacc4x4567 += vb4567 * va4[6] + SMLAL v18.4s, v27.4h, v5.h[6] // vacc5x0123 += vb0123 * va5[6] + SMLAL2 v19.4s, v27.8h, v5.h[6] // vacc5x4567 += vb4567 * va5[6] + SMLAL v20.4s, v27.4h, v6.h[6] // vacc6x0123 += vb0123 * va6[6] + SMLAL2 v21.4s, v27.8h, v6.h[6] // vacc6x4567 += vb4567 * va6[6] + SMLAL v22.4s, v27.4h, v7.h[6] // vacc7x0123 += vb0123 * va7[6] + SMLAL2 v23.4s, v27.8h, v7.h[6] // vacc7x4567 += vb4567 * va7[6] + +#ifndef IGNORE_CODE_ALIGN_DIRECTIVES + .p2align 4 +#endif +2: + // Load right_shift: + // - v27 = vright_shift + LD1R {v27.4s}, [x8], 4 + + SQRDMULH v8.4s, v8.4s, v26.4s + SQRDMULH v9.4s, v9.4s, v26.4s + SQRDMULH v10.4s, v10.4s, v26.4s + SQRDMULH v11.4s, v11.4s, v26.4s + SQRDMULH v12.4s, v12.4s, v26.4s + SQRDMULH v13.4s, v13.4s, v26.4s + SQRDMULH v14.4s, v14.4s, v26.4s + SQRDMULH v15.4s, v15.4s, v26.4s + + // Compute vzero_shift_mask + // - v28 = vzero_shift_mask + CMEQ v28.4s, v27.4s, 0 + + SQRDMULH v16.4s, v16.4s, v26.4s + SQRDMULH v17.4s, v17.4s, v26.4s + SQRDMULH v18.4s, v18.4s, v26.4s + SQRDMULH v19.4s, v19.4s, v26.4s + SQRDMULH v20.4s, v20.4s, v26.4s + SQRDMULH v21.4s, v21.4s, v26.4s + SQRDMULH v22.4s, v22.4s, v26.4s + SQRDMULH v23.4s, v23.4s, v26.4s + + // Load zero_point: + // - v29 = vzero_point + LD1R {v29.8h}, [x8], 2 + + BIC v0.16b, v8.16b, v28.16b + BIC v1.16b, v9.16b, v28.16b + BIC v2.16b, v10.16b, v28.16b + BIC v3.16b, v11.16b, v28.16b + BIC v4.16b, v12.16b, v28.16b + BIC v5.16b, v13.16b, v28.16b + BIC v6.16b, v14.16b, v28.16b + BIC v7.16b, v15.16b, v28.16b + + SSRA v8.4s, v0.4s, 31 + SSRA v9.4s, v1.4s, 31 + SSRA v10.4s, v2.4s, 31 + SSRA v11.4s, v3.4s, 31 + SSRA v12.4s, v4.4s, 31 + SSRA v13.4s, v5.4s, 31 + SSRA v14.4s, v6.4s, 31 + SSRA v15.4s, v7.4s, 31 + + // Load max: + // - v30 = vmax + LD1R {v30.16b}, [x8], 1 + + BIC v0.16b, v16.16b, v28.16b + BIC v1.16b, v17.16b, v28.16b + BIC v2.16b, v18.16b, v28.16b + BIC v3.16b, v19.16b, v28.16b + BIC v4.16b, v20.16b, v28.16b + BIC v5.16b, v21.16b, v28.16b + BIC v6.16b, v22.16b, v28.16b + BIC v7.16b, v23.16b, v28.16b + + SSRA v16.4s, v0.4s, 31 + SSRA v17.4s, v1.4s, 31 + SSRA v18.4s, v2.4s, 31 + SSRA v19.4s, v3.4s, 31 + SSRA v20.4s, v4.4s, 31 + SSRA v21.4s, v5.4s, 31 + SSRA v22.4s, v6.4s, 31 + SSRA v23.4s, v7.4s, 31 + + // Load min: + // - v31 = vmin + LD1R {v31.16b}, [x8] + + SRSHL v8.4s, v8.4s, v27.4s + SRSHL v9.4s, v9.4s, v27.4s + SRSHL v10.4s, v10.4s, v27.4s + SRSHL v11.4s, v11.4s, v27.4s + SRSHL v12.4s, v12.4s, v27.4s + SRSHL v13.4s, v13.4s, v27.4s + SRSHL v14.4s, v14.4s, v27.4s + SRSHL v15.4s, v15.4s, v27.4s + SRSHL v16.4s, v16.4s, v27.4s + SRSHL v17.4s, v17.4s, v27.4s + SRSHL v18.4s, v18.4s, v27.4s + SRSHL v19.4s, v19.4s, v27.4s + SRSHL v20.4s, v20.4s, v27.4s + SRSHL v21.4s, v21.4s, v27.4s + SRSHL v22.4s, v22.4s, v27.4s + SRSHL v23.4s, v23.4s, v27.4s + + SQXTN v8.4h, v8.4s + SQXTN v10.4h, v10.4s + SQXTN v12.4h, v12.4s + SQXTN v14.4h, v14.4s + SQXTN v16.4h, v16.4s + SQXTN v18.4h, v18.4s + SQXTN v20.4h, v20.4s + SQXTN v22.4h, v22.4s + + SQXTN2 v8.8h, v9.4s + SQXTN2 v10.8h, v11.4s + SQXTN2 v12.8h, v13.4s + SQXTN2 v14.8h, v15.4s + SQXTN2 v16.8h, v17.4s + SQXTN2 v18.8h, v19.4s + SQXTN2 v20.8h, v21.4s + SQXTN2 v22.8h, v23.4s + + SQADD v8.8h, v8.8h, v29.8h + SQADD v10.8h, v10.8h, v29.8h + SQADD v12.8h, v12.8h, v29.8h + SQADD v14.8h, v14.8h, v29.8h + SQADD v16.8h, v16.8h, v29.8h + SQADD v18.8h, v18.8h, v29.8h + SQADD v20.8h, v20.8h, v29.8h + SQADD v22.8h, v22.8h, v29.8h + + SQXTUN v8.8b, v8.8h + SQXTUN v12.8b, v12.8h + SQXTUN v16.8b, v16.8h + SQXTUN v20.8b, v20.8h + + SQXTUN2 v8.16b, v10.8h + SQXTUN2 v12.16b, v14.8h + SQXTUN2 v16.16b, v18.8h + SQXTUN2 v20.16b, v22.8h + + UMIN v8.16b, v8.16b, v30.16b + UMIN v12.16b, v12.16b, v30.16b + UMIN v16.16b, v16.16b, v30.16b + UMIN v20.16b, v20.16b, v30.16b + + UMAX v8.16b, v8.16b, v31.16b + UMAX v12.16b, v12.16b, v31.16b + UMAX v16.16b, v16.16b, v31.16b + UMAX v20.16b, v20.16b, v31.16b + + // Compute c0-c7 + + ADD x9, x6, x7 + CMP x0, 2 + CSEL x9, x6, x9, LO + + ADD x10, x9, x7 + CSEL x10, x9, x10, LS + + ADD x11, x10, x7 + CMP x0, 4 + CSEL x11, x10, x11, LO + + ADD x12, x11, x7 + CSEL x12, x11, x12, LS + + ADD x13, x12, x7 + CMP x0, 6 + CSEL x13, x12, x13, LO + + ADD x14, x13, x7 + CSEL x14, x13, x14, LS + + ADD x15, x14, x7 + CMP x0, 8 + CSEL x15, x14, x15, NE + + CMP x1, 8 + B.NE 4f + + // Store results + ST1 {v8.d}[0], [x6] + ST1 {v8.d}[1], [x9] + ST1 {v12.d}[0], [x10] + ST1 {v12.d}[1], [x11] + ST1 {v16.d}[0], [x12] + ST1 {v16.d}[1], [x13] + ST1 {v20.d}[0], [x14] + ST1 {v20.d}[1], [x15] + + LDP d9, d8, [sp, -64] + LDP d11, d10, [sp, -48] + LDP d13, d12, [sp, -32] + LDP d15, d14, [sp, -16] + + RET + +#ifndef IGNORE_CODE_ALIGN_DIRECTIVES + .p2align 3 +#endif +4: + CMP x1, 4 + B.LO 5f + + ST1 {v8.s}[0], [x6], 4 + ST1 {v8.s}[2], [x9], 4 + ST1 {v12.s}[0], [x10], 4 + ST1 {v12.s}[2], [x11], 4 + ST1 {v16.s}[0], [x12], 4 + ST1 {v16.s}[2], [x13], 4 + ST1 {v20.s}[0], [x14], 4 + ST1 {v20.s}[2], [x15], 4 + + SUB x1, x1, 4 + EXT v8.16b, v8.16b, v8.16b, 4 + EXT v12.16b, v12.16b, v12.16b, 4 + EXT v16.16b, v16.16b, v16.16b, 4 + EXT v20.16b, v20.16b, v20.16b, 4 + +5: + CMP x1, 2 + B.LO 6f + + ST1 {v8.h}[0], [x6], 2 + ST1 {v8.h}[4], [x9], 2 + ST1 {v12.h}[0], [x10], 2 + ST1 {v12.h}[4], [x11], 2 + ST1 {v16.h}[0], [x12], 2 + ST1 {v16.h}[4], [x13], 2 + ST1 {v20.h}[0], [x14], 2 + ST1 {v20.h}[4], [x15], 2 + + SUB x1, x1, 2 + EXT v8.16b, v8.16b, v8.16b, 2 + EXT v12.16b, v12.16b, v12.16b, 2 + EXT v16.16b, v16.16b, v16.16b, 2 + EXT v20.16b, v20.16b, v20.16b, 2 + +6: + CMP x1, 1 + B.LO 7f + + ST1 {v8.b}[0], [x6] + ST1 {v8.b}[8], [x9] + ST1 {v12.b}[0], [x10] + ST1 {v12.b}[8], [x11] + ST1 {v16.b}[0], [x12] + ST1 {v16.b}[8], [x13] + ST1 {v20.b}[0], [x14] + ST1 {v20.b}[8], [x15] + +7: + LDP d9, d8, [sp, -64] + LDP d11, d10, [sp, -48] + LDP d13, d12, [sp, -32] + LDP d15, d14, [sp, -16] + + RET + +END_FUNCTION pytorch_q8gemm_ukernel_8x8__aarch64_neon + +#ifdef __ELF__ +.section ".note.GNU-stack","",%progbits +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-neon.c new file mode 100644 index 0000000000000..ae4430a995f90 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8gemm/8x8-neon.c @@ -0,0 +1,1177 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +void pytorch_q8gemm_ukernel_8x8__neon( + size_t mr, + size_t nr, + size_t k, + const uint8_t* restrict a, + size_t a_stride, + const void* restrict w, + uint8_t* restrict c, + size_t c_stride, + const union pytorch_qnnp_conv_quantization_params + quantization_params[restrict static 1]) { + int32x4_t vacc0x0123 = vld1q_s32(w); + w = (const void*)((uintptr_t)w + 16); + int32x4_t vacc0x4567 = vld1q_s32(w); + w = (const void*)((uintptr_t)w + 16); + int32x4_t vacc1x0123 = vacc0x0123; + int32x4_t vacc1x4567 = vacc0x4567; + int32x4_t vacc2x0123 = vacc0x0123; + int32x4_t vacc2x4567 = vacc0x4567; + int32x4_t vacc3x0123 = vacc0x0123; + int32x4_t vacc3x4567 = vacc0x4567; + int32x4_t vacc4x0123 = vacc0x0123; + int32x4_t vacc4x4567 = vacc0x4567; + int32x4_t vacc5x0123 = vacc0x0123; + int32x4_t vacc5x4567 = vacc0x4567; + int32x4_t vacc6x0123 = vacc0x0123; + int32x4_t vacc6x4567 = vacc0x4567; + int32x4_t vacc7x0123 = vacc0x0123; + int32x4_t vacc7x4567 = vacc0x4567; + + const uint8_t* a0 = a; + const uint8_t* a1 = (const uint8_t*)((uintptr_t)a0 + a_stride); + if (mr < 2) { + a1 = a0; + } + const uint8_t* a2 = (const uint8_t*)((uintptr_t)a1 + a_stride); + if (mr <= 2) { + a2 = a1; + } + const uint8_t* a3 = (const uint8_t*)((uintptr_t)a2 + a_stride); + if (mr < 4) { + a3 = a2; + } + const uint8_t* a4 = (const uint8_t*)((uintptr_t)a3 + a_stride); + if (mr <= 4) { + a4 = a3; + } + const uint8_t* a5 = (const uint8_t*)((uintptr_t)a4 + a_stride); + if (mr < 6) { + a5 = a4; + } + const uint8_t* a6 = (const uint8_t*)((uintptr_t)a5 + a_stride); + if (mr <= 6) { + a6 = a5; + } + const uint8_t* a7 = (const uint8_t*)((uintptr_t)a6 + a_stride); + if (mr != 8) { + a7 = a6; + } + + const uint8x8_t va_zero_point = + vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point); + const uint8x8_t vb_zero_point = + vld1_dup_u8((const uint8_t*)&quantization_params->neon.kernel_zero_point); + for (; k >= 8; k -= 8) { + const uint8x8_t va0 = vld1_u8(a0); + const int16x8_t vxa0 = + vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point)); + a0 += 8; + const uint8x8_t va1 = vld1_u8(a1); + const int16x8_t vxa1 = + vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point)); + a1 += 8; + const uint8x8_t va2 = vld1_u8(a2); + const int16x8_t vxa2 = + vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point)); + a2 += 8; + const uint8x8_t va3 = vld1_u8(a3); + const int16x8_t vxa3 = + vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point)); + a3 += 8; + const uint8x8_t va4 = vld1_u8(a4); + const int16x8_t vxa4 = + vreinterpretq_s16_u16(sub_zero_point(va4, va_zero_point)); + a4 += 8; + const uint8x8_t va5 = vld1_u8(a5); + const int16x8_t vxa5 = + vreinterpretq_s16_u16(sub_zero_point(va5, va_zero_point)); + a5 += 8; + const uint8x8_t va6 = vld1_u8(a6); + const int16x8_t vxa6 = + vreinterpretq_s16_u16(sub_zero_point(va6, va_zero_point)); + a6 += 8; + const uint8x8_t va7 = vld1_u8(a7); + const int16x8_t vxa7 = + vreinterpretq_s16_u16(sub_zero_point(va7, va_zero_point)); + a7 += 8; + + const uint8x8_t vb01234567c0 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c0 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa4), 0); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa4), 0); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa5), 0); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa5), 0); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa6), 0); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa6), 0); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa7), 0); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa7), 0); + + const uint8x8_t vb01234567c1 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c1 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa4), 1); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa4), 1); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa5), 1); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa5), 1); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa6), 1); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa6), 1); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa7), 1); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa7), 1); + + const uint8x8_t vb01234567c2 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c2 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa4), 2); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa4), 2); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa5), 2); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa5), 2); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa6), 2); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa6), 2); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa7), 2); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa7), 2); + + const uint8x8_t vb01234567c3 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c3 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa4), 3); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa4), 3); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa5), 3); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa5), 3); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa6), 3); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa6), 3); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa7), 3); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa7), 3); + + const uint8x8_t vb01234567c4 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c4 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa3), 0); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa4), 0); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa4), 0); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa5), 0); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa5), 0); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa6), 0); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa6), 0); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa7), 0); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa7), 0); + + const uint8x8_t vb01234567c5 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c5 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa3), 1); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa4), 1); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa4), 1); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa5), 1); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa5), 1); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa6), 1); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa6), 1); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa7), 1); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa7), 1); + + const uint8x8_t vb01234567c6 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c6 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c6, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa3), 2); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa4), 2); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa4), 2); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa5), 2); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa5), 2); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa6), 2); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa6), 2); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa7), 2); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa7), 2); + + const uint8x8_t vb01234567c7 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c7 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c7, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa3), 3); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa4), 3); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa4), 3); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa5), 3); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa5), 3); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa6), 3); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa6), 3); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa7), 3); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa7), 3); + } + if (k != 0) { + const size_t a_predecrement = 8 - k; + const int64x1_t va_shift = vmov_n_s64(-8 * a_predecrement); + const uint8x8_t va0 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a0 - a_predecrement)), va_shift)); + const int16x8_t vxa0 = + vreinterpretq_s16_u16(sub_zero_point(va0, va_zero_point)); + const uint8x8_t va1 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a1 - a_predecrement)), va_shift)); + const int16x8_t vxa1 = + vreinterpretq_s16_u16(sub_zero_point(va1, va_zero_point)); + const uint8x8_t va2 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a2 - a_predecrement)), va_shift)); + const int16x8_t vxa2 = + vreinterpretq_s16_u16(sub_zero_point(va2, va_zero_point)); + const uint8x8_t va3 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a3 - a_predecrement)), va_shift)); + const int16x8_t vxa3 = + vreinterpretq_s16_u16(sub_zero_point(va3, va_zero_point)); + const uint8x8_t va4 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a4 - a_predecrement)), va_shift)); + const int16x8_t vxa4 = + vreinterpretq_s16_u16(sub_zero_point(va4, va_zero_point)); + const uint8x8_t va5 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a5 - a_predecrement)), va_shift)); + const int16x8_t vxa5 = + vreinterpretq_s16_u16(sub_zero_point(va5, va_zero_point)); + const uint8x8_t va6 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a6 - a_predecrement)), va_shift)); + const int16x8_t vxa6 = + vreinterpretq_s16_u16(sub_zero_point(va6, va_zero_point)); + const uint8x8_t va7 = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a7 - a_predecrement)), va_shift)); + const int16x8_t vxa7 = + vreinterpretq_s16_u16(sub_zero_point(va7, va_zero_point)); + + const uint8x8_t vb01234567c0 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c0 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa4), 0); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa4), 0); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa5), 0); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa5), 0); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa6), 0); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa6), 0); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa7), 0); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa7), 0); + + if (k >= 2) { + const uint8x8_t vb01234567c1 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c1 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa4), 1); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa4), 1); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa5), 1); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa5), 1); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa6), 1); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa6), 1); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa7), 1); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa7), 1); + + if (k >= 3) { + const uint8x8_t vb01234567c2 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c2 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa4), 2); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa4), 2); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa5), 2); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa5), 2); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa6), 2); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa6), 2); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa7), 2); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa7), 2); + + if (k >= 4) { + const uint8x8_t vb01234567c3 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c3 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa4), 3); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa4), 3); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa5), 3); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa5), 3); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa6), 3); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa6), 3); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa7), 3); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa7), 3); + + if (k >= 5) { + const uint8x8_t vb01234567c4 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c4 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, + vget_low_s16(vxb01234567c4), + vget_high_s16(vxa0), + 0); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, + vget_high_s16(vxb01234567c4), + vget_high_s16(vxa0), + 0); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, + vget_low_s16(vxb01234567c4), + vget_high_s16(vxa1), + 0); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, + vget_high_s16(vxb01234567c4), + vget_high_s16(vxa1), + 0); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, + vget_low_s16(vxb01234567c4), + vget_high_s16(vxa2), + 0); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, + vget_high_s16(vxb01234567c4), + vget_high_s16(vxa2), + 0); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, + vget_low_s16(vxb01234567c4), + vget_high_s16(vxa3), + 0); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, + vget_high_s16(vxb01234567c4), + vget_high_s16(vxa3), + 0); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, + vget_low_s16(vxb01234567c4), + vget_high_s16(vxa4), + 0); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, + vget_high_s16(vxb01234567c4), + vget_high_s16(vxa4), + 0); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, + vget_low_s16(vxb01234567c4), + vget_high_s16(vxa5), + 0); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, + vget_high_s16(vxb01234567c4), + vget_high_s16(vxa5), + 0); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, + vget_low_s16(vxb01234567c4), + vget_high_s16(vxa6), + 0); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, + vget_high_s16(vxb01234567c4), + vget_high_s16(vxa6), + 0); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, + vget_low_s16(vxb01234567c4), + vget_high_s16(vxa7), + 0); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, + vget_high_s16(vxb01234567c4), + vget_high_s16(vxa7), + 0); + + if (k >= 6) { + const uint8x8_t vb01234567c5 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c5 = + vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, + vget_low_s16(vxb01234567c5), + vget_high_s16(vxa0), + 1); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, + vget_high_s16(vxb01234567c5), + vget_high_s16(vxa0), + 1); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, + vget_low_s16(vxb01234567c5), + vget_high_s16(vxa1), + 1); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, + vget_high_s16(vxb01234567c5), + vget_high_s16(vxa1), + 1); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, + vget_low_s16(vxb01234567c5), + vget_high_s16(vxa2), + 1); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, + vget_high_s16(vxb01234567c5), + vget_high_s16(vxa2), + 1); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, + vget_low_s16(vxb01234567c5), + vget_high_s16(vxa3), + 1); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, + vget_high_s16(vxb01234567c5), + vget_high_s16(vxa3), + 1); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, + vget_low_s16(vxb01234567c5), + vget_high_s16(vxa4), + 1); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, + vget_high_s16(vxb01234567c5), + vget_high_s16(vxa4), + 1); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, + vget_low_s16(vxb01234567c5), + vget_high_s16(vxa5), + 1); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, + vget_high_s16(vxb01234567c5), + vget_high_s16(vxa5), + 1); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, + vget_low_s16(vxb01234567c5), + vget_high_s16(vxa6), + 1); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, + vget_high_s16(vxb01234567c5), + vget_high_s16(vxa6), + 1); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, + vget_low_s16(vxb01234567c5), + vget_high_s16(vxa7), + 1); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, + vget_high_s16(vxb01234567c5), + vget_high_s16(vxa7), + 1); + + if (k >= 7) { + const uint8x8_t vb01234567c6 = vld1_u8(w); + w = (const void*)((uintptr_t)w + 8); + const int16x8_t vxb01234567c6 = vreinterpretq_s16_u16( + vsubl_u8(vb01234567c6, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16( + vacc0x0123, + vget_low_s16(vxb01234567c6), + vget_high_s16(vxa0), + 2); + vacc0x4567 = vmlal_lane_s16( + vacc0x4567, + vget_high_s16(vxb01234567c6), + vget_high_s16(vxa0), + 2); + vacc1x0123 = vmlal_lane_s16( + vacc1x0123, + vget_low_s16(vxb01234567c6), + vget_high_s16(vxa1), + 2); + vacc1x4567 = vmlal_lane_s16( + vacc1x4567, + vget_high_s16(vxb01234567c6), + vget_high_s16(vxa1), + 2); + vacc2x0123 = vmlal_lane_s16( + vacc2x0123, + vget_low_s16(vxb01234567c6), + vget_high_s16(vxa2), + 2); + vacc2x4567 = vmlal_lane_s16( + vacc2x4567, + vget_high_s16(vxb01234567c6), + vget_high_s16(vxa2), + 2); + vacc3x0123 = vmlal_lane_s16( + vacc3x0123, + vget_low_s16(vxb01234567c6), + vget_high_s16(vxa3), + 2); + vacc3x4567 = vmlal_lane_s16( + vacc3x4567, + vget_high_s16(vxb01234567c6), + vget_high_s16(vxa3), + 2); + vacc4x0123 = vmlal_lane_s16( + vacc4x0123, + vget_low_s16(vxb01234567c6), + vget_high_s16(vxa4), + 2); + vacc4x4567 = vmlal_lane_s16( + vacc4x4567, + vget_high_s16(vxb01234567c6), + vget_high_s16(vxa4), + 2); + vacc5x0123 = vmlal_lane_s16( + vacc5x0123, + vget_low_s16(vxb01234567c6), + vget_high_s16(vxa5), + 2); + vacc5x4567 = vmlal_lane_s16( + vacc5x4567, + vget_high_s16(vxb01234567c6), + vget_high_s16(vxa5), + 2); + vacc6x0123 = vmlal_lane_s16( + vacc6x0123, + vget_low_s16(vxb01234567c6), + vget_high_s16(vxa6), + 2); + vacc6x4567 = vmlal_lane_s16( + vacc6x4567, + vget_high_s16(vxb01234567c6), + vget_high_s16(vxa6), + 2); + vacc7x0123 = vmlal_lane_s16( + vacc7x0123, + vget_low_s16(vxb01234567c6), + vget_high_s16(vxa7), + 2); + vacc7x4567 = vmlal_lane_s16( + vacc7x4567, + vget_high_s16(vxb01234567c6), + vget_high_s16(vxa7), + 2); + } + } + } + } + } + } + } + + const int32x4_t vmultiplier = + vld1q_dup_s32(&quantization_params->neon.multiplier); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier); + vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier); + vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier); + vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier); + vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier); + vacc4x0123 = vqrdmulhq_s32(vacc4x0123, vmultiplier); + vacc4x4567 = vqrdmulhq_s32(vacc4x4567, vmultiplier); + vacc5x0123 = vqrdmulhq_s32(vacc5x0123, vmultiplier); + vacc5x4567 = vqrdmulhq_s32(vacc5x4567, vmultiplier); + vacc6x0123 = vqrdmulhq_s32(vacc6x0123, vmultiplier); + vacc6x4567 = vqrdmulhq_s32(vacc6x4567, vmultiplier); + vacc7x0123 = vqrdmulhq_s32(vacc7x0123, vmultiplier); + vacc7x4567 = vqrdmulhq_s32(vacc7x4567, vmultiplier); + + const int32x4_t vright_shift = + vld1q_dup_s32(&quantization_params->neon.right_shift); + const int32x4_t vzero_shift_mask = + vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + vacc0x0123 = + vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask), 31); + vacc0x4567 = + vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask), 31); + vacc1x0123 = + vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask), 31); + vacc1x4567 = + vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask), 31); + vacc2x0123 = + vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask), 31); + vacc2x4567 = + vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask), 31); + vacc3x0123 = + vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask), 31); + vacc3x4567 = + vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask), 31); + vacc4x0123 = + vsraq_n_s32(vacc4x0123, vbicq_s32(vacc4x0123, vzero_shift_mask), 31); + vacc4x4567 = + vsraq_n_s32(vacc4x4567, vbicq_s32(vacc4x4567, vzero_shift_mask), 31); + vacc5x0123 = + vsraq_n_s32(vacc5x0123, vbicq_s32(vacc5x0123, vzero_shift_mask), 31); + vacc5x4567 = + vsraq_n_s32(vacc5x4567, vbicq_s32(vacc5x4567, vzero_shift_mask), 31); + vacc6x0123 = + vsraq_n_s32(vacc6x0123, vbicq_s32(vacc6x0123, vzero_shift_mask), 31); + vacc6x4567 = + vsraq_n_s32(vacc6x4567, vbicq_s32(vacc6x4567, vzero_shift_mask), 31); + vacc7x0123 = + vsraq_n_s32(vacc7x0123, vbicq_s32(vacc7x0123, vzero_shift_mask), 31); + vacc7x4567 = + vsraq_n_s32(vacc7x4567, vbicq_s32(vacc7x4567, vzero_shift_mask), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift); + vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift); + vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift); + vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift); + vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift); + vacc4x0123 = vrshlq_s32(vacc4x0123, vright_shift); + vacc4x4567 = vrshlq_s32(vacc4x4567, vright_shift); + vacc5x0123 = vrshlq_s32(vacc5x0123, vright_shift); + vacc5x4567 = vrshlq_s32(vacc5x4567, vright_shift); + vacc6x0123 = vrshlq_s32(vacc6x0123, vright_shift); + vacc6x4567 = vrshlq_s32(vacc6x4567, vright_shift); + vacc7x0123 = vrshlq_s32(vacc7x0123, vright_shift); + vacc7x4567 = vrshlq_s32(vacc7x4567, vright_shift); + + const int16x8_t voutput_zero_point = + vld1q_dup_s16(&quantization_params->neon.output_zero_point); +#ifdef __aarch64__ + const int16x8_t vacc0x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point); + const int16x8_t vacc4x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc4x0123), vacc4x4567), voutput_zero_point); + const int16x8_t vacc5x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc5x0123), vacc5x4567), voutput_zero_point); + const int16x8_t vacc6x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc6x0123), vacc6x4567), voutput_zero_point); + const int16x8_t vacc7x01234567 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc7x0123), vacc7x4567), voutput_zero_point); + + uint8x16_t vout0x01234567_1x01234567 = + vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc1x01234567); + uint8x16_t vout2x01234567_3x01234567 = + vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc3x01234567); + uint8x16_t vout4x01234567_5x01234567 = + vqmovun_high_s16(vqmovun_s16(vacc4x01234567), vacc5x01234567); + uint8x16_t vout6x01234567_7x01234567 = + vqmovun_high_s16(vqmovun_s16(vacc6x01234567), vacc7x01234567); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), + voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), + voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), + voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), + voutput_zero_point); + const int16x8_t vacc4x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc4x0123), vqmovn_s32(vacc4x4567)), + voutput_zero_point); + const int16x8_t vacc5x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc5x0123), vqmovn_s32(vacc5x4567)), + voutput_zero_point); + const int16x8_t vacc6x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc6x0123), vqmovn_s32(vacc6x4567)), + voutput_zero_point); + const int16x8_t vacc7x01234567 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc7x0123), vqmovn_s32(vacc7x4567)), + voutput_zero_point); + + uint8x16_t vout0x01234567_1x01234567 = + vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc1x01234567)); + uint8x16_t vout2x01234567_3x01234567 = + vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc3x01234567)); + uint8x16_t vout4x01234567_5x01234567 = + vcombine_u8(vqmovun_s16(vacc4x01234567), vqmovun_s16(vacc5x01234567)); + uint8x16_t vout6x01234567_7x01234567 = + vcombine_u8(vqmovun_s16(vacc6x01234567), vqmovun_s16(vacc7x01234567)); +#endif + const uint8x16_t voutput_min = + vld1q_dup_u8(&quantization_params->neon.output_min); + const uint8x16_t voutput_max = + vld1q_dup_u8(&quantization_params->neon.output_max); + + vout0x01234567_1x01234567 = vmaxq_u8(vout0x01234567_1x01234567, voutput_min); + vout2x01234567_3x01234567 = vmaxq_u8(vout2x01234567_3x01234567, voutput_min); + vout4x01234567_5x01234567 = vmaxq_u8(vout4x01234567_5x01234567, voutput_min); + vout6x01234567_7x01234567 = vmaxq_u8(vout6x01234567_7x01234567, voutput_min); + vout0x01234567_1x01234567 = vminq_u8(vout0x01234567_1x01234567, voutput_max); + vout2x01234567_3x01234567 = vminq_u8(vout2x01234567_3x01234567, voutput_max); + vout4x01234567_5x01234567 = vminq_u8(vout4x01234567_5x01234567, voutput_max); + vout6x01234567_7x01234567 = vminq_u8(vout6x01234567_7x01234567, voutput_max); + + uint8_t* c0 = c; + uint8_t* c1 = (uint8_t*)((uintptr_t)c0 + c_stride); + if (mr < 2) { + c1 = c0; + } + uint8_t* c2 = (uint8_t*)((uintptr_t)c1 + c_stride); + if (mr <= 2) { + c2 = c1; + } + uint8_t* c3 = (uint8_t*)((uintptr_t)c2 + c_stride); + if (mr < 4) { + c3 = c2; + } + uint8_t* c4 = (uint8_t*)((uintptr_t)c3 + c_stride); + if (mr <= 4) { + c4 = c3; + } + uint8_t* c5 = (uint8_t*)((uintptr_t)c4 + c_stride); + if (mr < 6) { + c5 = c4; + } + uint8_t* c6 = (uint8_t*)((uintptr_t)c5 + c_stride); + if (mr <= 6) { + c6 = c5; + } + uint8_t* c7 = (uint8_t*)((uintptr_t)c6 + c_stride); + if (mr != 8) { + c7 = c6; + } + if (nr == 8) { + vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567)); + vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567)); + vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567)); + vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567)); + vst1_u8(c4, vget_low_u8(vout4x01234567_5x01234567)); + vst1_u8(c5, vget_high_u8(vout4x01234567_5x01234567)); + vst1_u8(c6, vget_low_u8(vout6x01234567_7x01234567)); + vst1_u8(c7, vget_high_u8(vout6x01234567_7x01234567)); + } else { + if (nr >= 4) { + vst1q_lane_u32( + __builtin_assume_aligned(c0, 1), + vreinterpretq_u32_u8(vout0x01234567_1x01234567), + 0); + c0 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c1, 1), + vreinterpretq_u32_u8(vout0x01234567_1x01234567), + 2); + c1 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c2, 1), + vreinterpretq_u32_u8(vout2x01234567_3x01234567), + 0); + c2 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c3, 1), + vreinterpretq_u32_u8(vout2x01234567_3x01234567), + 2); + c3 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c4, 1), + vreinterpretq_u32_u8(vout4x01234567_5x01234567), + 0); + c4 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c5, 1), + vreinterpretq_u32_u8(vout4x01234567_5x01234567), + 2); + c5 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c6, 1), + vreinterpretq_u32_u8(vout6x01234567_7x01234567), + 0); + c6 += 4; + vst1q_lane_u32( + __builtin_assume_aligned(c7, 1), + vreinterpretq_u32_u8(vout6x01234567_7x01234567), + 2); + c7 += 4; + vout0x01234567_1x01234567 = + vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + vout2x01234567_3x01234567 = + vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4); + vout4x01234567_5x01234567 = + vextq_u8(vout4x01234567_5x01234567, vout4x01234567_5x01234567, 4); + vout6x01234567_7x01234567 = + vextq_u8(vout6x01234567_7x01234567, vout6x01234567_7x01234567, 4); + nr -= 4; + } + if (nr >= 2) { + vst1q_lane_u16( + __builtin_assume_aligned(c0, 1), + vreinterpretq_u16_u8(vout0x01234567_1x01234567), + 0); + c0 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c1, 1), + vreinterpretq_u16_u8(vout0x01234567_1x01234567), + 4); + c1 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c2, 1), + vreinterpretq_u16_u8(vout2x01234567_3x01234567), + 0); + c2 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c3, 1), + vreinterpretq_u16_u8(vout2x01234567_3x01234567), + 4); + c3 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c4, 1), + vreinterpretq_u16_u8(vout4x01234567_5x01234567), + 0); + c4 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c5, 1), + vreinterpretq_u16_u8(vout4x01234567_5x01234567), + 4); + c5 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c6, 1), + vreinterpretq_u16_u8(vout6x01234567_7x01234567), + 0); + c6 += 2; + vst1q_lane_u16( + __builtin_assume_aligned(c7, 1), + vreinterpretq_u16_u8(vout6x01234567_7x01234567), + 4); + c7 += 2; + vout0x01234567_1x01234567 = + vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + vout2x01234567_3x01234567 = + vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2); + vout4x01234567_5x01234567 = + vextq_u8(vout4x01234567_5x01234567, vout4x01234567_5x01234567, 2); + vout6x01234567_7x01234567 = + vextq_u8(vout6x01234567_7x01234567, vout6x01234567_7x01234567, 2); + nr -= 2; + } + if (nr != 0) { + vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0); + vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8); + vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0); + vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8); + vst1q_lane_u8(c4, vout4x01234567_5x01234567, 0); + vst1q_lane_u8(c5, vout4x01234567_5x01234567, 8); + vst1q_lane_u8(c6, vout6x01234567_7x01234567, 0); + vst1q_lane_u8(c7, vout6x01234567_7x01234567, 8); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8vadd/neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8vadd/neon.c new file mode 100644 index 0000000000000..867708b7615fa --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8vadd/neon.c @@ -0,0 +1,382 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +void pytorch_q8vadd_ukernel__neon( + size_t n, + const uint8_t* a, + const uint8_t* b, + uint8_t* y, + const union pytorch_qnnp_add_quantization_params + quantization_params[restrict static 1]) { + const uint8x8_t va_zero_point = + vld1_dup_u8(&quantization_params->neon.a_zero_point); + const uint8x8_t vb_zero_point = + vld1_dup_u8(&quantization_params->neon.b_zero_point); + const int16x8_t vy_zero_point = + vld1q_dup_s16(&quantization_params->neon.y_zero_point); + const int32x4_t va_multiplier = + vld1q_dup_s32(&quantization_params->neon.a_multiplier); + const int32x4_t vb_multiplier = + vld1q_dup_s32(&quantization_params->neon.b_multiplier); + const int32x4_t vright_shift = + vld1q_dup_s32(&quantization_params->neon.right_shift); + const int32x4_t vzero_shift_mask = + vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0))); + const uint8x16_t vy_max = vld1q_dup_u8(&quantization_params->neon.y_max); + const uint8x16_t vy_min = vld1q_dup_u8(&quantization_params->neon.y_min); + if + PYTORCH_QNNP_LIKELY(n >= 8) { +#ifdef __aarch64__ + for (; n >= 32; n -= 32) { + const uint8x16_t va01 = vld1q_u8(a); + a += 16; + const uint8x16_t vb01 = vld1q_u8(b); + b += 16; + const uint8x16_t va23 = vld1q_u8(a); + a += 16; + const uint8x16_t vb23 = vld1q_u8(b); + b += 16; + + /* Subtract zero point */ + const int16x8_t vxa0 = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(va01), va_zero_point)); + const int16x8_t vxb0 = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(vb01), vb_zero_point)); + const int16x8_t vxa1 = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(va01), va_zero_point)); + const int16x8_t vxb1 = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(vb01), vb_zero_point)); + const int16x8_t vxa2 = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(va23), va_zero_point)); + const int16x8_t vxb2 = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(vb23), vb_zero_point)); + const int16x8_t vxa3 = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(va23), va_zero_point)); + const int16x8_t vxb3 = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(vb23), vb_zero_point)); + + /* Multiply by factors and accumulate products */ + int32x4_t vacc0_lo = + vmulq_s32(vmovl_s16(vget_low_s16(vxa0)), va_multiplier); + int32x4_t vacc1_lo = + vmulq_s32(vmovl_s16(vget_low_s16(vxa1)), va_multiplier); + int32x4_t vacc2_lo = + vmulq_s32(vmovl_s16(vget_low_s16(vxa2)), va_multiplier); + int32x4_t vacc3_lo = + vmulq_s32(vmovl_s16(vget_low_s16(vxa3)), va_multiplier); + int32x4_t vacc0_hi = vmulq_s32(vmovl_high_s16(vxa0), va_multiplier); + int32x4_t vacc1_hi = vmulq_s32(vmovl_high_s16(vxa1), va_multiplier); + int32x4_t vacc2_hi = vmulq_s32(vmovl_high_s16(vxa2), va_multiplier); + int32x4_t vacc3_hi = vmulq_s32(vmovl_high_s16(vxa3), va_multiplier); + + vacc0_lo = + vmlaq_s32(vacc0_lo, vmovl_s16(vget_low_s16(vxb0)), vb_multiplier); + vacc1_lo = + vmlaq_s32(vacc1_lo, vmovl_s16(vget_low_s16(vxb1)), vb_multiplier); + vacc2_lo = + vmlaq_s32(vacc2_lo, vmovl_s16(vget_low_s16(vxb2)), vb_multiplier); + vacc3_lo = + vmlaq_s32(vacc3_lo, vmovl_s16(vget_low_s16(vxb3)), vb_multiplier); + vacc0_hi = vmlaq_s32(vacc0_hi, vmovl_high_s16(vxb0), vb_multiplier); + vacc1_hi = vmlaq_s32(vacc1_hi, vmovl_high_s16(vxb1), vb_multiplier); + vacc2_hi = vmlaq_s32(vacc2_hi, vmovl_high_s16(vxb2), vb_multiplier); + vacc3_hi = vmlaq_s32(vacc3_hi, vmovl_high_s16(vxb3), vb_multiplier); + + /* Shift right and round */ + vacc0_lo = + vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31); + vacc1_lo = + vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31); + vacc2_lo = + vsraq_n_s32(vacc2_lo, vbicq_s32(vacc2_lo, vzero_shift_mask), 31); + vacc3_lo = + vsraq_n_s32(vacc3_lo, vbicq_s32(vacc3_lo, vzero_shift_mask), 31); + vacc0_hi = + vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31); + vacc1_hi = + vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31); + vacc2_hi = + vsraq_n_s32(vacc2_hi, vbicq_s32(vacc2_hi, vzero_shift_mask), 31); + vacc3_hi = + vsraq_n_s32(vacc3_hi, vbicq_s32(vacc3_hi, vzero_shift_mask), 31); + + vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift); + vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift); + vacc2_lo = vrshlq_s32(vacc2_lo, vright_shift); + vacc3_lo = vrshlq_s32(vacc3_lo, vright_shift); + vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift); + vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift); + vacc2_hi = vrshlq_s32(vacc2_hi, vright_shift); + vacc3_hi = vrshlq_s32(vacc3_hi, vright_shift); + + /* Pack, saturate, and add output zero point */ + const int16x8_t vacc0 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi), vy_zero_point); + const int16x8_t vacc1 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc1_lo), vacc1_hi), vy_zero_point); + const int16x8_t vacc2 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc2_lo), vacc2_hi), vy_zero_point); + const int16x8_t vacc3 = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc3_lo), vacc3_hi), vy_zero_point); + + uint8x16_t vy01 = vqmovun_high_s16(vqmovun_s16(vacc0), vacc1); + uint8x16_t vy23 = vqmovun_high_s16(vqmovun_s16(vacc2), vacc3); + + vy01 = vmaxq_u8(vy01, vy_min); + vy23 = vmaxq_u8(vy23, vy_min); + vy01 = vminq_u8(vy01, vy_max); + vy23 = vminq_u8(vy23, vy_max); + + vst1q_u8(y, vy01); + y += 16; + vst1q_u8(y, vy23); + y += 16; + } +#else + for (; n >= 16; n -= 16) { + const uint8x16_t va01 = vld1q_u8(a); + a += 16; + const uint8x16_t vb01 = vld1q_u8(b); + b += 16; + + /* Subtract zero point */ + const int16x8_t vxa0 = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(va01), va_zero_point)); + const int16x8_t vxb0 = + vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(vb01), vb_zero_point)); + const int16x8_t vxa1 = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(va01), va_zero_point)); + const int16x8_t vxb1 = + vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(vb01), vb_zero_point)); + + /* Multiply by factors and accumulate products */ + int32x4_t vacc0_lo = + vmulq_s32(vmovl_s16(vget_low_s16(vxa0)), va_multiplier); + int32x4_t vacc1_lo = + vmulq_s32(vmovl_s16(vget_low_s16(vxa1)), va_multiplier); + int32x4_t vacc0_hi = + vmulq_s32(vmovl_s16(vget_high_s16(vxa0)), va_multiplier); + int32x4_t vacc1_hi = + vmulq_s32(vmovl_s16(vget_high_s16(vxa1)), va_multiplier); + + __builtin_prefetch(a + 640); + __builtin_prefetch(b + 640); + + vacc0_lo = + vmlaq_s32(vacc0_lo, vmovl_s16(vget_low_s16(vxb0)), vb_multiplier); + vacc1_lo = + vmlaq_s32(vacc1_lo, vmovl_s16(vget_low_s16(vxb1)), vb_multiplier); + vacc0_hi = + vmlaq_s32(vacc0_hi, vmovl_s16(vget_high_s16(vxb0)), vb_multiplier); + vacc1_hi = + vmlaq_s32(vacc1_hi, vmovl_s16(vget_high_s16(vxb1)), vb_multiplier); + + /* Shift right and round */ + vacc0_lo = + vsraq_n_s32(vacc0_lo, vbicq_s32(vacc0_lo, vzero_shift_mask), 31); + vacc1_lo = + vsraq_n_s32(vacc1_lo, vbicq_s32(vacc1_lo, vzero_shift_mask), 31); + vacc0_hi = + vsraq_n_s32(vacc0_hi, vbicq_s32(vacc0_hi, vzero_shift_mask), 31); + vacc1_hi = + vsraq_n_s32(vacc1_hi, vbicq_s32(vacc1_hi, vzero_shift_mask), 31); + + vacc0_lo = vrshlq_s32(vacc0_lo, vright_shift); + vacc1_lo = vrshlq_s32(vacc1_lo, vright_shift); + vacc0_hi = vrshlq_s32(vacc0_hi, vright_shift); + vacc1_hi = vrshlq_s32(vacc1_hi, vright_shift); + + /* Pack, saturate, and add output zero point */ + const int16x8_t vacc0 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc0_lo), vqmovn_s32(vacc0_hi)), + vy_zero_point); + const int16x8_t vacc1 = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc1_lo), vqmovn_s32(vacc1_hi)), + vy_zero_point); + + uint8x16_t vy01 = vcombine_u8(vqmovun_s16(vacc0), vqmovun_s16(vacc1)); + vy01 = vmaxq_u8(vy01, vy_min); + vy01 = vminq_u8(vy01, vy_max); + + vst1q_u8(y, vy01); + y += 16; + } +#endif + for (; n >= 8; n -= 8) { + const uint8x8_t va = vld1_u8(a); + a += 8; + const uint8x8_t vb = vld1_u8(b); + b += 8; + + /* Subtract zero point */ + const int16x8_t vxa = + vreinterpretq_s16_u16(vsubl_u8(va, va_zero_point)); + const int16x8_t vxb = + vreinterpretq_s16_u16(vsubl_u8(vb, vb_zero_point)); + + /* Multiply by factors and accumulate products */ + int32x4_t vacc_lo = + vmulq_s32(vmovl_s16(vget_low_s16(vxa)), va_multiplier); +#ifdef __aarch64__ + int32x4_t vacc_hi = vmulq_s32(vmovl_high_s16(vxa), va_multiplier); +#else + int32x4_t vacc_hi = + vmulq_s32(vmovl_s16(vget_high_s16(vxa)), va_multiplier); +#endif + + vacc_lo = + vmlaq_s32(vacc_lo, vmovl_s16(vget_low_s16(vxb)), vb_multiplier); +#ifdef __aarch64__ + vacc_hi = vmlaq_s32(vacc_hi, vmovl_high_s16(vxb), vb_multiplier); +#else + vacc_hi = + vmlaq_s32(vacc_hi, vmovl_s16(vget_high_s16(vxb)), vb_multiplier); +#endif + + /* Shift right and round */ + vacc_lo = + vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31); + vacc_hi = + vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31); + + vacc_lo = vrshlq_s32(vacc_lo, vright_shift); + vacc_hi = vrshlq_s32(vacc_hi, vright_shift); + + /* Pack, saturate, and add output zero point */ +#ifdef __aarch64__ + const int16x8_t vacc = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), vy_zero_point); +#else + const int16x8_t vacc = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), + vy_zero_point); +#endif + + uint8x8_t vy = vqmovun_s16(vacc); + vy = vmax_u8(vy, vget_low_u8(vy_min)); + vy = vmin_u8(vy, vget_low_u8(vy_max)); + + vst1_u8(y, vy); + y += 8; + } + if (n != 0) { + const size_t n_increment = n - 8; + const int64x1_t vld_shift = vmov_n_s64(8 * n_increment); + const uint8x8_t va = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(a + n_increment)), vld_shift)); + const uint8x8_t vb = vreinterpret_u8_u64( + vshl_u64(vreinterpret_u64_u8(vld1_u8(b + n_increment)), vld_shift)); + + /* Subtract zero point */ + const int16x8_t vxa = + vreinterpretq_s16_u16(vsubl_u8(va, va_zero_point)); + const int16x8_t vxb = + vreinterpretq_s16_u16(vsubl_u8(vb, vb_zero_point)); + + /* Multiply by factors and accumulate products */ + int32x4_t vacc_lo = + vmulq_s32(vmovl_s16(vget_low_s16(vxa)), va_multiplier); +#ifdef __aarch64__ + int32x4_t vacc_hi = vmulq_s32(vmovl_high_s16(vxa), va_multiplier); +#else + int32x4_t vacc_hi = + vmulq_s32(vmovl_s16(vget_high_s16(vxa)), va_multiplier); +#endif + + vacc_lo = + vmlaq_s32(vacc_lo, vmovl_s16(vget_low_s16(vxb)), vb_multiplier); +#ifdef __aarch64__ + vacc_hi = vmlaq_s32(vacc_hi, vmovl_high_s16(vxb), vb_multiplier); +#else + vacc_hi = + vmlaq_s32(vacc_hi, vmovl_s16(vget_high_s16(vxb)), vb_multiplier); +#endif + + /* Shift right and round */ + vacc_lo = + vsraq_n_s32(vacc_lo, vbicq_s32(vacc_lo, vzero_shift_mask), 31); + vacc_hi = + vsraq_n_s32(vacc_hi, vbicq_s32(vacc_hi, vzero_shift_mask), 31); + + vacc_lo = vrshlq_s32(vacc_lo, vright_shift); + vacc_hi = vrshlq_s32(vacc_hi, vright_shift); + + /* Pack, saturate, and add output zero point */ +#ifdef __aarch64__ + const int16x8_t vacc = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(vacc_lo), vacc_hi), vy_zero_point); +#else + const int16x8_t vacc = vqaddq_s16( + vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)), + vy_zero_point); +#endif + + uint8x8_t vy = vqmovun_s16(vacc); + vy = vmax_u8(vy, vget_low_u8(vy_min)); + vy = vmin_u8(vy, vget_low_u8(vy_max)); + + if (n & 4) { + vst1_lane_u32( + __builtin_assume_aligned(y, 1), vreinterpret_u32_u8(vy), 0); + y += 4; + vy = vext_u8(vy, vy, 4); + } + if (n & 2) { + vst1_lane_u16( + __builtin_assume_aligned(y, 1), vreinterpret_u16_u8(vy), 0); + y += 2; + vy = vext_u8(vy, vy, 2); + } + if (n & 1) { + vst1_lane_u8(y, vy, 0); + } + } + } + else { + for (; n != 0; n--) { + const uint8x8_t va = vld1_dup_u8(a); + a += 1; + const uint8x8_t vb = vld1_dup_u8(b); + b += 1; + + /* Subtract zero point */ + const int16x4_t vxa = + vreinterpret_s16_u16(vget_low_u16(vsubl_u8(va, va_zero_point))); + const int16x4_t vxb = + vreinterpret_s16_u16(vget_low_u16(vsubl_u8(vb, vb_zero_point))); + + /* Multiply by factors and accumulate products */ + int32x2_t vacc = + vmul_s32(vget_low_s32(vmovl_s16(vxa)), vget_low_s32(va_multiplier)); + vacc = vmla_s32( + vacc, vget_low_s32(vmovl_s16(vxb)), vget_low_s32(vb_multiplier)); + + /* Shift right and round */ + vacc = + vsra_n_s32(vacc, vbic_s32(vacc, vget_low_s32(vzero_shift_mask)), 31); + + vacc = vrshl_s32(vacc, vget_low_s32(vright_shift)); + + const int16x4_t vacc16 = vqadd_s16( + vqmovn_s32(vcombine_s32(vacc, vacc)), vget_low_s16(vy_zero_point)); + + /* Pack, saturate, and add output zero point */ + uint8x8_t vy = vqmovun_s16(vcombine_s16(vacc16, vacc16)); + vy = vmin_u8(vy, vget_low_u8(vy_max)); + vy = vmax_u8(vy, vget_low_u8(vy_min)); + + vst1_lane_u8(y, vy, 0); + y += 1; + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8vadd/sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8vadd/sse2.c new file mode 100644 index 0000000000000..6a829c442fd65 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/q8vadd/sse2.c @@ -0,0 +1,224 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +void pytorch_q8vadd_ukernel__sse2( + size_t n, + const uint8_t* a, + const uint8_t* b, + uint8_t* y, + const union pytorch_qnnp_add_quantization_params + quantization_params[RESTRICT_STATIC 1]) { + if + PYTORCH_QNNP_LIKELY(n >= 8) { + const __m128i vzero_point_product = _mm_load_si128( + (const __m128i*)&quantization_params->sse2.zero_point_product); + const __m128i va_multiplier_lo = _mm_load_si128( + (const __m128i*)&quantization_params->sse2.a_multiplier_lo); + const __m128i va_multiplier_hi = _mm_load_si128( + (const __m128i*)&quantization_params->sse2.a_multiplier_hi); + const __m128i vb_multiplier_lo = _mm_load_si128( + (const __m128i*)&quantization_params->sse2.b_multiplier_lo); + const __m128i vb_multiplier_hi = _mm_load_si128( + (const __m128i*)&quantization_params->sse2.b_multiplier_hi); + const __m128i vremainder_mask = _mm_load_si128( + (const __m128i*)quantization_params->sse2.remainder_mask); + const __m128i vremainder_threshold = _mm_load_si128( + (const __m128i*)quantization_params->sse2.remainder_threshold); + const __m128i vshift = + _mm_cvtsi32_si128((int)quantization_params->sse2.shift); + + const __m128i vzero = _mm_setzero_si128(); + do { + const __m128i va = _mm_loadl_epi64((const __m128i*)a); + a += 8; + const __m128i vb = _mm_loadl_epi64((const __m128i*)b); + b += 8; + + const __m128i vxa = _mm_unpacklo_epi8(va, vzero); + const __m128i vxb = _mm_unpacklo_epi8(vb, vzero); + + /* Multiply by factors */ + const __m128i va_product_lo = _mm_mullo_epi16(vxa, va_multiplier_lo); + const __m128i va_product_hi = _mm_add_epi16( + _mm_mulhi_epu16(vxa, va_multiplier_lo), + _mm_mullo_epi16(vxa, va_multiplier_hi)); + + const __m128i vb_product_lo = _mm_mullo_epi16(vxb, vb_multiplier_lo); + const __m128i vb_product_hi = _mm_add_epi16( + _mm_mulhi_epu16(vxb, vb_multiplier_lo), + _mm_mullo_epi16(vxb, vb_multiplier_hi)); + + /* Accumulate products */ + __m128i vacc_lo = _mm_add_epi32( + vzero_point_product, + _mm_unpacklo_epi16(va_product_lo, va_product_hi)); + __m128i vacc_hi = _mm_add_epi32( + vzero_point_product, + _mm_unpackhi_epi16(va_product_lo, va_product_hi)); + + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vb_product_lo, vb_product_hi)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vb_product_lo, vb_product_hi)); + + /* Shift right and round */ + const __m128i vrem_lo = _mm_add_epi32( + _mm_and_si128(vacc_lo, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo)); + const __m128i vrem_hi = _mm_add_epi32( + _mm_and_si128(vacc_hi, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi)); + + vacc_lo = _mm_sub_epi32( + _mm_sra_epi32(vacc_lo, vshift), + _mm_cmpgt_epi32(vrem_lo, vremainder_threshold)); + vacc_hi = _mm_sub_epi32( + _mm_sra_epi32(vacc_hi, vshift), + _mm_cmpgt_epi32(vrem_hi, vremainder_threshold)); + + /* Pack, saturate, and add output zero point */ + const __m128i vy_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.y_zero_point); + const __m128i vacc = + _mm_adds_epi16(_mm_packs_epi32(vacc_lo, vacc_hi), vy_zero_point); + __m128i vy = _mm_packus_epi16(vacc, vacc); + vy = _mm_max_epu8( + vy, + _mm_load_si128((const __m128i*)quantization_params->sse2.y_min)); + vy = _mm_min_epu8( + vy, + _mm_load_si128((const __m128i*)quantization_params->sse2.y_max)); + + _mm_storel_epi64((__m128i*)y, vy); + y += 8; + + n -= 8; + } while (n >= 8); + if (n != 0) { + const size_t n_decrement = 8 - n; + const __m128i vload_shift = _mm_cvtsi32_si128(8 * (int32_t)n_decrement); + + const __m128i va = _mm_srl_epi64( + _mm_loadl_epi64((const __m128i*)(a - n_decrement)), vload_shift); + const __m128i vb = _mm_srl_epi64( + _mm_loadl_epi64((const __m128i*)(b - n_decrement)), vload_shift); + + const __m128i vxa = _mm_unpacklo_epi8(va, vzero); + const __m128i vxb = _mm_unpacklo_epi8(vb, vzero); + + /* Multiply by factors */ + const __m128i va_product_lo = _mm_mullo_epi16(vxa, va_multiplier_lo); + const __m128i va_product_hi = _mm_add_epi16( + _mm_mulhi_epu16(vxa, va_multiplier_lo), + _mm_mullo_epi16(vxa, va_multiplier_hi)); + + const __m128i vb_product_lo = _mm_mullo_epi16(vxb, vb_multiplier_lo); + const __m128i vb_product_hi = _mm_add_epi16( + _mm_mulhi_epu16(vxb, vb_multiplier_lo), + _mm_mullo_epi16(vxb, vb_multiplier_hi)); + + /* Accumulate products */ + __m128i vacc_lo = _mm_add_epi32( + vzero_point_product, + _mm_unpacklo_epi16(va_product_lo, va_product_hi)); + __m128i vacc_hi = _mm_add_epi32( + vzero_point_product, + _mm_unpackhi_epi16(va_product_lo, va_product_hi)); + + vacc_lo = _mm_add_epi32( + vacc_lo, _mm_unpacklo_epi16(vb_product_lo, vb_product_hi)); + vacc_hi = _mm_add_epi32( + vacc_hi, _mm_unpackhi_epi16(vb_product_lo, vb_product_hi)); + + /* Shift right and round */ + const __m128i vrem_lo = _mm_add_epi32( + _mm_and_si128(vacc_lo, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_lo)); + const __m128i vrem_hi = _mm_add_epi32( + _mm_and_si128(vacc_hi, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), vacc_hi)); + + vacc_lo = _mm_sub_epi32( + _mm_sra_epi32(vacc_lo, vshift), + _mm_cmpgt_epi32(vrem_lo, vremainder_threshold)); + vacc_hi = _mm_sub_epi32( + _mm_sra_epi32(vacc_hi, vshift), + _mm_cmpgt_epi32(vrem_hi, vremainder_threshold)); + + /* Pack, saturate, and add output zero point */ + const __m128i vy_zero_point = _mm_load_si128( + (const __m128i*)quantization_params->sse2.y_zero_point); + const __m128i vacc = + _mm_adds_epi16(_mm_packs_epi32(vacc_lo, vacc_hi), vy_zero_point); + __m128i vy = _mm_packus_epi16(vacc, vacc); + vy = _mm_max_epu8( + vy, + _mm_load_si128((const __m128i*)quantization_params->sse2.y_min)); + vy = _mm_min_epu8( + vy, + _mm_load_si128((const __m128i*)quantization_params->sse2.y_max)); + + if (n & 4) { + *((uint32_t*)y) = (uint32_t)_mm_cvtsi128_si32(vy); + vy = _mm_shuffle_epi32(vy, _MM_SHUFFLE(3, 2, 1, 1)); + y += 4; + } + if (n & 2) { + *((uint16_t*)y) = (uint16_t)_mm_extract_epi16(vy, 0); + vy = _mm_srli_epi32(vy, 16); + y += 2; + } + if (n & 1) { + *((uint8_t*)y) = (uint8_t)_mm_cvtsi128_si32(vy); + } + } + } + else { + const int32_t vzero_point_product = + quantization_params->sse2.zero_point_product[0]; + const uint32_t va_multiplier = quantization_params->sse2.a_multiplier; + const uint32_t vb_multiplier = quantization_params->sse2.b_multiplier; + const int32_t vremainder_mask = quantization_params->sse2.remainder_mask[0]; + const int32_t vremainder_threshold = + quantization_params->sse2.remainder_threshold[0]; + const uint32_t vshift = quantization_params->sse2.shift; + const int32_t vy_zero_point = + (int32_t)quantization_params->sse2.y_zero_point[0]; + const int32_t vy_max = + (int32_t)(uint32_t)quantization_params->sse2.y_max[0]; + const int32_t vy_min = + (int32_t)(uint32_t)quantization_params->sse2.y_min[0]; + + while (n-- != 0) { + const uint32_t vxa = (uint32_t)*a++; + const uint32_t vxb = (uint32_t)*b++; + + /* Multiply by factors and accumulate products */ + int32_t vacc = vzero_point_product + (int32_t)(vxa * va_multiplier) + + (int32_t)(vxb * vb_multiplier); + + /* Shift right and round */ + const int32_t vrem = (vacc & vremainder_mask) - (int32_t)(vacc < 0); + + vacc = asr_s32(vacc, vshift) + (int32_t)(vrem > vremainder_threshold); + + /* Clamp and add output zero point */ + int32_t vy = vacc + vy_zero_point; + vy = vy >= vy_min ? vy : vy_min; + vy = vy <= vy_max ? vy : vy_max; + + *y++ = (uint8_t)vy; + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/AlignedAllocator.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/AlignedAllocator.h new file mode 100644 index 0000000000000..dd29682462f21 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/AlignedAllocator.h @@ -0,0 +1,104 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +template +class AlignedAllocator; + +template +class AlignedAllocator { + public: + typedef void* pointer; + typedef const void* const_pointer; + typedef void value_type; + + template + struct rebind { + typedef AlignedAllocator other; + }; +}; + +template +class AlignedAllocator { + public: + typedef T value_type; + typedef T* pointer; + typedef const T* const_pointer; + typedef T& reference; + typedef const T& const_reference; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + +#if __cplusplus >= 201402L + typedef std::true_type propagate_on_container_move_assignment; +#endif + + template + struct rebind { + typedef AlignedAllocator other; + }; + + public: + inline AlignedAllocator() noexcept {} + + template + inline AlignedAllocator( + const AlignedAllocator& other) noexcept {} + + inline size_type max_size() const noexcept { + return (std::numeric_limits::max() - size_type(Alignment)) / + sizeof(T); + } + + inline pointer address(reference x) const noexcept { + return std::addressof(x); + } + + inline const_pointer address(const_reference x) const noexcept { + return std::addressof(x); + } + + inline pointer allocate( + size_type n, + typename AlignedAllocator::const_pointer hint = 0) { +#if defined(__ANDROID__) + void* memory = memalign(Alignment, n * sizeof(T)); + if (memory == 0) { +#if !defined(__GNUC__) || defined(__EXCEPTIONS) + throw std::bad_alloc(); +#endif + } +#else + void* memory = nullptr; + if (posix_memalign(&memory, Alignment, n * sizeof(T)) != 0) { +#if !defined(__GNUC__) || defined(__EXCEPTIONS) + throw std::bad_alloc(); +#endif + } +#endif + return static_cast(memory); + } + + inline void deallocate(pointer p, size_type n) noexcept { + free(static_cast(p)); + } + + template + inline void construct(U* p, Args&&... args) { + ::new (static_cast(p)) U(std::forward(args)...); + } + + template + inline void destroy(U* p) { + p->~U(); + } +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/assembly.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/assembly.h new file mode 100644 index 0000000000000..8f8f351b73653 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/assembly.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// clang-format off +#ifdef __ELF__ + .macro BEGIN_FUNCTION name + .text + .align 2 + .global \name + .type \name, %function + \name: + .endm + + .macro END_FUNCTION name + .size \name, .-\name + .endm +#elif defined(__MACH__) + .macro BEGIN_FUNCTION name + .text + .align 2 + .global _\name + .private_extern _\name + _\name: + .endm + + .macro END_FUNCTION name + .endm +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/common.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/common.h new file mode 100644 index 0000000000000..14bcc01d21ed0 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/common.h @@ -0,0 +1,82 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#if defined(__GNUC__) +#if defined(__clang__) || (__GNUC__ > 4 || __GNUC__ == 4 && __GNUC_MINOR__ >= 5) +#define PYTORCH_QNNP_UNREACHABLE \ + do { \ + __builtin_unreachable(); \ + } while (0) +#else +#define PYTORCH_QNNP_UNREACHABLE \ + do { \ + __builtin_trap(); \ + } while (0) +#endif +#elif defined(_MSC_VER) +#define PYTORCH_QNNP_UNREACHABLE __assume(0) +#else +#define PYTORCH_QNNP_UNREACHABLE \ + do { \ + } while (0) +#endif + +#if defined(_MSC_VER) +#define PYTORCH_QNNP_ALIGN(alignment) __declspec(align(alignment)) +#else +#define PYTORCH_QNNP_ALIGN(alignment) __attribute__((__aligned__(alignment))) +#endif + +#define PYTORCH_QNNP_COUNT_OF(array) (sizeof(array) / sizeof(0 [array])) + +#if defined(__GNUC__) +#define PYTORCH_QNNP_LIKELY(condition) (__builtin_expect(!!(condition), 1)) +#define PYTORCH_QNNP_UNLIKELY(condition) (__builtin_expect(!!(condition), 0)) +#else +#define PYTORCH_QNNP_LIKELY(condition) (!!(condition)) +#define PYTORCH_QNNP_UNLIKELY(condition) (!!(condition)) +#endif + +#if defined(__GNUC__) +#define PYTORCH_QNNP_INLINE inline __attribute__((__always_inline__)) +#else +#define PYTORCH_QNNP_INLINE inline +#endif + +#ifndef PYTORCH_QNNP_INTERNAL +#if defined(__ELF__) +#define PYTORCH_QNNP_INTERNAL __attribute__((__visibility__("internal"))) +#elif defined(__MACH__) +#define PYTORCH_QNNP_INTERNAL __attribute__((__visibility__("hidden"))) +#else +#define PYTORCH_QNNP_INTERNAL +#endif +#endif + +#ifndef PYTORCH_QNNP_PRIVATE +#if defined(__ELF__) +#define PYTORCH_QNNP_PRIVATE __attribute__((__visibility__("hidden"))) +#elif defined(__MACH__) +#define PYTORCH_QNNP_PRIVATE __attribute__((__visibility__("hidden"))) +#else +#define PYTORCH_QNNP_PRIVATE +#endif +#endif + +#if defined(_MSC_VER) +#define RESTRICT_STATIC +#define restrict +#else +#define RESTRICT_STATIC restrict static +#endif + +#if defined(_MSC_VER) +#define __builtin_prefetch +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/hgemm.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/hgemm.h new file mode 100644 index 0000000000000..46e792cca359b --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/hgemm.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define DECLARE_PYTORCH_HGEMM_UKERNEL_FUNCTION(fn_name) \ + void fn_name( \ + size_t mr, \ + size_t nr, \ + size_t k, \ + const void* a, \ + size_t a_stride, \ + const void* w, \ + void* c, \ + size_t c_stride, \ + const struct pytorch_qnnp_fp16_clamping_params* clamping_params); + +DECLARE_PYTORCH_HGEMM_UKERNEL_FUNCTION(pytorch_hgemm_ukernel_8x8__neonfp16arith) +DECLARE_PYTORCH_HGEMM_UKERNEL_FUNCTION(pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith) + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/indirection.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/indirection.h new file mode 100644 index 0000000000000..20cc13bb2db60 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/indirection.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + + PYTORCH_QNNP_INTERNAL void pytorch_qnnp_indirection_init_conv2d( + pytorch_qnnp_operator_t op, + size_t output_tile_size, + size_t tiled_output_size); + + PYTORCH_QNNP_INTERNAL void pytorch_qnnp_indirection_init_dwconv2d( + pytorch_qnnp_operator_t convolution, + size_t batch_start, + size_t step_height, + size_t step_width); + + PYTORCH_QNNP_INTERNAL void pytorch_qnnp_indirection_init_deconv2d( + pytorch_qnnp_operator_t op, + size_t output_tile_size, + size_t tiled_output_size); + + PYTORCH_QNNP_INTERNAL void pytorch_qnnp_indirection_init_maxpool2d( + pytorch_qnnp_operator_t op, + size_t batch_start, + size_t step_height, + size_t step_width); + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/isa-checks.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/isa-checks.h new file mode 100644 index 0000000000000..68414f873b5be --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/isa-checks.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#define TEST_REQUIRES_X86_SSE2 \ + do { \ + if (!cpuinfo_initialize() || !cpuinfo_has_x86_sse2()) { \ + return; \ + } \ + } while (0) + +#define TEST_REQUIRES_ARM_NEON \ + do { \ + if (!cpuinfo_initialize() || !cpuinfo_has_arm_neon()) { \ + return; \ + } \ + } while (0) + +#define TEST_REQUIRES_ARM_NEON_FP16_ARITH \ + do { \ + if (!cpuinfo_initialize() || !cpuinfo_has_arm_neon_fp16_arith()) { \ + return; \ + } \ + } while (0) diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/log.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/log.h new file mode 100644 index 0000000000000..64ef72e2ca94f --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/log.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +#ifndef PYTORCH_QNNP_LOG_LEVEL +#define PYTORCH_QNNP_LOG_LEVEL CLOG_DEBUG +#endif + +CLOG_DEFINE_LOG_DEBUG( + pytorch_qnnp_log_debug, + "QNNPACK", + PYTORCH_QNNP_LOG_LEVEL); +CLOG_DEFINE_LOG_INFO(pytorch_qnnp_log_info, "QNNPACK", PYTORCH_QNNP_LOG_LEVEL); +CLOG_DEFINE_LOG_WARNING( + pytorch_qnnp_log_warning, + "QNNPACK", + PYTORCH_QNNP_LOG_LEVEL); +CLOG_DEFINE_LOG_ERROR( + pytorch_qnnp_log_error, + "QNNPACK", + PYTORCH_QNNP_LOG_LEVEL); +CLOG_DEFINE_LOG_FATAL( + pytorch_qnnp_log_fatal, + "QNNPACK", + PYTORCH_QNNP_LOG_LEVEL); diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/math.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/math.h new file mode 100644 index 0000000000000..82423546f9b0e --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/math.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#ifdef _MSC_VER +#undef min +#undef max +#endif + +inline static size_t min(size_t a, size_t b) { + return a < b ? a : b; +} + +inline static size_t max(size_t a, size_t b) { + return a > b ? a : b; +} + +inline static size_t doz(size_t a, size_t b) { + return a < b ? 0 : a - b; +} + +inline static size_t divide_round_up(size_t n, size_t q) { + return n % q == 0 ? n / q : n / q + 1; +} + +inline static size_t round_up(size_t n, size_t q) { + return divide_round_up(n, q) * q; +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/operator.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/operator.h new file mode 100644 index 0000000000000..e973bb4fa9f25 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/operator.h @@ -0,0 +1,121 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +enum pytorch_qnnp_format { + pytorch_qnnp_format_quint8 = 0x02000000, + pytorch_qnnp_format_float32 = 0x02020202, + pytorch_qnnp_format_float16 = 0x01010101, +}; + +enum pytorch_qnnp_ukernel_type { + pytorch_qnnp_ukernel_type_none = 0, + pytorch_qnnp_ukernel_type_add, + pytorch_qnnp_ukernel_type_average_pooling, + pytorch_qnnp_ukernel_type_channel_shuffle, + pytorch_qnnp_ukernel_type_clamp, + pytorch_qnnp_ukernel_type_conv, + pytorch_qnnp_ukernel_type_dwconv, + pytorch_qnnp_ukernel_type_gemm, + pytorch_qnnp_ukernel_type_global_average_pooling, + pytorch_qnnp_ukernel_type_lut, + pytorch_qnnp_ukernel_type_max_pooling, + pytorch_qnnp_ukernel_type_softargmax, + pytorch_qnnp_ukernel_type_xzp_gemm, +}; + +struct pytorch_qnnp_operator { + size_t batch_size; + uint32_t input_padding_top; + uint32_t input_padding_right; + uint32_t input_padding_bottom; + uint32_t input_padding_left; + uint32_t adjustment_height; + uint32_t adjustment_width; + uint32_t kernel_height; + uint32_t kernel_width; + uint32_t stride_height; + uint32_t stride_width; + uint32_t dilation_height; + uint32_t dilation_width; + uint32_t groups; + size_t group_stride; + size_t group_channels; + size_t group_input_channels; + size_t group_output_channels; + size_t channels; + + size_t input_height; + size_t input_width; + size_t input_pixel_stride; + const void* input; + const void** indirection_buffer; + void* a_sum; + + size_t input2_pixel_stride; + const void* input2; + + size_t output_height; + size_t output_width; + size_t output_pixel_stride; + void* output; + + void* packed_weights; + float input_scale; + float output_scale; + uint8_t input_zero_point; + uint8_t kernel_zero_point; + uint8_t output_zero_point; + uint8_t output_min; + uint8_t output_max; + + size_t valid_batch_size; + size_t last_input_height; + size_t last_input_width; + const void* last_input; + + void* zero_buffer; + void* zero_pointer; + void* lookup_table; + + union { + union pytorch_qnnp_q31_requantization_params requantization_params; + union pytorch_qnnp_conv_quantization_params conv_quantization_params; + union pytorch_qnnp_add_quantization_params add_quantization_params; + union pytorch_qnnp_avgpool_quantization_params avgpool_quantization_params; + union pytorch_qnnp_u8_clamping_params u8_clamping_params; + }; + enum pytorch_qnnp_ukernel_type ukernel_type; + enum pytorch_qnnp_format format; +}; + +static inline uint32_t pytorch_qnnp_operator_get_log2_output_element_size( + const struct pytorch_qnnp_operator* convolution) { + return (uint32_t)(convolution->format & UINT32_C(0xFF)); +} + +static inline uint32_t pytorch_qnnp_operator_get_log2_input_element_size( + const struct pytorch_qnnp_operator* convolution) { + return (uint32_t)((convolution->format >> 8) & UINT32_C(0xFF)); +} + +static inline uint32_t pytorch_qnnp_operator_get_log2_kernel_element_size( + const struct pytorch_qnnp_operator* convolution) { + return (uint32_t)((convolution->format >> 16) & UINT32_C(0xFF)); +} + +static inline uint32_t pytorch_qnnp_operator_get_log2_bias_element_size( + const struct pytorch_qnnp_operator* convolution) { + return (uint32_t)((convolution->format >> 24) & UINT32_C(0xFF)); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/pack.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/pack.h new file mode 100644 index 0000000000000..028a3095492db --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/pack.h @@ -0,0 +1,640 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include + +// Legend: +// dq: Design-time Quantization +// rq: Run-time Quantization + +static inline void pytorch_pack_q8gemm_wdq( + size_t nc, + size_t kc, + uint32_t nr, + uint32_t np, + uint32_t kr, + uint8_t izp, + uint8_t kzp, + const uint8_t* k, + const int32_t* b, + void* packed_w) { + const int32_t boff = (int32_t)kc * (int32_t)izp * (int32_t)kzp; + for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) { + const size_t nr_block_size = min(nc - nr_block_start, nr); + int32_t* packed_b = (int32_t*)packed_w; + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + *((int32_t*)packed_w) = b[nr_block_start + nr_block_offset] + boff; + packed_w = (void*)((uintptr_t)packed_w + sizeof(int32_t)); + } + packed_w = + (void*)((uintptr_t)packed_w + (nr - nr_block_size) * sizeof(int32_t)); + for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) { + const size_t kr_block_size = min(kc - kr_block_start, kr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + int32_t ksum = 0; + for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; + kr_block_offset++) { + const uint8_t kv = + k[(nr_block_start + nr_block_offset) * kc + + (kr_block_start + kr_block_offset)]; + ksum += (int32_t)kv; + *((uint8_t*)packed_w) = kv; + packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t)); + } + packed_b[nr_block_offset] -= ksum * (int32_t)izp; + packed_w = + (void*)((uintptr_t)packed_w + (kr - kr_block_size) * sizeof(uint8_t)); + } + packed_w = + (void*)((uintptr_t)packed_w + ((nr - nr_block_size) & (np - 1)) * kr * sizeof(uint8_t)); + } + } +} + +static inline void pytorch_pack_q8gemm_wrq( + const size_t nc, + const size_t kc, + const uint32_t nr, + const uint32_t np, + const uint32_t kr, + const uint8_t* const k, + const int32_t* const b, + void* const packed_w) { + union { + void* const as_void_ptr; + uint8_t* as_uint8_ptr; + int32_t* as_int32_ptr; + } packed = {packed_w}; + + for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) { + const size_t nr_block_size = min(nc - nr_block_start, nr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + *(packed.as_int32_ptr++) = b[nr_block_start + nr_block_offset]; + } + packed.as_int32_ptr += (nr - nr_block_size); + for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) { + const size_t kr_block_size = min(kc - kr_block_start, kr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; + kr_block_offset++) { + const uint8_t kv = + k[(nr_block_start + nr_block_offset) * kc + + (kr_block_start + kr_block_offset)]; + *(packed.as_uint8_ptr++) = kv; + } + packed.as_uint8_ptr += (kr - kr_block_size); + } + packed.as_uint8_ptr += ((nr - nr_block_size) & (np - 1)) * kr; + } + } +} + +static inline void pytorch_pack_q8conv_wdq( + size_t n, + size_t ks, + size_t kc, + uint32_t nr, + uint32_t kr, + uint8_t izp, + uint8_t kzp, + const uint8_t* k, + const int32_t* b, + void* packed_w) { + const int32_t boff = (int32_t)ks * (int32_t)kc * (int32_t)izp * (int32_t)kzp; + for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) { + const size_t nr_block_size = min(n - nr_block_start, nr); + int32_t* packed_b = (int32_t*)packed_w; + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + *((int32_t*)packed_w) = b[nr_block_start + nr_block_offset] + boff; + packed_w = (void*)((uintptr_t)packed_w + sizeof(int32_t)); + } + packed_w = + (void*)((uintptr_t)packed_w + (nr - nr_block_size) * sizeof(int32_t)); + for (size_t ki = 0; ki < ks; ki++) { + for (size_t kr_block_start = 0; kr_block_start < kc; + kr_block_start += kr) { + const size_t kr_block_size = min(kc - kr_block_start, kr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + int32_t ksum = 0; + for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; + kr_block_offset++) { + const uint8_t kv = + k[((nr_block_start + nr_block_offset) * ks + ki) * kc + + (kr_block_start + kr_block_offset)]; + ksum += (int32_t)kv; + *((uint8_t*)packed_w) = kv; + packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t)); + } + packed_b[nr_block_offset] -= ksum * (int32_t)izp; + packed_w = + (void*)((uintptr_t)packed_w + (kr - kr_block_size) * sizeof(uint8_t)); + } + packed_w = + (void*)((uintptr_t)packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t)); + } + } + } +} + +static inline void pytorch_pack_q8conv_wrq( + const size_t n, + const size_t ks, + const size_t kc, + const uint32_t nr, + const uint32_t kr, + const uint8_t* const k, + const int32_t* const b, + void* const packed_w) { + union { + void* const as_void_ptr; + uint8_t* as_uint8_ptr; + int32_t* as_int32_ptr; + } packed = {packed_w}; + + for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) { + const size_t nr_block_size = min(n - nr_block_start, nr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + *(packed.as_int32_ptr++) = b[nr_block_start + nr_block_offset]; + } + packed.as_int32_ptr += (nr - nr_block_size); + for (size_t ki = 0; ki < ks; ki++) { + for (size_t kr_block_start = 0; kr_block_start < kc; + kr_block_start += kr) { + const size_t kr_block_size = min(kc - kr_block_start, kr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; + kr_block_offset++) { + const uint8_t kv = + k[((nr_block_start + nr_block_offset) * ks + ki) * kc + + (kr_block_start + kr_block_offset)]; + *(packed.as_uint8_ptr++) = kv; + } + packed.as_uint8_ptr += (kr - kr_block_size); + } + packed.as_uint8_ptr += (nr - nr_block_size) * kr; + } + } + } +} + +static inline void pytorch_pack_q8deconv_wdq( + size_t n, + size_t ks, + size_t kc, + uint32_t nr, + uint32_t kr, + uint8_t izp, + uint8_t kzp, + const uint8_t* k, + const int32_t* b, + void* packed_w) { + const int32_t boff = (int32_t)ks * (int32_t)kc * (int32_t)izp * (int32_t)kzp; + for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) { + const size_t nr_block_size = min(n - nr_block_start, nr); + int32_t* packed_b = (int32_t*)packed_w; + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + *((int32_t*)packed_w) = b[nr_block_start + nr_block_offset] + boff; + packed_w = (void*)((uintptr_t)packed_w + sizeof(int32_t)); + } + packed_w = + (void*)((uintptr_t)packed_w + (nr - nr_block_size) * sizeof(int32_t)); + for (size_t ki = 0; ki < ks; ki++) { + for (size_t kr_block_start = 0; kr_block_start < kc; + kr_block_start += kr) { + const size_t kr_block_size = min(kc - kr_block_start, kr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + int32_t ksum = 0; + for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; + kr_block_offset++) { + const uint8_t kv = + k[((kr_block_start + kr_block_offset) * ks + ki) * n + + (nr_block_start + nr_block_offset)]; + ksum += (int32_t)kv; + *((uint8_t*)packed_w) = kv; + packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t)); + } + packed_b[nr_block_offset] -= ksum * (int32_t)izp; + packed_w = + (void*)((uintptr_t)packed_w + (kr - kr_block_size) * sizeof(uint8_t)); + } + packed_w = + (void*)((uintptr_t)packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t)); + } + } + } +} + +static inline void pytorch_pack_q8deconv_wrq( + const size_t n, + const size_t ks, + const size_t kc, + const uint32_t nr, + const uint32_t kr, + const uint8_t* const k, + const int32_t* const b, + void* const packed_w) { + union { + void* const as_void_ptr; + uint8_t* as_uint8_ptr; + int32_t* as_int32_ptr; + } packed = {packed_w}; + + for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) { + const size_t nr_block_size = min(n - nr_block_start, nr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + *(packed.as_int32_ptr++) = b[nr_block_start + nr_block_offset]; + } + packed.as_int32_ptr += (nr - nr_block_size); + for (size_t ki = 0; ki < ks; ki++) { + for (size_t kr_block_start = 0; kr_block_start < kc; + kr_block_start += kr) { + const size_t kr_block_size = min(kc - kr_block_start, kr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; + kr_block_offset++) { + const uint8_t kv = + k[((kr_block_start + kr_block_offset) * ks + ki) * n + + (nr_block_start + nr_block_offset)]; + *(packed.as_uint8_ptr++) = kv; + } + packed.as_uint8_ptr += (kr - kr_block_size); + } + packed.as_uint8_ptr += (nr - nr_block_size) * kr; + } + } + } +} + +static inline void pytorch_pack_q8dw_wdq( + size_t h, + size_t w, + size_t c, + size_t cr, + uint8_t izp, + uint8_t kzp, + const uint8_t* k, + const int32_t* b, + void* packed_w) { + const int32_t boff = (int32_t)h * (int32_t)w * (int32_t)izp * (int32_t)kzp; + for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) { + const size_t cr_block_size = min(c - cr_block_start, cr); + int32_t* packed_b = (int32_t*)packed_w; + for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; + cr_block_offset++) { + *((int32_t*)packed_w) = b[cr_block_start + cr_block_offset] + boff; + packed_w = (void*)((uintptr_t)packed_w + sizeof(int32_t)); + } + packed_w = + (void*)((uintptr_t)packed_w + (cr - cr_block_size) * sizeof(int32_t)); + for (size_t x = 0; x < w; x++) { + for (size_t y = 0; y < h; y++) { + for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; + cr_block_offset++) { + const uint8_t kv = + k[((cr_block_start + cr_block_offset) * h + y) * w + x]; + packed_b[cr_block_offset] -= (int32_t)kv * (int32_t)izp; + *((uint8_t*)packed_w) = kv; + packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t)); + } + packed_w = + (void*)((uintptr_t)packed_w + (cr - cr_block_size) * sizeof(uint8_t)); + } + } + } +} + +static inline void pytorch_pack_q8dw_wrq( + const size_t h, + const size_t w, + const size_t c, + const size_t cr, + const uint8_t* const k, + const int32_t* const b, + void* const packed_w) { + union { + void* const as_void_ptr; + uint8_t* as_uint8_ptr; + int32_t* as_int32_ptr; + } packed = {packed_w}; + + for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) { + const size_t cr_block_size = min(c - cr_block_start, cr); + for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; + cr_block_offset++) { + *(packed.as_int32_ptr++) = b[cr_block_start + cr_block_offset]; + } + packed.as_int32_ptr += (cr - cr_block_size); + for (size_t x = 0; x < w; x++) { + for (size_t y = 0; y < h; y++) { + for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; + cr_block_offset++) { + const uint8_t kv = + k[((cr_block_start + cr_block_offset) * h + y) * w + x]; + *(packed.as_uint8_ptr++) = kv; + } + packed.as_uint8_ptr += (cr - cr_block_size); + } + } + } +} + +static inline void pytorch_pack_q8dw_w_dilation( + size_t h, + size_t w, + size_t c, + size_t cr, + size_t y_start, + size_t y_end, + size_t x_start, + size_t x_end, + const uint8_t* k, + const int32_t* b, + void* packed_w, + bool pytorch_pack_b) { + for (size_t cr_block_start = 0; cr_block_start < c; cr_block_start += cr) { + const size_t cr_block_size = min(c - cr_block_start, cr); + if (pytorch_pack_b) { + for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; + cr_block_offset++) { + *((int32_t*)packed_w) = b[cr_block_start + cr_block_offset]; + packed_w = (void*)((uintptr_t)packed_w + sizeof(int32_t)); + } + packed_w = + (void*)((uintptr_t)packed_w + (cr - cr_block_size) * sizeof(int32_t)); + } + for (size_t x = x_start; x < x_end; x++) { + for (size_t y = y_start; y < y_end; y++) { + for (size_t cr_block_offset = 0; cr_block_offset < cr_block_size; + cr_block_offset++) { + *((uint8_t*)packed_w) = + k[((cr_block_start + cr_block_offset) * h + y) * w + x]; + packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t)); + } + packed_w = + (void*)((uintptr_t)packed_w + (cr - cr_block_size) * sizeof(uint8_t)); + } + } + } +} + +static inline void pytorch_pack_swizzle_q8gemm_bdq( + size_t n, + size_t kc, + uint32_t nr, + uint32_t kr, + uint32_t sr, + uint8_t izp, + uint8_t kzp, + const uint8_t* k, + const int32_t* b, + void* packed_w) { + const int32_t boff = (int32_t)kc * (int32_t)izp * (int32_t)kzp; + for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) { + const size_t nr_block_size = min(n - nr_block_start, nr); + int32_t* packed_b = (int32_t*)packed_w; + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + *((int32_t*)packed_w) = b[nr_block_start + nr_block_offset] + boff; + packed_w = (void*)((uintptr_t)packed_w + sizeof(int32_t)); + } + packed_w = + (void*)((uintptr_t)packed_w + (nr - nr_block_size) * sizeof(int32_t)); + + for (size_t kr_block_start = 0; kr_block_start < (kc & -sr); + kr_block_start += kr) { + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + for (size_t kr_block_offset = 0; kr_block_offset < kr; + kr_block_offset++) { + const uint8_t kv = + k[(nr_block_start + nr_block_offset) * kc + + (kr_block_start & -sr) + + ((kr_block_start + nr_block_offset * kr) & (sr - 1)) + + kr_block_offset]; + packed_b[nr_block_offset] -= (int32_t)kv * (int32_t)izp; + *((uint8_t*)packed_w) = kv; + packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t)); + } + } + packed_w = + (void*)((uintptr_t)packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t)); + } + + for (size_t kr_block_start = (kc & -sr); kr_block_start < kc; + kr_block_start += kr) { + const size_t kr_block_size = min(kc - kr_block_start, kr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; + kr_block_offset++) { + const uint8_t kv = + k[(nr_block_start + nr_block_offset) * kc + + (kr_block_start + kr_block_offset)]; + packed_b[nr_block_offset] -= (int32_t)kv * (int32_t)izp; + *((uint8_t*)packed_w) = kv; + packed_w = (void*)((uintptr_t)packed_w + sizeof(uint8_t)); + } + packed_w = + (void*)((uintptr_t)packed_w + (kr - kr_block_size) * sizeof(uint8_t)); + } + packed_w = + (void*)((uintptr_t)packed_w + (nr - nr_block_size) * kr * sizeof(uint8_t)); + } + } +} + +static inline void pytorch_pack_swizzle_q8gemm_brq( + const size_t n, + const size_t kc, + const uint32_t nr, + const uint32_t kr, + const uint32_t sr, + const uint8_t* const k, + const int32_t* const b, + void* const packed_w) { + union { + void* const as_void_ptr; + uint8_t* as_uint8_ptr; + int32_t* as_int32_ptr; + } packed = {packed_w}; + + for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) { + const size_t nr_block_size = min(n - nr_block_start, nr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + *(packed.as_int32_ptr++) = b[nr_block_start + nr_block_offset]; + } + + packed.as_int32_ptr += (nr - nr_block_size); + + for (size_t kr_block_start = 0; kr_block_start < (kc & -sr); + kr_block_start += kr) { + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + for (size_t kr_block_offset = 0; kr_block_offset < kr; + kr_block_offset++) { + const uint8_t kv = + k[(nr_block_start + nr_block_offset) * kc + + (kr_block_start & -sr) + + ((kr_block_start + nr_block_offset * kr) & (sr - 1)) + + kr_block_offset]; + *(packed.as_uint8_ptr++) = kv; + } + } + packed.as_uint8_ptr += (nr - nr_block_size) * kr; + } + + for (size_t kr_block_start = (kc & -sr); kr_block_start < kc; + kr_block_start += kr) { + const size_t kr_block_size = min(kc - kr_block_start, kr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; + kr_block_offset++) { + const uint8_t kv = + k[(nr_block_start + nr_block_offset) * kc + + (kr_block_start + kr_block_offset)]; + *(packed.as_uint8_ptr++) = kv; + } + packed.as_uint8_ptr += (kr - kr_block_size); + } + packed.as_uint8_ptr += (nr - nr_block_size) * kr; + } + } +} + +static inline void pytorch_pack_hgemm_w( + size_t nc, + size_t kc, + size_t nr, + size_t kr, + const uint16_t* k, + const uint16_t* b, + uint16_t* packed_w) { + for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) { + const size_t nr_block_size = min(nc - nr_block_start, nr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + *packed_w++ = b[nr_block_start + nr_block_offset]; + } + packed_w += nr - nr_block_size; + for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) { + const size_t kr_block_size = min(kc - kr_block_start, kr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; + kr_block_offset++) { + *packed_w++ = + k[(nr_block_start + nr_block_offset) * kc + + (kr_block_start + kr_block_offset)]; + } + packed_w += kr - kr_block_size; + } + packed_w += (nr - nr_block_size) * kr; + } + } +} + +static inline void pytorch_pack_sgemm_w( + size_t nc, + size_t kc, + size_t nr, + size_t kr, + const float* k, + const float* b, + float* packed_w) { + for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) { + const size_t nr_block_size = min(nc - nr_block_start, nr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + *packed_w++ = b[nr_block_start + nr_block_offset]; + } + packed_w += nr - nr_block_size; + for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) { + const size_t kr_block_size = min(kc - kr_block_start, kr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; + kr_block_offset++) { + *packed_w++ = + k[(nr_block_start + nr_block_offset) * kc + + (kr_block_start + kr_block_offset)]; + } + packed_w += kr - kr_block_size; + } + packed_w += (nr - nr_block_size) * kr; + } + } +} + +static inline void pytorch_pack_sconv_w( + size_t n, + size_t ks, + size_t kc, + size_t nr, + size_t kr, + const float* k, + const float* b, + float* packed_w) { + for (size_t nr_block_start = 0; nr_block_start < n; nr_block_start += nr) { + const size_t nr_block_size = min(n - nr_block_start, nr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + *packed_w++ = b[nr_block_start + nr_block_offset]; + } + packed_w += nr - nr_block_size; + for (size_t ki = 0; ki < ks; ki++) { + for (size_t kr_block_start = 0; kr_block_start < kc; + kr_block_start += kr) { + const size_t kr_block_size = min(kc - kr_block_start, kr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; + nr_block_offset++) { + for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; + kr_block_offset++) { + *packed_w++ = + k[((nr_block_start + nr_block_offset) * ks + ki) * kc + + (kr_block_start + kr_block_offset)]; + } + packed_w += kr - kr_block_size; + } + packed_w += (nr - nr_block_size) * kr; + } + } + } +} + +#if PYTORCH_QNNPACK_RUNTIME_QUANTIZATION + +#define pytorch_pack_q8gemm_w pytorch_pack_q8gemm_wrq +#define pytorch_pack_q8conv_w pytorch_pack_q8conv_wrq +#define pytorch_pack_q8deconv_w pytorch_pack_q8deconv_wrq +#define pytorch_pack_q8dw_w pytorch_pack_q8dw_wrq +#define pytorch_pack_swizzle_q8gemm_b pytorch_pack_swizzle_q8gemm_brq + +#else + +#define pytorch_pack_q8gemm_w pytorch_pack_q8gemm_wdq +#define pytorch_pack_q8conv_w pytorch_pack_q8conv_wdq +#define pytorch_pack_q8deconv_w pytorch_pack_q8deconv_wdq +#define pytorch_pack_q8dw_w pytorch_pack_q8dw_wdq +#define pytorch_pack_swizzle_q8gemm_b pytorch_pack_swizzle_q8gemm_bdq + +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/params.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/params.h new file mode 100644 index 0000000000000..ab4a316863034 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/params.h @@ -0,0 +1,529 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include + +#include + +struct pytorch_qnnp_fp16_clamping_params { + uint16_t scale; + uint16_t max; + uint16_t min; +}; + +struct pytorch_qnnp_fp32_clamping_params { + float max; + float min; +}; + +union pytorch_qnnp_fp32_requantization_params { + struct { + float scale; + float min_less_zero_point; + float max_less_zero_point; + float magic; + int32_t magic_less_zero_point; + } scalar; + struct { + float scale; + float max; + float min; + float magic; + int32_t magic_less_zero_point; + } neon; + struct { + float scale; + int16_t zero_point; + uint8_t max; + uint8_t min; + } neonv8; + struct { + PYTORCH_QNNP_ALIGN(16) float scale[4]; + PYTORCH_QNNP_ALIGN(16) int16_t zero_point[8]; + PYTORCH_QNNP_ALIGN(16) uint8_t max[16]; + PYTORCH_QNNP_ALIGN(16) uint8_t min[16]; + } sse2; + struct { + PYTORCH_QNNP_ALIGN(16) float scale[4]; + PYTORCH_QNNP_ALIGN(16) float min_less_zero_point[4]; + PYTORCH_QNNP_ALIGN(16) float max_less_zero_point[4]; + PYTORCH_QNNP_ALIGN(16) float magic[4]; + PYTORCH_QNNP_ALIGN(16) int32_t magic_less_zero_point[4]; + } psimd; +}; + +union pytorch_qnnp_precise_requantization_params { + struct { + uint32_t multiplier; + uint32_t rounding_lo; + uint32_t rounding_hi; + uint32_t shift_less_32; + int32_t min_less_zero_point; + int32_t max_less_zero_point; + int32_t zero_point; + } scalar; + struct { + int32_t multiplier; + int32_t right_shift; + int16_t zero_point; + uint8_t max; + uint8_t min; + } neon; + struct { + PYTORCH_QNNP_ALIGN(16) uint32_t multiplier[4]; + PYTORCH_QNNP_ALIGN(16) uint64_t rounding[2]; + PYTORCH_QNNP_ALIGN(16) uint32_t shift[4]; + PYTORCH_QNNP_ALIGN(16) int16_t zero_point[8]; + PYTORCH_QNNP_ALIGN(16) uint8_t max[16]; + PYTORCH_QNNP_ALIGN(16) uint8_t min[16]; + } sse2; +}; + +union pytorch_qnnp_q31_requantization_params { + struct { + int32_t multiplier; + int32_t remainder_mask; + int32_t remainder_threshold; + uint32_t shift; + int32_t min_less_zero_point; + int32_t max_less_zero_point; + int32_t zero_point; + } scalar; +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + struct { + int32_t multiplier; + int32_t right_shift; + int16_t zero_point; + uint8_t max; + uint8_t min; + } neon; +#endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */ +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + struct { + PYTORCH_QNNP_ALIGN(16) uint32_t multiplier[4]; + PYTORCH_QNNP_ALIGN(16) uint64_t rounding[2]; + PYTORCH_QNNP_ALIGN(16) int32_t remainder_mask[4]; + PYTORCH_QNNP_ALIGN(16) int32_t remainder_threshold[4]; + PYTORCH_QNNP_ALIGN(16) uint64_t shift[2]; + PYTORCH_QNNP_ALIGN(16) int16_t zero_point[8]; + PYTORCH_QNNP_ALIGN(16) uint8_t max[16]; + PYTORCH_QNNP_ALIGN(16) uint8_t min[16]; + } sse2; +#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */ +}; + +union pytorch_qnnp_conv_quantization_params { + struct { + int32_t kernel_zero_point; + int32_t input_zero_point; + int32_t multiplier; + int32_t remainder_mask; + int32_t remainder_threshold; + uint32_t shift; + int32_t output_min_less_zero_point; + int32_t output_max_less_zero_point; + int32_t output_zero_point; + } scalar; +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + struct { + int16_t kernel_zero_point; + int16_t input_zero_point; + int32_t multiplier; + int32_t right_shift; + int16_t output_zero_point; + uint8_t output_max; + uint8_t output_min; + } neon; +#endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */ +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + struct { + PYTORCH_QNNP_ALIGN(16) int16_t kernel_zero_point[8]; + PYTORCH_QNNP_ALIGN(16) int16_t input_zero_point[8]; + PYTORCH_QNNP_ALIGN(16) uint32_t multiplier[4]; + PYTORCH_QNNP_ALIGN(16) uint64_t rounding[2]; + PYTORCH_QNNP_ALIGN(16) int32_t remainder_mask[4]; + PYTORCH_QNNP_ALIGN(16) int32_t remainder_threshold[4]; + PYTORCH_QNNP_ALIGN(16) uint64_t shift[2]; + PYTORCH_QNNP_ALIGN(16) int16_t output_zero_point[8]; + PYTORCH_QNNP_ALIGN(16) uint8_t output_max[16]; + PYTORCH_QNNP_ALIGN(16) uint8_t output_min[16]; + } sse2; +#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */ +}; + +union pytorch_qnnp_requantization_params { + union pytorch_qnnp_precise_requantization_params precise; + union pytorch_qnnp_fp32_requantization_params fp32; + union pytorch_qnnp_q31_requantization_params q31; +}; + +union pytorch_qnnp_add_quantization_params { + struct { + int32_t zero_point_product; + uint32_t a_multiplier; + uint32_t b_multiplier; + uint32_t shift; + int32_t remainder_mask; + int32_t remainder_threshold; + int32_t y_zero_point; + int32_t y_max; + int32_t y_min; + } scalar; +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + struct { + uint8_t a_zero_point; + uint8_t b_zero_point; + int16_t y_zero_point; + int32_t a_multiplier; + int32_t b_multiplier; + int32_t right_shift; + uint8_t y_max; + uint8_t y_min; + } neon; +#endif +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + struct { + PYTORCH_QNNP_ALIGN(16) int32_t zero_point_product[4]; + PYTORCH_QNNP_ALIGN(16) uint16_t a_multiplier_lo[8]; + PYTORCH_QNNP_ALIGN(16) uint16_t a_multiplier_hi[8]; + PYTORCH_QNNP_ALIGN(16) uint16_t b_multiplier_lo[8]; + PYTORCH_QNNP_ALIGN(16) uint16_t b_multiplier_hi[8]; + PYTORCH_QNNP_ALIGN(16) int32_t remainder_mask[4]; + PYTORCH_QNNP_ALIGN(16) int32_t remainder_threshold[4]; + PYTORCH_QNNP_ALIGN(16) int16_t y_zero_point[8]; + PYTORCH_QNNP_ALIGN(16) uint8_t y_max[16]; + PYTORCH_QNNP_ALIGN(16) uint8_t y_min[16]; + uint32_t shift; + uint32_t a_multiplier; + uint32_t b_multiplier; + } sse2; +#endif +}; + +union pytorch_qnnp_avgpool_quantization_params { + struct { + int32_t bias; + int32_t multiplier; + int64_t rounding; + uint32_t right_shift; + int32_t output_min_less_zero_point; + int32_t output_max_less_zero_point; + int32_t output_zero_point; + } scalar; +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + struct { + int32_t bias; + int32_t multiplier; + int64_t left_shift; + int16_t output_zero_point; + uint8_t output_max; + uint8_t output_min; + } neon; +#endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */ +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + struct { + PYTORCH_QNNP_ALIGN(16) int32_t bias[4]; + PYTORCH_QNNP_ALIGN(16) uint32_t multiplier[4]; + PYTORCH_QNNP_ALIGN(16) uint64_t rounding[2]; + PYTORCH_QNNP_ALIGN(16) uint64_t right_shift[2]; + PYTORCH_QNNP_ALIGN(16) int16_t output_zero_point[8]; + PYTORCH_QNNP_ALIGN(16) uint8_t output_max[16]; + PYTORCH_QNNP_ALIGN(16) uint8_t output_min[16]; + } sse2; +#endif +}; + +union pytorch_qnnp_u8_clamping_params { + struct { + int32_t output_max; + int32_t output_min; + } scalar; +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + struct { + uint8_t output_max; + uint8_t output_min; + } neon; +#endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */ +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + struct { + PYTORCH_QNNP_ALIGN(16) uint8_t output_max[16]; + PYTORCH_QNNP_ALIGN(16) uint8_t output_min[16]; + } sse2; +#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */ +}; + +typedef void (*pytorch_q8gemm_ukernel_function)( + size_t mr, + size_t nr, + size_t k, + const uint8_t* a, + size_t a_stride, + const void* w, + uint8_t* c, + size_t c_stride, + const union pytorch_qnnp_conv_quantization_params* quantization_params); + +typedef void (*pytorch_q8conv_ukernel_function)( + size_t mr, + size_t nr, + size_t kc, + size_t ks, + const uint8_t** a, + const void* w, + uint8_t* c, + size_t c_stride, + const union pytorch_qnnp_conv_quantization_params* quantization_params); + +typedef void (*pytorch_q8gemm_xzp_ukernel_function)( + size_t mr, + size_t nr, + size_t k, + const uint8_t* a, + size_t a_stride, + const int32_t* a_sum, + const void* w, + uint8_t* c, + size_t c_stride, + const union pytorch_qnnp_q31_requantization_params* requantization_params); + +typedef void (*pytorch_q8sum_rows_ukernel_function)( + const uint8_t* a, + size_t m, + size_t k, + size_t stride, + int32_t multiplier, + int32_t* sums); + +typedef void (*pytorch_xzipc_ukernel_function)(size_t n, const void* x, void* y); + +typedef void ( + *pytorch_xzipv_ukernel_function)(size_t n, size_t m, const void* x, void* y); + +typedef void (*pytorch_x8lut_ukernel_function)( + size_t n, + const uint8_t* x, + const uint8_t* t, + uint8_t* y); + +typedef void (*pytorch_sgemm_ukernel_function)( + size_t mr, + size_t nr, + size_t k, + const float* a, + size_t a_stride, + const float* w, + float* c, + size_t c_stride, + const struct pytorch_qnnp_fp32_clamping_params* clamping_params); + +typedef void (*pytorch_sconv_ukernel_function)( + size_t mr, + size_t nr, + size_t kc, + size_t ks, + const float** a, + const float* w, + float* c, + size_t c_stride, + const struct pytorch_qnnp_fp32_clamping_params* clamping_params); + +typedef void (*pytorch_hgemm_ukernel_function)( + size_t mr, + size_t nr, + size_t k, + const void* a, + size_t a_stride, + const void* w, + void* c, + size_t c_stride, + const struct pytorch_qnnp_fp16_clamping_params* clamping_params); + +typedef void (*pytorch_q8dwconv_up_ukernel_function)( + size_t channels, + size_t output_width, + const uint8_t** input, + const void* weights, + uint8_t* output, + size_t input_stride, + size_t output_increment, + const union pytorch_qnnp_conv_quantization_params* quantization_params); + +typedef void (*pytorch_q8dwconv_mp_ukernel_function)( + size_t channels, + size_t output_width, + const uint8_t** input, + const void* weights, + int32_t* buffer, + uint8_t* output, + size_t input_stride, + size_t output_increment, + const union pytorch_qnnp_conv_quantization_params* quantization_params); + +typedef void (*pytorch_q8gavgpool_up_ukernel_function)( + size_t m, + size_t n, + const uint8_t* x, + size_t x_stride, + const uint8_t* zero, + uint8_t* y, + const union pytorch_qnnp_avgpool_quantization_params* quantization_params); + +typedef void (*pytorch_q8gavgpool_mp_ukernel_function)( + size_t m, + size_t n, + const uint8_t* x, + size_t x_stride, + const uint8_t* zero, + int32_t* buffer, + uint8_t* y, + const union pytorch_qnnp_avgpool_quantization_params* quantization_params); + +typedef void (*pytorch_q8avgpool_up_ukernel_function)( + size_t n, + size_t ks, + size_t kc, + const uint8_t** x, + const uint8_t* zero, + uint8_t* y, + size_t x_increment, + size_t y_increment, + const union pytorch_qnnp_avgpool_quantization_params* quantization_params); + +typedef void (*pytorch_q8avgpool_mp_ukernel_function)( + size_t n, + size_t ks, + size_t kc, + const uint8_t** x, + const uint8_t* zero, + int32_t* buffer, + uint8_t* y, + size_t x_increment, + size_t y_increment, + const union pytorch_qnnp_avgpool_quantization_params* quantization_params); + +typedef void (*pytorch_u8maxpool_ukernel_function)( + size_t n, + size_t ks, + size_t kc, + const uint8_t** x, + uint8_t* y, + size_t x_increment, + size_t y_increment, + const union pytorch_qnnp_u8_clamping_params* params); + +typedef void (*pytorch_u8clamp_ukernel_function)( + size_t n, + const uint8_t* x, + uint8_t* y, + const union pytorch_qnnp_u8_clamping_params* params); + +typedef uint8_t (*pytorch_u8rmax_ukernel_function)(size_t n, const uint8_t* x); + +typedef void (*pytorch_u8lut32norm_ukernel_function)( + size_t n, + const uint8_t* x, + const uint32_t* t, + uint8_t* y); + +typedef void (*pytorch_q8vadd_ukernel_function)( + size_t n, + const uint8_t* a, + const uint8_t* b, + uint8_t* y, + const union pytorch_qnnp_add_quantization_params* quantization_params); + +struct pytorch_q8conv_parameters { + pytorch_q8gemm_ukernel_function gemm; + pytorch_q8conv_ukernel_function conv; + uint8_t mr; + uint8_t nr; + uint8_t kr; +}; + +struct pytorch_q8conv_xzp_parameters { + pytorch_q8gemm_xzp_ukernel_function gemm; + /* no conv ukernel */ + uint8_t mr; + uint8_t nr; + uint8_t kr; + uint8_t kc; + size_t kthreshold; +}; + +struct pytorch_q8dwconv_up_parameters { + pytorch_q8dwconv_up_ukernel_function updw; + uint8_t cr; +}; + +struct pytorch_q8dwconv_mp_parameters { + pytorch_q8dwconv_mp_ukernel_function mpdw; + uint8_t cr; +}; + +struct pytorch_q8sum_rows_parameters { + pytorch_q8sum_rows_ukernel_function sum_rows; + uint32_t m; +}; + +struct pytorch_q8gavgpool_parameters { + pytorch_q8gavgpool_up_ukernel_function ltnr; + pytorch_q8gavgpool_up_ukernel_function genr_lemr; + pytorch_q8gavgpool_mp_ukernel_function genr_gtmr; + uint8_t mr; + uint8_t nr; +}; + +struct pytorch_q8avgpool_parameters { + pytorch_q8avgpool_up_ukernel_function ltkr; + pytorch_q8avgpool_up_ukernel_function gekr_lemr; + pytorch_q8avgpool_mp_ukernel_function gekr_gtmr; + uint8_t mr; + uint8_t qr; + uint8_t kr; +}; + +struct pytorch_u8maxpool_parameters { + pytorch_u8maxpool_ukernel_function ltkr; + pytorch_u8maxpool_ukernel_function gekr; + uint8_t mr; + uint8_t qr; + uint8_t kr; +}; + +struct pytorch_x8zip_parameters { + pytorch_xzipc_ukernel_function x2; + pytorch_xzipc_ukernel_function x3; + pytorch_xzipc_ukernel_function x4; + pytorch_xzipv_ukernel_function xm; +}; + +struct pytorch_qnnp_parameters { + struct pytorch_q8conv_parameters q8conv; + struct pytorch_q8conv_xzp_parameters q8conv_xzp; + struct pytorch_q8dwconv_up_parameters q8dw9; + struct pytorch_q8dwconv_mp_parameters q8dw25; + struct pytorch_q8sum_rows_parameters q8sum_rows; + pytorch_q8vadd_ukernel_function q8vadd; + struct pytorch_q8gavgpool_parameters q8gavgpool; + struct pytorch_q8avgpool_parameters q8avgpool; + struct pytorch_u8maxpool_parameters u8maxpool; + pytorch_u8lut32norm_ukernel_function u8lut32norm; + pytorch_u8clamp_ukernel_function u8clamp; + pytorch_u8rmax_ukernel_function u8rmax; + struct pytorch_x8zip_parameters x8zip; + pytorch_x8lut_ukernel_function x8lut; + bool initialized; +}; + +extern PYTORCH_QNNP_INTERNAL struct pytorch_qnnp_parameters pytorch_qnnp_params; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8avgpool.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8avgpool.h new file mode 100644 index 0000000000000..949bf67654acc --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8avgpool.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define DECLARE_PYTORCH_Q8MPAVGPOOL_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t n, \ + size_t ks, \ + size_t kc, \ + const uint8_t** x, \ + const uint8_t* zero, \ + int32_t* buffer, \ + uint8_t* y, \ + size_t x_increment, \ + size_t y_increment, \ + const union pytorch_qnnp_avgpool_quantization_params* \ + quantization_params); + +DECLARE_PYTORCH_Q8MPAVGPOOL_UKERNEL_FUNCTION(pytorch_q8avgpool_ukernel_mp8x9p8q__neon) +DECLARE_PYTORCH_Q8MPAVGPOOL_UKERNEL_FUNCTION(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2) + +#define DECLARE_PYTORCH_Q8UPAVGPOOL_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t n, \ + size_t ks, \ + size_t kc, \ + const uint8_t** x, \ + const uint8_t* zero, \ + uint8_t* y, \ + size_t x_increment, \ + size_t y_increment, \ + const union pytorch_qnnp_avgpool_quantization_params* \ + quantization_params); + +DECLARE_PYTORCH_Q8UPAVGPOOL_UKERNEL_FUNCTION(pytorch_q8avgpool_ukernel_up8x9__neon) +DECLARE_PYTORCH_Q8UPAVGPOOL_UKERNEL_FUNCTION(pytorch_q8avgpool_ukernel_up8xm__neon) +DECLARE_PYTORCH_Q8UPAVGPOOL_UKERNEL_FUNCTION(pytorch_q8avgpool_ukernel_up8x9__sse2) +DECLARE_PYTORCH_Q8UPAVGPOOL_UKERNEL_FUNCTION(pytorch_q8avgpool_ukernel_up8xm__sse2) + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8conv.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8conv.h new file mode 100644 index 0000000000000..340853dd7fb14 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8conv.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define DECLARE_PYTORCH_Q8CONV_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t mr, \ + size_t nr, \ + size_t kc, \ + size_t ks, \ + const uint8_t** a, \ + const void* w, \ + uint8_t* c, \ + size_t c_stride, \ + const union pytorch_qnnp_conv_quantization_params* quantization_params); + +DECLARE_PYTORCH_Q8CONV_UKERNEL_FUNCTION(pytorch_q8conv_ukernel_4x8__neon) +DECLARE_PYTORCH_Q8CONV_UKERNEL_FUNCTION(pytorch_q8conv_ukernel_4x8__aarch32_neon) +DECLARE_PYTORCH_Q8CONV_UKERNEL_FUNCTION(pytorch_q8conv_ukernel_8x8__aarch64_neon) +DECLARE_PYTORCH_Q8CONV_UKERNEL_FUNCTION(pytorch_q8conv_ukernel_8x8__neon) +DECLARE_PYTORCH_Q8CONV_UKERNEL_FUNCTION(pytorch_q8conv_ukernel_4x4c2__sse2) + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8dwconv.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8dwconv.h new file mode 100644 index 0000000000000..b7fea2907a272 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8dwconv.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define DECLARE_PYTORCH_Q8UPDWCONV_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t channels, \ + size_t output_width, \ + const uint8_t** input, \ + const void* weights, \ + uint8_t* output, \ + size_t input_stride, \ + size_t output_increment, \ + const union pytorch_qnnp_conv_quantization_params* quantization_params); + +DECLARE_PYTORCH_Q8UPDWCONV_UKERNEL_FUNCTION(pytorch_q8dwconv_ukernel_up8x9__neon) +DECLARE_PYTORCH_Q8UPDWCONV_UKERNEL_FUNCTION(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon) +DECLARE_PYTORCH_Q8UPDWCONV_UKERNEL_FUNCTION(pytorch_q8dwconv_ukernel_up8x9__sse2) + +#define DECLARE_PYTORCH_Q8MPDWCONV_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t channels, \ + size_t output_width, \ + const uint8_t** input, \ + const void* weights, \ + int32_t* buffer, \ + uint8_t* output, \ + size_t input_stride, \ + size_t output_increment, \ + const union pytorch_qnnp_conv_quantization_params* quantization_params); + +DECLARE_PYTORCH_Q8MPDWCONV_UKERNEL_FUNCTION(pytorch_q8dwconv_ukernel_mp8x25__neon) +DECLARE_PYTORCH_Q8MPDWCONV_UKERNEL_FUNCTION(pytorch_q8dwconv_ukernel_mp8x25__sse2) + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8gavgpool.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8gavgpool.h new file mode 100644 index 0000000000000..f6835a994c1ec --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8gavgpool.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define DECLARE_PYTORCH_Q8MPGAVGPOOL_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t m, \ + size_t n, \ + const uint8_t* x, \ + size_t x_stride, \ + const uint8_t* zero, \ + int32_t* buffer, \ + uint8_t* y, \ + const union pytorch_qnnp_avgpool_quantization_params* \ + quantization_params); + +DECLARE_PYTORCH_Q8MPGAVGPOOL_UKERNEL_FUNCTION(pytorch_q8gavgpool_ukernel_mp8x7p7q__neon) +DECLARE_PYTORCH_Q8MPGAVGPOOL_UKERNEL_FUNCTION(pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2) + +#define DECLARE_PYTORCH_Q8UPGAVGPOOL_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t m, \ + size_t n, \ + const uint8_t* x, \ + size_t x_stride, \ + const uint8_t* zero, \ + uint8_t* y, \ + const union pytorch_qnnp_avgpool_quantization_params* \ + quantization_params); + +DECLARE_PYTORCH_Q8UPGAVGPOOL_UKERNEL_FUNCTION(pytorch_q8gavgpool_ukernel_up8x7__neon) +DECLARE_PYTORCH_Q8UPGAVGPOOL_UKERNEL_FUNCTION(pytorch_q8gavgpool_ukernel_up8xm__neon) +DECLARE_PYTORCH_Q8UPGAVGPOOL_UKERNEL_FUNCTION(pytorch_q8gavgpool_ukernel_up8x7__sse2) +DECLARE_PYTORCH_Q8UPGAVGPOOL_UKERNEL_FUNCTION(pytorch_q8gavgpool_ukernel_up8xm__sse2) + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8gemm.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8gemm.h new file mode 100644 index 0000000000000..f6a750fc63376 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8gemm.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define DECLARE_PYTORCH_Q8GEMM_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t mr, \ + size_t nr, \ + size_t k, \ + const uint8_t* a, \ + size_t a_stride, \ + const void* w, \ + uint8_t* c, \ + size_t c_stride, \ + const union pytorch_qnnp_conv_quantization_params* quantization_params); + +DECLARE_PYTORCH_Q8GEMM_UKERNEL_FUNCTION(pytorch_q8gemm_ukernel_3x3c8__neon) +DECLARE_PYTORCH_Q8GEMM_UKERNEL_FUNCTION(pytorch_q8gemm_ukernel_2x4c8__neon) +DECLARE_PYTORCH_Q8GEMM_UKERNEL_FUNCTION(pytorch_q8gemm_ukernel_4x8__neon) +DECLARE_PYTORCH_Q8GEMM_UKERNEL_FUNCTION(pytorch_q8gemm_ukernel_6x4__neon) +DECLARE_PYTORCH_Q8GEMM_UKERNEL_FUNCTION(pytorch_q8gemm_ukernel_8x8__neon) + +DECLARE_PYTORCH_Q8GEMM_UKERNEL_FUNCTION(pytorch_q8gemm_ukernel_4x8__aarch32_neon) + +DECLARE_PYTORCH_Q8GEMM_UKERNEL_FUNCTION(pytorch_q8gemm_ukernel_8x8__aarch64_neon) + +DECLARE_PYTORCH_Q8GEMM_UKERNEL_FUNCTION(pytorch_q8gemm_ukernel_2x4c8__sse2) +DECLARE_PYTORCH_Q8GEMM_UKERNEL_FUNCTION(pytorch_q8gemm_ukernel_4x4c2__sse2) + +#define DECLARE_PYTORCH_Q8GEMM_XZP_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t mr, \ + size_t nr, \ + size_t k, \ + const uint8_t* a, \ + size_t a_stride, \ + const int32_t* a_sum, \ + const void* w, \ + uint8_t* c, \ + size_t c_stride, \ + const union pytorch_qnnp_q31_requantization_params* \ + requantization_params); +DECLARE_PYTORCH_Q8GEMM_XZP_UKERNEL_FUNCTION(pytorch_q8gemm_xzp_ukernel_4x8c2__neon) +DECLARE_PYTORCH_Q8GEMM_XZP_UKERNEL_FUNCTION(pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon) + +PYTORCH_QNNP_INTERNAL void pytorch_q8sumrows_ukernel_4x__neon( + const uint8_t* a, + size_t m, + size_t k, + size_t stride, + const int32_t multiplier, + int32_t* row_sum); + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8vadd.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8vadd.h new file mode 100644 index 0000000000000..06b12ee8e6a80 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/q8vadd.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define DECLARE_PYTORCH_Q8VADD_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t n, \ + const uint8_t* a, \ + const uint8_t* b, \ + uint8_t* y, \ + const union pytorch_qnnp_add_quantization_params* quantization_params); + +DECLARE_PYTORCH_Q8VADD_UKERNEL_FUNCTION(pytorch_q8vadd_ukernel__neon) +DECLARE_PYTORCH_Q8VADD_UKERNEL_FUNCTION(pytorch_q8vadd_ukernel__sse2) + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/requantization-stubs.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/requantization-stubs.h new file mode 100644 index 0000000000000..fc4e607fa3e77 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/requantization-stubs.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void (*pytorch_requantization_function)( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output); + +#define DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(fn_name) \ + void fn_name( \ + size_t n, \ + const int32_t* input, \ + float scale, \ + uint8_t zero_point, \ + uint8_t qmin, \ + uint8_t qmax, \ + uint8_t* output); + +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION( + pytorch_qnnp_requantize_precise__scalar_unsigned32) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION( + pytorch_qnnp_requantize_precise__scalar_unsigned64) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION( + pytorch_qnnp_requantize_precise__scalar_signed64) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_precise__sse2) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_precise__ssse3) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_precise__sse4) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_precise__neon) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_precise__psimd) + +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_fp32__scalar_lrintf) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_fp32__scalar_magic) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_fp32__sse2) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_fp32__neon) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_fp32__psimd) + +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_q31__scalar) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_q31__sse2) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_q31__ssse3) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_q31__sse4) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_q31__neon) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_q31__psimd) + +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_gemmlowp__scalar) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_gemmlowp__sse2) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_gemmlowp__ssse3) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_gemmlowp__sse4) +DECLARE_PYTORCH_REQUANTIZATION_FUNCTION(pytorch_qnnp_requantize_gemmlowp__neon) + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/requantization.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/requantization.h new file mode 100644 index 0000000000000..2c0c8e89a5805 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/requantization.h @@ -0,0 +1,549 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include + +#include +#include + +static inline union pytorch_qnnp_q31_requantization_params +pytorch_qnnp_compute_scalar_requantization_params( + float scale, + uint8_t zero_point, + uint8_t min, + uint8_t max) { + /* Compute requantization parameters */ + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + const uint32_t scale_bits = fp32_to_bits(scale); + + /* Multiplier is in [0x40000000, 0x7FFFFF80] range */ + const int32_t multiplier = (int32_t)( + ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); + assert(multiplier >= INT32_C(0x40000000)); + assert(multiplier <= INT32_C(0x7FFFFF80)); + + /* Shift is in [0, 31] range */ + const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23); + assert(shift >= 0); + assert(shift < 32); + + union pytorch_qnnp_q31_requantization_params params; + const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1); + const uint32_t remainder_threshold = remainder_mask >> 1; + params.scalar.multiplier = multiplier; + params.scalar.remainder_mask = (int32_t)remainder_mask; + params.scalar.remainder_threshold = (int32_t)remainder_threshold; + params.scalar.shift = (uint32_t)shift; + params.scalar.min_less_zero_point = + (int32_t)(uint32_t)min - (int32_t)(uint32_t)zero_point; + params.scalar.max_less_zero_point = + (int32_t)(uint32_t)max - (int32_t)(uint32_t)zero_point; + params.scalar.zero_point = (int32_t)(uint32_t)zero_point; + return params; +} + +static inline union pytorch_qnnp_q31_requantization_params +pytorch_qnnp_compute_requantization_params( + float scale, + uint8_t zero_point, + uint8_t min, + uint8_t max) { + /* Compute requantization parameters */ + const uint32_t scale_bits = fp32_to_bits(scale); + + /* Multiplier is in [0x40000000, 0x7FFFFF80] range */ + const int32_t multiplier = (int32_t)( + ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); + assert(multiplier >= INT32_C(0x40000000)); + assert(multiplier <= INT32_C(0x7FFFFF80)); + + /* Shift is in [0, 31] range */ + const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23); + assert(shift >= 0); + assert(shift < 32); + + union pytorch_qnnp_q31_requantization_params params; +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1); + const uint32_t remainder_threshold = remainder_mask >> 1; + params.sse2.multiplier[0] = multiplier; + params.sse2.multiplier[1] = multiplier; + params.sse2.multiplier[2] = multiplier; + params.sse2.multiplier[3] = multiplier; + params.sse2.rounding[0] = UINT64_C(0x40000000); + params.sse2.rounding[1] = UINT64_C(0x40000000); + params.sse2.remainder_mask[0] = (int32_t)remainder_mask; + params.sse2.remainder_mask[1] = (int32_t)remainder_mask; + params.sse2.remainder_mask[2] = (int32_t)remainder_mask; + params.sse2.remainder_mask[3] = (int32_t)remainder_mask; + params.sse2.remainder_threshold[0] = (int32_t)remainder_threshold; + params.sse2.remainder_threshold[1] = (int32_t)remainder_threshold; + params.sse2.remainder_threshold[2] = (int32_t)remainder_threshold; + params.sse2.remainder_threshold[3] = (int32_t)remainder_threshold; + params.sse2.shift[0] = (uint64_t)(uint32_t)shift; + params.sse2.shift[1] = (uint64_t)(uint32_t)shift; + for (uint32_t i = 0; i < 8; i++) { + params.sse2.zero_point[i] = (int16_t)(uint16_t)zero_point; + } + for (uint32_t i = 0; i < 16; i++) { + params.sse2.max[i] = max; + params.sse2.min[i] = min; + } +#elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + params.neon.multiplier = multiplier; + params.neon.right_shift = -shift; + params.neon.zero_point = (int16_t)(uint16_t)zero_point; + params.neon.max = max; + params.neon.min = min; +#else + const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1); + const uint32_t remainder_threshold = remainder_mask >> 1; + params.scalar.multiplier = multiplier; + params.scalar.remainder_mask = (int32_t)remainder_mask; + params.scalar.remainder_threshold = (int32_t)remainder_threshold; + params.scalar.shift = (uint32_t)shift; + params.scalar.min_less_zero_point = + (int32_t)(uint32_t)min - (int32_t)(uint32_t)zero_point; + params.scalar.max_less_zero_point = + (int32_t)(uint32_t)max - (int32_t)(uint32_t)zero_point; + params.scalar.zero_point = (int32_t)(uint32_t)zero_point; +#endif + return params; +} + +static inline union pytorch_qnnp_conv_quantization_params +pytorch_qnnp_compute_conv_quantization_params( + uint8_t input_zero_point, + uint8_t kernel_zero_point, + float scale, + uint8_t output_zero_point, + uint8_t output_min, + uint8_t output_max) { + /* Compute requantization parameters */ + const uint32_t scale_bits = fp32_to_bits(scale); + + /* Multiplier is in [0x40000000, 0x7FFFFF80] range */ + const int32_t multiplier = (int32_t)( + ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); + assert(multiplier >= INT32_C(0x40000000)); + assert(multiplier <= INT32_C(0x7FFFFF80)); + + /* Shift is in [0, 31] range */ + const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23); + assert(shift >= 0); + assert(shift < 32); + + union pytorch_qnnp_conv_quantization_params params; +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1); + const uint32_t remainder_threshold = remainder_mask >> 1; + for (uint32_t i = 0; i < 8; i++) { + params.sse2.input_zero_point[i] = (int16_t)(uint16_t)input_zero_point; + params.sse2.kernel_zero_point[i] = (int16_t)(uint16_t)kernel_zero_point; + } + params.sse2.multiplier[0] = multiplier; + params.sse2.multiplier[1] = multiplier; + params.sse2.multiplier[2] = multiplier; + params.sse2.multiplier[3] = multiplier; + params.sse2.rounding[0] = UINT64_C(0x40000000); + params.sse2.rounding[1] = UINT64_C(0x40000000); + params.sse2.remainder_mask[0] = (int32_t)remainder_mask; + params.sse2.remainder_mask[1] = (int32_t)remainder_mask; + params.sse2.remainder_mask[2] = (int32_t)remainder_mask; + params.sse2.remainder_mask[3] = (int32_t)remainder_mask; + params.sse2.remainder_threshold[0] = (int32_t)remainder_threshold; + params.sse2.remainder_threshold[1] = (int32_t)remainder_threshold; + params.sse2.remainder_threshold[2] = (int32_t)remainder_threshold; + params.sse2.remainder_threshold[3] = (int32_t)remainder_threshold; + params.sse2.shift[0] = (uint64_t)(uint32_t)shift; + params.sse2.shift[1] = (uint64_t)(uint32_t)shift; + for (uint32_t i = 0; i < 8; i++) { + params.sse2.output_zero_point[i] = (int16_t)(uint16_t)output_zero_point; + } + for (uint32_t i = 0; i < 16; i++) { + params.sse2.output_max[i] = output_max; + params.sse2.output_min[i] = output_min; + } +#elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + params.neon.input_zero_point = (int16_t)(uint16_t)input_zero_point; + params.neon.kernel_zero_point = (int16_t)(uint16_t)kernel_zero_point; + params.neon.multiplier = multiplier; + params.neon.right_shift = -shift; + params.neon.output_zero_point = (int16_t)(uint16_t)output_zero_point; + params.neon.output_max = output_max; + params.neon.output_min = output_min; +#else + const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1); + const uint32_t remainder_threshold = remainder_mask >> 1; + params.scalar.input_zero_point = (int32_t)(uint32_t)input_zero_point; + params.scalar.kernel_zero_point = (int32_t)(uint32_t)kernel_zero_point; + params.scalar.multiplier = multiplier; + params.scalar.remainder_mask = (int32_t)remainder_mask; + params.scalar.remainder_threshold = (int32_t)remainder_threshold; + params.scalar.shift = (uint32_t)shift; + params.scalar.output_min_less_zero_point = + (int32_t)(uint32_t)output_min - (int32_t)(uint32_t)output_zero_point; + params.scalar.output_max_less_zero_point = + (int32_t)(uint32_t)output_max - (int32_t)(uint32_t)output_zero_point; + params.scalar.output_zero_point = (int32_t)(uint32_t)output_zero_point; +#endif + return params; +} + +static inline union pytorch_qnnp_avgpool_quantization_params +pytorch_qnnp_compute_avgpool_quantization_params( + int32_t bias, + float scale, + uint8_t output_zero_point, + uint8_t output_min, + uint8_t output_max) { + /* Compute requantization parameters */ + assert(scale >= 0x1.0p-32f); + assert(scale < 256.0f); + const uint32_t scale_bits = fp32_to_bits(scale); + + /* Multiplier is in [0x00800000, 0x00FFFFFF] range */ + const int32_t multiplier = + ((int32_t)scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000); + assert(multiplier >= INT32_C(0x00800000)); + assert(multiplier <= INT32_C(0x00FFFFFF)); + + /* Shift is in [16, 55] range */ + const int32_t shift = 127 + 23 - (scale_bits >> 23); + assert(shift >= 16); + assert(shift < 64); + + union pytorch_qnnp_avgpool_quantization_params params; +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + const uint32_t right_shift = (uint32_t)shift; + const uint64_t rounding = UINT64_C(1) << (right_shift - 1); + params.sse2.bias[0] = bias; + params.sse2.bias[1] = bias; + params.sse2.bias[2] = bias; + params.sse2.bias[3] = bias; + params.sse2.multiplier[0] = (uint32_t)multiplier; + params.sse2.multiplier[1] = (uint32_t)multiplier; + params.sse2.multiplier[2] = (uint32_t)multiplier; + params.sse2.multiplier[3] = (uint32_t)multiplier; + params.sse2.rounding[0] = rounding; + params.sse2.rounding[1] = rounding; + params.sse2.right_shift[0] = (uint64_t)right_shift; + params.sse2.right_shift[1] = (uint64_t)right_shift; + for (uint32_t i = 0; i < 8; i++) { + params.sse2.output_zero_point[i] = (int16_t)(uint16_t)output_zero_point; + } + for (uint32_t i = 0; i < 16; i++) { + params.sse2.output_max[i] = output_max; + params.sse2.output_min[i] = output_min; + } +#elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + params.neon.bias = bias; + params.neon.multiplier = multiplier; + params.neon.left_shift = (int64_t)-shift; + params.neon.output_zero_point = (int16_t)(uint16_t)output_zero_point; + params.neon.output_max = output_max; + params.neon.output_min = output_min; +#else + const uint32_t right_shift = (uint32_t)shift; + const int64_t rounding = INT64_C(1) << (right_shift - 1); + params.scalar.bias = bias; + params.scalar.multiplier = multiplier; + params.scalar.rounding = rounding; + params.scalar.right_shift = right_shift; + params.scalar.output_min_less_zero_point = + (int32_t)(uint32_t)output_min - (int32_t)(uint32_t)output_zero_point; + params.scalar.output_max_less_zero_point = + (int32_t)(uint32_t)output_max - (int32_t)(uint32_t)output_zero_point; + params.scalar.output_zero_point = (int32_t)(uint32_t)output_zero_point; +#endif + return params; +} + +static inline union pytorch_qnnp_avgpool_quantization_params +pytorch_qnnp_compute_scalar_avgpool_quantization_params( + int32_t bias, + float scale, + uint8_t output_zero_point, + uint8_t output_min, + uint8_t output_max) { + /* Compute requantization parameters */ + assert(scale >= 0x1.0p-32f); + assert(scale < 256.0f); + const uint32_t scale_bits = fp32_to_bits(scale); + + /* Multiplier is in [0x00800000, 0x00FFFFFF] range */ + const int32_t multiplier = + ((int32_t)scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000); + assert(multiplier >= INT32_C(0x00800000)); + assert(multiplier <= INT32_C(0x00FFFFFF)); + + /* Shift is in [16, 55] range */ + const int32_t shift = 127 + 23 - (scale_bits >> 23); + assert(shift >= 16); + assert(shift < 64); + + union pytorch_qnnp_avgpool_quantization_params params; + const uint32_t right_shift = (uint32_t)shift; + const int64_t rounding = INT64_C(1) << (right_shift - 1); + params.scalar.bias = bias; + params.scalar.rounding = rounding; + params.scalar.multiplier = multiplier; + params.scalar.right_shift = right_shift; + params.scalar.output_min_less_zero_point = + (int32_t)(uint32_t)output_min - (int32_t)(uint32_t)output_zero_point; + params.scalar.output_max_less_zero_point = + (int32_t)(uint32_t)output_max - (int32_t)(uint32_t)output_zero_point; + params.scalar.output_zero_point = (int32_t)(uint32_t)output_zero_point; + return params; +} + +static inline union pytorch_qnnp_u8_clamping_params +pytorch_qnnp_compute_u8_clamping_params( + uint8_t output_min, + uint8_t output_max) { + assert(output_min <= output_max); + + union pytorch_qnnp_u8_clamping_params params; +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + for (uint32_t i = 0; i < 16; i++) { + params.sse2.output_max[i] = output_max; + params.sse2.output_min[i] = output_min; + } +#elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + params.neon.output_max = output_max; + params.neon.output_min = output_min; +#else + params.scalar.output_min = (int32_t)(uint32_t)output_min; + params.scalar.output_max = (int32_t)(uint32_t)output_max; +#endif + return params; +} + +static inline union pytorch_qnnp_add_quantization_params +pytorch_qnnp_compute_add_quantization_params( + uint8_t a_zero_point, + uint8_t b_zero_point, + uint8_t output_zero_point, + float a_output_scale, + float b_output_scale, + uint8_t output_min, + uint8_t output_max) { + assert(a_output_scale >= 0x1.0p-14f); + assert(b_output_scale >= 0x1.0p-14f); + assert(a_output_scale < 0x1.0p+8f); + assert(b_output_scale < 0x1.0p+8f); + + /* Compute requantization parameters */ + const float max_output_scale = + a_output_scale > b_output_scale ? a_output_scale : b_output_scale; + assert(max_output_scale >= 0x1.0p-14f); + assert(max_output_scale < 0x1.0p+8f); + const uint32_t max_scale_bits = fp32_to_bits(max_output_scale); + const int32_t max_scale_exponent = (int32_t)(max_scale_bits >> 23) - 127; + /* Shift is in [13, 31] range */ + const uint32_t shift = (uint32_t)(21 - max_scale_exponent); + assert(shift < 32); + assert(shift >= 13); + + const float scale_multiplier = + fp32_from_bits((uint32_t)(21 - max_scale_exponent + 127) << 23); + + /* Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, + * 2**22) range */ + const uint32_t a_multiplier = + (uint32_t)(int32_t)lrintf(a_output_scale * scale_multiplier); + const uint32_t b_multiplier = + (uint32_t)(int32_t)lrintf(b_output_scale * scale_multiplier); + assert( + (a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= + UINT32_C(0x00200000)); + assert(a_multiplier < UINT32_C(0x00400000)); + assert(b_multiplier < UINT32_C(0x00400000)); + + union pytorch_qnnp_add_quantization_params params; +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1); + const uint32_t remainder_threshold = remainder_mask >> 1; + const int32_t zero_point_product = (int32_t) - + (a_multiplier * (uint32_t)a_zero_point + + b_multiplier * (uint32_t)b_zero_point); + for (uint32_t i = 0; i < 4; i++) { + params.sse2.zero_point_product[i] = zero_point_product; + } + for (uint32_t i = 0; i < 8; i++) { + params.sse2.y_zero_point[i] = (int16_t)(uint16_t)output_zero_point; + } + for (uint32_t i = 0; i < 8; i++) { + params.sse2.a_multiplier_lo[i] = (uint16_t)(uint32_t)a_multiplier; + params.sse2.a_multiplier_hi[i] = (uint16_t)((uint32_t)a_multiplier >> 16); + params.sse2.b_multiplier_lo[i] = (uint16_t)(uint32_t)b_multiplier; + params.sse2.b_multiplier_hi[i] = (uint16_t)((uint32_t)b_multiplier >> 16); + } + params.sse2.a_multiplier = a_multiplier; + params.sse2.b_multiplier = b_multiplier; + for (uint32_t i = 0; i < 4; i++) { + params.sse2.remainder_mask[i] = remainder_mask; + params.sse2.remainder_threshold[i] = remainder_threshold; + } + params.sse2.shift = shift; + for (uint32_t i = 0; i < 16; i++) { + params.sse2.y_max[i] = output_max; + params.sse2.y_min[i] = output_min; + } +#elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + params.neon.a_zero_point = a_zero_point; + params.neon.b_zero_point = b_zero_point; + params.neon.y_zero_point = (int16_t)(uint16_t)output_zero_point; + params.neon.a_multiplier = (int32_t)a_multiplier; + params.neon.b_multiplier = (int32_t)b_multiplier; + params.neon.right_shift = (int32_t)-shift; + params.neon.y_max = output_max; + params.neon.y_min = output_min; +#else + const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1); + const uint32_t remainder_threshold = remainder_mask >> 1; + params.scalar.zero_point_product = (int32_t) - + (a_multiplier * (uint32_t)a_zero_point + + b_multiplier * (uint32_t)b_zero_point); + params.scalar.a_multiplier = a_multiplier; + params.scalar.b_multiplier = b_multiplier; + params.scalar.remainder_mask = (int32_t)remainder_mask; + params.scalar.remainder_threshold = (int32_t)remainder_threshold; + params.scalar.shift = shift; + params.scalar.y_zero_point = (int32_t)(uint32_t)output_zero_point; + params.scalar.y_max = (int32_t)(uint32_t)output_max; + params.scalar.y_min = (int32_t)(uint32_t)output_min; +#endif + return params; +} + +static inline union pytorch_qnnp_add_quantization_params +pytorch_qnnp_compute_scalar_add_quantization_params( + uint8_t a_zero_point, + uint8_t b_zero_point, + uint8_t output_zero_point, + float a_output_scale, + float b_output_scale, + uint8_t output_min, + uint8_t output_max) { + assert(a_output_scale >= 0x1.0p-10f); + assert(b_output_scale >= 0x1.0p-10f); + assert(a_output_scale < 0x1.0p+8f); + assert(b_output_scale < 0x1.0p+8f); + + /* Compute requantization parameters */ + const float max_output_scale = + a_output_scale > b_output_scale ? a_output_scale : b_output_scale; + assert(max_output_scale >= 0x1.0p-10f); + assert(max_output_scale < 0x1.0p+8f); + const uint32_t max_scale_bits = fp32_to_bits(max_output_scale); + const int32_t max_scale_exponent = (int32_t)(max_scale_bits >> 23) - 127; + /* Shift is in [13, 31] range */ + const uint32_t shift = (uint32_t)(21 - max_scale_exponent); + assert(shift < 32); + assert(shift >= 13); + + /* Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, + * 2**22) range */ + const uint32_t a_multiplier = (uint32_t)(int32_t)lrintf( + fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23))); + const uint32_t b_multiplier = (uint32_t)(int32_t)lrintf( + fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23))); + assert( + (a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= + UINT32_C(0x00200000)); + assert(a_multiplier < UINT32_C(0x00400000)); + assert(b_multiplier < UINT32_C(0x00400000)); + + union pytorch_qnnp_add_quantization_params params; + const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1); + const uint32_t remainder_threshold = remainder_mask >> 1; + params.scalar.zero_point_product = (int32_t) - + (a_multiplier * (uint32_t)a_zero_point + + b_multiplier * (uint32_t)b_zero_point); + params.scalar.a_multiplier = a_multiplier; + params.scalar.b_multiplier = b_multiplier; + params.scalar.remainder_mask = (int32_t)remainder_mask; + params.scalar.remainder_threshold = (int32_t)remainder_threshold; + params.scalar.shift = shift; + params.scalar.y_zero_point = (int32_t)(uint32_t)output_zero_point; + params.scalar.y_max = (int32_t)(uint32_t)output_max; + params.scalar.y_min = (int32_t)(uint32_t)output_min; + return params; +} + +static inline uint8_t pytorch_qnnp_q31_requantize( + int32_t n, + union pytorch_qnnp_q31_requantization_params params) { + const int64_t product = (int64_t)n * (int64_t)params.scalar.multiplier; + const int32_t q31product = + (int32_t)(uint32_t)((uint64_t)(product + INT64_C(0x40000000)) >> 31); + const int32_t remainder = + (q31product & params.scalar.remainder_mask) - (int32_t)(n < 0); + n = asr_s32(q31product, params.scalar.shift) + + (int32_t)(remainder > params.scalar.remainder_threshold); + if (n < params.scalar.min_less_zero_point) { + n = params.scalar.min_less_zero_point; + } + if (n > params.scalar.max_less_zero_point) { + n = params.scalar.max_less_zero_point; + } + + return (uint8_t)(n + params.scalar.zero_point); +} + +static inline uint8_t pytorch_qnnp_avgpool_quantize( + int32_t n, + union pytorch_qnnp_avgpool_quantization_params params) { + const int64_t product = (int64_t)n * (int64_t)params.scalar.multiplier; + const int64_t adjusted_product = product - (int64_t)(n < 0); + + n = (int32_t)asr_s64( + adjusted_product + params.scalar.rounding, params.scalar.right_shift); + if (n < params.scalar.output_min_less_zero_point) { + n = params.scalar.output_min_less_zero_point; + } + if (n > params.scalar.output_max_less_zero_point) { + n = params.scalar.output_max_less_zero_point; + } + + return (uint8_t)(n + params.scalar.output_zero_point); +} + +static inline uint8_t pytorch_qnnp_add_quantize( + uint8_t a, + uint8_t b, + union pytorch_qnnp_add_quantization_params params) { + /* Multiply by factors and accumulate products */ + int32_t acc = params.scalar.zero_point_product + + (int32_t)((uint32_t)a * params.scalar.a_multiplier) + + (int32_t)((uint32_t)b * params.scalar.b_multiplier); + + /* Shift right and round */ + const int32_t rem = (acc & params.scalar.remainder_mask) - (int32_t)(acc < 0); + acc = asr_s32(acc, params.scalar.shift) + + (int32_t)(rem > params.scalar.remainder_threshold); + + /* Clamp and add output zero point */ + int32_t y = acc + params.scalar.y_zero_point; + if (y >= params.scalar.y_max) { + y = params.scalar.y_max; + } + if (y <= params.scalar.y_min) { + y = params.scalar.y_min; + } + return (uint8_t)y; +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/scalar-utils.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/scalar-utils.h new file mode 100644 index 0000000000000..b2368dcb5e910 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/scalar-utils.h @@ -0,0 +1,119 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include + +#if defined(__clang__) +#if __clang_major__ == 3 && __clang_minor__ >= 7 || __clang_major__ > 3 +#define PYTORCH_QNNP_IGNORE_SHIFT_BASE_UB \ + __attribute__((__no_sanitize__("shift-base"))) +#else +#define PYTORCH_QNNP_IGNORE_SHIFT_BASE_UB +#endif +#elif defined(__GNUC__) +#if __GNUC__ >= 8 +#define PYTORCH_QNNP_IGNORE_SHIFT_BASE_UB \ + __attribute__((__no_sanitize__("shift-base"))) +#elif __GNUC__ == 4 && __GNUC_MINOR__ >= 9 || __GNUC__ > 4 +/* 4.9 <= gcc < 8 support ubsan, but doesn't support no_sanitize attribute */ +#define PYTORCH_QNNP_IGNORE_SHIFT_BASE_UB +#ifndef PYTORCH_QNNP_USE_SHIFT_BASE_UB_WORKAROUND +#define PYTORCH_QNNP_USE_SHIFT_BASE_UB_WORKAROUND 1 +#endif +#else +#define PYTORCH_QNNP_IGNORE_SHIFT_BASE_UB +#endif +#else +#define PYTORCH_QNNP_IGNORE_SHIFT_BASE_UB +#endif + +PYTORCH_QNNP_IGNORE_SHIFT_BASE_UB +inline static int32_t asr_s32(int32_t x, uint32_t n) { +#ifdef PYTORCH_QNNP_USE_SHIFT_BASE_UB_WORKAROUND +#if defined(__x86_64__) || defined(__aarch64__) + return (int32_t)((uint64_t)(int64_t)x >> n); +#else + return x >= 0 ? x >> n : ~(~x >> n); +#endif +#else + return x >> n; +#endif +} + +PYTORCH_QNNP_IGNORE_SHIFT_BASE_UB +inline static int64_t asr_s64(int64_t x, uint32_t n) { +#ifdef PYTORCH_QNNP_USE_SHIFT_BASE_UB_WORKAROUND + return x >= 0 ? x >> n : ~(~x >> n); +#else + return x >> n; +#endif +} + +inline static uint8_t pytorch_scalar_requantize_precise( + int32_t value, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax) { + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const uint32_t scale_bits = fp32_to_bits(scale); + const uint32_t multiplier = + (scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000); + const uint32_t shift = 127 + 23 - (scale_bits >> 23); + assert(shift >= 24); + assert(shift < 56); + + /* + * Compute absolute value of input as unsigned 32-bit int. + * All further computations will work with unsigned values to avoid undefined + * behaviour on signed operations. + */ + const uint32_t abs_value = (value >= 0) ? (uint32_t)value : -(uint32_t)value; + + /* Compute full 64-bit product of 32-bit factors */ + const uint64_t product = (uint64_t)abs_value * (uint64_t)multiplier; + + /* + * Shift the full 64-bit product right with rounding. + * Rounding is performed towards closest integer, with midpoints rounded up + * (same as away from zero). + */ + const uint64_t rounding = UINT64_C(1) << (shift - 1); + const uint32_t abs_scaled_value = (uint32_t)((product + rounding) >> shift); + + /* + * Copy the sign of input to scaled absolute input value. + */ + const int32_t scaled_value = + (int32_t)(value >= 0 ? abs_scaled_value : -abs_scaled_value); + + /* Clamp scaled value with zero point between smin and smax */ + int32_t clamped_value = scaled_value; + const int32_t smin = (int32_t)(uint32_t)qmin - (int32_t)(uint32_t)zero_point; + if (clamped_value < smin) { + clamped_value = smin; + } + const int32_t smax = (int32_t)(uint32_t)qmax - (int32_t)(uint32_t)zero_point; + if (clamped_value > smax) { + clamped_value = smax; + } + + /* Add zero point to clamped value */ + const int32_t biased_value = clamped_value + (int32_t)(uint32_t)zero_point; + + return biased_value; +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/sconv.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/sconv.h new file mode 100644 index 0000000000000..5e4eb843855b3 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/sconv.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define DECLARE_PYTORCH_SCONV_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t mr, \ + size_t nr, \ + size_t kc, \ + size_t ks, \ + const float** a, \ + const float* w, \ + float* c, \ + size_t c_stride, \ + const struct pytorch_qnnp_fp32_clamping_params* params); + +DECLARE_PYTORCH_SCONV_UKERNEL_FUNCTION(pytorch_sconv_ukernel_6x8__psimd) + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/sdwconv.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/sdwconv.h new file mode 100644 index 0000000000000..aa783e88070e9 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/sdwconv.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define DECLARE_PYTORCH_SUPDWCONV_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t channels, \ + size_t output_width, \ + const float** input, \ + const float* weights, \ + float* output, \ + size_t input_stride, \ + size_t output_increment, \ + const struct pytorch_qnnp_fp32_clamping_params* clamping_params); + +DECLARE_PYTORCH_SUPDWCONV_UKERNEL_FUNCTION(pytorch_sdwconv_ukernel_up4x9__psimd) + +#define DECLARE_PYTORCH_SMPDWCONV_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t channels, \ + size_t output_width, \ + const uint8_t** input, \ + const void* weights, \ + int32_t* buffer, \ + uint8_t* output, \ + size_t input_stride, \ + size_t output_increment, \ + const struct pytorch_qnnp_fp32_clamping_params* clamping_params); + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/sgemm.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/sgemm.h new file mode 100644 index 0000000000000..a28ab4802c823 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/sgemm.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define DECLARE_PYTORCH_SGEMM_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t mr, \ + size_t nr, \ + size_t k, \ + const float* a, \ + size_t a_stride, \ + const float* w, \ + float* c, \ + size_t c_stride, \ + const struct pytorch_qnnp_fp32_clamping_params* clamping_params); + +DECLARE_PYTORCH_SGEMM_UKERNEL_FUNCTION(pytorch_sgemm_ukernel_5x8__neon) +DECLARE_PYTORCH_SGEMM_UKERNEL_FUNCTION(pytorch_sgemm_ukernel_6x8__neon) +DECLARE_PYTORCH_SGEMM_UKERNEL_FUNCTION(pytorch_sgemm_ukernel_6x8__psimd) + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/u8clamp.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/u8clamp.h new file mode 100644 index 0000000000000..c4e223bf651f6 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/u8clamp.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define DECLARE_PYTORCH_U8CLAMP_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t n, \ + const uint8_t* x, \ + uint8_t* y, \ + const union pytorch_qnnp_u8_clamping_params* params); + +DECLARE_PYTORCH_U8CLAMP_UKERNEL_FUNCTION(pytorch_u8clamp_ukernel__neon) +DECLARE_PYTORCH_U8CLAMP_UKERNEL_FUNCTION(pytorch_u8clamp_ukernel__sse2) + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/u8lut32norm.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/u8lut32norm.h new file mode 100644 index 0000000000000..9cdf0f6f07044 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/u8lut32norm.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define DECLARE_PYTORCH_X8LUT32NORM_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t n, const uint8_t* x, const uint32_t* t, uint8_t* y); + +DECLARE_PYTORCH_X8LUT32NORM_UKERNEL_FUNCTION(pytorch_u8lut32norm_ukernel__scalar) + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/u8maxpool.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/u8maxpool.h new file mode 100644 index 0000000000000..8231c8f16c832 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/u8maxpool.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define DECLARE_PYTORCH_U8MAXPOOL_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t n, \ + size_t ks, \ + size_t kc, \ + const uint8_t** x, \ + uint8_t* y, \ + size_t x_increment, \ + size_t y_increment, \ + const union pytorch_qnnp_u8_clamping_params* params); + +DECLARE_PYTORCH_U8MAXPOOL_UKERNEL_FUNCTION(pytorch_u8maxpool_ukernel_16x9p8q__neon) +DECLARE_PYTORCH_U8MAXPOOL_UKERNEL_FUNCTION(pytorch_u8maxpool_ukernel_16x9p8q__sse2) +DECLARE_PYTORCH_U8MAXPOOL_UKERNEL_FUNCTION(pytorch_u8maxpool_ukernel_sub16__neon) +DECLARE_PYTORCH_U8MAXPOOL_UKERNEL_FUNCTION(pytorch_u8maxpool_ukernel_sub16__sse2) + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/u8rmax.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/u8rmax.h new file mode 100644 index 0000000000000..24d3cd00874fe --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/u8rmax.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define DECLARE_PYTORCH_U8RMAX_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL uint8_t fn_name(size_t n, const uint8_t* x); + +DECLARE_PYTORCH_U8RMAX_UKERNEL_FUNCTION(pytorch_u8rmax_ukernel__neon) +DECLARE_PYTORCH_U8RMAX_UKERNEL_FUNCTION(pytorch_u8rmax_ukernel__sse2) + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/x8lut.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/x8lut.h new file mode 100644 index 0000000000000..b2a87c538fd8a --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/x8lut.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define DECLARE_PYTORCH_X8LUT_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t n, const uint8_t* x, const uint8_t* t, uint8_t* y); + +DECLARE_PYTORCH_X8LUT_UKERNEL_FUNCTION(pytorch_x8lut_ukernel__scalar) + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/x8zip.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/x8zip.h new file mode 100644 index 0000000000000..7f3430dc50b48 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/qnnpack/x8zip.h @@ -0,0 +1,40 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define DECLARE_PYTORCH_XZIPC_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name(size_t n, const void* x, void* y); + +DECLARE_PYTORCH_XZIPC_UKERNEL_FUNCTION(pytorch_qnnp_x8zip_x2__neon) +DECLARE_PYTORCH_XZIPC_UKERNEL_FUNCTION(pytorch_qnnp_x8zip_x2__sse2) +DECLARE_PYTORCH_XZIPC_UKERNEL_FUNCTION(pytorch_qnnp_x8zip_x3__neon) +DECLARE_PYTORCH_XZIPC_UKERNEL_FUNCTION(pytorch_qnnp_x8zip_x3__sse2) +DECLARE_PYTORCH_XZIPC_UKERNEL_FUNCTION(pytorch_qnnp_x8zip_x4__neon) +DECLARE_PYTORCH_XZIPC_UKERNEL_FUNCTION(pytorch_qnnp_x8zip_x4__sse2) + +#define DECLARE_PYTORCH_XZIPV_UKERNEL_FUNCTION(fn_name) \ + PYTORCH_QNNP_INTERNAL void fn_name( \ + size_t n, size_t m, const void* x, void* y); + +DECLARE_PYTORCH_XZIPV_UKERNEL_FUNCTION(pytorch_qnnp_x8zip_xm__neon) +DECLARE_PYTORCH_XZIPV_UKERNEL_FUNCTION(pytorch_qnnp_x8zip_xm__sse2) + +#ifdef __cplusplus +} /* extern "C" */ +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/fp32-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/fp32-neon.c new file mode 100644 index 0000000000000..53677e7418c0e --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/fp32-neon.c @@ -0,0 +1,171 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include + +void pytorch_qnnp_requantize_fp32__neon( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 16 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const float32x4_t vscale = vdupq_n_f32(scale); +#ifdef __aarch64__ + const int16x8_t vzero_point = vdupq_n_s16((int16_t)(uint16_t)zero_point); + const uint8x16_t vqmin = vdupq_n_u8(qmin); + const uint8x16_t vqmax = vdupq_n_u8(qmax); +#else + const float32x4_t vfmin = vdupq_n_f32( + (float)((int32_t)(uint32_t)qmin - (int32_t)(uint32_t)zero_point)); + const float32x4_t vfmax = vdupq_n_f32( + (float)((int32_t)(uint32_t)qmax - (int32_t)(uint32_t)zero_point)); + const float32x4_t vfmagic = vdupq_n_f32(12582912.0f); + const int32x4_t vimagic = + vdupq_n_s32(INT32_C(0x4B400000) - (int32_t)(uint32_t)zero_point); +#endif + for (; n != 0; n -= 16) { + const int32x4_t x = vld1q_s32(input); + const int32x4_t y = vld1q_s32(input + 4); + const int32x4_t z = vld1q_s32(input + 8); + const int32x4_t w = vld1q_s32(input + 12); + input += 16; + + /* + * Convert int32_t input to FP32 and multiply by FP32 scale. + * Both operations involve statistically unbiased roundings: + * - Large int32_t values can't be exactly represented as FP32. The + * conversion instruction in ARM NEON would round it to nearest FP32 value + * with ties to even. + * - Product of two FP32 values is generally not exactly representation as + * an FP32 value, and will be rounded to nearest FP32 value with ties to + * even. + */ + const float32x4_t x_scaled = vmulq_f32(vcvtq_f32_s32(x), vscale); + const float32x4_t y_scaled = vmulq_f32(vcvtq_f32_s32(y), vscale); + const float32x4_t z_scaled = vmulq_f32(vcvtq_f32_s32(z), vscale); + const float32x4_t w_scaled = vmulq_f32(vcvtq_f32_s32(w), vscale); + +#ifdef __aarch64__ + /* + * Leverage "Floating-point Convert to Signed integer, rouding to nearest + * with ties to even" instruction. This is an ARMv8 instruction (always + * available in AArch64), which saturates result on overflow. We don't need + * to specifically consider saturated results, they will be clamped at the + * last stage. + */ + const int32x4_t x_rounded = vcvtnq_s32_f32(x_scaled); + const int32x4_t y_rounded = vcvtnq_s32_f32(y_scaled); + const int32x4_t z_rounded = vcvtnq_s32_f32(z_scaled); + const int32x4_t w_rounded = vcvtnq_s32_f32(w_scaled); + + /* + * Standard final sequence on ARM NEON: + * - Pack to int16_t and saturate + * - Add zero point + * - Pack to uint8_t and saturate + * - Clamp between qmin and qmax + */ + const int16x8_t xy_packed = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(x_rounded), y_rounded), vzero_point); + const int16x8_t zw_packed = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(z_rounded), w_rounded), vzero_point); + const uint8x16_t xyzw_packed = + vqmovun_high_s16(vqmovun_s16(xy_packed), zw_packed); + const uint8x16_t xyzw_clamped = + vmaxq_u8(vminq_u8(xyzw_packed, vqmax), vqmin); + + vst1q_u8(output, xyzw_clamped); + output += 16; +#else + /* + * ARMv7 NEON offers only a floating-point to integer conversion instruction + * with rounding towards zero. In lieu of conversion instruction with + * rounding-to-nearest-even, we use a magic trick of adding a large number + * (1.5 * 2**23) to scaled value to cause rounding to integer, and then + * substracing this magic number as integer. This trick works only in a + * limited range (absolute value of input must be less than 2**22), so + * generally we have to clamp input to this range before using the magic. + * However, clamping to any smaller range works just as well, and thus we + * clamp to [qmin - zero point, qmax - zero point] range so that after we + * add zero point to the result, it gets into target [qmin, qmax] range. + */ + const float32x4_t x_clamped = vminq_f32(vmaxq_f32(x_scaled, vfmin), vfmax); + const float32x4_t y_clamped = vminq_f32(vmaxq_f32(y_scaled, vfmin), vfmax); + const float32x4_t z_clamped = vminq_f32(vmaxq_f32(z_scaled, vfmin), vfmax); + const float32x4_t w_clamped = vminq_f32(vmaxq_f32(w_scaled, vfmin), vfmax); + + /* + * Conversion to integer using the "magic trick". Rounding is performed in + * the output of addition operation, and result is rounded to nearest even + * integer with ties to even. + */ + const int32x4_t x_biased = vsubq_s32( + vreinterpretq_s32_f32(vaddq_f32(x_clamped, vfmagic)), vimagic); + const int32x4_t y_biased = vsubq_s32( + vreinterpretq_s32_f32(vaddq_f32(y_clamped, vfmagic)), vimagic); + const int32x4_t z_biased = vsubq_s32( + vreinterpretq_s32_f32(vaddq_f32(z_clamped, vfmagic)), vimagic); + const int32x4_t w_biased = vsubq_s32( + vreinterpretq_s32_f32(vaddq_f32(w_clamped, vfmagic)), vimagic); + + /* + * Select low 8 bits of each 32-bit integer in the vectors for the output. + * Since result is already clamped to [qmin, qmax] subrange of [0, 255], + * saturation is not needed. + */ + const int16x8_t xy_packed = + vcombine_s16(vmovn_s32(x_biased), vmovn_s32(y_biased)); + const int16x8_t zw_packed = + vcombine_s16(vmovn_s32(z_biased), vmovn_s32(w_biased)); + const uint8x16_t xyzw_packed = vreinterpretq_u8_s8( + vcombine_s8(vmovn_s16(xy_packed), vmovn_s16(zw_packed))); + + /* + * AArch32 version: + * 4x VCVT.F32.S32 Qd, Qm + * 4x VMUL.F32 Qd, Qm, Qn + * 4x VMIN.F32 Qd, Qm, Qn + * 4x VMAX.F32 Qd, Qm, Qn + * 4x VADD.F32 Qd, Qm, Qn + * 4x VSUB.S32 Qd, Qm, Qn + * 4x VMOVN.I32 Dd, Qm + * 2x VMOVN.I16 Dd, Qm + * --------------------- + * 30 instructions total + * + * AArch64 version: + * 4x SCVTF Vd.4S, Vn.4S + * 4x FMUL Vd.4S, Vn.4S, Vm.4S + * 4x FCVTNS Vd.4S, Vn.4S + * 2x SQXTN Vd.4H, Vn.4S + * 2x SQXTN2 Vd.8H, Vn.4S + * 2x ADD Vd.8H, Vn.8H, Vm.8H + * 1x SQXTUN Vd.8B, Vn.8H + * 1x SQXTUN2 Vd.16B, Vn.8H + * 1x UMIN Vd.16B, Vn.16B, Vm.16B + * 1x UMAX Vd.16B, Vn.16B, Vm.16B + * --------------------- + * 22 instructions total + */ + + vst1q_u8(output, xyzw_packed); + output += 16; +#endif + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/fp32-psimd.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/fp32-psimd.c new file mode 100644 index 0000000000000..4c8e296a0d899 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/fp32-psimd.c @@ -0,0 +1,107 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include + +void pytorch_qnnp_requantize_fp32__psimd( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 16 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const psimd_f32 vscale = psimd_splat_f32(scale); + const psimd_f32 vfmin = psimd_splat_f32( + (float)((int32_t)(uint32_t)qmin - (int32_t)(uint32_t)zero_point)); + const psimd_f32 vfmax = psimd_splat_f32( + (float)((int32_t)(uint32_t)qmax - (int32_t)(uint32_t)zero_point)); + const psimd_f32 vfmagic = psimd_splat_f32(12582912.0f); + const psimd_s32 vimagic = + psimd_splat_s32(INT32_C(0x4B400000) - (int32_t)(uint32_t)zero_point); + for (; n != 0; n -= 16) { + const psimd_s32 x = psimd_load_s32(input); + const psimd_s32 y = psimd_load_s32(input + 4); + const psimd_s32 z = psimd_load_s32(input + 8); + const psimd_s32 w = psimd_load_s32(input + 12); + input += 16; + + /* + * Convert int32_t input to FP32 and multiply by FP32 scale. + * Both operations involve roundings: + * - Large int32_t values can't be exactly represented as FP32. We expect + * that conversion instruction would round it to nearest FP32 value with + * ties to even, but Clang documentation for __builtin_convertvector does + * not guaratee that. + * - Product of two FP32 values is generally not exactly representation as + * an FP32 value, and will be rounded to nearest FP32 value with ties to + * even. + */ + const psimd_f32 x_scaled = psimd_cvt_s32_f32(x) * vscale; + const psimd_f32 y_scaled = psimd_cvt_s32_f32(y) * vscale; + const psimd_f32 z_scaled = psimd_cvt_s32_f32(z) * vscale; + const psimd_f32 w_scaled = psimd_cvt_s32_f32(w) * vscale; + + /* + * Clang/gcc vector extension does not provide an intrinsics for a + * floating-point to integer conversion operation with + * rounding-to-nearest-even. In lieu of such intrinsic, we use a magic trick + * of adding a large number (1.5 * 2**23) to scaled value to cause rounding + * to integer, and then substracing this magic number as integer. This trick + * works only in a limited range (absolute value of input must be less than + * 2**22), so generally we have to clamp input to this range before using + * the magic. However, clamping to any smaller range works just as well, and + * thus we clamp to [qmin - zero point, qmax - zero point] range so that + * after we add zero point to the result, it gets into target [qmin, qmax] + * range. + */ + const psimd_f32 x_clamped = + psimd_min_f32(psimd_max_f32(x_scaled, vfmin), vfmax); + const psimd_f32 y_clamped = + psimd_min_f32(psimd_max_f32(y_scaled, vfmin), vfmax); + const psimd_f32 z_clamped = + psimd_min_f32(psimd_max_f32(z_scaled, vfmin), vfmax); + const psimd_f32 w_clamped = + psimd_min_f32(psimd_max_f32(w_scaled, vfmin), vfmax); + + /* + * Conversion to integer using the "magic trick". Rounding is performed in + * the output of addition operation, and result is rounded to nearest even + * integer with ties to even. + */ + const psimd_s32 x_biased = (psimd_s32)(x_clamped + vfmagic) - vimagic; + const psimd_s32 y_biased = (psimd_s32)(y_clamped + vfmagic) - vimagic; + const psimd_s32 z_biased = (psimd_s32)(z_clamped + vfmagic) - vimagic; + const psimd_s32 w_biased = (psimd_s32)(w_clamped + vfmagic) - vimagic; + + /* + * Select low 8 bits of each 32-bit integer in the vectors for the output. + * Since result is already clamped to [qmin, qmax] subrange of [0, 255], + * saturation is not needed. + */ + const psimd_u16 xy_packed = + psimd_concat_even_u16((psimd_u16)x_biased, (psimd_u16)y_biased); + const psimd_u16 zw_packed = + psimd_concat_even_u16((psimd_u16)z_biased, (psimd_u16)w_biased); + + const psimd_u8 xyzw_packed = + psimd_concat_even_u8((psimd_u8)xy_packed, (psimd_u8)zw_packed); + + psimd_store_u8(output, xyzw_packed); + output += 16; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/fp32-scalar.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/fp32-scalar.c new file mode 100644 index 0000000000000..ce2431f15c9c0 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/fp32-scalar.c @@ -0,0 +1,121 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include + +void pytorch_qnnp_requantize_fp32__scalar_lrintf( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 4 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const long lmin = + (long)((int32_t)(uint32_t)qmin - (int32_t)(uint32_t)zero_point); + const long lmax = + (long)((int32_t)(uint32_t)qmax - (int32_t)(uint32_t)zero_point); + for (; n != 0; n -= 4) { + const int32_t x = input[0]; + const int32_t y = input[1]; + const int32_t z = input[2]; + const int32_t w = input[3]; + input += 4; + + const float x_scaled = (float)x * scale; + const float y_scaled = (float)y * scale; + const float z_scaled = (float)z * scale; + const float w_scaled = (float)w * scale; + + const long x_rounded = lrintf(x_scaled); + const long y_rounded = lrintf(y_scaled); + const long z_rounded = lrintf(z_scaled); + const long w_rounded = lrintf(w_scaled); + + const int32_t x_clamped = (int32_t)( + x_rounded < lmin ? lmin : x_rounded > lmax ? lmax : x_rounded); + const int32_t y_clamped = (int32_t)( + y_rounded < lmin ? lmin : y_rounded > lmax ? lmax : y_rounded); + const int32_t z_clamped = (int32_t)( + z_rounded < lmin ? lmin : z_rounded > lmax ? lmax : z_rounded); + const int32_t w_clamped = (int32_t)( + w_rounded < lmin ? lmin : w_rounded > lmax ? lmax : w_rounded); + + const int32_t x_biased = x_clamped + (int32_t)(uint32_t)zero_point; + const int32_t y_biased = y_clamped + (int32_t)(uint32_t)zero_point; + const int32_t z_biased = z_clamped + (int32_t)(uint32_t)zero_point; + const int32_t w_biased = w_clamped + (int32_t)(uint32_t)zero_point; + + output[0] = (uint8_t)x_biased; + output[1] = (uint8_t)y_biased; + output[2] = (uint8_t)z_biased; + output[3] = (uint8_t)w_biased; + output += 4; + } +} + +void pytorch_qnnp_requantize_fp32__scalar_magic( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 4 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const float fmin = + (float)((int32_t)(uint32_t)qmin - (int32_t)(uint32_t)zero_point); + const float fmax = + (float)((int32_t)(uint32_t)qmax - (int32_t)(uint32_t)zero_point); + const float fmagic = 12582912.0f; + const int32_t imagic = INT32_C(0x4B400000) - (int32_t)(uint32_t)zero_point; + for (; n != 0; n -= 4) { + const int32_t x = input[0]; + const int32_t y = input[1]; + const int32_t z = input[2]; + const int32_t w = input[3]; + input += 4; + + const float x_scaled = (float)x * scale; + const float y_scaled = (float)y * scale; + const float z_scaled = (float)z * scale; + const float w_scaled = (float)w * scale; + + const float x_clamped = + x_scaled < fmin ? fmin : x_scaled > fmax ? fmax : x_scaled; + const float y_clamped = + y_scaled < fmin ? fmin : y_scaled > fmax ? fmax : y_scaled; + const float z_clamped = + z_scaled < fmin ? fmin : z_scaled > fmax ? fmax : z_scaled; + const float w_clamped = + w_scaled < fmin ? fmin : w_scaled > fmax ? fmax : w_scaled; + + const int32_t x_biased = (int32_t)fp32_to_bits(x_clamped + fmagic) - imagic; + const int32_t y_biased = (int32_t)fp32_to_bits(y_clamped + fmagic) - imagic; + const int32_t z_biased = (int32_t)fp32_to_bits(z_clamped + fmagic) - imagic; + const int32_t w_biased = (int32_t)fp32_to_bits(w_clamped + fmagic) - imagic; + + output[0] = (uint8_t)x_biased; + output[1] = (uint8_t)y_biased; + output[2] = (uint8_t)z_biased; + output[3] = (uint8_t)w_biased; + output += 4; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/fp32-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/fp32-sse2.c new file mode 100644 index 0000000000000..0a487e2194b28 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/fp32-sse2.c @@ -0,0 +1,109 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include + +void pytorch_qnnp_requantize_fp32__sse2( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 16 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const __m128 vscale = _mm_set1_ps(scale); + const __m128i vzero_point = _mm_set1_epi16((short)(uint16_t)zero_point); + const __m128i vqmin = _mm_set1_epi8((char)qmin); + const __m128i vqmax = _mm_set1_epi8((char)qmax); + for (; n != 0; n -= 16) { + const __m128i x = _mm_loadu_si128((const __m128i*)input); + const __m128i y = _mm_loadu_si128((const __m128i*)(input + 4)); + const __m128i z = _mm_loadu_si128((const __m128i*)(input + 8)); + const __m128i w = _mm_loadu_si128((const __m128i*)(input + 12)); + input += 16; + + /* + * Convert int32_t input to FP32 and multiply by FP32 scale. + * Both operations involve statistically unbiased roundings (with default + * MXCSR rounding mode): + * - Large int32_t values can't be exactly represented as FP32. CVTDQ2PS + * instruction on x86 would round it according to nearest FP32 value with + * ties to even (assuming default MXCSR rounding mode). + * - Product of two FP32 values is generally not exactly representation as + * an FP32 value, and will be rounded to nearest FP32 value with ties to + * even with default MXCSR rounding mode. + */ + const __m128 x_scaled = _mm_mul_ps(_mm_cvtepi32_ps(x), vscale); + const __m128 y_scaled = _mm_mul_ps(_mm_cvtepi32_ps(y), vscale); + const __m128 z_scaled = _mm_mul_ps(_mm_cvtepi32_ps(z), vscale); + const __m128 w_scaled = _mm_mul_ps(_mm_cvtepi32_ps(w), vscale); + + /* + * Convert scaled FP32 result to int32_t using CVTPS2DQ instruction from x86 + * SSE2. CVTPS2DQ instruction rounds result according to nearest FP32 value + * with ties to even (assuming default MXCSR rounding mode). However, when + * conversion overflows, it produces INT32_MIN as a result. For large + * positive inputs the result of conversion can become negative, which + * affects the final requantization result. Note that on x86 SSE2 we have + * e.g. int32_t(float(INT32_MAX)) == INT32_MIN! This happens because + * float(INT32_MAX) rounds to 2**31, which overflows int32_t when it is + * converted back to integer. + * + * Thankfully, we can prove that overflow never happens in this + * requantization scheme. The largest positive input is INT32_MAX (2**31 - + * 1), which turns into 2**31 when converted to float. The largest scale + * value is 0x1.FFFFFEp-1. When multiplied together, the result is + * 2147483520 (compare to INT32_MAX = 2147483647), which fits into int32_t + * without overflow. + */ + const __m128i x_rounded = _mm_cvtps_epi32(x_scaled); + const __m128i y_rounded = _mm_cvtps_epi32(y_scaled); + const __m128i z_rounded = _mm_cvtps_epi32(z_scaled); + const __m128i w_rounded = _mm_cvtps_epi32(w_scaled); + + /* + * Standard final sequence on x86 SSE2: + * - Pack to int16_t and saturate + * - Add zero point + * - Pack to uint8_t and saturate + * - Clamp between qmin and qmax + */ + const __m128i xy_packed = + _mm_adds_epi16(_mm_packs_epi32(x_rounded, y_rounded), vzero_point); + const __m128i zw_packed = + _mm_adds_epi16(_mm_packs_epi32(z_rounded, w_rounded), vzero_point); + const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed); + const __m128i xyzw_clamped = + _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin); + + /* + * 4x CVTDQ2PS + * 4x MULPS + * 4x CVTPS2DQ + * 2x PACKSSDW + * 1x PACKUSWB + * 2x PADDW + * 1x PMAXUB + * 1x PMINUB + * --------------------- + * 19 instructions total + */ + + _mm_storeu_si128((__m128i*)output, xyzw_clamped); + output += 16; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-neon.c new file mode 100644 index 0000000000000..51a435aa8d967 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-neon.c @@ -0,0 +1,113 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include +#include + +/* + * The requantization implementation below is adapted from Google's gemmlowp + * library. It is only used in QNNPACK unit tests and comparative benchmarks, + * but not the library itself. + */ + +// Copyright 2015 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +void pytorch_qnnp_requantize_gemmlowp__neon( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 16 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const uint32_t scale_bits = fp32_to_bits(scale); + + /* Compute requantization parameters */ + const uint32_t multiplier = + ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7; + const int32_t exponent = (fp32_to_bits(scale) >> 23) - 127 - 23 - 7; + const int32_t shift = + -(32 /* using high 32 bits in VQRDMUL */ - 1 /* doubling in VQRDMUL */ + + exponent); + + const int32x4_t vmultiplier = vdupq_n_s32(multiplier); + const int16x8_t vzero_point = vdupq_n_s16((int16_t)(uint16_t)zero_point); + const int32x4_t vshift = vdupq_n_s32(-shift); + const uint8x16_t vqmin = vdupq_n_u8(qmin); + const uint8x16_t vqmax = vdupq_n_u8(qmax); + for (; n != 0; n -= 16) { + const int32x4_t x = vld1q_s32(input); + const int32x4_t y = vld1q_s32(input + 4); + const int32x4_t z = vld1q_s32(input + 8); + const int32x4_t w = vld1q_s32(input + 12); + input += 16; + + const int32x4_t x_product = vqrdmulhq_s32(x, vmultiplier); + const int32x4_t y_product = vqrdmulhq_s32(y, vmultiplier); + const int32x4_t z_product = vqrdmulhq_s32(z, vmultiplier); + const int32x4_t w_product = vqrdmulhq_s32(w, vmultiplier); + + const int32x4_t x_product_fixup = vshrq_n_s32(vandq_s32(x, vshift), 31); + const int32x4_t y_product_fixup = vshrq_n_s32(vandq_s32(y, vshift), 31); + const int32x4_t z_product_fixup = vshrq_n_s32(vandq_s32(z, vshift), 31); + const int32x4_t w_product_fixup = vshrq_n_s32(vandq_s32(w, vshift), 31); + + const int32x4_t x_adjusted_product = vqaddq_s32(x_product, x_product_fixup); + const int32x4_t y_adjusted_product = vqaddq_s32(y_product, y_product_fixup); + const int32x4_t z_adjusted_product = vqaddq_s32(z_product, z_product_fixup); + const int32x4_t w_adjusted_product = vqaddq_s32(w_product, w_product_fixup); + + const int32x4_t x_scaled = vrshlq_s32(x_adjusted_product, vshift); + const int32x4_t y_scaled = vrshlq_s32(y_adjusted_product, vshift); + const int32x4_t z_scaled = vrshlq_s32(z_adjusted_product, vshift); + const int32x4_t w_scaled = vrshlq_s32(w_adjusted_product, vshift); + +#ifdef __aarch64__ + const int16x8_t xy_packed = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(x_scaled), y_scaled), vzero_point); + const int16x8_t zw_packed = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(z_scaled), w_scaled), vzero_point); + const uint8x16_t xyzw_packed = + vqmovun_high_s16(vqmovun_s16(xy_packed), zw_packed); +#else + const int16x8_t xy_packed = vqaddq_s16( + vcombine_s16(vqmovn_s32(x_scaled), vqmovn_s32(y_scaled)), vzero_point); + const int16x8_t zw_packed = vqaddq_s16( + vcombine_s16(vqmovn_s32(z_scaled), vqmovn_s32(w_scaled)), vzero_point); + const uint8x16_t xyzw_packed = + vcombine_u8(vqmovun_s16(xy_packed), vqmovun_s16(zw_packed)); +#endif + + const uint8x16_t xyzw_clamped = + vmaxq_u8(vminq_u8(xyzw_packed, vqmax), vqmin); + + vst1q_u8(output, xyzw_clamped); + output += 16; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-scalar.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-scalar.c new file mode 100644 index 0000000000000..84e2fc54e3cf7 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-scalar.c @@ -0,0 +1,81 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include + +#include "gemmlowp-scalar.h" + +void pytorch_qnnp_requantize_gemmlowp__scalar( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 4 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const uint32_t scale_bits = fp32_to_bits(scale); + + /* Compute requantization parameters */ + const uint32_t multiplier = + ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7; + const int32_t exponent = (fp32_to_bits(scale) >> 23) - 127 - 23 - 7; + const int32_t shift = + -(32 /* using high 32 bits in VQRDMUL */ - 1 /* doubling in VQRDMUL */ + + exponent); + + const int32_t smin = (int32_t)(uint32_t)qmin; + const int32_t smax = (int32_t)(uint32_t)qmax; + for (; n != 0; n -= 4) { + const int32_t x = input[0]; + const int32_t y = input[1]; + const int32_t z = input[2]; + const int32_t w = input[3]; + input += 4; + + const int32_t x_product = gemmlowp_scalar_vqrdmulh_s32(x, multiplier); + const int32_t y_product = gemmlowp_scalar_vqrdmulh_s32(y, multiplier); + const int32_t z_product = gemmlowp_scalar_vqrdmulh_s32(z, multiplier); + const int32_t w_product = gemmlowp_scalar_vqrdmulh_s32(w, multiplier); + + const int32_t x_scaled = gemmlowp_scalar_rdivbypo2_s32(x_product, shift); + const int32_t y_scaled = gemmlowp_scalar_rdivbypo2_s32(y_product, shift); + const int32_t z_scaled = gemmlowp_scalar_rdivbypo2_s32(z_product, shift); + const int32_t w_scaled = gemmlowp_scalar_rdivbypo2_s32(w_product, shift); + + /* Add zero point to scaled value */ + const int32_t x_biased = x_scaled + zero_point; + const int32_t y_biased = y_scaled + zero_point; + const int32_t z_biased = z_scaled + zero_point; + const int32_t w_biased = w_scaled + zero_point; + + /* Clamp scaled value with zero point between smin and smax */ + const int32_t x_clamped = + x_biased < smin ? smin : x_biased > smax ? smax : x_biased; + const int32_t y_clamped = + y_biased < smin ? smin : y_biased > smax ? smax : y_biased; + const int32_t z_clamped = + z_biased < smin ? smin : z_biased > smax ? smax : z_biased; + const int32_t w_clamped = + w_biased < smin ? smin : w_biased > smax ? smax : w_biased; + + output[0] = (uint8_t)x_clamped; + output[1] = (uint8_t)y_clamped; + output[2] = (uint8_t)z_clamped; + output[3] = (uint8_t)w_clamped; + output += 4; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-scalar.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-scalar.h new file mode 100644 index 0000000000000..e4fdb50c373c8 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-scalar.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +/* + * The code below is adapted from Google's gemmlowp library. + * It is only used in QNNPACK unit tests and comparative benchmarks, + * but not the library itself. + */ + +// Copyright 2015 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +inline static int32_t gemmlowp_scalar_vqrdmulh_s32(int32_t a, int32_t b) { + const bool overflow = a == b && a == INT32_MIN; + const int64_t ab_64 = (int64_t)a * (int64_t)b; + const int32_t nudge = + (a ^ b) >= 0 ? INT32_C(0x40000000) : -INT32_C(0x3FFFFFFF); + const int32_t ab_x2_high32 = (int32_t)((ab_64 + nudge) / INT64_C(0x80000000)); + return overflow ? INT32_MAX : ab_x2_high32; +} + +inline static int32_t gemmlowp_scalar_rdivbypo2_s32(int32_t x, int exponent) { + const int32_t mask = ((1 << exponent) - 1); + const int32_t remainder = x & mask; + const int32_t threshold = (mask >> 1) + (int32_t)(x < 0); + return asr_s32(x, exponent) + (int32_t)(remainder > threshold); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-sse.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-sse.h new file mode 100644 index 0000000000000..1a79fcf59ad4a --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-sse.h @@ -0,0 +1,123 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +/* + * The code below is adapted from Google's gemmlowp library. + * It is only used in QNNPACK unit tests and comparative benchmarks, + * but not the library itself. + */ + +// Copyright 2015 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +static inline __m128i gemmlowp_sse_rdivbypo2_s32(__m128i x, int exponent) { + const __m128i mask = + _mm_set1_epi32((int32_t)((UINT64_C(1) << exponent) - UINT64_C(1))); + const __m128i remainder = _mm_and_si128(x, mask); + const __m128i threshold = _mm_sub_epi32( + _mm_srli_epi32(mask, 1), _mm_cmplt_epi32(x, _mm_setzero_si128())); + return _mm_sub_epi32( + _mm_sra_epi32(x, _mm_cvtsi32_si128(exponent)), + _mm_cmpgt_epi32(remainder, threshold)); +} + +static inline __m128i gemmlowp_sse_mul_s32(__m128i a, __m128i b) { +#ifdef __SSE4_1__ + return _mm_mul_epi32(a, b); +#else + __m128i sign, zero, mul_us, a_neg, b_neg, mul_us_neg; + sign = _mm_xor_si128(a, b); + sign = _mm_srai_epi32(sign, 31); // promote sign bit to all fields, all fff if + // negative and all 0 if positive + sign = _mm_shuffle_epi32( + sign, + _MM_SHUFFLE(2, 2, 0, 0)); // promote sign bit to 3 and 1st data lanes + zero = _mm_setzero_si128(); +#ifdef __SSSE3__ + a_neg = _mm_abs_epi32(a); // negate a and b + b_neg = _mm_abs_epi32(b); // negate a and b +#else /* pre-SSSE3 */ + const __m128i a_neg_mask = _mm_cmplt_epi32(a, zero); + a_neg = _mm_sub_epi32(_mm_xor_si128(a, a_neg_mask), a_neg_mask); + const __m128i b_neg_mask = _mm_cmplt_epi32(b, zero); + b_neg = _mm_sub_epi32(_mm_xor_si128(b, b_neg_mask), b_neg_mask); +#endif /* pre-SSSE3 */ + mul_us = _mm_mul_epu32(a_neg, b_neg); // uses 0 and 2nd data lanes, (abs), the + // multiplication gives 64 bit result + mul_us_neg = _mm_sub_epi64(zero, mul_us); + mul_us_neg = _mm_and_si128(sign, mul_us_neg); + mul_us = _mm_andnot_si128(sign, mul_us); + return _mm_or_si128(mul_us, mul_us_neg); +#endif +} + +static inline __m128i gemmlowp_sse_vqrdmulh_s32(__m128i a, __m128i b) { + // saturation only happen if a == b == INT32_MIN + const __m128i min = _mm_set1_epi32(INT32_MIN); + const __m128i saturation_mask = + _mm_and_si128(_mm_cmpeq_epi32(a, b), _mm_cmpeq_epi32(a, min)); + + // a = a0 | a1 | a2 | a3 + // b = b0 | b1 | b2 | b3 + const __m128i a0_a2 = a; + const __m128i a1_a3 = _mm_srli_si128(a, 4); + const __m128i b0_b2 = b; + const __m128i b1_b3 = _mm_srli_si128(b, 4); + + const __m128i a0b0_a2b2 = gemmlowp_sse_mul_s32(a0_a2, b0_b2); + const __m128i a1b1_a3b3 = gemmlowp_sse_mul_s32(a1_a3, b1_b3); + + // do the rounding and take into account that it will be doubled + const __m128i nudge = _mm_set1_epi64x(1 << 30); + const __m128i a0b0_a2b2_rounded = _mm_add_epi64(a0b0_a2b2, nudge); + const __m128i a1b1_a3b3_rounded = _mm_add_epi64(a1b1_a3b3, nudge); + + // do the doubling + const __m128i a0b0_a2b2_rounded_2x = _mm_slli_epi64(a0b0_a2b2_rounded, 1); + const __m128i a1b1_a3b3_rounded_2x = _mm_slli_epi64(a1b1_a3b3_rounded, 1); + +// get the high part of the products +#ifdef __SSE4_1__ + const __m128i result = _mm_blend_epi16( + _mm_srli_epi64(a0b0_a2b2_rounded_2x, 32), a1b1_a3b3_rounded_2x, 0xCC); +#else + const __m128i result0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(a0b0_a2b2_rounded_2x), + _mm_castsi128_ps(a1b1_a3b3_rounded_2x), + _MM_SHUFFLE(3, 1, 3, 1))); + const __m128i result = _mm_shuffle_epi32(result0213, _MM_SHUFFLE(3, 1, 2, 0)); +#endif + +// saturate those which overflowed +#ifdef __SSE4_1__ + const __m128i saturated_result = + _mm_blendv_epi8(result, min, saturation_mask); +#else + const __m128i saturated_result = _mm_or_si128( + _mm_and_si128(saturation_mask, min), + _mm_andnot_si128(saturation_mask, result)); +#endif + return saturated_result; +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-sse2.c new file mode 100644 index 0000000000000..d6534cfa0ab75 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-sse2.c @@ -0,0 +1,73 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include +#include + +#include "gemmlowp-sse.h" + +void pytorch_qnnp_requantize_gemmlowp__sse2( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 16 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const uint32_t scale_bits = fp32_to_bits(scale); + + /* Compute requantization parameters */ + const uint32_t multiplier = + ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7; + const int32_t exponent = (fp32_to_bits(scale) >> 23) - 127 - 23 - 7; + const int32_t shift = + -(32 /* using high 32 bits in VQRDMUL */ - 1 /* doubling in VQRDMUL */ + + exponent); + + const __m128i vmultiplier = _mm_set1_epi32(multiplier); + const __m128i vzero_point = _mm_set1_epi16((short)(uint16_t)zero_point); + const __m128i vqmin = _mm_set1_epi8((char)qmin); + const __m128i vqmax = _mm_set1_epi8((char)qmax); + for (; n != 0; n -= 16) { + const __m128i x = _mm_loadu_si128((const __m128i*)input); + const __m128i y = _mm_loadu_si128((const __m128i*)(input + 4)); + const __m128i z = _mm_loadu_si128((const __m128i*)(input + 8)); + const __m128i w = _mm_loadu_si128((const __m128i*)(input + 12)); + input += 16; + + const __m128i x_product = gemmlowp_sse_vqrdmulh_s32(x, vmultiplier); + const __m128i y_product = gemmlowp_sse_vqrdmulh_s32(y, vmultiplier); + const __m128i z_product = gemmlowp_sse_vqrdmulh_s32(z, vmultiplier); + const __m128i w_product = gemmlowp_sse_vqrdmulh_s32(w, vmultiplier); + + const __m128i x_scaled = gemmlowp_sse_rdivbypo2_s32(x_product, shift); + const __m128i y_scaled = gemmlowp_sse_rdivbypo2_s32(y_product, shift); + const __m128i z_scaled = gemmlowp_sse_rdivbypo2_s32(z_product, shift); + const __m128i w_scaled = gemmlowp_sse_rdivbypo2_s32(w_product, shift); + + const __m128i xy_packed = + _mm_adds_epi16(_mm_packs_epi32(x_scaled, y_scaled), vzero_point); + const __m128i zw_packed = + _mm_adds_epi16(_mm_packs_epi32(z_scaled, w_scaled), vzero_point); + const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed); + const __m128i xyzw_clamped = + _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin); + + _mm_storeu_si128((__m128i*)output, xyzw_clamped); + output += 16; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-sse4.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-sse4.c new file mode 100644 index 0000000000000..5c870db9e86c5 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-sse4.c @@ -0,0 +1,73 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include +#include + +#include "gemmlowp-sse.h" + +void pytorch_qnnp_requantize_gemmlowp__sse4( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 16 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const uint32_t scale_bits = fp32_to_bits(scale); + + /* Compute requantization parameters */ + const uint32_t multiplier = + ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7; + const int32_t exponent = (fp32_to_bits(scale) >> 23) - 127 - 23 - 7; + const int32_t shift = + -(32 /* using high 32 bits in VQRDMUL */ - 1 /* doubling in VQRDMUL */ + + exponent); + + const __m128i vmultiplier = _mm_set1_epi32(multiplier); + const __m128i vzero_point = _mm_set1_epi16((short)(uint16_t)zero_point); + const __m128i vqmin = _mm_set1_epi8((char)qmin); + const __m128i vqmax = _mm_set1_epi8((char)qmax); + for (; n != 0; n -= 16) { + const __m128i x = _mm_loadu_si128((const __m128i*)input); + const __m128i y = _mm_loadu_si128((const __m128i*)(input + 4)); + const __m128i z = _mm_loadu_si128((const __m128i*)(input + 8)); + const __m128i w = _mm_loadu_si128((const __m128i*)(input + 12)); + input += 16; + + const __m128i x_product = gemmlowp_sse_vqrdmulh_s32(x, vmultiplier); + const __m128i y_product = gemmlowp_sse_vqrdmulh_s32(y, vmultiplier); + const __m128i z_product = gemmlowp_sse_vqrdmulh_s32(z, vmultiplier); + const __m128i w_product = gemmlowp_sse_vqrdmulh_s32(w, vmultiplier); + + const __m128i x_scaled = gemmlowp_sse_rdivbypo2_s32(x_product, shift); + const __m128i y_scaled = gemmlowp_sse_rdivbypo2_s32(y_product, shift); + const __m128i z_scaled = gemmlowp_sse_rdivbypo2_s32(z_product, shift); + const __m128i w_scaled = gemmlowp_sse_rdivbypo2_s32(w_product, shift); + + const __m128i xy_packed = + _mm_adds_epi16(_mm_packs_epi32(x_scaled, y_scaled), vzero_point); + const __m128i zw_packed = + _mm_adds_epi16(_mm_packs_epi32(z_scaled, w_scaled), vzero_point); + const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed); + const __m128i xyzw_clamped = + _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin); + + _mm_storeu_si128((__m128i*)output, xyzw_clamped); + output += 16; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-ssse3.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-ssse3.c new file mode 100644 index 0000000000000..adb01f720d673 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/gemmlowp-ssse3.c @@ -0,0 +1,73 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include +#include + +#include "gemmlowp-sse.h" + +void pytorch_qnnp_requantize_gemmlowp__ssse3( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 16 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const uint32_t scale_bits = fp32_to_bits(scale); + + /* Compute requantization parameters */ + const uint32_t multiplier = + ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7; + const int32_t exponent = (fp32_to_bits(scale) >> 23) - 127 - 23 - 7; + const int32_t shift = + -(32 /* using high 32 bits in VQRDMUL */ - 1 /* doubling in VQRDMUL */ + + exponent); + + const __m128i vmultiplier = _mm_set1_epi32(multiplier); + const __m128i vzero_point = _mm_set1_epi16((short)(uint16_t)zero_point); + const __m128i vqmin = _mm_set1_epi8((char)qmin); + const __m128i vqmax = _mm_set1_epi8((char)qmax); + for (; n != 0; n -= 16) { + const __m128i x = _mm_loadu_si128((const __m128i*)input); + const __m128i y = _mm_loadu_si128((const __m128i*)(input + 4)); + const __m128i z = _mm_loadu_si128((const __m128i*)(input + 8)); + const __m128i w = _mm_loadu_si128((const __m128i*)(input + 12)); + input += 16; + + const __m128i x_product = gemmlowp_sse_vqrdmulh_s32(x, vmultiplier); + const __m128i y_product = gemmlowp_sse_vqrdmulh_s32(y, vmultiplier); + const __m128i z_product = gemmlowp_sse_vqrdmulh_s32(z, vmultiplier); + const __m128i w_product = gemmlowp_sse_vqrdmulh_s32(w, vmultiplier); + + const __m128i x_scaled = gemmlowp_sse_rdivbypo2_s32(x_product, shift); + const __m128i y_scaled = gemmlowp_sse_rdivbypo2_s32(y_product, shift); + const __m128i z_scaled = gemmlowp_sse_rdivbypo2_s32(z_product, shift); + const __m128i w_scaled = gemmlowp_sse_rdivbypo2_s32(w_product, shift); + + const __m128i xy_packed = + _mm_adds_epi16(_mm_packs_epi32(x_scaled, y_scaled), vzero_point); + const __m128i zw_packed = + _mm_adds_epi16(_mm_packs_epi32(z_scaled, w_scaled), vzero_point); + const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed); + const __m128i xyzw_clamped = + _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin); + + _mm_storeu_si128((__m128i*)output, xyzw_clamped); + output += 16; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/precise-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/precise-neon.c new file mode 100644 index 0000000000000..9964b9f153da5 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/precise-neon.c @@ -0,0 +1,200 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include +#include + +void pytorch_qnnp_requantize_precise__neon( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 16 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const uint32_t scale_bits = fp32_to_bits(scale); + const int32_t multiplier = + ((int32_t)scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000); + const int32_t shift = 127 + 23 - (scale_bits >> 23); + assert(shift >= 24); + assert(shift < 56); + +#if defined(__aarch64__) + const int32x4_t vmultiplier = vdupq_n_s32(multiplier); +#else + const int32x2_t vmultiplier = vdup_n_s32(multiplier); +#endif + const int16x8_t vzero_point = vdupq_n_s16((int16_t)(uint16_t)zero_point); + const int64x2_t vshift = vdupq_n_s64(-shift); + const uint8x16_t vqmin = vdupq_n_u8(qmin); + const uint8x16_t vqmax = vdupq_n_u8(qmax); + for (; n != 0; n -= 16) { + const int32x4_t x = vld1q_s32(input); + const int32x4_t y = vld1q_s32(input + 4); + const int32x4_t z = vld1q_s32(input + 8); + const int32x4_t w = vld1q_s32(input + 12); + input += 16; + + const uint32x4_t x_neg_mask = vcltq_s32(x, vmovq_n_s32(0)); + const uint32x4_t y_neg_mask = vcltq_s32(y, vmovq_n_s32(0)); + const uint32x4_t z_neg_mask = vcltq_s32(z, vmovq_n_s32(0)); + const uint32x4_t w_neg_mask = vcltq_s32(w, vmovq_n_s32(0)); + +#if defined(__aarch64__) + const int64x2_t x01_product = + vmull_s32(vget_low_s32(x), vget_low_s32(vmultiplier)); + const int64x2_t x23_product = vmull_high_s32(x, vmultiplier); + const int64x2_t y01_product = + vmull_s32(vget_low_s32(y), vget_low_s32(vmultiplier)); + const int64x2_t y23_product = vmull_high_s32(y, vmultiplier); + const int64x2_t z01_product = + vmull_s32(vget_low_s32(z), vget_low_s32(vmultiplier)); + const int64x2_t z23_product = vmull_high_s32(z, vmultiplier); + const int64x2_t w01_product = + vmull_s32(vget_low_s32(w), vget_low_s32(vmultiplier)); + const int64x2_t w23_product = vmull_high_s32(w, vmultiplier); +#else + const int64x2_t x01_product = vmull_s32(vget_low_s32(x), vmultiplier); + const int64x2_t x23_product = vmull_s32(vget_high_s32(x), vmultiplier); + const int64x2_t y01_product = vmull_s32(vget_low_s32(y), vmultiplier); + const int64x2_t y23_product = vmull_s32(vget_high_s32(y), vmultiplier); + const int64x2_t z01_product = vmull_s32(vget_low_s32(z), vmultiplier); + const int64x2_t z23_product = vmull_s32(vget_high_s32(z), vmultiplier); + const int64x2_t w01_product = vmull_s32(vget_low_s32(w), vmultiplier); + const int64x2_t w23_product = vmull_s32(vget_high_s32(w), vmultiplier); +#endif + +#if defined(__aarch64__) + const int64x2_t x01_adjusted_product = + vaddw_s32(x01_product, vreinterpret_s32_u32(vget_low_u32(x_neg_mask))); + const int64x2_t x23_adjusted_product = + vaddw_high_s32(x23_product, vreinterpretq_s32_u32(x_neg_mask)); + const int64x2_t y01_adjusted_product = + vaddw_s32(y01_product, vreinterpret_s32_u32(vget_low_u32(y_neg_mask))); + const int64x2_t y23_adjusted_product = + vaddw_high_s32(y23_product, vreinterpretq_s32_u32(y_neg_mask)); + const int64x2_t z01_adjusted_product = + vaddw_s32(z01_product, vreinterpret_s32_u32(vget_low_u32(z_neg_mask))); + const int64x2_t z23_adjusted_product = + vaddw_high_s32(z23_product, vreinterpretq_s32_u32(z_neg_mask)); + const int64x2_t w01_adjusted_product = + vaddw_s32(w01_product, vreinterpret_s32_u32(vget_low_u32(w_neg_mask))); + const int64x2_t w23_adjusted_product = + vaddw_high_s32(w23_product, vreinterpretq_s32_u32(w_neg_mask)); +#else + const int64x2_t x01_adjusted_product = + vaddw_s32(x01_product, vreinterpret_s32_u32(vget_low_u32(x_neg_mask))); + const int64x2_t x23_adjusted_product = + vaddw_s32(x23_product, vreinterpret_s32_u32(vget_high_u32(x_neg_mask))); + const int64x2_t y01_adjusted_product = + vaddw_s32(y01_product, vreinterpret_s32_u32(vget_low_u32(y_neg_mask))); + const int64x2_t y23_adjusted_product = + vaddw_s32(y23_product, vreinterpret_s32_u32(vget_high_u32(y_neg_mask))); + const int64x2_t z01_adjusted_product = + vaddw_s32(z01_product, vreinterpret_s32_u32(vget_low_u32(z_neg_mask))); + const int64x2_t z23_adjusted_product = + vaddw_s32(z23_product, vreinterpret_s32_u32(vget_high_u32(z_neg_mask))); + const int64x2_t w01_adjusted_product = + vaddw_s32(w01_product, vreinterpret_s32_u32(vget_low_u32(w_neg_mask))); + const int64x2_t w23_adjusted_product = + vaddw_s32(w23_product, vreinterpret_s32_u32(vget_high_u32(w_neg_mask))); +#endif + + const int64x2_t x01_scaled = vrshlq_s64(x01_adjusted_product, vshift); + const int64x2_t x23_scaled = vrshlq_s64(x23_adjusted_product, vshift); + const int64x2_t y01_scaled = vrshlq_s64(y01_adjusted_product, vshift); + const int64x2_t y23_scaled = vrshlq_s64(y23_adjusted_product, vshift); + const int64x2_t z01_scaled = vrshlq_s64(z01_adjusted_product, vshift); + const int64x2_t z23_scaled = vrshlq_s64(z23_adjusted_product, vshift); + const int64x2_t w01_scaled = vrshlq_s64(w01_adjusted_product, vshift); + const int64x2_t w23_scaled = vrshlq_s64(w23_adjusted_product, vshift); + +#ifdef __aarch64__ + const int32x4_t x_scaled = vuzp1q_s32( + vreinterpretq_s32_s64(x01_scaled), vreinterpretq_s32_s64(x23_scaled)); + const int32x4_t y_scaled = vuzp1q_s32( + vreinterpretq_s32_s64(y01_scaled), vreinterpretq_s32_s64(y23_scaled)); + const int32x4_t z_scaled = vuzp1q_s32( + vreinterpretq_s32_s64(z01_scaled), vreinterpretq_s32_s64(z23_scaled)); + const int32x4_t w_scaled = vuzp1q_s32( + vreinterpretq_s32_s64(w01_scaled), vreinterpretq_s32_s64(w23_scaled)); + + const int16x8_t xy_packed = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(x_scaled), y_scaled), vzero_point); + const int16x8_t zw_packed = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(z_scaled), w_scaled), vzero_point); + const uint8x16_t xyzw_packed = + vqmovun_high_s16(vqmovun_s16(xy_packed), zw_packed); +#else + const int32x4_t x_scaled = + vcombine_s32(vmovn_s64(x01_scaled), vmovn_s64(x23_scaled)); + const int32x4_t y_scaled = + vcombine_s32(vmovn_s64(y01_scaled), vmovn_s64(y23_scaled)); + const int32x4_t z_scaled = + vcombine_s32(vmovn_s64(z01_scaled), vmovn_s64(z23_scaled)); + const int32x4_t w_scaled = + vcombine_s32(vmovn_s64(w01_scaled), vmovn_s64(w23_scaled)); + + const int16x8_t xy_packed = vqaddq_s16( + vcombine_s16(vqmovn_s32(x_scaled), vqmovn_s32(y_scaled)), vzero_point); + const int16x8_t zw_packed = vqaddq_s16( + vcombine_s16(vqmovn_s32(z_scaled), vqmovn_s32(w_scaled)), vzero_point); + const uint8x16_t xyzw_packed = + vcombine_u8(vqmovun_s16(xy_packed), vqmovun_s16(zw_packed)); +#endif + + const uint8x16_t xyzw_clamped = + vmaxq_u8(vminq_u8(xyzw_packed, vqmax), vqmin); + + /* + * AArch32 version: + * 4x VCLT.S32 Qd, Qm, #0 + * 8x VMULL.S32 Qd, Dm, Dn + * 8x VADDW.S32 Qd, Qm, Dn + * 8x VRSHL.S32 Qd, Qm, Qn + * 8x VMOVN.S64 Dd, Qm + * 4x VQMOVN.S32 Dd, Qm + * 2x VADD.S16 Qd, Qm, Qn + * 2x VQMOVUN.S16 Dd, Qm + * 1x VMAX.U8 Qd, Qm, Qn + * 1x VMIN.U8 Qd, Qm, Qn + * --------------------- + * 46 instructions total + * + * AArch64 version: + * 4x CMLT Vd.4S, Vn.4S, #0 + * 4x SMULL Vd.2D, Vn.2S, Vm.2S + * 4x SMULL2 Vd.2D, Vn.4S, Vm.4S + * 4x SADDW Vd.2D, Vn.2D, Vm.2S + * 4x SADDW2 Vd.2D, Vn.2D, Vm.4S + * 8x SRSHL Vd.2D, Vn.2D, Vm.2D + * 4x UZP1 Vd.4S, Vn.4S, Vm.4S + * 2x SQXTN Vd.4H, Vn.4S + * 2x SQXTN2 Vd.8H, Vn.4S + * 2x ADD Vd.8H, Vn.8H, Vm.8H + * 1x SQXTUN Vd.8B, Vn.8H + * 1x SQXTUN2 Vd.16B, Vn.8H + * 1x UMIN Vd.16B, Vn.16B, Vm.16B + * 1x UMAX Vd.16B, Vn.16B, Vm.16B + * --------------------- + * 42 instructions total + */ + + vst1q_u8(output, xyzw_clamped); + output += 16; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/precise-psimd.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/precise-psimd.c new file mode 100644 index 0000000000000..557dfbf7ae716 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/precise-psimd.c @@ -0,0 +1,165 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include +#include + +void pytorch_qnnp_requantize_precise__psimd( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 16 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const uint32_t scale_bits = fp32_to_bits(scale); + const uint32_t multiplier = (scale_bits << 8) | UINT32_C(0x80000000); + const uint32_t shift = 127 + 31 - (scale_bits >> 23); + assert(shift >= 32); + assert(shift < 64); + const uint64_t rounding = UINT64_C(1) << (shift - 1); + + const psimd_u32 vmultiplier_lo = + psimd_splat_u32(multiplier & UINT32_C(0x0000FFFF)); + const psimd_u32 vmultiplier_hi = psimd_splat_u32(multiplier >> 16); + const psimd_s32 vzero_point = psimd_splat_s32((int32_t)(uint32_t)zero_point); + const psimd_s32 vsmin = + psimd_splat_s32((int32_t)(uint32_t)qmin - (int32_t)(uint32_t)zero_point); + const psimd_s32 vsmax = + psimd_splat_s32((int32_t)(uint32_t)qmax - (int32_t)(uint32_t)zero_point); + const psimd_u32 vrounding_lo = psimd_splat_u32((uint32_t)rounding); + const psimd_u32 vrounding_hi = psimd_splat_u32((uint32_t)(rounding >> 32)); + const psimd_u32 vshift = psimd_splat_u32(shift - 32); + for (; n != 0; n -= 16) { + const psimd_s32 x = psimd_load_s32(input); + const psimd_s32 y = psimd_load_s32(input + 4); + const psimd_s32 z = psimd_load_s32(input + 8); + const psimd_s32 w = psimd_load_s32(input + 12); + input += 16; + + const psimd_s32 x_neg_mask = x >> psimd_splat_s32(31); + const psimd_s32 y_neg_mask = y >> psimd_splat_s32(31); + const psimd_s32 z_neg_mask = z >> psimd_splat_s32(31); + const psimd_s32 w_neg_mask = w >> psimd_splat_s32(31); + + const psimd_u32 x_abs = (psimd_u32)((x ^ x_neg_mask) - x_neg_mask); + const psimd_u32 y_abs = (psimd_u32)((y ^ y_neg_mask) - y_neg_mask); + const psimd_u32 z_abs = (psimd_u32)((z ^ z_neg_mask) - z_neg_mask); + const psimd_u32 w_abs = (psimd_u32)((w ^ w_neg_mask) - w_neg_mask); + + const psimd_u32 x_abs_lo = x_abs & psimd_splat_u32(UINT32_C(0x0000FFFF)); + const psimd_u32 x_abs_hi = x_abs >> psimd_splat_u32(16); + const psimd_u32 y_abs_lo = y_abs & psimd_splat_u32(UINT32_C(0x0000FFFF)); + const psimd_u32 y_abs_hi = y_abs >> psimd_splat_u32(16); + const psimd_u32 z_abs_lo = z_abs & psimd_splat_u32(UINT32_C(0x0000FFFF)); + const psimd_u32 z_abs_hi = z_abs >> psimd_splat_u32(16); + const psimd_u32 w_abs_lo = w_abs & psimd_splat_u32(UINT32_C(0x0000FFFF)); + const psimd_u32 w_abs_hi = w_abs >> psimd_splat_u32(16); + + const psimd_u32 x_product_ll = x_abs_lo * vmultiplier_lo; + const psimd_u32 y_product_ll = y_abs_lo * vmultiplier_lo; + const psimd_u32 z_product_ll = z_abs_lo * vmultiplier_lo; + const psimd_u32 w_product_ll = w_abs_lo * vmultiplier_lo; + + const psimd_u32 x_product_lh = + x_abs_lo * vmultiplier_hi + (x_product_ll >> psimd_splat_u32(16)); + const psimd_u32 y_product_lh = + y_abs_lo * vmultiplier_hi + (y_product_ll >> psimd_splat_u32(16)); + const psimd_u32 z_product_lh = + z_abs_lo * vmultiplier_hi + (z_product_ll >> psimd_splat_u32(16)); + const psimd_u32 w_product_lh = + w_abs_lo * vmultiplier_hi + (w_product_ll >> psimd_splat_u32(16)); + + const psimd_u32 x_product_hl = x_abs_hi * vmultiplier_lo + + (x_product_lh & psimd_splat_u32(UINT32_C(0x0000FFFF))); + const psimd_u32 y_product_hl = y_abs_hi * vmultiplier_lo + + (y_product_lh & psimd_splat_u32(UINT32_C(0x0000FFFF))); + const psimd_u32 z_product_hl = z_abs_hi * vmultiplier_lo + + (z_product_lh & psimd_splat_u32(UINT32_C(0x0000FFFF))); + const psimd_u32 w_product_hl = w_abs_hi * vmultiplier_lo + + (w_product_lh & psimd_splat_u32(UINT32_C(0x0000FFFF))); + + const psimd_u32 x_product_lo = (x_product_hl << psimd_splat_u32(16)) + + (x_product_ll & psimd_splat_u32(UINT32_C(0x0000FFFF))); + const psimd_u32 y_product_lo = (y_product_hl << psimd_splat_u32(16)) + + (y_product_ll & psimd_splat_u32(UINT32_C(0x0000FFFF))); + const psimd_u32 z_product_lo = (z_product_hl << psimd_splat_u32(16)) + + (z_product_ll & psimd_splat_u32(UINT32_C(0x0000FFFF))); + const psimd_u32 w_product_lo = (w_product_hl << psimd_splat_u32(16)) + + (w_product_ll & psimd_splat_u32(UINT32_C(0x0000FFFF))); + + const psimd_u32 x_product_hi = x_abs_hi * vmultiplier_hi + + (x_product_lh >> psimd_splat_u32(16)) + + (x_product_hl >> psimd_splat_u32(16)); + const psimd_u32 y_product_hi = y_abs_hi * vmultiplier_hi + + (y_product_lh >> psimd_splat_u32(16)) + + (y_product_hl >> psimd_splat_u32(16)); + const psimd_u32 z_product_hi = z_abs_hi * vmultiplier_hi + + (z_product_lh >> psimd_splat_u32(16)) + + (z_product_hl >> psimd_splat_u32(16)); + const psimd_u32 w_product_hi = w_abs_hi * vmultiplier_hi + + (w_product_lh >> psimd_splat_u32(16)) + + (w_product_hl >> psimd_splat_u32(16)); + + const psimd_u32 x_adjusted_product = (x_product_hi + vrounding_hi) - + ((psimd_s32)(x_product_lo & vrounding_lo) >> psimd_splat_s32(31)); + const psimd_u32 y_adjusted_product = (y_product_hi + vrounding_hi) - + ((psimd_s32)(y_product_lo & vrounding_lo) >> psimd_splat_s32(31)); + const psimd_u32 z_adjusted_product = (z_product_hi + vrounding_hi) - + ((psimd_s32)(z_product_lo & vrounding_lo) >> psimd_splat_s32(31)); + const psimd_u32 w_adjusted_product = (w_product_hi + vrounding_hi) - + ((psimd_s32)(w_product_lo & vrounding_lo) >> psimd_splat_s32(31)); + + const psimd_u32 x_abs_scaled = x_adjusted_product >> vshift; + const psimd_u32 y_abs_scaled = y_adjusted_product >> vshift; + const psimd_u32 z_abs_scaled = z_adjusted_product >> vshift; + const psimd_u32 w_abs_scaled = w_adjusted_product >> vshift; + + const psimd_s32 x_scaled = + (psimd_s32)(x_abs_scaled ^ x_neg_mask) - x_neg_mask; + const psimd_s32 y_scaled = + (psimd_s32)(y_abs_scaled ^ y_neg_mask) - y_neg_mask; + const psimd_s32 z_scaled = + (psimd_s32)(z_abs_scaled ^ z_neg_mask) - z_neg_mask; + const psimd_s32 w_scaled = + (psimd_s32)(w_abs_scaled ^ w_neg_mask) - w_neg_mask; + + const psimd_u32 x_clamped = + (psimd_u32)psimd_max_s32(psimd_min_s32(x_scaled, vsmax), vsmin) + + vzero_point; + const psimd_u32 y_clamped = + (psimd_u32)psimd_max_s32(psimd_min_s32(y_scaled, vsmax), vsmin) + + vzero_point; + const psimd_u32 z_clamped = + (psimd_u32)psimd_max_s32(psimd_min_s32(z_scaled, vsmax), vsmin) + + vzero_point; + const psimd_u32 w_clamped = + (psimd_u32)psimd_max_s32(psimd_min_s32(w_scaled, vsmax), vsmin) + + vzero_point; + + const psimd_u16 xy_clamped = + psimd_concat_even_u16((psimd_u16)x_clamped, (psimd_u16)y_clamped); + const psimd_u16 zw_clamped = + psimd_concat_even_u16((psimd_u16)z_clamped, (psimd_u16)w_clamped); + + const psimd_u8 xyzw_clamped = + psimd_concat_even_u8((psimd_u8)xy_clamped, (psimd_u8)zw_clamped); + + psimd_store_u8(output, xyzw_clamped); + output += 16; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/precise-scalar.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/precise-scalar.c new file mode 100644 index 0000000000000..5485832932b98 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/precise-scalar.c @@ -0,0 +1,365 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include + +void pytorch_qnnp_requantize_precise__scalar_unsigned32( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 4 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const uint32_t scale_bits = fp32_to_bits(scale); + const uint32_t multiplier = (scale_bits << 8) | UINT32_C(0x80000000); + const uint32_t shift = 127 + 31 - (scale_bits >> 23); + assert(shift >= 32); + assert(shift < 64); + + const uint64_t rounding = UINT64_C(1) << (shift - 1); + const uint32_t rounding_hi = (uint32_t)(rounding >> 32); + const uint32_t rounding_lo = (uint32_t)rounding; + const uint32_t shift_minus_32 = shift - 32; + const int32_t smin = (int32_t)(uint32_t)qmin - (int32_t)(uint32_t)zero_point; + const int32_t smax = (int32_t)(uint32_t)qmax - (int32_t)(uint32_t)zero_point; + for (; n != 0; n -= 4) { + const int32_t x = input[0]; + const int32_t y = input[1]; + const int32_t z = input[2]; + const int32_t w = input[3]; + input += 4; + + /* + * Compute absolute value of input as unsigned 32-bit int. + * All further computations will work with unsigned values to avoid + * undefined behaviour on signed operations. + */ + const uint32_t x_abs = (x >= 0) ? (uint32_t)x : -(uint32_t)x; + const uint32_t y_abs = (y >= 0) ? (uint32_t)y : -(uint32_t)y; + const uint32_t z_abs = (z >= 0) ? (uint32_t)z : -(uint32_t)z; + const uint32_t w_abs = (w >= 0) ? (uint32_t)w : -(uint32_t)w; + + /* Compute full 64-bit product of 32-bit factors */ + const uint64_t x_product = (uint64_t)x_abs * (uint64_t)multiplier; + const uint64_t y_product = (uint64_t)y_abs * (uint64_t)multiplier; + const uint64_t z_product = (uint64_t)z_abs * (uint64_t)multiplier; + const uint64_t w_product = (uint64_t)w_abs * (uint64_t)multiplier; + + /* + * Shift the full 64-bit product right with rounding. + * Rounding is performed towards closest integer, with midpoints rounded up + * (same as away from zero). + * + * Generally, this operation requires both 64-bit addition and 64-bit shift, + * but we use two tricks to replace 64-bit operations with 32-bit + * operations. + * + * To avoid full 64-bit addition we make use of three facts: + * - 64-bit rounding value added before the shift is a power of 2, and thus + * has only one bit set. + * - When 0x1.0p-32f <= scale < 0x1.0p-31f, then the non-zero bit in + * rounding is in the low 32 bits, and rounding is exactly 0x80000000 + * (2**31), because rounding is 2**(scale-1) and scale >= 32. In this case, + * addition of rounding can affect high 32 bits of the product only + * through overflow, which happens if low 32-bit part of the product equals + * or exceeds 0x80000000. We can reformulate the latter condition as low + * 32-bit part of the product has the bit 31 set, and then overflow happens + * if both the low 32-bit part of the product and the low 32-bit part of the + * rounding value have bit 31 set. Since 32-bit numbers with the bit 31 set + * are negative when interpreted as signed integers, we can check the + * overflow condition as (int32_t) (LOW(product) & LOW(rounding)) < 0 + * - When 0x1.0p-31f <= scale < 1.0f, then the non-zero bit is in the high + * 32 bits of rounding. We just need to do 32-bit addition of high 32 bits + * of rounding and high 32 bits of product. This addition never overflows + * because product <= 0x80000000 * 0xFFFFFF00 < 2**63 and rounding = + * 2**(scale-1) <= 2**62. + * + * To avoid full 64-bit shift, we leverage the fact that shift >= 32, and do + * it in two steps: + * - Shift by 32, which can be implemented by extacting the high 32-bit word + * on 32-bit systems. + * - Shift by (shift - 32), which can be implemented as a 32-bit shift of + * high word of addition result. + */ + const uint32_t x_carry_lo = + (uint32_t)((int32_t)((uint32_t)x_product & rounding_lo) < 0); + const uint32_t y_carry_lo = + (uint32_t)((int32_t)((uint32_t)y_product & rounding_lo) < 0); + const uint32_t z_carry_lo = + (uint32_t)((int32_t)((uint32_t)z_product & rounding_lo) < 0); + const uint32_t w_carry_lo = + (uint32_t)((int32_t)((uint32_t)w_product & rounding_lo) < 0); + + const uint32_t x_product_hi = (uint32_t)(x_product >> 32); + const uint32_t y_product_hi = (uint32_t)(y_product >> 32); + const uint32_t z_product_hi = (uint32_t)(z_product >> 32); + const uint32_t w_product_hi = (uint32_t)(w_product >> 32); + + const uint32_t x_abs_scaled = + (uint32_t)(x_product_hi + rounding_hi + x_carry_lo) >> shift_minus_32; + const uint32_t y_abs_scaled = + (uint32_t)(y_product_hi + rounding_hi + y_carry_lo) >> shift_minus_32; + const uint32_t z_abs_scaled = + (uint32_t)(z_product_hi + rounding_hi + z_carry_lo) >> shift_minus_32; + const uint32_t w_abs_scaled = + (uint32_t)(w_product_hi + rounding_hi + w_carry_lo) >> shift_minus_32; + + /* Copy the sign of input to scaled absolute input value */ + const int32_t x_scaled = (int32_t)(x >= 0 ? x_abs_scaled : -x_abs_scaled); + const int32_t y_scaled = (int32_t)(y >= 0 ? y_abs_scaled : -y_abs_scaled); + const int32_t z_scaled = (int32_t)(z >= 0 ? z_abs_scaled : -z_abs_scaled); + const int32_t w_scaled = (int32_t)(w >= 0 ? w_abs_scaled : -w_abs_scaled); + + /* + * Clamp scaled value with zero point between (qmin - zero point) and (qmax + * - zero point). + */ + const int32_t x_clamped = + x_scaled < smin ? smin : x_scaled > smax ? smax : x_scaled; + const int32_t y_clamped = + y_scaled < smin ? smin : y_scaled > smax ? smax : y_scaled; + const int32_t z_clamped = + z_scaled < smin ? smin : z_scaled > smax ? smax : z_scaled; + const int32_t w_clamped = + w_scaled < smin ? smin : w_scaled > smax ? smax : w_scaled; + + /* + * Add zero point to clamped value. + * The result is guaranteed to be in [qmin, qmax] range. + * + * This addition can not be safely done before clamping, because scaled + * values are in [-2147483520, 2147483519] range, so addition of zero point + * (which can be up to 255) can overflow signed 32-bit integer. + */ + const int32_t x_biased = x_clamped + zero_point; + const int32_t y_biased = y_clamped + zero_point; + const int32_t z_biased = z_clamped + zero_point; + const int32_t w_biased = w_clamped + zero_point; + + output[0] = (uint8_t)x_biased; + output[1] = (uint8_t)y_biased; + output[2] = (uint8_t)z_biased; + output[3] = (uint8_t)w_biased; + output += 4; + } +} + +void pytorch_qnnp_requantize_precise__scalar_unsigned64( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 4 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const uint32_t scale_bits = fp32_to_bits(scale); + const uint32_t multiplier = + (scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000); + const uint32_t shift = 127 + 23 - (scale_bits >> 23); + assert(shift >= 24); + assert(shift < 56); + + const uint64_t rounding = UINT64_C(1) << (shift - 1); + const int32_t smin = (int32_t)(uint32_t)qmin - (int32_t)(uint32_t)zero_point; + const int32_t smax = (int32_t)(uint32_t)qmax - (int32_t)(uint32_t)zero_point; + for (; n != 0; n -= 4) { + const int32_t x = input[0]; + const int32_t y = input[1]; + const int32_t z = input[2]; + const int32_t w = input[3]; + input += 4; + + /* + * Compute absolute value of input as unsigned 32-bit int. + * All further computations will work with unsigned values to avoid + * undefined behaviour on signed operations. + */ + const uint32_t x_abs = (x >= 0) ? (uint32_t)x : -(uint32_t)x; + const uint32_t y_abs = (y >= 0) ? (uint32_t)y : -(uint32_t)y; + const uint32_t z_abs = (z >= 0) ? (uint32_t)z : -(uint32_t)z; + const uint32_t w_abs = (w >= 0) ? (uint32_t)w : -(uint32_t)w; + + /* Compute full 64-bit product of 32-bit factors */ + const uint64_t x_product = (uint64_t)x_abs * (uint64_t)multiplier; + const uint64_t y_product = (uint64_t)y_abs * (uint64_t)multiplier; + const uint64_t z_product = (uint64_t)z_abs * (uint64_t)multiplier; + const uint64_t w_product = (uint64_t)w_abs * (uint64_t)multiplier; + + /* + * Shift the full 64-bit product right with rounding. + * Rounding is performed towards closest integer, with midpoints rounded up + * (same as away from zero). + * + * Note that although rounding is precomputed, it is dependent on shift + * value, and on processors with 64-bit "right shift with rounding" + * instruction each line below can be represented by just one such + * instruction (e.g. VRSHL.U64 on ARM NEON, URSHL in ARM64 Advanced SIMD). + */ + const uint32_t x_abs_scaled = (uint32_t)((x_product + rounding) >> shift); + const uint32_t y_abs_scaled = (uint32_t)((y_product + rounding) >> shift); + const uint32_t z_abs_scaled = (uint32_t)((z_product + rounding) >> shift); + const uint32_t w_abs_scaled = (uint32_t)((w_product + rounding) >> shift); + + /* + * Copy the sign of input to scaled absolute input value. + * + * On x86 processors with SSSE3 instruction set, this operation nicely maps + * to PSIGND instruction. + */ + const int32_t x_scaled = (int32_t)(x >= 0 ? x_abs_scaled : -x_abs_scaled); + const int32_t y_scaled = (int32_t)(y >= 0 ? y_abs_scaled : -y_abs_scaled); + const int32_t z_scaled = (int32_t)(z >= 0 ? z_abs_scaled : -z_abs_scaled); + const int32_t w_scaled = (int32_t)(w >= 0 ? w_abs_scaled : -w_abs_scaled); + + /* + * Clamp scaled value with zero point between (qmin - zero point) and (qmax + * - zero point). + */ + const int32_t x_clamped = + x_scaled < smin ? smin : x_scaled > smax ? smax : x_scaled; + const int32_t y_clamped = + y_scaled < smin ? smin : y_scaled > smax ? smax : y_scaled; + const int32_t z_clamped = + z_scaled < smin ? smin : z_scaled > smax ? smax : z_scaled; + const int32_t w_clamped = + w_scaled < smin ? smin : w_scaled > smax ? smax : w_scaled; + + /* + * Add zero point to clamped value. + * The result is guaranteed to be in [qmin, qmax] range. + * + * This addition can not be safely done before clamping, because scaled + * values are in [-2147483520, 2147483519] range, so addition of zero point + * (which can be up to 255) can overflow signed 32-bit integer. + */ + const int32_t x_biased = x_clamped + zero_point; + const int32_t y_biased = y_clamped + zero_point; + const int32_t z_biased = z_clamped + zero_point; + const int32_t w_biased = w_clamped + zero_point; + + output[0] = (uint8_t)x_biased; + output[1] = (uint8_t)y_biased; + output[2] = (uint8_t)z_biased; + output[3] = (uint8_t)w_biased; + output += 4; + } +} + +void pytorch_qnnp_requantize_precise__scalar_signed64( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 4 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const uint32_t scale_bits = fp32_to_bits(scale); + const int32_t multiplier = + ((int32_t)scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000); + const uint32_t shift = 127 + 23 - (scale_bits >> 23); + assert(shift >= 24); + assert(shift < 56); + + const int64_t rounding = INT64_C(1) << (shift - 1); + const int32_t smin = (int32_t)(uint32_t)qmin - (int32_t)(uint32_t)zero_point; + const int32_t smax = (int32_t)(uint32_t)qmax - (int32_t)(uint32_t)zero_point; + for (; n != 0; n -= 4) { + const int32_t x = input[0]; + const int32_t y = input[1]; + const int32_t z = input[2]; + const int32_t w = input[3]; + input += 4; + + /* + * Compute full 64-bit product of signed 32-bit factors. + * + * Note: multiplier can be treated as either signed or unsigned. + */ + const int64_t x_product = (int64_t)x * (int64_t)multiplier; + const int64_t y_product = (int64_t)y * (int64_t)multiplier; + const int64_t z_product = (int64_t)z * (int64_t)multiplier; + const int64_t w_product = (int64_t)w * (int64_t)multiplier; + + /* + * Adjust product before subsequent shift with rounding up to simulate shift + * with rounding away from zero. + */ + const int64_t x_adjusted_product = x_product - (int64_t)(x < 0); + const int64_t y_adjusted_product = y_product - (int64_t)(y < 0); + const int64_t z_adjusted_product = z_product - (int64_t)(z < 0); + const int64_t w_adjusted_product = w_product - (int64_t)(w < 0); + + /* + * Arithmetically shift the full 64-bit product right with rounding. + * Rounding is performed towards closest integer, with midpoints rounded up. + * + * Note that although rounding is precomputed, it is dependent on shift + * value, and on processors with 64-bit "right shift with rounding" + * instruction each line below can be represented by just one such + * instruction (e.g. VRSHL.S64 on ARM NEON, SRSHL in ARM64 Advanced SIMD). + */ + const int32_t x_scaled = + (int32_t)asr_s64(x_adjusted_product + rounding, shift); + const int32_t y_scaled = + (int32_t)asr_s64(y_adjusted_product + rounding, shift); + const int32_t z_scaled = + (int32_t)asr_s64(z_adjusted_product + rounding, shift); + const int32_t w_scaled = + (int32_t)asr_s64(w_adjusted_product + rounding, shift); + + /* + * Clamp scaled value with zero point between (qmin - zero point) and (qmax + * - zero point). + */ + const int32_t x_clamped = + x_scaled < smin ? smin : x_scaled > smax ? smax : x_scaled; + const int32_t y_clamped = + y_scaled < smin ? smin : y_scaled > smax ? smax : y_scaled; + const int32_t z_clamped = + z_scaled < smin ? smin : z_scaled > smax ? smax : z_scaled; + const int32_t w_clamped = + w_scaled < smin ? smin : w_scaled > smax ? smax : w_scaled; + + /* + * Add zero point to clamped value. + * The result is guaranteed to be in [qmin, qmax] range. + * + * This addition can not be safely done before clamping, because scaled + * values are in [-2147483520, 2147483519] range, so addition of zero point + * (which can be up to 255) can overflow signed 32-bit integer. + */ + const int32_t x_biased = x_clamped + zero_point; + const int32_t y_biased = y_clamped + zero_point; + const int32_t z_biased = z_clamped + zero_point; + const int32_t w_biased = w_clamped + zero_point; + + output[0] = (uint8_t)x_biased; + output[1] = (uint8_t)y_biased; + output[2] = (uint8_t)z_biased; + output[3] = (uint8_t)w_biased; + output += 4; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/precise-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/precise-sse2.c new file mode 100644 index 0000000000000..4e859bf51b4a3 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/precise-sse2.c @@ -0,0 +1,164 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include +#include + +void pytorch_qnnp_requantize_precise__sse2( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 16 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const uint32_t scale_bits = fp32_to_bits(scale); + const uint32_t multiplier = + (scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000); + const uint32_t shift = 127 + 23 - (scale_bits >> 23); + assert(shift >= 24); + assert(shift < 56); + const uint64_t rounding = UINT64_C(1) << (shift - 1); + + const __m128i vmultiplier = _mm_set1_epi32(multiplier); + const __m128i vzero_point = _mm_set1_epi16((short)(uint16_t)zero_point); + const __m128i vqmin = _mm_set1_epi8((char)qmin); + const __m128i vqmax = _mm_set1_epi8((char)qmax); + const __m128i vshift = _mm_cvtsi32_si128((int)shift); + const __m128i vrounding = _mm_set1_epi64x(rounding); + for (; n != 0; n -= 16) { + const __m128i x = _mm_loadu_si128((const __m128i*)input); + const __m128i y = _mm_loadu_si128((const __m128i*)(input + 4)); + const __m128i z = _mm_loadu_si128((const __m128i*)(input + 8)); + const __m128i w = _mm_loadu_si128((const __m128i*)(input + 12)); + input += 16; + + const __m128i x_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), x); + const __m128i y_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), y); + const __m128i z_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), z); + const __m128i w_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), w); + + const __m128i x_abs0123 = + _mm_sub_epi32(_mm_xor_si128(x, x_neg_mask), x_neg_mask); + const __m128i y_abs0123 = + _mm_sub_epi32(_mm_xor_si128(y, y_neg_mask), y_neg_mask); + const __m128i z_abs0123 = + _mm_sub_epi32(_mm_xor_si128(z, z_neg_mask), z_neg_mask); + const __m128i w_abs0123 = + _mm_sub_epi32(_mm_xor_si128(w, w_neg_mask), w_neg_mask); + + const __m128i x_abs1032 = + _mm_shuffle_epi32(x_abs0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i y_abs1032 = + _mm_shuffle_epi32(y_abs0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i z_abs1032 = + _mm_shuffle_epi32(z_abs0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i w_abs1032 = + _mm_shuffle_epi32(w_abs0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i x_absmul02 = _mm_mul_epu32(x_abs0123, vmultiplier); + const __m128i y_absmul02 = _mm_mul_epu32(y_abs0123, vmultiplier); + const __m128i z_absmul02 = _mm_mul_epu32(z_abs0123, vmultiplier); + const __m128i w_absmul02 = _mm_mul_epu32(w_abs0123, vmultiplier); + + const __m128i x_absmul13 = _mm_mul_epu32(x_abs1032, vmultiplier); + const __m128i y_absmul13 = _mm_mul_epu32(y_abs1032, vmultiplier); + const __m128i z_absmul13 = _mm_mul_epu32(z_abs1032, vmultiplier); + const __m128i w_absmul13 = _mm_mul_epu32(w_abs1032, vmultiplier); + + const __m128i x_abs_scaled02 = + _mm_srl_epi64(_mm_add_epi64(x_absmul02, vrounding), vshift); + const __m128i x_abs_scaled13 = + _mm_srl_epi64(_mm_add_epi64(x_absmul13, vrounding), vshift); + const __m128i y_abs_scaled02 = + _mm_srl_epi64(_mm_add_epi64(y_absmul02, vrounding), vshift); + const __m128i y_abs_scaled13 = + _mm_srl_epi64(_mm_add_epi64(y_absmul13, vrounding), vshift); + const __m128i z_abs_scaled02 = + _mm_srl_epi64(_mm_add_epi64(z_absmul02, vrounding), vshift); + const __m128i z_abs_scaled13 = + _mm_srl_epi64(_mm_add_epi64(z_absmul13, vrounding), vshift); + const __m128i w_abs_scaled02 = + _mm_srl_epi64(_mm_add_epi64(w_absmul02, vrounding), vshift); + const __m128i w_abs_scaled13 = + _mm_srl_epi64(_mm_add_epi64(w_absmul13, vrounding), vshift); + + const __m128i x_abs_scaled0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(x_abs_scaled02), + _mm_castsi128_ps(x_abs_scaled13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i y_abs_scaled0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(y_abs_scaled02), + _mm_castsi128_ps(y_abs_scaled13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i z_abs_scaled0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(z_abs_scaled02), + _mm_castsi128_ps(z_abs_scaled13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i w_abs_scaled0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(w_abs_scaled02), + _mm_castsi128_ps(w_abs_scaled13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i x_abs_scaled = + _mm_shuffle_epi32(x_abs_scaled0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i y_abs_scaled = + _mm_shuffle_epi32(y_abs_scaled0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i z_abs_scaled = + _mm_shuffle_epi32(z_abs_scaled0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i w_abs_scaled = + _mm_shuffle_epi32(w_abs_scaled0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i x_scaled = + _mm_sub_epi32(_mm_xor_si128(x_abs_scaled, x_neg_mask), x_neg_mask); + const __m128i y_scaled = + _mm_sub_epi32(_mm_xor_si128(y_abs_scaled, y_neg_mask), y_neg_mask); + const __m128i z_scaled = + _mm_sub_epi32(_mm_xor_si128(z_abs_scaled, z_neg_mask), z_neg_mask); + const __m128i w_scaled = + _mm_sub_epi32(_mm_xor_si128(w_abs_scaled, w_neg_mask), w_neg_mask); + + const __m128i xy_packed = + _mm_adds_epi16(_mm_packs_epi32(x_scaled, y_scaled), vzero_point); + const __m128i zw_packed = + _mm_adds_epi16(_mm_packs_epi32(z_scaled, w_scaled), vzero_point); + const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed); + const __m128i xyzw_clamped = + _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin); + + /* + * 4x PXOR (setzero) + * 8x PSUBD + * 8x PXOR + * 8x PSHUFD + * 8x PMULUDQ + * 8x PSRLQ + * 8x PADDQ + * 4x SHUFPS + * 2x PACKSSDW + * 1x PACKUSWB + * 2x PADDW + * 1x PMAXUB + * 1x PMINUB + * --------------------- + * 63 instructions total + */ + + _mm_storeu_si128((__m128i*)output, xyzw_clamped); + output += 16; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/precise-sse4.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/precise-sse4.c new file mode 100644 index 0000000000000..adef737e7c08d --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/precise-sse4.c @@ -0,0 +1,134 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include +#include + +void pytorch_qnnp_requantize_precise__sse4( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 16 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const uint32_t scale_bits = fp32_to_bits(scale); + const uint32_t multiplier = (scale_bits << 8) | UINT32_C(0x80000000); + const uint32_t shift = 127 + 31 - (scale_bits >> 23); + assert(shift >= 32); + assert(shift < 64); + const uint64_t rounding = UINT64_C(1) << (shift - 1); + + const __m128i vmultiplier = _mm_set1_epi32(multiplier); + const __m128i vzero_point = _mm_set1_epi16((short)(uint16_t)zero_point); + const __m128i vqmin = _mm_set1_epi8((char)qmin); + const __m128i vqmax = _mm_set1_epi8((char)qmax); + const __m128i vshiftlo = _mm_cvtsi32_si128((int)shift); + const __m128i vshifthi = _mm_cvtsi32_si128((int)shift - 32); + const __m128i vrounding = _mm_set1_epi64x(rounding); + for (; n != 0; n -= 16) { + const __m128i x = _mm_loadu_si128((const __m128i*)input); + const __m128i y = _mm_loadu_si128((const __m128i*)(input + 4)); + const __m128i z = _mm_loadu_si128((const __m128i*)(input + 8)); + const __m128i w = _mm_loadu_si128((const __m128i*)(input + 12)); + input += 16; + + const __m128i x_abs0123 = _mm_abs_epi32(x); + const __m128i y_abs0123 = _mm_abs_epi32(y); + const __m128i z_abs0123 = _mm_abs_epi32(z); + const __m128i w_abs0123 = _mm_abs_epi32(w); + + const __m128i x_abs1032 = + _mm_shuffle_epi32(x_abs0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i y_abs1032 = + _mm_shuffle_epi32(y_abs0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i z_abs1032 = + _mm_shuffle_epi32(z_abs0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i w_abs1032 = + _mm_shuffle_epi32(w_abs0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i x_absmul02 = _mm_mul_epu32(x_abs0123, vmultiplier); + const __m128i y_absmul02 = _mm_mul_epu32(y_abs0123, vmultiplier); + const __m128i z_absmul02 = _mm_mul_epu32(z_abs0123, vmultiplier); + const __m128i w_absmul02 = _mm_mul_epu32(w_abs0123, vmultiplier); + + const __m128i x_absmul13 = _mm_mul_epu32(x_abs1032, vmultiplier); + const __m128i y_absmul13 = _mm_mul_epu32(y_abs1032, vmultiplier); + const __m128i z_absmul13 = _mm_mul_epu32(z_abs1032, vmultiplier); + const __m128i w_absmul13 = _mm_mul_epu32(w_abs1032, vmultiplier); + + const __m128i x_abs_scaled02 = + _mm_srl_epi64(_mm_add_epi64(x_absmul02, vrounding), vshiftlo); + const __m128i x_abs_scaled13 = + _mm_srl_epi32(_mm_add_epi64(x_absmul13, vrounding), vshifthi); + const __m128i y_abs_scaled02 = + _mm_srl_epi64(_mm_add_epi64(y_absmul02, vrounding), vshiftlo); + const __m128i y_abs_scaled13 = + _mm_srl_epi32(_mm_add_epi64(y_absmul13, vrounding), vshifthi); + const __m128i z_abs_scaled02 = + _mm_srl_epi64(_mm_add_epi64(z_absmul02, vrounding), vshiftlo); + const __m128i z_abs_scaled13 = + _mm_srl_epi32(_mm_add_epi64(z_absmul13, vrounding), vshifthi); + const __m128i w_abs_scaled02 = + _mm_srl_epi64(_mm_add_epi64(w_absmul02, vrounding), vshiftlo); + const __m128i w_abs_scaled13 = + _mm_srl_epi32(_mm_add_epi64(w_absmul13, vrounding), vshifthi); + + const __m128i x_abs_scaled = + _mm_blend_epi16(x_abs_scaled02, x_abs_scaled13, 0xCC); + const __m128i y_abs_scaled = + _mm_blend_epi16(y_abs_scaled02, y_abs_scaled13, 0xCC); + const __m128i z_abs_scaled = + _mm_blend_epi16(z_abs_scaled02, z_abs_scaled13, 0xCC); + const __m128i w_abs_scaled = + _mm_blend_epi16(w_abs_scaled02, w_abs_scaled13, 0xCC); + + const __m128i x_scaled = _mm_sign_epi32(x_abs_scaled, x); + const __m128i y_scaled = _mm_sign_epi32(y_abs_scaled, y); + const __m128i z_scaled = _mm_sign_epi32(z_abs_scaled, z); + const __m128i w_scaled = _mm_sign_epi32(w_abs_scaled, w); + + const __m128i xy_packed = + _mm_adds_epi16(_mm_packs_epi32(x_scaled, y_scaled), vzero_point); + const __m128i zw_packed = + _mm_adds_epi16(_mm_packs_epi32(z_scaled, w_scaled), vzero_point); + const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed); + const __m128i xyzw_clamped = + _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin); + + /* + * 4x PABSD + * 4x PSHUFD + * 8x PMULUDQ + * 4x PSRLQ + * 4x PSRLD + * 8x PADDQ + * 4x PBLENDW + * 4x PSIGND + * 2x PACKSSDW + * 1x PACKUSWB + * 2x PADDW + * 1x PMAXUB + * 1x PMINUB + * --------------------- + * 47 instructions total + */ + + _mm_storeu_si128((__m128i*)output, xyzw_clamped); + output += 16; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/precise-ssse3.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/precise-ssse3.c new file mode 100644 index 0000000000000..21c94c5acd8e2 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/precise-ssse3.c @@ -0,0 +1,150 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include +#include + +void pytorch_qnnp_requantize_precise__ssse3( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 16 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + const uint32_t scale_bits = fp32_to_bits(scale); + const uint32_t multiplier = + (scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000); + const uint32_t shift = 127 + 23 - (scale_bits >> 23); + assert(shift >= 24); + assert(shift < 56); + const uint64_t rounding = UINT64_C(1) << (shift - 1); + + const __m128i vmultiplier = _mm_set1_epi32(multiplier); + const __m128i vzero_point = _mm_set1_epi16((short)(uint16_t)zero_point); + const __m128i vqmin = _mm_set1_epi8((char)qmin); + const __m128i vqmax = _mm_set1_epi8((char)qmax); + const __m128i vshift = _mm_cvtsi32_si128((int)shift); + const __m128i vrounding = _mm_set1_epi64x(rounding); + for (; n != 0; n -= 16) { + const __m128i x = _mm_loadu_si128((const __m128i*)input); + const __m128i y = _mm_loadu_si128((const __m128i*)(input + 4)); + const __m128i z = _mm_loadu_si128((const __m128i*)(input + 8)); + const __m128i w = _mm_loadu_si128((const __m128i*)(input + 12)); + input += 16; + + const __m128i x_abs0123 = _mm_abs_epi32(x); + const __m128i y_abs0123 = _mm_abs_epi32(y); + const __m128i z_abs0123 = _mm_abs_epi32(z); + const __m128i w_abs0123 = _mm_abs_epi32(w); + + const __m128i x_abs1032 = + _mm_shuffle_epi32(x_abs0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i y_abs1032 = + _mm_shuffle_epi32(y_abs0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i z_abs1032 = + _mm_shuffle_epi32(z_abs0123, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i w_abs1032 = + _mm_shuffle_epi32(w_abs0123, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i x_absmul02 = _mm_mul_epu32(x_abs0123, vmultiplier); + const __m128i y_absmul02 = _mm_mul_epu32(y_abs0123, vmultiplier); + const __m128i z_absmul02 = _mm_mul_epu32(z_abs0123, vmultiplier); + const __m128i w_absmul02 = _mm_mul_epu32(w_abs0123, vmultiplier); + + const __m128i x_absmul13 = _mm_mul_epu32(x_abs1032, vmultiplier); + const __m128i y_absmul13 = _mm_mul_epu32(y_abs1032, vmultiplier); + const __m128i z_absmul13 = _mm_mul_epu32(z_abs1032, vmultiplier); + const __m128i w_absmul13 = _mm_mul_epu32(w_abs1032, vmultiplier); + + const __m128i x_abs_scaled02 = + _mm_srl_epi64(_mm_add_epi64(x_absmul02, vrounding), vshift); + const __m128i x_abs_scaled13 = + _mm_srl_epi64(_mm_add_epi64(x_absmul13, vrounding), vshift); + const __m128i y_abs_scaled02 = + _mm_srl_epi64(_mm_add_epi64(y_absmul02, vrounding), vshift); + const __m128i y_abs_scaled13 = + _mm_srl_epi64(_mm_add_epi64(y_absmul13, vrounding), vshift); + const __m128i z_abs_scaled02 = + _mm_srl_epi64(_mm_add_epi64(z_absmul02, vrounding), vshift); + const __m128i z_abs_scaled13 = + _mm_srl_epi64(_mm_add_epi64(z_absmul13, vrounding), vshift); + const __m128i w_abs_scaled02 = + _mm_srl_epi64(_mm_add_epi64(w_absmul02, vrounding), vshift); + const __m128i w_abs_scaled13 = + _mm_srl_epi64(_mm_add_epi64(w_absmul13, vrounding), vshift); + + const __m128i x_abs_scaled0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(x_abs_scaled02), + _mm_castsi128_ps(x_abs_scaled13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i y_abs_scaled0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(y_abs_scaled02), + _mm_castsi128_ps(y_abs_scaled13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i z_abs_scaled0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(z_abs_scaled02), + _mm_castsi128_ps(z_abs_scaled13), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i w_abs_scaled0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(w_abs_scaled02), + _mm_castsi128_ps(w_abs_scaled13), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i x_abs_scaled = + _mm_shuffle_epi32(x_abs_scaled0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i y_abs_scaled = + _mm_shuffle_epi32(y_abs_scaled0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i z_abs_scaled = + _mm_shuffle_epi32(z_abs_scaled0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i w_abs_scaled = + _mm_shuffle_epi32(w_abs_scaled0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i x_scaled = _mm_sign_epi32(x_abs_scaled, x); + const __m128i y_scaled = _mm_sign_epi32(y_abs_scaled, y); + const __m128i z_scaled = _mm_sign_epi32(z_abs_scaled, z); + const __m128i w_scaled = _mm_sign_epi32(w_abs_scaled, w); + + const __m128i xy_packed = + _mm_adds_epi16(_mm_packs_epi32(x_scaled, y_scaled), vzero_point); + const __m128i zw_packed = + _mm_adds_epi16(_mm_packs_epi32(z_scaled, w_scaled), vzero_point); + const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed); + const __m128i xyzw_clamped = + _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin); + + /* + * 4x PABSD + * 8x PSHUFD + * 8x PMULUDQ + * 8x PSRLQ + * 8x PADDQ + * 4x SHUFPS + * 4x PSIGND + * 2x PACKSSDW + * 1x PACKUSWB + * 2x PADDW + * 1x PMAXUB + * 1x PMINUB + * --------------------- + * 51 instructions total + */ + + _mm_storeu_si128((__m128i*)output, xyzw_clamped); + output += 16; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-neon.c new file mode 100644 index 0000000000000..fcbedcad355ff --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-neon.c @@ -0,0 +1,144 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include +#include + +void pytorch_qnnp_requantize_q31__neon( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 16 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + /* Compute requantization parameters */ + const uint32_t scale_bits = fp32_to_bits(scale); + + /* Multiplier is in [0x40000000, 0x7FFFFF80] range */ + const int32_t multiplier = (int32_t)( + ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); + assert(multiplier >= INT32_C(0x40000000)); + assert(multiplier <= INT32_C(0x7FFFFF80)); + + /* Shift is in [0, 31] range */ + const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23); + assert(shift >= 0); + assert(shift < 32); + + const int32x4_t vmultiplier = vdupq_n_s32(multiplier); + const int16x8_t vzero_point = vdupq_n_s16((int16_t)(uint16_t)zero_point); + const int32x4_t vshift = vdupq_n_s32(-shift); + const int32x4_t vshift_eq_0_mask = + vreinterpretq_s32_u32(vceqq_s32(vshift, vmovq_n_s32(0))); + const uint8x16_t vqmin = vdupq_n_u8(qmin); + const uint8x16_t vqmax = vdupq_n_u8(qmax); + for (; n != 0; n -= 16) { + const int32x4_t x = vld1q_s32(input); + const int32x4_t y = vld1q_s32(input + 4); + const int32x4_t z = vld1q_s32(input + 8); + const int32x4_t w = vld1q_s32(input + 12); + input += 16; + + /* + * Directly use VQRDMULH/SQRDMULH instruction for Q31 multiplication with + * rounding. Although these instruction saturate out-of-range outputs, we + * never hit this case in requantization. + */ + const int32x4_t x_product = vqrdmulhq_s32(x, vmultiplier); + const int32x4_t y_product = vqrdmulhq_s32(y, vmultiplier); + const int32x4_t z_product = vqrdmulhq_s32(z, vmultiplier); + const int32x4_t w_product = vqrdmulhq_s32(w, vmultiplier); + + /* + * Shift the 32-bit product right with rounding. + * Rounding is performed towards closest integer, with midpoints rounded up + * (same as away from zero). + * + * We leverage the "right shift with rounding" instruction (VRSHL.S32 on ARM + * NEON, SRSHL in ARM64 Advanced SIMD) to do the shift. However, as this + * instruction rounds midpoints up, rather than away from zero, we adjust + * the input by subtracting 1 from negative values, but only if shift is + * non-zero. + */ + const int32x4_t x_adjusted_product = + vsraq_n_s32(x_product, vbicq_s32(x, vshift_eq_0_mask), 31); + const int32x4_t y_adjusted_product = + vsraq_n_s32(y_product, vbicq_s32(y, vshift_eq_0_mask), 31); + const int32x4_t z_adjusted_product = + vsraq_n_s32(z_product, vbicq_s32(z, vshift_eq_0_mask), 31); + const int32x4_t w_adjusted_product = + vsraq_n_s32(w_product, vbicq_s32(w, vshift_eq_0_mask), 31); + + const int32x4_t x_scaled = vrshlq_s32(x_adjusted_product, vshift); + const int32x4_t y_scaled = vrshlq_s32(y_adjusted_product, vshift); + const int32x4_t z_scaled = vrshlq_s32(z_adjusted_product, vshift); + const int32x4_t w_scaled = vrshlq_s32(w_adjusted_product, vshift); + +#ifdef __aarch64__ + const int16x8_t xy_packed = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(x_scaled), y_scaled), vzero_point); + const int16x8_t zw_packed = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(z_scaled), w_scaled), vzero_point); + const uint8x16_t xyzw_packed = + vqmovun_high_s16(vqmovun_s16(xy_packed), zw_packed); +#else + const int16x8_t xy_packed = vqaddq_s16( + vcombine_s16(vqmovn_s32(x_scaled), vqmovn_s32(y_scaled)), vzero_point); + const int16x8_t zw_packed = vqaddq_s16( + vcombine_s16(vqmovn_s32(z_scaled), vqmovn_s32(w_scaled)), vzero_point); + const uint8x16_t xyzw_packed = + vcombine_u8(vqmovun_s16(xy_packed), vqmovun_s16(zw_packed)); +#endif + + const uint8x16_t xyzw_clamped = + vmaxq_u8(vminq_u8(xyzw_packed, vqmax), vqmin); + + /* + * AArch32 version: + * 4x VQRDMULH.S32 Qd, Qm, Qn + * 4x VAND Qd, Qm, Dn + * 4x VSRA.S32 Qd, Qm, #31 + * 4x VRSHL.S32 Qd, Qm, Qn + * 4x VQMOVN.S32 Dd, Qm + * 2x VADD.S16 Qd, Qm, Qn + * 2x VQMOVUN.S16 Dd, Qm + * 1x VMAX.U8 Qd, Qm, Qn + * 1x VMIN.U8 Qd, Qm, Qn + * --------------------- + * 26 instructions total + * + * AArch64 version: + * 4x SQRDMULH Vd.4S, Vn.4S, Vm.4S + * 4x AND Vd.16B, Vn.16B, Vm.16B + * 4x SSRA Vd.4S, Vn.4S, #31 + * 4x SRSHL Vd.4S, Vn.4S, Vm.4S + * 2x SQXTN Vd.4H, Vn.4S + * 2x SQXTN2 Vd.8H, Vn.4S + * 2x ADD Vd.8H, Vn.8H, Vm.8H + * 1x SQXTUN Vd.8B, Vn.8H + * 1x SQXTUN2 Vd.16B, Vn.8H + * 1x UMIN Vd.16B, Vn.16B, Vm.16B + * 1x UMAX Vd.16B, Vn.16B, Vm.16B + * --------------------- + * 26 instructions total + */ + + vst1q_u8(output, xyzw_clamped); + output += 16; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-scalar.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-scalar.c new file mode 100644 index 0000000000000..e86130f2ccb61 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-scalar.c @@ -0,0 +1,163 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include + +void pytorch_qnnp_requantize_q31__scalar( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 4 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + /* Compute requantization parameters */ + const uint32_t scale_bits = fp32_to_bits(scale); + + /* Multiplier is in [0x40000000, 0x7FFFFF80] range */ + const int32_t multiplier = (int32_t)( + ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); + assert(multiplier >= INT32_C(0x40000000)); + assert(multiplier <= INT32_C(0x7FFFFF80)); + + /* Shift is in [0, 31] range */ + const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23); + assert(shift >= 0); + assert(shift < 32); + + const int64_t q31rounding = INT64_C(0x40000000); + const int32_t remainder_mask = + (int32_t)((UINT32_C(1) << shift) - UINT32_C(1)); + const int32_t threshold = (int32_t)((uint32_t)remainder_mask >> 1); + const int32_t smin = (int32_t)(uint32_t)qmin - (int32_t)(uint32_t)zero_point; + const int32_t smax = (int32_t)(uint32_t)qmax - (int32_t)(uint32_t)zero_point; + for (; n != 0; n -= 4) { + const int32_t x = input[0]; + const int32_t y = input[1]; + const int32_t z = input[2]; + const int32_t w = input[3]; + input += 4; + + /* + * Compute full 64-bit product of signed 32-bit factors. + * + * Note: multiplier can be treated as either signed or unsigned. + */ + const int64_t x_product = (int64_t)x * (int64_t)multiplier; + const int64_t y_product = (int64_t)y * (int64_t)multiplier; + const int64_t z_product = (int64_t)z * (int64_t)multiplier; + const int64_t w_product = (int64_t)w * (int64_t)multiplier; + + /* + * Get the Q31 multiplication result by extracting bits 31-62 of the + * product, with rounding up. Add rounding value (0x40000000) and then shift + * right by 31 bits and extract the low 32-bit word. Note: casts to unsigned + * types are needed to avoid undefined behavior. Given the multiplier range, + * the result of Q31 multiplication is in [-2147483520, 2147483519] range. + */ + const int32_t x_q31product = + (int32_t)(uint32_t)((uint64_t)(x_product + q31rounding) >> 31); + const int32_t y_q31product = + (int32_t)(uint32_t)((uint64_t)(y_product + q31rounding) >> 31); + const int32_t z_q31product = + (int32_t)(uint32_t)((uint64_t)(z_product + q31rounding) >> 31); + const int32_t w_q31product = + (int32_t)(uint32_t)((uint64_t)(w_product + q31rounding) >> 31); + + /* + * Arithmetically shift the adjusted product right with rounding. + * Rounding is performed towards closest integer, with midpoints rounded + * away from zero. + * + * Shift with correct rounding could be efficiently implemented by + * pre-adding rounding constant, but with input in + * [-2147483520, 2147483519] range and rounding constant up to 2**30 we + * can't rule out overflow. This limitation leaves us with 3 options: + * 1. Extend input to 64-bit signed integer, perform addition and shift on + * 64-bit integers, then truncate result to 32 bits. + * 2. Detect overflow and handle this situation separately. Note that + * overflow is possible only when input is positive, and even when addition + * of a rounding constant overflows 32-bit signed integer, it still doesn't + * overflow 32-bit unsigned integer. Thus, in case of signed overflow, we + * can compute the result using unsigned arithmetics, specifically using + * logical shift right instead of arithmetic shift right. + * 3. Performs arithmetic shift as is, which will produce division result + * rounded down. Then compute remainder of this division by a power of 2, + * and adjust the result. Result needs adjustment (increment by 1) when + * - input is positive, shift is non-zero, and remainder >= 2**(shift - + * 1), e.g. 10 >> 2 needs adjustment + * - input is negative, shift is non-zero, and remainder > 2**(shift - + * 1), e.g. -10 >> 2 doesn't need adjustment These conditions can be + * generalized as remainder + (input <= 0) > 2**(shift - 1) or equivalently + * remainder - (input < 0) > ((2**shift - 1) >> 1) + * When shift is 0, remainder is 0 as well, the last condition is always + * false, and no adjustment is done. + * + * Among these options, option 3 is the most performant across the board, + * although option 1 is promising for 64-bit instruction sets. + */ + const int32_t x_remainder = + (x_q31product & remainder_mask) - (int32_t)(x_q31product < 0); + const int32_t y_remainder = + (y_q31product & remainder_mask) - (int32_t)(y_q31product < 0); + const int32_t z_remainder = + (z_q31product & remainder_mask) - (int32_t)(z_q31product < 0); + const int32_t w_remainder = + (w_q31product & remainder_mask) - (int32_t)(w_q31product < 0); + + const int32_t x_scaled = + asr_s32(x_q31product, shift) + (int32_t)(x_remainder > threshold); + const int32_t y_scaled = + asr_s32(y_q31product, shift) + (int32_t)(y_remainder > threshold); + const int32_t z_scaled = + asr_s32(z_q31product, shift) + (int32_t)(z_remainder > threshold); + const int32_t w_scaled = + asr_s32(w_q31product, shift) + (int32_t)(w_remainder > threshold); + + /* + * Clamp scaled value with zero point between (qmin - zero point) and (qmax + * - zero point). + */ + const int32_t x_clamped = + x_scaled < smin ? smin : x_scaled > smax ? smax : x_scaled; + const int32_t y_clamped = + y_scaled < smin ? smin : y_scaled > smax ? smax : y_scaled; + const int32_t z_clamped = + z_scaled < smin ? smin : z_scaled > smax ? smax : z_scaled; + const int32_t w_clamped = + w_scaled < smin ? smin : w_scaled > smax ? smax : w_scaled; + + /* + * Add zero point to clamped value. + * The result is guaranteed to be in [qmin, qmax] range. + * + * This addition can not be safely done before clamping, because scaled + * values are in [-2147483520, 2147483519] range, so addition of zero point + * (which can be up to 255) can overflow signed 32-bit integer. + */ + const int32_t x_biased = x_clamped + zero_point; + const int32_t y_biased = y_clamped + zero_point; + const int32_t z_biased = z_clamped + zero_point; + const int32_t w_biased = w_clamped + zero_point; + + output[0] = (uint8_t)x_biased; + output[1] = (uint8_t)y_biased; + output[2] = (uint8_t)z_biased; + output[3] = (uint8_t)w_biased; + output += 4; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-sse2.c new file mode 100644 index 0000000000000..5ef7d0076da1d --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-sse2.c @@ -0,0 +1,241 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include +#include + +void pytorch_qnnp_requantize_q31__sse2( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 16 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + /* Compute requantization parameters */ + const uint32_t scale_bits = fp32_to_bits(scale); + + /* Multiplier is in [0x40000000, 0x7FFFFF80] range */ + const int32_t multiplier = (int32_t)( + ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); + assert(multiplier >= INT32_C(0x40000000)); + assert(multiplier <= INT32_C(0x7FFFFF80)); + + /* Shift is in [0, 31] range */ + const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23); + assert(shift >= 0); + assert(shift < 32); + + const __m128i vmultiplier = _mm_set1_epi32(multiplier); + const __m128i vzero_point = _mm_set1_epi16((short)(uint16_t)zero_point); + const __m128i vqmin = _mm_set1_epi8((char)qmin); + const __m128i vqmax = _mm_set1_epi8((char)qmax); + const __m128i vshift = _mm_cvtsi32_si128((int)shift); + const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1); + const __m128i vremainder_mask = _mm_set1_epi32((int)remainder_mask); + const __m128i vthreshold = _mm_set1_epi32((int)(remainder_mask >> 1)); + const __m128i vq31rounding = _mm_set1_epi64x(UINT64_C(0x40000000)); + for (; n != 0; n -= 16) { + const __m128i x = _mm_loadu_si128((const __m128i*)input); + const __m128i y = _mm_loadu_si128((const __m128i*)(input + 4)); + const __m128i z = _mm_loadu_si128((const __m128i*)(input + 8)); + const __m128i w = _mm_loadu_si128((const __m128i*)(input + 12)); + input += 16; + + const __m128i x_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), x); + const __m128i y_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), y); + const __m128i z_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), z); + const __m128i w_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), w); + + const __m128i x_abs = + _mm_sub_epi32(_mm_xor_si128(x, x_neg_mask), x_neg_mask); + const __m128i y_abs = + _mm_sub_epi32(_mm_xor_si128(y, y_neg_mask), y_neg_mask); + const __m128i z_abs = + _mm_sub_epi32(_mm_xor_si128(z, z_neg_mask), z_neg_mask); + const __m128i w_abs = + _mm_sub_epi32(_mm_xor_si128(w, w_neg_mask), w_neg_mask); + + const __m128i x_abs_rev = _mm_shuffle_epi32(x_abs, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i y_abs_rev = _mm_shuffle_epi32(y_abs, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i z_abs_rev = _mm_shuffle_epi32(z_abs, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i w_abs_rev = _mm_shuffle_epi32(w_abs, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i x_abs_product_even = _mm_mul_epu32(x_abs, vmultiplier); + const __m128i y_abs_product_even = _mm_mul_epu32(y_abs, vmultiplier); + const __m128i z_abs_product_even = _mm_mul_epu32(z_abs, vmultiplier); + const __m128i w_abs_product_even = _mm_mul_epu32(w_abs, vmultiplier); + + const __m128i x_neg_mask_even = + _mm_shuffle_epi32(x_neg_mask, _MM_SHUFFLE(2, 2, 0, 0)); + const __m128i y_neg_mask_even = + _mm_shuffle_epi32(y_neg_mask, _MM_SHUFFLE(2, 2, 0, 0)); + const __m128i z_neg_mask_even = + _mm_shuffle_epi32(z_neg_mask, _MM_SHUFFLE(2, 2, 0, 0)); + const __m128i w_neg_mask_even = + _mm_shuffle_epi32(w_neg_mask, _MM_SHUFFLE(2, 2, 0, 0)); + + const __m128i x_product_even = _mm_sub_epi64( + _mm_xor_si128(x_abs_product_even, x_neg_mask_even), x_neg_mask_even); + const __m128i y_product_even = _mm_sub_epi64( + _mm_xor_si128(y_abs_product_even, y_neg_mask_even), y_neg_mask_even); + const __m128i z_product_even = _mm_sub_epi64( + _mm_xor_si128(z_abs_product_even, z_neg_mask_even), z_neg_mask_even); + const __m128i w_product_even = _mm_sub_epi64( + _mm_xor_si128(w_abs_product_even, w_neg_mask_even), w_neg_mask_even); + + const __m128i x_rounded_product_even = + _mm_add_epi64(x_product_even, vq31rounding); + const __m128i y_rounded_product_even = + _mm_add_epi64(y_product_even, vq31rounding); + const __m128i z_rounded_product_even = + _mm_add_epi64(z_product_even, vq31rounding); + const __m128i w_rounded_product_even = + _mm_add_epi64(w_product_even, vq31rounding); + + const __m128i x_abs_product_odd = _mm_mul_epu32(x_abs_rev, vmultiplier); + const __m128i y_abs_product_odd = _mm_mul_epu32(y_abs_rev, vmultiplier); + const __m128i z_abs_product_odd = _mm_mul_epu32(z_abs_rev, vmultiplier); + const __m128i w_abs_product_odd = _mm_mul_epu32(w_abs_rev, vmultiplier); + + const __m128i x_neg_mask_odd = + _mm_shuffle_epi32(x_neg_mask, _MM_SHUFFLE(3, 3, 1, 1)); + const __m128i y_neg_mask_odd = + _mm_shuffle_epi32(y_neg_mask, _MM_SHUFFLE(3, 3, 1, 1)); + const __m128i z_neg_mask_odd = + _mm_shuffle_epi32(z_neg_mask, _MM_SHUFFLE(3, 3, 1, 1)); + const __m128i w_neg_mask_odd = + _mm_shuffle_epi32(w_neg_mask, _MM_SHUFFLE(3, 3, 1, 1)); + + const __m128i x_product_odd = _mm_sub_epi64( + _mm_xor_si128(x_abs_product_odd, x_neg_mask_odd), x_neg_mask_odd); + const __m128i y_product_odd = _mm_sub_epi64( + _mm_xor_si128(y_abs_product_odd, y_neg_mask_odd), y_neg_mask_odd); + const __m128i z_product_odd = _mm_sub_epi64( + _mm_xor_si128(z_abs_product_odd, z_neg_mask_odd), z_neg_mask_odd); + const __m128i w_product_odd = _mm_sub_epi64( + _mm_xor_si128(w_abs_product_odd, w_neg_mask_odd), w_neg_mask_odd); + + const __m128i x_rounded_product_odd = + _mm_add_epi64(x_product_odd, vq31rounding); + const __m128i y_rounded_product_odd = + _mm_add_epi64(y_product_odd, vq31rounding); + const __m128i z_rounded_product_odd = + _mm_add_epi64(z_product_odd, vq31rounding); + const __m128i w_rounded_product_odd = + _mm_add_epi64(w_product_odd, vq31rounding); + + const __m128i x_q31product_even = + _mm_srli_epi64(x_rounded_product_even, 31); + const __m128i x_q31product_odd = _mm_srli_epi64(x_rounded_product_odd, 31); + const __m128i y_q31product_even = + _mm_srli_epi64(y_rounded_product_even, 31); + const __m128i y_q31product_odd = _mm_srli_epi64(y_rounded_product_odd, 31); + const __m128i z_q31product_even = + _mm_srli_epi64(z_rounded_product_even, 31); + const __m128i z_q31product_odd = _mm_srli_epi64(z_rounded_product_odd, 31); + const __m128i w_q31product_even = + _mm_srli_epi64(w_rounded_product_even, 31); + const __m128i w_q31product_odd = _mm_srli_epi64(w_rounded_product_odd, 31); + + const __m128i x_q31product_0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(x_q31product_even), + _mm_castsi128_ps(x_q31product_odd), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i y_q31product_0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(y_q31product_even), + _mm_castsi128_ps(y_q31product_odd), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i z_q31product_0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(z_q31product_even), + _mm_castsi128_ps(z_q31product_odd), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i w_q31product_0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(w_q31product_even), + _mm_castsi128_ps(w_q31product_odd), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i x_q31product = + _mm_shuffle_epi32(x_q31product_0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i y_q31product = + _mm_shuffle_epi32(y_q31product_0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i z_q31product = + _mm_shuffle_epi32(z_q31product_0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i w_q31product = + _mm_shuffle_epi32(w_q31product_0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i x_remainder = _mm_add_epi32( + _mm_and_si128(x_q31product, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), x_q31product)); + const __m128i y_remainder = _mm_add_epi32( + _mm_and_si128(y_q31product, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), y_q31product)); + const __m128i z_remainder = _mm_add_epi32( + _mm_and_si128(z_q31product, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), z_q31product)); + const __m128i w_remainder = _mm_add_epi32( + _mm_and_si128(w_q31product, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), w_q31product)); + + const __m128i x_scaled = _mm_sub_epi32( + _mm_sra_epi32(x_q31product, vshift), + _mm_cmpgt_epi32(x_remainder, vthreshold)); + const __m128i y_scaled = _mm_sub_epi32( + _mm_sra_epi32(y_q31product, vshift), + _mm_cmpgt_epi32(y_remainder, vthreshold)); + const __m128i z_scaled = _mm_sub_epi32( + _mm_sra_epi32(z_q31product, vshift), + _mm_cmpgt_epi32(z_remainder, vthreshold)); + const __m128i w_scaled = _mm_sub_epi32( + _mm_sra_epi32(w_q31product, vshift), + _mm_cmpgt_epi32(w_remainder, vthreshold)); + + const __m128i xy_packed = + _mm_adds_epi16(_mm_packs_epi32(x_scaled, y_scaled), vzero_point); + const __m128i zw_packed = + _mm_adds_epi16(_mm_packs_epi32(z_scaled, w_scaled), vzero_point); + const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed); + const __m128i xyzw_clamped = + _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin); + + /* + * 16x PSHUFD + * 4x SHUFPS + * 8x PMULUDQ + * 8x PXOR (setzero) + * 12x PXOR + * 4x PAND + * 8x PADDQ + * 4x PADDD + * 2x PADDW + * 8x PSUBQ + * 8x PSUBD + * 8x PSRLQ (immediate) + * 4x PSRAD (register) + * 12x PCMPGTD + * 2x PACKSSDW + * 1x PACKUSWB + * 1x PMAXUB + * 1x PMINUB + * --------------------- + * 111 instructions total + */ + + _mm_storeu_si128((__m128i*)output, xyzw_clamped); + output += 16; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-sse4.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-sse4.c new file mode 100644 index 0000000000000..a57072b01979c --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-sse4.c @@ -0,0 +1,162 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include +#include + +void pytorch_qnnp_requantize_q31__sse4( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 16 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + /* Compute requantization parameters */ + const uint32_t scale_bits = fp32_to_bits(scale); + + /* Multiplier is in [0x40000000, 0x7FFFFF80] range */ + const int32_t multiplier = (int32_t)( + ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); + assert(multiplier >= INT32_C(0x40000000)); + assert(multiplier <= INT32_C(0x7FFFFF80)); + + /* Shift is in [0, 31] range */ + const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23); + assert(shift >= 0); + assert(shift < 32); + + const __m128i vmultiplier = _mm_set1_epi32(multiplier); + const __m128i vzero_point = _mm_set1_epi16((short)(uint16_t)zero_point); + const __m128i vqmin = _mm_set1_epi8((char)qmin); + const __m128i vqmax = _mm_set1_epi8((char)qmax); + const __m128i vshift = _mm_cvtsi32_si128((int)shift); + const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1); + const __m128i vremainder_mask = _mm_set1_epi32((int)remainder_mask); + const __m128i vthreshold = _mm_set1_epi32((int)(remainder_mask >> 1)); + const __m128i vq31rounding = _mm_set1_epi64x(UINT64_C(0x40000000)); + for (; n != 0; n -= 16) { + const __m128i x = _mm_loadu_si128((const __m128i*)input); + const __m128i y = _mm_loadu_si128((const __m128i*)(input + 4)); + const __m128i z = _mm_loadu_si128((const __m128i*)(input + 8)); + const __m128i w = _mm_loadu_si128((const __m128i*)(input + 12)); + input += 16; + + const __m128i x_rev = _mm_shuffle_epi32(x, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i y_rev = _mm_shuffle_epi32(y, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i z_rev = _mm_shuffle_epi32(z, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i w_rev = _mm_shuffle_epi32(w, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i x_product_even = + _mm_add_epi64(_mm_mul_epi32(x, vmultiplier), vq31rounding); + const __m128i y_product_even = + _mm_add_epi64(_mm_mul_epi32(y, vmultiplier), vq31rounding); + const __m128i z_product_even = + _mm_add_epi64(_mm_mul_epi32(z, vmultiplier), vq31rounding); + const __m128i w_product_even = + _mm_add_epi64(_mm_mul_epi32(w, vmultiplier), vq31rounding); + + const __m128i x_product_odd = + _mm_add_epi64(_mm_mul_epi32(x_rev, vmultiplier), vq31rounding); + const __m128i y_product_odd = + _mm_add_epi64(_mm_mul_epi32(y_rev, vmultiplier), vq31rounding); + const __m128i z_product_odd = + _mm_add_epi64(_mm_mul_epi32(z_rev, vmultiplier), vq31rounding); + const __m128i w_product_odd = + _mm_add_epi64(_mm_mul_epi32(w_rev, vmultiplier), vq31rounding); + + const __m128i x_q31product_even = _mm_srli_epi64(x_product_even, 31); + const __m128i x_q31product_odd = + _mm_add_epi64(x_product_odd, x_product_odd); + const __m128i y_q31product_even = _mm_srli_epi64(y_product_even, 31); + const __m128i y_q31product_odd = + _mm_add_epi64(y_product_odd, y_product_odd); + const __m128i z_q31product_even = _mm_srli_epi64(z_product_even, 31); + const __m128i z_q31product_odd = + _mm_add_epi64(z_product_odd, z_product_odd); + const __m128i w_q31product_even = _mm_srli_epi64(w_product_even, 31); + const __m128i w_q31product_odd = + _mm_add_epi64(w_product_odd, w_product_odd); + + const __m128i x_q31product = + _mm_blend_epi16(x_q31product_even, x_q31product_odd, 0xCC); + const __m128i y_q31product = + _mm_blend_epi16(y_q31product_even, y_q31product_odd, 0xCC); + const __m128i z_q31product = + _mm_blend_epi16(z_q31product_even, z_q31product_odd, 0xCC); + const __m128i w_q31product = + _mm_blend_epi16(w_q31product_even, w_q31product_odd, 0xCC); + + const __m128i x_remainder = _mm_add_epi32( + _mm_and_si128(x_q31product, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), x_q31product)); + const __m128i y_remainder = _mm_add_epi32( + _mm_and_si128(y_q31product, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), y_q31product)); + const __m128i z_remainder = _mm_add_epi32( + _mm_and_si128(z_q31product, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), z_q31product)); + const __m128i w_remainder = _mm_add_epi32( + _mm_and_si128(w_q31product, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), w_q31product)); + + const __m128i x_scaled = _mm_sub_epi32( + _mm_sra_epi32(x_q31product, vshift), + _mm_cmpgt_epi32(x_remainder, vthreshold)); + const __m128i y_scaled = _mm_sub_epi32( + _mm_sra_epi32(y_q31product, vshift), + _mm_cmpgt_epi32(y_remainder, vthreshold)); + const __m128i z_scaled = _mm_sub_epi32( + _mm_sra_epi32(z_q31product, vshift), + _mm_cmpgt_epi32(z_remainder, vthreshold)); + const __m128i w_scaled = _mm_sub_epi32( + _mm_sra_epi32(w_q31product, vshift), + _mm_cmpgt_epi32(w_remainder, vthreshold)); + + const __m128i xy_packed = + _mm_adds_epi16(_mm_packs_epi32(x_scaled, y_scaled), vzero_point); + const __m128i zw_packed = + _mm_adds_epi16(_mm_packs_epi32(z_scaled, w_scaled), vzero_point); + const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed); + const __m128i xyzw_clamped = + _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin); + + /* + * 4x PSHUFD + * 8x PMULDQ + * 12x PADDQ + * 4x PADDD + * 2x PADDW + * 4x PSUBD + * 4x PSLRQ (immediate) + * 4x PSRAD (register) + * 4x PBLENDW + * 4x PAND + * 4x PXOR (setzero) + * 8x PCMPGTD + * 2x PACKSSDW + * 1x PACKUSWB + * 1x PMAXUB + * 1x PMINUB + * --------------------- + * 67 instructions total + */ + + _mm_storeu_si128((__m128i*)output, xyzw_clamped); + output += 16; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-ssse3.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-ssse3.c new file mode 100644 index 0000000000000..e2d147536e844 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/q31-ssse3.c @@ -0,0 +1,238 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +#include +#include + +void pytorch_qnnp_requantize_q31__ssse3( + size_t n, + const int32_t* input, + float scale, + uint8_t zero_point, + uint8_t qmin, + uint8_t qmax, + uint8_t* output) { + assert(n % 16 == 0); + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + /* Compute requantization parameters */ + const uint32_t scale_bits = fp32_to_bits(scale); + + /* Multiplier is in [0x40000000, 0x7FFFFF80] range */ + const int32_t multiplier = (int32_t)( + ((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); + assert(multiplier >= INT32_C(0x40000000)); + assert(multiplier <= INT32_C(0x7FFFFF80)); + + /* Shift is in [0, 31] range */ + const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23); + assert(shift >= 0); + assert(shift < 32); + + const __m128i vmultiplier = _mm_set1_epi32(multiplier); + const __m128i vzero_point = _mm_set1_epi16((short)(uint16_t)zero_point); + const __m128i vqmin = _mm_set1_epi8((char)qmin); + const __m128i vqmax = _mm_set1_epi8((char)qmax); + const __m128i vshift = _mm_cvtsi32_si128((int)shift); + const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1); + const __m128i vremainder_mask = _mm_set1_epi32((int)remainder_mask); + const __m128i vthreshold = _mm_set1_epi32((int)(remainder_mask >> 1)); + const __m128i vq31rounding = _mm_set1_epi64x(UINT64_C(0x40000000)); + for (; n != 0; n -= 16) { + const __m128i x = _mm_loadu_si128((const __m128i*)input); + const __m128i y = _mm_loadu_si128((const __m128i*)(input + 4)); + const __m128i z = _mm_loadu_si128((const __m128i*)(input + 8)); + const __m128i w = _mm_loadu_si128((const __m128i*)(input + 12)); + input += 16; + + const __m128i x_abs = _mm_abs_epi32(x); + const __m128i y_abs = _mm_abs_epi32(y); + const __m128i z_abs = _mm_abs_epi32(z); + const __m128i w_abs = _mm_abs_epi32(w); + + const __m128i x_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), x); + const __m128i y_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), y); + const __m128i z_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), z); + const __m128i w_neg_mask = _mm_cmpgt_epi32(_mm_setzero_si128(), w); + + const __m128i x_abs_rev = _mm_shuffle_epi32(x_abs, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i y_abs_rev = _mm_shuffle_epi32(y_abs, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i z_abs_rev = _mm_shuffle_epi32(z_abs, _MM_SHUFFLE(2, 3, 0, 1)); + const __m128i w_abs_rev = _mm_shuffle_epi32(w_abs, _MM_SHUFFLE(2, 3, 0, 1)); + + const __m128i x_abs_product_even = _mm_mul_epu32(x_abs, vmultiplier); + const __m128i y_abs_product_even = _mm_mul_epu32(y_abs, vmultiplier); + const __m128i z_abs_product_even = _mm_mul_epu32(z_abs, vmultiplier); + const __m128i w_abs_product_even = _mm_mul_epu32(w_abs, vmultiplier); + + const __m128i x_neg_mask_even = + _mm_shuffle_epi32(x_neg_mask, _MM_SHUFFLE(2, 2, 0, 0)); + const __m128i y_neg_mask_even = + _mm_shuffle_epi32(y_neg_mask, _MM_SHUFFLE(2, 2, 0, 0)); + const __m128i z_neg_mask_even = + _mm_shuffle_epi32(z_neg_mask, _MM_SHUFFLE(2, 2, 0, 0)); + const __m128i w_neg_mask_even = + _mm_shuffle_epi32(w_neg_mask, _MM_SHUFFLE(2, 2, 0, 0)); + + const __m128i x_product_even = _mm_sub_epi64( + _mm_xor_si128(x_abs_product_even, x_neg_mask_even), x_neg_mask_even); + const __m128i y_product_even = _mm_sub_epi64( + _mm_xor_si128(y_abs_product_even, y_neg_mask_even), y_neg_mask_even); + const __m128i z_product_even = _mm_sub_epi64( + _mm_xor_si128(z_abs_product_even, z_neg_mask_even), z_neg_mask_even); + const __m128i w_product_even = _mm_sub_epi64( + _mm_xor_si128(w_abs_product_even, w_neg_mask_even), w_neg_mask_even); + + const __m128i x_rounded_product_even = + _mm_add_epi64(x_product_even, vq31rounding); + const __m128i y_rounded_product_even = + _mm_add_epi64(y_product_even, vq31rounding); + const __m128i z_rounded_product_even = + _mm_add_epi64(z_product_even, vq31rounding); + const __m128i w_rounded_product_even = + _mm_add_epi64(w_product_even, vq31rounding); + + const __m128i x_abs_product_odd = _mm_mul_epu32(x_abs_rev, vmultiplier); + const __m128i y_abs_product_odd = _mm_mul_epu32(y_abs_rev, vmultiplier); + const __m128i z_abs_product_odd = _mm_mul_epu32(z_abs_rev, vmultiplier); + const __m128i w_abs_product_odd = _mm_mul_epu32(w_abs_rev, vmultiplier); + + const __m128i x_neg_mask_odd = + _mm_shuffle_epi32(x_neg_mask, _MM_SHUFFLE(3, 3, 1, 1)); + const __m128i y_neg_mask_odd = + _mm_shuffle_epi32(y_neg_mask, _MM_SHUFFLE(3, 3, 1, 1)); + const __m128i z_neg_mask_odd = + _mm_shuffle_epi32(z_neg_mask, _MM_SHUFFLE(3, 3, 1, 1)); + const __m128i w_neg_mask_odd = + _mm_shuffle_epi32(w_neg_mask, _MM_SHUFFLE(3, 3, 1, 1)); + + const __m128i x_product_odd = _mm_sub_epi64( + _mm_xor_si128(x_abs_product_odd, x_neg_mask_odd), x_neg_mask_odd); + const __m128i y_product_odd = _mm_sub_epi64( + _mm_xor_si128(y_abs_product_odd, y_neg_mask_odd), y_neg_mask_odd); + const __m128i z_product_odd = _mm_sub_epi64( + _mm_xor_si128(z_abs_product_odd, z_neg_mask_odd), z_neg_mask_odd); + const __m128i w_product_odd = _mm_sub_epi64( + _mm_xor_si128(w_abs_product_odd, w_neg_mask_odd), w_neg_mask_odd); + + const __m128i x_rounded_product_odd = + _mm_add_epi64(x_product_odd, vq31rounding); + const __m128i y_rounded_product_odd = + _mm_add_epi64(y_product_odd, vq31rounding); + const __m128i z_rounded_product_odd = + _mm_add_epi64(z_product_odd, vq31rounding); + const __m128i w_rounded_product_odd = + _mm_add_epi64(w_product_odd, vq31rounding); + + const __m128i x_q31product_even = + _mm_srli_epi64(x_rounded_product_even, 31); + const __m128i x_q31product_odd = _mm_srli_epi64(x_rounded_product_odd, 31); + const __m128i y_q31product_even = + _mm_srli_epi64(y_rounded_product_even, 31); + const __m128i y_q31product_odd = _mm_srli_epi64(y_rounded_product_odd, 31); + const __m128i z_q31product_even = + _mm_srli_epi64(z_rounded_product_even, 31); + const __m128i z_q31product_odd = _mm_srli_epi64(z_rounded_product_odd, 31); + const __m128i w_q31product_even = + _mm_srli_epi64(w_rounded_product_even, 31); + const __m128i w_q31product_odd = _mm_srli_epi64(w_rounded_product_odd, 31); + + const __m128i x_q31product_0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(x_q31product_even), + _mm_castsi128_ps(x_q31product_odd), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i y_q31product_0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(y_q31product_even), + _mm_castsi128_ps(y_q31product_odd), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i z_q31product_0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(z_q31product_even), + _mm_castsi128_ps(z_q31product_odd), + _MM_SHUFFLE(2, 0, 2, 0))); + const __m128i w_q31product_0213 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(w_q31product_even), + _mm_castsi128_ps(w_q31product_odd), + _MM_SHUFFLE(2, 0, 2, 0))); + + const __m128i x_q31product = + _mm_shuffle_epi32(x_q31product_0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i y_q31product = + _mm_shuffle_epi32(y_q31product_0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i z_q31product = + _mm_shuffle_epi32(z_q31product_0213, _MM_SHUFFLE(3, 1, 2, 0)); + const __m128i w_q31product = + _mm_shuffle_epi32(w_q31product_0213, _MM_SHUFFLE(3, 1, 2, 0)); + + const __m128i x_remainder = _mm_add_epi32( + _mm_and_si128(x_q31product, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), x_q31product)); + const __m128i y_remainder = _mm_add_epi32( + _mm_and_si128(y_q31product, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), y_q31product)); + const __m128i z_remainder = _mm_add_epi32( + _mm_and_si128(z_q31product, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), z_q31product)); + const __m128i w_remainder = _mm_add_epi32( + _mm_and_si128(w_q31product, vremainder_mask), + _mm_cmpgt_epi32(_mm_setzero_si128(), w_q31product)); + + const __m128i x_scaled = _mm_sub_epi32( + _mm_sra_epi32(x_q31product, vshift), + _mm_cmpgt_epi32(x_remainder, vthreshold)); + const __m128i y_scaled = _mm_sub_epi32( + _mm_sra_epi32(y_q31product, vshift), + _mm_cmpgt_epi32(y_remainder, vthreshold)); + const __m128i z_scaled = _mm_sub_epi32( + _mm_sra_epi32(z_q31product, vshift), + _mm_cmpgt_epi32(z_remainder, vthreshold)); + const __m128i w_scaled = _mm_sub_epi32( + _mm_sra_epi32(w_q31product, vshift), + _mm_cmpgt_epi32(w_remainder, vthreshold)); + + const __m128i xy_packed = + _mm_adds_epi16(_mm_packs_epi32(x_scaled, y_scaled), vzero_point); + const __m128i zw_packed = + _mm_adds_epi16(_mm_packs_epi32(z_scaled, w_scaled), vzero_point); + const __m128i xyzw_packed = _mm_packus_epi16(xy_packed, zw_packed); + const __m128i xyzw_clamped = + _mm_max_epu8(_mm_min_epu8(xyzw_packed, vqmax), vqmin); + + /* + * 16x PSHUFD + * 4x SHUFPS + * 8x PMULUDQ + * 8x PXOR (setzero) + * 8x PXOR + * 4x PAND + * 8x PADDQ + * 4x PADDD + * 2x PADDW + * 8x PSUBQ + * 4x PSUBD + * 8x PSRLQ (immediate) + * 4x PSRAD (register) + * 12x PCMPGTD + * 4x PABSD + * 2x PACKSSDW + * 1x PACKUSWB + * 1x PMAXUB + * 1x PMINUB + * --------------------- + * 107 instructions total + */ + + _mm_storeu_si128((__m128i*)output, xyzw_clamped); + output += 16; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/runtime-assembly.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/runtime-assembly.h new file mode 100644 index 0000000000000..aae4438e0fa64 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/runtime-assembly.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef __aarch64__ + +.macro SUB_ZERO_POINT vout, vin1, vin2 +#if PYTORCH_QNNPACK_RUNTIME_QUANTIZATION + USUBL \vout, \vin1, \vin2 +#else + UXTL \vout, \vin1 +#endif +.endm + +#else /* aarch32 */ + +.macro SUB_ZERO_POINT qout, din1, din2 +#if PYTORCH_QNNPACK_RUNTIME_QUANTIZATION + VSUBL.U8 \qout, \din1, \din2 +#else + VMOVL.U8 \qout, \din1 +#endif +.endm + +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/runtime-neon.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/runtime-neon.h new file mode 100644 index 0000000000000..5081aaf737621 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/runtime-neon.h @@ -0,0 +1,22 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +PYTORCH_QNNP_INLINE uint16x8_t +sub_zero_point(const uint8x8_t va, const uint8x8_t vzp) { +#if PYTORCH_QNNPACK_RUNTIME_QUANTIZATION + // Run-time quantization + return vsubl_u8(va, vzp); +#else + // Design-time quantization + return vmovl_u8(va); +#endif +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/runtime-sse2.h b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/runtime-sse2.h new file mode 100644 index 0000000000000..40b9ffe16f949 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/requantization/runtime-sse2.h @@ -0,0 +1,22 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +PYTORCH_QNNP_INLINE __m128i +sub_zero_point(const __m128i va, const __m128i vzp) { +#if PYTORCH_QNNPACK_RUNTIME_QUANTIZATION + // Run-time quantization + return _mm_sub_epi16(va, vzp); +#else + // Design-time quantization (no-op) + return va; +#endif +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/sconv/6x8-psimd.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/sconv/6x8-psimd.c new file mode 100644 index 0000000000000..ee8e6723da6cf --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/sconv/6x8-psimd.c @@ -0,0 +1,203 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_sconv_ukernel_6x8__psimd( + size_t mr, + size_t nr, + size_t kc, + size_t ks, + const float** restrict a, + const float* restrict w, + float* restrict c, + size_t c_stride, + const struct pytorch_qnnp_fp32_clamping_params + clamping_params[restrict static 1]) { + psimd_f32 vacc0x0123 = psimd_load_f32(w); + w += 4; + psimd_f32 vacc0x4567 = psimd_load_f32(w); + w += 4; + psimd_f32 vacc1x0123 = vacc0x0123; + psimd_f32 vacc1x4567 = vacc0x4567; + psimd_f32 vacc2x0123 = vacc0x0123; + psimd_f32 vacc2x4567 = vacc0x4567; + psimd_f32 vacc3x0123 = vacc0x0123; + psimd_f32 vacc3x4567 = vacc0x4567; + psimd_f32 vacc4x0123 = vacc0x0123; + psimd_f32 vacc4x4567 = vacc0x4567; + psimd_f32 vacc5x0123 = vacc0x0123; + psimd_f32 vacc5x4567 = vacc0x4567; + + do { + const float* restrict a0 = *a++; + const float* restrict a1 = *a++; + const float* restrict a2 = *a++; + const float* restrict a3 = *a++; + const float* restrict a4 = *a++; + const float* restrict a5 = *a++; + + size_t k = kc; + do { + const psimd_f32 va0 = psimd_splat_f32(*a0); + a0 += 1; + const psimd_f32 va1 = psimd_splat_f32(*a1); + a1 += 1; + const psimd_f32 va2 = psimd_splat_f32(*a2); + a2 += 1; + const psimd_f32 va3 = psimd_splat_f32(*a3); + a3 += 1; + const psimd_f32 va4 = psimd_splat_f32(*a4); + a4 += 1; + const psimd_f32 va5 = psimd_splat_f32(*a5); + a5 += 1; + + const psimd_f32 vb0123 = psimd_load_f32(w); + w += 4; + const psimd_f32 vb4567 = psimd_load_f32(w); + w += 4; + + vacc0x0123 += vb0123 * va0; + vacc0x4567 += vb4567 * va0; + vacc1x0123 += vb0123 * va1; + vacc1x4567 += vb4567 * va1; + vacc2x0123 += vb0123 * va2; + vacc2x4567 += vb4567 * va2; + vacc3x0123 += vb0123 * va3; + vacc3x4567 += vb4567 * va3; + vacc4x0123 += vb0123 * va4; + vacc4x4567 += vb4567 * va4; + vacc5x0123 += vb0123 * va5; + vacc5x4567 += vb4567 * va5; + } while (--k != 0); + } while (--ks != 0); + + const psimd_f32 vmax = psimd_splat_f32(clamping_params->max); + vacc0x0123 = psimd_min_f32(vacc0x0123, vmax); + vacc0x4567 = psimd_min_f32(vacc0x4567, vmax); + vacc1x0123 = psimd_min_f32(vacc1x0123, vmax); + vacc1x4567 = psimd_min_f32(vacc1x4567, vmax); + vacc2x0123 = psimd_min_f32(vacc2x0123, vmax); + vacc2x4567 = psimd_min_f32(vacc2x4567, vmax); + vacc3x0123 = psimd_min_f32(vacc3x0123, vmax); + vacc3x4567 = psimd_min_f32(vacc3x4567, vmax); + vacc4x0123 = psimd_min_f32(vacc4x0123, vmax); + vacc4x4567 = psimd_min_f32(vacc4x4567, vmax); + vacc5x0123 = psimd_min_f32(vacc5x0123, vmax); + vacc5x4567 = psimd_min_f32(vacc5x4567, vmax); + + const psimd_f32 vmin = psimd_splat_f32(clamping_params->min); + vacc0x0123 = psimd_max_f32(vacc0x0123, vmin); + vacc0x4567 = psimd_max_f32(vacc0x4567, vmin); + vacc1x0123 = psimd_max_f32(vacc1x0123, vmin); + vacc1x4567 = psimd_max_f32(vacc1x4567, vmin); + vacc2x0123 = psimd_max_f32(vacc2x0123, vmin); + vacc2x4567 = psimd_max_f32(vacc2x4567, vmin); + vacc3x0123 = psimd_max_f32(vacc3x0123, vmin); + vacc3x4567 = psimd_max_f32(vacc3x4567, vmin); + vacc4x0123 = psimd_max_f32(vacc4x0123, vmin); + vacc4x4567 = psimd_max_f32(vacc4x4567, vmin); + vacc5x0123 = psimd_max_f32(vacc5x0123, vmin); + vacc5x4567 = psimd_max_f32(vacc5x4567, vmin); + + float* c0 = c; + float* c1 = (float*)((uintptr_t)c0 + c_stride); + if (mr < 2) { + c1 = c0; + } + float* c2 = (float*)((uintptr_t)c1 + c_stride); + if (mr <= 2) { + c2 = c1; + } + float* c3 = (float*)((uintptr_t)c2 + c_stride); + if (mr < 4) { + c3 = c2; + } + float* c4 = (float*)((uintptr_t)c3 + c_stride); + if (mr <= 4) { + c4 = c3; + } + float* c5 = (float*)((uintptr_t)c4 + c_stride); + if (mr != 6) { + c5 = c4; + } + if (nr == 8) { + psimd_store_f32(c0, vacc0x0123); + c0 += 4; + psimd_store_f32(c1, vacc1x0123); + c1 += 4; + psimd_store_f32(c2, vacc2x0123); + c2 += 4; + psimd_store_f32(c3, vacc3x0123); + c3 += 4; + psimd_store_f32(c4, vacc4x0123); + c4 += 4; + psimd_store_f32(c5, vacc5x0123); + c5 += 4; + + psimd_store_f32(c0, vacc0x4567); + psimd_store_f32(c1, vacc1x4567); + psimd_store_f32(c2, vacc2x4567); + psimd_store_f32(c3, vacc3x4567); + psimd_store_f32(c4, vacc4x4567); + psimd_store_f32(c5, vacc5x4567); + } else { + if (nr >= 4) { + psimd_store_f32(c0, vacc0x0123); + c0 += 4; + psimd_store_f32(c1, vacc1x0123); + c1 += 4; + psimd_store_f32(c2, vacc2x0123); + c2 += 4; + psimd_store_f32(c3, vacc3x0123); + c3 += 4; + psimd_store_f32(c4, vacc4x0123); + c4 += 4; + psimd_store_f32(c5, vacc5x0123); + c5 += 4; + vacc0x0123 = vacc0x4567; + vacc1x0123 = vacc1x4567; + vacc2x0123 = vacc2x4567; + vacc3x0123 = vacc3x4567; + vacc4x0123 = vacc4x4567; + vacc5x0123 = vacc5x4567; + nr -= 4; + } + if (nr >= 2) { + psimd_store2_f32(c0, vacc0x0123); + c0 += 2; + psimd_store2_f32(c1, vacc1x0123); + c1 += 2; + psimd_store2_f32(c2, vacc2x0123); + c2 += 2; + psimd_store2_f32(c3, vacc3x0123); + c3 += 2; + psimd_store2_f32(c4, vacc4x0123); + c4 += 2; + psimd_store2_f32(c5, vacc5x0123); + c5 += 2; + vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123); + vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123); + vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123); + vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123); + vacc4x0123 = psimd_concat_hi_f32(vacc4x0123, vacc4x0123); + vacc5x0123 = psimd_concat_hi_f32(vacc5x0123, vacc5x0123); + nr -= 2; + } + if (nr != 0) { + psimd_store1_f32(c0, vacc0x0123); + psimd_store1_f32(c1, vacc1x0123); + psimd_store1_f32(c2, vacc2x0123); + psimd_store1_f32(c3, vacc3x0123); + psimd_store1_f32(c4, vacc4x0123); + psimd_store1_f32(c5, vacc5x0123); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/sdwconv/up4x9-psimd.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/sdwconv/up4x9-psimd.c new file mode 100644 index 0000000000000..152b721809113 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/sdwconv/up4x9-psimd.c @@ -0,0 +1,156 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_sdwconv_ukernel_up4x9__psimd( + size_t channels, + size_t output_width, + const float** input, + const float* weights, + float* output, + size_t input_stride, + size_t output_increment, + const struct pytorch_qnnp_fp32_clamping_params + clamping_params[restrict static 1]) { + const psimd_f32 vmax = psimd_splat_f32(clamping_params->max); + const psimd_f32 vmin = psimd_splat_f32(clamping_params->min); + do { + const float* i0 = input[0]; + const float* i1 = input[1]; + const float* i2 = input[2]; + const float* i3 = input[3]; + const float* i4 = input[4]; + const float* i5 = input[5]; + const float* i6 = input[6]; + const float* i7 = input[7]; + const float* i8 = input[8]; + + input = (const float**)((uintptr_t)input + input_stride); + + size_t c = channels; + const float* w = weights; + for (; c >= 4; c -= 4) { + psimd_f32 vacc = psimd_load_f32(w); + + const psimd_f32 vi0 = psimd_load_f32(i0); + i0 += 4; + const psimd_f32 vk0 = psimd_load_f32(w + 8); + vacc += vi0 * vk0; + + const psimd_f32 vi1 = psimd_load_f32(i1); + i1 += 4; + const psimd_f32 vk1 = psimd_load_f32(w + 12); + psimd_f32 vacc2 = vi1 * vk1; + + const psimd_f32 vi2 = psimd_load_f32(i2); + i2 += 4; + const psimd_f32 vk2 = psimd_load_f32(w + 16); + vacc += vi2 * vk2; + + const psimd_f32 vi3 = psimd_load_f32(i3); + i3 += 4; + const psimd_f32 vk3 = psimd_load_f32(w + 20); + vacc2 += vi3 * vk3; + + const psimd_f32 vi4 = psimd_load_f32(i4); + i4 += 4; + const psimd_f32 vk4 = psimd_load_f32(w + 24); + vacc += vi4 * vk4; + + const psimd_f32 vi5 = psimd_load_f32(i5); + i5 += 4; + const psimd_f32 vk5 = psimd_load_f32(w + 28); + vacc2 += vi5 * vk5; + + const psimd_f32 vi6 = psimd_load_f32(i6); + i6 += 4; + const psimd_f32 vk6 = psimd_load_f32(w + 32); + vacc += vi6 * vk6; + + const psimd_f32 vi7 = psimd_load_f32(i7); + i7 += 4; + const psimd_f32 vk7 = psimd_load_f32(w + 36); + vacc2 += vi7 * vk7; + + const psimd_f32 vi8 = psimd_load_f32(i8); + i8 += 4; + const psimd_f32 vk8 = psimd_load_f32(w + 40); + vacc += vi8 * vk8; + + vacc += vacc2; + + vacc = psimd_min_f32(vacc, vmax); + vacc = psimd_max_f32(vacc, vmin); + + psimd_store_f32(output, vacc); + w += 44; + } + if (c != 0) { + psimd_f32 vacc = psimd_load_f32(w); + c *= sizeof(float); + + i0 = (const float*)((uintptr_t)i0 - c); + const psimd_f32 vi0 = psimd_load_f32(i0); + const psimd_f32 vk0 = psimd_load_f32(w + 8); + vacc += vi0 * vk0; + + i1 = (const float*)((uintptr_t)i1 - c); + const psimd_f32 vi1 = psimd_load_f32(i1); + const psimd_f32 vk1 = psimd_load_f32(w + 12); + psimd_f32 vacc2 = vi1 * vk1; + + i2 = (const float*)((uintptr_t)i2 - c); + const psimd_f32 vi2 = psimd_load_f32(i2); + const psimd_f32 vk2 = psimd_load_f32(w + 16); + vacc += vi2 * vk2; + + i3 = (const float*)((uintptr_t)i3 - c); + const psimd_f32 vi3 = psimd_load_f32(i3); + const psimd_f32 vk3 = psimd_load_f32(w + 20); + vacc2 += vi3 * vk3; + + i4 = (const float*)((uintptr_t)i4 - c); + const psimd_f32 vi4 = psimd_load_f32(i4); + const psimd_f32 vk4 = psimd_load_f32(w + 24); + vacc += vi4 * vk4; + + i5 = (const float*)((uintptr_t)i5 - c); + const psimd_f32 vi5 = psimd_load_f32(i5); + const psimd_f32 vk5 = psimd_load_f32(w + 28); + vacc2 += vi5 * vk5; + + i6 = (const float*)((uintptr_t)i6 - c); + const psimd_f32 vi6 = psimd_load_f32(i6); + const psimd_f32 vk6 = psimd_load_f32(w + 32); + vacc += vi6 * vk6; + + i7 = (const float*)((uintptr_t)i7 - c); + const psimd_f32 vi7 = psimd_load_f32(i7); + const psimd_f32 vk7 = psimd_load_f32(w + 36); + vacc2 += vi7 * vk7; + + i8 = (const float*)((uintptr_t)i8 - c); + const psimd_f32 vi8 = psimd_load_f32(i8); + const psimd_f32 vk8 = psimd_load_f32(w + 40); + vacc += vi8 * vk8; + + vacc += vacc2; + + vacc = psimd_min_f32(vacc, vmax); + vacc = psimd_max_f32(vacc, vmin); + + output = (float*)((uintptr_t)output - c); + psimd_store_f32(output, vacc); + } + + output = (float*)((uintptr_t)output + output_increment); + } while (--output_width != 0); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/sgemm/5x8-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/sgemm/5x8-neon.c new file mode 100644 index 0000000000000..2d416f5236eaf --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/sgemm/5x8-neon.c @@ -0,0 +1,268 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_sgemm_ukernel_5x8__neon( + size_t mr, + size_t nr, + size_t k, + const float* restrict a, + size_t a_stride, + const float* restrict w, + float* restrict c, + size_t c_stride, + const struct pytorch_qnnp_fp32_clamping_params + clamping_params[restrict static 1]) { + float32x4_t vacc0x0123 = vld1q_f32(w); + w += 4; + float32x4_t vacc0x4567 = vld1q_f32(w); + w += 4; + float32x4_t vacc1x0123 = vacc0x0123; + float32x4_t vacc1x4567 = vacc0x4567; + float32x4_t vacc2x0123 = vacc0x0123; + float32x4_t vacc2x4567 = vacc0x4567; + float32x4_t vacc3x0123 = vacc0x0123; + float32x4_t vacc3x4567 = vacc0x4567; + float32x4_t vacc4x0123 = vacc0x0123; + float32x4_t vacc4x4567 = vacc0x4567; + + const float* a0 = a; + const float* a1 = (const float*)((uintptr_t)a0 + a_stride); + if (mr < 2) { + a1 = a0; + } + const float* a2 = (const float*)((uintptr_t)a1 + a_stride); + if (mr <= 2) { + a2 = a1; + } + const float* a3 = (const float*)((uintptr_t)a2 + a_stride); + if (mr < 4) { + a3 = a2; + } + const float* a4 = (const float*)((uintptr_t)a3 + a_stride); + if (mr <= 4) { + a4 = a3; + } + + for (; k >= 2; k -= 2) { + const float32x2_t va0 = vld1_f32(a0); + a0 += 2; + const float32x2_t va1 = vld1_f32(a1); + a1 += 2; + const float32x2_t va2 = vld1_f32(a2); + a2 += 2; + const float32x2_t va3 = vld1_f32(a3); + a3 += 2; + const float32x2_t va4 = vld1_f32(a4); + a4 += 2; + + { + const float32x4_t vb0123 = vld1q_f32(w); + w += 4; + const float32x4_t vb4567 = vld1q_f32(w); + w += 4; + +#if defined(__aarch64__) + vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123, va0, 0); + vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567, va0, 0); + vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123, va1, 0); + vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567, va1, 0); + vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123, va2, 0); + vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567, va2, 0); + vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123, va3, 0); + vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567, va3, 0); + vacc4x0123 = vfmaq_lane_f32(vacc4x0123, vb0123, va4, 0); + vacc4x4567 = vfmaq_lane_f32(vacc4x4567, vb4567, va4, 0); +#else + vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123, va0, 0); + vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567, va0, 0); + vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123, va1, 0); + vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567, va1, 0); + vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123, va2, 0); + vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567, va2, 0); + vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123, va3, 0); + vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567, va3, 0); + vacc4x0123 = vmlaq_lane_f32(vacc4x0123, vb0123, va4, 0); + vacc4x4567 = vmlaq_lane_f32(vacc4x4567, vb4567, va4, 0); +#endif + } + + { + const float32x4_t vb0123 = vld1q_f32(w); + w += 4; + const float32x4_t vb4567 = vld1q_f32(w); + w += 4; + +#if defined(__aarch64__) + vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123, va0, 1); + vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567, va0, 1); + vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123, va1, 1); + vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567, va1, 1); + vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123, va2, 1); + vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567, va2, 1); + vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123, va3, 1); + vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567, va3, 1); + vacc4x0123 = vfmaq_lane_f32(vacc4x0123, vb0123, va4, 1); + vacc4x4567 = vfmaq_lane_f32(vacc4x4567, vb4567, va4, 1); +#else + vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123, va0, 1); + vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567, va0, 1); + vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123, va1, 1); + vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567, va1, 1); + vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123, va2, 1); + vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567, va2, 1); + vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123, va3, 1); + vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567, va3, 1); + vacc4x0123 = vmlaq_lane_f32(vacc4x0123, vb0123, va4, 1); + vacc4x4567 = vmlaq_lane_f32(vacc4x4567, vb4567, va4, 1); +#endif + } + } + if (k != 0) { + const float32x4_t va0 = vld1q_dup_f32(a0); + const float32x4_t va1 = vld1q_dup_f32(a1); + const float32x4_t va2 = vld1q_dup_f32(a2); + const float32x4_t va3 = vld1q_dup_f32(a3); + const float32x4_t va4 = vld1q_dup_f32(a4); + + const float32x4_t vb0123 = vld1q_f32(w); + w += 4; + const float32x4_t vb4567 = vld1q_f32(w); + w += 4; + +#if defined(__aarch64__) + vacc0x0123 = vfmaq_f32(vacc0x0123, vb0123, va0); + vacc0x4567 = vfmaq_f32(vacc0x4567, vb4567, va0); + vacc1x0123 = vfmaq_f32(vacc1x0123, vb0123, va1); + vacc1x4567 = vfmaq_f32(vacc1x4567, vb4567, va1); + vacc2x0123 = vfmaq_f32(vacc2x0123, vb0123, va2); + vacc2x4567 = vfmaq_f32(vacc2x4567, vb4567, va2); + vacc3x0123 = vfmaq_f32(vacc3x0123, vb0123, va3); + vacc3x4567 = vfmaq_f32(vacc3x4567, vb4567, va3); + vacc4x0123 = vfmaq_f32(vacc4x0123, vb0123, va4); + vacc4x4567 = vfmaq_f32(vacc4x4567, vb4567, va4); +#else + vacc0x0123 = vmlaq_f32(vacc0x0123, vb0123, va0); + vacc0x4567 = vmlaq_f32(vacc0x4567, vb4567, va0); + vacc1x0123 = vmlaq_f32(vacc1x0123, vb0123, va1); + vacc1x4567 = vmlaq_f32(vacc1x4567, vb4567, va1); + vacc2x0123 = vmlaq_f32(vacc2x0123, vb0123, va2); + vacc2x4567 = vmlaq_f32(vacc2x4567, vb4567, va2); + vacc3x0123 = vmlaq_f32(vacc3x0123, vb0123, va3); + vacc3x4567 = vmlaq_f32(vacc3x4567, vb4567, va3); + vacc4x0123 = vmlaq_f32(vacc4x0123, vb0123, va4); + vacc4x4567 = vmlaq_f32(vacc4x4567, vb4567, va4); +#endif + } + const float32x4_t vmax = vld1q_dup_f32(&clamping_params->max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc0x4567 = vminq_f32(vacc0x4567, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + vacc1x4567 = vminq_f32(vacc1x4567, vmax); + vacc2x0123 = vminq_f32(vacc2x0123, vmax); + vacc2x4567 = vminq_f32(vacc2x4567, vmax); + vacc3x0123 = vminq_f32(vacc3x0123, vmax); + vacc3x4567 = vminq_f32(vacc3x4567, vmax); + vacc4x0123 = vminq_f32(vacc4x0123, vmax); + vacc4x4567 = vminq_f32(vacc4x4567, vmax); + + const float32x4_t vmin = vld1q_dup_f32(&clamping_params->min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc0x4567 = vmaxq_f32(vacc0x4567, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + vacc1x4567 = vmaxq_f32(vacc1x4567, vmin); + vacc2x0123 = vmaxq_f32(vacc2x0123, vmin); + vacc2x4567 = vmaxq_f32(vacc2x4567, vmin); + vacc3x0123 = vmaxq_f32(vacc3x0123, vmin); + vacc3x4567 = vmaxq_f32(vacc3x4567, vmin); + vacc4x0123 = vmaxq_f32(vacc4x0123, vmin); + vacc4x4567 = vmaxq_f32(vacc4x4567, vmin); + + float* c0 = c; + float* c1 = (float*)((uintptr_t)c0 + c_stride); + if (mr < 2) { + c1 = c0; + } + float* c2 = (float*)((uintptr_t)c1 + c_stride); + if (mr <= 2) { + c2 = c1; + } + float* c3 = (float*)((uintptr_t)c2 + c_stride); + if (mr < 4) { + c3 = c2; + } + float* c4 = (float*)((uintptr_t)c3 + c_stride); + if (mr <= 4) { + c4 = c3; + } + if (nr == 8) { + vst1q_f32(c0, vacc0x0123); + c0 += 4; + vst1q_f32(c1, vacc1x0123); + c1 += 4; + vst1q_f32(c2, vacc2x0123); + c2 += 4; + vst1q_f32(c3, vacc3x0123); + c3 += 4; + vst1q_f32(c4, vacc4x0123); + c4 += 4; + + vst1q_f32(c0, vacc0x4567); + vst1q_f32(c1, vacc1x4567); + vst1q_f32(c2, vacc2x4567); + vst1q_f32(c3, vacc3x4567); + vst1q_f32(c4, vacc4x4567); + } else { + if (nr >= 4) { + vst1q_f32(c0, vacc0x0123); + c0 += 4; + vst1q_f32(c1, vacc1x0123); + c1 += 4; + vst1q_f32(c2, vacc2x0123); + c2 += 4; + vst1q_f32(c3, vacc3x0123); + c3 += 4; + vst1q_f32(c4, vacc4x0123); + c4 += 4; + vacc0x0123 = vacc0x4567; + vacc1x0123 = vacc1x4567; + vacc2x0123 = vacc2x4567; + vacc3x0123 = vacc3x4567; + vacc4x0123 = vacc4x4567; + nr -= 4; + } + if (nr >= 2) { + vst1_f32(c0, vget_low_f32(vacc0x0123)); + c0 += 2; + vst1_f32(c1, vget_low_f32(vacc1x0123)); + c1 += 2; + vst1_f32(c2, vget_low_f32(vacc2x0123)); + c2 += 2; + vst1_f32(c3, vget_low_f32(vacc3x0123)); + c3 += 2; + vst1_f32(c4, vget_low_f32(vacc4x0123)); + c4 += 2; + vacc0x0123 = vextq_f32(vacc0x0123, vacc0x0123, 2); + vacc1x0123 = vextq_f32(vacc1x0123, vacc1x0123, 2); + vacc2x0123 = vextq_f32(vacc2x0123, vacc2x0123, 2); + vacc3x0123 = vextq_f32(vacc3x0123, vacc3x0123, 2); + vacc4x0123 = vextq_f32(vacc4x0123, vacc4x0123, 2); + nr -= 2; + } + if (nr != 0) { + vst1q_lane_f32(c0, vacc0x0123, 0); + vst1q_lane_f32(c1, vacc1x0123, 0); + vst1q_lane_f32(c2, vacc2x0123, 0); + vst1q_lane_f32(c3, vacc3x0123, 0); + vst1q_lane_f32(c4, vacc4x0123, 0); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/sgemm/6x8-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/sgemm/6x8-neon.c new file mode 100644 index 0000000000000..ffc596e049e23 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/sgemm/6x8-neon.c @@ -0,0 +1,307 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_sgemm_ukernel_6x8__neon( + size_t mr, + size_t nr, + size_t k, + const float* restrict a, + size_t a_stride, + const float* restrict w, + float* restrict c, + size_t c_stride, + const struct pytorch_qnnp_fp32_clamping_params + clamping_params[restrict static 1]) { + float32x4_t vacc0x0123 = vld1q_f32(w); + w += 4; + float32x4_t vacc0x4567 = vld1q_f32(w); + w += 4; + float32x4_t vacc1x0123 = vacc0x0123; + float32x4_t vacc1x4567 = vacc0x4567; + float32x4_t vacc2x0123 = vacc0x0123; + float32x4_t vacc2x4567 = vacc0x4567; + float32x4_t vacc3x0123 = vacc0x0123; + float32x4_t vacc3x4567 = vacc0x4567; + float32x4_t vacc4x0123 = vacc0x0123; + float32x4_t vacc4x4567 = vacc0x4567; + float32x4_t vacc5x0123 = vacc0x0123; + float32x4_t vacc5x4567 = vacc0x4567; + + const float* a0 = a; + const float* a1 = (const float*)((uintptr_t)a0 + a_stride); + if (mr < 2) { + a1 = a0; + } + const float* a2 = (const float*)((uintptr_t)a1 + a_stride); + if (mr <= 2) { + a2 = a1; + } + const float* a3 = (const float*)((uintptr_t)a2 + a_stride); + if (mr < 4) { + a3 = a2; + } + const float* a4 = (const float*)((uintptr_t)a3 + a_stride); + if (mr <= 4) { + a4 = a3; + } + const float* a5 = (const float*)((uintptr_t)a4 + a_stride); + if (mr != 6) { + a5 = a4; + } + + for (; k >= 2; k -= 2) { + const float32x2_t va0 = vld1_f32(a0); + a0 += 2; + const float32x2_t va1 = vld1_f32(a1); + a1 += 2; + const float32x2_t va2 = vld1_f32(a2); + a2 += 2; + const float32x2_t va3 = vld1_f32(a3); + a3 += 2; + const float32x2_t va4 = vld1_f32(a4); + a4 += 2; + const float32x2_t va5 = vld1_f32(a5); + a5 += 2; + + { + const float32x4_t vb0123 = vld1q_f32(w); + w += 4; + const float32x4_t vb4567 = vld1q_f32(w); + w += 4; + +#if defined(__aarch64__) + vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123, va0, 0); + vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567, va0, 0); + vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123, va1, 0); + vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567, va1, 0); + vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123, va2, 0); + vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567, va2, 0); + vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123, va3, 0); + vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567, va3, 0); + vacc4x0123 = vfmaq_lane_f32(vacc4x0123, vb0123, va4, 0); + vacc4x4567 = vfmaq_lane_f32(vacc4x4567, vb4567, va4, 0); + vacc5x0123 = vfmaq_lane_f32(vacc5x0123, vb0123, va5, 0); + vacc5x4567 = vfmaq_lane_f32(vacc5x4567, vb4567, va5, 0); +#else + vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123, va0, 0); + vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567, va0, 0); + vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123, va1, 0); + vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567, va1, 0); + vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123, va2, 0); + vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567, va2, 0); + vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123, va3, 0); + vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567, va3, 0); + vacc4x0123 = vmlaq_lane_f32(vacc4x0123, vb0123, va4, 0); + vacc4x4567 = vmlaq_lane_f32(vacc4x4567, vb4567, va4, 0); + vacc5x0123 = vmlaq_lane_f32(vacc5x0123, vb0123, va5, 0); + vacc5x4567 = vmlaq_lane_f32(vacc5x4567, vb4567, va5, 0); +#endif + } + + { + const float32x4_t vb0123 = vld1q_f32(w); + w += 4; + const float32x4_t vb4567 = vld1q_f32(w); + w += 4; + +#if defined(__aarch64__) + vacc0x0123 = vfmaq_lane_f32(vacc0x0123, vb0123, va0, 1); + vacc0x4567 = vfmaq_lane_f32(vacc0x4567, vb4567, va0, 1); + vacc1x0123 = vfmaq_lane_f32(vacc1x0123, vb0123, va1, 1); + vacc1x4567 = vfmaq_lane_f32(vacc1x4567, vb4567, va1, 1); + vacc2x0123 = vfmaq_lane_f32(vacc2x0123, vb0123, va2, 1); + vacc2x4567 = vfmaq_lane_f32(vacc2x4567, vb4567, va2, 1); + vacc3x0123 = vfmaq_lane_f32(vacc3x0123, vb0123, va3, 1); + vacc3x4567 = vfmaq_lane_f32(vacc3x4567, vb4567, va3, 1); + vacc4x0123 = vfmaq_lane_f32(vacc4x0123, vb0123, va4, 1); + vacc4x4567 = vfmaq_lane_f32(vacc4x4567, vb4567, va4, 1); + vacc5x0123 = vfmaq_lane_f32(vacc5x0123, vb0123, va5, 1); + vacc5x4567 = vfmaq_lane_f32(vacc5x4567, vb4567, va5, 1); +#else + vacc0x0123 = vmlaq_lane_f32(vacc0x0123, vb0123, va0, 1); + vacc0x4567 = vmlaq_lane_f32(vacc0x4567, vb4567, va0, 1); + vacc1x0123 = vmlaq_lane_f32(vacc1x0123, vb0123, va1, 1); + vacc1x4567 = vmlaq_lane_f32(vacc1x4567, vb4567, va1, 1); + vacc2x0123 = vmlaq_lane_f32(vacc2x0123, vb0123, va2, 1); + vacc2x4567 = vmlaq_lane_f32(vacc2x4567, vb4567, va2, 1); + vacc3x0123 = vmlaq_lane_f32(vacc3x0123, vb0123, va3, 1); + vacc3x4567 = vmlaq_lane_f32(vacc3x4567, vb4567, va3, 1); + vacc4x0123 = vmlaq_lane_f32(vacc4x0123, vb0123, va4, 1); + vacc4x4567 = vmlaq_lane_f32(vacc4x4567, vb4567, va4, 1); + vacc5x0123 = vmlaq_lane_f32(vacc5x0123, vb0123, va5, 1); + vacc5x4567 = vmlaq_lane_f32(vacc5x4567, vb4567, va5, 1); +#endif + } + } + if (k != 0) { + const float32x4_t va0 = vld1q_dup_f32(a0); + const float32x4_t va1 = vld1q_dup_f32(a1); + const float32x4_t va2 = vld1q_dup_f32(a2); + const float32x4_t va3 = vld1q_dup_f32(a3); + const float32x4_t va4 = vld1q_dup_f32(a4); + const float32x4_t va5 = vld1q_dup_f32(a5); + + const float32x4_t vb0123 = vld1q_f32(w); + w += 4; + const float32x4_t vb4567 = vld1q_f32(w); + w += 4; + +#if defined(__aarch64__) + vacc0x0123 = vfmaq_f32(vacc0x0123, vb0123, va0); + vacc0x4567 = vfmaq_f32(vacc0x4567, vb4567, va0); + vacc1x0123 = vfmaq_f32(vacc1x0123, vb0123, va1); + vacc1x4567 = vfmaq_f32(vacc1x4567, vb4567, va1); + vacc2x0123 = vfmaq_f32(vacc2x0123, vb0123, va2); + vacc2x4567 = vfmaq_f32(vacc2x4567, vb4567, va2); + vacc3x0123 = vfmaq_f32(vacc3x0123, vb0123, va3); + vacc3x4567 = vfmaq_f32(vacc3x4567, vb4567, va3); + vacc4x0123 = vfmaq_f32(vacc4x0123, vb0123, va4); + vacc4x4567 = vfmaq_f32(vacc4x4567, vb4567, va4); + vacc5x0123 = vfmaq_f32(vacc5x0123, vb0123, va5); + vacc5x4567 = vfmaq_f32(vacc5x4567, vb4567, va5); +#else + vacc0x0123 = vmlaq_f32(vacc0x0123, vb0123, va0); + vacc0x4567 = vmlaq_f32(vacc0x4567, vb4567, va0); + vacc1x0123 = vmlaq_f32(vacc1x0123, vb0123, va1); + vacc1x4567 = vmlaq_f32(vacc1x4567, vb4567, va1); + vacc2x0123 = vmlaq_f32(vacc2x0123, vb0123, va2); + vacc2x4567 = vmlaq_f32(vacc2x4567, vb4567, va2); + vacc3x0123 = vmlaq_f32(vacc3x0123, vb0123, va3); + vacc3x4567 = vmlaq_f32(vacc3x4567, vb4567, va3); + vacc4x0123 = vmlaq_f32(vacc4x0123, vb0123, va4); + vacc4x4567 = vmlaq_f32(vacc4x4567, vb4567, va4); + vacc5x0123 = vmlaq_f32(vacc5x0123, vb0123, va5); + vacc5x4567 = vmlaq_f32(vacc5x4567, vb4567, va5); +#endif + } + const float32x4_t vmax = vld1q_dup_f32(&clamping_params->max); + vacc0x0123 = vminq_f32(vacc0x0123, vmax); + vacc0x4567 = vminq_f32(vacc0x4567, vmax); + vacc1x0123 = vminq_f32(vacc1x0123, vmax); + vacc1x4567 = vminq_f32(vacc1x4567, vmax); + vacc2x0123 = vminq_f32(vacc2x0123, vmax); + vacc2x4567 = vminq_f32(vacc2x4567, vmax); + vacc3x0123 = vminq_f32(vacc3x0123, vmax); + vacc3x4567 = vminq_f32(vacc3x4567, vmax); + vacc4x0123 = vminq_f32(vacc4x0123, vmax); + vacc4x4567 = vminq_f32(vacc4x4567, vmax); + vacc5x0123 = vminq_f32(vacc5x0123, vmax); + vacc5x4567 = vminq_f32(vacc5x4567, vmax); + + const float32x4_t vmin = vld1q_dup_f32(&clamping_params->min); + vacc0x0123 = vmaxq_f32(vacc0x0123, vmin); + vacc0x4567 = vmaxq_f32(vacc0x4567, vmin); + vacc1x0123 = vmaxq_f32(vacc1x0123, vmin); + vacc1x4567 = vmaxq_f32(vacc1x4567, vmin); + vacc2x0123 = vmaxq_f32(vacc2x0123, vmin); + vacc2x4567 = vmaxq_f32(vacc2x4567, vmin); + vacc3x0123 = vmaxq_f32(vacc3x0123, vmin); + vacc3x4567 = vmaxq_f32(vacc3x4567, vmin); + vacc4x0123 = vmaxq_f32(vacc4x0123, vmin); + vacc4x4567 = vmaxq_f32(vacc4x4567, vmin); + vacc5x0123 = vmaxq_f32(vacc5x0123, vmin); + vacc5x4567 = vmaxq_f32(vacc5x4567, vmin); + + float* c0 = c; + float* c1 = (float*)((uintptr_t)c0 + c_stride); + if (mr < 2) { + c1 = c0; + } + float* c2 = (float*)((uintptr_t)c1 + c_stride); + if (mr <= 2) { + c2 = c1; + } + float* c3 = (float*)((uintptr_t)c2 + c_stride); + if (mr < 4) { + c3 = c2; + } + float* c4 = (float*)((uintptr_t)c3 + c_stride); + if (mr <= 4) { + c4 = c3; + } + float* c5 = (float*)((uintptr_t)c4 + c_stride); + if (mr != 6) { + c5 = c4; + } + if (nr == 8) { + vst1q_f32(c0, vacc0x0123); + c0 += 4; + vst1q_f32(c1, vacc1x0123); + c1 += 4; + vst1q_f32(c2, vacc2x0123); + c2 += 4; + vst1q_f32(c3, vacc3x0123); + c3 += 4; + vst1q_f32(c4, vacc4x0123); + c4 += 4; + vst1q_f32(c5, vacc5x0123); + c5 += 4; + + vst1q_f32(c0, vacc0x4567); + vst1q_f32(c1, vacc1x4567); + vst1q_f32(c2, vacc2x4567); + vst1q_f32(c3, vacc3x4567); + vst1q_f32(c4, vacc4x4567); + vst1q_f32(c5, vacc5x4567); + } else { + if (nr >= 4) { + vst1q_f32(c0, vacc0x0123); + c0 += 4; + vst1q_f32(c1, vacc1x0123); + c1 += 4; + vst1q_f32(c2, vacc2x0123); + c2 += 4; + vst1q_f32(c3, vacc3x0123); + c3 += 4; + vst1q_f32(c4, vacc4x0123); + c4 += 4; + vst1q_f32(c5, vacc5x0123); + c5 += 4; + vacc0x0123 = vacc0x4567; + vacc1x0123 = vacc1x4567; + vacc2x0123 = vacc2x4567; + vacc3x0123 = vacc3x4567; + vacc4x0123 = vacc4x4567; + vacc5x0123 = vacc5x4567; + nr -= 4; + } + if (nr >= 2) { + vst1_f32(c0, vget_low_f32(vacc0x0123)); + c0 += 2; + vst1_f32(c1, vget_low_f32(vacc1x0123)); + c1 += 2; + vst1_f32(c2, vget_low_f32(vacc2x0123)); + c2 += 2; + vst1_f32(c3, vget_low_f32(vacc3x0123)); + c3 += 2; + vst1_f32(c4, vget_low_f32(vacc4x0123)); + c4 += 2; + vst1_f32(c5, vget_low_f32(vacc5x0123)); + c5 += 2; + vacc0x0123 = vextq_f32(vacc0x0123, vacc0x0123, 2); + vacc1x0123 = vextq_f32(vacc1x0123, vacc1x0123, 2); + vacc2x0123 = vextq_f32(vacc2x0123, vacc2x0123, 2); + vacc3x0123 = vextq_f32(vacc3x0123, vacc3x0123, 2); + vacc4x0123 = vextq_f32(vacc4x0123, vacc4x0123, 2); + vacc5x0123 = vextq_f32(vacc5x0123, vacc5x0123, 2); + nr -= 2; + } + if (nr != 0) { + vst1q_lane_f32(c0, vacc0x0123, 0); + vst1q_lane_f32(c1, vacc1x0123, 0); + vst1q_lane_f32(c2, vacc2x0123, 0); + vst1q_lane_f32(c3, vacc3x0123, 0); + vst1q_lane_f32(c4, vacc4x0123, 0); + vst1q_lane_f32(c5, vacc5x0123, 0); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/sgemm/6x8-psimd.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/sgemm/6x8-psimd.c new file mode 100644 index 0000000000000..62b9a761d30ec --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/sgemm/6x8-psimd.c @@ -0,0 +1,215 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_sgemm_ukernel_6x8__psimd( + size_t mr, + size_t nr, + size_t k, + const float* restrict a, + size_t a_stride, + const float* restrict w, + float* restrict c, + size_t c_stride, + const struct pytorch_qnnp_fp32_clamping_params + clamping_params[restrict static 1]) { + psimd_f32 vacc0x0123 = psimd_load_f32(w); + w += 4; + psimd_f32 vacc0x4567 = psimd_load_f32(w); + w += 4; + psimd_f32 vacc1x0123 = vacc0x0123; + psimd_f32 vacc1x4567 = vacc0x4567; + psimd_f32 vacc2x0123 = vacc0x0123; + psimd_f32 vacc2x4567 = vacc0x4567; + psimd_f32 vacc3x0123 = vacc0x0123; + psimd_f32 vacc3x4567 = vacc0x4567; + psimd_f32 vacc4x0123 = vacc0x0123; + psimd_f32 vacc4x4567 = vacc0x4567; + psimd_f32 vacc5x0123 = vacc0x0123; + psimd_f32 vacc5x4567 = vacc0x4567; + + const float* a0 = a; + const float* a1 = (const float*)((uintptr_t)a0 + a_stride); + if (mr < 2) { + a1 = a0; + } + const float* a2 = (const float*)((uintptr_t)a1 + a_stride); + if (mr <= 2) { + a2 = a1; + } + const float* a3 = (const float*)((uintptr_t)a2 + a_stride); + if (mr < 4) { + a3 = a2; + } + const float* a4 = (const float*)((uintptr_t)a3 + a_stride); + if (mr <= 4) { + a4 = a3; + } + const float* a5 = (const float*)((uintptr_t)a4 + a_stride); + if (mr != 6) { + a5 = a4; + } + + do { + const psimd_f32 va0 = psimd_splat_f32(*a0); + a0 += 1; + const psimd_f32 va1 = psimd_splat_f32(*a1); + a1 += 1; + const psimd_f32 va2 = psimd_splat_f32(*a2); + a2 += 1; + const psimd_f32 va3 = psimd_splat_f32(*a3); + a3 += 1; + const psimd_f32 va4 = psimd_splat_f32(*a4); + a4 += 1; + const psimd_f32 va5 = psimd_splat_f32(*a5); + a5 += 1; + + const psimd_f32 vb0123 = psimd_load_f32(w); + w += 4; + const psimd_f32 vb4567 = psimd_load_f32(w); + w += 4; + + vacc0x0123 += vb0123 * va0; + vacc0x4567 += vb4567 * va0; + vacc1x0123 += vb0123 * va1; + vacc1x4567 += vb4567 * va1; + vacc2x0123 += vb0123 * va2; + vacc2x4567 += vb4567 * va2; + vacc3x0123 += vb0123 * va3; + vacc3x4567 += vb4567 * va3; + vacc4x0123 += vb0123 * va4; + vacc4x4567 += vb4567 * va4; + vacc5x0123 += vb0123 * va5; + vacc5x4567 += vb4567 * va5; + } while (--k != 0); + + const psimd_f32 vmax = psimd_splat_f32(clamping_params->max); + vacc0x0123 = psimd_min_f32(vacc0x0123, vmax); + vacc0x4567 = psimd_min_f32(vacc0x4567, vmax); + vacc1x0123 = psimd_min_f32(vacc1x0123, vmax); + vacc1x4567 = psimd_min_f32(vacc1x4567, vmax); + vacc2x0123 = psimd_min_f32(vacc2x0123, vmax); + vacc2x4567 = psimd_min_f32(vacc2x4567, vmax); + vacc3x0123 = psimd_min_f32(vacc3x0123, vmax); + vacc3x4567 = psimd_min_f32(vacc3x4567, vmax); + vacc4x0123 = psimd_min_f32(vacc4x0123, vmax); + vacc4x4567 = psimd_min_f32(vacc4x4567, vmax); + vacc5x0123 = psimd_min_f32(vacc5x0123, vmax); + vacc5x4567 = psimd_min_f32(vacc5x4567, vmax); + + const psimd_f32 vmin = psimd_splat_f32(clamping_params->min); + vacc0x0123 = psimd_max_f32(vacc0x0123, vmin); + vacc0x4567 = psimd_max_f32(vacc0x4567, vmin); + vacc1x0123 = psimd_max_f32(vacc1x0123, vmin); + vacc1x4567 = psimd_max_f32(vacc1x4567, vmin); + vacc2x0123 = psimd_max_f32(vacc2x0123, vmin); + vacc2x4567 = psimd_max_f32(vacc2x4567, vmin); + vacc3x0123 = psimd_max_f32(vacc3x0123, vmin); + vacc3x4567 = psimd_max_f32(vacc3x4567, vmin); + vacc4x0123 = psimd_max_f32(vacc4x0123, vmin); + vacc4x4567 = psimd_max_f32(vacc4x4567, vmin); + vacc5x0123 = psimd_max_f32(vacc5x0123, vmin); + vacc5x4567 = psimd_max_f32(vacc5x4567, vmin); + + float* c0 = c; + float* c1 = (float*)((uintptr_t)c0 + c_stride); + if (mr < 2) { + c1 = c0; + } + float* c2 = (float*)((uintptr_t)c1 + c_stride); + if (mr <= 2) { + c2 = c1; + } + float* c3 = (float*)((uintptr_t)c2 + c_stride); + if (mr < 4) { + c3 = c2; + } + float* c4 = (float*)((uintptr_t)c3 + c_stride); + if (mr <= 4) { + c4 = c3; + } + float* c5 = (float*)((uintptr_t)c4 + c_stride); + if (mr != 6) { + c5 = c4; + } + if (nr == 8) { + psimd_store_f32(c0, vacc0x0123); + c0 += 4; + psimd_store_f32(c1, vacc1x0123); + c1 += 4; + psimd_store_f32(c2, vacc2x0123); + c2 += 4; + psimd_store_f32(c3, vacc3x0123); + c3 += 4; + psimd_store_f32(c4, vacc4x0123); + c4 += 4; + psimd_store_f32(c5, vacc5x0123); + c5 += 4; + + psimd_store_f32(c0, vacc0x4567); + psimd_store_f32(c1, vacc1x4567); + psimd_store_f32(c2, vacc2x4567); + psimd_store_f32(c3, vacc3x4567); + psimd_store_f32(c4, vacc4x4567); + psimd_store_f32(c5, vacc5x4567); + } else { + if (nr >= 4) { + psimd_store_f32(c0, vacc0x0123); + c0 += 4; + psimd_store_f32(c1, vacc1x0123); + c1 += 4; + psimd_store_f32(c2, vacc2x0123); + c2 += 4; + psimd_store_f32(c3, vacc3x0123); + c3 += 4; + psimd_store_f32(c4, vacc4x0123); + c4 += 4; + psimd_store_f32(c5, vacc5x0123); + c5 += 4; + vacc0x0123 = vacc0x4567; + vacc1x0123 = vacc1x4567; + vacc2x0123 = vacc2x4567; + vacc3x0123 = vacc3x4567; + vacc4x0123 = vacc4x4567; + vacc5x0123 = vacc5x4567; + nr -= 4; + } + if (nr >= 2) { + psimd_store2_f32(c0, vacc0x0123); + c0 += 2; + psimd_store2_f32(c1, vacc1x0123); + c1 += 2; + psimd_store2_f32(c2, vacc2x0123); + c2 += 2; + psimd_store2_f32(c3, vacc3x0123); + c3 += 2; + psimd_store2_f32(c4, vacc4x0123); + c4 += 2; + psimd_store2_f32(c5, vacc5x0123); + c5 += 2; + vacc0x0123 = psimd_concat_hi_f32(vacc0x0123, vacc0x0123); + vacc1x0123 = psimd_concat_hi_f32(vacc1x0123, vacc1x0123); + vacc2x0123 = psimd_concat_hi_f32(vacc2x0123, vacc2x0123); + vacc3x0123 = psimd_concat_hi_f32(vacc3x0123, vacc3x0123); + vacc4x0123 = psimd_concat_hi_f32(vacc4x0123, vacc4x0123); + vacc5x0123 = psimd_concat_hi_f32(vacc5x0123, vacc5x0123); + nr -= 2; + } + if (nr != 0) { + psimd_store1_f32(c0, vacc0x0123); + psimd_store1_f32(c1, vacc1x0123); + psimd_store1_f32(c2, vacc2x0123); + psimd_store1_f32(c3, vacc3x0123); + psimd_store1_f32(c4, vacc4x0123); + psimd_store1_f32(c5, vacc5x0123); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/sigmoid.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/sigmoid.c new file mode 100644 index 0000000000000..8c4e422c6c626 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/sigmoid.c @@ -0,0 +1,159 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +enum pytorch_qnnp_status pytorch_qnnp_create_sigmoid_nc_q8( + size_t channels, + uint8_t input_zero_point, + float input_scale, + uint8_t output_zero_point, + float output_scale, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + pytorch_qnnp_operator_t* sigmoid_out) { + pytorch_qnnp_operator_t sigmoid_op = NULL; + enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized; + + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_create_sigmoid_nc_q8 failed because QNNPACK is not properly initialized"); + goto error; + } + + status = pytorch_qnnp_status_invalid_parameter; + + if (channels == 0) { + pytorch_qnnp_log_error( + "failed to create Sigmoid operator with %zu channels: number of channels must be non-zero", + channels); + goto error; + } + + if (input_scale <= 0.0f || !isnormal(input_scale)) { + pytorch_qnnp_log_error( + "failed to create Sigmoid operator with %.7g input scale: scale must be finite and positive", + input_scale); + goto error; + } + + if (output_scale <= 0.0f || !isnormal(output_scale)) { + pytorch_qnnp_log_error( + "failed to create Sigmoid operator with %.7g output scale: scale must be finite and positive", + output_scale); + goto error; + } + + if (output_min >= output_max) { + pytorch_qnnp_log_error( + "failed to create Sigmoid operator with [%" PRIu8 ", %" PRIu8 + "] output range: range min must be below range max", + output_min, + output_max); + goto error; + } + + status = pytorch_qnnp_status_unsupported_parameter; + + if (output_scale != 0x1.0p-8f) { + pytorch_qnnp_log_error( + "failed to create Sigmoid operator with %.7g output scale: only output scale of 1/256 is supported", + output_scale); + goto error; + } + + if (output_zero_point != 0) { + pytorch_qnnp_log_error( + "failed to create Sigmoid operator with %" PRIu8 + " output zero point: only output zero point of 0 is supported", + output_zero_point); + goto error; + } + + status = pytorch_qnnp_status_out_of_memory; + + sigmoid_op = calloc(1, sizeof(struct pytorch_qnnp_operator)); + if (sigmoid_op == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for pytorch_qnnp_operator structure", + sizeof(struct pytorch_qnnp_operator)); + goto error; + } + + sigmoid_op->lookup_table = malloc(256 * sizeof(uint8_t)); + if (sigmoid_op->lookup_table == NULL) { + pytorch_qnnp_log_error( + "failed to allocate 256 bytes for Sigmoid lookup table"); + goto error; + } + + uint8_t* lookup_table = sigmoid_op->lookup_table; + const float scaled_min = (float)(int32_t)output_min; + const float scaled_max = (float)(int32_t)output_max; + for (int32_t i = 0; i < 256; i++) { + const float x = + input_scale * (float)(i - (int32_t)(uint32_t)input_zero_point); + /* Scale sigmoid(x) by 1 / output scale = 256.0 */ + float scaled_sigmoid_x = 256.0f / (1.0f + expf(-x)); + if (scaled_sigmoid_x < scaled_min) { + scaled_sigmoid_x = scaled_min; + } + if (scaled_sigmoid_x > scaled_max) { + scaled_sigmoid_x = scaled_max; + } + lookup_table[(uint32_t)i] = (uint8_t)lrintf(scaled_sigmoid_x); + } + + sigmoid_op->channels = channels; + + sigmoid_op->ukernel_type = pytorch_qnnp_ukernel_type_lut; + sigmoid_op->format = pytorch_qnnp_format_quint8; + + *sigmoid_out = sigmoid_op; + return pytorch_qnnp_status_success; + +error: + pytorch_qnnp_delete_operator(sigmoid_op); + return status; +} + +enum pytorch_qnnp_status pytorch_qnnp_setup_sigmoid_nc_q8( + pytorch_qnnp_operator_t sigmoid, + size_t batch_size, + const uint8_t* input, + size_t input_stride, + uint8_t* output, + size_t output_stride) { + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_setup_sigmoid_nc_q8 failed because QNNPACK is not properly initialized"); + return pytorch_qnnp_status_uninitialized; + } + + if (batch_size == 0) { + sigmoid->batch_size = 0; + return pytorch_qnnp_status_success; + } + + sigmoid->batch_size = batch_size; + sigmoid->input = input; + sigmoid->input_pixel_stride = input_stride; + sigmoid->output = output; + sigmoid->output_pixel_stride = output_stride; + + return pytorch_qnnp_status_success; +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/softargmax.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/softargmax.c new file mode 100644 index 0000000000000..2513450c5551f --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/softargmax.c @@ -0,0 +1,139 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +enum pytorch_qnnp_status pytorch_qnnp_create_softargmax_nc_q8( + size_t channels, + float input_scale, + uint8_t output_zero_point, + float output_scale, + uint32_t flags, + pytorch_qnnp_operator_t* softargmax_out) { + pytorch_qnnp_operator_t softargmax_op = NULL; + enum pytorch_qnnp_status status = pytorch_qnnp_status_uninitialized; + + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_create_softargmax_nc_q8 failed because QNNPACK is not properly initialized"); + goto error; + } + + status = pytorch_qnnp_status_invalid_parameter; + + if (channels == 0) { + pytorch_qnnp_log_error( + "failed to create Soft ArgMax operator with %zu channels: number of channels must be non-zero", + channels); + goto error; + } + + if (input_scale <= 0.0f || !isnormal(input_scale)) { + pytorch_qnnp_log_error( + "failed to create Soft ArgMax operator with %.7g input scale: scale must be finite and positive", + input_scale); + goto error; + } + + if (output_scale <= 0.0f || !isnormal(output_scale)) { + pytorch_qnnp_log_error( + "failed to create Soft ArgMax operator with %.7g output scale: scale must be finite and positive", + output_scale); + goto error; + } + + status = pytorch_qnnp_status_unsupported_parameter; + + if (output_scale != 0x1.0p-8f) { + pytorch_qnnp_log_error( + "failed to create Soft ArgMax operator with %.7g output scale: only output scale of 1/256 is supported", + output_scale); + goto error; + } + + if (output_zero_point != 0) { + pytorch_qnnp_log_error( + "failed to create Soft ArgMax operator with %" PRIu8 + " output zero point: only output zero point of 0 is supported", + output_zero_point); + goto error; + } + + status = pytorch_qnnp_status_out_of_memory; + + softargmax_op = calloc(1, sizeof(struct pytorch_qnnp_operator)); + if (softargmax_op == NULL) { + pytorch_qnnp_log_error( + "failed to allocate %zu bytes for pytorch_qnnp_operator structure", + sizeof(struct pytorch_qnnp_operator)); + goto error; + } + + softargmax_op->lookup_table = malloc(256 * sizeof(uint32_t)); + if (softargmax_op->lookup_table == NULL) { + pytorch_qnnp_log_error( + "failed to allocate 256 bytes for Soft ArgMax lookup table"); + goto error; + } + + uint32_t* lookup_table = softargmax_op->lookup_table; + const double qscale = + fmin(((double)UINT32_MAX) / (double)channels, 8388607.0); + for (int32_t i = 0; i < 256; i++) { + const double scaled_exp_xi = + qscale * exp((double)(i - 255) * (double)input_scale); + lookup_table[(uint32_t)i] = (uint32_t)lrint(scaled_exp_xi); + } + + softargmax_op->channels = channels; + + softargmax_op->ukernel_type = pytorch_qnnp_ukernel_type_softargmax; + softargmax_op->format = pytorch_qnnp_format_quint8; + + *softargmax_out = softargmax_op; + return pytorch_qnnp_status_success; + +error: + pytorch_qnnp_delete_operator(softargmax_op); + return status; +} + +enum pytorch_qnnp_status pytorch_qnnp_setup_softargmax_nc_q8( + pytorch_qnnp_operator_t softargmax, + size_t batch_size, + const uint8_t* input, + size_t input_stride, + uint8_t* output, + size_t output_stride) { + if (!pytorch_qnnp_params.initialized) { + pytorch_qnnp_log_error( + "pytorch_qnnp_setup_softargmax_nc_q8 failed because QNNPACK is not properly initialized"); + return pytorch_qnnp_status_uninitialized; + } + + if (batch_size == 0) { + softargmax->batch_size = 0; + return pytorch_qnnp_status_success; + } + + softargmax->batch_size = batch_size; + softargmax->input = input; + softargmax->input_pixel_stride = input_stride; + softargmax->output = output; + softargmax->output_pixel_stride = output_stride; + + return pytorch_qnnp_status_success; +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8clamp/neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8clamp/neon.c new file mode 100644 index 0000000000000..3f063e9d9a8e1 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8clamp/neon.c @@ -0,0 +1,86 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +void pytorch_u8clamp_ukernel__neon( + size_t n, + const uint8_t* x, + uint8_t* y, + const union pytorch_qnnp_u8_clamping_params params[restrict static 1]) { + assert(n != 0); + + const uint8x16_t voutput_max = vld1q_dup_u8(¶ms->neon.output_max); + const uint8x16_t voutput_min = vld1q_dup_u8(¶ms->neon.output_min); + + if + PYTORCH_QNNP_LIKELY(n >= 8) { + for (; n >= 64; n -= 64) { + const uint8x16_t vx0 = vld1q_u8(x); + x += 16; + const uint8x16_t vx1 = vld1q_u8(x); + x += 16; + const uint8x16_t vx2 = vld1q_u8(x); + x += 16; + const uint8x16_t vx3 = vld1q_u8(x); + x += 16; + + const uint8x16_t vy0 = + vminq_u8(vmaxq_u8(vx0, voutput_min), voutput_max); + const uint8x16_t vy1 = + vminq_u8(vmaxq_u8(vx1, voutput_min), voutput_max); + const uint8x16_t vy2 = + vminq_u8(vmaxq_u8(vx2, voutput_min), voutput_max); + const uint8x16_t vy3 = + vminq_u8(vmaxq_u8(vx3, voutput_min), voutput_max); + + __builtin_prefetch(x + 640); + + vst1q_u8(y, vy0); + y += 16; + vst1q_u8(y, vy1); + y += 16; + vst1q_u8(y, vy2); + y += 16; + vst1q_u8(y, vy3); + y += 16; + } + for (; n >= 8; n -= 8) { + uint8x8_t vout = vld1_u8(x); + x += 8; + vout = vmin_u8(vout, vget_low_u8(voutput_max)); + vout = vmax_u8(vout, vget_low_u8(voutput_min)); + vst1_u8(y, vout); + y += 8; + } + if (n != 0) { + const size_t n_increment = n - 8; + x = (const uint8_t*)((uintptr_t)x + n_increment); + y = (uint8_t*)((uintptr_t)y + n_increment); + + uint8x8_t vout = vld1_u8(x); + vout = vmin_u8(vout, vget_low_u8(voutput_max)); + vout = vmax_u8(vout, vget_low_u8(voutput_min)); + vst1_u8(y, vout); + } + } + else { + do { + uint8x8_t vout = vld1_dup_u8(x); + x += 1; + vout = vmin_u8(vout, vget_low_u8(voutput_max)); + vout = vmax_u8(vout, vget_low_u8(voutput_min)); + vst1_lane_u8(y, vout, 0); + y += 1; + } while (--n != 0); + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8clamp/sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8clamp/sse2.c new file mode 100644 index 0000000000000..5058842968195 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8clamp/sse2.c @@ -0,0 +1,81 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +void pytorch_u8clamp_ukernel__sse2( + size_t n, + const uint8_t* x, + uint8_t* y, + const union pytorch_qnnp_u8_clamping_params params[RESTRICT_STATIC 1]) { + assert(n != 0); + + if + PYTORCH_QNNP_LIKELY(n >= 8) { + const __m128i voutput_max = + _mm_load_si128((const __m128i*)¶ms->sse2.output_max); + const __m128i voutput_min = + _mm_load_si128((const __m128i*)¶ms->sse2.output_min); + for (; n >= 64; n -= 64) { + const __m128i vx0 = _mm_loadu_si128((const __m128i*)x); + const __m128i vx1 = _mm_loadu_si128((const __m128i*)x + 1); + const __m128i vx2 = _mm_loadu_si128((const __m128i*)x + 2); + const __m128i vx3 = _mm_loadu_si128((const __m128i*)x + 3); + x += 64; + + const __m128i vy0 = + _mm_min_epu8(_mm_max_epu8(vx0, voutput_min), voutput_max); + const __m128i vy1 = + _mm_min_epu8(_mm_max_epu8(vx1, voutput_min), voutput_max); + const __m128i vy2 = + _mm_min_epu8(_mm_max_epu8(vx2, voutput_min), voutput_max); + const __m128i vy3 = + _mm_min_epu8(_mm_max_epu8(vx3, voutput_min), voutput_max); + + __builtin_prefetch(x + 640); + + _mm_storeu_si128((__m128i*)y, vy0); + _mm_storeu_si128((__m128i*)y + 1, vy1); + _mm_storeu_si128((__m128i*)y + 2, vy2); + _mm_storeu_si128((__m128i*)y + 3, vy3); + y += 64; + } + for (; n >= 8; n -= 8) { + __m128i vout = _mm_loadl_epi64((const __m128i*)x); + x += 8; + vout = _mm_min_epu8(vout, voutput_max); + vout = _mm_max_epu8(vout, voutput_min); + _mm_storel_epi64((__m128i*)y, vout); + y += 8; + } + if (n != 0) { + const size_t n_increment = n - 8; + x = (const uint8_t*)((uintptr_t)x + n_increment); + y = (uint8_t*)((uintptr_t)y + n_increment); + + __m128i vout = _mm_loadl_epi64((const __m128i*)x); + vout = _mm_min_epu8(vout, voutput_max); + vout = _mm_max_epu8(vout, voutput_min); + _mm_storel_epi64((__m128i*)y, vout); + } + } + else { + const uint32_t voutput_max = params->sse2.output_max[0]; + const uint32_t voutput_min = params->sse2.output_min[0]; + do { + uint32_t vout = *x++; + vout = vout > voutput_max ? voutput_max : vout; + vout = vout < voutput_min ? voutput_min : vout; + *y++ = (uint8_t)vout; + } while (--n != 0); + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8lut32norm/scalar.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8lut32norm/scalar.c new file mode 100644 index 0000000000000..8cb20276c92d7 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8lut32norm/scalar.c @@ -0,0 +1,49 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +static inline uint32_t compute_sum( + size_t n, + const uint8_t* x, + const uint32_t* t) { + assert(n != 0); + + uint32_t vsum = 0; + do { + const size_t vx = *x++; + vsum += t[vx]; + } while (--n != 0); + return vsum; +} + +void pytorch_u8lut32norm_ukernel__scalar( + size_t n, + const uint8_t* x, + const uint32_t* t, + uint8_t* y) { + assert(n != 0); + + const uint32_t vsum = compute_sum(n, x, t); + assert(vsum != 0); + + struct fxdiv_divisor_uint32_t vsum_divisor = fxdiv_init_uint32_t(vsum); + const uint32_t vrounding = (vsum >> 1); + do { + const size_t vx = *x++; + const uint32_t vt = t[vx]; + const uint32_t vq = + fxdiv_quotient_uint32_t((vt << 8) + vrounding, vsum_divisor); + const uint8_t vy = vq > 255 ? UINT8_C(255) : (uint8_t)vq; + *y++ = vy; + } while (--n != 0); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8maxpool/16x9p8q-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8maxpool/16x9p8q-neon.c new file mode 100644 index 0000000000000..625a114c202d0 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8maxpool/16x9p8q-neon.c @@ -0,0 +1,251 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +void pytorch_u8maxpool_ukernel_16x9p8q__neon( + size_t n, + size_t ks, + size_t kc, + const uint8_t** input, + uint8_t* output, + size_t input_increment, + size_t output_increment, + const union pytorch_qnnp_u8_clamping_params params[restrict static 1]) { + assert(n != 0); + assert(ks != 0); + assert(kc >= 16); + + const uint8x16_t voutput_max = vld1q_dup_u8(¶ms->neon.output_max); + const uint8x16_t voutput_min = vld1q_dup_u8(¶ms->neon.output_min); + do { + uint8_t* o = output; + { + const uint8_t* i0 = *input++; + const uint8_t* i1 = *input++; + const uint8_t* i2 = *input++; + const uint8_t* i3 = *input++; + const uint8_t* i4 = *input++; + const uint8_t* i5 = *input++; + const uint8_t* i6 = *input++; + const uint8_t* i7 = *input++; + const uint8_t* i8 = *input++; + if (ks < 2) { + i1 = i0; + } + if (ks <= 2) { + i2 = i0; + } + if (ks < 4) { + i3 = i0; + } + if (ks <= 4) { + i4 = i0; + } + if (ks < 6) { + i5 = i0; + } + if (ks <= 6) { + i6 = i0; + } + if (ks < 8) { + i7 = i0; + } + if (ks <= 8) { + i8 = i0; + } + + size_t k = kc; + while (k >= 16) { + const uint8x16_t vi0 = vld1q_u8(i0); + i0 += 16; + const uint8x16_t vi1 = vld1q_u8(i1); + i1 += 16; + const uint8x16_t vi2 = vld1q_u8(i2); + i2 += 16; + const uint8x16_t vi3 = vld1q_u8(i3); + i3 += 16; + const uint8x16_t vi4 = vld1q_u8(i4); + i4 += 16; + const uint8x16_t vi5 = vld1q_u8(i5); + i5 += 16; + const uint8x16_t vi6 = vld1q_u8(i6); + i6 += 16; + const uint8x16_t vi7 = vld1q_u8(i7); + i7 += 16; + const uint8x16_t vi8 = vld1q_u8(i8); + i8 += 16; + + const uint8x16_t vmax018 = vmaxq_u8(vmaxq_u8(vi0, vi1), vi8); + const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3); + const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5); + const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7); + + const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45); + const uint8x16_t vmax01678 = vmaxq_u8(vmax018, vmax67); + const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax01678); + const uint8x16_t vout = + vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min); + + vst1q_u8(o, vout); + o += 16; + + k -= 16; + } + if (k != 0) { + const size_t address_increment = k - 16; + i0 = (const uint8_t*)((uintptr_t)i0 + address_increment); + i1 = (const uint8_t*)((uintptr_t)i1 + address_increment); + i2 = (const uint8_t*)((uintptr_t)i2 + address_increment); + i3 = (const uint8_t*)((uintptr_t)i3 + address_increment); + i4 = (const uint8_t*)((uintptr_t)i4 + address_increment); + i5 = (const uint8_t*)((uintptr_t)i5 + address_increment); + i6 = (const uint8_t*)((uintptr_t)i6 + address_increment); + i7 = (const uint8_t*)((uintptr_t)i7 + address_increment); + i8 = (const uint8_t*)((uintptr_t)i8 + address_increment); + o = (uint8_t*)((uintptr_t)o + address_increment); + + const uint8x16_t vi0 = vld1q_u8(i0); + const uint8x16_t vi1 = vld1q_u8(i1); + const uint8x16_t vi2 = vld1q_u8(i2); + const uint8x16_t vi3 = vld1q_u8(i3); + const uint8x16_t vi4 = vld1q_u8(i4); + const uint8x16_t vi5 = vld1q_u8(i5); + const uint8x16_t vi6 = vld1q_u8(i6); + const uint8x16_t vi7 = vld1q_u8(i7); + const uint8x16_t vi8 = vld1q_u8(i8); + + const uint8x16_t vmax018 = vmaxq_u8(vmaxq_u8(vi0, vi1), vi8); + const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3); + const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5); + const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7); + + const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45); + const uint8x16_t vmax01678 = vmaxq_u8(vmax018, vmax67); + const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax01678); + const uint8x16_t vout = + vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min); + + vst1q_u8(o, vout); + o += 16; + } + } + + for (ptrdiff_t m = (ptrdiff_t)ks - 9; m > 0; m -= 8) { + const uint8_t* i0 = *input++; + const uint8_t* i1 = *input++; + const uint8_t* i2 = *input++; + const uint8_t* i3 = *input++; + const uint8_t* i4 = *input++; + const uint8_t* i5 = *input++; + const uint8_t* i6 = *input++; + const uint8_t* i7 = *input++; + if (m < 2) { + i1 = i0; + } + if (m <= 2) { + i2 = i0; + } + if (m < 4) { + i3 = i0; + } + if (m <= 4) { + i4 = i0; + } + if (m < 6) { + i5 = i0; + } + if (m <= 6) { + i6 = i0; + } + if (m < 8) { + i7 = i0; + } + + o = output; + size_t k = kc; + while (k >= 16) { + const uint8x16_t vi0 = vld1q_u8(i0); + i0 += 16; + const uint8x16_t vi1 = vld1q_u8(i1); + i1 += 16; + const uint8x16_t vi2 = vld1q_u8(i2); + i2 += 16; + const uint8x16_t vi3 = vld1q_u8(i3); + i3 += 16; + const uint8x16_t vi4 = vld1q_u8(i4); + i4 += 16; + const uint8x16_t vi5 = vld1q_u8(i5); + i5 += 16; + const uint8x16_t vi6 = vld1q_u8(i6); + i6 += 16; + const uint8x16_t vi7 = vld1q_u8(i7); + i7 += 16; + const uint8x16_t vo = vld1q_u8(o); + + const uint8x16_t vmax01 = vmaxq_u8(vmaxq_u8(vi0, vi1), vo); + const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3); + const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5); + const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7); + + const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45); + const uint8x16_t vmax0167 = vmaxq_u8(vmax01, vmax67); + const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax0167); + const uint8x16_t vout = + vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min); + + vst1q_u8(o, vout); + o += 16; + + k -= 16; + } + if (k != 0) { + const size_t address_increment = k - 16; + i0 = (const uint8_t*)((uintptr_t)i0 + address_increment); + i1 = (const uint8_t*)((uintptr_t)i1 + address_increment); + i2 = (const uint8_t*)((uintptr_t)i2 + address_increment); + i3 = (const uint8_t*)((uintptr_t)i3 + address_increment); + i4 = (const uint8_t*)((uintptr_t)i4 + address_increment); + i5 = (const uint8_t*)((uintptr_t)i5 + address_increment); + i6 = (const uint8_t*)((uintptr_t)i6 + address_increment); + i7 = (const uint8_t*)((uintptr_t)i7 + address_increment); + o = (uint8_t*)((uintptr_t)o + address_increment); + + const uint8x16_t vi0 = vld1q_u8(i0); + const uint8x16_t vi1 = vld1q_u8(i1); + const uint8x16_t vi2 = vld1q_u8(i2); + const uint8x16_t vi3 = vld1q_u8(i3); + const uint8x16_t vi4 = vld1q_u8(i4); + const uint8x16_t vi5 = vld1q_u8(i5); + const uint8x16_t vi6 = vld1q_u8(i6); + const uint8x16_t vi7 = vld1q_u8(i7); + const uint8x16_t vo = vld1q_u8(o); + + const uint8x16_t vmax01 = vmaxq_u8(vmaxq_u8(vi0, vi1), vo); + const uint8x16_t vmax23 = vmaxq_u8(vi2, vi3); + const uint8x16_t vmax45 = vmaxq_u8(vi4, vi5); + const uint8x16_t vmax67 = vmaxq_u8(vi6, vi7); + + const uint8x16_t vmax2345 = vmaxq_u8(vmax23, vmax45); + const uint8x16_t vmax0167 = vmaxq_u8(vmax01, vmax67); + const uint8x16_t vmax = vmaxq_u8(vmax2345, vmax0167); + const uint8x16_t vout = + vmaxq_u8(vminq_u8(vmax, voutput_max), voutput_min); + + vst1q_u8(o, vout); + o += 16; + } + } + input = (const uint8_t**)((uintptr_t)input + input_increment); + output = (uint8_t*)((uintptr_t)o + output_increment); + } while (--n != 0); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8maxpool/16x9p8q-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8maxpool/16x9p8q-sse2.c new file mode 100644 index 0000000000000..c11dac7565b7d --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8maxpool/16x9p8q-sse2.c @@ -0,0 +1,254 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +void pytorch_u8maxpool_ukernel_16x9p8q__sse2( + size_t n, + size_t ks, + size_t kc, + const uint8_t** input, + uint8_t* output, + size_t input_increment, + size_t output_increment, + const union pytorch_qnnp_u8_clamping_params params[RESTRICT_STATIC 1]) { + assert(n != 0); + assert(ks != 0); + assert(kc >= 16); + + const __m128i voutput_max = + _mm_load_si128((const __m128i*)params->sse2.output_max); + const __m128i voutput_min = + _mm_load_si128((const __m128i*)params->sse2.output_min); + + do { + uint8_t* o = output; + { + const uint8_t* i0 = *input++; + const uint8_t* i1 = *input++; + const uint8_t* i2 = *input++; + const uint8_t* i3 = *input++; + const uint8_t* i4 = *input++; + const uint8_t* i5 = *input++; + const uint8_t* i6 = *input++; + const uint8_t* i7 = *input++; + const uint8_t* i8 = *input++; + if (ks < 2) { + i1 = i0; + } + if (ks <= 2) { + i2 = i0; + } + if (ks < 4) { + i3 = i0; + } + if (ks <= 4) { + i4 = i0; + } + if (ks < 6) { + i5 = i0; + } + if (ks <= 6) { + i6 = i0; + } + if (ks < 8) { + i7 = i0; + } + if (ks <= 8) { + i8 = i0; + } + + size_t k = kc; + while (k >= 16) { + const __m128i vi0 = _mm_loadu_si128((const __m128i*)i0); + i0 += 16; + const __m128i vi1 = _mm_loadu_si128((const __m128i*)i1); + i1 += 16; + const __m128i vi2 = _mm_loadu_si128((const __m128i*)i2); + i2 += 16; + const __m128i vi3 = _mm_loadu_si128((const __m128i*)i3); + i3 += 16; + const __m128i vi4 = _mm_loadu_si128((const __m128i*)i4); + i4 += 16; + const __m128i vi5 = _mm_loadu_si128((const __m128i*)i5); + i5 += 16; + const __m128i vi6 = _mm_loadu_si128((const __m128i*)i6); + i6 += 16; + const __m128i vi7 = _mm_loadu_si128((const __m128i*)i7); + i7 += 16; + const __m128i vi8 = _mm_loadu_si128((const __m128i*)i8); + i8 += 16; + + const __m128i vmax018 = _mm_max_epu8(_mm_max_epu8(vi0, vi1), vi8); + const __m128i vmax23 = _mm_max_epu8(vi2, vi3); + const __m128i vmax45 = _mm_max_epu8(vi4, vi5); + const __m128i vmax67 = _mm_max_epu8(vi6, vi7); + + const __m128i vmax2345 = _mm_max_epu8(vmax23, vmax45); + const __m128i vmax01678 = _mm_max_epu8(vmax018, vmax67); + const __m128i vmax = _mm_max_epu8(vmax2345, vmax01678); + const __m128i vout = + _mm_max_epu8(_mm_min_epu8(vmax, voutput_max), voutput_min); + + _mm_storeu_si128((__m128i*)o, vout); + o += 16; + + k -= 16; + } + if (k != 0) { + const size_t address_increment = k - 16; + i0 = (const uint8_t*)((uintptr_t)i0 + address_increment); + i1 = (const uint8_t*)((uintptr_t)i1 + address_increment); + i2 = (const uint8_t*)((uintptr_t)i2 + address_increment); + i3 = (const uint8_t*)((uintptr_t)i3 + address_increment); + i4 = (const uint8_t*)((uintptr_t)i4 + address_increment); + i5 = (const uint8_t*)((uintptr_t)i5 + address_increment); + i6 = (const uint8_t*)((uintptr_t)i6 + address_increment); + i7 = (const uint8_t*)((uintptr_t)i7 + address_increment); + i8 = (const uint8_t*)((uintptr_t)i8 + address_increment); + o = (uint8_t*)((uintptr_t)o + address_increment); + + const __m128i vi0 = _mm_loadu_si128((const __m128i*)i0); + const __m128i vi1 = _mm_loadu_si128((const __m128i*)i1); + const __m128i vi2 = _mm_loadu_si128((const __m128i*)i2); + const __m128i vi3 = _mm_loadu_si128((const __m128i*)i3); + const __m128i vi4 = _mm_loadu_si128((const __m128i*)i4); + const __m128i vi5 = _mm_loadu_si128((const __m128i*)i5); + const __m128i vi6 = _mm_loadu_si128((const __m128i*)i6); + const __m128i vi7 = _mm_loadu_si128((const __m128i*)i7); + const __m128i vi8 = _mm_loadu_si128((const __m128i*)i8); + + const __m128i vmax018 = _mm_max_epu8(_mm_max_epu8(vi0, vi1), vi8); + const __m128i vmax23 = _mm_max_epu8(vi2, vi3); + const __m128i vmax45 = _mm_max_epu8(vi4, vi5); + const __m128i vmax67 = _mm_max_epu8(vi6, vi7); + + const __m128i vmax2345 = _mm_max_epu8(vmax23, vmax45); + const __m128i vmax01678 = _mm_max_epu8(vmax018, vmax67); + const __m128i vmax = _mm_max_epu8(vmax2345, vmax01678); + const __m128i vout = + _mm_max_epu8(_mm_min_epu8(vmax, voutput_max), voutput_min); + + _mm_storeu_si128((__m128i*)o, vout); + o += 16; + } + } + + for (ptrdiff_t m = (ptrdiff_t)ks - 9; m > 0; m -= 8) { + const uint8_t* i0 = *input++; + const uint8_t* i1 = *input++; + const uint8_t* i2 = *input++; + const uint8_t* i3 = *input++; + const uint8_t* i4 = *input++; + const uint8_t* i5 = *input++; + const uint8_t* i6 = *input++; + const uint8_t* i7 = *input++; + if (m < 2) { + i1 = i0; + } + if (m <= 2) { + i2 = i0; + } + if (m < 4) { + i3 = i0; + } + if (m <= 4) { + i4 = i0; + } + if (m < 6) { + i5 = i0; + } + if (m <= 6) { + i6 = i0; + } + if (m < 8) { + i7 = i0; + } + + o = output; + size_t k = kc; + while (k >= 16) { + const __m128i vi0 = _mm_loadu_si128((const __m128i*)i0); + i0 += 16; + const __m128i vi1 = _mm_loadu_si128((const __m128i*)i1); + i1 += 16; + const __m128i vi2 = _mm_loadu_si128((const __m128i*)i2); + i2 += 16; + const __m128i vi3 = _mm_loadu_si128((const __m128i*)i3); + i3 += 16; + const __m128i vi4 = _mm_loadu_si128((const __m128i*)i4); + i4 += 16; + const __m128i vi5 = _mm_loadu_si128((const __m128i*)i5); + i5 += 16; + const __m128i vi6 = _mm_loadu_si128((const __m128i*)i6); + i6 += 16; + const __m128i vi7 = _mm_loadu_si128((const __m128i*)i7); + i7 += 16; + const __m128i vo = _mm_loadu_si128((const __m128i*)o); + + const __m128i vmax01 = _mm_max_epu8(_mm_max_epu8(vi0, vi1), vo); + const __m128i vmax23 = _mm_max_epu8(vi2, vi3); + const __m128i vmax45 = _mm_max_epu8(vi4, vi5); + const __m128i vmax67 = _mm_max_epu8(vi6, vi7); + + const __m128i vmax2345 = _mm_max_epu8(vmax23, vmax45); + const __m128i vmax0167 = _mm_max_epu8(vmax01, vmax67); + const __m128i vmax = _mm_max_epu8(vmax2345, vmax0167); + const __m128i vout = + _mm_max_epu8(_mm_min_epu8(vmax, voutput_max), voutput_min); + + _mm_storeu_si128((__m128i*)o, vout); + o += 16; + + k -= 16; + } + if (k != 0) { + const size_t address_increment = k - 16; + i0 = (const uint8_t*)((uintptr_t)i0 + address_increment); + i1 = (const uint8_t*)((uintptr_t)i1 + address_increment); + i2 = (const uint8_t*)((uintptr_t)i2 + address_increment); + i3 = (const uint8_t*)((uintptr_t)i3 + address_increment); + i4 = (const uint8_t*)((uintptr_t)i4 + address_increment); + i5 = (const uint8_t*)((uintptr_t)i5 + address_increment); + i6 = (const uint8_t*)((uintptr_t)i6 + address_increment); + i7 = (const uint8_t*)((uintptr_t)i7 + address_increment); + o = (uint8_t*)((uintptr_t)o + address_increment); + + const __m128i vi0 = _mm_loadu_si128((const __m128i*)i0); + const __m128i vi1 = _mm_loadu_si128((const __m128i*)i1); + const __m128i vi2 = _mm_loadu_si128((const __m128i*)i2); + const __m128i vi3 = _mm_loadu_si128((const __m128i*)i3); + const __m128i vi4 = _mm_loadu_si128((const __m128i*)i4); + const __m128i vi5 = _mm_loadu_si128((const __m128i*)i5); + const __m128i vi6 = _mm_loadu_si128((const __m128i*)i6); + const __m128i vi7 = _mm_loadu_si128((const __m128i*)i7); + const __m128i vo = _mm_loadu_si128((const __m128i*)o); + + const __m128i vmax01 = _mm_max_epu8(_mm_max_epu8(vi0, vi1), vo); + const __m128i vmax23 = _mm_max_epu8(vi2, vi3); + const __m128i vmax45 = _mm_max_epu8(vi4, vi5); + const __m128i vmax67 = _mm_max_epu8(vi6, vi7); + + const __m128i vmax2345 = _mm_max_epu8(vmax23, vmax45); + const __m128i vmax0167 = _mm_max_epu8(vmax01, vmax67); + const __m128i vmax = _mm_max_epu8(vmax2345, vmax0167); + const __m128i vout = + _mm_max_epu8(_mm_min_epu8(vmax, voutput_max), voutput_min); + + _mm_storeu_si128((__m128i*)o, vout); + o += 16; + } + } + input = (const uint8_t**)((uintptr_t)input + input_increment); + output = (uint8_t*)((uintptr_t)o + output_increment); + } while (--n != 0); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8maxpool/sub16-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8maxpool/sub16-neon.c new file mode 100644 index 0000000000000..8d3660d4a86b9 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8maxpool/sub16-neon.c @@ -0,0 +1,91 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +void pytorch_u8maxpool_ukernel_sub16__neon( + size_t n, + size_t ks, + size_t kc, + const uint8_t** input, + uint8_t* output, + size_t input_increment, + size_t output_increment, + const union pytorch_qnnp_u8_clamping_params params[restrict static 1]) { + assert(n != 0); + assert(ks != 0); + assert(kc != 0); + assert(kc < 16); + + const uint8x16_t voutput_max = vld1q_dup_u8(¶ms->neon.output_max); + const uint8x16_t voutput_min = vld1q_dup_u8(¶ms->neon.output_min); + do { + uint8x16_t vmax = vmovq_n_u8(0); + + size_t m = ks; + do { + const uint8_t* i = *input++; + i += kc; + uint8x16_t vi = vmax; + if (kc & 1) { + i -= 1; + vi = vld1q_lane_u8(i, vi, 0); + } + if (kc & 2) { + vi = vextq_u8(vi, vi, 14); + i -= 2; + vi = vreinterpretq_u8_u16(vld1q_lane_u16( + __builtin_assume_aligned(i, 1), vreinterpretq_u16_u8(vi), 0)); + } + if (kc & 4) { + vi = vextq_u8(vi, vi, 12); + i -= 4; + vi = vreinterpretq_u8_u32(vld1q_lane_u32( + __builtin_assume_aligned(i, 1), vreinterpretq_u32_u8(vi), 0)); + } + if (kc & 8) { + i -= 8; + vi = vcombine_u8(vld1_u8(i), vget_low_u8(vi)); + } + vmax = vmaxq_u8(vmax, vi); + } while (--m != 0); + input = (const uint8_t**)((uintptr_t)input + input_increment); + + vmax = vminq_u8(vmax, voutput_max); + vmax = vmaxq_u8(vmax, voutput_min); + + uint8x8_t vout = vget_low_u8(vmax); + if (kc & 8) { + vst1_u8(output, vout); + output += 8; + vout = vget_high_u8(vmax); + } + if (kc & 4) { + vst1_lane_u32( + __builtin_assume_aligned(output, 1), vreinterpret_u32_u8(vout), 0); + output += 4; + vout = vext_u8(vout, vout, 4); + } + if (kc & 2) { + vst1_lane_u16( + __builtin_assume_aligned(output, 1), vreinterpret_u16_u8(vout), 0); + output += 2; + vout = vext_u8(vout, vout, 2); + } + if (kc & 1) { + vst1_lane_u8(output, vout, 0); + output += 1; + } + output = (uint8_t*)((uintptr_t)output + output_increment); + + } while (--n != 0); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8maxpool/sub16-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8maxpool/sub16-sse2.c new file mode 100644 index 0000000000000..61dcd4194ac12 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8maxpool/sub16-sse2.c @@ -0,0 +1,86 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +void pytorch_u8maxpool_ukernel_sub16__sse2( + size_t n, + size_t ks, + size_t kc, + const uint8_t** input, + uint8_t* output, + size_t input_increment, + size_t output_increment, + const union pytorch_qnnp_u8_clamping_params params[RESTRICT_STATIC 1]) { + assert(n != 0); + assert(ks != 0); + assert(kc != 0); + assert(kc < 16); + + const __m128i voutput_max = + _mm_load_si128((const __m128i*)params->sse2.output_max); + const __m128i voutput_min = + _mm_load_si128((const __m128i*)params->sse2.output_min); + + do { + __m128i vmax = _mm_setzero_si128(); + + size_t m = ks; + do { + const uint8_t* i = *input++; + i += kc; + __m128i vi = vmax; + if (kc & 1) { + i -= 1; + vi = _mm_cvtsi32_si128(*i); + } + if (kc & 2) { + vi = _mm_slli_epi32(vi, 16); + i -= 2; + vi = _mm_insert_epi16(vi, *((const uint16_t*)i), 0); + } + if (kc & 4) { + i -= 4; + vi = _mm_unpacklo_epi32( + _mm_cvtsi32_si128((int)*((const uint32_t*)i)), vi); + } + if (kc & 8) { + i -= 8; + vi = _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i*)i), vi); + } + vmax = _mm_max_epu8(vmax, vi); + } while (--m != 0); + input = (const uint8_t**)((uintptr_t)input + input_increment); + __m128i vout = _mm_max_epu8(_mm_min_epu8(vmax, voutput_max), voutput_min); + + if (kc & 8) { + _mm_storel_epi64((__m128i*)output, vout); + output += 8; + vout = _mm_unpackhi_epi64(vout, vout); + } + if (kc & 4) { + *((uint32_t*)output) = (uint32_t)_mm_cvtsi128_si32(vout); + output += 4; + vout = _mm_srli_epi64(vout, 32); + } + if (kc & 2) { + *((uint16_t*)output) = (uint16_t)_mm_extract_epi16(vout, 0); + output += 2; + vout = _mm_srli_epi32(vout, 16); + } + if (kc & 1) { + *((uint8_t*)output) = (uint8_t)_mm_cvtsi128_si32(vout); + output += 1; + } + output = (uint8_t*)((uintptr_t)output + output_increment); + } while (--n != 0); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8rmax/neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8rmax/neon.c new file mode 100644 index 0000000000000..915a4e1618282 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8rmax/neon.c @@ -0,0 +1,48 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +uint8_t pytorch_u8rmax_ukernel__neon(size_t n, const uint8_t* x) { + assert(n != 0); + + if + PYTORCH_QNNP_LIKELY(n >= 16) { + uint8x16_t vmax = vmovq_n_u8(0); + do { + const uint8x16_t vx = vld1q_u8(x); + x += 16; + vmax = vmaxq_u8(vmax, vx); + n -= 16; + } while (n >= 16); + if (n != 0) { + const size_t x_increment = n - 16; + x = (const uint8_t*)((uintptr_t)x + x_increment); + const uint8x16_t vx = vld1q_u8(x); + vmax = vmaxq_u8(vmax, vx); + } + uint8x8_t vmax8 = vmax_u8(vget_low_u8(vmax), vget_high_u8(vmax)); + const uint8x8_t vmax4 = vpmax_u8(vmax8, vmax8); + const uint8x8_t vmax2 = vpmax_u8(vmax4, vmax4); + const uint8x8_t vmax1 = vpmax_u8(vmax2, vmax2); + return vget_lane_u8(vmax1, 0); + } + else { + uint8x8_t vmax = vmov_n_u8(0); + do { + const uint8x8_t vx = vld1_dup_u8(x); + x += 1; + vmax = vmax_u8(vmax, vx); + } while (--n != 0); + return vget_lane_u8(vmax, 0); + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8rmax/sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8rmax/sse2.c new file mode 100644 index 0000000000000..63c11c0f87b09 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/u8rmax/sse2.c @@ -0,0 +1,47 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +uint8_t pytorch_u8rmax_ukernel__sse2(size_t n, const uint8_t* x) { + assert(n != 0); + + if + PYTORCH_QNNP_LIKELY(n >= 16) { + __m128i vmax = _mm_setzero_si128(); + do { + const __m128i vx = _mm_loadu_si128((const __m128i*)x); + x += 16; + vmax = _mm_max_epu8(vmax, vx); + n -= 16; + } while (n >= 16); + if (n != 0) { + const size_t x_increment = n - 16; + x = (const uint8_t*)((uintptr_t)x + x_increment); + const __m128i vx = _mm_loadu_si128((const __m128i*)x); + vmax = _mm_max_epu8(vmax, vx); + } + vmax = _mm_max_epu8(vmax, _mm_unpackhi_epi64(vmax, vmax)); + vmax = _mm_max_epu8(vmax, _mm_srli_epi64(vmax, 32)); + vmax = _mm_max_epu8(vmax, _mm_srli_epi32(vmax, 16)); + vmax = _mm_max_epu8(vmax, _mm_srli_epi16(vmax, 8)); + return (uint8_t)_mm_cvtsi128_si32(vmax); + } + else { + uint8_t vmax = 0; + do { + const uint8_t vx = *x++; + vmax = vx > vmax ? vx : vmax; + } while (--n != 0); + return vmax; + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8lut/scalar.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8lut/scalar.c new file mode 100644 index 0000000000000..3970b43887023 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8lut/scalar.c @@ -0,0 +1,47 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_x8lut_ukernel__scalar( + size_t n, + const uint8_t* x, + const uint8_t t[RESTRICT_STATIC 256], + uint8_t* y) { + assert(n != 0); + + while (n >= 4) { + const size_t vx0 = x[0]; + const size_t vx1 = x[1]; + const size_t vx2 = x[2]; + const size_t vx3 = x[3]; + x += 4; + + const uint8_t vt0 = t[vx0]; + const uint8_t vt1 = t[vx1]; + const uint8_t vt2 = t[vx2]; + const uint8_t vt3 = t[vx3]; + + y[0] = vt0; + y[1] = vt1; + y[2] = vt2; + y[3] = vt3; + y += 4; + + n -= 4; + } + while (n != 0) { + const size_t vx = *x++; + const uint8_t vt = t[vx]; + *y++ = vt; + + n--; + }; +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/x2-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/x2-neon.c new file mode 100644 index 0000000000000..9d0a85a51f69d --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/x2-neon.c @@ -0,0 +1,46 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_qnnp_x8zip_x2__neon(size_t n, const void* input, void* output) { + const uint8_t* x = input; + const uint8_t* y = x + n; + uint8_t* o = output; + + if (n >= 8) { + do { + uint8x8x2_t vxy; + vxy.val[0] = vld1_u8(x); + x += 8; + vxy.val[1] = vld1_u8(y); + y += 8; + vst2_u8(o, vxy); + o += 16; + ; + n -= 8; + } while (n >= 8); + if (n != 0) { + const size_t address_increment = n - 8; + uint8x8x2_t vxy; + vxy.val[0] = vld1_u8((const uint8_t*)((uintptr_t)x + address_increment)); + vxy.val[1] = vld1_u8((const uint8_t*)((uintptr_t)y + address_increment)); + vst2_u8((uint8_t*)((uintptr_t)o + address_increment * 2), vxy); + } + } else { + do { + const uint8_t vx = *x++; + const uint8_t vy = *y++; + o[0] = vx; + o[1] = vy; + o += 2; + } while (--n != 0); + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/x2-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/x2-sse2.c new file mode 100644 index 0000000000000..f09672203805f --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/x2-sse2.c @@ -0,0 +1,52 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_qnnp_x8zip_x2__sse2(size_t n, const void* input, void* output) { + const uint8_t* x = input; + const uint8_t* y = x + n; + uint8_t* o = output; + + if (n >= 16) { + do { + const __m128i vx = _mm_loadu_si128((const __m128i*)x); + x += 16; + const __m128i vy = _mm_loadu_si128((const __m128i*)y); + y += 16; + const __m128i vxy_lo = _mm_unpacklo_epi8(vx, vy); + const __m128i vxy_hi = _mm_unpackhi_epi8(vx, vy); + _mm_storeu_si128((__m128i*)o, vxy_lo); + _mm_storeu_si128((__m128i*)(o + 16), vxy_hi); + o = (void*)((uintptr_t)o + 32); + n -= 16; + } while (n >= 16); + if (n != 0) { + const size_t address_increment = n - 16; + const __m128i vx = + _mm_loadu_si128((const __m128i*)((uintptr_t)x + address_increment)); + const __m128i vy = + _mm_loadu_si128((const __m128i*)((uintptr_t)y + address_increment)); + const __m128i vxy_lo = _mm_unpacklo_epi8(vx, vy); + const __m128i vxy_hi = _mm_unpackhi_epi8(vx, vy); + o = (void*)((uintptr_t)o + address_increment * 2); + _mm_storeu_si128((__m128i*)o, vxy_lo); + _mm_storeu_si128((__m128i*)o + 1, vxy_hi); + } + } else { + do { + const uint8_t vx = *x++; + const uint8_t vy = *y++; + o[0] = vx; + o[1] = vy; + o += 2; + } while (--n != 0); + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/x3-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/x3-neon.c new file mode 100644 index 0000000000000..45e66b1d8b475 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/x3-neon.c @@ -0,0 +1,51 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_qnnp_x8zip_x3__neon(size_t n, const void* input, void* output) { + const uint8_t* x = input; + const uint8_t* y = x + n; + const uint8_t* z = y + n; + uint8_t* o = output; + + if (n >= 8) { + do { + uint8x8x3_t vxyz; + vxyz.val[0] = vld1_u8(x); + x += 8; + vxyz.val[1] = vld1_u8(y); + y += 8; + vxyz.val[2] = vld1_u8(z); + z += 8; + vst3_u8(o, vxyz); + o += 24; + n -= 8; + } while (n >= 8); + if (n != 0) { + const size_t address_increment = n - 8; + uint8x8x3_t vxyz; + vxyz.val[0] = vld1_u8(x + address_increment); + vxyz.val[1] = vld1_u8(y + address_increment); + vxyz.val[2] = vld1_u8(z + address_increment); + vst3_u8((uint8_t*)((uintptr_t)o + address_increment * 3), vxyz); + } + } else { + do { + const uint8_t vx = *x++; + const uint8_t vy = *y++; + const uint8_t vz = *z++; + o[0] = vx; + o[1] = vy; + o[2] = vz; + o += 3; + } while (--n != 0); + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/x3-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/x3-sse2.c new file mode 100644 index 0000000000000..c4aced8758f4c --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/x3-sse2.c @@ -0,0 +1,205 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_qnnp_x8zip_x3__sse2(size_t n, const void* input, void* output) { + const uint8_t* x = input; + const uint8_t* y = x + n; + const uint8_t* z = y + n; + uint8_t* o = output; + + if (n >= 16) { + const __m128i vmask0x00FF00FF = _mm_set1_epi16(0x00FF); + const __m128i vmask0x0000FFFF = _mm_set1_epi32(0x0000FFFF); + do { + /* vx = ( x15, x14, x13, x12, x11, x10, x9, x8, x7, x6, x5, x4, x3, + * x2, x1, x0 ) */ + const __m128i vx = _mm_loadu_si128((const __m128i*)x); + x += 16; + /* vy = ( y15, y14, y13, y12, y11, y10, y9, y8, y7, y6, y5, y4, y3, + * y2, y1, y0 ) */ + const __m128i vy = _mm_loadu_si128((const __m128i*)y); + y += 16; + /* vz = ( z15, z14, z13, z12, z11, z10, z9, z8, z7, z6, z5, z4, z3, + * z2, z1, z0 ) */ + const __m128i vz = _mm_loadu_si128((const __m128i*)z); + z += 16; + + /* vxeye = ( y14, x14, y12, x12, y10, x10, y8, x8, y6, x6, y4, + * x4, y2, x2, y0, x0 ) */ + const __m128i vxeye = _mm_or_si128( + _mm_and_si128(vx, vmask0x00FF00FF), _mm_slli_epi16(vy, 8)); + /* vyozo = ( z15, y15, z13, y13, z11, y11, z9, y9, z7, y7, z5, + * y5, z3, y3, z1, y1 ) */ + const __m128i vyozo = _mm_or_si128( + _mm_andnot_si128(vmask0x00FF00FF, vz), _mm_srli_epi16(vy, 8)); + /* vzoxo = ( x15, z14, x13, z12, x11, z10, x9, z8, x7, z6, x5, + * z4, x3, z2, x1, z0 ) */ + const __m128i vzexo = _mm_or_si128( + _mm_and_si128(vz, vmask0x00FF00FF), + _mm_andnot_si128(vmask0x00FF00FF, vx)); + + /* vxeyezexo = ( x13, z12, y12, x12, x9, z8, y8, x8, x5, z4, y4, + * x4, x1, z0, y0, x0 ) */ + const __m128i vxeyezexo = _mm_or_si128( + _mm_and_si128(vxeye, vmask0x0000FFFF), _mm_slli_epi32(vzexo, 16)); + /* vyozoxeye = ( y14, x14, z13, y13, y10, x10, z9, y9, y6, x6, z5, + * y5, y2, x2, z1, y1 ) */ + const __m128i vyozoxeye = _mm_or_si128( + _mm_and_si128(vyozo, vmask0x0000FFFF), + _mm_andnot_si128(vmask0x0000FFFF, vxeye)); + /* vzexoyozo = ( z15, y15, x15, z14, z11, y11, x11, z10, z7, y7, x7, + * z6, z3, y3, x3, z2 ) */ + const __m128i vzexoyozo = _mm_or_si128( + _mm_andnot_si128(vmask0x0000FFFF, vyozo), _mm_srli_epi32(vzexo, 16)); + + /* vtemp0 = ( x13, z12, y12, x12, x5, z4, y4, x4, z11, y11, x11, + * z10, z3, y3, x3, z2 ) */ + const __m128i vtemp0 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vzexoyozo), + _mm_castsi128_ps(vxeyezexo), + _MM_SHUFFLE(3, 1, 2, 0))); + /* vtemp1 = ( y10, x10, z9, y9, y2, x2, z1, y1, x9, z8, y8, + * x8, x1, z0, y0, x0 ) */ + const __m128i vtemp1 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vxeyezexo), + _mm_castsi128_ps(vyozoxeye), + _MM_SHUFFLE(2, 0, 2, 0))); + /* vtemp2 = ( z15, y15, x15, z14, z7, y7, x7, z6, y14, x14, z13, + * y13, y6, x6, z5, y5 ) */ + const __m128i vtemp2 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vyozoxeye), + _mm_castsi128_ps(vzexoyozo), + _MM_SHUFFLE(3, 1, 3, 1))); + + /* vxyz0 = ( x5, z4, y4, x4, z3, y3, x3, z2, y2, x2, z1, + * y1, x1, z0, y0, x0 ) */ + const __m128i vxyz0 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vtemp1), + _mm_castsi128_ps(vtemp0), + _MM_SHUFFLE(2, 0, 2, 0))); + /* vxyz1 = ( y10, x10, z9, y9, x9, z8, y8, x8, z7, y7, x7, + * z6, y6, x6, z5, y5 ) */ + const __m128i vxyz1 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vtemp2), + _mm_castsi128_ps(vtemp1), + _MM_SHUFFLE(3, 1, 2, 0))); + /* vxyz2 = ( z15, y15, x15, z14, y14, x14, z13, y13, x13, z12, y12, + * x12, z11, y11, x11, z10 ) */ + const __m128i vxyz2 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vtemp0), + _mm_castsi128_ps(vtemp2), + _MM_SHUFFLE(3, 1, 3, 1))); + + _mm_storeu_si128((__m128i*)o, vxyz0); + _mm_storeu_si128((__m128i*)o + 1, vxyz1); + _mm_storeu_si128((__m128i*)o + 2, vxyz2); + o += 48; + n -= 16; + } while (n >= 16); + if (n != 0) { + const size_t address_increment = n - 16; + /* vx = ( x15, x14, x13, x12, x11, x10, x9, x8, x7, x6, x5, x4, x3, + * x2, x1, x0 ) */ + const __m128i vx = + _mm_loadu_si128((const __m128i*)((uintptr_t)x + address_increment)); + /* vy = ( y15, y14, y13, y12, y11, y10, y9, y8, y7, y6, y5, y4, y3, + * y2, y1, y0 ) */ + const __m128i vy = + _mm_loadu_si128((const __m128i*)((uintptr_t)y + address_increment)); + /* vz = ( z15, z14, z13, z12, z11, z10, z9, z8, z7, z6, z5, z4, z3, + * z2, z1, z0 ) */ + const __m128i vz = + _mm_loadu_si128((const __m128i*)((uintptr_t)z + address_increment)); + + /* vxeye = ( y14, x14, y12, x12, y10, x10, y8, x8, y6, x6, y4, + * x4, y2, x2, y0, x0 ) */ + const __m128i vxeye = _mm_or_si128( + _mm_and_si128(vx, vmask0x00FF00FF), _mm_slli_epi16(vy, 8)); + /* vyozo = ( z15, y15, z13, y13, z11, y11, z9, y9, z7, y7, z5, + * y5, z3, y3, z1, y1 ) */ + const __m128i vyozo = _mm_or_si128( + _mm_andnot_si128(vmask0x00FF00FF, vz), _mm_srli_epi16(vy, 8)); + /* vzoxo = ( x15, z14, x13, z12, x11, z10, x9, z8, x7, z6, x5, + * z4, x3, z2, x1, z0 ) */ + const __m128i vzexo = _mm_or_si128( + _mm_and_si128(vz, vmask0x00FF00FF), + _mm_andnot_si128(vmask0x00FF00FF, vx)); + + /* vxeyezexo = ( x13, z12, y12, x12, x9, z8, y8, x8, x5, z4, y4, + * x4, x1, z0, y0, x0 ) */ + const __m128i vxeyezexo = _mm_or_si128( + _mm_and_si128(vxeye, vmask0x0000FFFF), _mm_slli_epi32(vzexo, 16)); + /* vyozoxeye = ( y14, x14, z13, y13, y10, x10, z9, y9, y6, x6, z5, + * y5, y2, x2, z1, y1 ) */ + const __m128i vyozoxeye = _mm_or_si128( + _mm_and_si128(vyozo, vmask0x0000FFFF), + _mm_andnot_si128(vmask0x0000FFFF, vxeye)); + /* vzexoyozo = ( z15, y15, x15, z14, z11, y11, x11, z10, z7, y7, x7, + * z6, z3, y3, x3, z2 ) */ + const __m128i vzexoyozo = _mm_or_si128( + _mm_andnot_si128(vmask0x0000FFFF, vyozo), _mm_srli_epi32(vzexo, 16)); + + /* vtemp0 = ( x13, z12, y12, x12, x5, z4, y4, x4, z11, y11, x11, + * z10, z3, y3, x3, z2 ) */ + const __m128i vtemp0 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vzexoyozo), + _mm_castsi128_ps(vxeyezexo), + _MM_SHUFFLE(3, 1, 2, 0))); + /* vtemp1 = ( y10, x10, z9, y9, y2, x2, z1, y1, x9, z8, y8, + * x8, x1, z0, y0, x0 ) */ + const __m128i vtemp1 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vxeyezexo), + _mm_castsi128_ps(vyozoxeye), + _MM_SHUFFLE(2, 0, 2, 0))); + /* vtemp2 = ( z15, y15, x15, z14, z7, y7, x7, z6, y14, x14, z13, + * y13, y6, x6, z5, y5 ) */ + const __m128i vtemp2 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vyozoxeye), + _mm_castsi128_ps(vzexoyozo), + _MM_SHUFFLE(3, 1, 3, 1))); + + /* vxyz0 = ( x5, z4, y4, x4, z3, y3, x3, z2, y2, x2, z1, + * y1, x1, z0, y0, x0 ) */ + const __m128i vxyz0 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vtemp1), + _mm_castsi128_ps(vtemp0), + _MM_SHUFFLE(2, 0, 2, 0))); + /* vxyz1 = ( y10, x10, z9, y9, x9, z8, y8, x8, z7, y7, x7, + * z6, y6, x6, z5, y5 ) */ + const __m128i vxyz1 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vtemp2), + _mm_castsi128_ps(vtemp1), + _MM_SHUFFLE(3, 1, 2, 0))); + /* vxyz2 = ( z15, y15, x15, z14, y14, x14, z13, y13, x13, z12, y12, + * x12, z11, y11, x11, z10 ) */ + const __m128i vxyz2 = _mm_castps_si128(_mm_shuffle_ps( + _mm_castsi128_ps(vtemp0), + _mm_castsi128_ps(vtemp2), + _MM_SHUFFLE(3, 1, 3, 1))); + + o = (uint8_t*)((uintptr_t)o + address_increment * 3); + _mm_storeu_si128((__m128i*)o, vxyz0); + _mm_storeu_si128((__m128i*)o + 1, vxyz1); + _mm_storeu_si128((__m128i*)o + 2, vxyz2); + } + } else { + do { + const uint8_t vx = *x++; + const uint8_t vy = *y++; + const uint8_t vz = *z++; + o[0] = vx; + o[1] = vy; + o[2] = vz; + o += 3; + } while (--n != 0); + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/x4-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/x4-neon.c new file mode 100644 index 0000000000000..137bdae473764 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/x4-neon.c @@ -0,0 +1,57 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_qnnp_x8zip_x4__neon(size_t n, const void* input, void* output) { + const uint8_t* x = input; + const uint8_t* y = x + n; + const uint8_t* z = y + n; + const uint8_t* w = z + n; + uint8_t* o = output; + + if (n >= 8) { + do { + uint8x8x4_t vxyzw; + vxyzw.val[0] = vld1_u8(x); + x += 8; + vxyzw.val[1] = vld1_u8(y); + y += 8; + vxyzw.val[2] = vld1_u8(z); + z += 8; + vxyzw.val[3] = vld1_u8(w); + w += 8; + vst4_u8(o, vxyzw); + o += 32; + n -= 8; + } while (n >= 8); + if (n != 0) { + const size_t address_increment = n - 8; + uint8x8x4_t vxyzw; + vxyzw.val[0] = vld1_u8(x + address_increment); + vxyzw.val[1] = vld1_u8(y + address_increment); + vxyzw.val[2] = vld1_u8(z + address_increment); + vxyzw.val[3] = vld1_u8(w + address_increment); + vst4_u8((uint8_t*)((uintptr_t)o + address_increment * 4), vxyzw); + } + } else { + do { + const uint8_t vx = *x++; + const uint8_t vy = *y++; + const uint8_t vz = *z++; + const uint8_t vw = *w++; + o[0] = vx; + o[1] = vy; + o[2] = vz; + o[3] = vw; + o += 4; + } while (--n != 0); + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/x4-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/x4-sse2.c new file mode 100644 index 0000000000000..aff1cb68fc651 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/x4-sse2.c @@ -0,0 +1,82 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_qnnp_x8zip_x4__sse2(size_t n, const void* input, void* output) { + const uint8_t* x = input; + const uint8_t* y = x + n; + const uint8_t* z = y + n; + const uint8_t* w = z + n; + uint8_t* o = output; + + if (n >= 16) { + do { + const __m128i vx = _mm_loadu_si128((const __m128i*)x); + x += 16; + const __m128i vy = _mm_loadu_si128((const __m128i*)y); + y += 16; + const __m128i vz = _mm_loadu_si128((const __m128i*)z); + z += 16; + const __m128i vw = _mm_loadu_si128((const __m128i*)w); + w += 16; + const __m128i vxy_lo = _mm_unpacklo_epi8(vx, vy); + const __m128i vxy_hi = _mm_unpackhi_epi8(vx, vy); + const __m128i vzw_lo = _mm_unpacklo_epi8(vz, vw); + const __m128i vzw_hi = _mm_unpackhi_epi8(vz, vw); + const __m128i vxyzw0 = _mm_unpacklo_epi16(vxy_lo, vzw_lo); + const __m128i vxyzw1 = _mm_unpackhi_epi16(vxy_lo, vzw_lo); + const __m128i vxyzw2 = _mm_unpacklo_epi16(vxy_hi, vzw_hi); + const __m128i vxyzw3 = _mm_unpackhi_epi16(vxy_hi, vzw_hi); + _mm_storeu_si128((__m128i*)o, vxyzw0); + _mm_storeu_si128((__m128i*)o + 1, vxyzw1); + _mm_storeu_si128((__m128i*)o + 2, vxyzw2); + _mm_storeu_si128((__m128i*)o + 3, vxyzw3); + o = (void*)((uintptr_t)o + 64); + n -= 16; + } while (n >= 16); + if (n != 0) { + const size_t address_increment = n - 16; + const __m128i vx = + _mm_loadu_si128((const __m128i*)((uintptr_t)x + address_increment)); + const __m128i vy = + _mm_loadu_si128((const __m128i*)((uintptr_t)y + address_increment)); + const __m128i vz = + _mm_loadu_si128((const __m128i*)((uintptr_t)z + address_increment)); + const __m128i vw = + _mm_loadu_si128((const __m128i*)((uintptr_t)w + address_increment)); + const __m128i vxy_lo = _mm_unpacklo_epi8(vx, vy); + const __m128i vxy_hi = _mm_unpackhi_epi8(vx, vy); + const __m128i vzw_lo = _mm_unpacklo_epi8(vz, vw); + const __m128i vzw_hi = _mm_unpackhi_epi8(vz, vw); + const __m128i vxyzw0 = _mm_unpacklo_epi16(vxy_lo, vzw_lo); + const __m128i vxyzw1 = _mm_unpackhi_epi16(vxy_lo, vzw_lo); + const __m128i vxyzw2 = _mm_unpacklo_epi16(vxy_hi, vzw_hi); + const __m128i vxyzw3 = _mm_unpackhi_epi16(vxy_hi, vzw_hi); + o = (void*)((uintptr_t)o + address_increment * 4); + _mm_storeu_si128((__m128i*)o, vxyzw0); + _mm_storeu_si128((__m128i*)o + 1, vxyzw1); + _mm_storeu_si128((__m128i*)o + 2, vxyzw2); + _mm_storeu_si128((__m128i*)o + 3, vxyzw3); + } + } else { + do { + const uint8_t vx = *x++; + const uint8_t vy = *y++; + const uint8_t vz = *z++; + const uint8_t vw = *w++; + o[0] = vx; + o[1] = vy; + o[2] = vz; + o[3] = vw; + o += 4; + } while (--n != 0); + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/xm-neon.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/xm-neon.c new file mode 100644 index 0000000000000..24e6533d59e76 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/xm-neon.c @@ -0,0 +1,177 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_qnnp_x8zip_xm__neon( + size_t n, + size_t m, + const void* input, + void* output) { + const uint8_t* w = input; + const size_t input_increment = n * 3; + const size_t output_increment = 4 - m * n; + const uint8_t* last_input = w + n * (m - 1); + void* last_output = (void*)((uintptr_t)output + (m - 4)); + + if (n >= 8) { + for (size_t i = 0; i < m; i += 4) { + size_t k = n; + w = (const uint8_t*)((uintptr_t)w + input_increment); + if (w >= last_input) { + w = last_input; + } + const uint8_t* z = (const uint8_t*)((uintptr_t)w - n); + const uint8_t* y = (const uint8_t*)((uintptr_t)z - n); + const uint8_t* x = (const uint8_t*)((uintptr_t)y - n); + while (k >= 8) { + const uint8x8_t vx = vld1_u8(x); + x += 8; + const uint8x8_t vy = vld1_u8(y); + y += 8; + const uint8x8_t vz = vld1_u8(z); + z += 8; + const uint8x8_t vw = vld1_u8(w); + w += 8; + + const uint8x8x2_t vxy = vzip_u8(vx, vy); + const uint8x8x2_t vzw = vzip_u8(vz, vw); + const uint16x4x2_t vxyzw_lo = vzip_u16( + vreinterpret_u16_u8(vxy.val[0]), vreinterpret_u16_u8(vzw.val[0])); + const uint16x4x2_t vxyzw_hi = vzip_u16( + vreinterpret_u16_u8(vxy.val[1]), vreinterpret_u16_u8(vzw.val[1])); + + vst1_lane_u32( + __builtin_assume_aligned(output, 1), + vreinterpret_u32_u16(vxyzw_lo.val[0]), + 0); + output = (void*)((uintptr_t)output + m); + + vst1_lane_u32( + __builtin_assume_aligned(output, 1), + vreinterpret_u32_u16(vxyzw_lo.val[0]), + 1); + output = (void*)((uintptr_t)output + m); + + vst1_lane_u32( + __builtin_assume_aligned(output, 1), + vreinterpret_u32_u16(vxyzw_lo.val[1]), + 0); + output = (void*)((uintptr_t)output + m); + + vst1_lane_u32( + __builtin_assume_aligned(output, 1), + vreinterpret_u32_u16(vxyzw_lo.val[1]), + 1); + output = (void*)((uintptr_t)output + m); + + vst1_lane_u32( + __builtin_assume_aligned(output, 1), + vreinterpret_u32_u16(vxyzw_hi.val[0]), + 0); + output = (void*)((uintptr_t)output + m); + + vst1_lane_u32( + __builtin_assume_aligned(output, 1), + vreinterpret_u32_u16(vxyzw_hi.val[0]), + 1); + output = (void*)((uintptr_t)output + m); + + vst1_lane_u32( + __builtin_assume_aligned(output, 1), + vreinterpret_u32_u16(vxyzw_hi.val[1]), + 0); + output = (void*)((uintptr_t)output + m); + + vst1_lane_u32( + __builtin_assume_aligned(output, 1), + vreinterpret_u32_u16(vxyzw_hi.val[1]), + 1); + output = (void*)((uintptr_t)output + m); + + k -= 8; + } + if (k != 0) { + const size_t address_increment = k - 8; + x = (const uint8_t*)((uintptr_t)x + address_increment); + y = (const uint8_t*)((uintptr_t)y + address_increment); + z = (const uint8_t*)((uintptr_t)z + address_increment); + w = (const uint8_t*)((uintptr_t)w + address_increment); + const int64x1_t vshift = vmov_n_s64(8 * address_increment); + + const uint64x1_t vx = vshl_u64(vreinterpret_u64_u8(vld1_u8(x)), vshift); + const uint64x1_t vy = vshl_u64(vreinterpret_u64_u8(vld1_u8(y)), vshift); + const uint64x1_t vz = vshl_u64(vreinterpret_u64_u8(vld1_u8(z)), vshift); + const uint64x1_t vw = vshl_u64(vreinterpret_u64_u8(vld1_u8(w)), vshift); + w += 8; + const uint8x8x2_t vxy = + vzip_u8(vreinterpret_u8_u64(vx), vreinterpret_u8_u64(vy)); + const uint8x8x2_t vzw = + vzip_u8(vreinterpret_u8_u64(vz), vreinterpret_u8_u64(vw)); + const uint16x4x2_t vxyzw_lo = vzip_u16( + vreinterpret_u16_u8(vxy.val[0]), vreinterpret_u16_u8(vzw.val[0])); + const uint16x4x2_t vxyzw_hi = vzip_u16( + vreinterpret_u16_u8(vxy.val[1]), vreinterpret_u16_u8(vzw.val[1])); + + uint32x2_t vxyzw0 = vreinterpret_u32_u16(vxyzw_lo.val[0]); + uint32x2_t vxyzw1 = vreinterpret_u32_u16(vxyzw_lo.val[1]); + uint32x2_t vxyzw2 = vreinterpret_u32_u16(vxyzw_hi.val[0]); + uint32x2_t vxyzw3 = vreinterpret_u32_u16(vxyzw_hi.val[1]); + + if (k & 4) { + vst1_lane_u32(__builtin_assume_aligned(output, 1), vxyzw0, 0); + output = (void*)((uintptr_t)output + m); + + vst1_lane_u32(__builtin_assume_aligned(output, 1), vxyzw0, 1); + output = (void*)((uintptr_t)output + m); + + vst1_lane_u32(__builtin_assume_aligned(output, 1), vxyzw1, 0); + output = (void*)((uintptr_t)output + m); + + vst1_lane_u32(__builtin_assume_aligned(output, 1), vxyzw1, 1); + output = (void*)((uintptr_t)output + m); + + vxyzw0 = vxyzw2; + vxyzw1 = vxyzw3; + } + + if (k & 2) { + vst1_lane_u32(__builtin_assume_aligned(output, 1), vxyzw0, 0); + output = (void*)((uintptr_t)output + m); + + vst1_lane_u32(__builtin_assume_aligned(output, 1), vxyzw0, 1); + output = (void*)((uintptr_t)output + m); + + vxyzw0 = vxyzw1; + } + if (k & 1) { + vst1_lane_u32(__builtin_assume_aligned(output, 1), vxyzw0, 0); + output = (void*)((uintptr_t)output + m); + } + } + output = (void*)((uintptr_t)output + output_increment); + if (output > last_output) { + output = last_output; + } + } + } else { + const uint8_t* i = input; + uint8_t* o = output; + size_t k = n; + do { + size_t l = m; + const uint8_t* ii = i++; + do { + *o++ = *ii; + ii += n; + } while (--l != 0); + } while (--k != 0); + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/xm-sse2.c b/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/xm-sse2.c new file mode 100644 index 0000000000000..088963b664431 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/x8zip/xm-sse2.c @@ -0,0 +1,208 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +void pytorch_qnnp_x8zip_xm__sse2( + size_t n, + size_t m, + const void* input, + void* output) { + const uint8_t* w = input; + const size_t input_increment = n * 3; + const size_t output_increment = 4 - m * n; + const uint8_t* last_input = w + n * (m - 1); + void* last_output = (void*)((uintptr_t)output + (m - 4)); + + if (n >= 8) { + for (size_t i = 0; i < m; i += 4) { + size_t k = n; + w = (const uint8_t*)((uintptr_t)w + input_increment); + if (w >= last_input) { + w = last_input; + } + const uint8_t* z = (const uint8_t*)((uintptr_t)w - n); + const uint8_t* y = (const uint8_t*)((uintptr_t)z - n); + const uint8_t* x = (const uint8_t*)((uintptr_t)y - n); + while (k >= 16) { + const __m128i vx = _mm_loadu_si128((const __m128i*)x); + x += 16; + const __m128i vy = _mm_loadu_si128((const __m128i*)y); + y += 16; + const __m128i vz = _mm_loadu_si128((const __m128i*)z); + z += 16; + const __m128i vw = _mm_loadu_si128((const __m128i*)w); + w += 16; + const __m128i vxy_lo = _mm_unpacklo_epi8(vx, vy); + const __m128i vxy_hi = _mm_unpackhi_epi8(vx, vy); + const __m128i vzw_lo = _mm_unpacklo_epi8(vz, vw); + const __m128i vzw_hi = _mm_unpackhi_epi8(vz, vw); + __m128i vxyzw0 = _mm_unpacklo_epi16(vxy_lo, vzw_lo); + __m128i vxyzw1 = _mm_unpackhi_epi16(vxy_lo, vzw_lo); + __m128i vxyzw2 = _mm_unpacklo_epi16(vxy_hi, vzw_hi); + __m128i vxyzw3 = _mm_unpackhi_epi16(vxy_hi, vzw_hi); + + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw0); + output = (void*)((uintptr_t)output + m); + vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2)); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw0); + output = (void*)((uintptr_t)output + m); + vxyzw0 = _mm_unpackhi_epi64(vxyzw0, vxyzw0); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw0); + output = (void*)((uintptr_t)output + m); + vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2)); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw0); + output = (void*)((uintptr_t)output + m); + + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw1); + output = (void*)((uintptr_t)output + m); + vxyzw1 = _mm_shufflelo_epi16(vxyzw1, _MM_SHUFFLE(3, 2, 3, 2)); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw1); + output = (void*)((uintptr_t)output + m); + vxyzw1 = _mm_unpackhi_epi64(vxyzw1, vxyzw1); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw1); + output = (void*)((uintptr_t)output + m); + vxyzw1 = _mm_shufflelo_epi16(vxyzw1, _MM_SHUFFLE(3, 2, 3, 2)); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw1); + output = (void*)((uintptr_t)output + m); + + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw2); + output = (void*)((uintptr_t)output + m); + vxyzw2 = _mm_shufflelo_epi16(vxyzw2, _MM_SHUFFLE(3, 2, 3, 2)); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw2); + output = (void*)((uintptr_t)output + m); + vxyzw2 = _mm_unpackhi_epi64(vxyzw2, vxyzw2); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw2); + output = (void*)((uintptr_t)output + m); + vxyzw2 = _mm_shufflelo_epi16(vxyzw2, _MM_SHUFFLE(3, 2, 3, 2)); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw2); + output = (void*)((uintptr_t)output + m); + + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw3); + output = (void*)((uintptr_t)output + m); + vxyzw3 = _mm_shufflelo_epi16(vxyzw3, _MM_SHUFFLE(3, 2, 3, 2)); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw3); + output = (void*)((uintptr_t)output + m); + vxyzw3 = _mm_unpackhi_epi64(vxyzw3, vxyzw3); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw3); + output = (void*)((uintptr_t)output + m); + vxyzw3 = _mm_shufflelo_epi16(vxyzw3, _MM_SHUFFLE(3, 2, 3, 2)); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw3); + output = (void*)((uintptr_t)output + m); + k -= 16; + }; + if (k >= 8) { + const __m128i vx = _mm_loadl_epi64((const __m128i*)x); + x += 8; + const __m128i vy = _mm_loadl_epi64((const __m128i*)y); + y += 8; + const __m128i vz = _mm_loadl_epi64((const __m128i*)z); + z += 8; + const __m128i vw = _mm_loadl_epi64((const __m128i*)w); + w += 8; + const __m128i vxy = _mm_unpacklo_epi8(vx, vy); + const __m128i vzw = _mm_unpacklo_epi8(vz, vw); + __m128i vxyzw0 = _mm_unpacklo_epi16(vxy, vzw); + __m128i vxyzw1 = _mm_unpackhi_epi16(vxy, vzw); + + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw0); + output = (void*)((uintptr_t)output + m); + vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2)); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw0); + output = (void*)((uintptr_t)output + m); + vxyzw0 = _mm_unpackhi_epi64(vxyzw0, vxyzw0); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw0); + output = (void*)((uintptr_t)output + m); + vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2)); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw0); + output = (void*)((uintptr_t)output + m); + + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw1); + output = (void*)((uintptr_t)output + m); + vxyzw1 = _mm_shufflelo_epi16(vxyzw1, _MM_SHUFFLE(3, 2, 3, 2)); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw1); + output = (void*)((uintptr_t)output + m); + vxyzw1 = _mm_unpackhi_epi64(vxyzw1, vxyzw1); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw1); + output = (void*)((uintptr_t)output + m); + vxyzw1 = _mm_shufflelo_epi16(vxyzw1, _MM_SHUFFLE(3, 2, 3, 2)); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw1); + output = (void*)((uintptr_t)output + m); + k -= 8; + } + if (k != 0) { + const size_t address_decrement = 8 - k; + x -= address_decrement; + y -= address_decrement; + z -= address_decrement; + w -= address_decrement; + const __m128i vshift = _mm_cvtsi32_si128(8 * address_decrement); + + const __m128i vx = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)x), vshift); + const __m128i vy = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)y), vshift); + const __m128i vz = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)z), vshift); + const __m128i vw = + _mm_srl_epi64(_mm_loadl_epi64((const __m128i*)w), vshift); + w += 8; + const __m128i vxy = _mm_unpacklo_epi8(vx, vy); + const __m128i vzw = _mm_unpacklo_epi8(vz, vw); + __m128i vxyzw0 = _mm_unpacklo_epi16(vxy, vzw); + __m128i vxyzw1 = _mm_unpackhi_epi16(vxy, vzw); + + if (k & 4) { + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw0); + output = (void*)((uintptr_t)output + m); + vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2)); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw0); + output = (void*)((uintptr_t)output + m); + vxyzw0 = _mm_unpackhi_epi64(vxyzw0, vxyzw0); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw0); + output = (void*)((uintptr_t)output + m); + vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2)); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw0); + output = (void*)((uintptr_t)output + m); + vxyzw0 = vxyzw1; + } + + if (k & 2) { + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw0); + output = (void*)((uintptr_t)output + m); + vxyzw0 = _mm_shufflelo_epi16(vxyzw0, _MM_SHUFFLE(3, 2, 3, 2)); + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw0); + output = (void*)((uintptr_t)output + m); + vxyzw0 = _mm_unpackhi_epi64(vxyzw0, vxyzw0); + } + if (k & 1) { + *((uint32_t*)output) = _mm_cvtsi128_si32(vxyzw0); + output = (void*)((uintptr_t)output + m); + } + } + output = (void*)((uintptr_t)output + output_increment); + if (output > last_output) { + output = last_output; + } + } + } else { + const uint8_t* i = input; + uint8_t* o = output; + size_t k = n; + do { + size_t l = m; + const uint8_t* ii = i++; + do { + *o++ = *ii; + ii += n; + } while (--l != 0); + } while (--k != 0); + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/add-operator-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/add-operator-tester.h new file mode 100644 index 0000000000000..f594cbcae23e5 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/add-operator-tester.h @@ -0,0 +1,281 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +class AddOperatorTester { + public: + inline AddOperatorTester& channels(size_t channels) { + assert(channels != 0); + this->channels_ = channels; + return *this; + } + + inline size_t channels() const { + return this->channels_; + } + + inline AddOperatorTester& aStride(size_t aStride) { + assert(aStride != 0); + this->aStride_ = aStride; + return *this; + } + + inline size_t aStride() const { + if (this->aStride_ == 0) { + return this->channels_; + } else { + assert(this->aStride_ >= this->channels_); + return this->aStride_; + } + } + + inline AddOperatorTester& bStride(size_t bStride) { + assert(bStride != 0); + this->bStride_ = bStride; + return *this; + } + + inline size_t bStride() const { + if (this->bStride_ == 0) { + return this->channels_; + } else { + assert(this->bStride_ >= this->channels_); + return this->bStride_; + } + } + + inline AddOperatorTester& yStride(size_t yStride) { + assert(yStride != 0); + this->yStride_ = yStride; + return *this; + } + + inline size_t yStride() const { + if (this->yStride_ == 0) { + return this->channels_; + } else { + assert(this->yStride_ >= this->channels_); + return this->yStride_; + } + } + + inline AddOperatorTester& batchSize(size_t batchSize) { + this->batchSize_ = batchSize; + return *this; + } + + inline size_t batchSize() const { + return this->batchSize_; + } + + inline AddOperatorTester& aScale(float aScale) { + assert(aScale > 0.0f); + assert(std::isnormal(aScale)); + this->aScale_ = aScale; + return *this; + } + + inline float aScale() const { + return this->aScale_; + } + + inline AddOperatorTester& aZeroPoint(uint8_t aZeroPoint) { + this->aZeroPoint_ = aZeroPoint; + return *this; + } + + inline uint8_t aZeroPoint() const { + return this->aZeroPoint_; + } + + inline AddOperatorTester& bScale(float bScale) { + assert(bScale > 0.0f); + assert(std::isnormal(bScale)); + this->bScale_ = bScale; + return *this; + } + + inline float bScale() const { + return this->bScale_; + } + + inline AddOperatorTester& bZeroPoint(uint8_t bZeroPoint) { + this->bZeroPoint_ = bZeroPoint; + return *this; + } + + inline uint8_t bZeroPoint() const { + return this->bZeroPoint_; + } + + inline AddOperatorTester& yScale(float yScale) { + assert(yScale > 0.0f); + assert(std::isnormal(yScale)); + this->yScale_ = yScale; + return *this; + } + + inline float yScale() const { + return this->yScale_; + } + + inline AddOperatorTester& yZeroPoint(uint8_t yZeroPoint) { + this->yZeroPoint_ = yZeroPoint; + return *this; + } + + inline uint8_t yZeroPoint() const { + return this->yZeroPoint_; + } + + inline AddOperatorTester& qmin(uint8_t qmin) { + this->qmin_ = qmin; + return *this; + } + + inline uint8_t qmin() const { + return this->qmin_; + } + + inline AddOperatorTester& qmax(uint8_t qmax) { + this->qmax_ = qmax; + return *this; + } + + inline uint8_t qmax() const { + return this->qmax_; + } + + inline AddOperatorTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void testQ8() const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector a((batchSize() - 1) * aStride() + channels()); + std::vector b((batchSize() - 1) * bStride() + channels()); + std::vector y((batchSize() - 1) * yStride() + channels()); + std::vector yRef(batchSize() * channels()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(a.begin(), a.end(), std::ref(u8rng)); + std::generate(b.begin(), b.end(), std::ref(u8rng)); + std::fill(y.begin(), y.end(), 0xA5); + + if (batchSize() * channels() > 3) { + ASSERT_NE( + *std::max_element(a.cbegin(), a.cend()), + *std::min_element(a.cbegin(), a.cend())); + ASSERT_NE( + *std::max_element(b.cbegin(), b.cend()), + *std::min_element(b.cbegin(), b.cend())); + } + + /* Compute reference results */ + for (size_t i = 0; i < batchSize(); i++) { + for (size_t c = 0; c < channels(); c++) { + yRef[i * channels() + c] = float(yZeroPoint()) + + float(int32_t(a[i * aStride() + c]) - int32_t(aZeroPoint())) * + (aScale() / yScale()) + + float(int32_t(b[i * bStride() + c]) - int32_t(bZeroPoint())) * + (bScale() / yScale()); + yRef[i * channels() + c] = + std::min(yRef[i * channels() + c], float(qmax())); + yRef[i * channels() + c] = + std::max(yRef[i * channels() + c], float(qmin())); + } + } + + /* Create, setup, run, and destroy Add operator */ + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + pytorch_qnnp_operator_t add_op = nullptr; + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_create_add_nc_q8( + channels(), + aZeroPoint(), + aScale(), + bZeroPoint(), + bScale(), + yZeroPoint(), + yScale(), + qmin(), + qmax(), + 0, + &add_op)); + ASSERT_NE(nullptr, add_op); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_setup_add_nc_q8( + add_op, + batchSize(), + a.data(), + aStride(), + b.data(), + bStride(), + y.data(), + yStride())); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_run_operator(add_op, nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, pytorch_qnnp_delete_operator(add_op)); + add_op = nullptr; + + /* Verify results */ + for (size_t i = 0; i < batchSize(); i++) { + for (size_t c = 0; c < channels(); c++) { + ASSERT_LE(uint32_t(y[i * yStride() + c]), uint32_t(qmax())); + ASSERT_GE(uint32_t(y[i * yStride() + c]), uint32_t(qmin())); + ASSERT_NEAR( + float(int32_t(y[i * yStride() + c])), + yRef[i * channels() + c], + 0.6f); + } + } + } + } + + private: + size_t batchSize_{1}; + size_t channels_{1}; + size_t aStride_{0}; + size_t bStride_{0}; + size_t yStride_{0}; + float aScale_{0.75f}; + float bScale_{1.25f}; + float yScale_{0.96875f}; + uint8_t aZeroPoint_{121}; + uint8_t bZeroPoint_{127}; + uint8_t yZeroPoint_{133}; + uint8_t qmin_{0}; + uint8_t qmax_{255}; + size_t iterations_{15}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/add.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/add.cc new file mode 100644 index 0000000000000..9f599c02b92fb --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/add.cc @@ -0,0 +1,397 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include "add-operator-tester.h" + +TEST(ADD_OP, zero_batch) { + AddOperatorTester().batchSize(0).channels(2).iterations(1).testQ8(); +} + +TEST(ADD_OP, unit_batch) { + for (size_t channels = 1; channels < 100; channels += 15) { + AddOperatorTester().batchSize(1).channels(channels).iterations(3).testQ8(); + } +} + +TEST(ADD_OP, unit_batch_with_qmin) { + for (size_t channels = 1; channels < 100; channels += 15) { + AddOperatorTester() + .batchSize(1) + .channels(channels) + .qmin(128) + .iterations(3) + .testQ8(); + } +} + +TEST(ADD_OP, unit_batch_with_qmax) { + for (size_t channels = 1; channels < 100; channels += 15) { + AddOperatorTester() + .batchSize(1) + .channels(channels) + .qmax(128) + .iterations(3) + .testQ8(); + } +} + +TEST(ADD_OP, unit_batch_with_a_scale) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (float aScale = 1.0e-2f; aScale < 1.0e+2f; aScale *= 10.0f) { + AddOperatorTester() + .batchSize(1) + .channels(channels) + .aScale(aScale) + .iterations(1) + .testQ8(); + } + } +} + +TEST(ADD_OP, unit_batch_with_b_scale) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (float bScale = 1.0e-2f; bScale < 1.0e+2f; bScale *= 10.0f) { + AddOperatorTester() + .batchSize(1) + .channels(channels) + .bScale(bScale) + .iterations(1) + .testQ8(); + } + } +} + +TEST(ADD_OP, unit_batch_with_y_scale) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (float yScale = 1.0e-2f; yScale < 1.0e+2f; yScale *= 10.0f) { + AddOperatorTester() + .batchSize(1) + .channels(channels) + .yScale(yScale) + .iterations(1) + .testQ8(); + } + } +} + +TEST(ADD_OP, unit_batch_with_a_zero_point) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (int32_t aZeroPoint = 0; aZeroPoint <= 255; aZeroPoint += 51) { + AddOperatorTester() + .batchSize(1) + .channels(channels) + .aZeroPoint(uint8_t(aZeroPoint)) + .iterations(1) + .testQ8(); + } + } +} + +TEST(ADD_OP, unit_batch_with_b_zero_point) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (int32_t bZeroPoint = 0; bZeroPoint <= 255; bZeroPoint += 51) { + AddOperatorTester() + .batchSize(1) + .channels(channels) + .bZeroPoint(uint8_t(bZeroPoint)) + .iterations(1) + .testQ8(); + } + } +} + +TEST(ADD_OP, unit_batch_with_y_zero_point) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + AddOperatorTester() + .batchSize(1) + .channels(channels) + .yZeroPoint(uint8_t(yZeroPoint)) + .iterations(1) + .testQ8(); + } + } +} + +TEST(ADD_OP, small_batch) { + for (size_t channels = 1; channels < 100; channels += 15) { + AddOperatorTester().batchSize(3).channels(channels).iterations(3).testQ8(); + } +} + +TEST(ADD_OP, small_batch_with_a_stride) { + for (size_t channels = 1; channels < 100; channels += 15) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .aStride(129) + .iterations(3) + .testQ8(); + } +} + +TEST(ADD_OP, small_batch_with_b_stride) { + for (size_t channels = 1; channels < 100; channels += 15) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .bStride(123) + .iterations(3) + .testQ8(); + } +} + +TEST(ADD_OP, small_batch_with_y_stride) { + for (size_t channels = 1; channels < 100; channels += 15) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .yStride(117) + .iterations(3) + .testQ8(); + } +} + +TEST(ADD_OP, small_batch_with_qmin) { + for (size_t channels = 1; channels < 100; channels += 15) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .qmin(128) + .iterations(3) + .testQ8(); + } +} + +TEST(ADD_OP, small_batch_with_qmax) { + for (size_t channels = 1; channels < 100; channels += 15) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .qmax(128) + .iterations(3) + .testQ8(); + } +} + +TEST(ADD_OP, small_batch_with_a_scale) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (float aScale = 1.0e-2f; aScale < 1.0e+2f; aScale *= 10.0f) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .aScale(aScale) + .iterations(1) + .testQ8(); + } + } +} + +TEST(ADD_OP, small_batch_with_b_scale) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (float bScale = 1.0e-2f; bScale < 1.0e+2f; bScale *= 10.0f) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .bScale(bScale) + .iterations(1) + .testQ8(); + } + } +} + +TEST(ADD_OP, small_batch_with_y_scale) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (float yScale = 1.0e-2f; yScale < 1.0e+2f; yScale *= 10.0f) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .yScale(yScale) + .iterations(1) + .testQ8(); + } + } +} + +TEST(ADD_OP, small_batch_with_a_zero_point) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (int32_t aZeroPoint = 0; aZeroPoint <= 255; aZeroPoint += 51) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .aZeroPoint(uint8_t(aZeroPoint)) + .iterations(1) + .testQ8(); + } + } +} + +TEST(ADD_OP, small_batch_with_b_zero_point) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (int32_t bZeroPoint = 0; bZeroPoint <= 255; bZeroPoint += 51) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .bZeroPoint(uint8_t(bZeroPoint)) + .iterations(1) + .testQ8(); + } + } +} + +TEST(ADD_OP, small_batch_with_y_zero_point) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .yZeroPoint(uint8_t(yZeroPoint)) + .iterations(1) + .testQ8(); + } + } +} + +TEST(ADD_OP, strided_batch) { + for (size_t channels = 1; channels < 100; channels += 15) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .aStride(129) + .bStride(123) + .yStride(117) + .iterations(3) + .testQ8(); + } +} + +TEST(ADD_OP, strided_batch_with_qmin) { + for (size_t channels = 1; channels < 100; channels += 15) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .aStride(129) + .bStride(123) + .yStride(117) + .qmin(128) + .iterations(3) + .testQ8(); + } +} + +TEST(ADD_OP, strided_batch_with_qmax) { + for (size_t channels = 1; channels < 100; channels += 15) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .aStride(129) + .bStride(123) + .yStride(117) + .qmax(128) + .iterations(3) + .testQ8(); + } +} + +TEST(ADD_OP, strided_batch_with_a_scale) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (float aScale = 1.0e-2f; aScale < 1.0e+2f; aScale *= 10.0f) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .aStride(129) + .bStride(123) + .yStride(117) + .aScale(aScale) + .iterations(1) + .testQ8(); + } + } +} + +TEST(ADD_OP, strided_batch_with_b_scale) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (float bScale = 1.0e-2f; bScale < 1.0e+2f; bScale *= 10.0f) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .aStride(129) + .bStride(123) + .yStride(117) + .bScale(bScale) + .iterations(1) + .testQ8(); + } + } +} + +TEST(ADD_OP, strided_batch_with_y_scale) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (float yScale = 1.0e-2f; yScale < 1.0e+2f; yScale *= 10.0f) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .aStride(129) + .bStride(123) + .yStride(117) + .yScale(yScale) + .iterations(1) + .testQ8(); + } + } +} + +TEST(ADD_OP, strided_batch_with_a_zero_point) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (int32_t aZeroPoint = 0; aZeroPoint <= 255; aZeroPoint += 51) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .aStride(129) + .bStride(123) + .yStride(117) + .aZeroPoint(uint8_t(aZeroPoint)) + .iterations(1) + .testQ8(); + } + } +} + +TEST(ADD_OP, strided_batch_with_b_zero_point) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (int32_t bZeroPoint = 0; bZeroPoint <= 255; bZeroPoint += 51) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .aStride(129) + .bStride(123) + .yStride(117) + .bZeroPoint(uint8_t(bZeroPoint)) + .iterations(1) + .testQ8(); + } + } +} + +TEST(ADD_OP, strided_batch_with_y_zero_point) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + AddOperatorTester() + .batchSize(3) + .channels(channels) + .aStride(129) + .bStride(123) + .yStride(117) + .yZeroPoint(uint8_t(yZeroPoint)) + .iterations(1) + .testQ8(); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/average-pooling-operator-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/average-pooling-operator-tester.h new file mode 100644 index 0000000000000..f1c472e16fe03 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/average-pooling-operator-tester.h @@ -0,0 +1,874 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +class AveragePoolingOperatorTester { + public: + inline AveragePoolingOperatorTester& padding(uint32_t padding) { + this->paddingTop_ = padding; + this->paddingRight_ = padding; + this->paddingBottom_ = padding; + this->paddingLeft_ = padding; + return *this; + } + + inline AveragePoolingOperatorTester& padding( + uint32_t paddingHeight, + uint32_t paddingWidth) { + this->paddingTop_ = paddingHeight; + this->paddingRight_ = paddingWidth; + this->paddingBottom_ = paddingHeight; + this->paddingLeft_ = paddingWidth; + return *this; + } + + inline AveragePoolingOperatorTester& paddingHeight(uint32_t paddingHeight) { + this->paddingTop_ = paddingHeight; + this->paddingBottom_ = paddingHeight; + return *this; + } + + inline AveragePoolingOperatorTester& paddingWidth(uint32_t paddingWidth) { + this->paddingRight_ = paddingWidth; + this->paddingLeft_ = paddingWidth; + return *this; + } + + inline AveragePoolingOperatorTester& paddingTop(uint32_t paddingTop) { + this->paddingTop_ = paddingTop; + return *this; + } + + inline uint32_t paddingTop() const { + return this->paddingTop_; + } + + inline AveragePoolingOperatorTester& paddingRight(uint32_t paddingRight) { + this->paddingRight_ = paddingRight; + return *this; + } + + inline uint32_t paddingRight() const { + return this->paddingRight_; + } + + inline AveragePoolingOperatorTester& paddingBottom(uint32_t paddingBottom) { + this->paddingBottom_ = paddingBottom; + return *this; + } + + inline uint32_t paddingBottom() const { + return this->paddingBottom_; + } + + inline AveragePoolingOperatorTester& paddingLeft(uint32_t paddingLeft) { + this->paddingLeft_ = paddingLeft; + return *this; + } + + inline uint32_t paddingLeft() const { + return this->paddingLeft_; + } + + inline AveragePoolingOperatorTester& inputSize( + size_t inputHeight, + size_t inputWidth) { + assert(inputHeight >= 1); + assert(inputWidth >= 1); + this->inputHeight_ = inputHeight; + this->inputWidth_ = inputWidth; + return *this; + } + + inline AveragePoolingOperatorTester& inputHeight(size_t inputHeight) { + assert(inputHeight >= 1); + this->inputHeight_ = inputHeight; + return *this; + } + + inline size_t inputHeight() const { + return this->inputHeight_; + } + + inline AveragePoolingOperatorTester& inputWidth(size_t inputWidth) { + assert(inputWidth >= 1); + this->inputWidth_ = inputWidth; + return *this; + } + + inline size_t inputWidth() const { + return this->inputWidth_; + } + + inline AveragePoolingOperatorTester& channels(size_t channels) { + assert(channels != 0); + this->channels_ = channels; + return *this; + } + + inline size_t channels() const { + return this->channels_; + } + + inline AveragePoolingOperatorTester& batchSize(size_t batchSize) { + this->batchSize_ = batchSize; + return *this; + } + + inline size_t batchSize() const { + return this->batchSize_; + } + + inline AveragePoolingOperatorTester& poolingSize(uint32_t poolingSize) { + assert(poolingSize >= 1); + this->poolingHeight_ = poolingSize; + this->poolingWidth_ = poolingSize; + return *this; + } + + inline AveragePoolingOperatorTester& poolingSize( + uint32_t poolingHeight, + uint32_t poolingWidth) { + assert(poolingHeight >= 1); + assert(poolingWidth >= 1); + this->poolingHeight_ = poolingHeight; + this->poolingWidth_ = poolingWidth; + return *this; + } + + inline AveragePoolingOperatorTester& poolingHeight(uint32_t poolingHeight) { + assert(poolingHeight >= 1); + this->poolingHeight_ = poolingHeight; + return *this; + } + + inline uint32_t poolingHeight() const { + return this->poolingHeight_; + } + + inline AveragePoolingOperatorTester& poolingWidth(uint32_t poolingWidth) { + assert(poolingWidth >= 1); + this->poolingWidth_ = poolingWidth; + return *this; + } + + inline uint32_t poolingWidth() const { + return this->poolingWidth_; + } + + inline AveragePoolingOperatorTester& stride(uint32_t stride) { + assert(stride >= 1); + this->strideHeight_ = stride; + this->strideWidth_ = stride; + return *this; + } + + inline AveragePoolingOperatorTester& stride( + uint32_t strideHeight, + uint32_t strideWidth) { + assert(strideHeight >= 1); + assert(strideWidth >= 1); + this->strideHeight_ = strideHeight; + this->strideWidth_ = strideWidth; + return *this; + } + + inline AveragePoolingOperatorTester& strideHeight(uint32_t strideHeight) { + assert(strideHeight >= 1); + this->strideHeight_ = strideHeight; + return *this; + } + + inline uint32_t strideHeight() const { + return this->strideHeight_; + } + + inline AveragePoolingOperatorTester& strideWidth(uint32_t strideWidth) { + assert(strideWidth >= 1); + this->strideWidth_ = strideWidth; + return *this; + } + + inline uint32_t strideWidth() const { + return this->strideWidth_; + } + + inline size_t outputHeight() const { + const size_t paddedInputHeight = + paddingTop() + inputHeight() + paddingBottom(); + if (paddedInputHeight <= poolingHeight()) { + return 1; + } else { + return (paddedInputHeight - poolingHeight()) / strideHeight() + 1; + } + } + + inline size_t outputWidth() const { + const size_t paddedInputWidth = + paddingLeft() + inputWidth() + paddingRight(); + if (paddedInputWidth <= poolingWidth()) { + return 1; + } else { + return (paddedInputWidth - poolingWidth()) / strideWidth() + 1; + } + } + + inline AveragePoolingOperatorTester& inputPixelStride( + size_t inputPixelStride) { + assert(inputPixelStride != 0); + this->inputPixelStride_ = inputPixelStride; + return *this; + } + + inline size_t inputPixelStride() const { + if (this->inputPixelStride_ == 0) { + return channels(); + } else { + assert(this->inputPixelStride_ >= channels()); + return this->inputPixelStride_; + } + } + + inline AveragePoolingOperatorTester& outputPixelStride( + size_t outputPixelStride) { + assert(outputPixelStride != 0); + this->outputPixelStride_ = outputPixelStride; + return *this; + } + + inline size_t outputPixelStride() const { + if (this->outputPixelStride_ == 0) { + return channels(); + } else { + assert(this->outputPixelStride_ >= channels()); + return this->outputPixelStride_; + } + } + + inline AveragePoolingOperatorTester& nextInputSize( + uint32_t nextInputHeight, + uint32_t nextInputWidth) { + assert(nextInputHeight >= 1); + assert(nextInputWidth >= 1); + this->nextInputHeight_ = nextInputHeight; + this->nextInputWidth_ = nextInputWidth; + return *this; + } + + inline AveragePoolingOperatorTester& nextInputHeight( + uint32_t nextInputHeight) { + assert(nextInputHeight >= 1); + this->nextInputHeight_ = nextInputHeight; + return *this; + } + + inline uint32_t nextInputHeight() const { + if (this->nextInputHeight_ == 0) { + return inputHeight(); + } else { + return this->nextInputHeight_; + } + } + + inline AveragePoolingOperatorTester& nextInputWidth(uint32_t nextInputWidth) { + assert(nextInputWidth >= 1); + this->nextInputWidth_ = nextInputWidth; + return *this; + } + + inline uint32_t nextInputWidth() const { + if (this->nextInputWidth_ == 0) { + return inputWidth(); + } else { + return this->nextInputWidth_; + } + } + + inline size_t nextOutputHeight() const { + const size_t paddedNextInputHeight = + paddingTop() + nextInputHeight() + paddingBottom(); + if (paddedNextInputHeight <= poolingHeight()) { + return 1; + } else { + return (paddedNextInputHeight - poolingHeight()) / strideHeight() + 1; + } + } + + inline size_t nextOutputWidth() const { + const size_t paddedNextInputWidth = + paddingLeft() + nextInputWidth() + paddingRight(); + if (paddedNextInputWidth <= poolingWidth()) { + return 1; + } else { + return (paddedNextInputWidth - poolingWidth()) / strideWidth() + 1; + } + } + + inline AveragePoolingOperatorTester& nextBatchSize(size_t nextBatchSize) { + assert(nextBatchSize >= 1); + this->nextBatchSize_ = nextBatchSize; + return *this; + } + + inline size_t nextBatchSize() const { + if (this->nextBatchSize_ == 0) { + return batchSize(); + } else { + return this->nextBatchSize_; + } + } + + inline AveragePoolingOperatorTester& inputScale(float inputScale) { + assert(inputScale > 0.0f); + assert(std::isnormal(inputScale)); + this->inputScale_ = inputScale; + return *this; + } + + inline float inputScale() const { + return this->inputScale_; + } + + inline AveragePoolingOperatorTester& inputZeroPoint(uint8_t inputZeroPoint) { + this->inputZeroPoint_ = inputZeroPoint; + return *this; + } + + inline uint8_t inputZeroPoint() const { + return this->inputZeroPoint_; + } + + inline AveragePoolingOperatorTester& outputScale(float outputScale) { + assert(outputScale > 0.0f); + assert(std::isnormal(outputScale)); + this->outputScale_ = outputScale; + return *this; + } + + inline float outputScale() const { + return this->outputScale_; + } + + inline AveragePoolingOperatorTester& outputZeroPoint( + uint8_t outputZeroPoint) { + this->outputZeroPoint_ = outputZeroPoint; + return *this; + } + + inline uint8_t outputZeroPoint() const { + return this->outputZeroPoint_; + } + + inline AveragePoolingOperatorTester& qmin(uint8_t qmin) { + this->qmin_ = qmin; + return *this; + } + + inline uint8_t qmin() const { + return this->qmin_; + } + + inline AveragePoolingOperatorTester& qmax(uint8_t qmax) { + this->qmax_ = qmax; + return *this; + } + + inline uint8_t qmax() const { + return this->qmax_; + } + + inline AveragePoolingOperatorTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void testQ8() const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input( + (batchSize() * inputHeight() * inputWidth() - 1) * inputPixelStride() + + channels()); + std::vector output( + (batchSize() * outputHeight() * outputWidth() - 1) * + outputPixelStride() + + channels()); + std::vector outputRef( + batchSize() * outputHeight() * outputWidth() * channels()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::fill(output.begin(), output.end(), 0xA5); + + /* Compute reference results */ + const double scale = double(inputScale()) / + (double(outputScale()) * double(poolingHeight() * poolingWidth())); + for (size_t i = 0; i < batchSize(); i++) { + for (size_t oy = 0; oy < outputHeight(); oy++) { + for (size_t ox = 0; ox < outputWidth(); ox++) { + for (size_t c = 0; c < channels(); c++) { + double acc = 0.0f; + for (size_t py = 0; py < poolingHeight(); py++) { + const size_t iy = oy * strideHeight() + py - paddingTop(); + for (size_t px = 0; px < poolingWidth(); px++) { + const size_t ix = ox * strideWidth() + px - paddingLeft(); + if (ix < inputWidth() && iy < inputHeight()) { + acc += double( + int32_t(input + [((i * inputHeight() + iy) * inputWidth() + + ix) * + inputPixelStride() + + c]) - + int32_t(inputZeroPoint())); + } + } + } + outputRef + [((i * outputHeight() + oy) * outputWidth() + ox) * + channels() + + c] = float(acc * scale + double(outputZeroPoint())); + outputRef + [((i * outputHeight() + oy) * outputWidth() + ox) * + channels() + + c] = + std::min( + outputRef + [((i * outputHeight() + oy) * outputWidth() + + ox) * + channels() + + c], + float(qmax())); + outputRef + [((i * outputHeight() + oy) * outputWidth() + ox) * + channels() + + c] = + std::max( + outputRef + [((i * outputHeight() + oy) * outputWidth() + + ox) * + channels() + + c], + float(qmin())); + } + } + } + } + + /* Create, setup, run, and destroy Average Pooling operator */ + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + pytorch_qnnp_operator_t averagePoolingOp = nullptr; + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_create_average_pooling2d_nhwc_q8( + paddingTop(), + paddingRight(), + paddingBottom(), + paddingLeft(), + poolingHeight(), + poolingWidth(), + strideHeight(), + strideWidth(), + channels(), + inputZeroPoint(), + inputScale(), + outputZeroPoint(), + outputScale(), + qmin(), + qmax(), + 0, + &averagePoolingOp)); + ASSERT_NE(nullptr, averagePoolingOp); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_setup_average_pooling2d_nhwc_q8( + averagePoolingOp, + batchSize(), + inputHeight(), + inputWidth(), + input.data(), + inputPixelStride(), + output.data(), + outputPixelStride(), + nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_run_operator( + averagePoolingOp, nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_delete_operator(averagePoolingOp)); + averagePoolingOp = nullptr; + + /* Verify results */ + for (size_t i = 0; i < batchSize(); i++) { + for (size_t y = 0; y < outputHeight(); y++) { + for (size_t x = 0; x < outputWidth(); x++) { + for (size_t c = 0; c < channels(); c++) { + ASSERT_LE( + uint32_t(output + [((i * outputHeight() + y) * outputWidth() + x) * + outputPixelStride() + + c]), + uint32_t(qmax())); + ASSERT_GE( + uint32_t(output + [((i * outputHeight() + y) * outputWidth() + x) * + outputPixelStride() + + c]), + uint32_t(qmin())); + ASSERT_NEAR( + float(int32_t( + output + [((i * outputHeight() + y) * outputWidth() + x) * + outputPixelStride() + + c])), + outputRef + [((i * outputHeight() + y) * outputWidth() + x) * + channels() + + c], + 0.80f) + << "in batch index " << i << ", pixel (" << y << ", " << x + << "), channel " << c; + } + } + } + } + } + } + + void testSetupQ8() const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input(std::max( + (batchSize() * inputHeight() * inputWidth() - 1) * inputPixelStride() + + channels(), + (nextBatchSize() * nextInputHeight() * nextInputWidth() - 1) * + inputPixelStride() + + channels())); + std::vector output(std::max( + (batchSize() * outputHeight() * outputWidth() - 1) * + outputPixelStride() + + channels(), + (nextBatchSize() * nextOutputHeight() * nextOutputWidth() - 1) * + outputPixelStride() + + channels())); + std::vector outputRef( + batchSize() * outputHeight() * outputWidth() * channels()); + std::vector nextOutputRef( + nextBatchSize() * nextOutputHeight() * nextOutputWidth() * channels()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::fill(output.begin(), output.end(), 0xA5); + + /* Compute reference results */ + const double scale = double(inputScale()) / + (double(outputScale()) * double(poolingHeight() * poolingWidth())); + for (size_t i = 0; i < batchSize(); i++) { + for (size_t oy = 0; oy < outputHeight(); oy++) { + for (size_t ox = 0; ox < outputWidth(); ox++) { + for (size_t c = 0; c < channels(); c++) { + double acc = 0.0f; + for (size_t py = 0; py < poolingHeight(); py++) { + const size_t iy = oy * strideHeight() + py - paddingTop(); + for (size_t px = 0; px < poolingWidth(); px++) { + const size_t ix = ox * strideWidth() + px - paddingLeft(); + if (ix < inputWidth() && iy < inputHeight()) { + acc += double( + int32_t(input + [((i * inputHeight() + iy) * inputWidth() + + ix) * + inputPixelStride() + + c]) - + int32_t(inputZeroPoint())); + } + } + } + outputRef + [((i * outputHeight() + oy) * outputWidth() + ox) * + channels() + + c] = float(acc * scale + double(outputZeroPoint())); + outputRef + [((i * outputHeight() + oy) * outputWidth() + ox) * + channels() + + c] = + std::min( + outputRef + [((i * outputHeight() + oy) * outputWidth() + + ox) * + channels() + + c], + float(qmax())); + outputRef + [((i * outputHeight() + oy) * outputWidth() + ox) * + channels() + + c] = + std::max( + outputRef + [((i * outputHeight() + oy) * outputWidth() + + ox) * + channels() + + c], + float(qmin())); + } + } + } + } + + /* Create, setup, and run Average Pooling operator once */ + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + pytorch_qnnp_operator_t averagePoolingOp = nullptr; + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_create_average_pooling2d_nhwc_q8( + paddingTop(), + paddingRight(), + paddingBottom(), + paddingLeft(), + poolingHeight(), + poolingWidth(), + strideHeight(), + strideWidth(), + channels(), + inputZeroPoint(), + inputScale(), + outputZeroPoint(), + outputScale(), + qmin(), + qmax(), + 0, + &averagePoolingOp)); + ASSERT_NE(nullptr, averagePoolingOp); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_setup_average_pooling2d_nhwc_q8( + averagePoolingOp, + batchSize(), + inputHeight(), + inputWidth(), + input.data(), + inputPixelStride(), + output.data(), + outputPixelStride(), + nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_run_operator( + averagePoolingOp, nullptr /* thread pool */)); + + /* Verify results of the first run */ + for (size_t i = 0; i < batchSize(); i++) { + for (size_t y = 0; y < outputHeight(); y++) { + for (size_t x = 0; x < outputWidth(); x++) { + for (size_t c = 0; c < channels(); c++) { + ASSERT_LE( + uint32_t(output + [((i * outputHeight() + y) * outputWidth() + x) * + outputPixelStride() + + c]), + uint32_t(qmax())); + ASSERT_GE( + uint32_t(output + [((i * outputHeight() + y) * outputWidth() + x) * + outputPixelStride() + + c]), + uint32_t(qmin())); + ASSERT_NEAR( + float(int32_t( + output + [((i * outputHeight() + y) * outputWidth() + x) * + outputPixelStride() + + c])), + outputRef + [((i * outputHeight() + y) * outputWidth() + x) * + channels() + + c], + 0.80f) + << "in batch index " << i << ", pixel (" << y << ", " << x + << "), channel " << c; + } + } + } + } + + /* Re-generate data for the second run */ + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::fill(output.begin(), output.end(), 0xA5); + + /* Compute reference results for the second run */ + for (size_t i = 0; i < nextBatchSize(); i++) { + for (size_t oy = 0; oy < nextOutputHeight(); oy++) { + for (size_t ox = 0; ox < nextOutputWidth(); ox++) { + for (size_t c = 0; c < channels(); c++) { + double acc = 0.0f; + for (size_t py = 0; py < poolingHeight(); py++) { + const size_t iy = oy * strideHeight() + py - paddingTop(); + for (size_t px = 0; px < poolingWidth(); px++) { + const size_t ix = ox * strideWidth() + px - paddingLeft(); + if (ix < nextInputWidth() && iy < nextInputHeight()) { + acc += double( + int32_t(input + [((i * nextInputHeight() + iy) * + nextInputWidth() + + ix) * + inputPixelStride() + + c]) - + int32_t(inputZeroPoint())); + } + } + } + nextOutputRef + [((i * nextOutputHeight() + oy) * nextOutputWidth() + ox) * + channels() + + c] = float(acc * scale + double(outputZeroPoint())); + nextOutputRef + [((i * nextOutputHeight() + oy) * nextOutputWidth() + ox) * + channels() + + c] = + std::min( + nextOutputRef + [((i * nextOutputHeight() + oy) * + nextOutputWidth() + + ox) * + channels() + + c], + float(qmax())); + nextOutputRef + [((i * nextOutputHeight() + oy) * nextOutputWidth() + ox) * + channels() + + c] = + std::max( + nextOutputRef + [((i * nextOutputHeight() + oy) * + nextOutputWidth() + + ox) * + channels() + + c], + float(qmin())); + } + } + } + } + + /* Setup and run Average Pooling operator the second time, and destroy the + * operator */ + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_setup_average_pooling2d_nhwc_q8( + averagePoolingOp, + nextBatchSize(), + nextInputHeight(), + nextInputWidth(), + input.data(), + inputPixelStride(), + output.data(), + outputPixelStride(), + nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_run_operator( + averagePoolingOp, nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_delete_operator(averagePoolingOp)); + averagePoolingOp = nullptr; + + /* Verify results of the second run */ + for (size_t i = 0; i < nextBatchSize(); i++) { + for (size_t y = 0; y < nextOutputHeight(); y++) { + for (size_t x = 0; x < nextOutputWidth(); x++) { + for (size_t c = 0; c < channels(); c++) { + ASSERT_LE( + uint32_t( + output + [((i * nextOutputHeight() + y) * nextOutputWidth() + + x) * + outputPixelStride() + + c]), + uint32_t(qmax())); + ASSERT_GE( + uint32_t( + output + [((i * nextOutputHeight() + y) * nextOutputWidth() + + x) * + outputPixelStride() + + c]), + uint32_t(qmin())); + ASSERT_NEAR( + float(int32_t( + output + [((i * nextOutputHeight() + y) * nextOutputWidth() + + x) * + outputPixelStride() + + c])), + nextOutputRef + [((i * nextOutputHeight() + y) * nextOutputWidth() + x) * + channels() + + c], + 0.80f) + << "in batch index " << i << ", pixel (" << y << ", " << x + << "), channel " << c; + } + } + } + } + } + } + + private: + uint32_t paddingTop_{0}; + uint32_t paddingRight_{0}; + uint32_t paddingBottom_{0}; + uint32_t paddingLeft_{0}; + size_t inputHeight_{1}; + size_t inputWidth_{1}; + size_t channels_{1}; + size_t batchSize_{1}; + size_t inputPixelStride_{0}; + size_t outputPixelStride_{0}; + uint32_t poolingHeight_{1}; + uint32_t poolingWidth_{1}; + uint32_t strideHeight_{1}; + uint32_t strideWidth_{1}; + size_t nextInputHeight_{0}; + size_t nextInputWidth_{0}; + size_t nextBatchSize_{0}; + float inputScale_{1.0f}; + float outputScale_{1.0f}; + uint8_t inputZeroPoint_{121}; + uint8_t outputZeroPoint_{133}; + uint8_t qmin_{0}; + uint8_t qmax_{255}; + size_t iterations_{1}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/average-pooling.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/average-pooling.cc new file mode 100644 index 0000000000000..4977e52fa471f --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/average-pooling.cc @@ -0,0 +1,1512 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include "average-pooling-operator-tester.h" + +TEST(AVERAGE_POOLING_OP, zero_batch) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + AveragePoolingOperatorTester() + .batchSize(0) + .inputHeight(2) + .inputWidth(4) + .poolingHeight(1) + .poolingWidth(2) + .channels(4) + .testQ8(); +} + +TEST(AVERAGE_POOLING_OP, unit_batch_many_channels_small_1xM_pool) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_many_channels_small_1xM_pool_with_padding) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = 3; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + for (size_t paddingLeft = 0; paddingLeft <= 1; paddingLeft++) { + for (size_t paddingRight = 0; paddingRight <= 1; paddingRight++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .paddingLeft(paddingLeft) + .paddingRight(paddingRight) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .testQ8(); + } + } + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_many_channels_small_1xM_pool_with_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 4) + .poolingHeight(1) + .poolingWidth(poolSize) + .strideWidth(2) + .channels(channels) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_many_channels_small_Mx1_pool) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_many_channels_small_Mx1_pool_with_padding) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + for (size_t paddingTop = 0; paddingTop <= 1; paddingTop++) { + for (size_t paddingBottom = 0; paddingBottom <= 1; paddingBottom++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .paddingTop(paddingTop) + .paddingBottom(paddingBottom) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testQ8(); + } + } + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_many_channels_small_Mx1_pool_with_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 3) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .strideHeight(2) + .channels(channels) + .testQ8(); + } + } +} + +TEST( + AVERAGE_POOLING_OP, + unit_batch_many_channels_small_pool_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + } + } +} + +TEST( + AVERAGE_POOLING_OP, + unit_batch_many_channels_small_pool_with_output_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_many_channels_small_pool_with_input_scale) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + for (float inputScale = 0.01f; inputScale < 100.0f; + inputScale *= 3.14159265f) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .inputScale(inputScale) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .inputScale(inputScale) + .testQ8(); + } + } + } +} + +TEST( + AVERAGE_POOLING_OP, + unit_batch_many_channels_small_pool_with_input_zero_point) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + for (int32_t inputZeroPoint = 0; inputZeroPoint <= 255; + inputZeroPoint += 51) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .inputZeroPoint(uint8_t(inputZeroPoint)) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .inputZeroPoint(uint8_t(inputZeroPoint)) + .testQ8(); + } + } + } +} + +TEST( + AVERAGE_POOLING_OP, + unit_batch_many_channels_small_pool_with_output_scale) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + for (float outputScale = 0.01f; outputScale < 100.0f; + outputScale *= 3.14159265f) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .outputScale(outputScale) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .outputScale(outputScale) + .testQ8(); + } + } + } +} + +TEST( + AVERAGE_POOLING_OP, + unit_batch_many_channels_small_pool_with_output_zero_point) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + for (int32_t outputZeroPoint = 0; outputZeroPoint <= 255; + outputZeroPoint += 51) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .outputZeroPoint(uint8_t(outputZeroPoint)) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .outputZeroPoint(uint8_t(outputZeroPoint)) + .testQ8(); + } + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_many_channels_small_pool_with_qmin) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .qmin(128) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .qmin(128) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_many_channels_small_pool_with_qmax) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .qmax(128) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .qmax(128) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_many_channels_large_1xM_pool) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = pytorch_qnnp_params.q8avgpool.mr + 1; poolSize <= + pytorch_qnnp_params.q8avgpool.mr + pytorch_qnnp_params.q8avgpool.qr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_many_channels_large_1xM_pool_with_padding) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = 3; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + for (size_t paddingLeft = 0; paddingLeft <= 1; paddingLeft++) { + for (size_t paddingRight = 0; paddingRight <= 1; paddingRight++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .paddingLeft(paddingLeft) + .paddingRight(paddingRight) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .testQ8(); + } + } + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_many_channels_large_1xM_pool_with_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = pytorch_qnnp_params.q8avgpool.mr + 1; poolSize <= + pytorch_qnnp_params.q8avgpool.mr + pytorch_qnnp_params.q8avgpool.qr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 4) + .poolingHeight(1) + .poolingWidth(poolSize) + .strideWidth(2) + .channels(channels) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_many_channels_large_Mx1_pool) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = pytorch_qnnp_params.q8avgpool.mr + 1; poolSize <= + pytorch_qnnp_params.q8avgpool.mr + pytorch_qnnp_params.q8avgpool.qr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_many_channels_large_Mx1_pool_with_padding) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = pytorch_qnnp_params.q8avgpool.mr + 1; poolSize <= + pytorch_qnnp_params.q8avgpool.mr + pytorch_qnnp_params.q8avgpool.qr; + poolSize++) { + for (size_t paddingTop = 0; paddingTop <= 1; paddingTop++) { + for (size_t paddingBottom = 0; paddingBottom <= 1; paddingBottom++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .paddingTop(paddingTop) + .paddingBottom(paddingBottom) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testQ8(); + } + } + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_many_channels_large_Mx1_pool_with_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = pytorch_qnnp_params.q8avgpool.mr + 1; poolSize <= + pytorch_qnnp_params.q8avgpool.mr + pytorch_qnnp_params.q8avgpool.qr; + poolSize++) { + for (size_t paddingTop = 0; paddingTop <= 1; paddingTop++) { + for (size_t paddingBottom = 0; paddingBottom <= 1; paddingBottom++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .paddingTop(paddingTop) + .paddingBottom(paddingBottom) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testQ8(); + } + } + } + } +} + +TEST( + AVERAGE_POOLING_OP, + unit_batch_many_channels_large_pool_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = pytorch_qnnp_params.q8avgpool.mr + 1; poolSize <= + pytorch_qnnp_params.q8avgpool.mr + pytorch_qnnp_params.q8avgpool.qr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_many_channels_large_pool_with_input_scale) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = pytorch_qnnp_params.q8avgpool.mr + 1; poolSize <= + pytorch_qnnp_params.q8avgpool.mr + pytorch_qnnp_params.q8avgpool.qr; + poolSize++) { + for (float inputScale = 0.01f; inputScale < 100.0f; + inputScale *= 3.14159265f) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .inputScale(inputScale) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .inputScale(inputScale) + .testQ8(); + } + } + } +} + +TEST( + AVERAGE_POOLING_OP, + unit_batch_many_channels_large_pool_with_input_zero_point) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = pytorch_qnnp_params.q8avgpool.mr + 1; poolSize <= + pytorch_qnnp_params.q8avgpool.mr + pytorch_qnnp_params.q8avgpool.qr; + poolSize++) { + for (int32_t inputZeroPoint = 0; inputZeroPoint <= 255; + inputZeroPoint += 51) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .inputZeroPoint(uint8_t(inputZeroPoint)) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .inputZeroPoint(uint8_t(inputZeroPoint)) + .testQ8(); + } + } + } +} + +TEST( + AVERAGE_POOLING_OP, + unit_batch_many_channels_large_pool_with_output_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = pytorch_qnnp_params.q8avgpool.mr + 1; poolSize <= + pytorch_qnnp_params.q8avgpool.mr + pytorch_qnnp_params.q8avgpool.qr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + } + } +} + +TEST( + AVERAGE_POOLING_OP, + unit_batch_many_channels_large_pool_with_output_scale) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = pytorch_qnnp_params.q8avgpool.mr + 1; poolSize <= + pytorch_qnnp_params.q8avgpool.mr + pytorch_qnnp_params.q8avgpool.qr; + poolSize++) { + for (float outputScale = 0.01f; outputScale < 100.0f; + outputScale *= 3.14159265f) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .outputScale(outputScale) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .outputScale(outputScale) + .testQ8(); + } + } + } +} + +TEST( + AVERAGE_POOLING_OP, + unit_batch_many_channels_large_pool_with_output_zero_point) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = pytorch_qnnp_params.q8avgpool.mr + 1; poolSize <= + pytorch_qnnp_params.q8avgpool.mr + pytorch_qnnp_params.q8avgpool.qr; + poolSize++) { + for (int32_t outputZeroPoint = 0; outputZeroPoint <= 255; + outputZeroPoint += 51) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .outputZeroPoint(uint8_t(outputZeroPoint)) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .outputZeroPoint(uint8_t(outputZeroPoint)) + .testQ8(); + } + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_many_channels_large_pool_with_qmin) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = pytorch_qnnp_params.q8avgpool.mr + 1; poolSize <= + pytorch_qnnp_params.q8avgpool.mr + pytorch_qnnp_params.q8avgpool.qr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .qmin(128) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .qmin(128) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_many_channels_large_pool_with_qmax) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = pytorch_qnnp_params.q8avgpool.mr + 1; poolSize <= + pytorch_qnnp_params.q8avgpool.mr + pytorch_qnnp_params.q8avgpool.qr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .qmax(128) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .qmax(128) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_few_channels_1xM_pool) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.q8avgpool.kr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_few_channels_1xM_pool_with_padding) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 3; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + for (size_t paddingLeft = 0; paddingLeft <= 1; paddingLeft++) { + for (size_t paddingRight = 0; paddingRight <= 1; paddingRight++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .paddingLeft(paddingLeft) + .paddingRight(paddingRight) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .testQ8(); + } + } + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_few_channels_1xM_pool_with_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.q8avgpool.kr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 4) + .poolingHeight(1) + .poolingWidth(poolSize) + .strideWidth(2) + .channels(channels) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_few_channels_Mx1_pool) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.q8avgpool.kr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_few_channels_Mx1_pool_with_padding) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.q8avgpool.kr; + poolSize++) { + for (size_t paddingTop = 0; paddingTop <= 1; paddingTop++) { + for (size_t paddingBottom = 0; paddingBottom <= 1; paddingBottom++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .paddingTop(paddingTop) + .paddingBottom(paddingBottom) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testQ8(); + } + } + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_few_channels_Mx1_pool_with_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.q8avgpool.kr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 3) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .strideHeight(2) + .channels(channels) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_few_channels_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.q8avgpool.kr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_few_channels_with_input_scale) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.q8avgpool.kr; + poolSize++) { + for (float inputScale = 0.01f; inputScale < 100.0f; + inputScale *= 3.14159265f) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .inputScale(inputScale) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .inputScale(inputScale) + .testQ8(); + } + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_few_channels_with_input_zero_point) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.q8avgpool.kr; + poolSize++) { + for (int32_t inputZeroPoint = 0; inputZeroPoint <= 255; + inputZeroPoint += 51) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .inputZeroPoint(uint8_t(inputZeroPoint)) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .inputZeroPoint(uint8_t(inputZeroPoint)) + .testQ8(); + } + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_few_channels_with_output_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.q8avgpool.kr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_few_channels_with_output_scale) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.q8avgpool.kr; + poolSize++) { + for (float outputScale = 0.01f; outputScale < 100.0f; + outputScale *= 3.14159265f) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .outputScale(outputScale) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .outputScale(outputScale) + .testQ8(); + } + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_few_channels_with_output_zero_point) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.q8avgpool.kr; + poolSize++) { + for (int32_t outputZeroPoint = 0; outputZeroPoint <= 255; + outputZeroPoint += 51) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .outputZeroPoint(uint8_t(outputZeroPoint)) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .outputZeroPoint(uint8_t(outputZeroPoint)) + .testQ8(); + } + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_few_channels_with_qmin) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.q8avgpool.kr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .qmin(128) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 1) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .qmin(128) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, unit_batch_few_channels_with_qmax) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.q8avgpool.kr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .qmax(128) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 1) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .qmax(128) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, small_batch_many_channels_small_pool) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .testQ8(); + } + } +} + +TEST( + AVERAGE_POOLING_OP, + small_batch_many_channels_small_pool_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(2) + .inputWidth(poolSize + 1) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + } + } +} + +TEST( + AVERAGE_POOLING_OP, + small_batch_many_channels_small_pool_with_output_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.q8avgpool.mr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(2) + .inputWidth(poolSize + 1) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, small_batch_many_channels_large_pool) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = pytorch_qnnp_params.q8avgpool.mr + 1; poolSize <= + pytorch_qnnp_params.q8avgpool.mr + pytorch_qnnp_params.q8avgpool.qr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .testQ8(); + } + } +} + +TEST( + AVERAGE_POOLING_OP, + small_batch_many_channels_large_pool_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 5) { + for (size_t poolSize = pytorch_qnnp_params.q8avgpool.mr + 1; poolSize <= + pytorch_qnnp_params.q8avgpool.mr + pytorch_qnnp_params.q8avgpool.qr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(2) + .inputWidth(poolSize + 1) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + } + } +} + +TEST( + AVERAGE_POOLING_OP, + small_batch_many_channels_large_pool_with_output_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8avgpool.kr; + channels <= 3 * pytorch_qnnp_params.q8avgpool.kr; + channels += 5) { + for (size_t poolSize = pytorch_qnnp_params.q8avgpool.mr + 1; poolSize <= + pytorch_qnnp_params.q8avgpool.mr + pytorch_qnnp_params.q8avgpool.qr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(2) + .inputWidth(poolSize + 1) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, small_batch_few_channels) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.q8avgpool.kr; + poolSize++) { + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, small_batch_few_channels_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.q8avgpool.kr; + poolSize += 3) { + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, small_batch_few_channels_with_output_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8avgpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.q8avgpool.kr; + poolSize += 3) { + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.q8avgpool.kr) + .testQ8(); + } + } +} + +TEST(AVERAGE_POOLING_OP, setup_increasing_batch) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + AveragePoolingOperatorTester() + .batchSize(3) + .nextBatchSize(5) + .inputHeight(8) + .inputWidth(8) + .poolingHeight(5) + .poolingWidth(3) + .channels(24) + .testSetupQ8(); +} + +TEST(AVERAGE_POOLING_OP, setup_decreasing_batch) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + AveragePoolingOperatorTester() + .batchSize(5) + .nextBatchSize(3) + .inputHeight(8) + .inputWidth(8) + .poolingHeight(5) + .poolingWidth(3) + .channels(24) + .testSetupQ8(); +} + +TEST(AVERAGE_POOLING_OP, setup_changing_height) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(8) + .inputWidth(8) + .nextInputHeight(9) + .poolingHeight(5) + .poolingWidth(3) + .channels(24) + .testSetupQ8(); + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(8) + .inputWidth(8) + .nextInputHeight(7) + .poolingHeight(5) + .poolingWidth(3) + .channels(24) + .testSetupQ8(); +} + +TEST(AVERAGE_POOLING_OP, setup_changing_width) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(8) + .inputWidth(8) + .nextInputWidth(9) + .poolingHeight(5) + .poolingWidth(3) + .channels(24) + .testSetupQ8(); + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(8) + .inputWidth(8) + .nextInputWidth(7) + .poolingHeight(5) + .poolingWidth(3) + .channels(24) + .testSetupQ8(); +} + +TEST(AVERAGE_POOLING_OP, setup_swap_height_and_width) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + AveragePoolingOperatorTester() + .batchSize(3) + .inputHeight(9) + .inputWidth(8) + .nextInputHeight(8) + .nextInputWidth(9) + .poolingHeight(5) + .poolingWidth(3) + .channels(24) + .testSetupQ8(); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/avgpool-microkernel-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/avgpool-microkernel-tester.h new file mode 100644 index 0000000000000..5733741acea4b --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/avgpool-microkernel-tester.h @@ -0,0 +1,429 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +class AvgPoolMicrokernelTester { + public: + inline AvgPoolMicrokernelTester& n(size_t n) { + assert(n != 0); + this->n_ = n; + return *this; + } + + inline size_t n() const { + return this->n_; + } + + inline AvgPoolMicrokernelTester& s(size_t s) { + assert(s != 0); + this->s_ = s; + return *this; + } + + inline size_t s() const { + return this->s_; + } + + inline AvgPoolMicrokernelTester& kh(size_t kh) { + assert(kh != 0); + this->kh_ = kh; + return *this; + } + + inline size_t kh() const { + return this->kh_; + } + + inline AvgPoolMicrokernelTester& kw(size_t kw) { + assert(kw != 0); + this->kw_ = kw; + return *this; + } + + inline size_t kw() const { + return this->kw_; + } + + inline size_t ks() const { + return kh() * kw(); + } + + inline size_t packedKs() const { + if (kc() < kr()) { + return ks(); + } else if (ks() <= mr()) { + return mr(); + } else { + return (ks() - mr()) % qr() == 0 + ? ks() + : ((ks() - mr()) / qr() + 1) * qr() + mr(); + } + } + + inline AvgPoolMicrokernelTester& mr(size_t mr) { + assert(mr != 0); + this->mr_ = mr; + return *this; + } + + inline size_t mr() const { + return this->mr_; + } + + inline AvgPoolMicrokernelTester& qr(size_t qr) { + assert(qr != 0); + this->qr_ = qr; + return *this; + } + + inline size_t qr() const { + return this->qr_; + } + + inline AvgPoolMicrokernelTester& kc(size_t kc) { + assert(kc != 0); + this->kc_ = kc; + return *this; + } + + inline size_t kc() const { + return this->kc_; + } + + inline AvgPoolMicrokernelTester& kr(size_t kr) { + assert(kr != 0); + this->kr_ = kr; + return *this; + } + + inline size_t kr() const { + return this->kr_; + } + + inline size_t packedN() const { + return kc() % kr() == 0 ? kc() : (kc() / kr() + 1) * kr(); + } + + inline AvgPoolMicrokernelTester& xStride(size_t xStride) { + assert(xStride != 0); + this->xStride_ = xStride; + return *this; + } + + inline size_t xStride() const { + if (this->xStride_ == 0) { + return kc(); + } else { + assert(this->xStride_ >= kc()); + return this->xStride_; + } + } + + inline AvgPoolMicrokernelTester& yStride(size_t yStride) { + assert(yStride != 0); + this->yStride_ = yStride; + return *this; + } + + inline size_t yStride() const { + if (this->yStride_ == 0) { + return kc(); + } else { + assert(this->yStride_ >= kc()); + return this->yStride_; + } + } + + inline AvgPoolMicrokernelTester& xScale(float xScale) { + assert(xScale > 0.0f); + assert(std::isnormal(xScale)); + this->xScale_ = xScale; + return *this; + } + + inline float xScale() const { + return this->xScale_; + } + + inline AvgPoolMicrokernelTester& xZeroPoint(uint8_t xZeroPoint) { + this->xZeroPoint_ = xZeroPoint; + return *this; + } + + inline uint8_t xZeroPoint() const { + return this->xZeroPoint_; + } + + inline AvgPoolMicrokernelTester& yScale(float yScale) { + assert(yScale > 0.0f); + assert(std::isnormal(yScale)); + this->yScale_ = yScale; + return *this; + } + + inline float yScale() const { + return this->yScale_; + } + + inline AvgPoolMicrokernelTester& yZeroPoint(uint8_t yZeroPoint) { + this->yZeroPoint_ = yZeroPoint; + return *this; + } + + inline uint8_t yZeroPoint() const { + return this->yZeroPoint_; + } + + inline AvgPoolMicrokernelTester& yMin(uint8_t yMin) { + this->yMin_ = yMin; + return *this; + } + + inline uint8_t yMin() const { + return this->yMin_; + } + + inline AvgPoolMicrokernelTester& yMax(uint8_t yMax) { + this->yMax_ = yMax; + return *this; + } + + inline uint8_t yMax() const { + return this->yMax_; + } + + inline AvgPoolMicrokernelTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void test(pytorch_q8avgpool_up_ukernel_function q8avgpool) const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector indirectX(packedKs() + (n() * s() - 1) * kh()); + std::vector x((indirectX.size() - 1) * xStride() + kc()); + + std::vector zero(kc()); + std::vector y((n() - 1) * yStride() + kc()); + std::vector yRef(n() * kc()); + std::vector yFP(n() * kc()); + std::vector yAcc(n() * kc()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(x.begin(), x.end(), std::ref(u8rng)); + std::fill(y.begin(), y.end(), 0xA5); + + for (size_t i = 0; i < indirectX.size(); i++) { + indirectX[i] = x.data() + i * xStride(); + } + std::shuffle(indirectX.begin(), indirectX.end(), rng); + + /* Prepare quantization parameters */ + const union pytorch_qnnp_avgpool_quantization_params quantizationParams = + pytorch_qnnp_compute_avgpool_quantization_params( + -int32_t(xZeroPoint()) * int32_t(ks()), + xScale() / (yScale() * float(ks())), + yZeroPoint(), + yMin(), + yMax()); + const union pytorch_qnnp_avgpool_quantization_params + scalarQuantizationParams = + pytorch_qnnp_compute_scalar_avgpool_quantization_params( + -int32_t(xZeroPoint()) * int32_t(ks()), + xScale() / (yScale() * float(ks())), + yZeroPoint(), + yMin(), + yMax()); + + /* Compute reference results */ + for (size_t i = 0; i < n(); i++) { + for (size_t k = 0; k < kc(); k++) { + int32_t acc = scalarQuantizationParams.scalar.bias; + for (size_t j = 0; j < ks(); j++) { + acc += indirectX[i * s() * kh() + j][k]; + } + yAcc[i * kc() + k] = acc; + yRef[i * kc() + k] = + pytorch_qnnp_avgpool_quantize(acc, scalarQuantizationParams); + yFP[i * kc() + k] = + float(acc) * (xScale() / (yScale() * float(ks()))) + + float(yZeroPoint()); + yFP[i * kc() + k] = std::min(yFP[i * kc() + k], float(yMax())); + yFP[i * kc() + k] = std::max(yFP[i * kc() + k], float(yMin())); + } + } + + /* Call optimized micro-kernel */ + q8avgpool( + n(), + ks(), + kc(), + indirectX.data(), + zero.data(), + y.data(), + kh() * s() * sizeof(void*), + (yStride() - kc()) * sizeof(uint8_t), + &quantizationParams); + + /* Verify results */ + for (size_t i = 0; i < n(); i++) { + for (size_t k = 0; k < kc(); k++) { + ASSERT_LE(uint32_t(y[i * yStride() + k]), uint32_t(yMax())) + << "at pixel " << i << ", channel " << k << ", n = " << n() + << ", kc = " << kc(); + ASSERT_GE(uint32_t(y[i * yStride() + k]), uint32_t(yMin())) + << "at pixel " << i << ", channel " << k << ", n = " << n() + << ", kc = " << kc(); + ASSERT_NEAR( + float(int32_t(y[i * yStride() + k])), yFP[i * kc() + k], 0.5f) + << "at pixel " << i << ", channel " << k << ", n = " << n() + << ", ks = " << kh() << "x" << kw() << " (" << ks() + << "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k]; + ASSERT_EQ( + uint32_t(yRef[i * kc() + k]), uint32_t(y[i * yStride() + k])) + << "at pixel " << i << ", channel " << k << ", n = " << n() + << ", ks = " << kh() << "x" << kw() << " (" << ks() + << "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k]; + } + } + } + } + + void test(pytorch_q8avgpool_mp_ukernel_function q8avgpool) const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector indirectX(packedKs() + (n() * s() - 1) * kh()); + std::vector x((indirectX.size() - 1) * xStride() + kc()); + std::vector> mpAcc(packedN()); + + std::vector zero(kc()); + std::vector y((n() - 1) * yStride() + kc()); + std::vector yRef(n() * kc()); + std::vector yFP(n() * kc()); + std::vector yAcc(n() * kc()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(x.begin(), x.end(), std::ref(u8rng)); + std::fill(y.begin(), y.end(), 0xA5); + + for (size_t i = 0; i < indirectX.size(); i++) { + indirectX[i] = x.data() + i * xStride(); + } + std::shuffle(indirectX.begin(), indirectX.end(), rng); + + /* Prepare quantization parameters */ + const union pytorch_qnnp_avgpool_quantization_params quantizationParams = + pytorch_qnnp_compute_avgpool_quantization_params( + -int32_t(xZeroPoint()) * int32_t(ks()), + xScale() / (yScale() * float(ks())), + yZeroPoint(), + yMin(), + yMax()); + const union pytorch_qnnp_avgpool_quantization_params + scalarQuantizationParams = + pytorch_qnnp_compute_scalar_avgpool_quantization_params( + -int32_t(xZeroPoint()) * int32_t(ks()), + xScale() / (yScale() * float(ks())), + yZeroPoint(), + yMin(), + yMax()); + + /* Compute reference results */ + for (size_t i = 0; i < n(); i++) { + for (size_t k = 0; k < kc(); k++) { + int32_t acc = scalarQuantizationParams.scalar.bias; + for (size_t j = 0; j < ks(); j++) { + acc += indirectX[i * s() * kh() + j][k]; + } + yAcc[i * kc() + k] = acc; + yRef[i * kc() + k] = + pytorch_qnnp_avgpool_quantize(acc, scalarQuantizationParams); + yFP[i * kc() + k] = + float(acc) * (xScale() / (yScale() * float(ks()))) + + float(yZeroPoint()); + yFP[i * kc() + k] = std::min(yFP[i * kc() + k], float(yMax())); + yFP[i * kc() + k] = std::max(yFP[i * kc() + k], float(yMin())); + } + } + + /* Call optimized micro-kernel */ + q8avgpool( + n(), + ks(), + kc(), + indirectX.data(), + zero.data(), + mpAcc.data(), + y.data(), + (kh() * s() - (packedKs() - qr())) * sizeof(void*), + (yStride() - kc()) * sizeof(uint8_t), + &quantizationParams); + + /* Verify results */ + for (size_t i = 0; i < n(); i++) { + for (size_t k = 0; k < kc(); k++) { + ASSERT_LE(uint32_t(y[i * yStride() + k]), uint32_t(yMax())) + << "at pixel " << i << ", channel " << k << ", n = " << n() + << ", kc = " << kc(); + ASSERT_GE(uint32_t(y[i * yStride() + k]), uint32_t(yMin())) + << "at pixel " << i << ", channel " << k << ", n = " << n() + << ", kc = " << kc(); + ASSERT_NEAR( + float(int32_t(y[i * yStride() + k])), yFP[i * kc() + k], 0.5f) + << "at pixel " << i << ", channel " << k << ", n = " << n() + << ", ks = " << kh() << "x" << kw() << " (" << ks() + << "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k]; + ASSERT_EQ( + uint32_t(yRef[i * kc() + k]), uint32_t(y[i * yStride() + k])) + << "at pixel " << i << ", channel " << k << ", n = " << n() + << ", ks = " << kh() << "x" << kw() << " (" << ks() + << "), kc = " << kc() << ", acc = " << yAcc[i * kc() + k]; + } + } + } + } + + private: + size_t n_{1}; + size_t s_{1}; + size_t kh_{1}; + size_t kw_{1}; + size_t mr_{1}; + size_t qr_{1}; + size_t kc_{1}; + size_t kr_{1}; + size_t xStride_{0}; + size_t yStride_{0}; + float xScale_{1.25f}; + float yScale_{0.75f}; + uint8_t xZeroPoint_{121}; + uint8_t yZeroPoint_{133}; + uint8_t yMin_{0}; + uint8_t yMax_{255}; + size_t iterations_{15}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/channel-shuffle-operator-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/channel-shuffle-operator-tester.h new file mode 100644 index 0000000000000..d8c85276c6bd2 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/channel-shuffle-operator-tester.h @@ -0,0 +1,157 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +class ChannelShuffleOperatorTester { + public: + inline ChannelShuffleOperatorTester& groups(size_t groups) { + assert(groups != 0); + this->groups_ = groups; + return *this; + } + + inline size_t groups() const { + return this->groups_; + } + + inline ChannelShuffleOperatorTester& groupChannels(size_t groupChannels) { + assert(groupChannels != 0); + this->groupChannels_ = groupChannels; + return *this; + } + + inline size_t groupChannels() const { + return this->groupChannels_; + } + + inline size_t channels() const { + return groups() * groupChannels(); + } + + inline ChannelShuffleOperatorTester& inputStride(size_t inputStride) { + assert(inputStride != 0); + this->inputStride_ = inputStride; + return *this; + } + + inline size_t inputStride() const { + if (this->inputStride_ == 0) { + return channels(); + } else { + assert(this->inputStride_ >= channels()); + return this->inputStride_; + } + } + + inline ChannelShuffleOperatorTester& outputStride(size_t outputStride) { + assert(outputStride != 0); + this->outputStride_ = outputStride; + return *this; + } + + inline size_t outputStride() const { + if (this->outputStride_ == 0) { + return channels(); + } else { + assert(this->outputStride_ >= channels()); + return this->outputStride_; + } + } + + inline ChannelShuffleOperatorTester& batchSize(size_t batchSize) { + this->batchSize_ = batchSize; + return *this; + } + + inline size_t batchSize() const { + return this->batchSize_; + } + + inline ChannelShuffleOperatorTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void testX8() const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input((batchSize() - 1) * inputStride() + channels()); + std::vector output( + (batchSize() - 1) * outputStride() + channels()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::fill(output.begin(), output.end(), 0xA5); + + /* Create, setup, run, and destroy Channel Shuffle operator */ + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + pytorch_qnnp_operator_t channel_shuffle_op = nullptr; + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_create_channel_shuffle_nc_x8( + groups(), groupChannels(), 0, &channel_shuffle_op)); + ASSERT_NE(nullptr, channel_shuffle_op); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_setup_channel_shuffle_nc_x8( + channel_shuffle_op, + batchSize(), + input.data(), + inputStride(), + output.data(), + outputStride())); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_run_operator( + channel_shuffle_op, nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_delete_operator(channel_shuffle_op)); + channel_shuffle_op = nullptr; + + /* Verify results */ + for (size_t i = 0; i < batchSize(); i++) { + for (size_t g = 0; g < groups(); g++) { + for (size_t c = 0; c < groupChannels(); c++) { + ASSERT_EQ( + uint32_t(input[i * inputStride() + g * groupChannels() + c]), + uint32_t(output[i * outputStride() + c * groups() + g])); + } + } + } + } + } + + private: + size_t groups_{1}; + size_t groupChannels_{1}; + size_t batchSize_{1}; + size_t inputStride_{0}; + size_t outputStride_{0}; + size_t iterations_{15}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/channel-shuffle.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/channel-shuffle.cc new file mode 100644 index 0000000000000..d5e270fe609fd --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/channel-shuffle.cc @@ -0,0 +1,268 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include "channel-shuffle-operator-tester.h" + +TEST(CHANNEL_SHUFFLE_OP, zero_batch) { + ChannelShuffleOperatorTester() + .batchSize(0) + .groups(2) + .groupChannels(4) + .iterations(1) + .testX8(); +} + +TEST(CHANNEL_SHUFFLE_OP, two_groups_unit_batch) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(1) + .groups(2) + .groupChannels(groupChannels) + .iterations(3) + .testX8(); + } +} + +TEST(CHANNEL_SHUFFLE_OP, three_groups_unit_batch) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(1) + .groups(3) + .groupChannels(groupChannels) + .iterations(3) + .testX8(); + } +} + +TEST(CHANNEL_SHUFFLE_OP, four_groups_unit_batch) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(1) + .groups(4) + .groupChannels(groupChannels) + .iterations(3) + .testX8(); + } +} + +TEST(CHANNEL_SHUFFLE_OP, many_groups_unit_batch) { + for (size_t groups = 5; groups < 12; groups += 3) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(1) + .groups(groups) + .groupChannels(groupChannels) + .iterations(3) + .testX8(); + } + } +} + +TEST(CHANNEL_SHUFFLE_OP, two_groups_small_batch) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(3) + .groups(2) + .groupChannels(groupChannels) + .iterations(3) + .testX8(); + } +} + +TEST(CHANNEL_SHUFFLE_OP, three_groups_small_batch) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(3) + .groups(3) + .groupChannels(groupChannels) + .iterations(3) + .testX8(); + } +} + +TEST(CHANNEL_SHUFFLE_OP, four_groups_small_batch) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(3) + .groups(4) + .groupChannels(groupChannels) + .iterations(3) + .testX8(); + } +} + +TEST(CHANNEL_SHUFFLE_OP, many_groups_small_batch) { + for (size_t groups = 5; groups < 12; groups += 3) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(3) + .groups(groups) + .groupChannels(groupChannels) + .iterations(3) + .testX8(); + } + } +} + +TEST(CHANNEL_SHUFFLE_OP, two_groups_small_batch_with_input_stride) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(3) + .groups(2) + .groupChannels(groupChannels) + .inputStride(511) + .iterations(3) + .testX8(); + } +} + +TEST(CHANNEL_SHUFFLE_OP, three_groups_small_batch_with_input_stride) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(3) + .groups(3) + .groupChannels(groupChannels) + .inputStride(511) + .iterations(3) + .testX8(); + } +} + +TEST(CHANNEL_SHUFFLE_OP, four_groups_small_batch_with_input_stride) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(3) + .groups(4) + .groupChannels(groupChannels) + .inputStride(511) + .iterations(3) + .testX8(); + } +} + +TEST(CHANNEL_SHUFFLE_OP, many_groups_small_batch_with_input_stride) { + for (size_t groups = 5; groups < 12; groups += 3) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(3) + .groups(groups) + .groupChannels(groupChannels) + .inputStride(1007) + .iterations(3) + .testX8(); + } + } +} + +TEST(CHANNEL_SHUFFLE_OP, two_groups_small_batch_with_output_stride) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(3) + .groups(2) + .groupChannels(groupChannels) + .outputStride(513) + .iterations(3) + .testX8(); + } +} + +TEST(CHANNEL_SHUFFLE_OP, three_groups_small_batch_with_output_stride) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(3) + .groups(3) + .groupChannels(groupChannels) + .outputStride(513) + .iterations(3) + .testX8(); + } +} + +TEST(CHANNEL_SHUFFLE_OP, four_groups_small_batch_with_output_stride) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(3) + .groups(4) + .groupChannels(groupChannels) + .outputStride(513) + .iterations(3) + .testX8(); + } +} + +TEST(CHANNEL_SHUFFLE_OP, many_groups_small_batch_with_output_stride) { + for (size_t groups = 5; groups < 12; groups += 3) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(3) + .groups(groups) + .groupChannels(groupChannels) + .outputStride(1111) + .iterations(3) + .testX8(); + } + } +} + +TEST(CHANNEL_SHUFFLE_OP, two_groups_small_batch_with_input_and_output_stride) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(3) + .groups(2) + .groupChannels(groupChannels) + .inputStride(511) + .outputStride(513) + .iterations(3) + .testX8(); + } +} + +TEST( + CHANNEL_SHUFFLE_OP, + three_groups_small_batch_with_input_and_output_stride) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(3) + .groups(3) + .groupChannels(groupChannels) + .inputStride(511) + .outputStride(513) + .iterations(3) + .testX8(); + } +} + +TEST(CHANNEL_SHUFFLE_OP, four_groups_small_batch_with_input_and_output_stride) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(3) + .groups(4) + .groupChannels(groupChannels) + .inputStride(511) + .outputStride(513) + .iterations(3) + .testX8(); + } +} + +TEST(CHANNEL_SHUFFLE_OP, many_groups_small_batch_with_input_and_output_stride) { + for (size_t groups = 5; groups < 12; groups += 3) { + for (size_t groupChannels = 1; groupChannels < 100; groupChannels += 15) { + ChannelShuffleOperatorTester() + .batchSize(3) + .groups(groups) + .groupChannels(groupChannels) + .inputStride(1007) + .outputStride(1111) + .iterations(3) + .testX8(); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/clamp-microkernel-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/clamp-microkernel-tester.h new file mode 100644 index 0000000000000..d2a72ca0885f8 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/clamp-microkernel-tester.h @@ -0,0 +1,118 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +class ClampMicrokernelTester { + public: + inline ClampMicrokernelTester& n(size_t n) { + assert(n != 0); + this->n_ = n; + return *this; + } + + inline size_t n() const { + return this->n_; + } + + inline ClampMicrokernelTester& inplace(bool inplace) { + this->inplace_ = inplace; + return *this; + } + + inline bool inplace() const { + return this->inplace_; + } + + inline ClampMicrokernelTester& qmin(uint8_t qmin) { + this->qmin_ = qmin; + return *this; + } + + inline uint8_t qmin() const { + return this->qmin_; + } + + inline ClampMicrokernelTester& qmax(uint8_t qmax) { + this->qmax_ = qmax; + return *this; + } + + inline uint8_t qmax() const { + return this->qmax_; + } + + inline ClampMicrokernelTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void test(pytorch_u8clamp_ukernel_function u8clamp) const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector x(n()); + std::vector y(n()); + std::vector yRef(n()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(x.begin(), x.end(), std::ref(u8rng)); + if (inplace()) { + std::generate(y.begin(), y.end(), std::ref(u8rng)); + } else { + std::fill(y.begin(), y.end(), 0xA5); + } + const uint8_t* xData = inplace() ? y.data() : x.data(); + + /* Prepare clamping parameters */ + const union pytorch_qnnp_u8_clamping_params clampingParams = + pytorch_qnnp_compute_u8_clamping_params(qmin(), qmax()); + + /* Compute reference results */ + for (size_t i = 0; i < n(); i++) { + yRef[i] = std::max(std::min(xData[i], qmax()), qmin()); + } + + /* Call optimized micro-kernel */ + u8clamp(n(), xData, y.data(), &clampingParams); + + /* Verify results */ + for (size_t i = 0; i < n(); i++) { + ASSERT_LE(uint32_t(y[i]), uint32_t(qmax())) + << "at position " << i << ", n = " << n(); + ASSERT_GE(uint32_t(y[i]), uint32_t(qmin())) + << "at position " << i << ", n = " << n(); + ASSERT_EQ(uint32_t(yRef[i]), uint32_t(y[i])) + << "at position " << i << ", n = " << n() << ", qmin = " << qmin() + << ", qmax = " << qmax(); + } + } + } + + private: + size_t n_{1}; + bool inplace_{false}; + uint8_t qmin_{0}; + uint8_t qmax_{255}; + size_t iterations_{15}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/clamp-operator-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/clamp-operator-tester.h new file mode 100644 index 0000000000000..4a2f2ee2185af --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/clamp-operator-tester.h @@ -0,0 +1,177 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +class ClampOperatorTester { + public: + inline ClampOperatorTester& channels(size_t channels) { + assert(channels != 0); + this->channels_ = channels; + return *this; + } + + inline size_t channels() const { + return this->channels_; + } + + inline ClampOperatorTester& inputStride(size_t inputStride) { + assert(inputStride != 0); + this->inputStride_ = inputStride; + return *this; + } + + inline size_t inputStride() const { + if (this->inputStride_ == 0) { + return this->channels_; + } else { + assert(this->inputStride_ >= this->channels_); + return this->inputStride_; + } + } + + inline ClampOperatorTester& outputStride(size_t outputStride) { + assert(outputStride != 0); + this->outputStride_ = outputStride; + return *this; + } + + inline size_t outputStride() const { + if (this->outputStride_ == 0) { + return this->channels_; + } else { + assert(this->outputStride_ >= this->channels_); + return this->outputStride_; + } + } + + inline ClampOperatorTester& batchSize(size_t batchSize) { + this->batchSize_ = batchSize; + return *this; + } + + inline size_t batchSize() const { + return this->batchSize_; + } + + inline ClampOperatorTester& qmin(uint8_t qmin) { + this->qmin_ = qmin; + return *this; + } + + inline uint8_t qmin() const { + return this->qmin_; + } + + inline ClampOperatorTester& qmax(uint8_t qmax) { + this->qmax_ = qmax; + return *this; + } + + inline uint8_t qmax() const { + return this->qmax_; + } + + inline ClampOperatorTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void testU8() const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input((batchSize() - 1) * inputStride() + channels()); + std::vector output( + (batchSize() - 1) * outputStride() + channels()); + std::vector outputRef(batchSize() * channels()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::fill(output.begin(), output.end(), 0xA5); + + /* Compute reference results */ + for (size_t i = 0; i < batchSize(); i++) { + for (size_t c = 0; c < channels(); c++) { + const uint8_t x = input[i * inputStride() + c]; + const uint8_t y = std::min(std::max(x, qmin()), qmax()); + outputRef[i * channels() + c] = y; + } + } + + /* Create, setup, run, and destroy Sigmoid operator */ + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + pytorch_qnnp_operator_t clampOp = nullptr; + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_create_clamp_nc_u8( + channels(), qmin(), qmax(), 0, &clampOp)); + ASSERT_NE(nullptr, clampOp); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_setup_clamp_nc_u8( + clampOp, + batchSize(), + input.data(), + inputStride(), + output.data(), + outputStride())); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_run_operator(clampOp, nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, pytorch_qnnp_delete_operator(clampOp)); + clampOp = nullptr; + + /* Verify results */ + for (size_t i = 0; i < batchSize(); i++) { + for (size_t c = 0; c < channels(); c++) { + ASSERT_LE(uint32_t(output[i * channels() + c]), uint32_t(qmax())) + << "at position " << i << ", batch size = " << batchSize() + << ", channels = " << channels(); + ASSERT_GE(uint32_t(output[i * channels() + c]), uint32_t(qmin())) + << "at position " << i << ", batch size = " << batchSize() + << ", channels = " << channels(); + ASSERT_EQ( + uint32_t(outputRef[i * channels() + c]), + uint32_t(output[i * outputStride() + c])) + << "at position " << i << ", batch size = " << batchSize() + << ", channels = " << channels() << ", qmin = " << qmin() + << ", qmax = " << qmax(); + } + } + } + } + + private: + size_t batchSize_{1}; + size_t channels_{1}; + size_t inputStride_{0}; + size_t outputStride_{0}; + uint8_t qmin_{0}; + uint8_t qmax_{255}; + size_t iterations_{15}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/clamp.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/clamp.cc new file mode 100644 index 0000000000000..40015d51b0dfe --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/clamp.cc @@ -0,0 +1,107 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include "clamp-operator-tester.h" + +TEST(CLAMP_OP, zero_batch) { + ClampOperatorTester().batchSize(0).channels(2).iterations(1).testU8(); +} + +TEST(CLAMP_OP, unit_batch) { + for (size_t channels = 1; channels < 100; channels++) { + ClampOperatorTester() + .batchSize(1) + .channels(channels) + .iterations(3) + .testU8(); + } +} + +TEST(CLAMP_OP, unit_batch_with_qmin) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (uint8_t qmin = 1; qmin < 255; qmin++) { + ClampOperatorTester() + .batchSize(1) + .channels(channels) + .qmin(qmin) + .iterations(3) + .testU8(); + } + } +} + +TEST(CLAMP_OP, unit_batch_with_qmax) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (uint8_t qmax = 1; qmax < 255; qmax++) { + ClampOperatorTester() + .batchSize(1) + .channels(channels) + .qmax(qmax) + .iterations(3) + .testU8(); + } + } +} + +TEST(CLAMP_OP, small_batch) { + for (size_t channels = 1; channels < 100; channels++) { + ClampOperatorTester() + .batchSize(3) + .channels(channels) + .iterations(3) + .testU8(); + } +} + +TEST(CLAMP_OP, small_batch_with_input_stride) { + for (size_t channels = 1; channels < 100; channels += 15) { + ClampOperatorTester() + .batchSize(3) + .channels(channels) + .inputStride(129) + .iterations(3) + .testU8(); + } +} + +TEST(CLAMP_OP, small_batch_with_output_stride) { + for (size_t channels = 1; channels < 100; channels += 15) { + ClampOperatorTester() + .batchSize(3) + .channels(channels) + .outputStride(117) + .iterations(3) + .testU8(); + } +} + +TEST(CLAMP_OP, small_batch_with_input_and_output_stride) { + for (size_t channels = 1; channels < 100; channels += 15) { + ClampOperatorTester() + .batchSize(3) + .channels(channels) + .inputStride(129) + .outputStride(117) + .iterations(3) + .testU8(); + } +} + +TEST(CLAMP_OP, qmin_and_qmax_equal_uint8_max) { + for (size_t channels = 1; channels < 100; channels += 15) { + ClampOperatorTester() + .batchSize(3) + .channels(channels) + .qmin(255) + .qmax(255) + .iterations(3) + .testU8(); + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/convolution-operator-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/convolution-operator-tester.h new file mode 100644 index 0000000000000..c733732a659b5 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/convolution-operator-tester.h @@ -0,0 +1,629 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +class ConvolutionOperatorTester { + public: + inline ConvolutionOperatorTester& padding(uint32_t padding) { + this->paddingTop_ = padding; + this->paddingRight_ = padding; + this->paddingBottom_ = padding; + this->paddingLeft_ = padding; + return *this; + } + + inline ConvolutionOperatorTester& padding( + uint32_t paddingHeight, + uint32_t paddingWidth) { + this->paddingTop_ = paddingHeight; + this->paddingRight_ = paddingWidth; + this->paddingBottom_ = paddingHeight; + this->paddingLeft_ = paddingWidth; + return *this; + } + + inline ConvolutionOperatorTester& paddingHeight(uint32_t paddingHeight) { + this->paddingTop_ = paddingHeight; + this->paddingBottom_ = paddingHeight; + return *this; + } + + inline ConvolutionOperatorTester& paddingWidth(uint32_t paddingWidth) { + this->paddingRight_ = paddingWidth; + this->paddingLeft_ = paddingWidth; + return *this; + } + + inline ConvolutionOperatorTester& paddingTop(uint32_t paddingTop) { + this->paddingTop_ = paddingTop; + return *this; + } + + inline uint32_t paddingTop() const { + return this->paddingTop_; + } + + inline ConvolutionOperatorTester& paddingRight(uint32_t paddingRight) { + this->paddingRight_ = paddingRight; + return *this; + } + + inline uint32_t paddingRight() const { + return this->paddingRight_; + } + + inline ConvolutionOperatorTester& paddingBottom(uint32_t paddingBottom) { + this->paddingBottom_ = paddingBottom; + return *this; + } + + inline uint32_t paddingBottom() const { + return this->paddingBottom_; + } + + inline ConvolutionOperatorTester& paddingLeft(uint32_t paddingLeft) { + this->paddingLeft_ = paddingLeft; + return *this; + } + + inline uint32_t paddingLeft() const { + return this->paddingLeft_; + } + + inline ConvolutionOperatorTester& inputSize( + uint32_t inputHeight, + uint32_t inputWidth) { + assert(inputHeight >= 1); + assert(inputWidth >= 1); + this->inputHeight_ = inputHeight; + this->inputWidth_ = inputWidth; + return *this; + } + + inline ConvolutionOperatorTester& inputHeight(uint32_t inputHeight) { + assert(inputHeight >= 1); + this->inputHeight_ = inputHeight; + return *this; + } + + inline uint32_t inputHeight() const { + return this->inputHeight_; + } + + inline ConvolutionOperatorTester& inputWidth(uint32_t inputWidth) { + assert(inputWidth >= 1); + this->inputWidth_ = inputWidth; + return *this; + } + + inline uint32_t inputWidth() const { + return this->inputWidth_; + } + + inline ConvolutionOperatorTester& groups(uint32_t groups) { + assert(groups >= 1); + this->groups_ = groups; + return *this; + } + + inline uint32_t groups() const { + return this->groups_; + } + + inline ConvolutionOperatorTester& groupInputChannels( + size_t groupInputChannels) { + assert(groupInputChannels >= 1); + this->groupInputChannels_ = groupInputChannels; + return *this; + } + + inline size_t groupInputChannels() const { + return this->groupInputChannels_; + } + + inline ConvolutionOperatorTester& groupOutputChannels( + size_t groupOutputChannels) { + assert(groupOutputChannels >= 1); + this->groupOutputChannels_ = groupOutputChannels; + return *this; + } + + inline size_t groupOutputChannels() const { + return this->groupOutputChannels_; + } + + inline ConvolutionOperatorTester& batchSize(size_t batchSize) { + this->batchSize_ = batchSize; + return *this; + } + + inline size_t batchSize() const { + return this->batchSize_; + } + + inline ConvolutionOperatorTester& kernelSize(uint32_t kernelSize) { + assert(kernelSize >= 1); + this->kernelHeight_ = kernelSize; + this->kernelWidth_ = kernelSize; + return *this; + } + + inline ConvolutionOperatorTester& kernelSize( + uint32_t kernelHeight, + uint32_t kernelWidth) { + assert(kernelHeight >= 1); + assert(kernelWidth >= 1); + this->kernelHeight_ = kernelHeight; + this->kernelWidth_ = kernelWidth; + return *this; + } + + inline ConvolutionOperatorTester& kernelHeight(uint32_t kernelHeight) { + assert(kernelHeight >= 1); + this->kernelHeight_ = kernelHeight; + return *this; + } + + inline uint32_t kernelHeight() const { + return this->kernelHeight_; + } + + inline ConvolutionOperatorTester& kernelWidth(uint32_t kernelWidth) { + assert(kernelWidth >= 1); + this->kernelWidth_ = kernelWidth; + return *this; + } + + inline uint32_t kernelWidth() const { + return this->kernelWidth_; + } + + inline ConvolutionOperatorTester& dilation(uint32_t dilation) { + assert(dilation >= 1); + this->dilationHeight_ = dilation; + this->dilationWidth_ = dilation; + return *this; + } + + inline ConvolutionOperatorTester& dilation( + uint32_t dilationHeight, + uint32_t dilationWidth) { + assert(dilationHeight >= 1); + assert(dilationWidth >= 1); + this->dilationHeight_ = dilationHeight; + this->dilationWidth_ = dilationWidth; + return *this; + } + + inline ConvolutionOperatorTester& dilationHeight(uint32_t dilationHeight) { + assert(dilationHeight >= 1); + this->dilationHeight_ = dilationHeight; + return *this; + } + + inline uint32_t dilationHeight() const { + return this->dilationHeight_; + } + + inline ConvolutionOperatorTester& dilationWidth(uint32_t dilationWidth) { + assert(dilationWidth >= 1); + this->dilationWidth_ = dilationWidth; + return *this; + } + + inline uint32_t dilationWidth() const { + return this->dilationWidth_; + } + + inline ConvolutionOperatorTester& subsampling(uint32_t subsampling) { + assert(subsampling >= 1); + this->subsamplingHeight_ = subsampling; + this->subsamplingWidth_ = subsampling; + return *this; + } + + inline ConvolutionOperatorTester& subsampling( + uint32_t subsamplingHeight, + uint32_t subsamplingWidth) { + assert(subsamplingHeight >= 1); + assert(subsamplingWidth >= 1); + this->subsamplingHeight_ = subsamplingHeight; + this->subsamplingWidth_ = subsamplingWidth; + return *this; + } + + inline ConvolutionOperatorTester& subsamplingHeight( + uint32_t subsamplingHeight) { + assert(subsamplingHeight >= 1); + this->subsamplingHeight_ = subsamplingHeight; + return *this; + } + + inline uint32_t subsamplingHeight() const { + return this->subsamplingHeight_; + } + + inline ConvolutionOperatorTester& subsamplingWidth( + uint32_t subsamplingWidth) { + assert(subsamplingWidth >= 1); + this->subsamplingWidth_ = subsamplingWidth; + return *this; + } + + inline uint32_t subsamplingWidth() const { + return this->subsamplingWidth_; + } + + inline ConvolutionOperatorTester& inputPixelStride(size_t inputPixelStride) { + assert(inputPixelStride >= 1); + this->inputPixelStride_ = inputPixelStride; + return *this; + } + + inline size_t inputPixelStride() const { + if (this->inputPixelStride_ == 0) { + return groupInputChannels() * groups(); + } else { + assert(this->inputPixelStride_ >= groupInputChannels() * groups()); + return this->inputPixelStride_; + } + } + + inline ConvolutionOperatorTester& outputPixelStride( + size_t outputPixelStride) { + assert(outputPixelStride >= 1); + this->outputPixelStride_ = outputPixelStride; + return *this; + } + + inline size_t outputPixelStride() const { + if (this->outputPixelStride_ == 0) { + return groupOutputChannels() * groups(); + } else { + assert(this->outputPixelStride_ >= groupOutputChannels() * groups()); + return this->outputPixelStride_; + } + } + + inline uint32_t dilatedKernelHeight() const { + return (kernelHeight() - 1) * dilationHeight() + 1; + } + + inline uint32_t dilatedKernelWidth() const { + return (kernelWidth() - 1) * dilationWidth() + 1; + } + + inline size_t outputHeight() const { + const size_t paddedInputHeight = + paddingTop() + inputHeight() + paddingBottom(); + if (paddedInputHeight <= dilatedKernelHeight()) { + return 1; + } else { + return (paddedInputHeight - dilatedKernelHeight()) / subsamplingHeight() + + 1; + } + } + + inline size_t outputWidth() const { + const size_t paddedInputWidth = + paddingLeft() + inputWidth() + paddingRight(); + if (paddedInputWidth <= dilatedKernelWidth()) { + return 1; + } else { + return (paddedInputWidth - dilatedKernelWidth()) / subsamplingWidth() + 1; + } + } + + inline ConvolutionOperatorTester& qmin(uint8_t qmin) { + this->qmin_ = qmin; + return *this; + } + + inline uint8_t qmin() const { + return this->qmin_; + } + + inline ConvolutionOperatorTester& qmax(uint8_t qmax) { + this->qmax_ = qmax; + return *this; + } + + inline uint8_t qmax() const { + return this->qmax_; + } + + inline ConvolutionOperatorTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void testQ8(bool runtime_quant = false) const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto s32rng = + std::bind(std::uniform_int_distribution(-10000, 10000), rng); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input( + batchSize() * + ((inputHeight() * inputWidth() - 1) * inputPixelStride() + + groups() * groupInputChannels()) + + 8); + std::vector kernel( + groups() * groupOutputChannels() * kernelHeight() * kernelWidth() * + groupInputChannels()); + std::vector bias(groups() * groupOutputChannels()); + std::vector output( + batchSize() * + ((outputHeight() * outputWidth() - 1) * outputPixelStride() + + groups() * groupOutputChannels())); + std::vector accumulators( + batchSize() * outputHeight() * outputWidth() * groups() * + groupOutputChannels()); + + const uint8_t* inputPtr = input.data() + 8; + const uint8_t inputZeroPoint = 127; + const uint8_t kernelZeroPoint = 127; + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::generate(kernel.begin(), kernel.end(), std::ref(u8rng)); + std::generate(bias.begin(), bias.end(), std::ref(s32rng)); + std::fill(output.begin(), output.end(), 0xA5); + std::fill(accumulators.begin(), accumulators.end(), 0); + + for (size_t i = 0; i < batchSize(); i++) { + for (size_t oy = 0; oy < outputHeight(); oy++) { + for (size_t ox = 0; ox < outputWidth(); ox++) { + for (size_t g = 0; g < groups(); g++) { + for (size_t oc = 0; oc < groupOutputChannels(); oc++) { + accumulators + [(((i * outputHeight() + oy) * outputWidth() + ox) * + groups() + + g) * + groupOutputChannels() + + oc] = bias[g * groupOutputChannels() + oc]; + } + } + } + } + } + for (size_t i = 0; i < batchSize(); i++) { + for (size_t oy = 0; oy < outputHeight(); oy++) { + for (size_t ox = 0; ox < outputWidth(); ox++) { + for (size_t ky = 0; ky < kernelHeight(); ky++) { + const size_t iy = oy * subsamplingHeight() + + ky * dilationHeight() - paddingTop(); + if (iy < inputHeight()) { + for (size_t kx = 0; kx < kernelWidth(); kx++) { + const size_t ix = ox * subsamplingWidth() + + kx * dilationWidth() - paddingLeft(); + if (ix < inputWidth()) { + for (size_t g = 0; g < groups(); g++) { + for (size_t oc = 0; oc < groupOutputChannels(); oc++) { + for (size_t ic = 0; ic < groupInputChannels(); ic++) { + accumulators + [(((i * outputHeight() + oy) * outputWidth() + + ox) * + groups() + + g) * + groupOutputChannels() + + oc] += + (int32_t(inputPtr + [((i * inputHeight() + iy) * + inputWidth() + + ix) * + inputPixelStride() + + g * groupInputChannels() + ic]) - + int32_t(inputZeroPoint)) * + (int32_t(kernel + [(((g * groupOutputChannels() + oc) * + kernelHeight() + + ky) * + kernelWidth() + + kx) * + groupInputChannels() + + ic]) - + int32_t(kernelZeroPoint)); + } + } + } + } + } + } + } + } + } + } + // Create dummy min/max for empty inputs. + // These are only used to compute scale and zero point, + // and real callers will just pull those values from the model. + const int32_t accumulatorsMin = accumulators.empty() + ? 0 + : *std::min_element(accumulators.cbegin(), accumulators.cend()); + const int32_t accumulatorsMax = accumulators.empty() + ? 900 + : *std::max_element(accumulators.cbegin(), accumulators.cend()); + + const double outputScale = + double(uint32_t(accumulatorsMax - accumulatorsMin)) / 255.0; + const uint8_t outputZeroPoint = uint8_t(std::max( + std::min( + lrint( + 127.5 - + 0.5 * double(accumulatorsMin + accumulatorsMax) / + outputScale), + long(std::numeric_limits::max())), + long(std::numeric_limits::min()))); + + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + if (runtime_quant) { + qnnpack::conv_param_t conv_p( + {kernelWidth(), kernelHeight()}, + {subsamplingWidth(), subsamplingHeight()}, + {dilationWidth(), dilationHeight()}, + {paddingTop(), paddingLeft(), paddingBottom(), paddingRight()}, + groups(), + groupInputChannels() * groups(), + groupOutputChannels() * groups(), + kernelZeroPoint, + 1.0, + qmin(), + qmax()); + auto packW = std::unique_ptr( + new qnnpack::PrePackConvWeights( + conv_p, + kernel.data(), + bias.data())); + const pytorch_qnnp_status runStatus = qnnpack::qnnpackConv( + conv_p, + packW->getPackedWeights(), + batchSize(), + inputHeight(), + inputWidth(), + 1.0, + inputZeroPoint, + inputPtr, + outputScale, + outputZeroPoint, + output.data(), + nullptr); + ASSERT_EQ(pytorch_qnnp_status_success, runStatus); + } + else { + pytorch_qnnp_operator_t convolution = nullptr; + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_create_convolution2d_nhwc_q8( + paddingTop(), + paddingRight(), + paddingBottom(), + paddingLeft(), + kernelHeight(), + kernelWidth(), + subsamplingHeight(), + subsamplingWidth(), + dilationHeight(), + dilationWidth(), + groups(), + groupInputChannels(), + groupOutputChannels(), + inputZeroPoint, + 1.0f /* input scale */, + kernelZeroPoint, + 1.0f /* kernel scale */, + kernel.data(), + bias.data(), + outputZeroPoint, + outputScale, + qmin(), + qmax(), + 0, + &convolution)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_setup_convolution2d_nhwc_q8( + convolution, + batchSize(), + inputHeight(), + inputWidth(), + inputPtr, + inputPixelStride(), + output.data(), + outputPixelStride(), + nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_run_operator(convolution, nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_delete_operator(convolution)); + convolution = nullptr; + } + for (size_t i = 0; i < batchSize(); i++) { + for (size_t y = 0; y < outputHeight(); y++) { + for (size_t x = 0; x < outputWidth(); x++) { + for (size_t g = 0; g < groups(); g++) { + for (size_t c = 0; c < groupOutputChannels(); c++) { + const double scaledAccumulator = + accumulators + [(((i * outputHeight() + y) * outputWidth() + x) * + groups() + + g) * + groupOutputChannels() + + c] / + outputScale; + const double clampedAccumulator = std::max( + std::min( + scaledAccumulator, + double(qmax()) - double(outputZeroPoint)), + double(qmin()) - double(outputZeroPoint)); + ASSERT_NEAR( + clampedAccumulator, + (int32_t( + output + [((i * outputHeight() + y) * outputWidth() + x) * + outputPixelStride() + + g * groupOutputChannels() + c]) - + outputZeroPoint), + 0.9) + << "(x, y) = (" << x << ", " << y << "), group = " << g + << ", channel = " << c; + } + } + } + } + } + } + } + + private: + uint32_t paddingTop_{0}; + uint32_t paddingRight_{0}; + uint32_t paddingBottom_{0}; + uint32_t paddingLeft_{0}; + size_t inputHeight_{1}; + size_t inputWidth_{1}; + uint32_t groups_{1}; + size_t groupInputChannels_{1}; + size_t inputPixelStride_{0}; + size_t groupOutputChannels_{1}; + size_t outputPixelStride_{0}; + size_t batchSize_{1}; + uint32_t kernelHeight_{1}; + uint32_t kernelWidth_{1}; + uint32_t dilationHeight_{1}; + uint32_t dilationWidth_{1}; + uint32_t subsamplingHeight_{1}; + uint32_t subsamplingWidth_{1}; + uint8_t qmin_{0}; + uint8_t qmax_{255}; + size_t iterations_{1}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/convolution.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/convolution.cc new file mode 100644 index 0000000000000..6439d71da3960 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/convolution.cc @@ -0,0 +1,648 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include "convolution-operator-tester.h" + +TEST(CONVOLUTION_OP, zero_batch) { + ConvolutionOperatorTester() + .batchSize(0) + .inputSize(5, 5) + .kernelSize(1, 1) + .groupInputChannels(2) + .groupOutputChannels(2) + .iterations(1) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 1x1) { + ConvolutionOperatorTester() + .inputSize(27, 29) + .kernelSize(1, 1) + .groupInputChannels(23) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 1x1_runtime_quant) { + ConvolutionOperatorTester() + .inputSize(27, 29) + .kernelSize(1, 1) + .groupInputChannels(23) + .groupOutputChannels(19) + .iterations(3) + .testQ8(true); +} + +TEST(CONVOLUTION_OP, 1x1_with_qmin) { + ConvolutionOperatorTester() + .inputSize(27, 29) + .kernelSize(1, 1) + .groupInputChannels(23) + .groupOutputChannels(19) + .qmin(128) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 1x1_with_qmax) { + ConvolutionOperatorTester() + .inputSize(27, 29) + .kernelSize(1, 1) + .groupInputChannels(23) + .groupOutputChannels(19) + .qmax(128) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 1x1_with_input_stride) { + ConvolutionOperatorTester() + .inputSize(27, 29) + .kernelSize(1, 1) + .inputPixelStride(28) + .groupInputChannels(23) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 1x1_with_output_stride) { + ConvolutionOperatorTester() + .inputSize(27, 29) + .kernelSize(1, 1) + .outputPixelStride(29) + .groupInputChannels(23) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 1x1_with_batch) { + ConvolutionOperatorTester() + .inputSize(13, 14) + .kernelSize(1, 1) + .batchSize(3) + .groupInputChannels(23) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, grouped_1x1) { + ConvolutionOperatorTester() + .inputSize(24, 25) + .kernelSize(1, 1) + .groups(2) + .groupInputChannels(17) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, xzp_1x1) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + if (pytorch_qnnp_params.q8conv_xzp.kthreshold != SIZE_MAX) { + ConvolutionOperatorTester() + .inputSize(27, 29) + .kernelSize(1, 1) + .groupInputChannels(pytorch_qnnp_params.q8conv_xzp.kthreshold + 1) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); + } +} + +TEST(CONVOLUTION_OP, xzp_1x1_with_qmin) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + if (pytorch_qnnp_params.q8conv_xzp.kthreshold != SIZE_MAX) { + ConvolutionOperatorTester() + .inputSize(27, 29) + .kernelSize(1, 1) + .groupInputChannels(pytorch_qnnp_params.q8conv_xzp.kthreshold + 1) + .groupOutputChannels(19) + .qmin(128) + .iterations(3) + .testQ8(); + } +} + +TEST(CONVOLUTION_OP, xzp_1x1_with_qmax) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + if (pytorch_qnnp_params.q8conv_xzp.kthreshold != SIZE_MAX) { + ConvolutionOperatorTester() + .inputSize(27, 29) + .kernelSize(1, 1) + .groupInputChannels(pytorch_qnnp_params.q8conv_xzp.kthreshold + 1) + .groupOutputChannels(19) + .qmax(128) + .iterations(3) + .testQ8(); + } +} + +TEST(CONVOLUTION_OP, xzp_1x1_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + if (pytorch_qnnp_params.q8conv_xzp.kthreshold != SIZE_MAX) { + ConvolutionOperatorTester() + .inputSize(27, 29) + .kernelSize(1, 1) + .inputPixelStride(pytorch_qnnp_params.q8conv_xzp.kthreshold + 5) + .groupInputChannels(pytorch_qnnp_params.q8conv_xzp.kthreshold + 1) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); + } +} + +TEST(CONVOLUTION_OP, xzp_1x1_with_output_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + if (pytorch_qnnp_params.q8conv_xzp.kthreshold != SIZE_MAX) { + ConvolutionOperatorTester() + .inputSize(27, 29) + .kernelSize(1, 1) + .outputPixelStride(29) + .groupInputChannels(pytorch_qnnp_params.q8conv_xzp.kthreshold + 1) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); + } +} + +TEST(CONVOLUTION_OP, xzp_1x1_with_batch) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + if (pytorch_qnnp_params.q8conv_xzp.kthreshold != SIZE_MAX) { + ConvolutionOperatorTester() + .inputSize(13, 14) + .kernelSize(1, 1) + .batchSize(3) + .groupInputChannels(pytorch_qnnp_params.q8conv_xzp.kthreshold + 1) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); + } +} + +TEST(CONVOLUTION_OP, grouped_xzp_1x1) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + if (pytorch_qnnp_params.q8conv_xzp.kthreshold != SIZE_MAX) { + ConvolutionOperatorTester() + .inputSize(24, 25) + .kernelSize(1, 1) + .groups(2) + .groupInputChannels(pytorch_qnnp_params.q8conv_xzp.kthreshold + 1) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); + } +} + +TEST(CONVOLUTION_OP, grouped_xzp_1x1_runtime_quant) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + if (pytorch_qnnp_params.q8conv_xzp.kthreshold != SIZE_MAX) { + ConvolutionOperatorTester() + .inputSize(24, 25) + .kernelSize(1, 1) + .groups(2) + .groupInputChannels(pytorch_qnnp_params.q8conv_xzp.kthreshold + 1) + .groupOutputChannels(19) + .iterations(3) + .testQ8(true); + } +} + +TEST(CONVOLUTION_OP, 1x3) { + ConvolutionOperatorTester() + .inputSize(20, 19) + .paddingWidth(1) + .kernelSize(1, 3) + .groupInputChannels(17) + .groupOutputChannels(15) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, grouped_1x3) { + ConvolutionOperatorTester() + .inputSize(20, 19) + .paddingWidth(1) + .kernelSize(1, 3) + .groups(2) + .groupInputChannels(17) + .groupOutputChannels(15) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, grouped_1x3_runtime_quant) { + ConvolutionOperatorTester() + .inputSize(20, 19) + .paddingWidth(1) + .kernelSize(1, 3) + .groups(2) + .groupInputChannels(17) + .groupOutputChannels(15) + .iterations(3) + .testQ8(true); +} + +TEST(CONVOLUTION_OP, 3x1) { + ConvolutionOperatorTester() + .inputSize(19, 20) + .paddingHeight(1) + .kernelSize(3, 1) + .groupInputChannels(17) + .groupOutputChannels(15) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, grouped_3x1) { + ConvolutionOperatorTester() + .inputSize(19, 20) + .paddingHeight(1) + .kernelSize(3, 1) + .groups(2) + .groupInputChannels(17) + .groupOutputChannels(15) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 3x3) { + ConvolutionOperatorTester() + .inputSize(13, 12) + .padding(1) + .kernelSize(3, 3) + .groupInputChannels(15) + .groupOutputChannels(17) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 3x3_without_padding) { + ConvolutionOperatorTester() + .inputSize(13, 12) + .kernelSize(3, 3) + .groupInputChannels(15) + .groupOutputChannels(17) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 3x3_with_left_padding) { + ConvolutionOperatorTester() + .inputSize(13, 12) + .paddingLeft(1) + .kernelSize(3, 3) + .groupInputChannels(15) + .groupOutputChannels(17) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 3x3_with_right_padding) { + ConvolutionOperatorTester() + .inputSize(13, 12) + .paddingRight(1) + .kernelSize(3, 3) + .groupInputChannels(15) + .groupOutputChannels(17) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 3x3_with_top_padding) { + ConvolutionOperatorTester() + .inputSize(13, 12) + .paddingTop(1) + .kernelSize(3, 3) + .groupInputChannels(15) + .groupOutputChannels(17) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 3x3_with_bottom_padding) { + ConvolutionOperatorTester() + .inputSize(13, 12) + .paddingBottom(1) + .kernelSize(3, 3) + .groupInputChannels(15) + .groupOutputChannels(17) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 3x3_with_input_stride) { + ConvolutionOperatorTester() + .inputSize(13, 12) + .padding(1) + .kernelSize(3, 3) + .inputPixelStride(22) + .groupInputChannels(15) + .groupOutputChannels(17) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 3x3_with_output_stride) { + ConvolutionOperatorTester() + .inputSize(13, 12) + .padding(1) + .kernelSize(3, 3) + .outputPixelStride(23) + .groupInputChannels(15) + .groupOutputChannels(17) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 3x3_with_batch) { + ConvolutionOperatorTester() + .inputSize(10, 9) + .padding(1) + .kernelSize(3, 3) + .batchSize(3) + .groupInputChannels(15) + .groupOutputChannels(17) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, grouped_3x3) { + ConvolutionOperatorTester() + .inputSize(10, 11) + .padding(1) + .kernelSize(3, 3) + .groups(2) + .groupInputChannels(14) + .groupOutputChannels(13) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 3x3s2) { + ConvolutionOperatorTester() + .inputSize(19, 21) + .padding(1) + .kernelSize(3, 3) + .subsampling(2) + .groupInputChannels(27) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 3x3s1x2) { + ConvolutionOperatorTester() + .inputSize(13, 13) + .padding(1) + .kernelSize(3, 3) + .subsampling(1, 2) + .groupInputChannels(27) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 3x3s2x1) { + ConvolutionOperatorTester() + .inputSize(13, 13) + .padding(1) + .kernelSize(3, 3) + .subsampling(2, 1) + .groupInputChannels(27) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 3x3d2) { + ConvolutionOperatorTester() + .inputSize(13, 14) + .padding(2) + .kernelSize(3, 3) + .dilation(2) + .groupInputChannels(27) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 3x3d1x2) { + ConvolutionOperatorTester() + .inputSize(14, 15) + .padding(1, 2) + .kernelSize(3, 3) + .dilation(1, 2) + .groupInputChannels(27) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, 3x3d2x1) { + ConvolutionOperatorTester() + .inputSize(15, 14) + .padding(2, 1) + .kernelSize(3, 3) + .dilation(2, 1) + .groupInputChannels(27) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, depthwise_3x3) { + ConvolutionOperatorTester() + .inputSize(15, 14) + .padding(1, 1) + .kernelSize(3, 3) + .groups(27) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, depthwise_3x3_runtime_quant) { + ConvolutionOperatorTester() + .inputSize(15, 14) + .padding(1, 1) + .kernelSize(3, 3) + .groups(27) + .iterations(3) + .testQ8(true); +} + +TEST(CONVOLUTION_OP, depthwise_3x3s2) { + ConvolutionOperatorTester() + .inputSize(15, 14) + .padding(1, 1) + .kernelSize(3, 3) + .subsampling(2) + .groups(27) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, depthwise_3x3s1x2) { + ConvolutionOperatorTester() + .inputSize(15, 14) + .padding(1, 1) + .kernelSize(3, 3) + .subsampling(1, 2) + .groups(27) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, depthwise_3x3s2x1) { + ConvolutionOperatorTester() + .inputSize(15, 14) + .padding(1, 1) + .kernelSize(3, 3) + .subsampling(2, 1) + .groups(27) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, depthwise_3x3d2) { + ConvolutionOperatorTester() + .inputSize(15, 14) + .padding(1, 1) + .kernelSize(3, 3) + .dilation(2) + .groups(27) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, depthwise_3x3d1x2) { + ConvolutionOperatorTester() + .inputSize(15, 14) + .padding(1, 1) + .kernelSize(3, 3) + .dilation(1, 2) + .groups(27) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, depthwise_3x3d2x1) { + ConvolutionOperatorTester() + .inputSize(15, 14) + .padding(1, 1) + .kernelSize(3, 3) + .dilation(2, 1) + .groups(27) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, depthwise_3x3d2x1_runtime_quant) { + ConvolutionOperatorTester() + .inputSize(15, 14) + .padding(1, 1) + .kernelSize(3, 3) + .dilation(2, 1) + .groups(27) + .iterations(3) + .testQ8(true); +} + +TEST(CONVOLUTION_OP, depthwise_5x5) { + ConvolutionOperatorTester() + .inputSize(15, 14) + .padding(2, 2) + .kernelSize(5, 5) + .groups(27) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, depthwise_5x5s2) { + ConvolutionOperatorTester() + .inputSize(15, 14) + .padding(2, 2) + .kernelSize(5, 5) + .subsampling(2) + .groups(27) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, depthwise_5x5s1x2) { + ConvolutionOperatorTester() + .inputSize(15, 14) + .padding(2, 2) + .kernelSize(5, 5) + .subsampling(1, 2) + .groups(27) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, depthwise_5x5s2x1) { + ConvolutionOperatorTester() + .inputSize(15, 14) + .padding(2, 2) + .kernelSize(5, 5) + .subsampling(2, 1) + .groups(27) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, depthwise_5x5d2) { + ConvolutionOperatorTester() + .inputSize(15, 14) + .padding(2, 2) + .kernelSize(5, 5) + .dilation(2) + .groups(27) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, depthwise_5x5d1x2) { + ConvolutionOperatorTester() + .inputSize(15, 14) + .padding(2, 2) + .kernelSize(5, 5) + .dilation(1, 2) + .groups(27) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, depthwise_5x5d2x1) { + ConvolutionOperatorTester() + .inputSize(15, 14) + .padding(2, 2) + .kernelSize(5, 5) + .dilation(2, 1) + .groups(27) + .iterations(3) + .testQ8(); +} + +TEST(CONVOLUTION_OP, depthwise_5x5d2x1_runtime_quant) { + ConvolutionOperatorTester() + .inputSize(15, 14) + .padding(2, 2) + .kernelSize(5, 5) + .dilation(2, 1) + .groups(27) + .iterations(3) + .testQ8(true); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/deconvolution-operator-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/deconvolution-operator-tester.h new file mode 100644 index 0000000000000..bdba87c4b0d89 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/deconvolution-operator-tester.h @@ -0,0 +1,635 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +class DeconvolutionOperatorTester { + public: + inline DeconvolutionOperatorTester& padding(uint32_t padding) { + this->paddingTop_ = padding; + this->paddingRight_ = padding; + this->paddingBottom_ = padding; + this->paddingLeft_ = padding; + return *this; + } + + inline DeconvolutionOperatorTester& padding( + uint32_t paddingHeight, + uint32_t paddingWidth) { + this->paddingTop_ = paddingHeight; + this->paddingRight_ = paddingWidth; + this->paddingBottom_ = paddingHeight; + this->paddingLeft_ = paddingWidth; + return *this; + } + + inline DeconvolutionOperatorTester& paddingHeight(uint32_t paddingHeight) { + this->paddingTop_ = paddingHeight; + this->paddingBottom_ = paddingHeight; + return *this; + } + + inline uint32_t paddingHeight() const { + return this->paddingTop_ + this->paddingBottom_; + } + + inline DeconvolutionOperatorTester& paddingWidth(uint32_t paddingWidth) { + this->paddingRight_ = paddingWidth; + this->paddingLeft_ = paddingWidth; + return *this; + } + + inline uint32_t paddingWidth() const { + return this->paddingLeft_ + this->paddingRight_; + } + + inline DeconvolutionOperatorTester& paddingTop(uint32_t paddingTop) { + this->paddingTop_ = paddingTop; + return *this; + } + + inline uint32_t paddingTop() const { + return this->paddingTop_; + } + + inline DeconvolutionOperatorTester& paddingRight(uint32_t paddingRight) { + this->paddingRight_ = paddingRight; + return *this; + } + + inline uint32_t paddingRight() const { + return this->paddingRight_; + } + + inline DeconvolutionOperatorTester& paddingBottom(uint32_t paddingBottom) { + this->paddingBottom_ = paddingBottom; + return *this; + } + + inline uint32_t paddingBottom() const { + return this->paddingBottom_; + } + + inline DeconvolutionOperatorTester& paddingLeft(uint32_t paddingLeft) { + this->paddingLeft_ = paddingLeft; + return *this; + } + + inline uint32_t paddingLeft() const { + return this->paddingLeft_; + } + + inline DeconvolutionOperatorTester& adjustmentHeight( + uint32_t adjustmentHeight) { + this->adjustmentHeight_ = adjustmentHeight; + return *this; + } + + inline uint32_t adjustmentHeight() const { + return this->adjustmentHeight_; + } + + inline DeconvolutionOperatorTester& adjustmentWidth( + uint32_t adjustmentWidth) { + this->adjustmentWidth_ = adjustmentWidth; + return *this; + } + + inline uint32_t adjustmentWidth() const { + return this->adjustmentWidth_; + } + + inline DeconvolutionOperatorTester& inputSize( + uint32_t inputHeight, + uint32_t inputWidth) { + assert(inputHeight >= 1); + assert(inputWidth >= 1); + this->inputHeight_ = inputHeight; + this->inputWidth_ = inputWidth; + return *this; + } + + inline DeconvolutionOperatorTester& inputHeight(uint32_t inputHeight) { + assert(inputHeight >= 1); + this->inputHeight_ = inputHeight; + return *this; + } + + inline uint32_t inputHeight() const { + return this->inputHeight_; + } + + inline DeconvolutionOperatorTester& inputWidth(uint32_t inputWidth) { + assert(inputWidth >= 1); + this->inputWidth_ = inputWidth; + return *this; + } + + inline uint32_t inputWidth() const { + return this->inputWidth_; + } + + inline DeconvolutionOperatorTester& groups(uint32_t groups) { + assert(groups >= 1); + this->groups_ = groups; + return *this; + } + + inline uint32_t groups() const { + return this->groups_; + } + + inline DeconvolutionOperatorTester& groupInputChannels( + size_t groupInputChannels) { + assert(groupInputChannels >= 1); + this->groupInputChannels_ = groupInputChannels; + return *this; + } + + inline size_t groupInputChannels() const { + return this->groupInputChannels_; + } + + inline DeconvolutionOperatorTester& groupOutputChannels( + size_t groupOutputChannels) { + assert(groupOutputChannels >= 1); + this->groupOutputChannels_ = groupOutputChannels; + return *this; + } + + inline size_t groupOutputChannels() const { + return this->groupOutputChannels_; + } + + inline DeconvolutionOperatorTester& batchSize(size_t batchSize) { + this->batchSize_ = batchSize; + return *this; + } + + inline size_t batchSize() const { + return this->batchSize_; + } + + inline DeconvolutionOperatorTester& kernelSize(uint32_t kernelSize) { + assert(kernelSize >= 1); + this->kernelHeight_ = kernelSize; + this->kernelWidth_ = kernelSize; + return *this; + } + + inline DeconvolutionOperatorTester& kernelSize( + uint32_t kernelHeight, + uint32_t kernelWidth) { + assert(kernelHeight >= 1); + assert(kernelWidth >= 1); + this->kernelHeight_ = kernelHeight; + this->kernelWidth_ = kernelWidth; + return *this; + } + + inline DeconvolutionOperatorTester& kernelHeight(uint32_t kernelHeight) { + assert(kernelHeight >= 1); + this->kernelHeight_ = kernelHeight; + return *this; + } + + inline uint32_t kernelHeight() const { + return this->kernelHeight_; + } + + inline DeconvolutionOperatorTester& kernelWidth(uint32_t kernelWidth) { + assert(kernelWidth >= 1); + this->kernelWidth_ = kernelWidth; + return *this; + } + + inline uint32_t kernelWidth() const { + return this->kernelWidth_; + } + + inline DeconvolutionOperatorTester& dilation(uint32_t dilation) { + assert(dilation >= 1); + this->dilationHeight_ = dilation; + this->dilationWidth_ = dilation; + return *this; + } + + inline DeconvolutionOperatorTester& dilation( + uint32_t dilationHeight, + uint32_t dilationWidth) { + assert(dilationHeight >= 1); + assert(dilationWidth >= 1); + this->dilationHeight_ = dilationHeight; + this->dilationWidth_ = dilationWidth; + return *this; + } + + inline DeconvolutionOperatorTester& dilationHeight(uint32_t dilationHeight) { + assert(dilationHeight >= 1); + this->dilationHeight_ = dilationHeight; + return *this; + } + + inline uint32_t dilationHeight() const { + return this->dilationHeight_; + } + + inline DeconvolutionOperatorTester& dilationWidth(uint32_t dilationWidth) { + assert(dilationWidth >= 1); + this->dilationWidth_ = dilationWidth; + return *this; + } + + inline uint32_t dilationWidth() const { + return this->dilationWidth_; + } + + inline DeconvolutionOperatorTester& stride(uint32_t stride) { + assert(stride >= 1); + this->strideHeight_ = stride; + this->strideWidth_ = stride; + return *this; + } + + inline DeconvolutionOperatorTester& stride( + uint32_t strideHeight, + uint32_t strideWidth) { + assert(strideHeight >= 1); + assert(strideWidth >= 1); + this->strideHeight_ = strideHeight; + this->strideWidth_ = strideWidth; + return *this; + } + + inline DeconvolutionOperatorTester& strideHeight(uint32_t strideHeight) { + assert(strideHeight >= 1); + this->strideHeight_ = strideHeight; + return *this; + } + + inline uint32_t strideHeight() const { + return this->strideHeight_; + } + + inline DeconvolutionOperatorTester& strideWidth(uint32_t strideWidth) { + assert(strideWidth >= 1); + this->strideWidth_ = strideWidth; + return *this; + } + + inline uint32_t strideWidth() const { + return this->strideWidth_; + } + + inline DeconvolutionOperatorTester& inputPixelStride( + size_t inputPixelStride) { + assert(inputPixelStride >= 1); + this->inputPixelStride_ = inputPixelStride; + return *this; + } + + inline size_t inputPixelStride() const { + if (this->inputPixelStride_ == 0) { + return groupInputChannels() * groups(); + } else { + assert(this->inputPixelStride_ >= groupInputChannels() * groups()); + return this->inputPixelStride_; + } + } + + inline DeconvolutionOperatorTester& outputPixelStride( + size_t outputPixelStride) { + assert(outputPixelStride >= 1); + this->outputPixelStride_ = outputPixelStride; + return *this; + } + + inline size_t outputPixelStride() const { + if (this->outputPixelStride_ == 0) { + return groupOutputChannels() * groups(); + } else { + assert(this->outputPixelStride_ >= groupOutputChannels() * groups()); + return this->outputPixelStride_; + } + } + + inline uint32_t dilatedKernelHeight() const { + return (kernelHeight() - 1) * dilationHeight() + 1; + } + + inline uint32_t dilatedKernelWidth() const { + return (kernelWidth() - 1) * dilationWidth() + 1; + } + + inline size_t outputHeight() const { + return strideHeight() * (inputHeight() - 1) + adjustmentHeight() + + dilatedKernelHeight() - paddingHeight(); + } + + inline size_t outputWidth() const { + return strideWidth() * (inputWidth() - 1) + adjustmentWidth() + + dilatedKernelWidth() - paddingWidth(); + } + + inline DeconvolutionOperatorTester& qmin(uint8_t qmin) { + this->qmin_ = qmin; + return *this; + } + + inline uint8_t qmin() const { + return this->qmin_; + } + + inline DeconvolutionOperatorTester& qmax(uint8_t qmax) { + this->qmax_ = qmax; + return *this; + } + + inline uint8_t qmax() const { + return this->qmax_; + } + + inline DeconvolutionOperatorTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void testQ8() const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto s32rng = + std::bind(std::uniform_int_distribution(-10000, 10000), rng); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input( + batchSize() * + ((inputHeight() * inputWidth() - 1) * inputPixelStride() + + groups() * groupInputChannels()) + + 8); + std::vector kernel( + groups() * groupOutputChannels() * kernelHeight() * kernelWidth() * + groupInputChannels()); + std::vector bias(groups() * groupOutputChannels()); + std::vector output( + batchSize() * + ((outputHeight() * outputWidth() - 1) * outputPixelStride() + + groups() * groupOutputChannels())); + std::vector accumulators( + batchSize() * outputHeight() * outputWidth() * groups() * + groupOutputChannels()); + + const uint8_t* inputPtr = input.data() + 8; + const uint8_t inputZeroPoint = 127; + const uint8_t kernelZeroPoint = 127; + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::generate(kernel.begin(), kernel.end(), std::ref(u8rng)); + std::generate(bias.begin(), bias.end(), std::ref(s32rng)); + std::fill(output.begin(), output.end(), 0xA5); + std::fill(accumulators.begin(), accumulators.end(), 0); + + for (size_t i = 0; i < batchSize(); i++) { + for (size_t oy = 0; oy < outputHeight(); oy++) { + for (size_t ox = 0; ox < outputWidth(); ox++) { + for (size_t g = 0; g < groups(); g++) { + for (size_t oc = 0; oc < groupOutputChannels(); oc++) { + accumulators + [(((i * outputHeight() + oy) * outputWidth() + ox) * + groups() + + g) * + groupOutputChannels() + + oc] = bias[g * groupOutputChannels() + oc]; + } + } + } + } + } + for (size_t i = 0; i < batchSize(); i++) { + for (size_t oy = 0; oy < outputHeight(); oy++) { + for (size_t ox = 0; ox < outputWidth(); ox++) { + for (size_t ky = 0; ky < kernelHeight(); ky++) { + const size_t y = oy + paddingTop() - ky * dilationHeight(); + const size_t iy = y / strideHeight(); + if (iy * strideHeight() == y && iy < inputHeight()) { + for (size_t kx = 0; kx < kernelWidth(); kx++) { + const size_t x = ox + paddingLeft() - kx * dilationWidth(); + const size_t ix = x / strideWidth(); + if (ix * strideWidth() == x && ix < inputWidth()) { + for (size_t g = 0; g < groups(); g++) { + for (size_t oc = 0; oc < groupOutputChannels(); oc++) { + for (size_t ic = 0; ic < groupInputChannels(); ic++) { + accumulators + [(((i * outputHeight() + oy) * outputWidth() + + ox) * + groups() + + g) * + groupOutputChannels() + + oc] += + (int32_t(inputPtr + [((i * inputHeight() + iy) * + inputWidth() + + ix) * + inputPixelStride() + + g * groupInputChannels() + ic]) - + int32_t(inputZeroPoint)) * + (int32_t(kernel + [(((g * groupInputChannels() + ic) * + kernelHeight() + + ky) * + kernelWidth() + + kx) * + groupOutputChannels() + + oc]) - + int32_t(kernelZeroPoint)); + } + } + } + } + } + } + } + } + } + } + // Create dummy min/max for empty inputs. + // These are only used to compute scale and zero point, + // and real callers will just pull those values from the model. + const int32_t accumulatorsMin = accumulators.empty() + ? 0 + : *std::min_element(accumulators.cbegin(), accumulators.cend()); + const int32_t accumulatorsMax = accumulators.empty() + ? 900 + : *std::max_element(accumulators.cbegin(), accumulators.cend()); + + const double outputScale = + double(uint32_t(accumulatorsMax - accumulatorsMin)) / 255.0; + const uint8_t outputZeroPoint = uint8_t(std::max( + std::min( + lrint( + 127.5 - + 0.5 * double(accumulatorsMin + accumulatorsMax) / + outputScale), + long(std::numeric_limits::max())), + long(std::numeric_limits::min()))); + + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + pytorch_qnnp_operator_t deconvolution = nullptr; + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_create_deconvolution2d_nhwc_q8( + paddingTop(), + paddingRight(), + paddingBottom(), + paddingLeft(), + adjustmentHeight(), + adjustmentWidth(), + kernelHeight(), + kernelWidth(), + strideHeight(), + strideWidth(), + dilationHeight(), + dilationWidth(), + groups(), + groupInputChannels(), + groupOutputChannels(), + inputZeroPoint, + 1.0f /* input scale */, + kernelZeroPoint, + 1.0f /* kernel scale */, + kernel.data(), + bias.data(), + outputZeroPoint, + outputScale, + qmin(), + qmax(), + 0, + &deconvolution)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_setup_deconvolution2d_nhwc_q8( + deconvolution, + batchSize(), + inputHeight(), + inputWidth(), + inputPtr, + inputPixelStride(), + output.data(), + outputPixelStride(), + nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_run_operator(deconvolution, nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_delete_operator(deconvolution)); + deconvolution = nullptr; + + for (size_t i = 0; i < batchSize(); i++) { + for (size_t y = 0; y < outputHeight(); y++) { + for (size_t x = 0; x < outputWidth(); x++) { + for (size_t g = 0; g < groups(); g++) { + for (size_t c = 0; c < groupOutputChannels(); c++) { + const double scaledAccumulator = + accumulators + [(((i * outputHeight() + y) * outputWidth() + x) * + groups() + + g) * + groupOutputChannels() + + c] / + outputScale; + const double clampedAccumulator = std::max( + std::min( + scaledAccumulator, + double(qmax()) - double(outputZeroPoint)), + double(qmin()) - double(outputZeroPoint)); + ASSERT_NEAR( + clampedAccumulator, + (int32_t( + output + [((i * outputHeight() + y) * outputWidth() + x) * + outputPixelStride() + + g * groupOutputChannels() + c]) - + outputZeroPoint), + 0.9) + << "(x, y) = (" << x << ", " << y << "), group = " << g + << ", channel = " << c; + ASSERT_LE( + double( + int32_t(output + [((i * outputHeight() + y) * outputWidth() + + x) * + outputPixelStride() + + g * groupOutputChannels() + c]) - + outputZeroPoint), + double(qmax()) - double(outputZeroPoint)) + << "(x, y) = (" << x << ", " << y << "), group = " << g + << ", channel = " << c; + ASSERT_GE( + double( + int32_t(output + [((i * outputHeight() + y) * outputWidth() + + x) * + outputPixelStride() + + g * groupOutputChannels() + c]) - + outputZeroPoint), + double(qmin()) - double(outputZeroPoint)) + << "(x, y) = (" << x << ", " << y << "), group = " << g + << ", channel = " << c; + } + } + } + } + } + } + } + + private: + uint32_t paddingTop_{0}; + uint32_t paddingRight_{0}; + uint32_t paddingBottom_{0}; + uint32_t paddingLeft_{0}; + size_t inputHeight_{1}; + size_t inputWidth_{1}; + uint32_t groups_{1}; + size_t groupInputChannels_{1}; + size_t inputPixelStride_{0}; + size_t groupOutputChannels_{1}; + size_t outputPixelStride_{0}; + size_t batchSize_{1}; + uint32_t kernelHeight_{1}; + uint32_t kernelWidth_{1}; + uint32_t adjustmentHeight_{0}; + uint32_t adjustmentWidth_{0}; + uint32_t dilationHeight_{1}; + uint32_t dilationWidth_{1}; + uint32_t strideHeight_{1}; + uint32_t strideWidth_{1}; + uint8_t qmin_{0}; + uint8_t qmax_{255}; + size_t iterations_{1}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/deconvolution.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/deconvolution.cc new file mode 100644 index 0000000000000..dc2c30e877a2a --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/deconvolution.cc @@ -0,0 +1,275 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include "deconvolution-operator-tester.h" + +TEST(DECONVOLUTION_OP, zero_batch) { + DeconvolutionOperatorTester() + .inputSize(5, 5) + .kernelSize(1, 1) + .groupInputChannels(2) + .groupOutputChannels(2) + .iterations(1) + .batchSize(0) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, 1x1) { + DeconvolutionOperatorTester() + .inputSize(27, 29) + .kernelSize(1, 1) + .groupInputChannels(23) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, 1x1_with_qmin) { + DeconvolutionOperatorTester() + .inputSize(27, 29) + .kernelSize(1, 1) + .groupInputChannels(23) + .groupOutputChannels(19) + .qmin(128) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, 1x1_with_qmax) { + DeconvolutionOperatorTester() + .inputSize(27, 29) + .kernelSize(1, 1) + .groupInputChannels(23) + .groupOutputChannels(19) + .qmax(128) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, 1x1_with_input_stride) { + DeconvolutionOperatorTester() + .inputSize(27, 29) + .kernelSize(1, 1) + .inputPixelStride(28) + .groupInputChannels(23) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, 1x1_with_output_stride) { + DeconvolutionOperatorTester() + .inputSize(27, 29) + .kernelSize(1, 1) + .outputPixelStride(29) + .groupInputChannels(23) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, 1x1_with_batch) { + DeconvolutionOperatorTester() + .inputSize(13, 14) + .kernelSize(1, 1) + .batchSize(3) + .groupInputChannels(23) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, grouped_1x1) { + DeconvolutionOperatorTester() + .inputSize(24, 25) + .kernelSize(1, 1) + .groups(2) + .groupInputChannels(17) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, 1x3) { + DeconvolutionOperatorTester() + .inputSize(20, 19) + .paddingWidth(1) + .kernelSize(1, 3) + .groupInputChannels(17) + .groupOutputChannels(15) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, grouped_1x3) { + DeconvolutionOperatorTester() + .inputSize(20, 19) + .paddingWidth(1) + .kernelSize(1, 3) + .groups(2) + .groupInputChannels(17) + .groupOutputChannels(15) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, 3x1) { + DeconvolutionOperatorTester() + .inputSize(19, 20) + .paddingHeight(1) + .kernelSize(3, 1) + .groupInputChannels(17) + .groupOutputChannels(15) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, grouped_3x1) { + DeconvolutionOperatorTester() + .inputSize(19, 20) + .paddingHeight(1) + .kernelSize(3, 1) + .groups(2) + .groupInputChannels(17) + .groupOutputChannels(15) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, 3x3) { + DeconvolutionOperatorTester() + .inputSize(13, 12) + .padding(1) + .kernelSize(3, 3) + .groupInputChannels(15) + .groupOutputChannels(17) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, 3x3_with_input_stride) { + DeconvolutionOperatorTester() + .inputSize(13, 12) + .padding(1) + .kernelSize(3, 3) + .inputPixelStride(22) + .groupInputChannels(15) + .groupOutputChannels(17) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, 3x3_with_output_stride) { + DeconvolutionOperatorTester() + .inputSize(13, 12) + .padding(1) + .kernelSize(3, 3) + .outputPixelStride(23) + .groupInputChannels(15) + .groupOutputChannels(17) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, 3x3_with_batch) { + DeconvolutionOperatorTester() + .inputSize(10, 9) + .padding(1) + .kernelSize(3, 3) + .batchSize(3) + .groupInputChannels(15) + .groupOutputChannels(17) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, grouped_3x3) { + DeconvolutionOperatorTester() + .inputSize(10, 11) + .padding(1) + .kernelSize(3, 3) + .groups(2) + .groupInputChannels(14) + .groupOutputChannels(13) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, 3x3s2) { + DeconvolutionOperatorTester() + .inputSize(19, 21) + .padding(1) + .kernelSize(3, 3) + .stride(2) + .groupInputChannels(27) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, 3x3s1x2) { + DeconvolutionOperatorTester() + .inputSize(13, 13) + .padding(1) + .kernelSize(3, 3) + .stride(1, 2) + .groupInputChannels(27) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, 3x3s2x1) { + DeconvolutionOperatorTester() + .inputSize(13, 13) + .padding(1) + .kernelSize(3, 3) + .stride(2, 1) + .groupInputChannels(27) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, 3x3d2) { + DeconvolutionOperatorTester() + .inputSize(13, 14) + .padding(2) + .kernelSize(3, 3) + .dilation(2) + .groupInputChannels(27) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, 3x3d1x2) { + DeconvolutionOperatorTester() + .inputSize(14, 15) + .padding(1, 2) + .kernelSize(3, 3) + .dilation(1, 2) + .groupInputChannels(27) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(DECONVOLUTION_OP, 3x3d2x1) { + DeconvolutionOperatorTester() + .inputSize(15, 14) + .padding(2, 1) + .kernelSize(3, 3) + .dilation(2, 1) + .groupInputChannels(27) + .groupOutputChannels(19) + .iterations(3) + .testQ8(); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/dwconv-microkernel-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/dwconv-microkernel-tester.h new file mode 100644 index 0000000000000..18cb65735bca4 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/dwconv-microkernel-tester.h @@ -0,0 +1,499 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +class DWConvMicrokernelTester { + public: + inline DWConvMicrokernelTester& width(uint32_t width) { + assert(width >= 1); + this->width_ = width; + return *this; + } + + inline uint32_t width() const { + return this->width_; + } + + inline DWConvMicrokernelTester& subsampling(uint32_t subsampling) { + assert(subsampling >= 1); + this->subsampling_ = subsampling; + return *this; + } + + inline uint32_t subsampling() const { + return this->subsampling_; + } + + inline DWConvMicrokernelTester& channels(uint32_t channels) { + assert(channels >= 1); + this->channels_ = channels; + return *this; + } + + inline uint32_t channels() const { + return this->channels_; + } + + inline DWConvMicrokernelTester& cr(uint32_t cr) { + assert(cr != 0); + assert((cr & (cr - 1)) == 0); + this->cr_ = cr; + return *this; + } + + inline uint32_t cr() const { + return this->cr_; + } + + inline uint32_t packedChannels() const { + return (channels() + (cr() - 1)) & -cr(); + } + + inline DWConvMicrokernelTester& kernelHeight(uint32_t kernelHeight) { + assert(kernelHeight != 0); + this->kernelHeight_ = kernelHeight; + return *this; + } + + inline uint32_t kernelHeight() const { + return this->kernelHeight_; + } + + inline DWConvMicrokernelTester& kernelWidth(uint32_t kernelWidth) { + assert(kernelWidth != 0); + this->kernelWidth_ = kernelWidth; + return *this; + } + + inline uint32_t kernelWidth() const { + return this->kernelWidth_; + } + + inline uint32_t kernelSize() const { + return kernelHeight() * kernelWidth(); + } + + inline DWConvMicrokernelTester& inputStride(uint32_t inputStride) { + assert(inputStride != 0); + this->inputStride_ = inputStride; + return *this; + } + + inline uint32_t inputStride() const { + if (this->inputStride_ == 0) { + return channels(); + } else { + assert(this->inputStride_ >= channels()); + return this->inputStride_; + } + } + + inline DWConvMicrokernelTester& outputStride(uint32_t outputStride) { + assert(outputStride != 0); + this->outputStride_ = outputStride; + return *this; + } + + inline uint32_t outputStride() const { + if (this->outputStride_ == 0) { + return channels(); + } else { + assert(this->outputStride_ >= channels()); + return this->outputStride_; + } + } + + inline DWConvMicrokernelTester& inputZeroPoint(uint8_t inputZeroPoint) { + this->inputZeroPoint_ = inputZeroPoint; + return *this; + } + + inline uint8_t inputZeroPoint() const { + return this->inputZeroPoint_; + } + + inline DWConvMicrokernelTester& kernelZeroPoint(uint8_t kernelZeroPoint) { + this->kernelZeroPoint_ = kernelZeroPoint; + return *this; + } + + inline uint8_t kernelZeroPoint() const { + return this->kernelZeroPoint_; + } + + inline DWConvMicrokernelTester& qmin(uint8_t qmin) { + this->qmin_ = qmin; + return *this; + } + + inline uint8_t qmin() const { + return this->qmin_; + } + + inline DWConvMicrokernelTester& qmax(uint8_t qmax) { + this->qmax_ = qmax; + return *this; + } + + inline uint8_t qmax() const { + return this->qmax_; + } + + inline DWConvMicrokernelTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void test(pytorch_q8dwconv_up_ukernel_function q8dwconv) const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto s32rng = + std::bind(std::uniform_int_distribution(-10000, 10000), rng); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input( + (kernelSize() + (width() * subsampling() - 1) * kernelHeight() - 1) * + inputStride() + + channels() + 8); + std::vector kernel(channels() * kernelSize()); + std::vector> packedWeights( + (kernelSize() + sizeof(int32_t) / sizeof(uint8_t)) * packedChannels()); + std::vector bias(packedChannels()); + std::vector accumulators(width() * channels()); + std::vector output((width() - 1) * outputStride() + channels()); + std::vector indirectInput( + kernelSize() + (width() * subsampling() - 1) * kernelHeight()); + + const uint8_t* inputPtr = input.data() + 8; + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::generate(kernel.begin(), kernel.end(), std::ref(u8rng)); + std::generate(bias.begin(), bias.end(), std::ref(s32rng)); + std::fill(accumulators.begin(), accumulators.end(), 0); + + ASSERT_NE( + *std::max_element(input.cbegin(), input.cend()), + *std::min_element(input.cbegin(), input.cend())); + ASSERT_NE( + *std::max_element(kernel.cbegin(), kernel.cend()), + *std::min_element(kernel.cbegin(), kernel.cend())); + + std::fill(packedWeights.begin(), packedWeights.end(), 0xA5); + + pytorch_pack_q8dw_w( + kernelHeight(), + kernelWidth(), + channels(), + cr(), +#if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION + inputZeroPoint(), + kernelZeroPoint(), +#endif + kernel.data(), + bias.data(), + packedWeights.data()); + + for (size_t i = 0; + i < kernelSize() + (width() * subsampling() - 1) * kernelHeight(); + i++) { + indirectInput[i] = inputPtr + i * inputStride(); + } + std::shuffle(indirectInput.begin(), indirectInput.end(), rng); + + for (size_t x = 0; x < width(); x++) { + for (size_t c = 0; c < channels(); c++) { + int32_t acc = bias[c]; + for (size_t kx = 0; kx < kernelWidth(); kx++) { + for (size_t ky = 0; ky < kernelHeight(); ky++) { + acc += (int32_t(indirectInput + [(x * subsampling() + kx) * kernelHeight() + + ky][c]) - + int32_t(inputZeroPoint())) * + (int32_t( + kernel[(c * kernelHeight() + ky) * kernelWidth() + kx]) - + int32_t(kernelZeroPoint())); + } + } + accumulators[x * channels() + c] = acc; + } + } + const int32_t accumulatorsMin = + *std::min_element(accumulators.cbegin(), accumulators.cend()); + const int32_t accumulatorsMax = + *std::max_element(accumulators.cbegin(), accumulators.cend()); + const uint32_t accumulatorsRange = + uint32_t(accumulatorsMax) - uint32_t(accumulatorsMin); + ASSERT_NE(0, accumulatorsRange); + + const double outputScale = accumulatorsRange >= 256 + ? double(accumulatorsRange) / 255.0 + : 1.00001; + const uint8_t outputZeroPoint = uint8_t(std::max( + std::min( + lrint( + 127.5 - + 0.5 * double(accumulatorsMin + accumulatorsMax) / + outputScale), + long(std::numeric_limits::max())), + long(std::numeric_limits::min()))); + + const float requantizationScale = 1.0f / float(outputScale); + const union pytorch_qnnp_conv_quantization_params quantizationParams = + pytorch_qnnp_compute_conv_quantization_params( + inputZeroPoint(), + kernelZeroPoint(), + requantizationScale, + outputZeroPoint, + qmin(), + qmax()); + const union pytorch_qnnp_q31_requantization_params + scalarRequantizationParams = + pytorch_qnnp_compute_scalar_requantization_params( + requantizationScale, outputZeroPoint, qmin(), qmax()); + + q8dwconv( + channels(), + width(), + indirectInput.data(), + packedWeights.data(), + output.data(), + kernelHeight() * subsampling() * sizeof(void*), + (outputStride() - channels()) * sizeof(uint8_t), + &quantizationParams); + + for (size_t x = 0; x < width(); x++) { + for (size_t c = 0; c < channels(); c++) { + const uint8_t referenceOutput = pytorch_qnnp_q31_requantize( + accumulators[x * channels() + c], scalarRequantizationParams); + const double scaledAccumulator = + accumulators[x * channels() + c] / outputScale + + double(outputZeroPoint); + const double clampedAccumulator = std::max( + std::min(scaledAccumulator, double(qmax())), double(qmin())); + ASSERT_NEAR( + clampedAccumulator, double(output[x * outputStride() + c]), 0.6) + << "x = " << x << ", channel = " << c; + ASSERT_EQ( + uint32_t(referenceOutput), + uint32_t(output[x * outputStride() + c])) + << "x = " << x << ", channel = " << c; + } + } + } + } + + void test(pytorch_q8dwconv_mp_ukernel_function q8dwconv) const { + ASSERT_EQ(25, kernelSize()) + << "only 5x5 microkernel is currently supported"; + + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto s32rng = + std::bind(std::uniform_int_distribution(-10000, 10000), rng); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input( + (kernelSize() + (width() * subsampling() - 1) * kernelHeight() - 1) * + inputStride() + + channels() + 8); + std::vector kernel(channels() * kernelSize()); + std::vector> packedWeights( + (kernelSize() + sizeof(int32_t) / sizeof(uint8_t)) * packedChannels()); + std::vector bias(packedChannels()); + std::vector accumulators(width() * channels()); + std::vector mpAcc(width() * packedChannels()); + std::vector output((width() - 1) * outputStride() + channels()); + std::vector indirectInput( + kernelSize() + (width() * subsampling() - 1) * kernelHeight()); + + const uint8_t* inputPtr = input.data() + 8; + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::generate(kernel.begin(), kernel.end(), std::ref(u8rng)); + std::generate(bias.begin(), bias.end(), std::ref(s32rng)); + std::fill(accumulators.begin(), accumulators.end(), 0); + std::fill(mpAcc.begin(), mpAcc.end(), 0xA5A55A5A); + + ASSERT_NE( + *std::max_element(input.cbegin(), input.cend()), + *std::min_element(input.cbegin(), input.cend())); + ASSERT_NE( + *std::max_element(kernel.cbegin(), kernel.cend()), + *std::min_element(kernel.cbegin(), kernel.cend())); + + std::fill(packedWeights.begin(), packedWeights.end(), 0xA5); + + ASSERT_EQ(25, kernelSize()) + << "only 5x5 microkernel is currently supported"; + pytorch_pack_q8dw_w_dilation( + kernelHeight(), + kernelWidth(), + channels(), + cr(), + 0, + kernelHeight(), + 0, + 2, + kernel.data(), + bias.data(), + packedWeights.data(), + true); + pytorch_pack_q8dw_w_dilation( + kernelHeight(), + kernelWidth(), + channels(), + cr(), + 0, + kernelHeight(), + 2, + 4, + kernel.data(), + bias.data(), + packedWeights.data() + + (10 + sizeof(int32_t) / sizeof(uint8_t)) * packedChannels(), + false); + pytorch_pack_q8dw_w_dilation( + kernelHeight(), + kernelWidth(), + channels(), + cr(), + 0, + kernelHeight(), + 4, + 5, + kernel.data(), + bias.data(), + packedWeights.data() + + (20 + sizeof(int32_t) / sizeof(uint8_t)) * packedChannels(), + false); + for (size_t i = 0; + i < kernelSize() + (width() * subsampling() - 1) * kernelHeight(); + i++) { + indirectInput[i] = inputPtr + i * inputStride(); + } + std::shuffle(indirectInput.begin(), indirectInput.end(), rng); + + for (size_t x = 0; x < width(); x++) { + for (size_t c = 0; c < channels(); c++) { + int32_t acc = bias[c]; + for (size_t kx = 0; kx < kernelWidth(); kx++) { + for (size_t ky = 0; ky < kernelHeight(); ky++) { + acc += (int32_t(indirectInput + [(x * subsampling() + kx) * kernelHeight() + + ky][c]) - + int32_t(inputZeroPoint())) * + (int32_t( + kernel[(c * kernelHeight() + ky) * kernelWidth() + kx]) - + int32_t(kernelZeroPoint())); + } + } + accumulators[x * channels() + c] = acc; + } + } + const int32_t accumulatorsMin = + *std::min_element(accumulators.cbegin(), accumulators.cend()); + const int32_t accumulatorsMax = + *std::max_element(accumulators.cbegin(), accumulators.cend()); + const uint32_t accumulatorsRange = + uint32_t(accumulatorsMax) - uint32_t(accumulatorsMin); + ASSERT_NE(0, accumulatorsRange); + + const double outputScale = accumulatorsRange >= 256 + ? double(accumulatorsRange) / 255.0 + : 1.00001; + const uint8_t outputZeroPoint = uint8_t(std::max( + std::min( + lrint( + 127.5 - + 0.5 * double(accumulatorsMin + accumulatorsMax) / + outputScale), + long(std::numeric_limits::max())), + long(std::numeric_limits::min()))); + + const float requantizationScale = 1.0f / float(outputScale); + const union pytorch_qnnp_conv_quantization_params quantizationParams = + pytorch_qnnp_compute_conv_quantization_params( + inputZeroPoint(), + kernelZeroPoint(), + requantizationScale, + outputZeroPoint, + qmin(), + qmax()); + const union pytorch_qnnp_q31_requantization_params + scalarRequantizationParams = + pytorch_qnnp_compute_scalar_requantization_params( + requantizationScale, outputZeroPoint, qmin(), qmax()); + + q8dwconv( + channels(), + width(), + indirectInput.data(), + packedWeights.data(), + mpAcc.data(), + output.data(), + kernelHeight() * subsampling() * sizeof(void*), + (outputStride() - channels()) * sizeof(uint8_t), + &quantizationParams); + + for (size_t x = 0; x < width(); x++) { + for (size_t c = 0; c < channels(); c++) { + const uint8_t referenceOutput = pytorch_qnnp_q31_requantize( + accumulators[x * channels() + c], scalarRequantizationParams); + const double scaledAccumulator = + accumulators[x * channels() + c] / outputScale + + double(outputZeroPoint); + const double clampedAccumulator = std::max( + std::min(scaledAccumulator, double(qmax())), double(qmin())); + ASSERT_NEAR( + clampedAccumulator, double(output[x * outputStride() + c]), 0.6) + << "x = " << x << ", channel = " << c; + ASSERT_EQ( + uint32_t(referenceOutput), + uint32_t(output[x * outputStride() + c])) + << "x = " << x << ", channel = " << c; + } + } + } + } + + private: + uint32_t channels_{1}; + uint32_t cr_{1}; + uint32_t width_{1}; + uint32_t subsampling_{1}; + uint32_t kernelHeight_{1}; + uint32_t kernelWidth_{1}; + uint32_t inputStride_{0}; + uint32_t outputStride_{0}; + uint8_t inputZeroPoint_{127}; + uint8_t kernelZeroPoint_{127}; + uint8_t qmin_{0}; + uint8_t qmax_{255}; + size_t iterations_{3}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/fully-connected-operator-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/fully-connected-operator-tester.h new file mode 100644 index 0000000000000..d5b67e73a69f3 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/fully-connected-operator-tester.h @@ -0,0 +1,273 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +class FullyConnectedOperatorTester { + public: + inline FullyConnectedOperatorTester& inputChannels(size_t inputChannels) { + assert(inputChannels >= 1); + this->inputChannels_ = inputChannels; + return *this; + } + + inline size_t inputChannels() const { + return this->inputChannels_; + } + + inline FullyConnectedOperatorTester& outputChannels(size_t outputChannels) { + assert(outputChannels >= 1); + this->outputChannels_ = outputChannels; + return *this; + } + + inline size_t outputChannels() const { + return this->outputChannels_; + } + + inline FullyConnectedOperatorTester& batchSize(size_t batchSize) { + this->batchSize_ = batchSize; + return *this; + } + + inline size_t batchSize() const { + return this->batchSize_; + } + + inline FullyConnectedOperatorTester& inputStride(size_t inputStride) { + assert(inputStride >= 1); + this->inputStride_ = inputStride; + return *this; + } + + inline size_t inputStride() const { + if (this->inputStride_ == 0) { + return inputChannels(); + } else { + assert(this->inputStride_ >= inputChannels()); + return this->inputStride_; + } + } + + inline FullyConnectedOperatorTester& outputStride(size_t outputStride) { + assert(outputStride >= 1); + this->outputStride_ = outputStride; + return *this; + } + + inline size_t outputStride() const { + if (this->outputStride_ == 0) { + return outputChannels(); + } else { + assert(this->outputStride_ >= outputChannels()); + return this->outputStride_; + } + } + + inline FullyConnectedOperatorTester& qmin(uint8_t qmin) { + this->qmin_ = qmin; + return *this; + } + + inline uint8_t qmin() const { + return this->qmin_; + } + + inline FullyConnectedOperatorTester& qmax(uint8_t qmax) { + this->qmax_ = qmax; + return *this; + } + + inline uint8_t qmax() const { + return this->qmax_; + } + + inline FullyConnectedOperatorTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void testQ8(bool runtime_quant = false ) const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto s32rng = + std::bind(std::uniform_int_distribution(-10000, 10000), rng); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input( + (batchSize() - 1) * inputStride() + inputChannels() + 8); + std::vector kernel(outputChannels() * inputChannels()); + std::vector bias(outputChannels()); + std::vector output( + (batchSize() - 1) * outputStride() + outputChannels()); + std::vector accumulators(batchSize() * outputChannels()); + + const uint8_t* inputPtr = input.data() + 8; + const uint8_t inputZeroPoint = 127; + const uint8_t kernelZeroPoint = 127; + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::generate(kernel.begin(), kernel.end(), std::ref(u8rng)); + std::generate(bias.begin(), bias.end(), std::ref(s32rng)); + std::fill(output.begin(), output.end(), 0xA5); + std::fill(accumulators.begin(), accumulators.end(), 0); + + for (size_t i = 0; i < batchSize(); i++) { + for (size_t oc = 0; oc < outputChannels(); oc++) { + accumulators[i * outputChannels() + oc] = bias[oc]; + } + } + for (size_t i = 0; i < batchSize(); i++) { + for (size_t oc = 0; oc < outputChannels(); oc++) { + for (size_t ic = 0; ic < inputChannels(); ic++) { + accumulators[i * outputChannels() + oc] += + (int32_t(inputPtr[i * inputStride() + ic]) - + int32_t(inputZeroPoint)) * + (int32_t(kernel[oc * inputChannels() + ic]) - + int32_t(kernelZeroPoint)); + } + } + } + // Create dummy min/max for empty inputs. + // These are only used to compute scale and zero point, + // and real callers will just pull those values from the model. + const int32_t accumulatorsMin = accumulators.empty() + ? 0 + : *std::min_element(accumulators.cbegin(), accumulators.cend()); + const int32_t accumulatorsMax = accumulators.empty() + ? 900 + : *std::max_element(accumulators.cbegin(), accumulators.cend()); + + const double outputScale = + double(uint32_t(accumulatorsMax - accumulatorsMin)) / 255.0; + const uint8_t outputZeroPoint = uint8_t(std::max( + std::min( + lrint( + 127.5 - + 0.5 * double(accumulatorsMin + accumulatorsMax) / + outputScale), + long(std::numeric_limits::max())), + long(std::numeric_limits::min()))); + + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + if (runtime_quant) { + auto packW = std::unique_ptr( + new qnnpack::PackBMatrix( + inputChannels(), + outputChannels(), + kernelZeroPoint, + 1.0f, + kernel.data(), + bias.data())); + + const pytorch_qnnp_status runStatus = qnnpack::qnnpackLinear( + batchSize() /* batch_size */, + inputChannels() /* input_channels */, + outputChannels() /* output_channels */, + inputZeroPoint, + 1.0f /* input scale */, + kernelZeroPoint, + 1.0f /* kernel scale */, + outputZeroPoint, + outputScale, + qmin(), + qmax(), + inputPtr, + inputChannels() /* input_stride */, + packW->getPackedWeights(), + output.data(), + outputStride() /* output_stride */, + nullptr /* threadpool */); + ASSERT_EQ(pytorch_qnnp_status_success, runStatus); + + } + else { + pytorch_qnnp_operator_t convolution = nullptr; + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_create_fully_connected_nc_q8( + inputChannels(), + outputChannels(), + inputZeroPoint, + 1.0f /* input scale */, + kernelZeroPoint, + 1.0f /* kernel scale */, + kernel.data(), + bias.data(), + outputZeroPoint, + outputScale, + qmin(), + qmax(), + 0, + &convolution)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_setup_fully_connected_nc_q8( + convolution, + batchSize(), + inputPtr, + inputStride(), + output.data(), + outputStride())); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_run_operator(convolution, nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_delete_operator(convolution)); + convolution = nullptr; + } + for (size_t i = 0; i < batchSize(); i++) { + for (size_t c = 0; c < outputChannels(); c++) { + const double scaledAccumulator = + accumulators[i * outputChannels() + c] / outputScale; + const double clampedAccumulator = std::max( + std::min( + scaledAccumulator, double(qmax()) - double(outputZeroPoint)), + double(qmin()) - double(outputZeroPoint)); + ASSERT_NEAR( + clampedAccumulator, + (int32_t(output[i * outputStride() + c]) - outputZeroPoint), + 0.9) + << "batch index = " << i << ", channel = " << c; + } + } + } + } + + private: + size_t inputChannels_{1}; + size_t inputStride_{0}; + size_t outputChannels_{1}; + size_t outputStride_{0}; + size_t batchSize_{1}; + uint8_t qmin_{0}; + uint8_t qmax_{255}; + size_t iterations_{1}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/fully-connected.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/fully-connected.cc new file mode 100644 index 0000000000000..b4ec8b725f161 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/fully-connected.cc @@ -0,0 +1,136 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include "fully-connected-operator-tester.h" + +TEST(FULLY_CONNECTED_OP, integration_test) { + FullyConnectedOperatorTester() + .batchSize(4) + .inputChannels(4) + .outputChannels(4) + .iterations(3) + .testQ8(); +} + +TEST(FULLY_CONNECTED_OP, zero_batch) { + FullyConnectedOperatorTester() + .batchSize(0) + .inputChannels(2) + .outputChannels(2) + .iterations(1) + .testQ8(); +} + +TEST(FULLY_CONNECTED_OP, unit_batch) { + FullyConnectedOperatorTester() + .batchSize(1) + .inputChannels(23) + .outputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(FULLY_CONNECTED_OP, unit_batch_with_qmin) { + FullyConnectedOperatorTester() + .batchSize(1) + .inputChannels(23) + .outputChannels(19) + .qmin(128) + .iterations(3) + .testQ8(); +} + +TEST(FULLY_CONNECTED_OP, unit_batch_with_qmax) { + FullyConnectedOperatorTester() + .batchSize(1) + .inputChannels(23) + .outputChannels(19) + .qmax(128) + .iterations(3) + .testQ8(); +} + +TEST(FULLY_CONNECTED_OP, unit_batch_with_input_stride) { + FullyConnectedOperatorTester() + .batchSize(1) + .inputChannels(23) + .inputStride(28) + .outputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(FULLY_CONNECTED_OP, unit_batch_with_output_stride) { + FullyConnectedOperatorTester() + .batchSize(1) + .inputChannels(23) + .outputChannels(19) + .outputStride(29) + .iterations(3) + .testQ8(); +} + +TEST(FULLY_CONNECTED_OP, small_batch) { + FullyConnectedOperatorTester() + .batchSize(12) + .inputChannels(23) + .outputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(FULLY_CONNECTED_OP, small_batch_with_qmin) { + FullyConnectedOperatorTester() + .batchSize(12) + .inputChannels(23) + .outputChannels(19) + .qmin(128) + .iterations(3) + .testQ8(); +} + +TEST(FULLY_CONNECTED_OP, small_batch_with_qmax) { + FullyConnectedOperatorTester() + .batchSize(12) + .inputChannels(23) + .outputChannels(19) + .qmax(128) + .iterations(3) + .testQ8(); +} + +TEST(FULLY_CONNECTED_OP, small_batch_with_input_stride) { + FullyConnectedOperatorTester() + .batchSize(12) + .inputChannels(23) + .inputStride(28) + .outputChannels(19) + .iterations(3) + .testQ8(); +} + +TEST(FULLY_CONNECTED_OP, small_batch_with_output_stride) { + FullyConnectedOperatorTester() + .batchSize(12) + .inputChannels(23) + .outputChannels(19) + .outputStride(29) + .iterations(3) + .testQ8(); +} + +TEST(FULLY_CONNECTED_OP, runtime_quant) { + FullyConnectedOperatorTester() + .batchSize(4) + .inputChannels(4) + .outputChannels(4) + .iterations(3) + .testQ8(true); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/gavgpool-microkernel-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/gavgpool-microkernel-tester.h new file mode 100644 index 0000000000000..5449cc8a39396 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/gavgpool-microkernel-tester.h @@ -0,0 +1,301 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +class GAvgPoolMicrokernelTester { + public: + inline GAvgPoolMicrokernelTester& m(size_t m) { + assert(m != 0); + this->m_ = m; + return *this; + } + + inline size_t m() const { + return this->m_; + } + + inline GAvgPoolMicrokernelTester& n(size_t n) { + assert(n != 0); + this->n_ = n; + return *this; + } + + inline size_t n() const { + return this->n_; + } + + inline GAvgPoolMicrokernelTester& nr(size_t nr) { + assert(nr != 0); + this->nr_ = nr; + return *this; + } + + inline size_t nr() const { + return this->nr_; + } + + inline size_t packedN() const { + return n() % nr() == 0 ? n() : (n() / nr() + 1) * nr(); + } + + inline GAvgPoolMicrokernelTester& xStride(size_t xStride) { + assert(xStride != 0); + this->xStride_ = xStride; + return *this; + } + + inline size_t xStride() const { + if (this->xStride_ == 0) { + return n(); + } else { + assert(this->xStride_ >= n()); + return this->xStride_; + } + } + + inline GAvgPoolMicrokernelTester& xScale(float xScale) { + assert(xScale > 0.0f); + assert(std::isnormal(xScale)); + this->xScale_ = xScale; + return *this; + } + + inline float xScale() const { + return this->xScale_; + } + + inline GAvgPoolMicrokernelTester& xZeroPoint(uint8_t xZeroPoint) { + this->xZeroPoint_ = xZeroPoint; + return *this; + } + + inline uint8_t xZeroPoint() const { + return this->xZeroPoint_; + } + + inline GAvgPoolMicrokernelTester& yScale(float yScale) { + assert(yScale > 0.0f); + assert(std::isnormal(yScale)); + this->yScale_ = yScale; + return *this; + } + + inline float yScale() const { + return this->yScale_; + } + + inline GAvgPoolMicrokernelTester& yZeroPoint(uint8_t yZeroPoint) { + this->yZeroPoint_ = yZeroPoint; + return *this; + } + + inline uint8_t yZeroPoint() const { + return this->yZeroPoint_; + } + + inline GAvgPoolMicrokernelTester& yMin(uint8_t yMin) { + this->yMin_ = yMin; + return *this; + } + + inline uint8_t yMin() const { + return this->yMin_; + } + + inline GAvgPoolMicrokernelTester& yMax(uint8_t yMax) { + this->yMax_ = yMax; + return *this; + } + + inline uint8_t yMax() const { + return this->yMax_; + } + + inline GAvgPoolMicrokernelTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void test(pytorch_q8gavgpool_up_ukernel_function q8gavgpool) const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector x((m() - 1) * xStride() + n()); + std::vector zero(n()); + std::vector y(n()); + std::vector yRef(n()); + std::vector yFP(n()); + std::vector yAcc(n()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(x.begin(), x.end(), std::ref(u8rng)); + std::fill(y.begin(), y.end(), 0xA5); + + /* Prepare quantization parameters */ + const union pytorch_qnnp_avgpool_quantization_params quantizationParams = + pytorch_qnnp_compute_avgpool_quantization_params( + -int32_t(xZeroPoint()) * int32_t(m()), + xScale() / (yScale() * float(m())), + yZeroPoint(), + yMin(), + yMax()); + const union pytorch_qnnp_avgpool_quantization_params + scalarQuantizationParams = + pytorch_qnnp_compute_scalar_avgpool_quantization_params( + -int32_t(xZeroPoint()) * int32_t(m()), + xScale() / (yScale() * float(m())), + yZeroPoint(), + yMin(), + yMax()); + + /* Compute reference results */ + for (size_t j = 0; j < n(); j++) { + int32_t acc = scalarQuantizationParams.scalar.bias; + for (size_t i = 0; i < m(); i++) { + acc += x[i * xStride() + j]; + } + yAcc[j] = acc; + yRef[j] = pytorch_qnnp_avgpool_quantize(acc, scalarQuantizationParams); + yFP[j] = float(acc) * (xScale() / (yScale() * float(m()))) + + float(yZeroPoint()); + yFP[j] = std::min(yFP[j], float(yMax())); + yFP[j] = std::max(yFP[j], float(yMin())); + } + + /* Call optimized micro-kernel */ + q8gavgpool( + m(), + n(), + x.data(), + xStride() * sizeof(uint8_t), + zero.data(), + y.data(), + &quantizationParams); + + /* Verify results */ + for (size_t i = 0; i < n(); i++) { + ASSERT_LE(uint32_t(y[i]), uint32_t(yMax())) + << "at position " << i << ", m = " << m() << ", n = " << n(); + ASSERT_GE(uint32_t(y[i]), uint32_t(yMin())) + << "at position " << i << ", m = " << m() << ", n = " << n(); + ASSERT_NEAR(float(int32_t(y[i])), yFP[i], 0.5f) + << "at position " << i << ", m = " << m() << ", n = " << n() + << ", acc = " << yAcc[i]; + ASSERT_EQ(uint32_t(yRef[i]), uint32_t(y[i])) + << "at position " << i << ", m = " << m() << ", n = " << n() + << ", acc = " << yAcc[i]; + } + } + } + + void test(pytorch_q8gavgpool_mp_ukernel_function q8gavgpool) const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector x((m() - 1) * xStride() + n()); + std::vector> mpAcc(packedN()); + std::vector zero(n()); + std::vector y(n()); + std::vector yRef(n()); + std::vector yFP(n()); + std::vector yAcc(n()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(x.begin(), x.end(), std::ref(u8rng)); + std::fill(y.begin(), y.end(), 0xA5); + + /* Prepare quantization parameters */ + const union pytorch_qnnp_avgpool_quantization_params quantizationParams = + pytorch_qnnp_compute_avgpool_quantization_params( + -int32_t(xZeroPoint()) * int32_t(m()), + xScale() / (yScale() * float(m())), + yZeroPoint(), + yMin(), + yMax()); + const union pytorch_qnnp_avgpool_quantization_params + scalarQuantizationParams = + pytorch_qnnp_compute_scalar_avgpool_quantization_params( + -int32_t(xZeroPoint()) * int32_t(m()), + xScale() / (yScale() * float(m())), + yZeroPoint(), + yMin(), + yMax()); + + /* Compute reference results */ + for (size_t j = 0; j < n(); j++) { + int32_t acc = scalarQuantizationParams.scalar.bias; + for (size_t i = 0; i < m(); i++) { + acc += x[i * xStride() + j]; + } + + yAcc[j] = acc; + yRef[j] = pytorch_qnnp_avgpool_quantize(acc, scalarQuantizationParams); + yFP[j] = float(acc) * (xScale() / (yScale() * float(m()))) + + float(yZeroPoint()); + yFP[j] = std::min(yFP[j], float(yMax())); + yFP[j] = std::max(yFP[j], float(yMin())); + } + + /* Call optimized micro-kernel */ + q8gavgpool( + m(), + n(), + x.data(), + xStride() * sizeof(uint8_t), + zero.data(), + mpAcc.data(), + y.data(), + &quantizationParams); + + /* Verify results */ + for (size_t i = 0; i < n(); i++) { + ASSERT_LE(uint32_t(y[i]), uint32_t(yMax())) + << "at position " << i << ", m = " << m() << ", n = " << n(); + ASSERT_GE(uint32_t(y[i]), uint32_t(yMin())) + << "at position " << i << ", m = " << m() << ", n = " << n(); + ASSERT_NEAR(float(int32_t(y[i])), yFP[i], 0.5f) + << "at position " << i << ", m = " << m() << ", n = " << n() + << ", acc = " << yAcc[i]; + ASSERT_EQ(uint32_t(yRef[i]), uint32_t(y[i])) + << "at position " << i << ", m = " << m() << ", n = " << n() + << ", acc = " << yAcc[i]; + } + } + } + + private: + size_t m_{1}; + size_t n_{1}; + size_t nr_{1}; + size_t xStride_{0}; + float xScale_{1.25f}; + float yScale_{0.75f}; + uint8_t xZeroPoint_{121}; + uint8_t yZeroPoint_{133}; + uint8_t yMin_{0}; + uint8_t yMax_{255}; + size_t iterations_{15}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/gemm-microkernel-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/gemm-microkernel-tester.h new file mode 100644 index 0000000000000..997dcb66a3641 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/gemm-microkernel-tester.h @@ -0,0 +1,997 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +class GemmMicrokernelTester { + public: + inline GemmMicrokernelTester& mr(size_t mr) { + this->mr_ = mr; + return *this; + } + + inline size_t mr() const { + return this->mr_; + } + + inline GemmMicrokernelTester& nr(size_t nr) { + this->nr_ = nr; + return *this; + } + + inline size_t nr() const { + return this->nr_; + } + + inline GemmMicrokernelTester& np(size_t np) { + this->np_ = np; + return *this; + } + + inline size_t np() const { + return this->np_; + } + + inline GemmMicrokernelTester& kr(size_t kr) { + this->kr_ = kr; + return *this; + } + + inline size_t kr() const { + return this->kr_; + } + + inline GemmMicrokernelTester& m(size_t m) { + this->m_ = m; + return *this; + } + + inline size_t m() const { + return this->m_; + } + + inline GemmMicrokernelTester& n(size_t n) { + this->n_ = n; + return *this; + } + + inline size_t n() const { + return this->n_; + } + + inline GemmMicrokernelTester& k(size_t k) { + this->k_ = k; + return *this; + } + + inline size_t k() const { + return this->k_; + } + + inline GemmMicrokernelTester& ks(size_t ks) { + this->ks_ = ks; + return *this; + } + + inline size_t ks() const { + return this->ks_; + } + + inline size_t packedK() const { + return k() % kr() == 0 ? k() : (k() / kr() + 1) * kr(); + } + + inline size_t packedN() const { + return n() % np() == 0 ? n() : (n() / np() + 1) * np(); + } + + inline size_t biasN() const { + return n() % nr() == 0 ? n() : (n() / nr() + 1) * nr(); + } + + inline GemmMicrokernelTester& aStride(size_t aStride) { + this->aStride_ = aStride; + return *this; + } + + inline size_t aStride() const { + return this->aStride_ == 0 ? k() : this->aStride_; + } + + inline GemmMicrokernelTester& cStride(size_t cStride) { + this->cStride_ = cStride; + return *this; + } + + inline size_t cStride() const { + return this->cStride_ == 0 ? n() : this->cStride_; + } + + inline GemmMicrokernelTester& aZeroPoint(uint8_t aZeroPoint) { + this->aZeroPoint_ = aZeroPoint; + return *this; + } + + inline uint8_t aZeroPoint() const { + return this->aZeroPoint_; + } + + inline GemmMicrokernelTester& bZeroPoint(uint8_t bZeroPoint) { + this->bZeroPoint_ = bZeroPoint; + return *this; + } + + inline uint8_t bZeroPoint() const { + return this->bZeroPoint_; + } + + inline GemmMicrokernelTester& qmin(uint8_t qmin) { + this->qmin_ = qmin; + return *this; + } + + inline uint8_t qmin() const { + return this->qmin_; + } + + inline GemmMicrokernelTester& qmax(uint8_t qmax) { + this->qmax_ = qmax; + return *this; + } + + inline uint8_t qmax() const { + return this->qmax_; + } + + inline GemmMicrokernelTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void test(pytorch_q8gemm_ukernel_function qgemm) const { + ASSERT_LE(m(), mr()); + ASSERT_LE(n(), nr()); + ASSERT_GE(k(), kr()); + + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto s32rng = + std::bind(std::uniform_int_distribution(-10000, 10000), rng); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector a((m() - 1) * aStride() + k() + 8); + std::vector b(n() * k()); + std::vector bias(n()); + std::vector> packedW( + packedN() * packedK() + biasN() * sizeof(uint32_t) / sizeof(uint8_t)); + std::vector c((m() - 1) * cStride() + n()); + std::vector acc(m() * n()); + std::vector cRef(m() * n()); + + const uint8_t* aPtr = a.data() + 8; + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(a.begin(), a.end(), std::ref(u8rng)); + std::generate(b.begin(), b.end(), std::ref(u8rng)); + std::generate(bias.begin(), bias.end(), std::ref(s32rng)); + std::fill(c.begin(), c.end(), 0xA5); + + std::fill(packedW.begin(), packedW.end(), bZeroPoint()); + + pytorch_pack_q8gemm_w( + n(), + k(), + nr(), + np(), + kr(), +#if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION + aZeroPoint(), + bZeroPoint(), +#endif + b.data(), + bias.data(), + packedW.data()); + + ASSERT_NE( + *std::max_element(a.cbegin(), a.cend()), + *std::min_element(a.cbegin(), a.cend())); + ASSERT_NE( + *std::max_element(b.cbegin(), b.cend()), + *std::min_element(b.cbegin(), b.cend())); + + /* Compute 32-bit results and output quantization arguments */ + std::fill(acc.begin(), acc.end(), 0); + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + for (size_t kIndex = 0; kIndex < k(); kIndex++) { + ASSERT_LE(n(), packedN()); + ASSERT_LT(mIndex * n() + nIndex, acc.size()); + ASSERT_LT(mIndex * k() + kIndex, a.size()); + acc[mIndex * n() + nIndex] += + (int32_t(aPtr[mIndex * aStride() + kIndex]) - + int32_t(aZeroPoint())) * + (int32_t(b[nIndex * k() + kIndex]) - int32_t(bZeroPoint())); + } + acc[mIndex * n() + nIndex] += bias[nIndex]; + } + } + + const int32_t accMin = *std::min_element(acc.cbegin(), acc.cend()); + const int32_t accMax = *std::max_element(acc.cbegin(), acc.cend()); + if (m() * n() >= 3) { + ASSERT_NE(accMax, accMin) + << "Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() + << ", M x N x K = " << m() << " x " << n() << " x " << k(); + } + + const double cScale = uint32_t(accMax - accMin) >= 256 + ? double(uint32_t(accMax - accMin)) / 255.0 + : 1.00001; + const uint8_t cZeroPoint = uint8_t(std::max( + std::min( + lrint(127.5 - 0.5 * double(accMin + accMax) / cScale), + long(std::numeric_limits::max())), + long(std::numeric_limits::min()))); + + const float requantizationScale = 1.0f / float(cScale); + const union pytorch_qnnp_conv_quantization_params quantizationParams = + pytorch_qnnp_compute_conv_quantization_params( + aZeroPoint(), + bZeroPoint(), + requantizationScale, + cZeroPoint, + qmin(), + qmax()); + const union pytorch_qnnp_q31_requantization_params + scalarRequantizationParams = + pytorch_qnnp_compute_scalar_requantization_params( + requantizationScale, cZeroPoint, qmin(), qmax()); + + qgemm( + m(), + n(), + k(), + aPtr, + aStride() * sizeof(uint8_t), + packedW.data(), + c.data(), + cStride() * sizeof(uint8_t), + &quantizationParams); + + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + cRef[mIndex * n() + nIndex] = pytorch_qnnp_q31_requantize( + acc[mIndex * n() + nIndex], scalarRequantizationParams); + } + } + + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + ASSERT_LE(uint32_t(c[mIndex * cStride() + nIndex]), uint32_t(qmax())); + ASSERT_GE(uint32_t(c[mIndex * cStride() + nIndex]), uint32_t(qmin())); + ASSERT_EQ( + uint32_t(c[mIndex * cStride() + nIndex]), + uint32_t(cRef[mIndex * n() + nIndex])) + << "at " << mIndex << ", " << nIndex + << ": reference = " << (uint32_t)cRef[mIndex * n() + nIndex] + << " (accumulator = " << acc[mIndex * n() + nIndex] + << "), optimized = " << (uint32_t)c[mIndex * cStride() + nIndex] + << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() + << ", M x N x K = " << m() << " x " << n() << " x " << k() + << ", requantization scale = " << requantizationScale + << ", output zero point = " << int32_t(cZeroPoint); + } + } + } + } + + void test(pytorch_q8conv_ukernel_function qconv) const { + ASSERT_LE(m(), mr()); + ASSERT_LE(n(), nr()); + ASSERT_GE(k(), kr()); + + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto s32rng = + std::bind(std::uniform_int_distribution(-10000, 10000), rng); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector a((mr() - 1) * aStride() + k() + 8); + std::vector b(n() * ks() * k()); + std::vector> packedW( + ks() * packedN() * packedK() + + biasN() * sizeof(uint32_t) / sizeof(uint8_t)); + std::vector bias(n()); + std::vector c((m() - 1) * cStride() + n()); + std::vector acc(m() * n()); + std::vector cRef(m() * n()); + std::vector im2col(mr() * ks()); + + const uint8_t* aPtr = a.data() + 8; + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(a.begin(), a.end(), std::ref(u8rng)); + std::generate(b.begin(), b.end(), std::ref(u8rng)); + std::generate(bias.begin(), bias.end(), std::ref(s32rng)); + std::fill(c.begin(), c.end(), 0xA5); + + std::fill(packedW.begin(), packedW.end(), bZeroPoint()); + + pytorch_pack_q8conv_w( + n(), + ks(), + k(), + np(), + kr(), +#if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION + aZeroPoint(), + bZeroPoint(), +#endif + b.data(), + bias.data(), + packedW.data()); + + ASSERT_NE( + *std::max_element(a.cbegin(), a.cend()), + *std::min_element(a.cbegin(), a.cend())); + ASSERT_NE( + *std::max_element(b.cbegin(), b.cend()), + *std::min_element(b.cbegin(), b.cend())); + + for (size_t ksIndex = 0; ksIndex < ks(); ksIndex++) { + for (size_t mIndex = 0; mIndex < mr(); mIndex++) { + im2col[ksIndex * mr() + mIndex] = aPtr + aStride() * mIndex; + } + } + std::shuffle(im2col.begin(), im2col.end(), rng); + for (size_t ksIndex = 0; ksIndex < ks(); ksIndex++) { + for (size_t mIndex = m(); mIndex < mr(); mIndex++) { + im2col[ksIndex * mr() + mIndex] = im2col[ksIndex * mr() + m() - 1]; + } + } + + /* Compute 32-bit results and output quantization arguments */ + std::fill(acc.begin(), acc.end(), 0); + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + for (size_t ksIndex = 0; ksIndex < ks(); ksIndex++) { + for (size_t kBlockStart = 0; kBlockStart < k(); + kBlockStart += kr()) { + for (size_t kBlockOffset = 0; + kBlockOffset < std::min(k() - kBlockStart, kr()); + kBlockOffset++) { + ASSERT_LT(ksIndex * mr() + mIndex, im2col.size()); + ASSERT_LT(kBlockStart + kBlockOffset, k()); + ASSERT_LT(kBlockStart + kBlockOffset, aStride()); + + acc[mIndex * n() + nIndex] += + (int32_t(im2col[ksIndex * mr() + mIndex] + [kBlockStart + kBlockOffset]) - + int32_t(aZeroPoint())) * + (int32_t( + b[(nIndex * ks() + ksIndex) * k() + kBlockStart + + kBlockOffset]) - + int32_t(bZeroPoint())); + } + } + } + acc[mIndex * n() + nIndex] += bias[nIndex]; + } + } + + const int32_t accMin = *std::min_element(acc.cbegin(), acc.cend()); + const int32_t accMax = *std::max_element(acc.cbegin(), acc.cend()); + if (m() * n() >= 3) { + ASSERT_NE(accMax, accMin) + << "Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() + << ", M x N x K = " << m() << " x " << n() << " x " << k(); + } + + const double cScale = uint32_t(accMax - accMin) >= 256 + ? double(uint32_t(accMax - accMin)) / 255.0 + : 1.00001; + const uint8_t cZeroPoint = uint8_t(std::max( + std::min( + lrint(127.5 - 0.5 * double(accMin + accMax) / cScale), + long(std::numeric_limits::max())), + long(std::numeric_limits::min()))); + + const float requantizationScale = 1.0f / float(cScale); + const union pytorch_qnnp_conv_quantization_params quantizationParams = + pytorch_qnnp_compute_conv_quantization_params( + aZeroPoint(), + bZeroPoint(), + requantizationScale, + cZeroPoint, + qmin(), + qmax()); + const union pytorch_qnnp_q31_requantization_params + scalarRequantizationParams = + pytorch_qnnp_compute_scalar_requantization_params( + requantizationScale, cZeroPoint, qmin(), qmax()); + + qconv( + m(), + n(), + k(), + ks(), + im2col.data(), + packedW.data(), + c.data(), + cStride() * sizeof(uint8_t), + &quantizationParams); + + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + cRef[mIndex * n() + nIndex] = pytorch_qnnp_q31_requantize( + acc[mIndex * n() + nIndex], scalarRequantizationParams); + } + } + + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + ASSERT_LE(uint32_t(c[mIndex * cStride() + nIndex]), uint32_t(qmax())); + ASSERT_GE(uint32_t(c[mIndex * cStride() + nIndex]), uint32_t(qmin())); + ASSERT_EQ( + uint32_t(c[mIndex * cStride() + nIndex]), + uint32_t(cRef[mIndex * n() + nIndex])) + << "at " << mIndex << ", " << nIndex + << ": reference = " << uint32_t(cRef[mIndex * n() + nIndex]) + << " (accumulator = " << acc[mIndex * n() + nIndex] + << "), optimized = " << uint32_t(c[mIndex * cStride() + nIndex]) + << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() + << ", M x N x K = " << m() << " x " << n() << " x " << k() + << ", requantization scale = " << requantizationScale + << ", output zero point = " << int32_t(cZeroPoint); + } + } + } + } + + static void q8gemm_compute_row_sum( + const uint8_t* a, + size_t m, + size_t k, + size_t stride, + const int32_t multiplier, + int32_t* row_sum, + pytorch_q8sum_rows_ukernel_function q8sum_rows) { + const size_t block_size = 4; + for (size_t block_start = 0; block_start < m; block_start += block_size) { + q8sum_rows( + a + block_start * stride, + std::min(block_size, m - block_start), + k, + stride, + multiplier, + row_sum + block_start); + } + } + + void test(pytorch_q8gemm_xzp_ukernel_function qgemm) const { + ASSERT_LE(m(), mr()); + ASSERT_LE(n(), nr()); + ASSERT_GE(k(), kr()); + + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto s32rng = + std::bind(std::uniform_int_distribution(-10000, 10000), rng); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector a((m() - 1) * aStride() + k() + 8); + std::vector b(n() * k()); + std::vector bias(n()); + std::vector> packedW( + packedN() * packedK() + biasN() * sizeof(uint32_t) / sizeof(uint8_t)); + std::vector aRowSums(m()); + std::vector c((m() - 1) * cStride() + n()); + std::vector acc(m() * n()); + std::vector cRef(m() * n()); + + const uint8_t* aPtr = a.data() + 8; + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(a.begin(), a.end(), std::ref(u8rng)); + std::generate(b.begin(), b.end(), std::ref(u8rng)); + std::generate(bias.begin(), bias.end(), std::ref(s32rng)); + + std::fill(packedW.begin(), packedW.end(), 0); + pytorch_pack_swizzle_q8gemm_b( + n(), + k(), + np(), + kr(), + 8, +#if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION + aZeroPoint(), + bZeroPoint(), +#endif + b.data(), + bias.data(), + packedW.data()); + + ASSERT_NE( + *std::max_element(a.cbegin(), a.cend()), + *std::min_element(a.cbegin(), a.cend())); + ASSERT_NE( + *std::max_element(b.cbegin(), b.cend()), + *std::min_element(b.cbegin(), b.cend())); + + std::fill(aRowSums.begin(), aRowSums.end(), 0); + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + int32_t sum = 0; + for (size_t kIndex = 0; kIndex < k(); kIndex++) { + sum += int32_t(aPtr[mIndex * aStride() + kIndex]); + } + aRowSums[mIndex] = -sum * int32_t(bZeroPoint()); + } + + /* Compute 32-bit results and output quantization arguments */ + std::fill(acc.begin(), acc.end(), 0); + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + for (size_t kIndex = 0; kIndex < k(); kIndex++) { + ASSERT_LE(n(), packedN()); + ASSERT_LT(mIndex * n() + nIndex, acc.size()); + ASSERT_LT(mIndex * k() + kIndex, a.size()); + acc[mIndex * n() + nIndex] += + (int32_t(aPtr[mIndex * aStride() + kIndex]) - + int32_t(aZeroPoint())) * + (int32_t(b[nIndex * k() + kIndex]) - int32_t(bZeroPoint())); + } + acc[mIndex * n() + nIndex] += bias[nIndex]; + } + } + + const int32_t accMin = *std::min_element(acc.cbegin(), acc.cend()); + const int32_t accMax = *std::max_element(acc.cbegin(), acc.cend()); + if (m() * n() >= 3) { + ASSERT_NE(accMax, accMin) + << "Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() + << ", M x N x K = " << m() << " x " << n() << " x " << k(); + } + + const double cScale = uint32_t(accMax - accMin) >= 256 + ? double(uint32_t(accMax - accMin)) / 255.0 + : 1.00001; + const uint8_t cZeroPoint = uint8_t(std::max( + std::min( + lrint(127.5 - 0.5 * double(accMin + accMax) / cScale), + long(std::numeric_limits::max())), + long(std::numeric_limits::min()))); + + const float requantizationScale = 1.0f / float(cScale); + const union pytorch_qnnp_q31_requantization_params requantizationParams = + pytorch_qnnp_compute_requantization_params( + requantizationScale, cZeroPoint, qmin(), qmax()); + const union pytorch_qnnp_q31_requantization_params + scalarRequantizationParams = + pytorch_qnnp_compute_scalar_requantization_params( + requantizationScale, cZeroPoint, qmin(), qmax()); + + std::fill(c.begin(), c.end(), 0xA5); + qgemm( + m(), + n(), + k(), + aPtr, + aStride(), + aRowSums.data(), + packedW.data(), + c.data(), + cStride(), + &requantizationParams); + + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + cRef[mIndex * n() + nIndex] = pytorch_qnnp_q31_requantize( + acc[mIndex * n() + nIndex], scalarRequantizationParams); + } + } + + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + ASSERT_LE(uint32_t(c[mIndex * cStride() + nIndex]), uint32_t(qmax())); + ASSERT_GE(uint32_t(c[mIndex * cStride() + nIndex]), uint32_t(qmin())); + ASSERT_EQ(c[mIndex * cStride() + nIndex], cRef[mIndex * n() + nIndex]) + << "at " << mIndex << ", " << nIndex + << ": reference = " << (uint32_t)cRef[mIndex * n() + nIndex] + << ", optimized = " << (uint32_t)c[mIndex * cStride() + nIndex] + << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() + << ", M x N x K = " << m() << " x " << n() << " x " << k(); + } + } + } + } + + void test(pytorch_hgemm_ukernel_function hgemm) const { + ASSERT_LE(m(), mr()); + ASSERT_LE(n(), nr()); + ASSERT_GE(k(), kr()); + ASSERT_GE(aStride(), k()); + ASSERT_GE(cStride(), n()); + + std::random_device randomDevice; + auto rng = std::bind( + fp16_ieee_from_fp32_value, + std::bind( + std::uniform_real_distribution(), + std::mt19937(randomDevice()))); + + std::vector a((m() - 1) * aStride() + k() + 4); + std::vector b(n() * k()); + std::vector> packedW( + packedN() * packedK() + biasN()); + std::vector bias(n()); + std::vector c((mr() - 1) * cStride() + nr()); + std::vector cRef(m() * n()); + + const uint16_t* aPtr = a.data() + 4; + + struct pytorch_qnnp_fp16_clamping_params clampingParams; + clampingParams.scale = UINT16_C(0x3C00) /* 1.0 */; + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(a.begin(), a.end(), std::ref(rng)); + std::generate(b.begin(), b.end(), std::ref(rng)); + std::generate(bias.begin(), bias.end(), std::ref(rng)); + std::fill(c.begin(), c.end(), UINT16_C(0x7E00) /* NaN */); + std::fill(cRef.begin(), cRef.end(), 0.0f); + + std::fill(packedW.begin(), packedW.end(), 0); + pytorch_pack_hgemm_w(n(), k(), np(), kr(), b.data(), bias.data(), packedW.data()); + + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + for (size_t kBlockStart = 0; kBlockStart < k(); kBlockStart += kr()) { + for (size_t kBlockOffset = 0; + kBlockOffset < std::min(k() - kBlockStart, kr()); + kBlockOffset++) { + ASSERT_LE(n(), packedN()); + ASSERT_LT(mIndex * n() + nIndex, cRef.size()); + ASSERT_LT(mIndex * k() + kBlockStart + kBlockOffset, a.size()); + + cRef[mIndex * n() + nIndex] += + fp16_ieee_to_fp32_value( + aPtr[mIndex * aStride() + kBlockStart + kBlockOffset]) * + fp16_ieee_to_fp32_value( + b[nIndex * k() + kBlockStart + kBlockOffset]); + } + } + cRef[mIndex * n() + nIndex] += fp16_ieee_to_fp32_value(bias[nIndex]); + } + } + + const float accMin = *std::min_element(cRef.cbegin(), cRef.cend()); + const float accMax = *std::max_element(cRef.cbegin(), cRef.cend()); + const float cMin = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value( + accMin + (accMax - accMin) / 255.0f * float(qmin()))); + const float cMax = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value( + accMax - (accMax - accMin) / 255.0f * float(255 - qmax()))); + clampingParams.max = fp16_ieee_from_fp32_value(cMax); + clampingParams.min = fp16_ieee_from_fp32_value(cMin); + + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + cRef[mIndex * n() + nIndex] = + std::max(std::min(cRef[mIndex * n() + nIndex], cMax), cMin); + } + } + + hgemm( + m(), + n(), + k(), + aPtr, + aStride() * sizeof(uint16_t), + packedW.data(), + c.data(), + cStride() * sizeof(uint16_t), + &clampingParams); + + /* Validate micro-kernel outputs */ + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + ASSERT_NEAR( + fp16_ieee_to_fp32_value(c[mIndex * cStride() + nIndex]), + cRef[mIndex * n() + nIndex], + std::abs(cRef[mIndex * n() + nIndex]) * 1.0e-2f) + << "at " << mIndex << ", " << nIndex + << ": reference = " << cRef[mIndex * n() + nIndex] + << ", optimized = " + << fp16_ieee_to_fp32_value(c[mIndex * cStride() + nIndex]) + << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() + << ", M x N x K = " << m() << " x " << n() << " x " << k(); + } + } + /* Check that micro-kernel did not overwrite data beyond bounds */ + for (size_t mIndex = 0; mIndex < m() - 1; mIndex++) { + for (size_t nIndex = n(); nIndex < cStride(); nIndex++) { + ASSERT_EQ(UINT16_C(0x7E00) /* NaN */, c[mIndex * cStride() + nIndex]) + << "at " << mIndex << ", " << nIndex + << ": Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() + << ", M x N x K = " << m() << " x " << n() << " x " << k(); + } + } + for (size_t i = (m() - 1) * cStride() + n(); i < c.size(); i++) { + ASSERT_EQ(UINT16_C(0x7E00) /* NaN */, c[i]) + << "at i = " << i << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " + << k(); + } + } + } + + void test(pytorch_sgemm_ukernel_function sgemm) const { + ASSERT_LE(m(), mr()); + ASSERT_LE(n(), nr()); + ASSERT_GE(k(), kr()); + ASSERT_GE(aStride(), k()); + ASSERT_GE(cStride(), n()); + + std::random_device randomDevice; + auto rng = std::bind( + std::uniform_real_distribution(), std::mt19937(randomDevice())); + + std::vector a((m() - 1) * aStride() + k()); + std::vector b(n() * k()); + std::vector bias(n()); + std::vector> packedW( + packedN() * packedK() + biasN()); + std::vector c((mr() - 1) * cStride() + nr()); + std::vector cRef(m() * n()); + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(a.begin(), a.end(), std::ref(rng)); + std::generate(b.begin(), b.end(), std::ref(rng)); + std::generate(bias.begin(), bias.end(), std::ref(rng)); + std::fill(c.begin(), c.end(), nanf("")); + std::fill(cRef.begin(), cRef.end(), 0.0f); + + std::fill(packedW.begin(), packedW.end(), 0.0f); + pytorch_pack_sgemm_w(n(), k(), np(), kr(), b.data(), bias.data(), packedW.data()); + + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + for (size_t kIndex = 0; kIndex < k(); kIndex++) { + ASSERT_LE(n(), packedN()); + ASSERT_LT(mIndex * n() + nIndex, cRef.size()); + cRef[mIndex * n() + nIndex] += + a[mIndex * aStride() + kIndex] * b[nIndex * k() + kIndex]; + } + cRef[mIndex * n() + nIndex] += bias[nIndex]; + } + } + + const float accMin = *std::min_element(cRef.cbegin(), cRef.cend()); + const float accMax = *std::max_element(cRef.cbegin(), cRef.cend()); + const float cMin = accMin + (accMax - accMin) / 255.0f * float(qmin()); + const float cMax = + accMax - (accMax - accMin) / 255.0f * float(255 - qmax()); + struct pytorch_qnnp_fp32_clamping_params clampingParams = { + .max = cMax, + .min = cMin, + }; + + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + cRef[mIndex * n() + nIndex] = + std::max(std::min(cRef[mIndex * n() + nIndex], cMax), cMin); + } + } + + sgemm( + m(), + n(), + k(), + a.data(), + aStride() * sizeof(float), + packedW.data(), + c.data(), + cStride() * sizeof(float), + &clampingParams); + + /* Validate micro-kernel outputs */ + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + ASSERT_NEAR( + c[mIndex * cStride() + nIndex], + cRef[mIndex * n() + nIndex], + std::abs(cRef[mIndex * n() + nIndex]) * 1.0e-6f) + << "at " << mIndex << ", " << nIndex + << ": reference = " << cRef[mIndex * n() + nIndex] + << ", optimized = " << c[mIndex * cStride() + nIndex] + << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() + << ", M x N x K = " << m() << " x " << n() << " x " << k(); + } + } + /* Check that micro-kernel did not overwrite data beyond bounds */ + for (size_t mIndex = 0; mIndex < m() - 1; mIndex++) { + for (size_t nIndex = n(); nIndex < cStride(); nIndex++) { + ASSERT_TRUE(std::isnan(c[mIndex * cStride() + nIndex])) + << "at " << mIndex << ", " << nIndex + << ": Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() + << ", M x N x K = " << m() << " x " << n() << " x " << k(); + } + } + for (size_t i = (m() - 1) * cStride() + n(); i < c.size(); i++) { + ASSERT_TRUE(std::isnan(c[i])) + << "at i = " << i << ", Mr x Nr x Kr = " << mr() << " x " << nr() + << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " + << k(); + } + } + } + + void test(pytorch_sconv_ukernel_function sconv) const { + ASSERT_LE(m(), mr()); + ASSERT_LE(n(), nr()); + ASSERT_GE(k(), kr()); + + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto f32rng = std::bind( + std::uniform_real_distribution(), std::mt19937(randomDevice())); + + std::vector a((mr() - 1) * aStride() + k() + 8); + std::vector b(n() * ks() * k()); + std::vector> packedW( + ks() * packedK() * packedN() + biasN()); + std::vector bias(n()); + std::vector c((m() - 1) * cStride() + n()); + std::vector cRef(m() * n()); + std::vector im2col(mr() * ks()); + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(a.begin(), a.end(), std::ref(f32rng)); + std::generate(b.begin(), b.end(), std::ref(f32rng)); + std::generate(bias.begin(), bias.end(), std::ref(f32rng)); + std::fill(c.begin(), c.end(), nanf("")); + std::fill(cRef.begin(), cRef.end(), 0.0f); + + std::fill(packedW.begin(), packedW.end(), 0.0f); + pytorch_pack_sconv_w( + n(), ks(), k(), np(), kr(), b.data(), bias.data(), packedW.data()); + + ASSERT_NE( + *std::max_element(a.cbegin(), a.cend()), + *std::min_element(a.cbegin(), a.cend())); + ASSERT_NE( + *std::max_element(b.cbegin(), b.cend()), + *std::min_element(b.cbegin(), b.cend())); + + for (size_t ksIndex = 0; ksIndex < ks(); ksIndex++) { + for (size_t mIndex = 0; mIndex < mr(); mIndex++) { + im2col[ksIndex * mr() + mIndex] = a.data() + aStride() * mIndex; + } + } + std::shuffle(im2col.begin(), im2col.end(), rng); + for (size_t ksIndex = 0; ksIndex < ks(); ksIndex++) { + for (size_t mIndex = m(); mIndex < mr(); mIndex++) { + im2col[ksIndex * mr() + mIndex] = im2col[ksIndex * mr() + m() - 1]; + } + } + + std::fill(cRef.begin(), cRef.end(), 0.0); + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + for (size_t ksIndex = 0; ksIndex < ks(); ksIndex++) { + for (size_t kBlockStart = 0; kBlockStart < k(); + kBlockStart += kr()) { + for (size_t kBlockOffset = 0; + kBlockOffset < std::min(k() - kBlockStart, kr()); + kBlockOffset++) { + ASSERT_LT(ksIndex * mr() + mIndex, im2col.size()); + ASSERT_LT(kBlockStart + kBlockOffset, k()); + ASSERT_LT(kBlockStart + kBlockOffset, aStride()); + + cRef[mIndex * n() + nIndex] += + double(im2col[ksIndex * mr() + mIndex] + [kBlockStart + kBlockOffset]) * + double( + b[(nIndex * ks() + ksIndex) * k() + kBlockStart + + kBlockOffset]); + } + } + } + cRef[mIndex * n() + nIndex] += bias[nIndex]; + } + } + + const float accMin = *std::min_element(cRef.cbegin(), cRef.cend()); + const float accMax = *std::max_element(cRef.cbegin(), cRef.cend()); + if (m() * n() >= 3) { + ASSERT_NE(accMax, accMin) + << "Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() + << ", M x N x K = " << m() << " x " << n() << " x " << k(); + } + + const float cRefMin = accMin + float(qmin()) / 255.0f * (accMax - accMin); + const float cRefMax = + accMax - float(255 - qmax()) / 255.0f * (accMax - accMin); + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + cRef[mIndex * n() + nIndex] = + std::min(cRef[mIndex * n() + nIndex], cRefMax); + cRef[mIndex * n() + nIndex] = + std::max(cRef[mIndex * n() + nIndex], cRefMin); + } + } + + const struct pytorch_qnnp_fp32_clamping_params clampingParams { + cRefMax, cRefMin + }; + + sconv( + m(), + n(), + k(), + ks(), + im2col.data(), + packedW.data(), + c.data(), + cStride() * sizeof(float), + &clampingParams); + + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + ASSERT_LE(c[mIndex * cStride() + nIndex], cRefMax); + ASSERT_GE(c[mIndex * cStride() + nIndex], cRefMin); + ASSERT_NEAR( + c[mIndex * cStride() + nIndex], + cRef[mIndex * n() + nIndex], + std::abs(cRef[mIndex * n() + nIndex]) * 1.0e-6f) + << "at " << mIndex << ", " << nIndex + << ": reference = " << cRef[mIndex * n() + nIndex] + << ", optimized = " << c[mIndex * cStride() + nIndex] + << ", Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() + << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() + << " x " << ks(); + } + } + } + } + + private: + size_t mr_{1}; + size_t nr_{1}; + size_t np_{1}; + size_t kr_{1}; + size_t m_{1}; + size_t n_{1}; + size_t k_{1}; + size_t ks_{1}; + size_t aStride_{0}; + size_t cStride_{0}; + uint8_t aZeroPoint_{127}; + uint8_t bZeroPoint_{127}; + uint8_t qmin_{0}; + uint8_t qmax_{255}; + size_t iterations_{15}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/global-average-pooling-operator-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/global-average-pooling-operator-tester.h new file mode 100644 index 0000000000000..42016aa7158c2 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/global-average-pooling-operator-tester.h @@ -0,0 +1,253 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +class GlobalAveragePoolingOperatorTester { + public: + inline GlobalAveragePoolingOperatorTester& channels(size_t channels) { + assert(channels != 0); + this->channels_ = channels; + return *this; + } + + inline size_t channels() const { + return this->channels_; + } + + inline GlobalAveragePoolingOperatorTester& width(size_t width) { + assert(width != 0); + this->width_ = width; + return *this; + } + + inline size_t width() const { + return this->width_; + } + + inline GlobalAveragePoolingOperatorTester& inputStride(size_t inputStride) { + assert(inputStride != 0); + this->inputStride_ = inputStride; + return *this; + } + + inline size_t inputStride() const { + if (this->inputStride_ == 0) { + return channels(); + } else { + assert(this->inputStride_ >= channels()); + return this->inputStride_; + } + } + + inline GlobalAveragePoolingOperatorTester& outputStride(size_t outputStride) { + assert(outputStride != 0); + this->outputStride_ = outputStride; + return *this; + } + + inline size_t outputStride() const { + if (this->outputStride_ == 0) { + return channels(); + } else { + assert(this->outputStride_ >= channels()); + return this->outputStride_; + } + } + + inline GlobalAveragePoolingOperatorTester& batchSize(size_t batchSize) { + this->batchSize_ = batchSize; + return *this; + } + + inline size_t batchSize() const { + return this->batchSize_; + } + + inline GlobalAveragePoolingOperatorTester& inputScale(float inputScale) { + assert(inputScale > 0.0f); + assert(std::isnormal(inputScale)); + this->inputScale_ = inputScale; + return *this; + } + + inline float inputScale() const { + return this->inputScale_; + } + + inline GlobalAveragePoolingOperatorTester& inputZeroPoint( + uint8_t inputZeroPoint) { + this->inputZeroPoint_ = inputZeroPoint; + return *this; + } + + inline uint8_t inputZeroPoint() const { + return this->inputZeroPoint_; + } + + inline GlobalAveragePoolingOperatorTester& outputScale(float outputScale) { + assert(outputScale > 0.0f); + assert(std::isnormal(outputScale)); + this->outputScale_ = outputScale; + return *this; + } + + inline float outputScale() const { + return this->outputScale_; + } + + inline GlobalAveragePoolingOperatorTester& outputZeroPoint( + uint8_t outputZeroPoint) { + this->outputZeroPoint_ = outputZeroPoint; + return *this; + } + + inline uint8_t outputZeroPoint() const { + return this->outputZeroPoint_; + } + + inline GlobalAveragePoolingOperatorTester& outputMin(uint8_t outputMin) { + this->outputMin_ = outputMin; + return *this; + } + + inline uint8_t outputMin() const { + return this->outputMin_; + } + + inline GlobalAveragePoolingOperatorTester& outputMax(uint8_t outputMax) { + this->outputMax_ = outputMax; + return *this; + } + + inline uint8_t outputMax() const { + return this->outputMax_; + } + + inline GlobalAveragePoolingOperatorTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void testQ8() const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input( + (batchSize() * width() - 1) * inputStride() + channels()); + std::vector output(batchSize() * outputStride()); + std::vector outputRef(batchSize() * channels()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::fill(output.begin(), output.end(), 0xA5); + + /* Compute reference results */ + const double scale = + double(inputScale()) / (double(width()) * double(outputScale())); + for (size_t i = 0; i < batchSize(); i++) { + for (size_t j = 0; j < channels(); j++) { + double acc = 0.0f; + for (size_t k = 0; k < width(); k++) { + acc += double( + int32_t(input[(i * width() + k) * inputStride() + j]) - + int32_t(inputZeroPoint())); + } + outputRef[i * channels() + j] = + float(acc * scale + double(outputZeroPoint())); + outputRef[i * channels() + j] = std::min( + outputRef[i * channels() + j], float(outputMax())); + outputRef[i * channels() + j] = std::max( + outputRef[i * channels() + j], float(outputMin())); + } + } + + /* Create, setup, run, and destroy Add operator */ + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + pytorch_qnnp_operator_t globalAveragePoolingOp = nullptr; + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_create_global_average_pooling_nwc_q8( + channels(), + inputZeroPoint(), + inputScale(), + outputZeroPoint(), + outputScale(), + outputMin(), + outputMax(), + 0, + &globalAveragePoolingOp)); + ASSERT_NE(nullptr, globalAveragePoolingOp); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_setup_global_average_pooling_nwc_q8( + globalAveragePoolingOp, + batchSize(), + width(), + input.data(), + inputStride(), + output.data(), + outputStride())); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_run_operator( + globalAveragePoolingOp, nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_delete_operator(globalAveragePoolingOp)); + globalAveragePoolingOp = nullptr; + + /* Verify results */ + for (size_t i = 0; i < batchSize(); i++) { + for (size_t c = 0; c < channels(); c++) { + ASSERT_LE( + uint32_t(output[i * outputStride() + c]), uint32_t(outputMax())); + ASSERT_GE( + uint32_t(output[i * outputStride() + c]), uint32_t(outputMin())); + ASSERT_NEAR( + float(int32_t(output[i * outputStride() + c])), + outputRef[i * channels() + c], + 0.80f) + << "in batch index " << i << ", channel " << c; + } + } + } + } + + private: + size_t batchSize_{1}; + size_t width_{1}; + size_t channels_{1}; + size_t inputStride_{0}; + size_t outputStride_{0}; + float inputScale_{1.0f}; + float outputScale_{1.0f}; + uint8_t inputZeroPoint_{121}; + uint8_t outputZeroPoint_{133}; + uint8_t outputMin_{0}; + uint8_t outputMax_{255}; + size_t iterations_{1}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/global-average-pooling.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/global-average-pooling.cc new file mode 100644 index 0000000000000..ea86682791174 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/global-average-pooling.cc @@ -0,0 +1,651 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include "global-average-pooling-operator-tester.h" + +#include + +TEST(GLOBAL_AVERAGE_POOLING_OP, zero_batch) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + GlobalAveragePoolingOperatorTester() + .batchSize(0) + .width(1) + .channels(8) + .testQ8(); +} + +TEST(GLOBAL_AVERAGE_POOLING_OP, unit_batch_many_channels_small_width) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= pytorch_qnnp_params.q8gavgpool.mr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .testQ8(); + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + unit_batch_many_channels_small_width_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= pytorch_qnnp_params.q8gavgpool.mr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .inputStride(5 * pytorch_qnnp_params.q8gavgpool.nr) + .testQ8(); + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + unit_batch_many_channels_small_width_with_input_scale) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= pytorch_qnnp_params.q8gavgpool.mr; + width++) { + for (float inputScale = 0.01f; inputScale < 100.0f; + inputScale *= 3.14159265f) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .inputScale(inputScale) + .testQ8(); + } + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + unit_batch_many_channels_small_width_with_input_zero_point) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= pytorch_qnnp_params.q8gavgpool.mr; + width++) { + for (int32_t inputZeroPoint = 0; inputZeroPoint <= 255; + inputZeroPoint += 51) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .inputZeroPoint(uint8_t(inputZeroPoint)) + .testQ8(); + } + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + unit_batch_many_channels_small_width_with_output_scale) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= pytorch_qnnp_params.q8gavgpool.mr; + width++) { + for (float outputScale = 0.01f; outputScale < 100.0f; + outputScale *= 3.14159265f) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .outputScale(outputScale) + .testQ8(); + } + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + unit_batch_many_channels_small_width_with_output_zero_point) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= pytorch_qnnp_params.q8gavgpool.mr; + width++) { + for (int32_t outputZeroPoint = 0; outputZeroPoint <= 255; + outputZeroPoint += 51) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .outputZeroPoint(uint8_t(outputZeroPoint)) + .testQ8(); + } + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + unit_batch_many_channels_small_width_with_output_min) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= pytorch_qnnp_params.q8gavgpool.mr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .outputMin(128) + .testQ8(); + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + unit_batch_many_channels_small_width_with_output_max) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= pytorch_qnnp_params.q8gavgpool.mr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .outputMax(128) + .testQ8(); + } + } +} + +TEST(GLOBAL_AVERAGE_POOLING_OP, unit_batch_many_channels_large_width) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = pytorch_qnnp_params.q8gavgpool.mr; + width <= 4 * pytorch_qnnp_params.q8gavgpool.mr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .testQ8(); + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + unit_batch_many_channels_large_width_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = pytorch_qnnp_params.q8gavgpool.mr; + width <= 4 * pytorch_qnnp_params.q8gavgpool.mr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .inputStride(5 * pytorch_qnnp_params.q8gavgpool.nr) + .testQ8(); + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + unit_batch_many_channels_large_width_with_input_scale) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = pytorch_qnnp_params.q8gavgpool.mr; + width <= 4 * pytorch_qnnp_params.q8gavgpool.mr; + width++) { + for (float inputScale = 0.01f; inputScale < 100.0f; + inputScale *= 3.14159265f) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .inputScale(inputScale) + .testQ8(); + } + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + unit_batch_many_channels_large_width_with_input_zero_point) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = pytorch_qnnp_params.q8gavgpool.mr; + width <= 4 * pytorch_qnnp_params.q8gavgpool.mr; + width++) { + for (int32_t inputZeroPoint = 0; inputZeroPoint <= 255; + inputZeroPoint += 51) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .inputZeroPoint(uint8_t(inputZeroPoint)) + .testQ8(); + } + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + unit_batch_many_channels_large_width_with_output_scale) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = pytorch_qnnp_params.q8gavgpool.mr; + width <= 4 * pytorch_qnnp_params.q8gavgpool.mr; + width++) { + for (float outputScale = 0.01f; outputScale < 100.0f; + outputScale *= 3.14159265f) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .outputScale(outputScale) + .testQ8(); + } + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + unit_batch_many_channels_large_width_with_output_zero_point) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = pytorch_qnnp_params.q8gavgpool.mr; + width <= 4 * pytorch_qnnp_params.q8gavgpool.mr; + width++) { + for (int32_t outputZeroPoint = 0; outputZeroPoint <= 255; + outputZeroPoint += 51) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .outputZeroPoint(uint8_t(outputZeroPoint)) + .testQ8(); + } + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + unit_batch_many_channels_large_width_with_output_min) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = pytorch_qnnp_params.q8gavgpool.mr; + width <= 4 * pytorch_qnnp_params.q8gavgpool.mr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .outputMin(128) + .testQ8(); + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + unit_batch_many_channels_large_width_with_output_max) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = pytorch_qnnp_params.q8gavgpool.mr; + width <= 4 * pytorch_qnnp_params.q8gavgpool.mr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .outputMax(128) + .testQ8(); + } + } +} + +TEST(GLOBAL_AVERAGE_POOLING_OP, unit_batch_few_channels) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= 2 * pytorch_qnnp_params.q8gavgpool.nr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .testQ8(); + } + } +} + +TEST(GLOBAL_AVERAGE_POOLING_OP, unit_batch_few_channels_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= 2 * pytorch_qnnp_params.q8gavgpool.nr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .inputStride(5 * pytorch_qnnp_params.q8gavgpool.nr) + .testQ8(); + } + } +} + +TEST(GLOBAL_AVERAGE_POOLING_OP, unit_batch_few_channels_with_input_scale) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= 2 * pytorch_qnnp_params.q8gavgpool.nr; + width++) { + for (float inputScale = 0.01f; inputScale < 100.0f; + inputScale *= 3.14159265f) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .inputScale(inputScale) + .testQ8(); + } + } + } +} + +TEST(GLOBAL_AVERAGE_POOLING_OP, unit_batch_few_channels_with_input_zero_point) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= 2 * pytorch_qnnp_params.q8gavgpool.nr; + width++) { + for (int32_t inputZeroPoint = 0; inputZeroPoint <= 255; + inputZeroPoint += 51) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .inputZeroPoint(uint8_t(inputZeroPoint)) + .testQ8(); + } + } + } +} + +TEST(GLOBAL_AVERAGE_POOLING_OP, unit_batch_few_channels_with_output_scale) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= 2 * pytorch_qnnp_params.q8gavgpool.nr; + width++) { + for (float outputScale = 0.01f; outputScale < 100.0f; + outputScale *= 3.14159265f) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .outputScale(outputScale) + .testQ8(); + } + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + unit_batch_few_channels_with_output_zero_point) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= 2 * pytorch_qnnp_params.q8gavgpool.nr; + width++) { + for (int32_t outputZeroPoint = 0; outputZeroPoint <= 255; + outputZeroPoint += 51) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .outputZeroPoint(uint8_t(outputZeroPoint)) + .testQ8(); + } + } + } +} + +TEST(GLOBAL_AVERAGE_POOLING_OP, unit_batch_few_channels_with_output_min) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= 2 * pytorch_qnnp_params.q8gavgpool.nr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .outputMin(128) + .testQ8(); + } + } +} + +TEST(GLOBAL_AVERAGE_POOLING_OP, unit_batch_few_channels_with_output_max) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= 2 * pytorch_qnnp_params.q8gavgpool.nr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(1) + .width(width) + .channels(channels) + .outputMax(128) + .testQ8(); + } + } +} + +TEST(GLOBAL_AVERAGE_POOLING_OP, small_batch_many_channels_small_width) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= pytorch_qnnp_params.q8gavgpool.mr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(3) + .width(width) + .channels(channels) + .testQ8(); + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + small_batch_many_channels_small_width_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= pytorch_qnnp_params.q8gavgpool.mr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(3) + .width(width) + .channels(channels) + .inputStride(5 * pytorch_qnnp_params.q8gavgpool.nr) + .testQ8(); + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + small_batch_many_channels_small_width_with_output_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= pytorch_qnnp_params.q8gavgpool.mr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(3) + .width(width) + .channels(channels) + .outputStride(5 * pytorch_qnnp_params.q8gavgpool.nr) + .testQ8(); + } + } +} + +TEST(GLOBAL_AVERAGE_POOLING_OP, small_batch_many_channels_large_width) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = pytorch_qnnp_params.q8gavgpool.mr; + width <= 4 * pytorch_qnnp_params.q8gavgpool.mr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(3) + .width(width) + .channels(channels) + .testQ8(); + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + small_batch_many_channels_large_width_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = pytorch_qnnp_params.q8gavgpool.mr; + width <= 4 * pytorch_qnnp_params.q8gavgpool.mr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(3) + .width(width) + .channels(channels) + .inputStride(5 * pytorch_qnnp_params.q8gavgpool.nr) + .testQ8(); + } + } +} + +TEST( + GLOBAL_AVERAGE_POOLING_OP, + small_batch_many_channels_large_width_with_output_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.q8gavgpool.nr; + channels <= 3 * pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = pytorch_qnnp_params.q8gavgpool.mr; + width <= 4 * pytorch_qnnp_params.q8gavgpool.mr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(3) + .width(width) + .channels(channels) + .outputStride(5 * pytorch_qnnp_params.q8gavgpool.nr) + .testQ8(); + } + } +} + +TEST(GLOBAL_AVERAGE_POOLING_OP, small_batch_few_channels) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= 2 * pytorch_qnnp_params.q8gavgpool.nr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(3) + .width(width) + .channels(channels) + .testQ8(); + } + } +} + +TEST(GLOBAL_AVERAGE_POOLING_OP, small_batch_few_channels_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= 2 * pytorch_qnnp_params.q8gavgpool.nr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(3) + .width(width) + .channels(channels) + .inputStride(5 * pytorch_qnnp_params.q8gavgpool.nr) + .testQ8(); + } + } +} + +TEST(GLOBAL_AVERAGE_POOLING_OP, small_batch_few_channels_with_output_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.q8gavgpool.nr; + channels++) { + for (size_t width = 1; width <= 2 * pytorch_qnnp_params.q8gavgpool.nr; + width++) { + GlobalAveragePoolingOperatorTester() + .batchSize(3) + .width(width) + .channels(channels) + .outputStride(5 * pytorch_qnnp_params.q8gavgpool.nr) + .testQ8(); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/hgemm.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/hgemm.cc new file mode 100644 index 0000000000000..6e98dcc545cb0 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/hgemm.cc @@ -0,0 +1,183 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include "gemm-microkernel-tester.h" + +#if CPUINFO_ARCH_ARM +TEST(HGEMM_8x8__AARCH32_NEONFP16ARITH, k_eq_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(4).test( + pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith); +} + +TEST(HGEMM_8x8__AARCH32_NEONFP16ARITH, k_eq_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(4) + .aStride(37) + .test(pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith); +} + +TEST(HGEMM_8x8__AARCH32_NEONFP16ARITH, k_eq_4_strided_c) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(4) + .cStride(17) + .test(pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith); +} + +TEST(HGEMM_8x8__AARCH32_NEONFP16ARITH, k_eq_4_qmin128) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(4).qmin(128).test( + pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith); +} + +TEST(HGEMM_8x8__AARCH32_NEONFP16ARITH, k_eq_4_qmax128) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(4).qmax(128).test( + pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith); +} + +TEST(HGEMM_8x8__AARCH32_NEONFP16ARITH, k_gt_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(k).test( + pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith); + } +} + +TEST(HGEMM_8x8__AARCH32_NEONFP16ARITH, k_gt_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(37) + .test(pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith); + } +} + +TEST(HGEMM_8x8__AARCH32_NEONFP16ARITH, k_gt_4_strided_c) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith); + } +} + +TEST(HGEMM_8x8__AARCH32_NEONFP16ARITH, k_gt_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 5; k < 8; k++) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith); + } + } + } +} + +TEST(HGEMM_8x8__AARCH32_NEONFP16ARITH, k_div_4) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k < 64; k += 4) { + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(k).test( + pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith); + } +} + +TEST(HGEMM_8x8__AARCH32_NEONFP16ARITH, k_div_4_strided_a) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k < 64; k += 4) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(171) + .test(pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith); + } +} + +TEST(HGEMM_8x8__AARCH32_NEONFP16ARITH, k_div_4_strided_c) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k < 64; k += 4) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith); + } +} + +TEST(HGEMM_8x8__AARCH32_NEONFP16ARITH, k_div_4_subtile) { + TEST_REQUIRES_ARM_NEON_FP16_ARITH; + for (size_t k = 8; k < 64; k += 12) { + for (uint32_t m = 1; m <= 1; m++) { + for (uint32_t n = 8; n <= 8; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_hgemm_ukernel_8x8__aarch32_neonfp16arith); + } + } + } +} +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/leaky-relu-operator-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/leaky-relu-operator-tester.h new file mode 100644 index 0000000000000..57bbe7851c704 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/leaky-relu-operator-tester.h @@ -0,0 +1,240 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +class LeakyReLUOperatorTester { + public: + inline LeakyReLUOperatorTester& channels(size_t channels) { + assert(channels != 0); + this->channels_ = channels; + return *this; + } + + inline size_t channels() const { + return this->channels_; + } + + inline LeakyReLUOperatorTester& inputStride(size_t inputStride) { + assert(inputStride != 0); + this->inputStride_ = inputStride; + return *this; + } + + inline size_t inputStride() const { + if (this->inputStride_ == 0) { + return this->channels_; + } else { + assert(this->inputStride_ >= this->channels_); + return this->inputStride_; + } + } + + inline LeakyReLUOperatorTester& outputStride(size_t outputStride) { + assert(outputStride != 0); + this->outputStride_ = outputStride; + return *this; + } + + inline size_t outputStride() const { + if (this->outputStride_ == 0) { + return this->channels_; + } else { + assert(this->outputStride_ >= this->channels_); + return this->outputStride_; + } + } + + inline LeakyReLUOperatorTester& batchSize(size_t batchSize) { + this->batchSize_ = batchSize; + return *this; + } + + inline size_t batchSize() const { + return this->batchSize_; + } + + inline LeakyReLUOperatorTester& negativeSlope(float negativeSlope) { + assert(negativeSlope > 0.0f); + assert(negativeSlope < 1.0f); + this->negativeSlope_ = negativeSlope; + return *this; + } + + inline float negativeSlope() const { + return this->negativeSlope_; + } + + inline LeakyReLUOperatorTester& inputScale(float inputScale) { + assert(inputScale > 0.0f); + assert(std::isnormal(inputScale)); + this->inputScale_ = inputScale; + return *this; + } + + inline float inputScale() const { + return this->inputScale_; + } + + inline LeakyReLUOperatorTester& inputZeroPoint(uint8_t inputZeroPoint) { + this->inputZeroPoint_ = inputZeroPoint; + return *this; + } + + inline uint8_t inputZeroPoint() const { + return this->inputZeroPoint_; + } + + inline LeakyReLUOperatorTester& outputScale(float outputScale) { + assert(outputScale > 0.0f); + assert(std::isnormal(outputScale)); + this->outputScale_ = outputScale; + return *this; + } + + inline float outputScale() const { + return this->outputScale_; + } + + inline LeakyReLUOperatorTester& outputZeroPoint(uint8_t outputZeroPoint) { + this->outputZeroPoint_ = outputZeroPoint; + return *this; + } + + inline uint8_t outputZeroPoint() const { + return this->outputZeroPoint_; + } + + inline LeakyReLUOperatorTester& qmin(uint8_t qmin) { + this->qmin_ = qmin; + return *this; + } + + inline uint8_t qmin() const { + return this->qmin_; + } + + inline LeakyReLUOperatorTester& qmax(uint8_t qmax) { + this->qmax_ = qmax; + return *this; + } + + inline uint8_t qmax() const { + return this->qmax_; + } + + inline LeakyReLUOperatorTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void testQ8() const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input((batchSize() - 1) * inputStride() + channels()); + std::vector output( + (batchSize() - 1) * outputStride() + channels()); + std::vector outputRef(batchSize() * channels()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::fill(output.begin(), output.end(), 0xA5); + + /* Compute reference results */ + for (size_t i = 0; i < batchSize(); i++) { + for (size_t c = 0; c < channels(); c++) { + const float x = inputScale() * + (int32_t(input[i * inputStride() + c]) - + int32_t(inputZeroPoint())); + float y = (x < 0.0f ? x * negativeSlope() : x) / outputScale(); + y = std::min(y, int32_t(qmax()) - int32_t(outputZeroPoint())); + y = std::max(y, int32_t(qmin()) - int32_t(outputZeroPoint())); + outputRef[i * channels() + c] = y + float(int32_t(outputZeroPoint())); + } + } + + /* Create, setup, run, and destroy LeakyReLU operator */ + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + pytorch_qnnp_operator_t leakyReLUOp = nullptr; + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_create_leaky_relu_nc_q8( + channels(), + negativeSlope(), + inputZeroPoint(), + inputScale(), + outputZeroPoint(), + outputScale(), + qmin(), + qmax(), + 0, + &leakyReLUOp)); + ASSERT_NE(nullptr, leakyReLUOp); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_setup_leaky_relu_nc_q8( + leakyReLUOp, + batchSize(), + input.data(), + inputStride(), + output.data(), + outputStride())); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_run_operator(leakyReLUOp, nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_delete_operator(leakyReLUOp)); + leakyReLUOp = nullptr; + + /* Verify results */ + for (size_t i = 0; i < batchSize(); i++) { + for (size_t c = 0; c < channels(); c++) { + ASSERT_NEAR( + float(int32_t(output[i * outputStride() + c])), + outputRef[i * channels() + c], + 0.6f); + } + } + } + } + + private: + size_t batchSize_{1}; + size_t channels_{1}; + size_t inputStride_{0}; + size_t outputStride_{0}; + float negativeSlope_{0.5f}; + float outputScale_{0.75f}; + uint8_t outputZeroPoint_{133}; + float inputScale_{1.25f}; + uint8_t inputZeroPoint_{121}; + uint8_t qmin_{0}; + uint8_t qmax_{255}; + size_t iterations_{15}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/leaky-relu.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/leaky-relu.cc new file mode 100644 index 0000000000000..b760c46b2dc76 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/leaky-relu.cc @@ -0,0 +1,161 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include "leaky-relu-operator-tester.h" + +TEST(LEAKY_RELU_OP, zero_batch) { + LeakyReLUOperatorTester().batchSize(0).channels(2).iterations(1).testQ8(); +} + +TEST(LEAKY_RELU_OP, unit_batch) { + for (size_t channels = 1; channels < 100; channels++) { + LeakyReLUOperatorTester() + .batchSize(1) + .channels(channels) + .iterations(3) + .testQ8(); + } +} + +TEST(LEAKY_RELU_OP, unit_batch_with_qmin) { + for (size_t channels = 1; channels < 100; channels += 15) { + LeakyReLUOperatorTester() + .batchSize(1) + .channels(channels) + .qmin(128) + .iterations(3) + .testQ8(); + } +} + +TEST(LEAKY_RELU_OP, unit_batch_with_qmax) { + for (size_t channels = 1; channels < 100; channels += 15) { + LeakyReLUOperatorTester() + .batchSize(1) + .channels(channels) + .qmax(128) + .iterations(3) + .testQ8(); + } +} + +TEST(LEAKY_RELU_OP, unit_batch_with_negative_slope) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (float negativeSlope = 1.0e-4f; negativeSlope < 1.0f; + negativeSlope *= 3.14159265f) { + LeakyReLUOperatorTester() + .batchSize(1) + .channels(channels) + .negativeSlope(negativeSlope) + .iterations(1) + .testQ8(); + } + } +} + +TEST(LEAKY_RELU_OP, unit_batch_with_input_scale) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (float inputScale = 1.0e-2f; inputScale < 1.0e+2f; + inputScale *= 3.14159265f) { + LeakyReLUOperatorTester() + .batchSize(1) + .channels(channels) + .inputScale(inputScale) + .iterations(1) + .testQ8(); + } + } +} + +TEST(LEAKY_RELU_OP, unit_batch_with_input_zero_point) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (int32_t inputZeroPoint = 0; inputZeroPoint <= 255; + inputZeroPoint += 51) { + LeakyReLUOperatorTester() + .batchSize(1) + .channels(channels) + .inputZeroPoint(uint8_t(inputZeroPoint)) + .iterations(1) + .testQ8(); + } + } +} + +TEST(LEAKY_RELU_OP, unit_batch_with_output_scale) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (float outputScale = 1.0e-2f; outputScale < 1.0e+2f; + outputScale *= 3.14159265f) { + LeakyReLUOperatorTester() + .batchSize(1) + .channels(channels) + .outputScale(outputScale) + .iterations(1) + .testQ8(); + } + } +} + +TEST(LEAKY_RELU_OP, unit_batch_with_output_zero_point) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (int32_t outputZeroPoint = 0; outputZeroPoint <= 255; + outputZeroPoint += 51) { + LeakyReLUOperatorTester() + .batchSize(1) + .channels(channels) + .outputZeroPoint(uint8_t(outputZeroPoint)) + .iterations(1) + .testQ8(); + } + } +} + +TEST(LEAKY_RELU_OP, small_batch) { + for (size_t channels = 1; channels < 100; channels++) { + LeakyReLUOperatorTester() + .batchSize(3) + .channels(channels) + .iterations(3) + .testQ8(); + } +} + +TEST(LEAKY_RELU_OP, small_batch_with_input_stride) { + for (size_t channels = 1; channels < 100; channels += 15) { + LeakyReLUOperatorTester() + .batchSize(3) + .channels(channels) + .inputStride(129) + .iterations(3) + .testQ8(); + } +} + +TEST(LEAKY_RELU_OP, small_batch_with_output_stride) { + for (size_t channels = 1; channels < 100; channels += 15) { + LeakyReLUOperatorTester() + .batchSize(3) + .channels(channels) + .outputStride(117) + .iterations(3) + .testQ8(); + } +} + +TEST(LEAKY_RELU_OP, small_batch_with_input_and_output_stride) { + for (size_t channels = 1; channels < 100; channels += 15) { + LeakyReLUOperatorTester() + .batchSize(3) + .channels(channels) + .inputStride(129) + .outputStride(117) + .iterations(3) + .testQ8(); + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/lut-microkernel-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/lut-microkernel-tester.h new file mode 100644 index 0000000000000..79c07ea380cff --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/lut-microkernel-tester.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +class LUTMicrokernelTester { + public: + inline LUTMicrokernelTester& n(size_t n) { + assert(n != 0); + this->n_ = n; + return *this; + } + + inline size_t n() const { + return this->n_; + } + + inline LUTMicrokernelTester& inplace(bool inplace) { + this->inplace_ = inplace; + return *this; + } + + inline bool inplace() const { + return this->inplace_; + } + + inline LUTMicrokernelTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void test(pytorch_x8lut_ukernel_function x8lut) const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector x(n()); + std::vector t(256); + std::vector y(n()); + std::vector yRef(n()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(x.begin(), x.end(), std::ref(u8rng)); + std::generate(t.begin(), t.end(), std::ref(u8rng)); + if (inplace()) { + std::generate(y.begin(), y.end(), std::ref(u8rng)); + } else { + std::fill(y.begin(), y.end(), 0xA5); + } + const uint8_t* xData = inplace() ? y.data() : x.data(); + + /* Compute reference results */ + for (size_t i = 0; i < n(); i++) { + yRef[i] = t[xData[i]]; + } + + /* Call optimized micro-kernel */ + x8lut(n(), xData, t.data(), y.data()); + + /* Verify results */ + for (size_t i = 0; i < n(); i++) { + ASSERT_EQ(uint32_t(yRef[i]), uint32_t(y[i])) + << "at position " << i << ", n = " << n(); + } + } + } + + private: + size_t n_{1}; + bool inplace_{false}; + size_t iterations_{15}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/lut-norm-microkernel-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/lut-norm-microkernel-tester.h new file mode 100644 index 0000000000000..57debca8c793c --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/lut-norm-microkernel-tester.h @@ -0,0 +1,99 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +class LUTNormMicrokernelTester { + public: + inline LUTNormMicrokernelTester& n(size_t n) { + assert(n != 0); + this->n_ = n; + return *this; + } + + inline size_t n() const { + return this->n_; + } + + inline LUTNormMicrokernelTester& inplace(bool inplace) { + this->inplace_ = inplace; + return *this; + } + + inline bool inplace() const { + return this->inplace_; + } + + inline LUTNormMicrokernelTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void test(pytorch_u8lut32norm_ukernel_function u8lut32norm) const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + auto u32rng = std::bind( + std::uniform_int_distribution( + 1, std::numeric_limits::max() / (257 * n())), + rng); + + std::vector x(n()); + std::vector t(256); + std::vector y(n()); + std::vector yRef(n()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(x.begin(), x.end(), std::ref(u8rng)); + std::generate(t.begin(), t.end(), std::ref(u32rng)); + if (inplace()) { + std::generate(y.begin(), y.end(), std::ref(u8rng)); + } else { + std::fill(y.begin(), y.end(), 0xA5); + } + const uint8_t* xData = inplace() ? y.data() : x.data(); + + /* Compute reference results */ + uint32_t sum = 0; + for (size_t i = 0; i < n(); i++) { + sum += t[xData[i]]; + } + for (size_t i = 0; i < n(); i++) { + yRef[i] = 256.0f * float(t[xData[i]]) / float(sum); + yRef[i] = std::min(yRef[i], 255.0f); + } + + /* Call optimized micro-kernel */ + u8lut32norm(n(), xData, t.data(), y.data()); + + /* Verify results */ + for (size_t i = 0; i < n(); i++) { + ASSERT_NEAR(yRef[i], float(y[i]), 0.5f) + << "at position " << i << ", n = " << n() << ", sum = " << sum; + } + } + } + + private: + size_t n_{1}; + bool inplace_{false}; + size_t iterations_{15}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/max-pooling-operator-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/max-pooling-operator-tester.h new file mode 100644 index 0000000000000..7c17919df75ea --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/max-pooling-operator-tester.h @@ -0,0 +1,800 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +class MaxPoolingOperatorTester { + public: + inline MaxPoolingOperatorTester& padding(uint32_t padding) { + this->paddingTop_ = padding; + this->paddingRight_ = padding; + this->paddingBottom_ = padding; + this->paddingLeft_ = padding; + return *this; + } + + inline MaxPoolingOperatorTester& padding( + uint32_t paddingHeight, + uint32_t paddingWidth) { + this->paddingTop_ = paddingHeight; + this->paddingRight_ = paddingWidth; + this->paddingBottom_ = paddingHeight; + this->paddingLeft_ = paddingWidth; + return *this; + } + + inline MaxPoolingOperatorTester& paddingHeight(uint32_t paddingHeight) { + this->paddingTop_ = paddingHeight; + this->paddingBottom_ = paddingHeight; + return *this; + } + + inline MaxPoolingOperatorTester& paddingWidth(uint32_t paddingWidth) { + this->paddingRight_ = paddingWidth; + this->paddingLeft_ = paddingWidth; + return *this; + } + + inline MaxPoolingOperatorTester& paddingTop(uint32_t paddingTop) { + this->paddingTop_ = paddingTop; + return *this; + } + + inline uint32_t paddingTop() const { + return this->paddingTop_; + } + + inline MaxPoolingOperatorTester& paddingRight(uint32_t paddingRight) { + this->paddingRight_ = paddingRight; + return *this; + } + + inline uint32_t paddingRight() const { + return this->paddingRight_; + } + + inline MaxPoolingOperatorTester& paddingBottom(uint32_t paddingBottom) { + this->paddingBottom_ = paddingBottom; + return *this; + } + + inline uint32_t paddingBottom() const { + return this->paddingBottom_; + } + + inline MaxPoolingOperatorTester& paddingLeft(uint32_t paddingLeft) { + this->paddingLeft_ = paddingLeft; + return *this; + } + + inline uint32_t paddingLeft() const { + return this->paddingLeft_; + } + + inline MaxPoolingOperatorTester& inputSize( + size_t inputHeight, + size_t inputWidth) { + assert(inputHeight >= 1); + assert(inputWidth >= 1); + this->inputHeight_ = inputHeight; + this->inputWidth_ = inputWidth; + return *this; + } + + inline MaxPoolingOperatorTester& inputHeight(size_t inputHeight) { + assert(inputHeight >= 1); + this->inputHeight_ = inputHeight; + return *this; + } + + inline size_t inputHeight() const { + return this->inputHeight_; + } + + inline MaxPoolingOperatorTester& inputWidth(size_t inputWidth) { + assert(inputWidth >= 1); + this->inputWidth_ = inputWidth; + return *this; + } + + inline size_t inputWidth() const { + return this->inputWidth_; + } + + inline MaxPoolingOperatorTester& channels(size_t channels) { + assert(channels != 0); + this->channels_ = channels; + return *this; + } + + inline size_t channels() const { + return this->channels_; + } + + inline MaxPoolingOperatorTester& batchSize(size_t batchSize) { + this->batchSize_ = batchSize; + return *this; + } + + inline size_t batchSize() const { + return this->batchSize_; + } + + inline MaxPoolingOperatorTester& poolingSize(uint32_t poolingSize) { + assert(poolingSize >= 1); + this->poolingHeight_ = poolingSize; + this->poolingWidth_ = poolingSize; + return *this; + } + + inline MaxPoolingOperatorTester& poolingSize( + uint32_t poolingHeight, + uint32_t poolingWidth) { + assert(poolingHeight >= 1); + assert(poolingWidth >= 1); + this->poolingHeight_ = poolingHeight; + this->poolingWidth_ = poolingWidth; + return *this; + } + + inline MaxPoolingOperatorTester& poolingHeight(uint32_t poolingHeight) { + assert(poolingHeight >= 1); + this->poolingHeight_ = poolingHeight; + return *this; + } + + inline uint32_t poolingHeight() const { + return this->poolingHeight_; + } + + inline MaxPoolingOperatorTester& poolingWidth(uint32_t poolingWidth) { + assert(poolingWidth >= 1); + this->poolingWidth_ = poolingWidth; + return *this; + } + + inline uint32_t poolingWidth() const { + return this->poolingWidth_; + } + + inline MaxPoolingOperatorTester& stride(uint32_t stride) { + assert(stride >= 1); + this->strideHeight_ = stride; + this->strideWidth_ = stride; + return *this; + } + + inline MaxPoolingOperatorTester& stride( + uint32_t strideHeight, + uint32_t strideWidth) { + assert(strideHeight >= 1); + assert(strideWidth >= 1); + this->strideHeight_ = strideHeight; + this->strideWidth_ = strideWidth; + return *this; + } + + inline MaxPoolingOperatorTester& strideHeight(uint32_t strideHeight) { + assert(strideHeight >= 1); + this->strideHeight_ = strideHeight; + return *this; + } + + inline uint32_t strideHeight() const { + return this->strideHeight_; + } + + inline MaxPoolingOperatorTester& strideWidth(uint32_t strideWidth) { + assert(strideWidth >= 1); + this->strideWidth_ = strideWidth; + return *this; + } + + inline uint32_t strideWidth() const { + return this->strideWidth_; + } + + inline MaxPoolingOperatorTester& dilation(uint32_t dilation) { + assert(dilation >= 1); + this->dilationHeight_ = dilation; + this->dilationWidth_ = dilation; + return *this; + } + + inline MaxPoolingOperatorTester& dilation( + uint32_t dilationHeight, + uint32_t dilationWidth) { + assert(dilationHeight >= 1); + assert(dilationWidth >= 1); + this->dilationHeight_ = dilationHeight; + this->dilationWidth_ = dilationWidth; + return *this; + } + + inline MaxPoolingOperatorTester& dilationHeight(uint32_t dilationHeight) { + assert(dilationHeight >= 1); + this->dilationHeight_ = dilationHeight; + return *this; + } + + inline uint32_t dilationHeight() const { + return this->dilationHeight_; + } + + inline MaxPoolingOperatorTester& dilationWidth(uint32_t dilationWidth) { + assert(dilationWidth >= 1); + this->dilationWidth_ = dilationWidth; + return *this; + } + + inline uint32_t dilationWidth() const { + return this->dilationWidth_; + } + + inline uint32_t dilatedPoolingHeight() const { + return (poolingHeight() - 1) * dilationHeight() + 1; + } + + inline uint32_t dilatedPoolingWidth() const { + return (poolingWidth() - 1) * dilationWidth() + 1; + } + + inline size_t outputHeight() const { + const size_t paddedInputHeight = + paddingTop() + inputHeight() + paddingBottom(); + if (paddedInputHeight <= dilatedPoolingHeight()) { + return 1; + } else { + return (paddedInputHeight - dilatedPoolingHeight()) / strideHeight() + 1; + } + } + + inline size_t outputWidth() const { + const size_t paddedInputWidth = + paddingLeft() + inputWidth() + paddingRight(); + if (paddedInputWidth <= dilatedPoolingWidth()) { + return 1; + } else { + return (paddedInputWidth - dilatedPoolingWidth()) / strideWidth() + 1; + } + } + + inline MaxPoolingOperatorTester& inputPixelStride(size_t inputPixelStride) { + assert(inputPixelStride != 0); + this->inputPixelStride_ = inputPixelStride; + return *this; + } + + inline size_t inputPixelStride() const { + if (this->inputPixelStride_ == 0) { + return channels(); + } else { + assert(this->inputPixelStride_ >= channels()); + return this->inputPixelStride_; + } + } + + inline MaxPoolingOperatorTester& outputPixelStride(size_t outputPixelStride) { + assert(outputPixelStride != 0); + this->outputPixelStride_ = outputPixelStride; + return *this; + } + + inline size_t outputPixelStride() const { + if (this->outputPixelStride_ == 0) { + return channels(); + } else { + assert(this->outputPixelStride_ >= channels()); + return this->outputPixelStride_; + } + } + + inline MaxPoolingOperatorTester& nextInputSize( + uint32_t nextInputHeight, + uint32_t nextInputWidth) { + assert(nextInputHeight >= 1); + assert(nextInputWidth >= 1); + this->nextInputHeight_ = nextInputHeight; + this->nextInputWidth_ = nextInputWidth; + return *this; + } + + inline MaxPoolingOperatorTester& nextInputHeight(uint32_t nextInputHeight) { + assert(nextInputHeight >= 1); + this->nextInputHeight_ = nextInputHeight; + return *this; + } + + inline uint32_t nextInputHeight() const { + if (this->nextInputHeight_ == 0) { + return inputHeight(); + } else { + return this->nextInputHeight_; + } + } + + inline MaxPoolingOperatorTester& nextInputWidth(uint32_t nextInputWidth) { + assert(nextInputWidth >= 1); + this->nextInputWidth_ = nextInputWidth; + return *this; + } + + inline uint32_t nextInputWidth() const { + if (this->nextInputWidth_ == 0) { + return inputWidth(); + } else { + return this->nextInputWidth_; + } + } + + inline size_t nextOutputHeight() const { + const size_t paddedNextInputHeight = + paddingTop() + nextInputHeight() + paddingBottom(); + if (paddedNextInputHeight <= dilatedPoolingHeight()) { + return 1; + } else { + return (paddedNextInputHeight - dilatedPoolingHeight()) / strideHeight() + + 1; + } + } + + inline size_t nextOutputWidth() const { + const size_t paddedNextInputWidth = + paddingLeft() + nextInputWidth() + paddingRight(); + if (paddedNextInputWidth <= dilatedPoolingWidth()) { + return 1; + } else { + return (paddedNextInputWidth - dilatedPoolingWidth()) / strideWidth() + 1; + } + } + + inline MaxPoolingOperatorTester& nextBatchSize(size_t nextBatchSize) { + assert(nextBatchSize >= 1); + this->nextBatchSize_ = nextBatchSize; + return *this; + } + + inline size_t nextBatchSize() const { + if (this->nextBatchSize_ == 0) { + return batchSize(); + } else { + return this->nextBatchSize_; + } + } + + inline MaxPoolingOperatorTester& qmin(uint8_t qmin) { + this->qmin_ = qmin; + return *this; + } + + inline uint8_t qmin() const { + return this->qmin_; + } + + inline MaxPoolingOperatorTester& qmax(uint8_t qmax) { + this->qmax_ = qmax; + return *this; + } + + inline uint8_t qmax() const { + return this->qmax_; + } + + inline MaxPoolingOperatorTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void testU8() const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input( + (batchSize() * inputHeight() * inputWidth() - 1) * inputPixelStride() + + channels()); + std::vector output( + (batchSize() * outputHeight() * outputWidth() - 1) * + outputPixelStride() + + channels()); + std::vector outputRef( + batchSize() * outputHeight() * outputWidth() * channels()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::fill(output.begin(), output.end(), 0xA5); + + /* Compute reference results */ + for (size_t i = 0; i < batchSize(); i++) { + for (size_t oy = 0; oy < outputHeight(); oy++) { + for (size_t ox = 0; ox < outputWidth(); ox++) { + for (size_t c = 0; c < channels(); c++) { + uint8_t maxValue = 0; + for (size_t py = 0; py < poolingHeight(); py++) { + const size_t iy = + oy * strideHeight() + py * dilationHeight() - paddingTop(); + for (size_t px = 0; px < poolingWidth(); px++) { + const size_t ix = + ox * strideWidth() + px * dilationWidth() - paddingLeft(); + if (ix < inputWidth() && iy < inputHeight()) { + maxValue = std::max( + maxValue, + input + [((i * inputHeight() + iy) * inputWidth() + ix) * + inputPixelStride() + + c]); + } + } + } + maxValue = std::min(maxValue, qmax()); + maxValue = std::max(maxValue, qmin()); + outputRef + [((i * outputHeight() + oy) * outputWidth() + ox) * + channels() + + c] = maxValue; + } + } + } + } + + /* Create, setup, run, and destroy Max Pooling operator */ + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + pytorch_qnnp_operator_t maxPoolingOp = nullptr; + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_create_max_pooling2d_nhwc_u8( + paddingTop(), + paddingRight(), + paddingBottom(), + paddingLeft(), + poolingHeight(), + poolingWidth(), + strideHeight(), + strideWidth(), + dilationHeight(), + dilationWidth(), + channels(), + qmin(), + qmax(), + 0, + &maxPoolingOp)); + ASSERT_NE(nullptr, maxPoolingOp); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_setup_max_pooling2d_nhwc_u8( + maxPoolingOp, + batchSize(), + inputHeight(), + inputWidth(), + input.data(), + inputPixelStride(), + output.data(), + outputPixelStride(), + nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_run_operator(maxPoolingOp, nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_delete_operator(maxPoolingOp)); + maxPoolingOp = nullptr; + + /* Verify results */ + for (size_t i = 0; i < batchSize(); i++) { + for (size_t y = 0; y < outputHeight(); y++) { + for (size_t x = 0; x < outputWidth(); x++) { + for (size_t c = 0; c < channels(); c++) { + ASSERT_LE( + uint32_t(output + [((i * outputHeight() + y) * outputWidth() + x) * + outputPixelStride() + + c]), + uint32_t(qmax())); + ASSERT_GE( + uint32_t(output + [((i * outputHeight() + y) * outputWidth() + x) * + outputPixelStride() + + c]), + uint32_t(qmin())); + ASSERT_EQ( + uint32_t(outputRef + [((i * outputHeight() + y) * outputWidth() + x) * + channels() + + c]), + uint32_t(output + [((i * outputHeight() + y) * outputWidth() + x) * + outputPixelStride() + + c])) + << "in batch index " << i << ", pixel (" << y << ", " << x + << "), channel " << c; + } + } + } + } + } + } + + void testSetupU8() const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input(std::max( + (batchSize() * inputHeight() * inputWidth() - 1) * inputPixelStride() + + channels(), + (nextBatchSize() * nextInputHeight() * nextInputWidth() - 1) * + inputPixelStride() + + channels())); + std::vector output(std::max( + (batchSize() * outputHeight() * outputWidth() - 1) * + outputPixelStride() + + channels(), + (nextBatchSize() * nextOutputHeight() * nextOutputWidth() - 1) * + outputPixelStride() + + channels())); + std::vector outputRef( + batchSize() * outputHeight() * outputWidth() * channels()); + std::vector nextOutputRef( + nextBatchSize() * nextOutputHeight() * nextOutputWidth() * channels()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::fill(output.begin(), output.end(), 0xA5); + + /* Compute reference results */ + for (size_t i = 0; i < batchSize(); i++) { + for (size_t oy = 0; oy < outputHeight(); oy++) { + for (size_t ox = 0; ox < outputWidth(); ox++) { + for (size_t c = 0; c < channels(); c++) { + uint8_t maxValue = 0; + for (size_t py = 0; py < poolingHeight(); py++) { + const size_t iy = + oy * strideHeight() + py * dilationHeight() - paddingTop(); + for (size_t px = 0; px < poolingWidth(); px++) { + const size_t ix = + ox * strideWidth() + px * dilationWidth() - paddingLeft(); + if (ix < inputWidth() && iy < inputHeight()) { + maxValue = std::max( + maxValue, + input + [((i * inputHeight() + iy) * inputWidth() + ix) * + inputPixelStride() + + c]); + } + } + } + maxValue = std::min(maxValue, qmax()); + maxValue = std::max(maxValue, qmin()); + outputRef + [((i * outputHeight() + oy) * outputWidth() + ox) * + channels() + + c] = maxValue; + } + } + } + } + + /* Create, setup, and run Max Pooling operator once */ + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + pytorch_qnnp_operator_t maxPoolingOp = nullptr; + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_create_max_pooling2d_nhwc_u8( + paddingTop(), + paddingRight(), + paddingBottom(), + paddingLeft(), + poolingHeight(), + poolingWidth(), + strideHeight(), + strideWidth(), + dilationHeight(), + dilationWidth(), + channels(), + qmin(), + qmax(), + 0, + &maxPoolingOp)); + ASSERT_NE(nullptr, maxPoolingOp); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_setup_max_pooling2d_nhwc_u8( + maxPoolingOp, + batchSize(), + inputHeight(), + inputWidth(), + input.data(), + inputPixelStride(), + output.data(), + outputPixelStride(), + nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_run_operator(maxPoolingOp, nullptr /* thread pool */)); + + /* Verify results of the first run */ + for (size_t i = 0; i < batchSize(); i++) { + for (size_t y = 0; y < outputHeight(); y++) { + for (size_t x = 0; x < outputWidth(); x++) { + for (size_t c = 0; c < channels(); c++) { + ASSERT_LE( + uint32_t(output + [((i * outputHeight() + y) * outputWidth() + x) * + outputPixelStride() + + c]), + uint32_t(qmax())); + ASSERT_GE( + uint32_t(output + [((i * outputHeight() + y) * outputWidth() + x) * + outputPixelStride() + + c]), + uint32_t(qmin())); + ASSERT_EQ( + uint32_t(outputRef + [((i * outputHeight() + y) * outputWidth() + x) * + channels() + + c]), + uint32_t(output + [((i * outputHeight() + y) * outputWidth() + x) * + outputPixelStride() + + c])) + << "in batch index " << i << ", pixel (" << y << ", " << x + << "), channel " << c; + } + } + } + } + + /* Re-generate data for the second run */ + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::fill(output.begin(), output.end(), 0xA5); + + /* Compute reference results for the second run */ + for (size_t i = 0; i < nextBatchSize(); i++) { + for (size_t oy = 0; oy < nextOutputHeight(); oy++) { + for (size_t ox = 0; ox < nextOutputWidth(); ox++) { + for (size_t c = 0; c < channels(); c++) { + uint8_t maxValue = 0; + for (size_t py = 0; py < poolingHeight(); py++) { + const size_t iy = + oy * strideHeight() + py * dilationHeight() - paddingTop(); + for (size_t px = 0; px < poolingWidth(); px++) { + const size_t ix = + ox * strideWidth() + px * dilationWidth() - paddingLeft(); + if (ix < nextInputWidth() && iy < nextInputHeight()) { + maxValue = std::max( + maxValue, + input + [((i * nextInputHeight() + iy) * nextInputWidth() + + ix) * + inputPixelStride() + + c]); + } + } + } + maxValue = std::min(maxValue, qmax()); + maxValue = std::max(maxValue, qmin()); + nextOutputRef + [((i * nextOutputHeight() + oy) * nextOutputWidth() + ox) * + channels() + + c] = maxValue; + } + } + } + } + + /* Setup and run Max Pooling operator the second time, and destroy the + * operator */ + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_setup_max_pooling2d_nhwc_u8( + maxPoolingOp, + nextBatchSize(), + nextInputHeight(), + nextInputWidth(), + input.data(), + inputPixelStride(), + output.data(), + outputPixelStride(), + nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_run_operator(maxPoolingOp, nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_delete_operator(maxPoolingOp)); + maxPoolingOp = nullptr; + + /* Verify results of the second run */ + for (size_t i = 0; i < nextBatchSize(); i++) { + for (size_t y = 0; y < nextOutputHeight(); y++) { + for (size_t x = 0; x < nextOutputWidth(); x++) { + for (size_t c = 0; c < channels(); c++) { + ASSERT_LE( + uint32_t( + output + [((i * nextOutputHeight() + y) * nextOutputWidth() + + x) * + outputPixelStride() + + c]), + uint32_t(qmax())); + ASSERT_GE( + uint32_t( + output + [((i * nextOutputHeight() + y) * nextOutputWidth() + + x) * + outputPixelStride() + + c]), + uint32_t(qmin())); + ASSERT_EQ( + uint32_t( + nextOutputRef + [((i * nextOutputHeight() + y) * nextOutputWidth() + + x) * + channels() + + c]), + uint32_t( + output + [((i * nextOutputHeight() + y) * nextOutputWidth() + + x) * + outputPixelStride() + + c])) + << "in batch index " << i << ", pixel (" << y << ", " << x + << "), channel " << c; + } + } + } + } + } + } + + private: + uint32_t paddingTop_{0}; + uint32_t paddingRight_{0}; + uint32_t paddingBottom_{0}; + uint32_t paddingLeft_{0}; + size_t inputHeight_{1}; + size_t inputWidth_{1}; + size_t channels_{1}; + size_t batchSize_{1}; + size_t inputPixelStride_{0}; + size_t outputPixelStride_{0}; + uint32_t poolingHeight_{1}; + uint32_t poolingWidth_{1}; + uint32_t strideHeight_{1}; + uint32_t strideWidth_{1}; + uint32_t dilationHeight_{1}; + uint32_t dilationWidth_{1}; + size_t nextInputHeight_{0}; + size_t nextInputWidth_{0}; + size_t nextBatchSize_{0}; + uint8_t qmin_{0}; + uint8_t qmax_{255}; + size_t iterations_{1}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/max-pooling.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/max-pooling.cc new file mode 100644 index 0000000000000..281929539ab40 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/max-pooling.cc @@ -0,0 +1,1215 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include "max-pooling-operator-tester.h" + +#include + +TEST(MAX_POOLING_OP, zero_batch) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + MaxPoolingOperatorTester() + .batchSize(0) + .inputHeight(2) + .inputWidth(6) + .poolingHeight(1) + .poolingWidth(8) + .channels(8) + .testU8(); +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_small_1xM_pool) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.u8maxpool.mr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_small_1xM_pool_with_padding) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = 3; poolSize <= pytorch_qnnp_params.u8maxpool.mr; + poolSize++) { + for (size_t paddingLeft = 0; paddingLeft <= 1; paddingLeft++) { + for (size_t paddingRight = 0; paddingRight <= 1; paddingRight++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .paddingLeft(paddingLeft) + .paddingRight(paddingRight) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .testU8(); + } + } + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_small_1xM_pool_with_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.u8maxpool.mr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 4) + .poolingHeight(1) + .poolingWidth(poolSize) + .strideWidth(2) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_small_1xM_pool_with_dilation) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.u8maxpool.mr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(2 * poolSize + 1) + .poolingHeight(1) + .poolingWidth(poolSize) + .dilationWidth(2) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_small_Mx1_pool) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.u8maxpool.mr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_small_Mx1_pool_with_padding) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.u8maxpool.mr; + poolSize++) { + for (size_t paddingTop = 0; paddingTop <= 1; paddingTop++) { + for (size_t paddingBottom = 0; paddingBottom <= 1; paddingBottom++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .paddingTop(paddingTop) + .paddingBottom(paddingBottom) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testU8(); + } + } + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_small_Mx1_pool_with_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.u8maxpool.mr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 3) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .strideHeight(2) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_small_Mx1_pool_with_dilation) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.u8maxpool.mr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2 * poolSize) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .dilationHeight(2) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_small_pool_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.u8maxpool.mr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_small_pool_with_output_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.u8maxpool.mr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_small_pool_with_qmin) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.u8maxpool.mr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .qmin(192) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .qmin(192) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_small_pool_with_qmax) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.u8maxpool.mr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .qmax(192) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .qmax(192) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_large_1xM_pool) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = pytorch_qnnp_params.u8maxpool.mr; poolSize <= + pytorch_qnnp_params.u8maxpool.mr + pytorch_qnnp_params.u8maxpool.qr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_large_1xM_pool_with_padding) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = 3; poolSize <= pytorch_qnnp_params.u8maxpool.mr; + poolSize++) { + for (size_t paddingLeft = 0; paddingLeft <= 1; paddingLeft++) { + for (size_t paddingRight = 0; paddingRight <= 1; paddingRight++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .paddingLeft(paddingLeft) + .paddingRight(paddingRight) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .testU8(); + } + } + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_large_1xM_pool_with_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = pytorch_qnnp_params.u8maxpool.mr; poolSize <= + pytorch_qnnp_params.u8maxpool.mr + pytorch_qnnp_params.u8maxpool.qr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 4) + .poolingHeight(1) + .poolingWidth(poolSize) + .strideWidth(2) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_large_1xM_pool_with_dilation) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = pytorch_qnnp_params.u8maxpool.mr; poolSize <= + pytorch_qnnp_params.u8maxpool.mr + pytorch_qnnp_params.u8maxpool.qr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(2 * poolSize + 1) + .poolingHeight(1) + .poolingWidth(poolSize) + .dilationWidth(2) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_large_Mx1_pool) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = pytorch_qnnp_params.u8maxpool.mr; poolSize <= + pytorch_qnnp_params.u8maxpool.mr + pytorch_qnnp_params.u8maxpool.qr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_large_Mx1_pool_with_padding) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = pytorch_qnnp_params.u8maxpool.mr; poolSize <= + pytorch_qnnp_params.u8maxpool.mr + pytorch_qnnp_params.u8maxpool.qr; + poolSize++) { + for (size_t paddingTop = 0; paddingTop <= 1; paddingTop++) { + for (size_t paddingBottom = 0; paddingBottom <= 1; paddingBottom++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .paddingTop(paddingTop) + .paddingBottom(paddingBottom) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testU8(); + } + } + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_large_Mx1_pool_with_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = pytorch_qnnp_params.u8maxpool.mr; poolSize <= + pytorch_qnnp_params.u8maxpool.mr + pytorch_qnnp_params.u8maxpool.qr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 3) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .strideHeight(2) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_large_Mx1_pool_with_dilation) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = pytorch_qnnp_params.u8maxpool.mr; poolSize <= + pytorch_qnnp_params.u8maxpool.mr + pytorch_qnnp_params.u8maxpool.qr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2 * poolSize) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .dilationHeight(2) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_large_pool_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = pytorch_qnnp_params.u8maxpool.mr; poolSize <= + pytorch_qnnp_params.u8maxpool.mr + pytorch_qnnp_params.u8maxpool.qr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_large_pool_with_output_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = pytorch_qnnp_params.u8maxpool.mr; poolSize <= + pytorch_qnnp_params.u8maxpool.mr + pytorch_qnnp_params.u8maxpool.qr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_large_pool_with_qmin) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = pytorch_qnnp_params.u8maxpool.mr; poolSize <= + pytorch_qnnp_params.u8maxpool.mr + pytorch_qnnp_params.u8maxpool.qr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .qmin(192) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .qmin(192) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_many_channels_large_pool_with_qmax) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = pytorch_qnnp_params.u8maxpool.mr; poolSize <= + pytorch_qnnp_params.u8maxpool.mr + pytorch_qnnp_params.u8maxpool.qr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .qmax(192) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .qmax(192) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_few_channels_1xM_pool) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.u8maxpool.kr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_few_channels_1xM_pool_with_padding) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = 3; poolSize <= pytorch_qnnp_params.u8maxpool.mr; + poolSize++) { + for (size_t paddingLeft = 0; paddingLeft <= 1; paddingLeft++) { + for (size_t paddingRight = 0; paddingRight <= 1; paddingRight++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .paddingLeft(paddingLeft) + .paddingRight(paddingRight) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .testU8(); + } + } + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_few_channels_1xM_pool_with_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.u8maxpool.kr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 4) + .poolingHeight(1) + .poolingWidth(poolSize) + .strideWidth(2) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_few_channels_1xM_pool_with_dilation) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.u8maxpool.kr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(2 * poolSize + 1) + .poolingHeight(1) + .poolingWidth(poolSize) + .dilationWidth(2) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_few_channels_Mx1_pool) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.u8maxpool.kr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_few_channels_Mx1_pool_with_padding) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.u8maxpool.kr; + poolSize++) { + for (size_t paddingTop = 0; paddingTop <= 1; paddingTop++) { + for (size_t paddingBottom = 0; paddingBottom <= 1; paddingBottom++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .paddingTop(paddingTop) + .paddingBottom(paddingBottom) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testU8(); + } + } + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_few_channels_Mx1_pool_with_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.u8maxpool.kr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 3) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .strideHeight(2) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_few_channels_Mx1_pool_with_dilation) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.u8maxpool.kr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2 * poolSize) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .dilationHeight(2) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_few_channels_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.u8maxpool.kr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_few_channels_with_output_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.u8maxpool.kr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_few_channels_with_qmin) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.u8maxpool.kr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .qmin(192) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .qmin(192) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, unit_batch_few_channels_with_qmax) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.u8maxpool.kr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .qmax(192) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(1) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .qmax(192) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, small_batch_many_channels_small_pool) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.u8maxpool.mr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, small_batch_many_channels_small_pool_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.u8maxpool.mr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, small_batch_many_channels_small_pool_with_output_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 3) { + for (size_t poolSize = 2; poolSize <= pytorch_qnnp_params.u8maxpool.mr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, small_batch_many_channels_large_pool) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = pytorch_qnnp_params.u8maxpool.mr + 1; poolSize <= + pytorch_qnnp_params.u8maxpool.mr + pytorch_qnnp_params.u8maxpool.qr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, small_batch_many_channels_large_pool_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 5) { + for (size_t poolSize = pytorch_qnnp_params.u8maxpool.mr + 1; poolSize <= + pytorch_qnnp_params.u8maxpool.mr + pytorch_qnnp_params.u8maxpool.qr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, small_batch_many_channels_large_pool_with_output_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = pytorch_qnnp_params.u8maxpool.kr; + channels <= 3 * pytorch_qnnp_params.u8maxpool.kr; + channels += 5) { + for (size_t poolSize = pytorch_qnnp_params.u8maxpool.mr + 1; poolSize <= + pytorch_qnnp_params.u8maxpool.mr + pytorch_qnnp_params.u8maxpool.qr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, small_batch_few_channels) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.u8maxpool.kr; + poolSize++) { + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, small_batch_few_channels_with_input_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.u8maxpool.kr; + poolSize += 3) { + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .inputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, small_batch_few_channels_with_output_stride) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + for (size_t channels = 1; channels < pytorch_qnnp_params.u8maxpool.kr; + channels++) { + for (size_t poolSize = 2; poolSize <= 2 * pytorch_qnnp_params.u8maxpool.kr; + poolSize += 3) { + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(poolSize + 1) + .inputWidth(3) + .poolingHeight(poolSize) + .poolingWidth(1) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(2) + .inputWidth(poolSize + 2) + .poolingHeight(1) + .poolingWidth(poolSize) + .channels(channels) + .outputPixelStride(5 * pytorch_qnnp_params.u8maxpool.kr) + .testU8(); + } + } +} + +TEST(MAX_POOLING_OP, setup_increasing_batch) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + MaxPoolingOperatorTester() + .batchSize(3) + .nextBatchSize(5) + .inputHeight(8) + .inputWidth(8) + .poolingHeight(5) + .poolingWidth(3) + .channels(24) + .testSetupU8(); +} + +TEST(MAX_POOLING_OP, setup_decreasing_batch) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + MaxPoolingOperatorTester() + .batchSize(5) + .nextBatchSize(3) + .inputHeight(8) + .inputWidth(8) + .poolingHeight(5) + .poolingWidth(3) + .channels(24) + .testSetupU8(); +} + +TEST(MAX_POOLING_OP, setup_changing_height) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(8) + .inputWidth(8) + .nextInputHeight(9) + .poolingHeight(5) + .poolingWidth(3) + .channels(24) + .testSetupU8(); + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(8) + .inputWidth(8) + .nextInputHeight(7) + .poolingHeight(5) + .poolingWidth(3) + .channels(24) + .testSetupU8(); +} + +TEST(MAX_POOLING_OP, setup_changing_width) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(8) + .inputWidth(8) + .nextInputWidth(9) + .poolingHeight(5) + .poolingWidth(3) + .channels(24) + .testSetupU8(); + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(8) + .inputWidth(8) + .nextInputWidth(7) + .poolingHeight(5) + .poolingWidth(3) + .channels(24) + .testSetupU8(); +} + +TEST(MAX_POOLING_OP, setup_swap_height_and_width) { + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + MaxPoolingOperatorTester() + .batchSize(3) + .inputHeight(9) + .inputWidth(8) + .nextInputHeight(8) + .nextInputWidth(9) + .poolingHeight(5) + .poolingWidth(3) + .channels(24) + .testSetupU8(); +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/maxpool-microkernel-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/maxpool-microkernel-tester.h new file mode 100644 index 0000000000000..e1583a2c058ef --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/maxpool-microkernel-tester.h @@ -0,0 +1,256 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +class MaxPoolMicrokernelTester { + public: + inline MaxPoolMicrokernelTester& n(size_t n) { + assert(n != 0); + this->n_ = n; + return *this; + } + + inline size_t n() const { + return this->n_; + } + + inline MaxPoolMicrokernelTester& s(size_t s) { + assert(s != 0); + this->s_ = s; + return *this; + } + + inline size_t s() const { + return this->s_; + } + + inline MaxPoolMicrokernelTester& kh(size_t kh) { + assert(kh != 0); + this->kh_ = kh; + return *this; + } + + inline size_t kh() const { + return this->kh_; + } + + inline MaxPoolMicrokernelTester& kw(size_t kw) { + assert(kw != 0); + this->kw_ = kw; + return *this; + } + + inline size_t kw() const { + return this->kw_; + } + + inline size_t ks() const { + return kh() * kw(); + } + + inline size_t packedKs() const { + if (kc() < kr()) { + return ks(); + } else if (ks() <= mr()) { + return mr(); + } else { + return (ks() - mr()) % qr() == 0 + ? ks() + : ((ks() - mr()) / qr() + 1) * qr() + mr(); + } + } + + inline MaxPoolMicrokernelTester& mr(size_t mr) { + assert(mr != 0); + this->mr_ = mr; + return *this; + } + + inline size_t mr() const { + return this->mr_; + } + + inline MaxPoolMicrokernelTester& qr(size_t qr) { + assert(qr != 0); + this->qr_ = qr; + return *this; + } + + inline size_t qr() const { + return this->qr_; + } + + inline MaxPoolMicrokernelTester& kc(size_t kc) { + assert(kc != 0); + this->kc_ = kc; + return *this; + } + + inline size_t kc() const { + return this->kc_; + } + + inline MaxPoolMicrokernelTester& kr(size_t kr) { + assert(kr != 0); + this->kr_ = kr; + return *this; + } + + inline size_t kr() const { + return this->kr_; + } + + inline size_t packedN() const { + return kc() % kr() == 0 ? kc() : (kc() / kr() + 1) * kr(); + } + + inline MaxPoolMicrokernelTester& xStride(size_t xStride) { + assert(xStride != 0); + this->xStride_ = xStride; + return *this; + } + + inline size_t xStride() const { + if (this->xStride_ == 0) { + return kc(); + } else { + assert(this->xStride_ >= kc()); + return this->xStride_; + } + } + + inline MaxPoolMicrokernelTester& yStride(size_t yStride) { + assert(yStride != 0); + this->yStride_ = yStride; + return *this; + } + + inline size_t yStride() const { + if (this->yStride_ == 0) { + return kc(); + } else { + assert(this->yStride_ >= kc()); + return this->yStride_; + } + } + + inline MaxPoolMicrokernelTester& qmin(uint8_t qmin) { + this->qmin_ = qmin; + return *this; + } + + inline uint8_t qmin() const { + return this->qmin_; + } + + inline MaxPoolMicrokernelTester& qmax(uint8_t qmax) { + this->qmax_ = qmax; + return *this; + } + + inline uint8_t qmax() const { + return this->qmax_; + } + + inline MaxPoolMicrokernelTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void test(pytorch_u8maxpool_ukernel_function u8maxpool) const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector indirectX(packedKs() + (n() * s() - 1) * kh()); + std::vector x((indirectX.size() - 1) * xStride() + kc()); + + std::vector zero(kc()); + std::vector y((n() - 1) * yStride() + kc()); + std::vector yRef(n() * kc()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(x.begin(), x.end(), std::ref(u8rng)); + std::fill(y.begin(), y.end(), 0xA5); + + for (size_t i = 0; i < indirectX.size(); i++) { + indirectX[i] = x.data() + i * xStride(); + } + std::shuffle(indirectX.begin(), indirectX.end(), rng); + + /* Prepare quantization parameters */ + const union pytorch_qnnp_u8_clamping_params clampingParams = + pytorch_qnnp_compute_u8_clamping_params(qmin(), qmax()); + + /* Compute reference results */ + for (size_t i = 0; i < n(); i++) { + for (size_t k = 0; k < kc(); k++) { + uint8_t maxValue = 0; + for (size_t j = 0; j < ks(); j++) { + maxValue = std::max(maxValue, indirectX[i * s() * kh() + j][k]); + } + maxValue = std::min(maxValue, qmax()); + maxValue = std::max(maxValue, qmin()); + yRef[i * kc() + k] = maxValue; + } + } + + /* Call optimized micro-kernel */ + u8maxpool( + n(), + ks(), + kc(), + indirectX.data(), + y.data(), + (kh() * s() - packedKs()) * sizeof(void*), + (yStride() - kc()) * sizeof(uint8_t), + &clampingParams); + + /* Verify results */ + for (size_t i = 0; i < n(); i++) { + for (size_t k = 0; k < kc(); k++) { + ASSERT_EQ( + uint32_t(yRef[i * kc() + k]), uint32_t(y[i * yStride() + k])) + << "at pixel " << i << ", channel " << k << ", n = " << n() + << ", ks = " << kh() << "x" << kw() << " (" << ks() + << "), kc = " << kc(); + } + } + } + } + + private: + size_t n_{1}; + size_t s_{1}; + size_t kh_{1}; + size_t kw_{1}; + size_t mr_{1}; + size_t qr_{1}; + size_t kc_{1}; + size_t kr_{1}; + size_t xStride_{0}; + size_t yStride_{0}; + uint8_t qmin_{0}; + uint8_t qmax_{255}; + size_t iterations_{15}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/q8avgpool.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/q8avgpool.cc new file mode 100644 index 0000000000000..4bfeea1878142 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/q8avgpool.cc @@ -0,0 +1,1965 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include "avgpool-microkernel-tester.h" + +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 +TEST(Q8AVGPOOL_UP8xM__NEON, kc_lt_8_small_ks) { + TEST_REQUIRES_ARM_NEON; + for (size_t kc = 1; kc < 8; kc++) { + for (size_t ks = 1; ks < 8; ks++) { + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + AvgPoolMicrokernelTester().kr(8).kh(kh).kw(kw).kc(kc).test( + pytorch_q8avgpool_ukernel_up8xm__neon); + } + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__NEON, kc_lt_8_large_ks) { + TEST_REQUIRES_ARM_NEON; + for (size_t kc = 1; kc < 8; kc++) { + for (size_t ks = 8; ks < 16; ks++) { + AvgPoolMicrokernelTester().kr(8).kh(ks).kw(1).kc(kc).test( + pytorch_q8avgpool_ukernel_up8xm__neon); + AvgPoolMicrokernelTester().kr(8).kh(1).kw(ks).kc(kc).test( + pytorch_q8avgpool_ukernel_up8xm__neon); + } + } +} + +TEST(Q8AVGPOOL_UP8xM__NEON, kc_lt_8_with_x_scale) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n <= 3; n += 2) { + for (size_t kc = 1; kc < 8; kc++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (float xScale = 0.01f; xScale < 100.0f; xScale *= 3.14159265f) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .xScale(xScale) + .iterations(1) + .test(pytorch_q8avgpool_ukernel_up8xm__neon); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__NEON, kc_lt_8_with_x_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n <= 3; n += 2) { + for (size_t kc = 1; kc < 8; kc++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (int32_t xZeroPoint = 0; xZeroPoint <= 255; xZeroPoint += 51) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .xZeroPoint(uint8_t(xZeroPoint)) + .iterations(1) + .test(pytorch_q8avgpool_ukernel_up8xm__neon); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__NEON, kc_lt_8_with_y_scale) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n <= 3; n += 2) { + for (size_t kc = 1; kc < 8; kc++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (float yScale = 0.01f; yScale < 100.0f; yScale *= 3.14159265f) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .yScale(yScale) + .iterations(1) + .test(pytorch_q8avgpool_ukernel_up8xm__neon); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__NEON, kc_lt_8_with_y_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n <= 3; n += 2) { + for (size_t kc = 1; kc < 8; kc++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .yZeroPoint(uint8_t(yZeroPoint)) + .iterations(1) + .test(pytorch_q8avgpool_ukernel_up8xm__neon); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__NEON, kc_lt_8_with_y_max) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n <= 3; n += 2) { + for (size_t kc = 1; kc < 8; kc++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMax(128) + .iterations(3) + .test(pytorch_q8avgpool_ukernel_up8xm__neon); + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__NEON, kc_lt_8_with_y_min) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n <= 3; n += 2) { + for (size_t kc = 1; kc < 8; kc++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMin(128) + .iterations(3) + .test(pytorch_q8avgpool_ukernel_up8xm__neon); + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__NEON, small_n) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t kc = 1; kc < 8; kc++) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .iterations(3) + .test(pytorch_q8avgpool_ukernel_up8xm__neon); + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__NEON, small_n_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t kc = 1; kc < 8; kc++) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .xStride(11) + .iterations(3) + .test(pytorch_q8avgpool_ukernel_up8xm__neon); + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__NEON, small_n_with_y_stride) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t kc = 1; kc < 8; kc++) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .yStride(13) + .iterations(3) + .test(pytorch_q8avgpool_ukernel_up8xm__neon); + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__NEON, small_n_with_s) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t s = 2; s <= 5; s++) { + for (size_t kc = 1; kc < 8; kc++) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .s(s) + .iterations(1) + .test(pytorch_q8avgpool_ukernel_up8xm__neon); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__NEON, kc_eq_8_fulltile) { + TEST_REQUIRES_ARM_NEON; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).kc(8); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + tester.kh(kh).kw(kw).test(pytorch_q8avgpool_ukernel_up8x9__neon); + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__NEON, kc_eq_8_subtile) { + TEST_REQUIRES_ARM_NEON; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).kc(8); + for (size_t ks = 2; ks < tester.mr(); ks++) { + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + tester.kh(kh).kw(kw).test(pytorch_q8avgpool_ukernel_up8x9__neon); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__NEON, kc_div_8_fulltile) { + TEST_REQUIRES_ARM_NEON; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 8; kc < 128; kc += 24) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_q8avgpool_ukernel_up8x9__neon); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__NEON, kc_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).iterations(3); + for (size_t ks = 2; ks < tester.mr(); ks++) { + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + for (size_t kc = 8; kc < 128; kc += 24) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_q8avgpool_ukernel_up8x9__neon); + } + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__NEON, kc_div_8_fulltile_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).iterations(3); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 8; kc < 128; kc += 24) { + tester.kh(kh).kw(kw).kc(kc).xStride(131).test( + pytorch_q8avgpool_ukernel_up8x9__neon); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__NEON, kc_gt_8_fulltile) { + TEST_REQUIRES_ARM_NEON; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 9; kc < 16; kc++) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_q8avgpool_ukernel_up8x9__neon); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__NEON, kc_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).iterations(3); + for (size_t ks = 2; ks < tester.mr(); ks++) { + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + for (size_t kc = 9; kc < 16; kc++) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_q8avgpool_ukernel_up8x9__neon); + } + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__NEON, kc_gt_8_fulltile_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).iterations(3); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 9; kc < 16; kc++) { + tester.kh(kh).kw(kw).kc(kc).xStride(23).test( + pytorch_q8avgpool_ukernel_up8x9__neon); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__NEON, kc_div_8_with_x_scale) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + for (float xScale = 0.01f; xScale < 100.0f; xScale *= 3.14159265f) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .n(n) + .kh(3) + .kw(3) + .kc(kc) + .xScale(xScale) + .iterations(2) + .test(pytorch_q8avgpool_ukernel_up8x9__neon); + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__NEON, kc_div_8_with_x_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + for (int32_t xZeroPoint = 0; xZeroPoint <= 255; xZeroPoint += 51) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .n(n) + .kh(3) + .kw(3) + .kc(kc) + .xZeroPoint(uint8_t(xZeroPoint)) + .iterations(3) + .test(pytorch_q8avgpool_ukernel_up8x9__neon); + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__NEON, kc_div_8_with_y_scale) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + for (float yScale = 0.01f; yScale < 100.0f; yScale *= 3.14159265f) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .n(n) + .kh(3) + .kw(3) + .kc(kc) + .yScale(yScale) + .iterations(2) + .test(pytorch_q8avgpool_ukernel_up8x9__neon); + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__NEON, kc_div_8_with_y_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .n(n) + .kh(3) + .kw(3) + .kc(kc) + .yZeroPoint(uint8_t(yZeroPoint)) + .iterations(3) + .test(pytorch_q8avgpool_ukernel_up8x9__neon); + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__NEON, kc_div_8_with_y_max) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .n(n) + .kh(3) + .kw(3) + .kc(kc) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMax(128) + .test(pytorch_q8avgpool_ukernel_up8x9__neon); + } + } +} + +TEST(Q8AVGPOOL_UP8x9__NEON, kc_div_8_with_y_min) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .n(n) + .kh(3) + .kw(3) + .kc(kc) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMin(128) + .test(pytorch_q8avgpool_ukernel_up8x9__neon); + } + } +} + +TEST(Q8AVGPOOL_UP8x9__NEON, small_n) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3}}) { + for (size_t kc = 8; kc < 25; kc += 5) { + AvgPoolMicrokernelTester().kr(8).mr(9).n(n).kh(ks).kw(ks).kc(kc).test( + pytorch_q8avgpool_ukernel_up8x9__neon); + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__NEON, small_n_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3}}) { + for (size_t kc = 8; kc < 25; kc += 5) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .xStride(29) + .test(pytorch_q8avgpool_ukernel_up8x9__neon); + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__NEON, small_n_with_y_stride) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3}}) { + for (size_t kc = 8; kc < 25; kc += 5) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .yStride(31) + .test(pytorch_q8avgpool_ukernel_up8x9__neon); + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__NEON, small_n_with_s) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3}}) { + for (size_t kc = 8; kc < 25; kc += 5) { + for (size_t s = 2; s <= ks; s++) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .s(s) + .test(pytorch_q8avgpool_ukernel_up8x9__neon); + } + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_eq_8_twopass_fulltile) { + TEST_REQUIRES_ARM_NEON; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).kc(8); + const size_t ks = tester.mr() + tester.qr(); + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + tester.kh(kh).kw(kw).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_eq_8_twopass_subtile) { + TEST_REQUIRES_ARM_NEON; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).kc(8); + for (size_t ks = 10; ks < tester.mr() + tester.qr(); ks++) { + tester.kh(ks).kw(1).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + tester.kh(1).kw(ks).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_eq_8_multipass_fulltile) { + TEST_REQUIRES_ARM_NEON; + for (size_t ks : std::vector{{25, 49}}) { + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).kc(8); + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + tester.kh(kh).kw(kw).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_eq_8_multipass_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t ksMax : std::vector{{25, 49}}) { + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).kc(8); + for (size_t ks = ksMax - tester.qr() + 1; ks < ksMax; ks++) { + tester.kh(ks).kw(1).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + tester.kh(1).kw(ks).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_div_8_twopass_fulltile) { + TEST_REQUIRES_ARM_NEON; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + const size_t ks = 17; + for (size_t kc = 8; kc < 128; kc += 24) { + tester.kc(kc).kh(ks).kw(1).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + tester.kc(kc).kh(1).kw(ks).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_div_8_twopass_subtile) { + TEST_REQUIRES_ARM_NEON; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + for (size_t ks = 10; ks < tester.mr() + tester.qr(); ks++) { + for (size_t kc = 8; kc < 128; kc += 24) { + tester.kc(kc).kh(ks).kw(1).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + tester.kc(kc).kh(1).kw(ks).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_div_8_twopass_fulltile_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + const size_t ks = tester.mr() + tester.qr(); + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + for (size_t kc = 8; kc < 128; kc += 24) { + tester.kh(kh).kw(kw).kc(kc).xStride(131).test( + pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_div_8_multipass_fulltile) { + TEST_REQUIRES_ARM_NEON; + for (size_t ks : std::vector{{25, 49}}) { + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + for (size_t kc = 8; kc < 128; kc += 24) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_div_8_multipass_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t ksMax : std::vector{{25, 49}}) { + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + for (size_t ks = ksMax - tester.qr() + 1; ks < ksMax; ks++) { + for (size_t kc = 8; kc < 128; kc += 24) { + tester.kc(kc).kh(ks).kw(1).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + tester.kc(kc).kh(1).kw(ks).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_div_8_multipass_fulltile_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + for (size_t ks : std::vector{{25, 49}}) { + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + for (size_t kc = 8; kc < 128; kc += 24) { + tester.kh(kh).kw(kw).kc(kc).xStride(131).test( + pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_gt_8_twopass_fulltile) { + TEST_REQUIRES_ARM_NEON; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + const size_t ks = tester.mr() + tester.qr(); + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + for (size_t kc = 9; kc < 16; kc++) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_gt_8_twopass_subtile) { + TEST_REQUIRES_ARM_NEON; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + for (size_t ks = 10; ks < tester.mr() + tester.qr(); ks++) { + for (size_t kc = 9; kc < 16; kc++) { + tester.kc(kc).kh(ks).kw(1).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + tester.kc(kc).kh(1).kw(ks).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_gt_8_twopass_fulltile_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + const size_t ks = tester.mr() + tester.qr(); + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + for (size_t kc = 9; kc < 16; kc++) { + tester.kh(kh).kw(kw).kc(kc).xStride(23).test( + pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_gt_8_multipass_fulltile) { + TEST_REQUIRES_ARM_NEON; + for (size_t ks : std::vector{{25, 49}}) { + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + for (size_t kc = 9; kc < 16; kc++) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_gt_8_multipass_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t ksMax : std::vector{{25, 49}}) { + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + for (size_t ks = ksMax - tester.qr() + 1; ks < ksMax; ks++) { + for (size_t kc = 9; kc < 16; kc++) { + tester.kc(kc).kh(ks).kw(1).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + tester.kc(kc).kh(1).kw(ks).test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_gt_8_multipass_fulltile_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + for (size_t ks : std::vector{{25, 49}}) { + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + for (size_t kc = 9; kc < 16; kc++) { + tester.kh(kh).kw(kw).kc(kc).xStride(23).test( + pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_div_8_with_x_scale) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + for (float xScale = 0.01f; xScale < 100.0f; xScale *= 3.14159265f) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(5) + .kw(5) + .kc(kc) + .xScale(xScale) + .iterations(1) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_div_8_with_x_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + for (int32_t xZeroPoint = 0; xZeroPoint <= 255; xZeroPoint += 51) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(5) + .kw(5) + .kc(kc) + .xZeroPoint(uint8_t(xZeroPoint)) + .iterations(1) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_div_8_with_y_scale) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + for (float yScale = 0.01f; yScale < 100.0f; yScale *= 3.14159265f) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(5) + .kw(5) + .kc(kc) + .yScale(yScale) + .iterations(1) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_div_8_with_y_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(5) + .kw(5) + .kc(kc) + .yZeroPoint(uint8_t(yZeroPoint)) + .iterations(1) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_div_8_with_y_max) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(5) + .kw(5) + .kc(kc) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMax(128) + .iterations(3) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, kc_div_8_with_y_min) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(5) + .kw(5) + .kc(kc) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMin(128) + .iterations(3) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, small_n) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{5, 7}}) { + for (size_t kc = 8; kc < 25; kc += 5) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, small_n_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{5, 7}}) { + for (size_t kc = 8; kc < 25; kc += 5) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .xStride(29) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, small_n_with_y_stride) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{5, 7}}) { + for (size_t kc = 8; kc < 25; kc += 5) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .yStride(31) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__NEON, small_n_with_s) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{5, 7}}) { + for (size_t s = 2; s <= 5; s++) { + for (size_t kc = 8; kc < 25; kc += 5) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .s(s) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__neon); + } + } + } + } +} +#endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */ + +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +TEST(Q8AVGPOOL_UP8xM__SSE2, kc_lt_8_small_ks) { + TEST_REQUIRES_X86_SSE2; + for (size_t kc = 1; kc < 8; kc++) { + for (size_t ks = 1; ks < 8; ks++) { + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + AvgPoolMicrokernelTester().kr(8).kh(kh).kw(kw).kc(kc).test( + pytorch_q8avgpool_ukernel_up8xm__sse2); + } + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__SSE2, kc_lt_8_large_ks) { + TEST_REQUIRES_X86_SSE2; + for (size_t kc = 1; kc < 8; kc++) { + for (size_t ks = 8; ks < 16; ks++) { + AvgPoolMicrokernelTester().kr(8).kh(ks).kw(1).kc(kc).test( + pytorch_q8avgpool_ukernel_up8xm__sse2); + AvgPoolMicrokernelTester().kr(8).kh(1).kw(ks).kc(kc).test( + pytorch_q8avgpool_ukernel_up8xm__sse2); + } + } +} + +TEST(Q8AVGPOOL_UP8xM__SSE2, kc_lt_8_with_x_scale) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n <= 3; n += 2) { + for (size_t kc = 1; kc < 8; kc++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (float xScale = 0.01f; xScale < 100.0f; xScale *= 3.14159265f) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .xScale(xScale) + .iterations(1) + .test(pytorch_q8avgpool_ukernel_up8xm__sse2); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__SSE2, kc_lt_8_with_x_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n <= 3; n += 2) { + for (size_t kc = 1; kc < 8; kc++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (int32_t xZeroPoint = 0; xZeroPoint <= 255; xZeroPoint += 51) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .xZeroPoint(uint8_t(xZeroPoint)) + .iterations(1) + .test(pytorch_q8avgpool_ukernel_up8xm__sse2); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__SSE2, kc_lt_8_with_y_scale) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n <= 3; n += 2) { + for (size_t kc = 1; kc < 8; kc++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (float yScale = 0.01f; yScale < 100.0f; yScale *= 3.14159265f) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .yScale(yScale) + .iterations(1) + .test(pytorch_q8avgpool_ukernel_up8xm__sse2); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__SSE2, kc_lt_8_with_y_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n <= 3; n += 2) { + for (size_t kc = 1; kc < 8; kc++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .yZeroPoint(uint8_t(yZeroPoint)) + .iterations(1) + .test(pytorch_q8avgpool_ukernel_up8xm__sse2); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__SSE2, kc_lt_8_with_y_max) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n <= 3; n += 2) { + for (size_t kc = 1; kc < 8; kc++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMax(128) + .iterations(3) + .test(pytorch_q8avgpool_ukernel_up8xm__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__SSE2, kc_lt_8_with_y_min) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n <= 3; n += 2) { + for (size_t kc = 1; kc < 8; kc++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMin(128) + .iterations(3) + .test(pytorch_q8avgpool_ukernel_up8xm__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__SSE2, small_n) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t kc = 1; kc < 8; kc++) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .iterations(3) + .test(pytorch_q8avgpool_ukernel_up8xm__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__SSE2, small_n_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t kc = 1; kc < 8; kc++) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .xStride(11) + .iterations(3) + .test(pytorch_q8avgpool_ukernel_up8xm__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__SSE2, small_n_with_y_stride) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t kc = 1; kc < 8; kc++) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .yStride(13) + .iterations(3) + .test(pytorch_q8avgpool_ukernel_up8xm__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_UP8xM__SSE2, small_n_with_s) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t s = 2; s <= 5; s++) { + for (size_t kc = 1; kc < 8; kc++) { + AvgPoolMicrokernelTester() + .kr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .s(s) + .iterations(1) + .test(pytorch_q8avgpool_ukernel_up8xm__sse2); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__SSE2, kc_eq_8_fulltile) { + TEST_REQUIRES_X86_SSE2; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).kc(8); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + tester.kh(kh).kw(kw).test(pytorch_q8avgpool_ukernel_up8x9__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__SSE2, kc_eq_8_subtile) { + TEST_REQUIRES_X86_SSE2; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).kc(8); + for (size_t ks = 2; ks < tester.mr(); ks++) { + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + tester.kh(kh).kw(kw).test(pytorch_q8avgpool_ukernel_up8x9__sse2); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__SSE2, kc_div_8_fulltile) { + TEST_REQUIRES_X86_SSE2; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 8; kc < 128; kc += 24) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_q8avgpool_ukernel_up8x9__sse2); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__SSE2, kc_div_8_subtile) { + TEST_REQUIRES_X86_SSE2; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).iterations(3); + for (size_t ks = 2; ks < tester.mr(); ks++) { + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + for (size_t kc = 8; kc < 128; kc += 24) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_q8avgpool_ukernel_up8x9__sse2); + } + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__SSE2, kc_div_8_fulltile_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).iterations(3); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 8; kc < 128; kc += 24) { + tester.kh(kh).kw(kw).kc(kc).xStride(131).test( + pytorch_q8avgpool_ukernel_up8x9__sse2); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__SSE2, kc_gt_8_fulltile) { + TEST_REQUIRES_X86_SSE2; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 9; kc < 16; kc++) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_q8avgpool_ukernel_up8x9__sse2); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__SSE2, kc_gt_8_subtile) { + TEST_REQUIRES_X86_SSE2; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).iterations(3); + for (size_t ks = 2; ks < tester.mr(); ks++) { + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + for (size_t kc = 9; kc < 16; kc++) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_q8avgpool_ukernel_up8x9__sse2); + } + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__SSE2, kc_gt_8_fulltile_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).iterations(3); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 9; kc < 16; kc++) { + tester.kh(kh).kw(kw).kc(kc).xStride(23).test( + pytorch_q8avgpool_ukernel_up8x9__sse2); + } + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__SSE2, kc_div_8_with_x_scale) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + for (float xScale = 0.01f; xScale < 100.0f; xScale *= 3.14159265f) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .n(n) + .kh(3) + .kw(3) + .kc(kc) + .xScale(xScale) + .iterations(2) + .test(pytorch_q8avgpool_ukernel_up8x9__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__SSE2, kc_div_8_with_x_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + for (int32_t xZeroPoint = 0; xZeroPoint <= 255; xZeroPoint += 51) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .n(n) + .kh(3) + .kw(3) + .kc(kc) + .xZeroPoint(uint8_t(xZeroPoint)) + .iterations(3) + .test(pytorch_q8avgpool_ukernel_up8x9__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__SSE2, kc_div_8_with_y_scale) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + for (float yScale = 0.01f; yScale < 100.0f; yScale *= 3.14159265f) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .n(n) + .kh(3) + .kw(3) + .kc(kc) + .yScale(yScale) + .iterations(2) + .test(pytorch_q8avgpool_ukernel_up8x9__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__SSE2, kc_div_8_with_y_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .n(n) + .kh(3) + .kw(3) + .kc(kc) + .yZeroPoint(uint8_t(yZeroPoint)) + .iterations(3) + .test(pytorch_q8avgpool_ukernel_up8x9__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__SSE2, kc_div_8_with_y_max) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .n(n) + .kh(3) + .kw(3) + .kc(kc) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMax(128) + .test(pytorch_q8avgpool_ukernel_up8x9__sse2); + } + } +} + +TEST(Q8AVGPOOL_UP8x9__SSE2, kc_div_8_with_y_min) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .n(n) + .kh(3) + .kw(3) + .kc(kc) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMin(128) + .test(pytorch_q8avgpool_ukernel_up8x9__sse2); + } + } +} + +TEST(Q8AVGPOOL_UP8x9__SSE2, small_n) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3}}) { + for (size_t kc = 8; kc < 25; kc += 5) { + AvgPoolMicrokernelTester().kr(8).mr(9).n(n).kh(ks).kw(ks).kc(kc).test( + pytorch_q8avgpool_ukernel_up8x9__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__SSE2, small_n_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3}}) { + for (size_t kc = 8; kc < 25; kc += 5) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .xStride(29) + .test(pytorch_q8avgpool_ukernel_up8x9__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__SSE2, small_n_with_y_stride) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3}}) { + for (size_t kc = 8; kc < 25; kc += 5) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .yStride(31) + .test(pytorch_q8avgpool_ukernel_up8x9__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_UP8x9__SSE2, small_n_with_s) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3}}) { + for (size_t kc = 8; kc < 25; kc += 5) { + for (size_t s = 2; s <= ks; s++) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .s(s) + .test(pytorch_q8avgpool_ukernel_up8x9__sse2); + } + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_eq_8_twopass_fulltile) { + TEST_REQUIRES_X86_SSE2; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).kc(8); + const size_t ks = tester.mr() + tester.qr(); + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + tester.kh(kh).kw(kw).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_eq_8_twopass_subtile) { + TEST_REQUIRES_X86_SSE2; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).kc(8); + for (size_t ks = 10; ks < tester.mr() + tester.qr(); ks++) { + tester.kh(ks).kw(1).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + tester.kh(1).kw(ks).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_eq_8_multipass_fulltile) { + TEST_REQUIRES_X86_SSE2; + for (size_t ks : std::vector{{25, 49}}) { + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).kc(8); + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + tester.kh(kh).kw(kw).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_eq_8_multipass_subtile) { + TEST_REQUIRES_X86_SSE2; + for (size_t ksMax : std::vector{{25, 49}}) { + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).kc(8); + for (size_t ks = ksMax - tester.qr() + 1; ks < ksMax; ks++) { + tester.kh(ks).kw(1).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + tester.kh(1).kw(ks).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_div_8_twopass_fulltile) { + TEST_REQUIRES_X86_SSE2; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + const size_t ks = 17; + for (size_t kc = 8; kc < 128; kc += 24) { + tester.kc(kc).kh(ks).kw(1).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + tester.kc(kc).kh(1).kw(ks).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_div_8_twopass_subtile) { + TEST_REQUIRES_X86_SSE2; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + for (size_t ks = 10; ks < tester.mr() + tester.qr(); ks++) { + for (size_t kc = 8; kc < 128; kc += 24) { + tester.kc(kc).kh(ks).kw(1).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + tester.kc(kc).kh(1).kw(ks).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_div_8_twopass_fulltile_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + const size_t ks = tester.mr() + tester.qr(); + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + for (size_t kc = 8; kc < 128; kc += 24) { + tester.kh(kh).kw(kw).kc(kc).xStride(131).test( + pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_div_8_multipass_fulltile) { + TEST_REQUIRES_X86_SSE2; + for (size_t ks : std::vector{{25, 49}}) { + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + for (size_t kc = 8; kc < 128; kc += 24) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_div_8_multipass_subtile) { + TEST_REQUIRES_X86_SSE2; + for (size_t ksMax : std::vector{{25, 49}}) { + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + for (size_t ks = ksMax - tester.qr() + 1; ks < ksMax; ks++) { + for (size_t kc = 8; kc < 128; kc += 24) { + tester.kc(kc).kh(ks).kw(1).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + tester.kc(kc).kh(1).kw(ks).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_div_8_multipass_fulltile_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + for (size_t ks : std::vector{{25, 49}}) { + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + for (size_t kc = 8; kc < 128; kc += 24) { + tester.kh(kh).kw(kw).kc(kc).xStride(131).test( + pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_gt_8_twopass_fulltile) { + TEST_REQUIRES_X86_SSE2; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + const size_t ks = tester.mr() + tester.qr(); + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + for (size_t kc = 9; kc < 16; kc++) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_gt_8_twopass_subtile) { + TEST_REQUIRES_X86_SSE2; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + for (size_t ks = 10; ks < tester.mr() + tester.qr(); ks++) { + for (size_t kc = 9; kc < 16; kc++) { + tester.kc(kc).kh(ks).kw(1).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + tester.kc(kc).kh(1).kw(ks).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_gt_8_twopass_fulltile_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + const size_t ks = tester.mr() + tester.qr(); + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + for (size_t kc = 9; kc < 16; kc++) { + tester.kh(kh).kw(kw).kc(kc).xStride(23).test( + pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_gt_8_multipass_fulltile) { + TEST_REQUIRES_X86_SSE2; + for (size_t ks : std::vector{{25, 49}}) { + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + for (size_t kc = 9; kc < 16; kc++) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_gt_8_multipass_subtile) { + TEST_REQUIRES_X86_SSE2; + for (size_t ksMax : std::vector{{25, 49}}) { + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + for (size_t ks = ksMax - tester.qr() + 1; ks < ksMax; ks++) { + for (size_t kc = 9; kc < 16; kc++) { + tester.kc(kc).kh(ks).kw(1).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + tester.kc(kc).kh(1).kw(ks).test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_gt_8_multipass_fulltile_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + for (size_t ks : std::vector{{25, 49}}) { + auto tester = AvgPoolMicrokernelTester().kr(8).mr(9).qr(8).iterations(3); + for (size_t kh = 1; kh <= ks; kh++) { + for (size_t kw = 1; kw <= ks; kw++) { + if (kh * kw == ks) { + for (size_t kc = 9; kc < 16; kc++) { + tester.kh(kh).kw(kw).kc(kc).xStride(23).test( + pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_div_8_with_x_scale) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + for (float xScale = 0.01f; xScale < 100.0f; xScale *= 3.14159265f) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(5) + .kw(5) + .kc(kc) + .xScale(xScale) + .iterations(1) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_div_8_with_x_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + for (int32_t xZeroPoint = 0; xZeroPoint <= 255; xZeroPoint += 51) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(5) + .kw(5) + .kc(kc) + .xZeroPoint(uint8_t(xZeroPoint)) + .iterations(1) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_div_8_with_y_scale) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + for (float yScale = 0.01f; yScale < 100.0f; yScale *= 3.14159265f) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(5) + .kw(5) + .kc(kc) + .yScale(yScale) + .iterations(1) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_div_8_with_y_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(5) + .kw(5) + .kc(kc) + .yZeroPoint(uint8_t(yZeroPoint)) + .iterations(1) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_div_8_with_y_max) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(5) + .kw(5) + .kc(kc) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMax(128) + .iterations(3) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, kc_div_8_with_y_min) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n <= 5; n += 2) { + for (size_t kc = 8; kc < 128; kc += 24) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(5) + .kw(5) + .kc(kc) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMin(128) + .iterations(3) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, small_n) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{5, 7}}) { + for (size_t kc = 8; kc < 25; kc += 5) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, small_n_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{5, 7}}) { + for (size_t kc = 8; kc < 25; kc += 5) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .xStride(29) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, small_n_with_y_stride) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{5, 7}}) { + for (size_t kc = 8; kc < 25; kc += 5) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .yStride(31) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } +} + +TEST(Q8AVGPOOL_MP8x9P8Q__SSE2, small_n_with_s) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{5, 7}}) { + for (size_t s = 2; s <= 5; s++) { + for (size_t kc = 8; kc < 25; kc += 5) { + AvgPoolMicrokernelTester() + .kr(8) + .mr(9) + .qr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .s(s) + .test(pytorch_q8avgpool_ukernel_mp8x9p8q__sse2); + } + } + } + } +} +#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/q8conv.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/q8conv.cc new file mode 100644 index 0000000000000..2cc2efb82441d --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/q8conv.cc @@ -0,0 +1,1084 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include "gemm-microkernel-tester.h" + +#if CPUINFO_ARCH_ARM +TEST(Q8CONV_4x8__AARCH32_NEON, k_eq_8) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aStride(37) + .test(pytorch_q8conv_ukernel_4x8__aarch32_neon); +} + +TEST(Q8CONV_4x8__AARCH32_NEON, k_eq_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aStride(37) + .cStride(17) + .test(pytorch_q8conv_ukernel_4x8__aarch32_neon); +} + +TEST(Q8CONV_4x8__AARCH32_NEON, k_eq_8_qmin128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(1).m(4).n(8).k(8).qmin(128).test( + pytorch_q8conv_ukernel_4x8__aarch32_neon); +} + +TEST(Q8CONV_4x8__AARCH32_NEON, k_eq_8_qmax128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(1).m(4).n(8).k(8).qmax(128).test( + pytorch_q8conv_ukernel_4x8__aarch32_neon); +} + +TEST(Q8CONV_4x8__AARCH32_NEON, k_eq_8_azp_only) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aZeroPoint(255) + .bZeroPoint(0) + .test(pytorch_q8conv_ukernel_4x8__aarch32_neon); +} + +TEST(Q8CONV_4x8__AARCH32_NEON, k_eq_8_bzp_only) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aZeroPoint(0) + .bZeroPoint(255) + .test(pytorch_q8conv_ukernel_4x8__aarch32_neon); +} + +TEST(Q8CONV_4x8__AARCH32_NEON, k_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(37) + .test(pytorch_q8conv_ukernel_4x8__aarch32_neon); + } +} + +TEST(Q8CONV_4x8__AARCH32_NEON, k_gt_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(37) + .cStride(17) + .test(pytorch_q8conv_ukernel_4x8__aarch32_neon); + } +} + +TEST(Q8CONV_4x8__AARCH32_NEON, k_gt_8_azp_only) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(37) + .aZeroPoint(255) + .bZeroPoint(0) + .test(pytorch_q8conv_ukernel_4x8__aarch32_neon); + } +} + +TEST(Q8CONV_4x8__AARCH32_NEON, k_gt_8_bzp_only) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(37) + .aZeroPoint(0) + .bZeroPoint(255) + .test(pytorch_q8conv_ukernel_4x8__aarch32_neon); + } +} + +TEST(Q8CONV_4x8__AARCH32_NEON, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .aStride(37) + .iterations(3) + .test(pytorch_q8conv_ukernel_4x8__aarch32_neon); + } + } + } +} + +TEST(Q8CONV_4x8__AARCH32_NEON, k_div_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(171) + .test(pytorch_q8conv_ukernel_4x8__aarch32_neon); + } +} + +TEST(Q8CONV_4x8__AARCH32_NEON, k_div_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(171) + .cStride(17) + .test(pytorch_q8conv_ukernel_4x8__aarch32_neon); + } +} + +TEST(Q8CONV_4x8__AARCH32_NEON, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 24) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .aStride(171) + .iterations(3) + .test(pytorch_q8conv_ukernel_4x8__aarch32_neon); + } + } + } +} +#endif + +#if CPUINFO_ARCH_ARM64 +TEST(Q8CONV_8x8__AARCH64_NEON, k_eq_8) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .aStride(37) + .test(pytorch_q8conv_ukernel_8x8__aarch64_neon); +} + +TEST(Q8CONV_8x8__AARCH64_NEON, k_eq_8_strided_c) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .aStride(37) + .cStride(17) + .test(pytorch_q8conv_ukernel_8x8__aarch64_neon); +} + +TEST(Q8CONV_8x8__AARCH64_NEON, k_eq_8_qmin128) { + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(8).qmin(128).test( + pytorch_q8conv_ukernel_8x8__aarch64_neon); +} + +TEST(Q8CONV_8x8__AARCH64_NEON, k_eq_8_qmax128) { + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(8).qmax(128).test( + pytorch_q8conv_ukernel_8x8__aarch64_neon); +} + +TEST(Q8CONV_8x8__AARCH64_NEON, k_eq_8_azp_only) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .aZeroPoint(255) + .bZeroPoint(0) + .test(pytorch_q8conv_ukernel_8x8__aarch64_neon); +} + +TEST(Q8CONV_8x8__AARCH64_NEON, k_eq_8_bzp_only) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .aZeroPoint(0) + .bZeroPoint(255) + .test(pytorch_q8conv_ukernel_8x8__aarch64_neon); +} + +TEST(Q8CONV_8x8__AARCH64_NEON, k_gt_8) { + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(37) + .test(pytorch_q8conv_ukernel_8x8__aarch64_neon); + } +} + +TEST(Q8CONV_8x8__AARCH64_NEON, k_gt_8_strided_c) { + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(37) + .cStride(17) + .test(pytorch_q8conv_ukernel_8x8__aarch64_neon); + } +} + +TEST(Q8CONV_8x8__AARCH64_NEON, k_gt_8_azp_only) { + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(37) + .aZeroPoint(255) + .bZeroPoint(0) + .test(pytorch_q8conv_ukernel_8x8__aarch64_neon); + } +} + +TEST(Q8CONV_8x8__AARCH64_NEON, k_gt_8_bzp_only) { + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(37) + .aZeroPoint(0) + .bZeroPoint(255) + .test(pytorch_q8conv_ukernel_8x8__aarch64_neon); + } +} + +TEST(Q8CONV_8x8__AARCH64_NEON, k_gt_8_subtile) { + for (size_t k = 9; k < 16; k++) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .aStride(37) + .iterations(3) + .test(pytorch_q8conv_ukernel_8x8__aarch64_neon); + } + } + } +} + +TEST(Q8CONV_8x8__AARCH64_NEON, k_div_8) { + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(171) + .test(pytorch_q8conv_ukernel_8x8__aarch64_neon); + } +} + +TEST(Q8CONV_8x8__AARCH64_NEON, k_div_8_strided_c) { + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(171) + .cStride(17) + .test(pytorch_q8conv_ukernel_8x8__aarch64_neon); + } +} + +TEST(Q8CONV_8x8__AARCH64_NEON, k_div_8_subtile) { + for (size_t k = 16; k < 128; k += 24) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .aStride(171) + .iterations(3) + .test(pytorch_q8conv_ukernel_8x8__aarch64_neon); + } + } + } +} +#endif + +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 +TEST(Q8CONV_4x8__NEON, k_eq_8) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aStride(37) + .test(pytorch_q8conv_ukernel_4x8__neon); +} + +TEST(Q8CONV_4x8__NEON, k_eq_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aStride(37) + .cStride(17) + .test(pytorch_q8conv_ukernel_4x8__neon); +} + +TEST(Q8CONV_4x8__NEON, k_eq_8_qmin128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(1).m(4).n(8).k(8).qmin(128).test( + pytorch_q8conv_ukernel_4x8__neon); +} + +TEST(Q8CONV_4x8__NEON, k_eq_8_qmax128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(1).m(4).n(8).k(8).qmax(128).test( + pytorch_q8conv_ukernel_4x8__neon); +} + +TEST(Q8CONV_4x8__NEON, k_eq_8_azp_only) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aZeroPoint(255) + .bZeroPoint(0) + .test(pytorch_q8conv_ukernel_4x8__neon); +} + +TEST(Q8CONV_4x8__NEON, k_eq_8_bzp_only) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aZeroPoint(0) + .bZeroPoint(255) + .test(pytorch_q8conv_ukernel_4x8__neon); +} + +TEST(Q8CONV_4x8__NEON, k_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(37) + .test(pytorch_q8conv_ukernel_4x8__neon); + } +} + +TEST(Q8CONV_4x8__NEON, k_gt_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(37) + .cStride(17) + .test(pytorch_q8conv_ukernel_4x8__neon); + } +} + +TEST(Q8CONV_4x8__NEON, k_gt_8_azp_only) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(37) + .aZeroPoint(255) + .bZeroPoint(0) + .test(pytorch_q8conv_ukernel_4x8__neon); + } +} + +TEST(Q8CONV_4x8__NEON, k_gt_8_bzp_only) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(37) + .aZeroPoint(0) + .bZeroPoint(255) + .test(pytorch_q8conv_ukernel_4x8__neon); + } +} + +TEST(Q8CONV_4x8__NEON, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .aStride(37) + .iterations(3) + .test(pytorch_q8conv_ukernel_4x8__neon); + } + } + } +} + +TEST(Q8CONV_4x8__NEON, k_div_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(171) + .test(pytorch_q8conv_ukernel_4x8__neon); + } +} + +TEST(Q8CONV_4x8__NEON, k_div_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(171) + .cStride(17) + .test(pytorch_q8conv_ukernel_4x8__neon); + } +} + +TEST(Q8CONV_4x8__NEON, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 24) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .aStride(171) + .iterations(3) + .test(pytorch_q8conv_ukernel_4x8__neon); + } + } + } +} + +TEST(Q8CONV_8x8__NEON, k_eq_8) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .aStride(37) + .test(pytorch_q8conv_ukernel_8x8__neon); +} + +TEST(Q8CONV_8x8__NEON, k_eq_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .aStride(37) + .cStride(17) + .test(pytorch_q8conv_ukernel_8x8__neon); +} + +TEST(Q8CONV_8x8__NEON, k_eq_8_qmin128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(8).qmin(128).test( + pytorch_q8conv_ukernel_8x8__neon); +} + +TEST(Q8CONV_8x8__NEON, k_eq_8_qmax128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(8).qmax(128).test( + pytorch_q8conv_ukernel_8x8__neon); +} + +TEST(Q8CONV_8x8__NEON, k_eq_8_azp_only) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .aZeroPoint(255) + .bZeroPoint(0) + .test(pytorch_q8conv_ukernel_8x8__neon); +} + +TEST(Q8CONV_8x8__NEON, k_eq_8_bzp_only) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .aZeroPoint(0) + .bZeroPoint(255) + .test(pytorch_q8conv_ukernel_8x8__neon); +} + +TEST(Q8CONV_8x8__NEON, k_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(37) + .test(pytorch_q8conv_ukernel_8x8__neon); + } +} + +TEST(Q8CONV_8x8__NEON, k_gt_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(37) + .cStride(17) + .test(pytorch_q8conv_ukernel_8x8__neon); + } +} + +TEST(Q8CONV_8x8__NEON, k_gt_8_azp_only) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(37) + .aZeroPoint(255) + .bZeroPoint(0) + .test(pytorch_q8conv_ukernel_8x8__neon); + } +} + +TEST(Q8CONV_8x8__NEON, k_gt_8_bzp_only) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(37) + .aZeroPoint(0) + .bZeroPoint(255) + .test(pytorch_q8conv_ukernel_8x8__neon); + } +} + +TEST(Q8CONV_8x8__NEON, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .aStride(37) + .iterations(3) + .test(pytorch_q8conv_ukernel_8x8__neon); + } + } + } +} + +TEST(Q8CONV_8x8__NEON, k_div_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(171) + .test(pytorch_q8conv_ukernel_8x8__neon); + } +} + +TEST(Q8CONV_8x8__NEON, k_div_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(171) + .cStride(17) + .test(pytorch_q8conv_ukernel_8x8__neon); + } +} + +TEST(Q8CONV_8x8__NEON, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 24) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .aStride(171) + .iterations(3) + .test(pytorch_q8conv_ukernel_8x8__neon); + } + } + } +} +#endif + +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +TEST(Q8CONV_4x4c2__SSE2, k_eq_8) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(8) + .aStride(37) + .test(pytorch_q8conv_ukernel_4x4c2__sse2); +} + +TEST(Q8CONV_4x4c2__SSE2, k_eq_8_strided_c) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(8) + .aStride(37) + .cStride(17) + .test(pytorch_q8conv_ukernel_4x4c2__sse2); +} + +TEST(Q8CONV_4x4c2__SSE2, k_eq_8_qmin128) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester().mr(4).nr(4).np(4).kr(2).m(4).n(4).k(8).qmin(128).test( + pytorch_q8conv_ukernel_4x4c2__sse2); +} + +TEST(Q8CONV_4x4c2__SSE2, k_eq_8_qmax128) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester().mr(4).nr(4).np(4).kr(2).m(4).n(4).k(8).qmax(128).test( + pytorch_q8conv_ukernel_4x4c2__sse2); +} + +TEST(Q8CONV_4x4c2__SSE2, k_eq_8_azp_only) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(8) + .aZeroPoint(255) + .bZeroPoint(0) + .test(pytorch_q8conv_ukernel_4x4c2__sse2); +} + +TEST(Q8CONV_4x4c2__SSE2, k_eq_8_bzp_only) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(8) + .aZeroPoint(0) + .bZeroPoint(255) + .test(pytorch_q8conv_ukernel_4x4c2__sse2); +} + +TEST(Q8CONV_4x4c2__SSE2, k_gt_8) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(k) + .aStride(37) + .test(pytorch_q8conv_ukernel_4x4c2__sse2); + } +} + +TEST(Q8CONV_4x4c2__SSE2, k_gt_8_strided_c) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(k) + .aStride(37) + .cStride(17) + .test(pytorch_q8conv_ukernel_4x4c2__sse2); + } +} + +TEST(Q8CONV_4x4c2__SSE2, k_gt_8_azp_only) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(k) + .aStride(37) + .aZeroPoint(255) + .bZeroPoint(0) + .test(pytorch_q8conv_ukernel_4x4c2__sse2); + } +} + +TEST(Q8CONV_4x4c2__SSE2, k_gt_8_bzp_only) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(k) + .aStride(37) + .aZeroPoint(0) + .bZeroPoint(255) + .test(pytorch_q8conv_ukernel_4x4c2__sse2); + } +} + +TEST(Q8CONV_4x4c2__SSE2, k_gt_8_subtile) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(m) + .n(n) + .k(k) + .aStride(37) + .iterations(3) + .test(pytorch_q8conv_ukernel_4x4c2__sse2); + } + } + } +} + +TEST(Q8CONV_4x4c2__SSE2, k_div_8) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(k) + .aStride(171) + .test(pytorch_q8conv_ukernel_4x4c2__sse2); + } +} + +TEST(Q8CONV_4x4c2__SSE2, k_div_8_strided_c) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(k) + .aStride(171) + .cStride(17) + .test(pytorch_q8conv_ukernel_4x4c2__sse2); + } +} + +TEST(Q8CONV_4x4c2__SSE2, k_div_8_subtile) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 16; k < 128; k += 24) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(m) + .n(n) + .k(k) + .aStride(171) + .iterations(3) + .test(pytorch_q8conv_ukernel_4x4c2__sse2); + } + } + } +} +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/q8dwconv.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/q8dwconv.cc new file mode 100644 index 0000000000000..594c602a92458 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/q8dwconv.cc @@ -0,0 +1,1267 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include "dwconv-microkernel-tester.h" + +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 +TEST(Q8DWCONV_UP8x9__NEON, single_output_channels_eq_8) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(1) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); +} + +TEST(Q8DWCONV_UP8x9__NEON, single_output_channels_eq_8_with_qmin) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(1) + .qmin(128) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); +} + +TEST(Q8DWCONV_UP8x9__NEON, single_output_channels_eq_8_with_qmax) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(1) + .qmax(128) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); +} + +TEST( + Q8DWCONV_UP8x9__NEON, + single_output_channels_eq_8_with_input_zero_point_only) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(1) + .inputZeroPoint(255) + .kernelZeroPoint(0) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); +} + +TEST( + Q8DWCONV_UP8x9__NEON, + single_output_channels_eq_8_with_kernel_zero_point_only) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(1) + .inputZeroPoint(0) + .kernelZeroPoint(255) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); +} + +TEST(Q8DWCONV_UP8x9__NEON, multi_output_channels_eq_8) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(5) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); +} + +TEST(Q8DWCONV_UP8x9__NEON, multi_output_channels_eq_8_with_subsampling) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .subsampling(2) + .cr(8) + .channels(8) + .width(5) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); +} + +TEST(Q8DWCONV_UP8x9__NEON, multi_output_channels_eq_8_with_input_stride) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(5) + .inputStride(17) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); +} + +TEST(Q8DWCONV_UP8x9__NEON, multi_output_channels_eq_8_with_output_stride) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(5) + .outputStride(19) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); +} + +TEST(Q8DWCONV_UP8x9__NEON, single_output_channels_div_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 16; channels < 128; channels += 24) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(1) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); + } +} + +TEST(Q8DWCONV_UP8x9__NEON, multi_output_channels_div_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 16; channels < 128; channels += 24) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(5) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); + } +} + +TEST(Q8DWCONV_UP8x9__NEON, multi_output_channels_div_8_with_output_stride) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 16; channels < 128; channels += 24) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(5) + .outputStride(171) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); + } +} + +TEST(Q8DWCONV_UP8x9__NEON, single_output_channels_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(1) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); + } +} + +TEST(Q8DWCONV_UP8x9__NEON, single_output_channels_gt_8_with_qmin) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(1) + .qmin(128) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); + } +} + +TEST(Q8DWCONV_UP8x9__NEON, single_output_channels_gt_8_with_qmax) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(1) + .qmax(128) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); + } +} + +TEST( + Q8DWCONV_UP8x9__NEON, + single_output_channels_gt_8_with_input_zero_point_only) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(1) + .inputZeroPoint(255) + .kernelZeroPoint(0) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); + } +} + +TEST( + Q8DWCONV_UP8x9__NEON, + single_output_channels_gt_8_with_kernel_zero_point_only) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(1) + .inputZeroPoint(0) + .kernelZeroPoint(255) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); + } +} + +TEST(Q8DWCONV_UP8x9__NEON, multi_output_channels_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(5) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); + } +} + +TEST(Q8DWCONV_UP8x9__NEON, multi_output_channels_gt_8_with_output_stride) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(5) + .outputStride(17) + .test(pytorch_q8dwconv_ukernel_up8x9__neon); + } +} + +TEST(Q8DWCONV_MP8x25__NEON, single_output_channels_eq_8) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(8) + .width(1) + .test(pytorch_q8dwconv_ukernel_mp8x25__neon); +} + +TEST(Q8DWCONV_MP8x25__NEON, multi_output_channels_eq_8_with_subsampling) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .subsampling(2) + .cr(8) + .channels(8) + .width(5) + .test(pytorch_q8dwconv_ukernel_mp8x25__neon); +} + +TEST(Q8DWCONV_MP8x25__NEON, multi_output_channels_eq_8_with_input_stride) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(8) + .width(5) + .inputStride(17) + .test(pytorch_q8dwconv_ukernel_mp8x25__neon); +} + +TEST(Q8DWCONV_MP8x25__NEON, multi_output_channels_eq_8_with_output_stride) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(8) + .width(5) + .outputStride(19) + .test(pytorch_q8dwconv_ukernel_mp8x25__neon); +} + +TEST(Q8DWCONV_MP8x25__NEON, single_output_channels_eq_8_with_qmin) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(8) + .width(1) + .qmin(128) + .test(pytorch_q8dwconv_ukernel_mp8x25__neon); +} + +TEST(Q8DWCONV_MP8x25__NEON, single_output_channels_eq_8_with_qmax) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(8) + .width(1) + .qmax(128) + .test(pytorch_q8dwconv_ukernel_mp8x25__neon); +} + +TEST( + Q8DWCONV_MP8x25__NEON, + single_output_channels_eq_8_with_input_zero_point_only) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(8) + .width(1) + .inputZeroPoint(255) + .kernelZeroPoint(0) + .test(pytorch_q8dwconv_ukernel_mp8x25__neon); +} + +TEST( + Q8DWCONV_MP8x25__NEON, + single_output_channels_eq_8_with_kernel_zero_point_only) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(8) + .width(1) + .inputZeroPoint(0) + .kernelZeroPoint(255) + .test(pytorch_q8dwconv_ukernel_mp8x25__neon); +} + +TEST(Q8DWCONV_MP8x25__NEON, multi_output_channels_eq_8) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(8) + .width(3) + .test(pytorch_q8dwconv_ukernel_mp8x25__neon); +} + +TEST(Q8DWCONV_MP8x25__NEON, single_output_channels_div_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 16; channels < 128; channels += 24) { + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(channels) + .width(1) + .test(pytorch_q8dwconv_ukernel_mp8x25__neon); + } +} + +TEST(Q8DWCONV_MP8x25__NEON, multi_output_channels_div_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 16; channels < 128; channels += 24) { + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(channels) + .width(5) + .test(pytorch_q8dwconv_ukernel_mp8x25__neon); + } +} + +TEST(Q8DWCONV_MP8x25__NEON, multi_output_channels_div_8_with_output_stride) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 16; channels < 128; channels += 24) { + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(channels) + .width(5) + .outputStride(171) + .test(pytorch_q8dwconv_ukernel_mp8x25__neon); + } +} + +TEST(Q8DWCONV_MP8x25__NEON, single_output_channels_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(channels) + .width(1) + .test(pytorch_q8dwconv_ukernel_mp8x25__neon); + } +} + +TEST(Q8DWCONV_MP8x25__NEON, single_output_channels_gt_8_with_qmin) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(channels) + .width(1) + .qmin(128) + .test(pytorch_q8dwconv_ukernel_mp8x25__neon); + } +} + +TEST(Q8DWCONV_MP8x25__NEON, single_output_channels_gt_8_with_qmax) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(channels) + .width(1) + .qmax(128) + .test(pytorch_q8dwconv_ukernel_mp8x25__neon); + } +} + +TEST(Q8DWCONV_MP8x25__NEON, multi_output_channels_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(channels) + .width(5) + .test(pytorch_q8dwconv_ukernel_mp8x25__neon); + } +} + +TEST(Q8DWCONV_MP8x25__NEON, multi_output_channels_gt_8_with_output_stride) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(channels) + .width(5) + .outputStride(17) + .test(pytorch_q8dwconv_ukernel_mp8x25__neon); + } +} +#endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */ + +#if CPUINFO_ARCH_ARM +TEST(Q8DWCONV_UP8x9__AARCH32_NEON, single_output_channels_eq_8) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(1) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); +} + +TEST(Q8DWCONV_UP8x9__AARCH32_NEON, single_output_channels_eq_8_with_qmin) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(1) + .qmin(128) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); +} + +TEST(Q8DWCONV_UP8x9__AARCH32_NEON, single_output_channels_eq_8_with_qmax) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(1) + .qmax(128) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); +} + +TEST( + Q8DWCONV_UP8x9__AARCH32_NEON, + single_output_channels_eq_8_with_input_zero_point_only) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(1) + .inputZeroPoint(255) + .kernelZeroPoint(0) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); +} + +TEST( + Q8DWCONV_UP8x9__AARCH32_NEON, + single_output_channels_eq_8_with_kernel_zero_point_only) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(1) + .inputZeroPoint(0) + .kernelZeroPoint(255) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); +} + +TEST(Q8DWCONV_UP8x9__AARCH32_NEON, multi_output_channels_eq_8) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(5) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); +} + +TEST( + Q8DWCONV_UP8x9__AARCH32_NEON, + multi_output_channels_eq_8_with_subsampling) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .subsampling(2) + .cr(8) + .channels(8) + .width(5) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); +} + +TEST( + Q8DWCONV_UP8x9__AARCH32_NEON, + multi_output_channels_eq_8_with_input_stride) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(5) + .inputStride(17) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); +} + +TEST( + Q8DWCONV_UP8x9__AARCH32_NEON, + multi_output_channels_eq_8_with_output_stride) { + TEST_REQUIRES_ARM_NEON; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(5) + .outputStride(19) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); +} + +TEST(Q8DWCONV_UP8x9__AARCH32_NEON, single_output_channels_div_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 16; channels < 128; channels += 24) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(1) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); + } +} + +TEST(Q8DWCONV_UP8x9__AARCH32_NEON, multi_output_channels_div_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 16; channels < 128; channels += 24) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(5) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); + } +} + +TEST( + Q8DWCONV_UP8x9__AARCH32_NEON, + multi_output_channels_div_8_with_output_stride) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 16; channels < 128; channels += 24) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(5) + .outputStride(171) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); + } +} + +TEST(Q8DWCONV_UP8x9__AARCH32_NEON, single_output_channels_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(1) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); + } +} + +TEST(Q8DWCONV_UP8x9__AARCH32_NEON, single_output_channels_gt_8_with_qmin) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(1) + .qmin(128) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); + } +} + +TEST(Q8DWCONV_UP8x9__AARCH32_NEON, single_output_channels_gt_8_with_qmax) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(1) + .qmax(128) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); + } +} + +TEST( + Q8DWCONV_UP8x9__AARCH32_NEON, + single_output_channels_gt_8_with_input_zero_point_only) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(1) + .inputZeroPoint(255) + .kernelZeroPoint(0) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); + } +} + +TEST( + Q8DWCONV_UP8x9__AARCH32_NEON, + single_output_channels_gt_8_with_kernel_zero_point_only) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(1) + .inputZeroPoint(0) + .kernelZeroPoint(255) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); + } +} + +TEST(Q8DWCONV_UP8x9__AARCH32_NEON, multi_output_channels_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(5) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); + } +} + +TEST( + Q8DWCONV_UP8x9__AARCH32_NEON, + multi_output_channels_gt_8_with_output_stride) { + TEST_REQUIRES_ARM_NEON; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(5) + .outputStride(17) + .test(pytorch_q8dwconv_ukernel_up8x9__aarch32_neon); + } +} +#endif /* CPUINFO_ARCH_ARM */ + +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +TEST(Q8DWCONV_UP8x9__SSE2, single_output_channels_eq_8) { + TEST_REQUIRES_X86_SSE2; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(1) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); +} + +TEST(Q8DWCONV_UP8x9__SSE2, single_output_channels_eq_8_with_qmin) { + TEST_REQUIRES_X86_SSE2; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(1) + .qmin(128) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); +} + +TEST(Q8DWCONV_UP8x9__SSE2, single_output_channels_eq_8_with_qmax) { + TEST_REQUIRES_X86_SSE2; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(1) + .qmax(128) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); +} + +TEST( + Q8DWCONV_UP8x9__SSE2, + single_output_channels_eq_8_with_input_zero_point_only) { + TEST_REQUIRES_X86_SSE2; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(1) + .inputZeroPoint(255) + .kernelZeroPoint(0) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); +} + +TEST( + Q8DWCONV_UP8x9__SSE2, + single_output_channels_eq_8_with_kernel_zero_point_only) { + TEST_REQUIRES_X86_SSE2; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(1) + .inputZeroPoint(0) + .kernelZeroPoint(255) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); +} + +TEST(Q8DWCONV_UP8x9__SSE2, multi_output_channels_eq_8) { + TEST_REQUIRES_X86_SSE2; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(5) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); +} + +TEST(Q8DWCONV_UP8x9__SSE2, multi_output_channels_eq_8_with_subsampling) { + TEST_REQUIRES_X86_SSE2; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .subsampling(2) + .cr(8) + .channels(8) + .width(5) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); +} + +TEST(Q8DWCONV_UP8x9__SSE2, multi_output_channels_eq_8_with_input_stride) { + TEST_REQUIRES_X86_SSE2; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(5) + .inputStride(17) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); +} + +TEST(Q8DWCONV_UP8x9__SSE2, multi_output_channels_eq_8_with_output_stride) { + TEST_REQUIRES_X86_SSE2; + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(8) + .width(5) + .outputStride(19) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); +} + +TEST(Q8DWCONV_UP8x9__SSE2, single_output_channels_div_8) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 16; channels < 128; channels += 24) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(1) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); + } +} + +TEST(Q8DWCONV_UP8x9__SSE2, multi_output_channels_div_8) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 16; channels < 128; channels += 24) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(5) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); + } +} + +TEST(Q8DWCONV_UP8x9__SSE2, multi_output_channels_div_8_with_output_stride) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 16; channels < 128; channels += 24) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(5) + .outputStride(171) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); + } +} + +TEST(Q8DWCONV_UP8x9__SSE2, single_output_channels_gt_8) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(1) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); + } +} + +TEST(Q8DWCONV_UP8x9__SSE2, single_output_channels_gt_8_with_qmin) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(1) + .qmin(128) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); + } +} + +TEST(Q8DWCONV_UP8x9__SSE2, single_output_channels_gt_8_with_qmax) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(1) + .qmax(128) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); + } +} + +TEST( + Q8DWCONV_UP8x9__SSE2, + single_output_channels_gt_8_with_input_zero_point_only) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(1) + .inputZeroPoint(255) + .kernelZeroPoint(0) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); + } +} + +TEST( + Q8DWCONV_UP8x9__SSE2, + single_output_channels_gt_8_with_kernel_zero_point_only) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(1) + .inputZeroPoint(0) + .kernelZeroPoint(255) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); + } +} + +TEST(Q8DWCONV_UP8x9__SSE2, multi_output_channels_gt_8) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(5) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); + } +} + +TEST(Q8DWCONV_UP8x9__SSE2, multi_output_channels_gt_8_with_output_stride) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(3) + .kernelWidth(3) + .cr(8) + .channels(channels) + .width(5) + .outputStride(17) + .test(pytorch_q8dwconv_ukernel_up8x9__sse2); + } +} + +TEST(Q8DWCONV_MP8x25__SSE2, single_output_channels_eq_8) { + TEST_REQUIRES_X86_SSE2; + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(8) + .width(1) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); +} + +TEST(Q8DWCONV_MP8x25__SSE2, single_output_channels_eq_8_with_qmin) { + TEST_REQUIRES_X86_SSE2; + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(8) + .width(1) + .qmin(128) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); +} + +TEST(Q8DWCONV_MP8x25__SSE2, single_output_channels_eq_8_with_qmax) { + TEST_REQUIRES_X86_SSE2; + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(8) + .width(1) + .qmax(128) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); +} + +TEST( + Q8DWCONV_MP8x25__SSE2, + single_output_channels_eq_8_with_input_zero_point_only) { + TEST_REQUIRES_X86_SSE2; + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(8) + .width(1) + .inputZeroPoint(255) + .kernelZeroPoint(0) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); +} + +TEST( + Q8DWCONV_MP8x25__SSE2, + single_output_channels_eq_8_with_kernel_zero_point_only) { + TEST_REQUIRES_X86_SSE2; + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(8) + .width(1) + .inputZeroPoint(0) + .kernelZeroPoint(255) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); +} + +TEST(Q8DWCONV_MP8x25__SSE2, multi_output_channels_eq_8) { + TEST_REQUIRES_X86_SSE2; + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(8) + .width(5) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); +} + +TEST(Q8DWCONV_MP8x25__SSE2, multi_output_channels_eq_8_with_subsampling) { + TEST_REQUIRES_X86_SSE2; + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .subsampling(2) + .cr(8) + .channels(8) + .width(5) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); +} + +TEST(Q8DWCONV_MP8x25__SSE2, multi_output_channels_eq_8_with_input_stride) { + TEST_REQUIRES_X86_SSE2; + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(8) + .width(5) + .inputStride(17) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); +} + +TEST(Q8DWCONV_MP8x25__SSE2, multi_output_channels_eq_8_with_output_stride) { + TEST_REQUIRES_X86_SSE2; + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(8) + .width(5) + .outputStride(19) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); +} + +TEST(Q8DWCONV_MP8x25__SSE2, single_output_channels_div_8) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 16; channels < 128; channels += 24) { + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(channels) + .width(1) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); + } +} + +TEST(Q8DWCONV_MP8x25__SSE2, multi_output_channels_div_8) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 16; channels < 128; channels += 24) { + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(channels) + .width(5) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); + } +} + +TEST(Q8DWCONV_MP8x25__SSE2, multi_output_channels_div_8_with_output_stride) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 16; channels < 128; channels += 24) { + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(channels) + .width(5) + .outputStride(171) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); + } +} + +TEST(Q8DWCONV_MP8x25__SSE2, single_output_channels_gt_8) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(channels) + .width(1) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); + } +} + +TEST(Q8DWCONV_MP8x25__SSE2, single_output_channels_gt_8_with_qmin) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(channels) + .width(1) + .qmin(128) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); + } +} + +TEST(Q8DWCONV_MP8x25__SSE2, single_output_channels_gt_8_with_qmax) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(channels) + .width(1) + .qmax(128) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); + } +} + +TEST( + Q8DWCONV_MP8x25__SSE2, + single_output_channels_gt_8_with_input_zero_point_only) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(channels) + .width(1) + .inputZeroPoint(255) + .kernelZeroPoint(0) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); + } +} + +TEST( + Q8DWCONV_MP8x25__SSE2, + single_output_channels_gt_8_with_kernel_zero_point_only) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(channels) + .width(1) + .inputZeroPoint(0) + .kernelZeroPoint(255) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); + } +} + +TEST(Q8DWCONV_MP8x25__SSE2, multi_output_channels_gt_8) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(channels) + .width(5) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); + } +} + +TEST(Q8DWCONV_MP8x25__SSE2, multi_output_channels_gt_8_with_output_stride) { + TEST_REQUIRES_X86_SSE2; + for (uint32_t channels = 9; channels < 16; channels++) { + DWConvMicrokernelTester() + .kernelHeight(5) + .kernelWidth(5) + .cr(8) + .channels(channels) + .width(5) + .outputStride(17) + .test(pytorch_q8dwconv_ukernel_mp8x25__sse2); + } +} +#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/q8gavgpool.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/q8gavgpool.cc new file mode 100644 index 0000000000000..74a3d5a89f770 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/q8gavgpool.cc @@ -0,0 +1,1155 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include "gavgpool-microkernel-tester.h" + +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 +TEST(Q8GAVGPOOL_UP8x7__NEON, n_eq_8_all_m) { + TEST_REQUIRES_ARM_NEON; + GAvgPoolMicrokernelTester().m(7).n(8).test(pytorch_q8gavgpool_ukernel_up8x7__neon); +} + +TEST(Q8GAVGPOOL_UP8x7__NEON, n_eq_8_few_m) { + TEST_REQUIRES_ARM_NEON; + for (size_t m = 1; m < 7; m++) { + GAvgPoolMicrokernelTester().m(m).n(8).test(pytorch_q8gavgpool_ukernel_up8x7__neon); + } +} + +TEST(Q8GAVGPOOL_UP8x7__NEON, n_eq_8_all_m_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + GAvgPoolMicrokernelTester().m(7).n(8).xStride(11).test( + pytorch_q8gavgpool_ukernel_up8x7__neon); +} + +TEST(Q8GAVGPOOL_UP8x7__NEON, n_eq_8_all_m_with_x_scale) { + TEST_REQUIRES_ARM_NEON; + for (float xScale = 0.01f; xScale < 100.0f; xScale *= 3.14159265f) { + GAvgPoolMicrokernelTester().m(7).n(8).xScale(xScale).test( + pytorch_q8gavgpool_ukernel_up8x7__neon); + } +} + +TEST(Q8GAVGPOOL_UP8x7__NEON, n_eq_8_all_m_with_x_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (int32_t xZeroPoint = 0; xZeroPoint <= 255; xZeroPoint += 51) { + GAvgPoolMicrokernelTester() + .m(7) + .n(8) + .xZeroPoint(xZeroPoint) + .test(pytorch_q8gavgpool_ukernel_up8x7__neon); + } +} + +TEST(Q8GAVGPOOL_UP8x7__NEON, n_eq_8_all_m_with_y_scale) { + TEST_REQUIRES_ARM_NEON; + for (float yScale = 0.01f; yScale < 100.0f; yScale *= 3.14159265f) { + GAvgPoolMicrokernelTester().m(7).n(8).yScale(yScale).test( + pytorch_q8gavgpool_ukernel_up8x7__neon); + } +} + +TEST(Q8GAVGPOOL_UP8x7__NEON, n_eq_8_all_m_with_y_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + GAvgPoolMicrokernelTester() + .m(7) + .n(8) + .yZeroPoint(yZeroPoint) + .test(pytorch_q8gavgpool_ukernel_up8x7__neon); + } +} + +TEST(Q8GAVGPOOL_UP8x7__NEON, n_eq_8_all_m_with_y_max) { + TEST_REQUIRES_ARM_NEON; + GAvgPoolMicrokernelTester() + .m(7) + .n(8) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMax(128) + .test(pytorch_q8gavgpool_ukernel_up8x7__neon); +} + +TEST(Q8GAVGPOOL_UP8x7__NEON, n_eq_8_all_m_with_y_min) { + TEST_REQUIRES_ARM_NEON; + GAvgPoolMicrokernelTester() + .m(7) + .n(8) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMin(128) + .test(pytorch_q8gavgpool_ukernel_up8x7__neon); +} + +TEST(Q8GAVGPOOL_UP8x7__NEON, n_div_8_all_m) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 8; n < 128; n += 24) { + GAvgPoolMicrokernelTester().m(7).n(n).test(pytorch_q8gavgpool_ukernel_up8x7__neon); + } +} + +TEST(Q8GAVGPOOL_UP8x7__NEON, n_div_8_few_m) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 8; n < 128; n += 24) { + for (size_t m = 1; m < 7; m++) { + GAvgPoolMicrokernelTester().m(m).n(n).test( + pytorch_q8gavgpool_ukernel_up8x7__neon); + } + } +} + +TEST(Q8GAVGPOOL_UP8x7__NEON, n_gt_8_all_m) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester().m(7).n(n).test(pytorch_q8gavgpool_ukernel_up8x7__neon); + } +} + +TEST(Q8GAVGPOOL_UP8x7__NEON, n_gt_8_few_m) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + for (size_t m = 1; m < 7; m++) { + GAvgPoolMicrokernelTester().m(m).n(n).test( + pytorch_q8gavgpool_ukernel_up8x7__neon); + } + } +} + +TEST(Q8GAVGPOOL_UP8x7__NEON, n_gt_8_all_m_with_x_scale) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + for (float xScale = 0.01f; xScale < 100.0f; xScale *= 3.14159265f) { + GAvgPoolMicrokernelTester().m(7).n(n).xScale(xScale).test( + pytorch_q8gavgpool_ukernel_up8x7__neon); + } + } +} + +TEST(Q8GAVGPOOL_UP8x7__NEON, n_gt_8_all_m_with_x_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + for (int32_t xZeroPoint = 0; xZeroPoint <= 255; xZeroPoint += 51) { + GAvgPoolMicrokernelTester() + .m(7) + .n(n) + .xZeroPoint(xZeroPoint) + .test(pytorch_q8gavgpool_ukernel_up8x7__neon); + } + } +} + +TEST(Q8GAVGPOOL_UP8x7__NEON, n_gt_8_all_m_with_y_scale) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + for (float yScale = 0.01f; yScale < 100.0f; yScale *= 3.14159265f) { + GAvgPoolMicrokernelTester().m(7).n(n).yScale(yScale).test( + pytorch_q8gavgpool_ukernel_up8x7__neon); + } + } +} + +TEST(Q8GAVGPOOL_UP8x7__NEON, n_gt_8_all_m_with_y_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + GAvgPoolMicrokernelTester() + .m(7) + .n(n) + .yZeroPoint(yZeroPoint) + .test(pytorch_q8gavgpool_ukernel_up8x7__neon); + } + } +} + +TEST(Q8GAVGPOOL_UP8x7__NEON, n_gt_8_all_m_with_y_max) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester() + .m(7) + .n(n) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMax(128) + .test(pytorch_q8gavgpool_ukernel_up8x7__neon); + } +} + +TEST(Q8GAVGPOOL_UP8x7__NEON, n_gt_8_all_m_with_y_min) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester() + .m(7) + .n(n) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMin(128) + .test(pytorch_q8gavgpool_ukernel_up8x7__neon); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_eq_8_2pass_all_m) { + TEST_REQUIRES_ARM_NEON; + GAvgPoolMicrokernelTester().m(14).n(8).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_eq_8_2pass_all_m_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + GAvgPoolMicrokernelTester().m(14).n(8).nr(8).xStride(11).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_eq_8_2pass_all_m_with_x_scale) { + TEST_REQUIRES_ARM_NEON; + for (float xScale = 0.01f; xScale < 100.0f; xScale *= 3.14159265f) { + GAvgPoolMicrokernelTester().m(14).n(8).nr(8).xScale(xScale).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_eq_8_2pass_all_m_with_x_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (int32_t xZeroPoint = 0; xZeroPoint <= 255; xZeroPoint += 51) { + GAvgPoolMicrokernelTester() + .m(14) + .n(8) + .nr(8) + .xZeroPoint(xZeroPoint) + .test(pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_eq_8_2pass_all_m_with_y_scale) { + TEST_REQUIRES_ARM_NEON; + for (float yScale = 0.01f; yScale < 100.0f; yScale *= 3.14159265f) { + GAvgPoolMicrokernelTester().m(14).n(8).nr(8).yScale(yScale).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_eq_8_2pass_all_m_with_y_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + GAvgPoolMicrokernelTester() + .m(14) + .n(8) + .nr(8) + .yZeroPoint(yZeroPoint) + .test(pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_eq_8_2pass_all_m_with_y_max) { + TEST_REQUIRES_ARM_NEON; + GAvgPoolMicrokernelTester() + .m(14) + .n(8) + .nr(8) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMax(128) + .test(pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_eq_8_2pass_all_m_with_y_min) { + TEST_REQUIRES_ARM_NEON; + GAvgPoolMicrokernelTester() + .m(14) + .n(8) + .nr(8) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMin(128) + .test(pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_eq_8_2pass_few_m) { + TEST_REQUIRES_ARM_NEON; + for (size_t m = 1; m < 7; m++) { + GAvgPoolMicrokernelTester().m(7 + m).n(8).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_eq_8_2pass_few_m_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + for (size_t m = 1; m < 7; m++) { + GAvgPoolMicrokernelTester().m(7 + m).n(8).nr(8).xStride(11).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_eq_8_multipass_all_m) { + TEST_REQUIRES_ARM_NEON; + for (size_t m = 14; m <= 35; m += 7) { + GAvgPoolMicrokernelTester().m(m).n(8).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_eq_8_multipass_all_m_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + for (size_t m = 14; m <= 35; m += 7) { + GAvgPoolMicrokernelTester().m(m).n(8).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_div_8_2pass_all_m) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 8; n < 128; n += 24) { + GAvgPoolMicrokernelTester().m(14).n(n).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_div_8_2pass_few_m) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 8; n < 128; n += 24) { + for (size_t m = 1; m < 7; m++) { + GAvgPoolMicrokernelTester().m(7 + m).n(n).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_div_8_multipass_all_m) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 8; n < 128; n += 24) { + for (size_t m = 14; m <= 35; m += 7) { + GAvgPoolMicrokernelTester().m(m).n(n).nr(8).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_div_8_multipass_all_m_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 8; n < 128; n += 24) { + for (size_t m = 14; m <= 35; m += 7) { + GAvgPoolMicrokernelTester().m(m).n(n).nr(8).nr(8).xStride(131).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_gt_8_2pass_all_m) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester().m(14).n(n).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_gt_8_2pass_all_m_with_x_scale) { + TEST_REQUIRES_ARM_NEON; + for (float xScale = 0.01f; xScale < 100.0f; xScale *= 3.14159265f) { + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester().m(14).n(n).nr(8).xScale(xScale).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_gt_8_2pass_all_m_with_x_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (int32_t xZeroPoint = 0; xZeroPoint <= 255; xZeroPoint += 51) { + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester() + .m(14) + .n(n) + .nr(8) + .xZeroPoint(xZeroPoint) + .test(pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_gt_8_2pass_all_m_with_y_scale) { + TEST_REQUIRES_ARM_NEON; + for (float yScale = 0.01f; yScale < 100.0f; yScale *= 3.14159265f) { + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester().m(14).n(n).nr(8).yScale(yScale).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_gt_8_2pass_all_m_with_y_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester() + .m(14) + .n(n) + .nr(8) + .yZeroPoint(yZeroPoint) + .test(pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_gt_8_2pass_all_m_with_y_max) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester() + .m(14) + .n(n) + .nr(8) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMax(128) + .test(pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_gt_8_2pass_all_m_with_y_min) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester() + .m(14) + .n(n) + .nr(8) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMin(128) + .test(pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_gt_8_2pass_few_m) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + for (size_t m = 1; m < 7; m++) { + GAvgPoolMicrokernelTester().m(7 + m).n(n).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_gt_8_multipass_all_m) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + for (size_t m = 14; m <= 35; m += 7) { + GAvgPoolMicrokernelTester().m(m).n(n).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__NEON, n_gt_8_multipass_all_m_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + for (size_t m = 14; m <= 35; m += 7) { + GAvgPoolMicrokernelTester().m(m).n(n).nr(8).xStride(23).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__neon); + } + } +} + +TEST(Q8GAVGPOOL_UP8xM__NEON, n_lt_8_small_m) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 8; n++) { + for (size_t m = 1; m < 8; m++) { + GAvgPoolMicrokernelTester().m(m).n(n).test( + pytorch_q8gavgpool_ukernel_up8xm__neon); + } + } +} + +TEST(Q8GAVGPOOL_UP8xM__NEON, n_lt_8_large_m) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 8; n++) { + for (size_t m = 8; m < 16; m++) { + GAvgPoolMicrokernelTester().m(m).n(n).test( + pytorch_q8gavgpool_ukernel_up8xm__neon); + } + } +} + +TEST(Q8GAVGPOOL_UP8xM__NEON, n_lt_8_with_x_scale) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 8; n++) { + for (size_t m = 1; m < 16; m += 5) { + for (float xScale = 0.01f; xScale < 100.0f; xScale *= 3.14159265f) { + GAvgPoolMicrokernelTester().m(m).n(n).xScale(xScale).test( + pytorch_q8gavgpool_ukernel_up8xm__neon); + } + } + } +} + +TEST(Q8GAVGPOOL_UP8xM__NEON, n_lt_8_with_x_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 8; n++) { + for (size_t m = 1; m < 16; m += 5) { + for (int32_t xZeroPoint = 0; xZeroPoint <= 255; xZeroPoint += 51) { + GAvgPoolMicrokernelTester() + .m(m) + .n(n) + .xZeroPoint(xZeroPoint) + .test(pytorch_q8gavgpool_ukernel_up8xm__neon); + } + } + } +} + +TEST(Q8GAVGPOOL_UP8xM__NEON, n_lt_8_with_y_scale) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 8; n++) { + for (size_t m = 1; m < 16; m += 5) { + for (float yScale = 0.01f; yScale < 100.0f; yScale *= 3.14159265f) { + GAvgPoolMicrokernelTester().m(m).n(n).yScale(yScale).test( + pytorch_q8gavgpool_ukernel_up8xm__neon); + } + } + } +} + +TEST(Q8GAVGPOOL_UP8xM__NEON, n_lt_8_with_y_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 8; n++) { + for (size_t m = 1; m < 16; m += 5) { + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + GAvgPoolMicrokernelTester() + .m(m) + .n(n) + .yZeroPoint(yZeroPoint) + .test(pytorch_q8gavgpool_ukernel_up8xm__neon); + } + } + } +} + +TEST(Q8GAVGPOOL_UP8xM__NEON, n_lt_8_with_y_max) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 8; n++) { + for (size_t m = 1; m < 16; m += 5) { + GAvgPoolMicrokernelTester() + .m(m) + .n(n) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMax(128) + .test(pytorch_q8gavgpool_ukernel_up8xm__neon); + } + } +} + +TEST(Q8GAVGPOOL_UP8xM__NEON, n_lt_8_with_y_min) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 8; n++) { + for (size_t m = 1; m < 16; m += 5) { + GAvgPoolMicrokernelTester() + .m(m) + .n(n) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMin(128) + .test(pytorch_q8gavgpool_ukernel_up8xm__neon); + } + } +} +#endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */ + +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_eq_8_all_m) { + TEST_REQUIRES_X86_SSE2; + GAvgPoolMicrokernelTester().m(7).n(8).test(pytorch_q8gavgpool_ukernel_up8x7__sse2); +} + +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_eq_8_few_m) { + TEST_REQUIRES_X86_SSE2; + for (size_t m = 1; m < 7; m++) { + GAvgPoolMicrokernelTester().m(m).n(8).test(pytorch_q8gavgpool_ukernel_up8x7__sse2); + } +} + +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_eq_8_all_m_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + GAvgPoolMicrokernelTester().m(7).n(8).xStride(11).test( + pytorch_q8gavgpool_ukernel_up8x7__sse2); +} + +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_eq_8_all_m_with_x_scale) { + TEST_REQUIRES_X86_SSE2; + for (float xScale = 0.01f; xScale < 100.0f; xScale *= 3.14159265f) { + GAvgPoolMicrokernelTester().m(7).n(8).xScale(xScale).test( + pytorch_q8gavgpool_ukernel_up8x7__sse2); + } +} + +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_eq_8_all_m_with_x_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (int32_t xZeroPoint = 0; xZeroPoint <= 255; xZeroPoint += 51) { + GAvgPoolMicrokernelTester() + .m(7) + .n(8) + .xZeroPoint(xZeroPoint) + .test(pytorch_q8gavgpool_ukernel_up8x7__sse2); + } +} + +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_eq_8_all_m_with_y_scale) { + TEST_REQUIRES_X86_SSE2; + for (float yScale = 0.01f; yScale < 100.0f; yScale *= 3.14159265f) { + GAvgPoolMicrokernelTester().m(7).n(8).yScale(yScale).test( + pytorch_q8gavgpool_ukernel_up8x7__sse2); + } +} + +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_eq_8_all_m_with_y_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + GAvgPoolMicrokernelTester() + .m(7) + .n(8) + .yZeroPoint(yZeroPoint) + .test(pytorch_q8gavgpool_ukernel_up8x7__sse2); + } +} + +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_eq_8_all_m_with_y_max) { + TEST_REQUIRES_X86_SSE2; + GAvgPoolMicrokernelTester() + .m(7) + .n(8) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMax(128) + .test(pytorch_q8gavgpool_ukernel_up8x7__sse2); +} + +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_eq_8_all_m_with_y_min) { + TEST_REQUIRES_X86_SSE2; + GAvgPoolMicrokernelTester() + .m(7) + .n(8) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMin(128) + .test(pytorch_q8gavgpool_ukernel_up8x7__sse2); +} + +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_div_8_all_m) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 8; n < 128; n += 24) { + GAvgPoolMicrokernelTester().m(7).n(n).test(pytorch_q8gavgpool_ukernel_up8x7__sse2); + } +} + +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_div_8_few_m) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 8; n < 128; n += 24) { + for (size_t m = 1; m < 7; m++) { + GAvgPoolMicrokernelTester().m(m).n(n).test( + pytorch_q8gavgpool_ukernel_up8x7__sse2); + } + } +} + +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_gt_8_all_m) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester().m(7).n(n).test(pytorch_q8gavgpool_ukernel_up8x7__sse2); + } +} + +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_gt_8_few_m) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 9; n < 16; n++) { + for (size_t m = 1; m < 7; m++) { + GAvgPoolMicrokernelTester().m(m).n(n).test( + pytorch_q8gavgpool_ukernel_up8x7__sse2); + } + } +} + +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_gt_8_all_m_with_x_scale) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 9; n < 16; n++) { + for (float xScale = 0.01f; xScale < 100.0f; xScale *= 3.14159265f) { + GAvgPoolMicrokernelTester().m(7).n(n).xScale(xScale).test( + pytorch_q8gavgpool_ukernel_up8x7__sse2); + } + } +} + +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_gt_8_all_m_with_x_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 9; n < 16; n++) { + for (int32_t xZeroPoint = 0; xZeroPoint <= 255; xZeroPoint += 51) { + GAvgPoolMicrokernelTester() + .m(7) + .n(n) + .xZeroPoint(xZeroPoint) + .test(pytorch_q8gavgpool_ukernel_up8x7__sse2); + } + } +} + +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_gt_8_all_m_with_y_scale) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 9; n < 16; n++) { + for (float yScale = 0.01f; yScale < 100.0f; yScale *= 3.14159265f) { + GAvgPoolMicrokernelTester().m(7).n(n).yScale(yScale).test( + pytorch_q8gavgpool_ukernel_up8x7__sse2); + } + } +} + +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_gt_8_all_m_with_y_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 9; n < 16; n++) { + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + GAvgPoolMicrokernelTester() + .m(7) + .n(n) + .yZeroPoint(yZeroPoint) + .test(pytorch_q8gavgpool_ukernel_up8x7__sse2); + } + } +} + +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_gt_8_all_m_with_y_max) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester() + .m(7) + .n(n) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMax(128) + .test(pytorch_q8gavgpool_ukernel_up8x7__sse2); + } +} + +TEST(Q8GAVGPOOL_UP8x7__SSE2, n_gt_8_all_m_with_y_min) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester() + .m(7) + .n(n) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMin(128) + .test(pytorch_q8gavgpool_ukernel_up8x7__sse2); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_eq_8_2pass_all_m) { + TEST_REQUIRES_X86_SSE2; + GAvgPoolMicrokernelTester().m(14).n(8).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_eq_8_2pass_all_m_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + GAvgPoolMicrokernelTester().m(14).n(8).nr(8).xStride(11).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_eq_8_2pass_all_m_with_x_scale) { + TEST_REQUIRES_X86_SSE2; + for (float xScale = 0.01f; xScale < 100.0f; xScale *= 3.14159265f) { + GAvgPoolMicrokernelTester().m(14).n(8).nr(8).xScale(xScale).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_eq_8_2pass_all_m_with_x_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (int32_t xZeroPoint = 0; xZeroPoint <= 255; xZeroPoint += 51) { + GAvgPoolMicrokernelTester() + .m(14) + .n(8) + .nr(8) + .xZeroPoint(xZeroPoint) + .test(pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_eq_8_2pass_all_m_with_y_scale) { + TEST_REQUIRES_X86_SSE2; + for (float yScale = 0.01f; yScale < 100.0f; yScale *= 3.14159265f) { + GAvgPoolMicrokernelTester().m(14).n(8).nr(8).yScale(yScale).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_eq_8_2pass_all_m_with_y_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + GAvgPoolMicrokernelTester() + .m(14) + .n(8) + .nr(8) + .yZeroPoint(yZeroPoint) + .test(pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_eq_8_2pass_all_m_with_y_max) { + TEST_REQUIRES_X86_SSE2; + GAvgPoolMicrokernelTester() + .m(14) + .n(8) + .nr(8) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMax(128) + .test(pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_eq_8_2pass_all_m_with_y_min) { + TEST_REQUIRES_X86_SSE2; + GAvgPoolMicrokernelTester() + .m(14) + .n(8) + .nr(8) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMin(128) + .test(pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_eq_8_2pass_few_m) { + TEST_REQUIRES_X86_SSE2; + for (size_t m = 1; m < 7; m++) { + GAvgPoolMicrokernelTester().m(7 + m).n(8).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_eq_8_2pass_few_m_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + for (size_t m = 1; m < 7; m++) { + GAvgPoolMicrokernelTester().m(7 + m).n(8).nr(8).xStride(11).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_eq_8_multipass_all_m) { + TEST_REQUIRES_X86_SSE2; + for (size_t m = 14; m <= 35; m += 7) { + GAvgPoolMicrokernelTester().m(m).n(8).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_eq_8_multipass_all_m_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + for (size_t m = 14; m <= 35; m += 7) { + GAvgPoolMicrokernelTester().m(m).n(8).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_div_8_2pass_all_m) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 8; n < 128; n += 24) { + GAvgPoolMicrokernelTester().m(14).n(n).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_div_8_2pass_few_m) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 8; n < 128; n += 24) { + for (size_t m = 1; m < 7; m++) { + GAvgPoolMicrokernelTester().m(7 + m).n(n).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_div_8_multipass_all_m) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 8; n < 128; n += 24) { + for (size_t m = 14; m <= 35; m += 7) { + GAvgPoolMicrokernelTester().m(m).n(n).nr(8).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_div_8_multipass_all_m_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 8; n < 128; n += 24) { + for (size_t m = 14; m <= 35; m += 7) { + GAvgPoolMicrokernelTester().m(m).n(n).nr(8).nr(8).xStride(131).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_gt_8_2pass_all_m) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester().m(14).n(n).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_gt_8_2pass_all_m_with_x_scale) { + TEST_REQUIRES_X86_SSE2; + for (float xScale = 0.01f; xScale < 100.0f; xScale *= 3.14159265f) { + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester().m(14).n(n).nr(8).xScale(xScale).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_gt_8_2pass_all_m_with_x_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (int32_t xZeroPoint = 0; xZeroPoint <= 255; xZeroPoint += 51) { + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester() + .m(14) + .n(n) + .nr(8) + .xZeroPoint(xZeroPoint) + .test(pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_gt_8_2pass_all_m_with_y_scale) { + TEST_REQUIRES_X86_SSE2; + for (float yScale = 0.01f; yScale < 100.0f; yScale *= 3.14159265f) { + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester().m(14).n(n).nr(8).yScale(yScale).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_gt_8_2pass_all_m_with_y_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester() + .m(14) + .n(n) + .nr(8) + .yZeroPoint(yZeroPoint) + .test(pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_gt_8_2pass_all_m_with_y_max) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester() + .m(14) + .n(n) + .nr(8) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMax(128) + .test(pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_gt_8_2pass_all_m_with_y_min) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 9; n < 16; n++) { + GAvgPoolMicrokernelTester() + .m(14) + .n(n) + .nr(8) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMin(128) + .test(pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_gt_8_2pass_few_m) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 9; n < 16; n++) { + for (size_t m = 1; m < 7; m++) { + GAvgPoolMicrokernelTester().m(7 + m).n(n).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_gt_8_multipass_all_m) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 9; n < 16; n++) { + for (size_t m = 14; m <= 35; m += 7) { + GAvgPoolMicrokernelTester().m(m).n(n).nr(8).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } + } +} + +TEST(Q8GAVGPOOL_MP8x7p7q__SSE2, n_gt_8_multipass_all_m_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 9; n < 16; n++) { + for (size_t m = 14; m <= 35; m += 7) { + GAvgPoolMicrokernelTester().m(m).n(n).nr(8).xStride(23).test( + pytorch_q8gavgpool_ukernel_mp8x7p7q__sse2); + } + } +} + +TEST(Q8GAVGPOOL_UP8xM__SSE2, n_lt_8_small_m) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 8; n++) { + for (size_t m = 1; m < 8; m++) { + GAvgPoolMicrokernelTester().m(m).n(n).test( + pytorch_q8gavgpool_ukernel_up8xm__sse2); + } + } +} + +TEST(Q8GAVGPOOL_UP8xM__SSE2, n_lt_8_large_m) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 8; n++) { + for (size_t m = 8; m < 16; m++) { + GAvgPoolMicrokernelTester().m(m).n(n).test( + pytorch_q8gavgpool_ukernel_up8xm__sse2); + } + } +} + +TEST(Q8GAVGPOOL_UP8xM__SSE2, n_lt_8_with_x_scale) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 8; n++) { + for (size_t m = 1; m < 16; m += 5) { + for (float xScale = 0.01f; xScale < 100.0f; xScale *= 3.14159265f) { + GAvgPoolMicrokernelTester().m(m).n(n).xScale(xScale).test( + pytorch_q8gavgpool_ukernel_up8xm__sse2); + } + } + } +} + +TEST(Q8GAVGPOOL_UP8xM__SSE2, n_lt_8_with_x_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 8; n++) { + for (size_t m = 1; m < 16; m += 5) { + for (int32_t xZeroPoint = 0; xZeroPoint <= 255; xZeroPoint += 51) { + GAvgPoolMicrokernelTester() + .m(m) + .n(n) + .xZeroPoint(xZeroPoint) + .test(pytorch_q8gavgpool_ukernel_up8xm__sse2); + } + } + } +} + +TEST(Q8GAVGPOOL_UP8xM__SSE2, n_lt_8_with_y_scale) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 8; n++) { + for (size_t m = 1; m < 16; m += 5) { + for (float yScale = 0.01f; yScale < 100.0f; yScale *= 3.14159265f) { + GAvgPoolMicrokernelTester().m(m).n(n).yScale(yScale).test( + pytorch_q8gavgpool_ukernel_up8xm__sse2); + } + } + } +} + +TEST(Q8GAVGPOOL_UP8xM__SSE2, n_lt_8_with_y_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 8; n++) { + for (size_t m = 1; m < 16; m += 5) { + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + GAvgPoolMicrokernelTester() + .m(m) + .n(n) + .yZeroPoint(yZeroPoint) + .test(pytorch_q8gavgpool_ukernel_up8xm__sse2); + } + } + } +} + +TEST(Q8GAVGPOOL_UP8xM__SSE2, n_lt_8_with_y_max) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 8; n++) { + for (size_t m = 1; m < 16; m += 5) { + GAvgPoolMicrokernelTester() + .m(m) + .n(n) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMax(128) + .test(pytorch_q8gavgpool_ukernel_up8xm__sse2); + } + } +} + +TEST(Q8GAVGPOOL_UP8xM__SSE2, n_lt_8_with_y_min) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 8; n++) { + for (size_t m = 1; m < 16; m += 5) { + GAvgPoolMicrokernelTester() + .m(m) + .n(n) + .xZeroPoint(128) + .yZeroPoint(128) + .xScale(1.0f) + .yScale(1.0f) + .yMin(128) + .test(pytorch_q8gavgpool_ukernel_up8xm__sse2); + } + } +} +#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/q8gemm.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/q8gemm.cc new file mode 100644 index 0000000000000..630da3b8252a6 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/q8gemm.cc @@ -0,0 +1,2613 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include "gemm-microkernel-tester.h" + +#if CPUINFO_ARCH_ARM +TEST(Q8GEMM_4x8__AARCH32_NEON, k_eq_8) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(1).m(4).n(8).k(8).test( + pytorch_q8gemm_ukernel_4x8__aarch32_neon); +} + +TEST(Q8GEMM_4x8__AARCH32_NEON, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aStride(37) + .test(pytorch_q8gemm_ukernel_4x8__aarch32_neon); +} + +TEST(Q8GEMM_4x8__AARCH32_NEON, k_eq_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .cStride(17) + .test(pytorch_q8gemm_ukernel_4x8__aarch32_neon); +} + +TEST(Q8GEMM_4x8__AARCH32_NEON, k_eq_8_qmin128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(1).m(4).n(8).k(8).qmin(128).test( + pytorch_q8gemm_ukernel_4x8__aarch32_neon); +} + +TEST(Q8GEMM_4x8__AARCH32_NEON, k_eq_8_qmax128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(1).m(4).n(8).k(8).qmax(128).test( + pytorch_q8gemm_ukernel_4x8__aarch32_neon); +} + +TEST(Q8GEMM_4x8__AARCH32_NEON, k_eq_8_azp0) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x8__aarch32_neon); +} + +TEST(Q8GEMM_4x8__AARCH32_NEON, k_eq_8_bzp0) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x8__aarch32_neon); +} + +TEST(Q8GEMM_4x8__AARCH32_NEON, k_eq_8_nozp) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x8__aarch32_neon); +} + +TEST(Q8GEMM_4x8__AARCH32_NEON, k_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(1).m(4).n(8).k(k).test( + pytorch_q8gemm_ukernel_4x8__aarch32_neon); + } +} + +TEST(Q8GEMM_4x8__AARCH32_NEON, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(37) + .test(pytorch_q8gemm_ukernel_4x8__aarch32_neon); + } +} + +TEST(Q8GEMM_4x8__AARCH32_NEON, k_gt_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_q8gemm_ukernel_4x8__aarch32_neon); + } +} + +TEST(Q8GEMM_4x8__AARCH32_NEON, k_gt_8_azp0) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x8__aarch32_neon); + } +} + +TEST(Q8GEMM_4x8__AARCH32_NEON, k_gt_8_bzp0) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x8__aarch32_neon); + } +} + +TEST(Q8GEMM_4x8__AARCH32_NEON, k_gt_8_nozp) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x8__aarch32_neon); + } +} + +TEST(Q8GEMM_4x8__AARCH32_NEON, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_q8gemm_ukernel_4x8__aarch32_neon); + } + } + } +} + +TEST(Q8GEMM_4x8__AARCH32_NEON, k_div_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(1).m(4).n(8).k(k).test( + pytorch_q8gemm_ukernel_4x8__aarch32_neon); + } +} + +TEST(Q8GEMM_4x8__AARCH32_NEON, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(171) + .test(pytorch_q8gemm_ukernel_4x8__aarch32_neon); + } +} + +TEST(Q8GEMM_4x8__AARCH32_NEON, k_div_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_q8gemm_ukernel_4x8__aarch32_neon); + } +} + +TEST(Q8GEMM_4x8__AARCH32_NEON, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 24) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_q8gemm_ukernel_4x8__aarch32_neon); + } + } + } +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_eq_8) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(2).m(4).n(8).k(8).test( + pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(8) + .aStride(37) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_eq_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(8) + .cStride(17) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_eq_8_qmin128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(2).m(4).n(8).k(8).qmin(128).test( + pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_eq_8_qmax128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(2).m(4).n(8).k(8).qmax(128).test( + pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_eq_8_azp0) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(8) + .aZeroPoint(0) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_eq_8_bzp0) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(8) + .bZeroPoint(0) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_eq_8_nozp) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(8) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(2).m(4).n(8).k(k).test( + pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); + } +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(k) + .aStride(37) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); + } +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_gt_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); + } +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_gt_8_azp0) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(k) + .aZeroPoint(0) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); + } +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_gt_8_bzp0) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(k) + .bZeroPoint(0) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); + } +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_gt_8_nozp) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(k) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); + } +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); + } + } + } +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_div_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(2).m(4).n(8).k(k).test( + pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); + } +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(k) + .aStride(171) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); + } +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_div_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); + } +} + +TEST(Q8GEMM_4x8c2_XZP__AARCH32_NEON, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 24) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__aarch32_neon); + } + } + } +} +#endif + +#if CPUINFO_ARCH_ARM64 +TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8) { + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(8).test( + pytorch_q8gemm_ukernel_8x8__aarch64_neon); +} + +TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_strided_a) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .aStride(37) + .test(pytorch_q8gemm_ukernel_8x8__aarch64_neon); +} + +TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_strided_c) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .cStride(17) + .test(pytorch_q8gemm_ukernel_8x8__aarch64_neon); +} + +TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_qmin128) { + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(8).qmin(128).test( + pytorch_q8gemm_ukernel_8x8__aarch64_neon); +} + +TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_qmax128) { + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(8).qmax(128).test( + pytorch_q8gemm_ukernel_8x8__aarch64_neon); +} + +TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_azp0) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .aZeroPoint(0) + .test(pytorch_q8gemm_ukernel_8x8__aarch64_neon); +} + +TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_bzp0) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_8x8__aarch64_neon); +} + +TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_nozp) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_8x8__aarch64_neon); +} + +TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8) { + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(k).test( + pytorch_q8gemm_ukernel_8x8__aarch64_neon); + } +} + +TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8_strided_a) { + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(37) + .test(pytorch_q8gemm_ukernel_8x8__aarch64_neon); + } +} + +TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8_strided_c) { + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_q8gemm_ukernel_8x8__aarch64_neon); + } +} + +TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8_azp0) { + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aZeroPoint(0) + .test(pytorch_q8gemm_ukernel_8x8__aarch64_neon); + } +} + +TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8_bzp0) { + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_8x8__aarch64_neon); + } +} + +TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8_nozp) { + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_8x8__aarch64_neon); + } +} + +TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8_subtile) { + for (size_t k = 9; k < 16; k++) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_q8gemm_ukernel_8x8__aarch64_neon); + } + } + } +} + +TEST(Q8GEMM_8x8__AARCH64_NEON, k_div_8) { + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(k).test( + pytorch_q8gemm_ukernel_8x8__aarch64_neon); + } +} + +TEST(Q8GEMM_8x8__AARCH64_NEON, k_div_8_strided_a) { + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(171) + .test(pytorch_q8gemm_ukernel_8x8__aarch64_neon); + } +} + +TEST(Q8GEMM_8x8__AARCH64_NEON, k_div_8_strided_c) { + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_q8gemm_ukernel_8x8__aarch64_neon); + } +} + +TEST(Q8GEMM_8x8__AARCH64_NEON, k_div_8_subtile) { + for (size_t k = 16; k < 128; k += 24) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_q8gemm_ukernel_8x8__aarch64_neon); + } + } + } +} +#endif + +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 +TEST(Q8GEMM_4x8__NEON, k_eq_8) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(1).m(4).n(8).k(8).test( + pytorch_q8gemm_ukernel_4x8__neon); +} + +TEST(Q8GEMM_4x8__NEON, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aStride(37) + .test(pytorch_q8gemm_ukernel_4x8__neon); +} + +TEST(Q8GEMM_4x8__NEON, k_eq_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .cStride(17) + .test(pytorch_q8gemm_ukernel_4x8__neon); +} + +TEST(Q8GEMM_4x8__NEON, k_eq_8_qmin128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(1).m(4).n(8).k(8).qmin(128).test( + pytorch_q8gemm_ukernel_4x8__neon); +} + +TEST(Q8GEMM_4x8__NEON, k_eq_8_qmax128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(1).m(4).n(8).k(8).qmax(128).test( + pytorch_q8gemm_ukernel_4x8__neon); +} + +TEST(Q8GEMM_4x8__NEON, k_eq_8_azp0) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x8__neon); +} + +TEST(Q8GEMM_4x8__NEON, k_eq_8_bzp0) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x8__neon); +} + +TEST(Q8GEMM_4x8__NEON, k_eq_8_nozp) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(8) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x8__neon); +} + +TEST(Q8GEMM_4x8__NEON, k_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(1).m(4).n(8).k(k).test( + pytorch_q8gemm_ukernel_4x8__neon); + } +} + +TEST(Q8GEMM_4x8__NEON, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(37) + .test(pytorch_q8gemm_ukernel_4x8__neon); + } +} + +TEST(Q8GEMM_4x8__NEON, k_gt_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_q8gemm_ukernel_4x8__neon); + } +} + +TEST(Q8GEMM_4x8__NEON, k_gt_8_azp0) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x8__neon); + } +} + +TEST(Q8GEMM_4x8__NEON, k_gt_8_bzp0) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x8__neon); + } +} + +TEST(Q8GEMM_4x8__NEON, k_gt_8_nozp) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x8__neon); + } +} + +TEST(Q8GEMM_4x8__NEON, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_q8gemm_ukernel_4x8__neon); + } + } + } +} + +TEST(Q8GEMM_4x8__NEON, k_div_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(1).m(4).n(8).k(k).test( + pytorch_q8gemm_ukernel_4x8__neon); + } +} + +TEST(Q8GEMM_4x8__NEON, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .aStride(171) + .test(pytorch_q8gemm_ukernel_4x8__neon); + } +} + +TEST(Q8GEMM_4x8__NEON, k_div_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(4) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_q8gemm_ukernel_4x8__neon); + } +} + +TEST(Q8GEMM_4x8__NEON, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 24) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_q8gemm_ukernel_4x8__neon); + } + } + } +} + +TEST(Q8GEMM_8x8__NEON, k_eq_8) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(8).test( + pytorch_q8gemm_ukernel_8x8__neon); +} + +TEST(Q8GEMM_8x8__NEON, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .aStride(37) + .test(pytorch_q8gemm_ukernel_8x8__neon); +} + +TEST(Q8GEMM_8x8__NEON, k_eq_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .cStride(17) + .test(pytorch_q8gemm_ukernel_8x8__neon); +} + +TEST(Q8GEMM_8x8__NEON, k_eq_8_qmin128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(8).qmin(128).test( + pytorch_q8gemm_ukernel_8x8__neon); +} + +TEST(Q8GEMM_8x8__NEON, k_eq_8_qmax128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(8).qmax(128).test( + pytorch_q8gemm_ukernel_8x8__neon); +} + +TEST(Q8GEMM_8x8__NEON, k_eq_8_azp0) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .aZeroPoint(0) + .test(pytorch_q8gemm_ukernel_8x8__neon); +} + +TEST(Q8GEMM_8x8__NEON, k_eq_8_bzp0) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_8x8__neon); +} + +TEST(Q8GEMM_8x8__NEON, k_eq_8_nozp) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_8x8__neon); +} + +TEST(Q8GEMM_8x8__NEON, k_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(k).test( + pytorch_q8gemm_ukernel_8x8__neon); + } +} + +TEST(Q8GEMM_8x8__NEON, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(37) + .test(pytorch_q8gemm_ukernel_8x8__neon); + } +} + +TEST(Q8GEMM_8x8__NEON, k_gt_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_q8gemm_ukernel_8x8__neon); + } +} + +TEST(Q8GEMM_8x8__NEON, k_gt_8_azp0) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aZeroPoint(0) + .test(pytorch_q8gemm_ukernel_8x8__neon); + } +} + +TEST(Q8GEMM_8x8__NEON, k_gt_8_bzp0) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_8x8__neon); + } +} + +TEST(Q8GEMM_8x8__NEON, k_gt_8_nozp) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_8x8__neon); + } +} + +TEST(Q8GEMM_8x8__NEON, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_q8gemm_ukernel_8x8__neon); + } + } + } +} + +TEST(Q8GEMM_8x8__NEON, k_div_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester().mr(8).nr(8).np(8).kr(1).m(8).n(8).k(k).test( + pytorch_q8gemm_ukernel_8x8__neon); + } +} + +TEST(Q8GEMM_8x8__NEON, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(171) + .test(pytorch_q8gemm_ukernel_8x8__neon); + } +} + +TEST(Q8GEMM_8x8__NEON, k_div_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_q8gemm_ukernel_8x8__neon); + } +} + +TEST(Q8GEMM_8x8__NEON, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 24) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_q8gemm_ukernel_8x8__neon); + } + } + } +} + +TEST(Q8GEMM_6x4__NEON, k_eq_8) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(6).nr(4).np(4).kr(1).m(6).n(4).k(8).test( + pytorch_q8gemm_ukernel_6x4__neon); +} + +TEST(Q8GEMM_6x4__NEON, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(8) + .aStride(37) + .test(pytorch_q8gemm_ukernel_6x4__neon); +} + +TEST(Q8GEMM_6x4__NEON, k_eq_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(8) + .cStride(17) + .test(pytorch_q8gemm_ukernel_6x4__neon); +} + +TEST(Q8GEMM_6x4__NEON, k_eq_8_qmin128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(6).nr(4).np(4).kr(1).m(6).n(4).k(8).qmin(128).test( + pytorch_q8gemm_ukernel_6x4__neon); +} + +TEST(Q8GEMM_6x4__NEON, k_eq_8_qmax128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(6).nr(4).np(4).kr(1).m(6).n(4).k(8).qmax(128).test( + pytorch_q8gemm_ukernel_6x4__neon); +} + +TEST(Q8GEMM_6x4__NEON, k_eq_8_azp0) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(8) + .aZeroPoint(0) + .test(pytorch_q8gemm_ukernel_6x4__neon); +} + +TEST(Q8GEMM_6x4__NEON, k_eq_8_bzp0) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(8) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_6x4__neon); +} + +TEST(Q8GEMM_6x4__NEON, k_eq_8_nozp) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(8) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_6x4__neon); +} + +TEST(Q8GEMM_6x4__NEON, k_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester().mr(6).nr(4).np(4).kr(1).m(6).n(4).k(k).test( + pytorch_q8gemm_ukernel_6x4__neon); + } +} + +TEST(Q8GEMM_6x4__NEON, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(k) + .aStride(37) + .test(pytorch_q8gemm_ukernel_6x4__neon); + } +} + +TEST(Q8GEMM_6x4__NEON, k_gt_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(k) + .cStride(17) + .test(pytorch_q8gemm_ukernel_6x4__neon); + } +} + +TEST(Q8GEMM_6x4__NEON, k_gt_8_azp0) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(k) + .aZeroPoint(0) + .test(pytorch_q8gemm_ukernel_6x4__neon); + } +} + +TEST(Q8GEMM_6x4__NEON, k_gt_8_bzp0) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(k) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_6x4__neon); + } +} + +TEST(Q8GEMM_6x4__NEON, k_gt_8_nozp) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(k) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_6x4__neon); + } +} + +TEST(Q8GEMM_6x4__NEON, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_q8gemm_ukernel_6x4__neon); + } + } + } +} + +TEST(Q8GEMM_6x4__NEON, k_div_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester().mr(6).nr(4).np(4).kr(1).m(6).n(4).k(k).test( + pytorch_q8gemm_ukernel_6x4__neon); + } +} + +TEST(Q8GEMM_6x4__NEON, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(k) + .aStride(171) + .test(pytorch_q8gemm_ukernel_6x4__neon); + } +} + +TEST(Q8GEMM_6x4__NEON, k_div_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(k) + .cStride(17) + .test(pytorch_q8gemm_ukernel_6x4__neon); + } +} + +TEST(Q8GEMM_6x4__NEON, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 24) { + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester().mr(6).nr(4).np(4).kr(1).m(m).n(n).k(k).test( + pytorch_q8gemm_ukernel_6x4__neon); + } + } + } +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(2).m(4).n(8).k(8).test( + pytorch_q8gemm_xzp_ukernel_4x8c2__neon); +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(8) + .aStride(37) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__neon); +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(8) + .cStride(17) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__neon); +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_qmin128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(2).m(4).n(8).k(8).qmin(128).test( + pytorch_q8gemm_xzp_ukernel_4x8c2__neon); +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_qmax128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(2).m(4).n(8).k(8).qmax(128).test( + pytorch_q8gemm_xzp_ukernel_4x8c2__neon); +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_azp0) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(8) + .aZeroPoint(0) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__neon); +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_bzp0) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(8) + .bZeroPoint(0) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__neon); +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_nozp) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(8) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__neon); +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(2).m(4).n(8).k(k).test( + pytorch_q8gemm_xzp_ukernel_4x8c2__neon); + } +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(k) + .aStride(37) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__neon); + } +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__neon); + } +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8_azp0) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(k) + .aZeroPoint(0) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__neon); + } +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8_bzp0) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(k) + .bZeroPoint(0) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__neon); + } +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8_nozp) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(k) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__neon); + } +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__neon); + } + } + } +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_div_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(2).m(4).n(8).k(k).test( + pytorch_q8gemm_xzp_ukernel_4x8c2__neon); + } +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(k) + .aStride(171) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__neon); + } +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_div_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__neon); + } +} + +TEST(Q8GEMM_4x8c2_XZP__NEON, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 24) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_q8gemm_xzp_ukernel_4x8c2__neon); + } + } + } +} +#endif + +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +TEST(Q8GEMM_2x4c8__SSE2, k_eq_8) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester().mr(2).nr(4).np(1).kr(8).m(2).n(4).k(8).test( + pytorch_q8gemm_ukernel_2x4c8__sse2); +} + +TEST(Q8GEMM_2x4c8__SSE2, k_eq_8_strided_a) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .np(1) + .kr(8) + .m(2) + .n(4) + .k(8) + .aStride(37) + .test(pytorch_q8gemm_ukernel_2x4c8__sse2); +} + +TEST(Q8GEMM_2x4c8__SSE2, k_eq_8_strided_c) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .np(1) + .kr(8) + .m(2) + .n(4) + .k(8) + .cStride(17) + .test(pytorch_q8gemm_ukernel_2x4c8__sse2); +} + +TEST(Q8GEMM_2x4c8__SSE2, k_eq_8_qmin128) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester().mr(2).nr(4).np(1).kr(8).m(2).n(4).k(8).qmin(128).test( + pytorch_q8gemm_ukernel_2x4c8__sse2); +} + +TEST(Q8GEMM_2x4c8__SSE2, k_eq_8_qmax128) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester().mr(2).nr(4).np(1).kr(8).m(2).n(4).k(8).qmax(128).test( + pytorch_q8gemm_ukernel_2x4c8__sse2); +} + +TEST(Q8GEMM_2x4c8__SSE2, k_eq_8_azp0) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .np(1) + .kr(8) + .m(2) + .n(4) + .k(8) + .aZeroPoint(0) + .test(pytorch_q8gemm_ukernel_2x4c8__sse2); +} + +TEST(Q8GEMM_2x4c8__SSE2, k_eq_8_bzp0) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .np(1) + .kr(8) + .m(2) + .n(4) + .k(8) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_2x4c8__sse2); +} + +TEST(Q8GEMM_2x4c8__SSE2, k_eq_8_nozp) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(2) + .nr(4) + .np(1) + .kr(8) + .m(2) + .n(4) + .k(8) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_2x4c8__sse2); +} + +TEST(Q8GEMM_2x4c8__SSE2, k_gt_8) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester().mr(2).nr(4).np(1).kr(8).m(2).n(4).k(k).test( + pytorch_q8gemm_ukernel_2x4c8__sse2); + } +} + +TEST(Q8GEMM_2x4c8__SSE2, k_gt_8_strided_a) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .np(1) + .kr(8) + .m(2) + .n(4) + .k(k) + .aStride(37) + .test(pytorch_q8gemm_ukernel_2x4c8__sse2); + } +} + +TEST(Q8GEMM_2x4c8__SSE2, k_gt_8_strided_c) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .np(1) + .kr(8) + .m(2) + .n(4) + .k(k) + .cStride(17) + .test(pytorch_q8gemm_ukernel_2x4c8__sse2); + } +} + +TEST(Q8GEMM_2x4c8__SSE2, k_gt_8_azp0) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .np(1) + .kr(8) + .m(2) + .n(4) + .k(k) + .aZeroPoint(0) + .test(pytorch_q8gemm_ukernel_2x4c8__sse2); + } +} + +TEST(Q8GEMM_2x4c8__SSE2, k_gt_8_bzp0) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .np(1) + .kr(8) + .m(2) + .n(4) + .k(k) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_2x4c8__sse2); + } +} + +TEST(Q8GEMM_2x4c8__SSE2, k_gt_8_nozp) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .np(1) + .kr(8) + .m(2) + .n(4) + .k(k) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_2x4c8__sse2); + } +} + +TEST(Q8GEMM_2x4c8__SSE2, k_gt_8_subtile) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .np(1) + .kr(8) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_q8gemm_ukernel_2x4c8__sse2); + } + } + } +} + +TEST(Q8GEMM_2x4c8__SSE2, k_div_8) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester().mr(2).nr(4).np(1).kr(8).m(2).n(4).k(k).test( + pytorch_q8gemm_ukernel_2x4c8__sse2); + } +} + +TEST(Q8GEMM_2x4c8__SSE2, k_div_8_strided_a) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .np(1) + .kr(8) + .m(2) + .n(4) + .k(k) + .aStride(171) + .test(pytorch_q8gemm_ukernel_2x4c8__sse2); + } +} + +TEST(Q8GEMM_2x4c8__SSE2, k_div_8_strided_c) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .np(1) + .kr(8) + .m(2) + .n(4) + .k(k) + .cStride(17) + .test(pytorch_q8gemm_ukernel_2x4c8__sse2); + } +} + +TEST(Q8GEMM_2x4c8__SSE2, k_div_8_subtile) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 16; k < 128; k += 24) { + for (uint32_t m = 1; m <= 2; m++) { + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(2) + .nr(4) + .np(1) + .kr(8) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_q8gemm_ukernel_2x4c8__sse2); + } + } + } +} + +// Following tests fail both on original QNNPack and the version +// with runtime requantization. + +#if 0 + TEST(Q8GEMM_4x4c2__SSE2, k_eq_1) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(1) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); + } + + TEST(Q8GEMM_4x4c2__SSE2, k_eq_1_strided_a) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(1) + .aStride(37) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); + } + + TEST(Q8GEMM_4x4c2__SSE2, k_eq_1_strided_c) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(1) + .cStride(17) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); + } + + TEST(Q8GEMM_4x4c2__SSE2, k_eq_1_qmin128) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(1) + .qmin(128) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); + } + + TEST(Q8GEMM_4x4c2__SSE2, k_eq_1_qmax128) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(1) + .qmax(128) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); + } + + TEST(Q8GEMM_4x4c2__SSE2, k_eq_1_azp0) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(1) + .aZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); + } + + TEST(Q8GEMM_4x4c2__SSE2, k_eq_1_bzp0) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(1) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); + } + + TEST(Q8GEMM_4x4c2__SSE2, k_eq_1_nozp) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(1) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); + } +#endif + +TEST(Q8GEMM_4x4c2__SSE2, k_lt_4) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester().mr(4).nr(4).np(4).kr(2).m(4).n(4).k(3).test( + pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_lt_4_strided_a) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(3) + .aStride(37) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_lt_4_strided_c) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(3) + .cStride(17) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_lt_4_qmin128) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester().mr(4).nr(4).np(4).kr(2).m(4).n(4).k(3).qmin(128).test( + pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_lt_4_qmax128) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester().mr(4).nr(4).np(4).kr(2).m(4).n(4).k(3).qmax(128).test( + pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_lt_4_azp0) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(3) + .aZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_lt_4_bzp0) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(3) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_lt_4_nozp) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(3) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_lt_8) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester().mr(4).nr(4).np(4).kr(2).m(4).n(4).k(5).test( + pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_lt_8_strided_a) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(5) + .aStride(37) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_lt_8_strided_c) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(5) + .cStride(17) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_lt_8_qmin128) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester().mr(4).nr(4).np(4).kr(2).m(4).n(4).k(5).qmin(128).test( + pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_lt_8_qmax128) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester().mr(4).nr(4).np(4).kr(2).m(4).n(4).k(5).qmax(128).test( + pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_lt_8_azp0) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(5) + .aZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_lt_8_bzp0) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(5) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_lt_8_nozp) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(5) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_eq_8) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester().mr(4).nr(4).np(4).kr(2).m(4).n(4).k(8).test( + pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_eq_8_strided_a) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(8) + .aStride(37) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_eq_8_strided_c) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(8) + .cStride(17) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_eq_8_qmin128) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester().mr(4).nr(4).np(4).kr(2).m(4).n(4).k(8).qmin(128).test( + pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_eq_8_qmax128) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester().mr(4).nr(4).np(4).kr(2).m(4).n(4).k(8).qmax(128).test( + pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_eq_8_azp0) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(8) + .aZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_eq_8_bzp0) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(8) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_eq_8_nozp) { + TEST_REQUIRES_X86_SSE2; + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(8) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); +} + +TEST(Q8GEMM_4x4c2__SSE2, k_gt_8) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester().mr(4).nr(4).np(4).kr(2).m(4).n(4).k(k).test( + pytorch_q8gemm_ukernel_4x4c2__sse2); + } +} + +TEST(Q8GEMM_4x4c2__SSE2, k_gt_8_strided_a) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(k) + .aStride(37) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); + } +} + +TEST(Q8GEMM_4x4c2__SSE2, k_gt_8_strided_c) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(k) + .cStride(17) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); + } +} + +TEST(Q8GEMM_4x4c2__SSE2, k_gt_8_azp0) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(k) + .aZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); + } +} + +TEST(Q8GEMM_4x4c2__SSE2, k_gt_8_bzp0) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(k) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); + } +} + +TEST(Q8GEMM_4x4c2__SSE2, k_gt_8_nozp) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(k) + .aZeroPoint(0) + .bZeroPoint(0) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); + } +} + +TEST(Q8GEMM_4x4c2__SSE2, k_gt_8_subtile) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 9; k < 16; k++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); + } + } + } +} + +TEST(Q8GEMM_4x4c2__SSE2, k_div_8) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester().mr(4).nr(4).np(4).kr(2).m(4).n(4).k(k).test( + pytorch_q8gemm_ukernel_4x4c2__sse2); + } +} + +TEST(Q8GEMM_4x4c2__SSE2, k_div_8_strided_a) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(k) + .aStride(171) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); + } +} + +TEST(Q8GEMM_4x4c2__SSE2, k_div_8_strided_c) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(4) + .n(4) + .k(k) + .cStride(17) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); + } +} + +TEST(Q8GEMM_4x4c2__SSE2, k_div_8_subtile) { + TEST_REQUIRES_X86_SSE2; + for (size_t k = 16; k < 128; k += 24) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(4) + .nr(4) + .np(4) + .kr(2) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_q8gemm_ukernel_4x4c2__sse2); + } + } + } +} +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/q8vadd.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/q8vadd.cc new file mode 100644 index 0000000000000..8d5a5a83b9455 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/q8vadd.cc @@ -0,0 +1,297 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include "vadd-microkernel-tester.h" + +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +TEST(Q8VADD__SSE2, n_eq_8) { + TEST_REQUIRES_X86_SSE2; + VAddMicrokernelTester().n(8).test(pytorch_q8vadd_ukernel__sse2); +} + +TEST(Q8VADD__SSE2, n_div_8) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 8; n < 128; n += 24) { + VAddMicrokernelTester().n(n).test(pytorch_q8vadd_ukernel__sse2); + } +} + +TEST(Q8VADD__SSE2, n_gt_8) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 9; n < 16; n++) { + VAddMicrokernelTester().n(n).test(pytorch_q8vadd_ukernel__sse2); + } +} + +TEST(Q8VADD__SSE2, n_lt_8) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 8; n++) { + VAddMicrokernelTester().n(n).test(pytorch_q8vadd_ukernel__sse2); + } +} + +TEST(Q8VADD__SSE2, inplace_a) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 128; n += 11) { + VAddMicrokernelTester().iterations(1).n(n).inplaceA(true).test( + pytorch_q8vadd_ukernel__sse2); + } +} + +TEST(Q8VADD__SSE2, inplace_b) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 128; n += 11) { + VAddMicrokernelTester().iterations(1).n(n).inplaceB(true).test( + pytorch_q8vadd_ukernel__sse2); + } +} + +TEST(Q8VADD__SSE2, inplace_a_and_b) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 128; n += 11) { + VAddMicrokernelTester() + .iterations(1) + .n(n) + .inplaceA(true) + .inplaceB(true) + .test(pytorch_q8vadd_ukernel__sse2); + } +} + +TEST(Q8VADD__SSE2, a_scale) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 128; n += 11) { + for (float aScale = 1.0e-2; aScale < 1.0e+2; aScale *= 1.7f) { + VAddMicrokernelTester().iterations(1).n(n).aScale(aScale).test( + pytorch_q8vadd_ukernel__sse2); + } + } +} + +TEST(Q8VADD__SSE2, b_scale) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 128; n += 11) { + for (float bScale = 1.0e-2; bScale < 1.0e+2; bScale *= 1.7f) { + VAddMicrokernelTester().iterations(1).n(n).bScale(bScale).test( + pytorch_q8vadd_ukernel__sse2); + } + } +} + +TEST(Q8VADD__SSE2, y_scale) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 128; n += 11) { + for (float yScale = 1.0e-2; yScale < 1.0e+2; yScale *= 1.7f) { + VAddMicrokernelTester().iterations(1).n(n).yScale(yScale).test( + pytorch_q8vadd_ukernel__sse2); + } + } +} + +TEST(Q8VADD__SSE2, a_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 128; n += 11) { + for (int32_t aZeroPoint = 0; aZeroPoint <= 255; aZeroPoint += 51) { + VAddMicrokernelTester() + .iterations(1) + .n(n) + .aZeroPoint(uint8_t(aZeroPoint)) + .test(pytorch_q8vadd_ukernel__sse2); + } + } +} + +TEST(Q8VADD__SSE2, b_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 128; n += 11) { + for (int32_t bZeroPoint = 0; bZeroPoint <= 255; bZeroPoint += 51) { + VAddMicrokernelTester() + .iterations(1) + .n(n) + .bZeroPoint(uint8_t(bZeroPoint)) + .test(pytorch_q8vadd_ukernel__sse2); + } + } +} + +TEST(Q8VADD__SSE2, y_zero_point) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 128; n += 11) { + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + VAddMicrokernelTester() + .iterations(1) + .n(n) + .yZeroPoint(uint8_t(yZeroPoint)) + .test(pytorch_q8vadd_ukernel__sse2); + } + } +} + +TEST(Q8VADD__SSE2, qmin) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 128; n += 11) { + VAddMicrokernelTester().iterations(1).n(n).qmin(128).test( + pytorch_q8vadd_ukernel__sse2); + } +} + +TEST(Q8VADD__SSE2, qmax) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 128; n += 11) { + VAddMicrokernelTester().iterations(1).n(n).qmax(128).test( + pytorch_q8vadd_ukernel__sse2); + } +} +#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */ + +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 +TEST(Q8VADD__NEON, n_eq_8) { + TEST_REQUIRES_ARM_NEON; + VAddMicrokernelTester().n(8).test(pytorch_q8vadd_ukernel__neon); +} + +TEST(Q8VADD__NEON, n_div_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 8; n < 128; n += 24) { + VAddMicrokernelTester().n(n).test(pytorch_q8vadd_ukernel__neon); + } +} + +TEST(Q8VADD__NEON, n_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + VAddMicrokernelTester().n(n).test(pytorch_q8vadd_ukernel__neon); + } +} + +TEST(Q8VADD__NEON, n_lt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 8; n++) { + VAddMicrokernelTester().n(n).test(pytorch_q8vadd_ukernel__neon); + } +} + +TEST(Q8VADD__NEON, inplace_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 128; n += 11) { + VAddMicrokernelTester().iterations(1).n(n).inplaceA(true).test( + pytorch_q8vadd_ukernel__neon); + } +} + +TEST(Q8VADD__NEON, inplace_b) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 128; n += 11) { + VAddMicrokernelTester().iterations(1).n(n).inplaceB(true).test( + pytorch_q8vadd_ukernel__neon); + } +} + +TEST(Q8VADD__NEON, inplace_a_and_b) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 128; n += 11) { + VAddMicrokernelTester() + .iterations(1) + .n(n) + .inplaceA(true) + .inplaceB(true) + .test(pytorch_q8vadd_ukernel__neon); + } +} + +TEST(Q8VADD__NEON, a_scale) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 128; n += 11) { + for (float aScale = 1.0e-2; aScale < 1.0e+2; aScale *= 1.7f) { + VAddMicrokernelTester().iterations(1).n(n).aScale(aScale).test( + pytorch_q8vadd_ukernel__neon); + } + } +} + +TEST(Q8VADD__NEON, b_scale) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 128; n += 11) { + for (float bScale = 1.0e-2; bScale < 1.0e+2; bScale *= 1.7f) { + VAddMicrokernelTester().iterations(1).n(n).bScale(bScale).test( + pytorch_q8vadd_ukernel__neon); + } + } +} + +TEST(Q8VADD__NEON, y_scale) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 128; n += 11) { + for (float yScale = 1.0e-2; yScale < 1.0e+2; yScale *= 1.7f) { + VAddMicrokernelTester().iterations(1).n(n).yScale(yScale).test( + pytorch_q8vadd_ukernel__neon); + } + } +} + +TEST(Q8VADD__NEON, a_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 128; n += 11) { + for (int32_t aZeroPoint = 0; aZeroPoint <= 255; aZeroPoint += 51) { + VAddMicrokernelTester() + .iterations(1) + .n(n) + .aZeroPoint(uint8_t(aZeroPoint)) + .test(pytorch_q8vadd_ukernel__neon); + } + } +} + +TEST(Q8VADD__NEON, b_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 128; n += 11) { + for (int32_t bZeroPoint = 0; bZeroPoint <= 255; bZeroPoint += 51) { + VAddMicrokernelTester() + .iterations(1) + .n(n) + .bZeroPoint(uint8_t(bZeroPoint)) + .test(pytorch_q8vadd_ukernel__neon); + } + } +} + +TEST(Q8VADD__NEON, y_zero_point) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 128; n += 11) { + for (int32_t yZeroPoint = 0; yZeroPoint <= 255; yZeroPoint += 51) { + VAddMicrokernelTester() + .iterations(1) + .n(n) + .yZeroPoint(uint8_t(yZeroPoint)) + .test(pytorch_q8vadd_ukernel__neon); + } + } +} + +TEST(Q8VADD__NEON, qmin) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 128; n += 11) { + VAddMicrokernelTester().iterations(1).n(n).qmin(128).test( + pytorch_q8vadd_ukernel__neon); + } +} + +TEST(Q8VADD__NEON, qmax) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 128; n += 11) { + VAddMicrokernelTester().iterations(1).n(n).qmax(128).test( + pytorch_q8vadd_ukernel__neon); + } +} +#endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/requantization-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/requantization-tester.h new file mode 100644 index 0000000000000..ef7abb5f9cda2 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/requantization-tester.h @@ -0,0 +1,470 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +class RequantizationTester { + public: + inline RequantizationTester& s(uint32_t s) { + this->s_ = s; + return *this; + } + + inline uint32_t s() const { + return this->s_; + } + + inline float scale() const { + return ldexpf(1.0f, -s()); + } + + inline RequantizationTester& zeroPoint(int32_t zeroPoint) { + this->zeroPoint_ = zeroPoint; + return *this; + } + + inline int32_t zeroPoint() const { + return this->zeroPoint_; + } + + inline RequantizationTester& qmin(uint8_t qmin) { + this->qmin_ = qmin; + return *this; + } + + inline uint8_t qmin() const { + return this->qmin_; + } + + inline RequantizationTester& qmax(uint8_t qmax) { + this->qmax_ = qmax; + return *this; + } + + inline uint8_t qmax() const { + return this->qmax_; + } + + inline RequantizationTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + /* + * Test that requantization of numbers ((i - zero point) * 2**s) with + * - scale = exp2(-s) + * - zero point in [0, 255] + * - no output clamping + * produces exactly i, provided that ((i - zero point) * 2**s) does not + * overflow. + */ + void testExactDivideByPO2(pytorch_requantization_function requantize) const { + ASSERT_GE(zeroPoint(), 0); + ASSERT_LE(zeroPoint(), 255); + + /* Note: need s >= 1 to ensure scale = exp2(-s) < 1.0 */ + ASSERT_GE(s(), 1); + ASSERT_LT(s(), 32); + + std::vector inputs(256); + std::vector outputs(inputs.size()); + const int32_t maxI = + (uint32_t(std::numeric_limits::max()) >> s()) + zeroPoint(); + const int32_t minI = + -(-uint32_t(std::numeric_limits::min()) >> s()) + zeroPoint(); + for (int32_t i = 0; i < 256; i++) { + const int32_t clampedI = std::max(minI, std::min(maxI, i)); + inputs[i] = int32_t(uint32_t(clampedI - zeroPoint()) << s()); + } + requantize( + inputs.size(), + inputs.data(), + scale(), + zeroPoint(), + qmin(), + qmax(), + outputs.data()); + for (int32_t i = 0; i < 256; i++) { + const int32_t clampedI = std::max(minI, std::min(maxI, i)); + ASSERT_EQ(clampedI, outputs[i]) + << "i = " << i << ", clamped i = " << clampedI << ", min i = " << minI + << ", max i = " << maxI << ", s = " << s() + << ", zero point = " << zeroPoint(); + } + } + + /* + * Test that requantization of numbers (i * 2**s + sign(i - zero point) * + * 2**(s-1)) with + * - scale = exp2(-s) + * - zero point in [1, 255] + * - no output clamping + * produces exactly i, provided that ((i - zero point) * 2**s) does not + * overflow. + */ + void testDivideByPO2WithRoundingUp(pytorch_requantization_function requantize) { + ASSERT_GE(zeroPoint(), 0); + ASSERT_LE(zeroPoint(), 255); + + /* Note: need s >= 1 to ensure scale = exp2(-s) < 1.0 */ + ASSERT_GE(s(), 1); + ASSERT_LT(s(), 32); + + std::vector inputs(256); + std::vector outputs(inputs.size()); + for (int32_t i = 0; i < 256; i++) { + const int64_t input = + RequantizationTester::shiftLeft(i - zeroPoint(), s()) - + (INT64_C(1) << (s() - 1)) + (int64_t)(i <= zeroPoint()); + inputs[i] = int32_t(input); + } + requantize( + inputs.size(), + inputs.data(), + scale(), + zeroPoint(), + qmin(), + qmax(), + outputs.data()); + for (int32_t i = 0; i < 256; i++) { + const int64_t input = + RequantizationTester::shiftLeft(i - zeroPoint(), s()) - + (INT64_C(1) << (s() - 1)) + (int64_t)(i <= zeroPoint()); + if (int32_t(input) == input) { + ASSERT_EQ(i, uint32_t(outputs[i])) + << "i = " << i << ", input = " << input << ", s = " << s() + << ", zero point = " << zeroPoint(); + } + } + } + + /* + * Test that requantization of numbers (i * 2**s + sign(i - zero point) * + * 2**(s-1)) with + * - scale = exp2(-s) + * - zero point in [1, 255] + * - no output clamping + * produces exactly i, provided that ((i - zero point) * 2**s) does not + * overflow. + */ + void testDivideByPO2WithRoundingDown(pytorch_requantization_function requantize) { + ASSERT_GE(zeroPoint(), 0); + ASSERT_LE(zeroPoint(), 255); + + /* Note: need s >= 1 to ensure scale = exp2(-s) < 1.0 */ + ASSERT_GE(s(), 1); + ASSERT_LT(s(), 32); + + std::vector inputs(256); + std::vector outputs(inputs.size()); + for (int32_t i = 0; i < 256; i++) { + const int64_t input = + RequantizationTester::shiftLeft(i - zeroPoint(), s()) + + (INT64_C(1) << (s() - 1)) - (int64_t)(i >= zeroPoint()); + inputs[i] = int32_t(input); + } + requantize( + inputs.size(), + inputs.data(), + scale(), + zeroPoint(), + qmin(), + qmax(), + outputs.data()); + for (int32_t i = 0; i < 256; i++) { + const int64_t input = + RequantizationTester::shiftLeft(i - zeroPoint(), s()) + + (INT64_C(1) << (s() - 1)) - (int64_t)(i >= zeroPoint()); + if (int32_t(input) == input) { + ASSERT_EQ(i, uint32_t(outputs[i])) + << "i = " << i << ", input = " << input << ", s = " << s() + << ", zero point = " << zeroPoint(); + } + } + } + + void testDivideByPO2WithRoundingAway(pytorch_requantization_function requantize) { + ASSERT_GE(zeroPoint(), 0); + ASSERT_LE(zeroPoint(), 255); + + /* Note: need s >= 1 to ensure scale = exp2(-s) < 1.0 */ + ASSERT_GE(s(), 1); + ASSERT_LT(s(), 32); + + std::vector inputs(256); + std::vector outputs(inputs.size()); + for (int32_t i = 0; i < 256; i++) { + int64_t input = RequantizationTester::shiftLeft(i - zeroPoint(), s()); + if (input > 0) { + input -= INT64_C(1) << (s() - 1); + } else if (input < 0) { + input += INT64_C(1) << (s() - 1); + } + inputs[i] = int32_t(input); + } + requantize( + inputs.size(), + inputs.data(), + scale(), + zeroPoint(), + qmin(), + qmax(), + outputs.data()); + for (uint32_t i = 0; i < 256; i++) { + int64_t input = RequantizationTester::shiftLeft(i - zeroPoint(), s()); + if (input > 0) { + input -= INT64_C(1) << (s() - 1); + } else if (input < 0) { + input += INT64_C(1) << (s() - 1); + } + if (int32_t(input) == input) { + ASSERT_EQ(i, uint32_t(outputs[i])) + << "i = " << i << ", input = " << input << ", s = " << s() + << ", zero point = " << zeroPoint(); + } + } + } + + void testSpecialCases(pytorch_requantization_function requantize) { + std::vector inputs(256); + std::vector outputs(inputs.size()); + + std::fill( + inputs.begin(), inputs.end(), std::numeric_limits::min()); + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + requantize( + inputs.size(), + inputs.data(), + ldexpf(1.0f, -32) /* scale */, + zeroPoint /* zero point */, + std::numeric_limits::min(), + std::numeric_limits::max(), + outputs.data()); + ASSERT_EQ( + std::max(int32_t(0), zeroPoint - 1), + *std::min_element(outputs.cbegin(), outputs.cend())); + } + + std::fill( + inputs.begin(), inputs.end(), std::numeric_limits::max()); + requantize( + inputs.size(), + inputs.data(), + 0x1.FFFFFEp-1f /* scale */, + std::numeric_limits::max() /* zero point */, + std::numeric_limits::min(), + std::numeric_limits::max(), + outputs.data()); + for (size_t i = 0; i < inputs.size(); i++) { + ASSERT_EQ(std::numeric_limits::max(), outputs[i]); + } + } + + void testRandomCasesPrecise(pytorch_requantization_function requantize) { + std::random_device randomDevice; + std::mt19937 mtRng(randomDevice()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + auto rng = std::bind(std::uniform_int_distribution(), mtRng); + + std::vector inputs(4096); + std::vector outputs(inputs.size()); + + const uint8_t zeroPoint = UINT8_C(128); + std::uniform_real_distribution scaleDistribution( + 0x1.000000p-23f, 0x1.FFFFFEp-1f); + const float scale = scaleDistribution(mtRng); + for (size_t i = 0; i < inputs.size(); i++) { + const uint8_t approximateOutput = rng(); + const int32_t input = + int32_t(double(approximateOutput) / double(scale)); + inputs[i] = input; + } + + requantize( + inputs.size(), + inputs.data(), + scale, + zeroPoint, + std::numeric_limits::min(), + std::numeric_limits::max(), + outputs.data()); + + /* Ensure that outputs are not all identical, as in this case test doesn't + * validate much */ + ASSERT_NE( + *std::max_element(outputs.cbegin(), outputs.cend()), + *std::min_element(outputs.cbegin(), outputs.cend())); + + for (size_t i = 0; i < inputs.size(); i++) { + const uint8_t referenceOutput = pytorch_scalar_requantize_precise( + inputs[i], + scale, + zeroPoint, + std::numeric_limits::min(), + std::numeric_limits::max()); + ASSERT_EQ(uint32_t(referenceOutput), uint32_t(outputs[i])); + } + } + } + + void testRandomCasesApproximate(pytorch_requantization_function requantize) { + std::random_device randomDevice; + std::mt19937 mtRng(randomDevice()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + auto rng = std::bind(std::uniform_int_distribution(), mtRng); + + std::vector inputs(4096); + std::vector outputs(inputs.size()); + + const uint8_t zeroPoint = UINT8_C(128); + std::uniform_real_distribution scaleDistribution( + 0x1.000000p-23f, 0x1.FFFFFEp-1f); + const float scale = scaleDistribution(mtRng); + for (size_t i = 0; i < inputs.size(); i++) { + const uint8_t approximateOutput = rng(); + const int32_t input = + int32_t(double(approximateOutput) / double(scale)); + inputs[i] = input; + } + + requantize( + inputs.size(), + inputs.data(), + scale, + zeroPoint, + std::numeric_limits::min(), + std::numeric_limits::max(), + outputs.data()); + + /* Ensure that outputs are not all identical, as in this case test doesn't + * validate much */ + ASSERT_NE( + *std::max_element(outputs.cbegin(), outputs.cend()), + *std::min_element(outputs.cbegin(), outputs.cend())); + + for (size_t i = 0; i < inputs.size(); i++) { + const double referenceOutput = + RequantizationTester::requantizeApproximate( + inputs[i], + scale, + zeroPoint, + std::numeric_limits::min(), + std::numeric_limits::max()); + ASSERT_LE(fabs(referenceOutput - double(outputs[i])), 0.55) + << "input = " << inputs[i] << ", output = " << uint32_t(outputs[i]) + << ", reference output = " << referenceOutput; + } + } + } + + void testRandomCasesAgainstReference( + pytorch_requantization_function requantize, + pytorch_requantization_function requantizeReference) { + std::random_device randomDevice; + std::mt19937 mtRng(randomDevice()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + auto rng = std::bind(std::uniform_int_distribution(), mtRng); + + std::vector inputs(4096); + std::vector outputs(inputs.size()); + std::vector referenceOutputs(inputs.size()); + + const uint8_t zeroPoint = UINT8_C(128); + std::uniform_real_distribution scaleDistribution( + 0x1.000000p-23f, 0x1.FFFFFEp-1f); + const float scale = scaleDistribution(mtRng); + for (size_t i = 0; i < inputs.size(); i++) { + const uint8_t approximateOutput = rng(); + const int32_t input = + int32_t(double(approximateOutput) / double(scale)); + inputs[i] = input; + } + + requantize( + inputs.size(), + inputs.data(), + scale, + zeroPoint, + std::numeric_limits::min(), + std::numeric_limits::max(), + outputs.data()); + + requantizeReference( + inputs.size(), + inputs.data(), + scale, + zeroPoint, + std::numeric_limits::min(), + std::numeric_limits::max(), + referenceOutputs.data()); + + /* Ensure that outputs are not all identical, as in this case test doesn't + * validate much */ + ASSERT_NE( + *std::max_element(outputs.cbegin(), outputs.cend()), + *std::min_element(outputs.cbegin(), outputs.cend())); + + for (size_t i = 0; i < inputs.size(); i++) { + ASSERT_EQ(uint32_t(referenceOutputs[i]), uint32_t(outputs[i])); + } + } + } + + static inline int64_t shiftLeft(int64_t w, uint32_t n) { + return (int64_t)((uint64_t)w << n); + } + + static inline double requantizeApproximate( + int32_t value, + float scale, + uint8_t zeroPoint, + uint8_t qmin, + uint8_t qmax) { + assert(scale < 1.0f); + assert(scale >= 0x1.0p-32f); + + double clampedValue = double(value) * double(scale) + double(zeroPoint); + + const double fmin = double(qmin); + if (clampedValue < fmin) { + clampedValue = fmin; + } + + const double fmax = double(qmax); + if (clampedValue > fmax) { + clampedValue = fmax; + } + + return clampedValue; + } + + private: + size_t zeroPoint_{0}; + size_t s_{1}; + uint8_t qmin_{std::numeric_limits::min()}; + uint8_t qmax_{std::numeric_limits::max()}; + size_t iterations_{1}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/requantization.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/requantization.cc new file mode 100644 index 0000000000000..a837974dd9fc0 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/requantization.cc @@ -0,0 +1,1077 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include + +#include "requantization-tester.h" + +/* + * Precise scalar implementation using unsigned 32-bit arithmetics. + */ + +TEST(PRECISE__SCALAR_UNSIGNED32, exact_divide_by_po2) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_precise__scalar_unsigned32); + } +} + +TEST(PRECISE__SCALAR_UNSIGNED32, exact_divide_by_po2_with_zero_point) { + for (int32_t zeroPoint = 1; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().zeroPoint(zeroPoint).s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_precise__scalar_unsigned32); + } + } +} + +TEST(PRECISE__SCALAR_UNSIGNED32, divide_by_po2_with_rounding_up) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingUp( + pytorch_qnnp_requantize_precise__scalar_unsigned32); + } + } +} + +TEST(PRECISE__SCALAR_UNSIGNED32, divide_by_po2_with_rounding_down) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingDown( + pytorch_qnnp_requantize_precise__scalar_unsigned32); + } + } +} + +TEST(PRECISE__SCALAR_UNSIGNED32, divide_by_po2_with_rounding_away) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingAway( + pytorch_qnnp_requantize_precise__scalar_unsigned32); + } + } +} + +TEST(PRECISE__SCALAR_UNSIGNED32, special_cases) { + RequantizationTester().testSpecialCases( + pytorch_qnnp_requantize_precise__scalar_unsigned32); +} + +TEST(PRECISE__SCALAR_UNSIGNED32, random_cases) { + RequantizationTester().iterations(100).testRandomCasesPrecise( + pytorch_qnnp_requantize_precise__scalar_unsigned32); +} + +/* + * Precise scalar implementation using unsigned 64-bit arithmetics. + */ + +TEST(PRECISE__SCALAR_UNSIGNED64, exact_divide_by_po2) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_precise__scalar_unsigned64); + } +} + +TEST(PRECISE__SCALAR_UNSIGNED64, exact_divide_by_po2_with_zero_point) { + for (int32_t zeroPoint = 1; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().zeroPoint(zeroPoint).s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_precise__scalar_unsigned64); + } + } +} + +TEST(PRECISE__SCALAR_UNSIGNED64, divide_by_po2_with_rounding_up) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingUp( + pytorch_qnnp_requantize_precise__scalar_unsigned64); + } + } +} + +TEST(PRECISE__SCALAR_UNSIGNED64, divide_by_po2_with_rounding_down) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingDown( + pytorch_qnnp_requantize_precise__scalar_unsigned64); + } + } +} + +TEST(PRECISE__SCALAR_UNSIGNED64, divide_by_po2_with_rounding_away) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingAway( + pytorch_qnnp_requantize_precise__scalar_unsigned64); + } + } +} + +TEST(PRECISE__SCALAR_UNSIGNED64, special_cases) { + RequantizationTester().testSpecialCases( + pytorch_qnnp_requantize_precise__scalar_unsigned64); +} + +TEST(PRECISE__SCALAR_UNSIGNED64, random_cases) { + RequantizationTester().iterations(100).testRandomCasesPrecise( + pytorch_qnnp_requantize_precise__scalar_unsigned64); +} + +/* + * Precise scalar implementation using signed 64-bit arithmetics. + */ + +TEST(PRECISE__SCALAR_SIGNED64, exact_divide_by_po2) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_precise__scalar_signed64); + } +} + +TEST(PRECISE__SCALAR_SIGNED64, exact_divide_by_po2_with_zero_point) { + for (int32_t zeroPoint = 1; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().zeroPoint(zeroPoint).s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_precise__scalar_signed64); + } + } +} + +TEST(PRECISE__SCALAR_SIGNED64, divide_by_po2_with_rounding_up) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingUp( + pytorch_qnnp_requantize_precise__scalar_signed64); + } + } +} + +TEST(PRECISE__SCALAR_SIGNED64, divide_by_po2_with_rounding_down) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingDown( + pytorch_qnnp_requantize_precise__scalar_signed64); + } + } +} + +TEST(PRECISE__SCALAR_SIGNED64, divide_by_po2_with_rounding_away) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingAway( + pytorch_qnnp_requantize_precise__scalar_signed64); + } + } +} + +TEST(PRECISE__SCALAR_SIGNED64, special_cases) { + RequantizationTester().testSpecialCases( + pytorch_qnnp_requantize_precise__scalar_signed64); +} + +TEST(PRECISE__SCALAR_SIGNED64, random_cases) { + RequantizationTester().iterations(100).testRandomCasesPrecise( + pytorch_qnnp_requantize_precise__scalar_signed64); +} + +/* + * FP32-based scalar implementation using lrintf function. + */ + +TEST(FP32__SCALAR_LRINTF, random_cases) { + RequantizationTester().iterations(1000).testRandomCasesApproximate( + pytorch_qnnp_requantize_fp32__scalar_lrintf); +} + +/* + * FP32-based scalar implementation using magic trick for FP32->INT32 + * conversion. + */ + +TEST(FP32__SCALAR_MAGIC, random_cases) { + RequantizationTester().iterations(1000).testRandomCasesApproximate( + pytorch_qnnp_requantize_fp32__scalar_magic); +} + +/* + * Q31-based scalar implementation. + */ + +TEST(Q31__SCALAR, exact_divide_by_po2) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_q31__scalar); + } +} + +TEST(Q31__SCALAR, exact_divide_by_po2_with_zero_point) { + for (int32_t zeroPoint = 1; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().zeroPoint(zeroPoint).s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_q31__scalar); + } + } +} + +TEST(Q31__SCALAR, divide_by_po2_with_rounding_up) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingUp(pytorch_qnnp_requantize_q31__scalar); + } + } +} + +/* No rounding down test - it fails because of upward bias in multiplication */ + +TEST(Q31__SCALAR, divide_by_po2_with_rounding_away) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingAway(pytorch_qnnp_requantize_q31__scalar); + } + } +} + +TEST(Q31__SCALAR, special_cases) { + RequantizationTester().testSpecialCases(pytorch_qnnp_requantize_q31__scalar); +} + +TEST(Q31__SCALAR, random_cases) { + RequantizationTester().iterations(100).testRandomCasesApproximate( + pytorch_qnnp_requantize_q31__scalar); +} + +TEST(Q31__SCALAR, random_match_gemmlowp) { + RequantizationTester().iterations(100).testRandomCasesAgainstReference( + pytorch_qnnp_requantize_q31__scalar, + pytorch_qnnp_requantize_gemmlowp__scalar); +} + +/* + * Scalar implementation from gemmlowp. + */ + +TEST(GEMMLOWP__SCALAR, random_cases) { + RequantizationTester().iterations(100).testRandomCasesApproximate( + pytorch_qnnp_requantize_gemmlowp__scalar); +} + +/* + * Precise PSIMD implementation using unsigned 32-bit arithmetics. + */ + +TEST(PRECISE__PSIMD, exact_divide_by_po2) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_precise__psimd); + } +} + +TEST(PRECISE__PSIMD, exact_divide_by_po2_with_zero_point) { + for (int32_t zeroPoint = 1; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().zeroPoint(zeroPoint).s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_precise__psimd); + } + } +} + +TEST(PRECISE__PSIMD, divide_by_po2_with_rounding_up) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingUp( + pytorch_qnnp_requantize_precise__psimd); + } + } +} + +TEST(PRECISE__PSIMD, divide_by_po2_with_rounding_down) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingDown( + pytorch_qnnp_requantize_precise__psimd); + } + } +} + +TEST(PRECISE__PSIMD, divide_by_po2_with_rounding_away) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingAway( + pytorch_qnnp_requantize_precise__psimd); + } + } +} + +TEST(PRECISE__PSIMD, special_cases) { + RequantizationTester().testSpecialCases( + pytorch_qnnp_requantize_precise__psimd); +} + +TEST(PRECISE__PSIMD, random_cases) { + RequantizationTester().iterations(100).testRandomCasesPrecise( + pytorch_qnnp_requantize_precise__psimd); +} + +/* + * FP32-based PSIMD implementation using magic trick for FP32->INT32 conversion. + */ + +TEST(FP32__PSIMD, random_cases) { + RequantizationTester().iterations(1000).testRandomCasesApproximate( + pytorch_qnnp_requantize_fp32__psimd); +} + +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + +/* + * Precise SSE2 implementation using floating-point shuffle. + */ + +TEST(PRECISE__SSE2, exact_divide_by_po2) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_precise__sse2); + } +} + +TEST(PRECISE__SSE2, exact_divide_by_po2_with_zero_point) { + for (int32_t zeroPoint = 1; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().zeroPoint(zeroPoint).s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_precise__sse2); + } + } +} + +TEST(PRECISE__SSE2, divide_by_po2_with_rounding_up) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingUp(pytorch_qnnp_requantize_precise__sse2); + } + } +} + +TEST(PRECISE__SSE2, divide_by_po2_with_rounding_down) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingDown( + pytorch_qnnp_requantize_precise__sse2); + } + } +} + +TEST(PRECISE__SSE2, divide_by_po2_with_rounding_away) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingAway( + pytorch_qnnp_requantize_precise__sse2); + } + } +} + +TEST(PRECISE__SSE2, special_cases) { + RequantizationTester().testSpecialCases( + pytorch_qnnp_requantize_precise__sse2); +} + +TEST(PRECISE__SSE2, random_cases) { + RequantizationTester().iterations(100).testRandomCasesPrecise( + pytorch_qnnp_requantize_precise__sse2); +} + +/* + * Precise SSSE3 implementation using floating-point shuffle. + */ + +TEST(PRECISE__SSSE3, exact_divide_by_po2) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_precise__ssse3); + } +} + +TEST(PRECISE__SSSE3, exact_divide_by_po2_with_zero_point) { + for (int32_t zeroPoint = 1; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().zeroPoint(zeroPoint).s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_precise__ssse3); + } + } +} + +TEST(PRECISE__SSSE3, divide_by_po2_with_rounding_up) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingUp( + pytorch_qnnp_requantize_precise__ssse3); + } + } +} + +TEST(PRECISE__SSSE3, divide_by_po2_with_rounding_down) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingDown( + pytorch_qnnp_requantize_precise__ssse3); + } + } +} + +TEST(PRECISE__SSSE3, divide_by_po2_with_rounding_away) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingAway( + pytorch_qnnp_requantize_precise__ssse3); + } + } +} + +TEST(PRECISE__SSSE3, special_cases) { + RequantizationTester().testSpecialCases( + pytorch_qnnp_requantize_precise__ssse3); +} + +TEST(PRECISE__SSSE3, random_cases) { + RequantizationTester().iterations(100).testRandomCasesPrecise( + pytorch_qnnp_requantize_precise__ssse3); +} + +/* + * Precise SSE4.1 implementation using static blend instruction. + */ + +TEST(PRECISE__SSE4, exact_divide_by_po2) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_precise__sse4); + } +} + +TEST(PRECISE__SSE4, exact_divide_by_po2_with_zero_point) { + for (int32_t zeroPoint = 1; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().zeroPoint(zeroPoint).s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_precise__sse4); + } + } +} + +TEST(PRECISE__SSE4, divide_by_po2_with_rounding_up) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingUp(pytorch_qnnp_requantize_precise__sse4); + } + } +} + +TEST(PRECISE__SSE4, divide_by_po2_with_rounding_down) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingDown( + pytorch_qnnp_requantize_precise__sse4); + } + } +} + +TEST(PRECISE__SSE4, divide_by_po2_with_rounding_away) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingAway( + pytorch_qnnp_requantize_precise__sse4); + } + } +} + +TEST(PRECISE__SSE4, special_cases) { + RequantizationTester().testSpecialCases( + pytorch_qnnp_requantize_precise__sse4); +} + +TEST(PRECISE__SSE4, random_cases) { + RequantizationTester().iterations(100).testRandomCasesPrecise( + pytorch_qnnp_requantize_precise__sse4); +} + +/* + * FP32-based x86 SSE2 implementation. + */ + +TEST(FP32__SSE2, random_cases) { + RequantizationTester().iterations(1000).testRandomCasesApproximate( + pytorch_qnnp_requantize_fp32__sse2); +} + +/* + * Q31-based x86 SSE2 implementation. + */ + +TEST(Q31__SSE2, exact_divide_by_po2) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_q31__sse2); + } +} + +TEST(Q31__SSE2, exact_divide_by_po2_with_zero_point) { + for (int32_t zeroPoint = 1; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().zeroPoint(zeroPoint).s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_q31__sse2); + } + } +} + +TEST(Q31__SSE2, divide_by_po2_with_rounding_up) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingUp(pytorch_qnnp_requantize_q31__sse2); + } + } +} + +/* No rounding down test - it fails because of upward bias in multiplication */ + +TEST(Q31__SSE2, divide_by_po2_with_rounding_away) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingAway(pytorch_qnnp_requantize_q31__sse2); + } + } +} + +TEST(Q31__SSE2, special_cases) { + RequantizationTester().testSpecialCases(pytorch_qnnp_requantize_q31__sse2); +} + +TEST(Q31__SSE2, random_cases) { + RequantizationTester().iterations(100).testRandomCasesApproximate( + pytorch_qnnp_requantize_q31__sse2); +} + +TEST(Q31__SSE2, random_match_gemmlowp) { + RequantizationTester().iterations(100).testRandomCasesAgainstReference( + pytorch_qnnp_requantize_q31__sse2, + pytorch_qnnp_requantize_gemmlowp__sse2); +} + +/* + * Q31-based x86 SSSE3 implementation. + */ + +TEST(Q31__SSSE3, exact_divide_by_po2) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_q31__ssse3); + } +} + +TEST(Q31__SSSE3, exact_divide_by_po2_with_zero_point) { + for (int32_t zeroPoint = 1; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().zeroPoint(zeroPoint).s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_q31__ssse3); + } + } +} + +TEST(Q31__SSSE3, divide_by_po2_with_rounding_up) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingUp(pytorch_qnnp_requantize_q31__ssse3); + } + } +} + +/* No rounding down test - it fails because of upward bias in multiplication */ + +TEST(Q31__SSSE3, divide_by_po2_with_rounding_away) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingAway(pytorch_qnnp_requantize_q31__ssse3); + } + } +} + +TEST(Q31__SSSE3, special_cases) { + RequantizationTester().testSpecialCases(pytorch_qnnp_requantize_q31__ssse3); +} + +TEST(Q31__SSSE3, random_cases) { + RequantizationTester().iterations(100).testRandomCasesApproximate( + pytorch_qnnp_requantize_q31__ssse3); +} + +TEST(Q31__SSSE3, random_match_gemmlowp) { + RequantizationTester().iterations(100).testRandomCasesAgainstReference( + pytorch_qnnp_requantize_q31__ssse3, + pytorch_qnnp_requantize_gemmlowp__ssse3); +} + +/* + * Q31-based x86 SSE4 implementation. + */ + +TEST(Q31__SSE4, exact_divide_by_po2) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_q31__sse4); + } +} + +TEST(Q31__SSE4, exact_divide_by_po2_with_zero_point) { + for (int32_t zeroPoint = 1; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().zeroPoint(zeroPoint).s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_q31__sse4); + } + } +} + +TEST(Q31__SSE4, divide_by_po2_with_rounding_up) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingUp(pytorch_qnnp_requantize_q31__sse4); + } + } +} + +/* No rounding down test - it fails because of upward bias in multiplication */ + +TEST(Q31__SSE4, divide_by_po2_with_rounding_away) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingAway(pytorch_qnnp_requantize_q31__sse4); + } + } +} + +TEST(Q31__SSE4, special_cases) { + RequantizationTester().testSpecialCases(pytorch_qnnp_requantize_q31__sse4); +} + +TEST(Q31__SSE4, random_cases) { + RequantizationTester().iterations(100).testRandomCasesApproximate( + pytorch_qnnp_requantize_q31__sse4); +} + +TEST(Q31__SSE4, random_match_gemmlowp) { + RequantizationTester().iterations(100).testRandomCasesAgainstReference( + pytorch_qnnp_requantize_q31__sse4, + pytorch_qnnp_requantize_gemmlowp__sse4); +} + +/* + * x86 SSE2 implementation from gemmlowp. + */ + +TEST(GEMMLOWP__SSE2, exact_divide_by_po2) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_gemmlowp__sse2); + } +} + +TEST(GEMMLOWP__SSE2, exact_divide_by_po2_with_zero_point) { + for (int32_t zeroPoint = 1; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().zeroPoint(zeroPoint).s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_gemmlowp__sse2); + } + } +} + +TEST(GEMMLOWP__SSE2, divide_by_po2_with_rounding_up) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingUp( + pytorch_qnnp_requantize_gemmlowp__sse2); + } + } +} + +/* No rounding down test - it fails because of upward bias in multiplication */ + +TEST(GEMMLOWP__SSE2, divide_by_po2_with_rounding_away) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingAway( + pytorch_qnnp_requantize_gemmlowp__sse2); + } + } +} + +TEST(GEMMLOWP__SSE2, special_cases) { + RequantizationTester().testSpecialCases( + pytorch_qnnp_requantize_gemmlowp__sse2); +} + +TEST(GEMMLOWP__SSE2, random_cases) { + RequantizationTester().iterations(100).testRandomCasesApproximate( + pytorch_qnnp_requantize_gemmlowp__sse2); +} + +/* + * x86 SSSE3 implementation from gemmlowp. + */ + +TEST(GEMMLOWP__SSSE3, exact_divide_by_po2) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_gemmlowp__ssse3); + } +} + +TEST(GEMMLOWP__SSSE3, exact_divide_by_po2_with_zero_point) { + for (int32_t zeroPoint = 1; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().zeroPoint(zeroPoint).s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_gemmlowp__ssse3); + } + } +} + +TEST(GEMMLOWP__SSSE3, divide_by_po2_with_rounding_up) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingUp( + pytorch_qnnp_requantize_gemmlowp__ssse3); + } + } +} + +/* No rounding down test - it fails because of upward bias in multiplication */ + +TEST(GEMMLOWP__SSSE3, divide_by_po2_with_rounding_away) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingAway( + pytorch_qnnp_requantize_gemmlowp__ssse3); + } + } +} + +TEST(GEMMLOWP__SSSE3, special_cases) { + RequantizationTester().testSpecialCases( + pytorch_qnnp_requantize_gemmlowp__ssse3); +} + +TEST(GEMMLOWP__SSSE3, random_cases) { + RequantizationTester().iterations(100).testRandomCasesApproximate( + pytorch_qnnp_requantize_gemmlowp__ssse3); +} + +/* + * x86 SSE4 implementation from gemmlowp. + */ + +TEST(GEMMLOWP__SSE4, exact_divide_by_po2) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_gemmlowp__sse4); + } +} + +TEST(GEMMLOWP__SSE4, exact_divide_by_po2_with_zero_point) { + for (int32_t zeroPoint = 1; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().zeroPoint(zeroPoint).s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_gemmlowp__sse4); + } + } +} + +TEST(GEMMLOWP__SSE4, divide_by_po2_with_rounding_up) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingUp( + pytorch_qnnp_requantize_gemmlowp__sse4); + } + } +} + +/* No rounding down test - it fails because of upward bias in multiplication */ + +TEST(GEMMLOWP__SSE4, divide_by_po2_with_rounding_away) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingAway( + pytorch_qnnp_requantize_gemmlowp__sse4); + } + } +} + +TEST(GEMMLOWP__SSE4, special_cases) { + RequantizationTester().testSpecialCases( + pytorch_qnnp_requantize_gemmlowp__sse4); +} + +TEST(GEMMLOWP__SSE4, random_cases) { + RequantizationTester().iterations(100).testRandomCasesApproximate( + pytorch_qnnp_requantize_gemmlowp__sse4); +} + +#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */ + +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + +/* + * Precise ARM NEON implementation. + */ + +TEST(PRECISE__NEON, exact_divide_by_po2) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_precise__neon); + } +} + +TEST(PRECISE__NEON, exact_divide_by_po2_with_zero_point) { + for (int32_t zeroPoint = 1; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().zeroPoint(zeroPoint).s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_precise__neon); + } + } +} + +TEST(PRECISE__NEON, divide_by_po2_with_rounding_up) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingUp(pytorch_qnnp_requantize_precise__neon); + } + } +} + +TEST(PRECISE__NEON, divide_by_po2_with_rounding_down) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingDown( + pytorch_qnnp_requantize_precise__neon); + } + } +} + +TEST(PRECISE__NEON, divide_by_po2_with_rounding_away) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingAway( + pytorch_qnnp_requantize_precise__neon); + } + } +} + +TEST(PRECISE__NEON, special_cases) { + RequantizationTester().testSpecialCases( + pytorch_qnnp_requantize_precise__neon); +} + +TEST(PRECISE__NEON, random_cases) { + RequantizationTester().iterations(100).testRandomCasesPrecise( + pytorch_qnnp_requantize_precise__neon); +} + +/* + * FP32-based ARM NEON implementation. + */ + +TEST(FP32__NEON, random_cases) { + RequantizationTester().iterations(1000).testRandomCasesApproximate( + pytorch_qnnp_requantize_fp32__neon); +} + +/* + * Q31-based ARM NEON implementation. + */ + +TEST(Q31__NEON, exact_divide_by_po2) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_q31__neon); + } +} + +TEST(Q31__NEON, exact_divide_by_po2_with_zero_point) { + for (int32_t zeroPoint = 1; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester().zeroPoint(zeroPoint).s(s).testExactDivideByPO2( + pytorch_qnnp_requantize_q31__neon); + } + } +} + +TEST(Q31__NEON, divide_by_po2_with_rounding_up) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingUp(pytorch_qnnp_requantize_q31__neon); + } + } +} + +/* No rounding down test - it fails because of upward bias in multiplication */ + +TEST(Q31__NEON, divide_by_po2_with_rounding_away) { + for (int32_t zeroPoint = 0; zeroPoint < 256; zeroPoint++) { + for (uint32_t s = 1; s < 32; s++) { + RequantizationTester() + .zeroPoint(zeroPoint) + .s(s) + .testDivideByPO2WithRoundingAway(pytorch_qnnp_requantize_q31__neon); + } + } +} + +TEST(Q31__NEON, special_cases) { + RequantizationTester().testSpecialCases(pytorch_qnnp_requantize_q31__neon); +} + +TEST(Q31__NEON, random_cases) { + RequantizationTester().iterations(100).testRandomCasesApproximate( + pytorch_qnnp_requantize_q31__neon); +} + +TEST(Q31__NEON, random_match_gemmlowp) { + RequantizationTester().iterations(100).testRandomCasesAgainstReference( + pytorch_qnnp_requantize_q31__neon, + pytorch_qnnp_requantize_gemmlowp__neon); +} + +/* + * ARM NEON implementation from gemmlowp. + */ + +TEST(GEMMLOWP__NEON, random_cases) { + RequantizationTester().iterations(100).testRandomCasesApproximate( + pytorch_qnnp_requantize_gemmlowp__neon); +} + +#endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/rmax-microkernel-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/rmax-microkernel-tester.h new file mode 100644 index 0000000000000..16e6487baf60d --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/rmax-microkernel-tester.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +class RMaxMicrokernelTester { + public: + inline RMaxMicrokernelTester& n(size_t n) { + assert(n != 0); + this->n_ = n; + return *this; + } + + inline size_t n() const { + return this->n_; + } + + inline RMaxMicrokernelTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void test(pytorch_u8rmax_ukernel_function u8rmax) const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector x(n()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(x.begin(), x.end(), std::ref(u8rng)); + + /* Compute reference results */ + uint8_t yRef = 0; + for (size_t i = 0; i < n(); i++) { + yRef = std::max(yRef, x[i]); + } + + /* Call optimized micro-kernel */ + const uint8_t y = u8rmax(n(), x.data()); + + /* Verify results */ + ASSERT_EQ(yRef, y) << "n = " << n(); + } + } + + private: + size_t n_{1}; + size_t iterations_{15}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/sconv.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/sconv.cc new file mode 100644 index 0000000000000..405edd411e2d2 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/sconv.cc @@ -0,0 +1,103 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include "gemm-microkernel-tester.h" + +TEST(SCONV_6x8__PSIMD, k_eq_1) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(6) + .n(8) + .k(1) + .aStride(37) + .test(pytorch_sconv_ukernel_6x8__psimd); +} + +TEST(SCONV_6x8__PSIMD, k_eq_1_strided_c) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(6) + .n(8) + .k(1) + .aStride(37) + .cStride(17) + .test(pytorch_sconv_ukernel_6x8__psimd); +} + +TEST(SCONV_6x8__PSIMD, k_eq_1_qmin128) { + GemmMicrokernelTester().mr(6).nr(8).np(8).kr(1).m(6).n(8).k(1).qmin(128).test( + pytorch_sconv_ukernel_6x8__psimd); +} + +TEST(SCONV_6x8__PSIMD, k_eq_1_qmax128) { + GemmMicrokernelTester().mr(6).nr(8).np(8).kr(1).m(6).n(8).k(1).qmax(128).test( + pytorch_sconv_ukernel_6x8__psimd); +} + +TEST(SCONV_6x8__PSIMD, k_gt_1) { + for (size_t k = 2; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(6) + .n(8) + .k(k) + .aStride(37) + .test(pytorch_sconv_ukernel_6x8__psimd); + } +} + +TEST(SCONV_6x8__PSIMD, k_gt_1_strided_c) { + for (size_t k = 2; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(6) + .n(8) + .k(k) + .aStride(37) + .cStride(17) + .test(pytorch_sconv_ukernel_6x8__psimd); + } +} + +TEST(SCONV_6x8__PSIMD, k_gt_1_subtile) { + for (size_t k = 2; k < 16; k++) { + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .aStride(37) + .iterations(3) + .test(pytorch_sconv_ukernel_6x8__psimd); + } + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/sgemm.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/sgemm.cc new file mode 100644 index 0000000000000..5e0f5bb896c97 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/sgemm.cc @@ -0,0 +1,502 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include "gemm-microkernel-tester.h" + +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 +TEST(SGEMM_5x8__NEON, k_eq_2) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(5).nr(8).np(8).kr(1).m(5).n(8).k(2).test( + pytorch_sgemm_ukernel_5x8__neon); +} + +TEST(SGEMM_5x8__NEON, k_eq_2_strided_a) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(5) + .nr(8) + .np(8) + .kr(1) + .m(5) + .n(8) + .k(2) + .aStride(37) + .test(pytorch_sgemm_ukernel_5x8__neon); +} + +TEST(SGEMM_5x8__NEON, k_eq_2_strided_c) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(5) + .nr(8) + .np(8) + .kr(1) + .m(5) + .n(8) + .k(2) + .cStride(17) + .test(pytorch_sgemm_ukernel_5x8__neon); +} + +TEST(SGEMM_5x8__NEON, k_eq_8_rmin128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(1).m(4).n(8).k(8).qmin(128).test( + pytorch_sgemm_ukernel_5x8__neon); +} + +TEST(SGEMM_5x8__NEON, k_eq_8_qmax128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(4).nr(8).np(8).kr(1).m(4).n(8).k(8).qmax(128).test( + pytorch_sgemm_ukernel_5x8__neon); +} + +TEST(SGEMM_5x8__NEON, k_gt_2) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 3; k < 16; k++) { + GemmMicrokernelTester().mr(5).nr(8).np(8).kr(1).m(5).n(8).k(k).test( + pytorch_sgemm_ukernel_5x8__neon); + } +} + +TEST(SGEMM_5x8__NEON, k_gt_2_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 3; k < 16; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .np(8) + .kr(1) + .m(5) + .n(8) + .k(k) + .aStride(37) + .test(pytorch_sgemm_ukernel_5x8__neon); + } +} + +TEST(SGEMM_5x8__NEON, k_gt_2_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 3; k < 16; k++) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .np(8) + .kr(1) + .m(5) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_sgemm_ukernel_5x8__neon); + } +} + +TEST(SGEMM_5x8__NEON, k_gt_2_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 3; k < 16; k++) { + for (uint32_t m = 1; m <= 5; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_sgemm_ukernel_5x8__neon); + } + } + } +} + +TEST(SGEMM_5x8__NEON, k_div_2) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 2; k < 32; k += 2) { + GemmMicrokernelTester().mr(5).nr(8).np(8).kr(1).m(5).n(8).k(k).test( + pytorch_sgemm_ukernel_5x8__neon); + } +} + +TEST(SGEMM_5x8__NEON, k_div_2_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 2; k < 32; k += 2) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .np(8) + .kr(1) + .m(5) + .n(8) + .k(k) + .aStride(171) + .test(pytorch_sgemm_ukernel_5x8__neon); + } +} + +TEST(SGEMM_5x8__NEON, k_div_2_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 2; k < 32; k += 2) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .np(8) + .kr(1) + .m(5) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_sgemm_ukernel_5x8__neon); + } +} + +TEST(SGEMM_5x8__NEON, k_div_2_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 2; k < 32; k += 6) { + for (uint32_t m = 1; m <= 5; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(5) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_sgemm_ukernel_5x8__neon); + } + } + } +} + +TEST(SGEMM_6x8__NEON, k_eq_2) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(6).nr(8).np(8).kr(1).m(6).n(8).k(2).test( + pytorch_sgemm_ukernel_6x8__neon); +} + +TEST(SGEMM_6x8__NEON, k_eq_2_strided_a) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(6) + .n(8) + .k(2) + .aStride(37) + .test(pytorch_sgemm_ukernel_6x8__neon); +} + +TEST(SGEMM_6x8__NEON, k_eq_2_strided_c) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(6) + .n(8) + .k(2) + .cStride(17) + .test(pytorch_sgemm_ukernel_6x8__neon); +} + +TEST(SGEMM_6x8__NEON, k_eq_8_qmin128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(6).nr(8).np(8).kr(1).m(6).n(8).k(8).qmin(128).test( + pytorch_sgemm_ukernel_6x8__neon); +} + +TEST(SGEMM_6x8__NEON, k_eq_8_qmax128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester().mr(6).nr(8).np(8).kr(1).m(6).n(8).k(8).qmax(128).test( + pytorch_sgemm_ukernel_6x8__neon); +} + +TEST(SGEMM_6x8__NEON, k_gt_2) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 3; k < 16; k++) { + GemmMicrokernelTester().mr(6).nr(8).np(8).kr(1).m(6).n(8).k(k).test( + pytorch_sgemm_ukernel_6x8__neon); + } +} + +TEST(SGEMM_6x8__NEON, k_gt_2_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 3; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(6) + .n(8) + .k(k) + .aStride(37) + .test(pytorch_sgemm_ukernel_6x8__neon); + } +} + +TEST(SGEMM_6x8__NEON, k_gt_2_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 3; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(6) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_sgemm_ukernel_6x8__neon); + } +} + +TEST(SGEMM_6x8__NEON, k_gt_2_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 3; k < 16; k++) { + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_sgemm_ukernel_6x8__neon); + } + } + } +} + +TEST(SGEMM_6x8__NEON, k_div_2) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 2; k < 32; k += 2) { + GemmMicrokernelTester().mr(6).nr(8).np(8).kr(1).m(6).n(8).k(k).test( + pytorch_sgemm_ukernel_6x8__neon); + } +} + +TEST(SGEMM_6x8__NEON, k_div_2_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 2; k < 32; k += 2) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(6) + .n(8) + .k(k) + .aStride(171) + .test(pytorch_sgemm_ukernel_6x8__neon); + } +} + +TEST(SGEMM_6x8__NEON, k_div_2_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 2; k < 32; k += 2) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(6) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_sgemm_ukernel_6x8__neon); + } +} + +TEST(SGEMM_6x8__NEON, k_div_2_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 2; k < 32; k += 6) { + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_sgemm_ukernel_6x8__neon); + } + } + } +} +#endif + +TEST(SGEMM_6x8__PSIMD, k_eq_2) { + GemmMicrokernelTester().mr(6).nr(8).np(8).kr(1).m(6).n(8).k(2).test( + pytorch_sgemm_ukernel_6x8__psimd); +} + +TEST(SGEMM_6x8__PSIMD, k_eq_2_strided_a) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(6) + .n(8) + .k(2) + .aStride(37) + .test(pytorch_sgemm_ukernel_6x8__psimd); +} + +TEST(SGEMM_6x8__PSIMD, k_eq_2_strided_c) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(6) + .n(8) + .k(2) + .cStride(17) + .test(pytorch_sgemm_ukernel_6x8__psimd); +} + +TEST(SGEMM_6x8__PSIMD, k_eq_8_qmin128) { + GemmMicrokernelTester().mr(6).nr(8).np(8).kr(1).m(6).n(8).k(8).qmin(128).test( + pytorch_sgemm_ukernel_6x8__psimd); +} + +TEST(SGEMM_6x8__PSIMD, k_eq_8_qmax128) { + GemmMicrokernelTester().mr(6).nr(8).np(8).kr(1).m(6).n(8).k(8).qmax(128).test( + pytorch_sgemm_ukernel_6x8__psimd); +} + +TEST(SGEMM_6x8__PSIMD, k_gt_2) { + for (size_t k = 3; k < 16; k++) { + GemmMicrokernelTester().mr(6).nr(8).np(8).kr(1).m(6).n(8).k(k).test( + pytorch_sgemm_ukernel_6x8__psimd); + } +} + +TEST(SGEMM_6x8__PSIMD, k_gt_2_strided_a) { + for (size_t k = 3; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(6) + .n(8) + .k(k) + .aStride(37) + .test(pytorch_sgemm_ukernel_6x8__psimd); + } +} + +TEST(SGEMM_6x8__PSIMD, k_gt_2_strided_c) { + for (size_t k = 3; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(6) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_sgemm_ukernel_6x8__psimd); + } +} + +TEST(SGEMM_6x8__PSIMD, k_gt_2_subtile) { + for (size_t k = 3; k < 16; k++) { + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_sgemm_ukernel_6x8__psimd); + } + } + } +} + +TEST(SGEMM_6x8__PSIMD, k_div_2) { + for (size_t k = 2; k < 32; k += 2) { + GemmMicrokernelTester().mr(6).nr(8).np(8).kr(1).m(6).n(8).k(k).test( + pytorch_sgemm_ukernel_6x8__psimd); + } +} + +TEST(SGEMM_6x8__PSIMD, k_div_2_strided_a) { + for (size_t k = 2; k < 32; k += 2) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(6) + .n(8) + .k(k) + .aStride(171) + .test(pytorch_sgemm_ukernel_6x8__psimd); + } +} + +TEST(SGEMM_6x8__PSIMD, k_div_2_strided_c) { + for (size_t k = 2; k < 32; k += 2) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(6) + .n(8) + .k(k) + .cStride(17) + .test(pytorch_sgemm_ukernel_6x8__psimd); + } +} + +TEST(SGEMM_6x8__PSIMD, k_div_2_subtile) { + for (size_t k = 2; k < 32; k += 6) { + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(pytorch_sgemm_ukernel_6x8__psimd); + } + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/sigmoid-operator-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/sigmoid-operator-tester.h new file mode 100644 index 0000000000000..d1d0c23b59d56 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/sigmoid-operator-tester.h @@ -0,0 +1,214 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +class SigmoidOperatorTester { + public: + inline SigmoidOperatorTester& channels(size_t channels) { + assert(channels != 0); + this->channels_ = channels; + return *this; + } + + inline size_t channels() const { + return this->channels_; + } + + inline SigmoidOperatorTester& inputStride(size_t inputStride) { + assert(inputStride != 0); + this->inputStride_ = inputStride; + return *this; + } + + inline size_t inputStride() const { + if (this->inputStride_ == 0) { + return this->channels_; + } else { + assert(this->inputStride_ >= this->channels_); + return this->inputStride_; + } + } + + inline SigmoidOperatorTester& outputStride(size_t outputStride) { + assert(outputStride != 0); + this->outputStride_ = outputStride; + return *this; + } + + inline size_t outputStride() const { + if (this->outputStride_ == 0) { + return this->channels_; + } else { + assert(this->outputStride_ >= this->channels_); + return this->outputStride_; + } + } + + inline SigmoidOperatorTester& batchSize(size_t batchSize) { + this->batchSize_ = batchSize; + return *this; + } + + inline size_t batchSize() const { + return this->batchSize_; + } + + inline SigmoidOperatorTester& inputScale(float inputScale) { + assert(inputScale > 0.0f); + assert(std::isnormal(inputScale)); + this->inputScale_ = inputScale; + return *this; + } + + inline float inputScale() const { + return this->inputScale_; + } + + inline SigmoidOperatorTester& inputZeroPoint(uint8_t inputZeroPoint) { + this->inputZeroPoint_ = inputZeroPoint; + return *this; + } + + inline uint8_t inputZeroPoint() const { + return this->inputZeroPoint_; + } + + inline float outputScale() const { + return 1.0f / 256.0f; + } + + inline uint8_t outputZeroPoint() const { + return 0; + } + + inline SigmoidOperatorTester& qmin(uint8_t qmin) { + this->qmin_ = qmin; + return *this; + } + + inline uint8_t qmin() const { + return this->qmin_; + } + + inline SigmoidOperatorTester& qmax(uint8_t qmax) { + this->qmax_ = qmax; + return *this; + } + + inline uint8_t qmax() const { + return this->qmax_; + } + + inline SigmoidOperatorTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void testQ8() const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input((batchSize() - 1) * inputStride() + channels()); + std::vector output( + (batchSize() - 1) * outputStride() + channels()); + std::vector outputRef(batchSize() * channels()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::fill(output.begin(), output.end(), 0xA5); + + /* Compute reference results */ + for (size_t i = 0; i < batchSize(); i++) { + for (size_t c = 0; c < channels(); c++) { + const float x = inputScale() * + (int32_t(input[i * inputStride() + c]) - + int32_t(inputZeroPoint())); + const float sigmoidX = 1.0f / (1.0f + exp(-x)); + const float scaledSigmoidX = sigmoidX / outputScale(); + float y = scaledSigmoidX; + y = std::min(y, int32_t(qmax()) - int32_t(outputZeroPoint())); + y = std::max(y, int32_t(qmin()) - int32_t(outputZeroPoint())); + outputRef[i * channels() + c] = y + int32_t(outputZeroPoint()); + } + } + + /* Create, setup, run, and destroy Sigmoid operator */ + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + pytorch_qnnp_operator_t sigmoidOp = nullptr; + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_create_sigmoid_nc_q8( + channels(), + inputZeroPoint(), + inputScale(), + outputZeroPoint(), + outputScale(), + qmin(), + qmax(), + 0, + &sigmoidOp)); + ASSERT_NE(nullptr, sigmoidOp); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_setup_sigmoid_nc_q8( + sigmoidOp, + batchSize(), + input.data(), + inputStride(), + output.data(), + outputStride())); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_run_operator(sigmoidOp, nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, pytorch_qnnp_delete_operator(sigmoidOp)); + sigmoidOp = nullptr; + + /* Verify results */ + for (size_t i = 0; i < batchSize(); i++) { + for (size_t c = 0; c < channels(); c++) { + ASSERT_NEAR( + float(int32_t(output[i * outputStride() + c])), + outputRef[i * channels() + c], + 0.6f); + } + } + } + } + + private: + size_t batchSize_{1}; + size_t channels_{1}; + size_t inputStride_{0}; + size_t outputStride_{0}; + float inputScale_{0.75f}; + uint8_t inputZeroPoint_{121}; + uint8_t qmin_{0}; + uint8_t qmax_{255}; + size_t iterations_{15}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/sigmoid.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/sigmoid.cc new file mode 100644 index 0000000000000..fd17560839417 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/sigmoid.cc @@ -0,0 +1,229 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include "sigmoid-operator-tester.h" + +#include + +TEST(SIGMOID_OP, zero_batch) { + SigmoidOperatorTester().batchSize(0).channels(8).iterations(1).testQ8(); +} + +TEST(SIGMOID_OP, unit_batch) { + for (size_t channels = 1; channels < 100; channels += 15) { + SigmoidOperatorTester() + .batchSize(1) + .channels(channels) + .iterations(3) + .testQ8(); + } +} + +TEST(SIGMOID_OP, unit_batch_with_qmin) { + for (size_t channels = 1; channels < 100; channels += 15) { + SigmoidOperatorTester() + .batchSize(1) + .channels(channels) + .qmin(128) + .iterations(3) + .testQ8(); + } +} + +TEST(SIGMOID_OP, unit_batch_with_qmax) { + for (size_t channels = 1; channels < 100; channels += 15) { + SigmoidOperatorTester() + .batchSize(1) + .channels(channels) + .qmax(128) + .iterations(3) + .testQ8(); + } +} + +TEST(SIGMOID_OP, unit_batch_with_input_scale) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (float inputScale = 1.0e-2f; inputScale < 1.0e+2f; + inputScale *= 10.0f) { + SigmoidOperatorTester() + .batchSize(1) + .channels(channels) + .inputScale(inputScale) + .iterations(1) + .testQ8(); + } + } +} + +TEST(SIGMOID_OP, unit_batch_with_input_zero_point) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (int32_t inputZeroPoint = 0; inputZeroPoint <= 255; + inputZeroPoint += 51) { + SigmoidOperatorTester() + .batchSize(1) + .channels(channels) + .inputZeroPoint(uint8_t(inputZeroPoint)) + .iterations(1) + .testQ8(); + } + } +} + +TEST(SIGMOID_OP, small_batch) { + for (size_t channels = 1; channels < 100; channels += 15) { + SigmoidOperatorTester() + .batchSize(3) + .channels(channels) + .iterations(3) + .testQ8(); + } +} + +TEST(SIGMOID_OP, small_batch_with_input_stride) { + for (size_t channels = 1; channels < 100; channels += 15) { + SigmoidOperatorTester() + .batchSize(3) + .channels(channels) + .inputStride(129) + .iterations(3) + .testQ8(); + } +} + +TEST(SIGMOID_OP, small_batch_with_output_stride) { + for (size_t channels = 1; channels < 100; channels += 15) { + SigmoidOperatorTester() + .batchSize(3) + .channels(channels) + .outputStride(117) + .iterations(3) + .testQ8(); + } +} + +TEST(SIGMOID_OP, small_batch_with_qmin) { + for (size_t channels = 1; channels < 100; channels += 15) { + SigmoidOperatorTester() + .batchSize(3) + .channels(channels) + .qmin(128) + .iterations(3) + .testQ8(); + } +} + +TEST(SIGMOID_OP, small_batch_with_qmax) { + for (size_t channels = 1; channels < 100; channels += 15) { + SigmoidOperatorTester() + .batchSize(3) + .channels(channels) + .qmax(128) + .iterations(3) + .testQ8(); + } +} + +TEST(SIGMOID_OP, small_batch_with_input_scale) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (float inputScale = 1.0e-2f; inputScale < 1.0e+2f; + inputScale *= 10.0f) { + SigmoidOperatorTester() + .batchSize(3) + .channels(channels) + .inputScale(inputScale) + .iterations(1) + .testQ8(); + } + } +} + +TEST(SIGMOID_OP, small_batch_with_input_zero_point) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (int32_t inputZeroPoint = 0; inputZeroPoint <= 255; + inputZeroPoint += 51) { + SigmoidOperatorTester() + .batchSize(3) + .channels(channels) + .inputZeroPoint(uint8_t(inputZeroPoint)) + .iterations(1) + .testQ8(); + } + } +} + +TEST(SIGMOID_OP, strided_batch) { + for (size_t channels = 1; channels < 100; channels += 15) { + SigmoidOperatorTester() + .batchSize(3) + .channels(channels) + .inputStride(129) + .outputStride(117) + .iterations(3) + .testQ8(); + } +} + +TEST(SIGMOID_OP, strided_batch_with_qmin) { + for (size_t channels = 1; channels < 100; channels += 15) { + SigmoidOperatorTester() + .batchSize(3) + .channels(channels) + .inputStride(129) + .outputStride(117) + .qmin(128) + .iterations(3) + .testQ8(); + } +} + +TEST(SIGMOID_OP, strided_batch_with_qmax) { + for (size_t channels = 1; channels < 100; channels += 15) { + SigmoidOperatorTester() + .batchSize(3) + .channels(channels) + .inputStride(129) + .outputStride(117) + .qmax(128) + .iterations(3) + .testQ8(); + } +} + +TEST(SIGMOID_OP, strided_batch_with_input_scale) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (float inputScale = 1.0e-2f; inputScale < 1.0e+2f; + inputScale *= 10.0f) { + SigmoidOperatorTester() + .batchSize(3) + .channels(channels) + .inputStride(129) + .outputStride(117) + .inputScale(inputScale) + .iterations(1) + .testQ8(); + } + } +} + +TEST(SIGMOID_OP, strided_batch_with_input_zero_point) { + for (size_t channels = 1; channels < 100; channels += 15) { + for (int32_t inputZeroPoint = 0; inputZeroPoint <= 255; + inputZeroPoint += 51) { + SigmoidOperatorTester() + .batchSize(3) + .channels(channels) + .inputStride(129) + .outputStride(117) + .inputZeroPoint(uint8_t(inputZeroPoint)) + .iterations(1) + .testQ8(); + } + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/softargmax-operator-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/softargmax-operator-tester.h new file mode 100644 index 0000000000000..54fbc3c2cff40 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/softargmax-operator-tester.h @@ -0,0 +1,198 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +class SoftArgMaxOperatorTester { + public: + inline SoftArgMaxOperatorTester& channels(size_t channels) { + assert(channels != 0); + this->channels_ = channels; + return *this; + } + + inline size_t channels() const { + return this->channels_; + } + + inline SoftArgMaxOperatorTester& inputStride(size_t inputStride) { + assert(inputStride != 0); + this->inputStride_ = inputStride; + return *this; + } + + inline size_t inputStride() const { + if (this->inputStride_ == 0) { + return this->channels_; + } else { + assert(this->inputStride_ >= this->channels_); + return this->inputStride_; + } + } + + inline SoftArgMaxOperatorTester& outputStride(size_t outputStride) { + assert(outputStride != 0); + this->outputStride_ = outputStride; + return *this; + } + + inline size_t outputStride() const { + if (this->outputStride_ == 0) { + return this->channels_; + } else { + assert(this->outputStride_ >= this->channels_); + return this->outputStride_; + } + } + + inline SoftArgMaxOperatorTester& batchSize(size_t batchSize) { + this->batchSize_ = batchSize; + return *this; + } + + inline size_t batchSize() const { + return this->batchSize_; + } + + inline SoftArgMaxOperatorTester& inputScale(float inputScale) { + assert(inputScale > 0.0f); + assert(std::isnormal(inputScale)); + this->inputScale_ = inputScale; + return *this; + } + + inline float inputScale() const { + return this->inputScale_; + } + + inline SoftArgMaxOperatorTester& inputZeroPoint(uint8_t inputZeroPoint) { + this->inputZeroPoint_ = inputZeroPoint; + return *this; + } + + inline uint8_t inputZeroPoint() const { + return this->inputZeroPoint_; + } + + inline float outputScale() const { + return 1.0f / 256.0f; + } + + inline uint8_t outputZeroPoint() const { + return 0; + } + + inline SoftArgMaxOperatorTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void testQ8() const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector input((batchSize() - 1) * inputStride() + channels()); + std::vector output( + (batchSize() - 1) * outputStride() + channels()); + std::vector outputRef(batchSize() * channels()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(input.begin(), input.end(), std::ref(u8rng)); + std::fill(output.begin(), output.end(), 0xA5); + + /* Compute reference results */ + for (size_t i = 0; i < batchSize(); i++) { + const int32_t maxInput = *std::max_element( + input.data() + i * inputStride(), + input.data() + i * inputStride() + channels()); + float sumExp = 0.0f; + for (size_t c = 0; c < channels(); c++) { + sumExp += + exp((int32_t(input[i * inputStride() + c]) - maxInput) * + inputScale()); + } + for (size_t c = 0; c < channels(); c++) { + outputRef[i * channels() + c] = + exp((int32_t(input[i * inputStride() + c]) - maxInput) * + inputScale()) / + (sumExp * outputScale()); + outputRef[i * channels() + c] = + std::min(outputRef[i * channels() + c], 255.0f); + } + } + + /* Create, setup, run, and destroy SoftArgMax operator */ + ASSERT_EQ(pytorch_qnnp_status_success, pytorch_qnnp_initialize()); + pytorch_qnnp_operator_t softArgMaxOp = nullptr; + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_create_softargmax_nc_q8( + channels(), + inputScale(), + outputZeroPoint(), + outputScale(), + 0, + &softArgMaxOp)); + ASSERT_NE(nullptr, softArgMaxOp); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_setup_softargmax_nc_q8( + softArgMaxOp, + batchSize(), + input.data(), + inputStride(), + output.data(), + outputStride())); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_run_operator(softArgMaxOp, nullptr /* thread pool */)); + + ASSERT_EQ( + pytorch_qnnp_status_success, + pytorch_qnnp_delete_operator(softArgMaxOp)); + softArgMaxOp = nullptr; + + /* Verify results */ + for (size_t i = 0; i < batchSize(); i++) { + for (size_t c = 0; c < channels(); c++) { + ASSERT_NEAR( + float(int32_t(output[i * outputStride() + c])), + outputRef[i * channels() + c], + 0.6f); + } + } + } + } + + private: + size_t batchSize_{1}; + size_t channels_{1}; + size_t inputStride_{0}; + size_t outputStride_{0}; + float inputScale_{0.176080093}; + uint8_t inputZeroPoint_{121}; + size_t iterations_{15}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/softargmax.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/softargmax.cc new file mode 100644 index 0000000000000..a89a957451851 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/softargmax.cc @@ -0,0 +1,135 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include "softargmax-operator-tester.h" + +#include + +TEST(SOFTARGMAX_OP, zero_batch) { + SoftArgMaxOperatorTester().batchSize(0).channels(1).iterations(1).testQ8(); +} + +TEST(SOFTARGMAX_OP, single_class) { + SoftArgMaxOperatorTester().batchSize(1).channels(1).iterations(100).testQ8(); +} + +TEST(SOFTARGMAX_OP, two_classes) { + SoftArgMaxOperatorTester().batchSize(1).channels(2).iterations(100).testQ8(); +} + +TEST(SOFTARGMAX_OP, many_classes) { + for (size_t channels = 3; channels < 100; channels++) { + SoftArgMaxOperatorTester() + .batchSize(1) + .channels(channels) + .iterations(1) + .testQ8(); + } +} + +TEST(SOFTARGMAX_OP, cifar_classes) { + /* CIFAR-10 */ + SoftArgMaxOperatorTester().batchSize(1).channels(10).iterations(15).testQ8(); + /* CIFAR-100 */ + SoftArgMaxOperatorTester().batchSize(1).channels(100).iterations(15).testQ8(); +} + +TEST(SOFTARGMAX_OP, imagenet_classes) { + /* ImageNet-1K */ + SoftArgMaxOperatorTester() + .batchSize(1) + .channels(1000) + .iterations(10) + .testQ8(); + /* ImageNet-1K+1 */ + SoftArgMaxOperatorTester() + .batchSize(1) + .channels(1001) + .iterations(10) + .testQ8(); + /* ImageNet-22K */ + SoftArgMaxOperatorTester() + .batchSize(1) + .channels(21841) + .iterations(10) + .testQ8(); +} + +TEST(SOFTARGMAX_OP, many_channels_with_input_scale) { + for (size_t channels = 1; channels < 100; channels += 5) { + for (float inputScale = 1.0e-2f; inputScale < 1.0e+2f; + inputScale *= 3.14159265f) { + SoftArgMaxOperatorTester() + .batchSize(1) + .channels(channels) + .inputScale(inputScale) + .iterations(1) + .testQ8(); + } + } +} + +TEST(SOFTARGMAX_OP, many_channels_with_input_zero_point) { + for (size_t channels = 1; channels < 100; channels += 5) { + for (int32_t inputZeroPoint = 0; inputZeroPoint <= 255; + inputZeroPoint += 51) { + SoftArgMaxOperatorTester() + .batchSize(1) + .channels(channels) + .inputZeroPoint(uint8_t(inputZeroPoint)) + .iterations(1) + .testQ8(); + } + } +} + +TEST(SOFTARGMAX_OP, small_batch) { + for (size_t channels = 1; channels < 100; channels += 5) { + SoftArgMaxOperatorTester() + .batchSize(3) + .channels(channels) + .iterations(3) + .testQ8(); + } +} + +TEST(SOFTARGMAX_OP, small_batch_with_input_stride) { + for (size_t channels = 1; channels < 100; channels += 5) { + SoftArgMaxOperatorTester() + .batchSize(3) + .channels(channels) + .inputStride(129) + .iterations(3) + .testQ8(); + } +} + +TEST(SOFTARGMAX_OP, small_batch_with_output_stride) { + for (size_t channels = 1; channels < 100; channels += 5) { + SoftArgMaxOperatorTester() + .batchSize(3) + .channels(channels) + .outputStride(117) + .iterations(3) + .testQ8(); + } +} + +TEST(SOFTARGMAX_OP, strided_batch_with_input_and_output_stride) { + for (size_t channels = 1; channels < 100; channels += 5) { + SoftArgMaxOperatorTester() + .batchSize(3) + .channels(channels) + .inputStride(129) + .outputStride(117) + .iterations(3) + .testQ8(); + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/u8clamp.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/u8clamp.cc new file mode 100644 index 0000000000000..3506b453f8e4b --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/u8clamp.cc @@ -0,0 +1,127 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include "clamp-microkernel-tester.h" + +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 +TEST(U8CLAMP__NEON, n_eq_8) { + TEST_REQUIRES_ARM_NEON; + ClampMicrokernelTester().n(8).test(pytorch_u8clamp_ukernel__neon); +} + +TEST(U8CLAMP__NEON, n_div_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 8; n < 512; n += 8) { + ClampMicrokernelTester().n(n).test(pytorch_u8clamp_ukernel__neon); + } +} + +TEST(U8CLAMP__NEON, n_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + ClampMicrokernelTester().n(n).test(pytorch_u8clamp_ukernel__neon); + } +} + +TEST(U8CLAMP__NEON, n_lt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 8; n++) { + ClampMicrokernelTester().n(n).test(pytorch_u8clamp_ukernel__neon); + } +} + +TEST(U8CLAMP__NEON, inplace) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 128; n += 5) { + ClampMicrokernelTester().iterations(1).n(n).inplace(true).test( + pytorch_u8clamp_ukernel__neon); + } +} + +TEST(U8CLAMP__NEON, qmin) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 128; n += 11) { + for (uint8_t qmin = 1; qmin < 255; qmin++) { + ClampMicrokernelTester().iterations(1).n(n).qmin(qmin).test( + pytorch_u8clamp_ukernel__neon); + } + } +} + +TEST(U8CLAMP__NEON, qmax) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 128; n += 11) { + for (uint8_t qmax = 1; qmax < 255; qmax++) { + ClampMicrokernelTester().iterations(1).n(n).qmax(qmax).test( + pytorch_u8clamp_ukernel__neon); + } + } +} +#endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */ + +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +TEST(U8CLAMP__SSE2, n_eq_8) { + TEST_REQUIRES_X86_SSE2; + ClampMicrokernelTester().n(8).test(pytorch_u8clamp_ukernel__sse2); +} + +TEST(U8CLAMP__SSE2, n_div_8) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 8; n < 512; n += 8) { + ClampMicrokernelTester().n(n).test(pytorch_u8clamp_ukernel__sse2); + } +} + +TEST(U8CLAMP__SSE2, n_gt_8) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 9; n < 16; n++) { + ClampMicrokernelTester().n(n).test(pytorch_u8clamp_ukernel__sse2); + } +} + +TEST(U8CLAMP__SSE2, n_lt_8) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 8; n++) { + ClampMicrokernelTester().n(n).test(pytorch_u8clamp_ukernel__sse2); + } +} + +TEST(U8CLAMP__SSE2, inplace) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 128; n += 5) { + ClampMicrokernelTester().iterations(1).n(n).inplace(true).test( + pytorch_u8clamp_ukernel__sse2); + } +} + +TEST(U8CLAMP__SSE2, qmin) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 128; n += 11) { + for (uint8_t qmin = 1; qmin < 255; qmin++) { + ClampMicrokernelTester().iterations(1).n(n).qmin(qmin).test( + pytorch_u8clamp_ukernel__sse2); + } + } +} + +TEST(U8CLAMP__SSE2, qmax) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 128; n += 11) { + for (uint8_t qmax = 1; qmax < 255; qmax++) { + ClampMicrokernelTester().iterations(1).n(n).qmax(qmax).test( + pytorch_u8clamp_ukernel__sse2); + } + } +} +#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/u8lut32norm.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/u8lut32norm.cc new file mode 100644 index 0000000000000..9ab19a4280cad --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/u8lut32norm.cc @@ -0,0 +1,48 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include "lut-norm-microkernel-tester.h" + +TEST(U8LUT32NORM__SCALAR, n_eq_1) { + LUTNormMicrokernelTester().n(1).test(pytorch_u8lut32norm_ukernel__scalar); +} + +TEST(U8LUT32NORM__SCALAR, small_n) { + for (size_t n = 2; n <= 16; n++) { + LUTNormMicrokernelTester().n(n).test(pytorch_u8lut32norm_ukernel__scalar); + } +} + +TEST(U8LUT32NORM__SCALAR, large_n) { + for (size_t n = 16; n <= 128; n += 2) { + LUTNormMicrokernelTester().n(n).test(pytorch_u8lut32norm_ukernel__scalar); + } +} + +TEST(U8LUT32NORM__SCALAR, n_eq_1_inplace) { + LUTNormMicrokernelTester().n(1).inplace(true).test( + pytorch_u8lut32norm_ukernel__scalar); +} + +TEST(U8LUT32NORM__SCALAR, small_n_inplace) { + for (size_t n = 2; n <= 16; n++) { + LUTNormMicrokernelTester().n(n).inplace(true).test( + pytorch_u8lut32norm_ukernel__scalar); + } +} + +TEST(U8LUT32NORM__SCALAR, large_n_inplace) { + for (size_t n = 16; n <= 128; n += 2) { + LUTNormMicrokernelTester().n(n).inplace(true).test( + pytorch_u8lut32norm_ukernel__scalar); + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/u8maxpool.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/u8maxpool.cc new file mode 100644 index 0000000000000..f73c428579520 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/u8maxpool.cc @@ -0,0 +1,1539 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include "maxpool-microkernel-tester.h" + +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 +TEST(U8MAXPOOL_SUB16__NEON, kc_lt_16_mx1_pool) { + TEST_REQUIRES_ARM_NEON; + for (size_t kc = 1; kc < 16; kc++) { + for (size_t ks = 2; ks < 16; ks++) { + MaxPoolMicrokernelTester().kr(16).kh(ks).kw(1).kc(kc).test( + pytorch_u8maxpool_ukernel_sub16__neon); + } + } +} + +TEST(U8MAXPOOL_SUB16__NEON, kc_lt_16_mx1_pool_with_qmin) { + TEST_REQUIRES_ARM_NEON; + for (size_t kc = 1; kc < 16; kc++) { + for (size_t ks = 2; ks < 16; ks++) { + MaxPoolMicrokernelTester().kr(16).kh(ks).kw(1).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_sub16__neon); + } + } +} + +TEST(U8MAXPOOL_SUB16__NEON, kc_lt_16_mx1_pool_with_qmax) { + TEST_REQUIRES_ARM_NEON; + for (size_t kc = 1; kc < 16; kc++) { + for (size_t ks = 2; ks < 16; ks++) { + MaxPoolMicrokernelTester().kr(16).kh(ks).kw(1).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_sub16__neon); + } + } +} + +TEST(U8MAXPOOL_SUB16__NEON, kc_lt_16_1xm_pool) { + TEST_REQUIRES_ARM_NEON; + for (size_t kc = 1; kc < 16; kc++) { + for (size_t ks = 2; ks < 16; ks++) { + MaxPoolMicrokernelTester().kr(16).kh(1).kw(ks).kc(kc).test( + pytorch_u8maxpool_ukernel_sub16__neon); + } + } +} + +TEST(U8MAXPOOL_SUB16__NEON, kc_lt_16_1xm_pool_with_qmin) { + TEST_REQUIRES_ARM_NEON; + for (size_t kc = 1; kc < 16; kc++) { + for (size_t ks = 2; ks < 16; ks++) { + MaxPoolMicrokernelTester().kr(16).kh(1).kw(ks).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_sub16__neon); + } + } +} + +TEST(U8MAXPOOL_SUB16__NEON, kc_lt_16_1xm_pool_with_qmax) { + TEST_REQUIRES_ARM_NEON; + for (size_t kc = 1; kc < 16; kc++) { + for (size_t ks = 2; ks < 16; ks++) { + MaxPoolMicrokernelTester().kr(16).kh(1).kw(ks).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_sub16__neon); + } + } +} + +TEST(U8MAXPOOL_SUB16__NEON, kc_lt_16_small_n) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t kc = 1; kc < 16; kc++) { + MaxPoolMicrokernelTester() + .kr(16) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .iterations(3) + .test(pytorch_u8maxpool_ukernel_sub16__neon); + } + } + } +} + +TEST(U8MAXPOOL_SUB16__NEON, kc_lt_16_small_n_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t kc = 1; kc < 16; kc++) { + MaxPoolMicrokernelTester() + .kr(16) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .xStride(17) + .iterations(3) + .test(pytorch_u8maxpool_ukernel_sub16__neon); + } + } + } +} + +TEST(U8MAXPOOL_SUB16__NEON, kc_lt_16_small_n_with_s) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t s = 2; s <= 5; s++) { + for (size_t kc = 1; kc < 16; kc++) { + MaxPoolMicrokernelTester() + .kr(16) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .s(s) + .iterations(1) + .test(pytorch_u8maxpool_ukernel_sub16__neon); + } + } + } + } +} + +TEST(U8MAXPOOL_SUB16__NEON, kc_lt_16_small_n_with_qmin) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t kc = 1; kc < 16; kc++) { + MaxPoolMicrokernelTester() + .kr(16) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .qmin(192) + .iterations(3) + .test(pytorch_u8maxpool_ukernel_sub16__neon); + } + } + } +} + +TEST(U8MAXPOOL_SUB16__NEON, kc_lt_16_small_n_with_qmax) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t kc = 1; kc < 16; kc++) { + MaxPoolMicrokernelTester() + .kr(16) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .qmax(192) + .iterations(3) + .test(pytorch_u8maxpool_ukernel_sub16__neon); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_eq_16_unipass_fulltile) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).kc(16); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + tester.kh(kh).kw(kw).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_eq_16_unipass_fulltile_with_qmin) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).kc(16); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + tester.kh(kh).kw(kw).qmin(192).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_eq_16_unipass_fulltile_with_qmax) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).kc(16); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + tester.kh(kh).kw(kw).qmax(192).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_eq_16_unipass_subtile) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).kc(16); + for (size_t ks = 2; ks < tester.mr(); ks++) { + tester.kh(ks).kw(1).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + tester.kh(1).kw(ks).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_div_16_unipass_fulltile) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_div_16_unipass_fulltile_with_qmin) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(kh).kw(kw).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_div_16_unipass_fulltile_with_qmax) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(kh).kw(kw).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_div_16_unipass_fulltile_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).iterations(3); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(kh).kw(kw).kc(kc).xStride(257).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_div_16_unipass_subtile) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).iterations(3); + for (size_t ks = 2; ks < tester.mr(); ks++) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(ks).kw(1).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + tester.kh(1).kw(ks).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_gt_16_unipass_fulltile) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_gt_16_unipass_fulltile_with_qmin) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(kh).kw(kw).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_gt_16_unipass_fulltile_with_qmax) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(kh).kw(kw).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_gt_16_unipass_fulltile_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).iterations(3); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(kh).kw(kw).kc(kc).xStride(257).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_gt_16_unipass_subtile) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).iterations(3); + for (size_t ks = 2; ks < tester.mr(); ks++) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(ks).kw(1).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + tester.kh(1).kw(ks).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_eq_16_twopass_fulltile) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).kc(16); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + tester.kh(kh).kw(kw).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_eq_16_twopass_fulltile_with_qmin) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).kc(16); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + tester.kh(kh).kw(kw).qmin(192).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_eq_16_twopass_fulltile_with_qmax) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).kc(16); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + tester.kh(kh).kw(kw).qmax(192).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_eq_16_twopass_subtile) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).kc(16); + for (size_t ks = tester.mr() + 1; ks < tester.mr() + tester.qr(); ks++) { + tester.kh(ks).kw(1).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + tester.kh(1).kw(ks).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_div_16_twopass_fulltile) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_div_16_twopass_fulltile_with_qmin) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(kh).kw(kw).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_div_16_twopass_fulltile_with_qmax) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(kh).kw(kw).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_div_16_twopass_fulltile_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).iterations(3); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(kh).kw(kw).kc(kc).xStride(257).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_div_16_twopass_subtile) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).iterations(3); + for (size_t ks = tester.mr() + 1; ks < tester.mr() + tester.qr(); ks++) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(ks).kw(1).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + tester.kh(1).kw(ks).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_gt_16_twopass_fulltile) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_gt_16_twopass_fulltile_with_qmin) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(kh).kw(kw).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_gt_16_twopass_fulltile_with_qmax) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(kh).kw(kw).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_gt_16_twopass_fulltile_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).iterations(3); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(kh).kw(kw).kc(kc).xStride(257).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_gt_16_twopass_subtile) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).iterations(3); + for (size_t ks = tester.mr() + 1; ks < tester.mr() + tester.qr(); ks++) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(ks).kw(1).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + tester.kh(1).kw(ks).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_eq_16_multipass) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).kc(16); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + tester.kh(ks).kw(1).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + tester.kh(1).kw(ks).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_eq_16_multipass_with_qmin) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).kc(16); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + tester.kh(ks).kw(1).qmin(192).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + tester.kh(1).kw(ks).qmin(192).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_eq_16_multipass_with_qmax) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).kc(16); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + tester.kh(ks).kw(1).qmax(192).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + tester.kh(1).kw(ks).qmax(192).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_div_16_multipass) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(ks).kw(1).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + tester.kh(1).kw(ks).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_div_16_multipass_with_qmin) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(ks).kw(1).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + tester.kh(1).kw(ks).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_div_16_multipass_with_qmax) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(ks).kw(1).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + tester.kh(1).kw(ks).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_div_16_multipass_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).iterations(3); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(ks).kw(1).kc(kc).xStride(257).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + tester.kh(1).kw(ks).kc(kc).xStride(257).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_gt_16_multipass) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).iterations(3); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(ks).kw(1).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + tester.kh(1).kw(ks).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_gt_16_multipass_with_qmin) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).iterations(3); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(ks).kw(1).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + tester.kh(1).kw(ks).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_gt_16_multipass_with_qmax) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).iterations(3); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(ks).kw(1).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + tester.kh(1).kw(ks).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, kc_gt_16_multipass_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).iterations(3); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(ks).kw(1).kc(kc).xStride(257).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + tester.kh(1).kw(ks).kc(kc).xStride(257).test( + pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, small_n) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5, 10}}) { + for (size_t kc = 16; kc < 51; kc += 5) { + MaxPoolMicrokernelTester() + .kr(16) + .mr(9) + .qr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .iterations(3) + .test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, small_n_with_x_stride) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5, 10}}) { + for (size_t kc = 16; kc < 51; kc += 5) { + MaxPoolMicrokernelTester() + .kr(16) + .mr(9) + .qr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .xStride(101) + .iterations(1) + .test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, small_n_with_y_stride) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5, 10}}) { + for (size_t kc = 16; kc < 51; kc += 5) { + MaxPoolMicrokernelTester() + .kr(16) + .mr(9) + .qr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .yStride(103) + .iterations(1) + .test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__NEON, small_n_with_s) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t kc = 16; kc < 51; kc += 5) { + for (size_t s = 2; s <= ks; s++) { + MaxPoolMicrokernelTester() + .kr(16) + .mr(9) + .qr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .s(s) + .iterations(1) + .test(pytorch_u8maxpool_ukernel_16x9p8q__neon); + } + } + } + } +} +#endif + +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +TEST(U8MAXPOOL_SUB16__SSE2, kc_lt_16_mx1_pool) { + TEST_REQUIRES_X86_SSE2; + for (size_t kc = 1; kc < 16; kc++) { + for (size_t ks = 2; ks < 16; ks++) { + MaxPoolMicrokernelTester().kr(16).kh(ks).kw(1).kc(kc).test( + pytorch_u8maxpool_ukernel_sub16__sse2); + } + } +} + +TEST(U8MAXPOOL_SUB16__SSE2, kc_lt_16_mx1_pool_with_qmin) { + TEST_REQUIRES_X86_SSE2; + for (size_t kc = 1; kc < 16; kc++) { + for (size_t ks = 2; ks < 16; ks++) { + MaxPoolMicrokernelTester().kr(16).kh(ks).kw(1).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_sub16__sse2); + } + } +} + +TEST(U8MAXPOOL_SUB16__SSE2, kc_lt_16_mx1_pool_with_qmax) { + TEST_REQUIRES_X86_SSE2; + for (size_t kc = 1; kc < 16; kc++) { + for (size_t ks = 2; ks < 16; ks++) { + MaxPoolMicrokernelTester().kr(16).kh(ks).kw(1).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_sub16__sse2); + } + } +} + +TEST(U8MAXPOOL_SUB16__SSE2, kc_lt_16_1xm_pool) { + TEST_REQUIRES_X86_SSE2; + for (size_t kc = 1; kc < 16; kc++) { + for (size_t ks = 2; ks < 16; ks++) { + MaxPoolMicrokernelTester().kr(16).kh(1).kw(ks).kc(kc).test( + pytorch_u8maxpool_ukernel_sub16__sse2); + } + } +} + +TEST(U8MAXPOOL_SUB16__SSE2, kc_lt_16_1xm_pool_with_qmin) { + TEST_REQUIRES_X86_SSE2; + for (size_t kc = 1; kc < 16; kc++) { + for (size_t ks = 2; ks < 16; ks++) { + MaxPoolMicrokernelTester().kr(16).kh(1).kw(ks).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_sub16__sse2); + } + } +} + +TEST(U8MAXPOOL_SUB16__SSE2, kc_lt_16_1xm_pool_with_qmax) { + TEST_REQUIRES_X86_SSE2; + for (size_t kc = 1; kc < 16; kc++) { + for (size_t ks = 2; ks < 16; ks++) { + MaxPoolMicrokernelTester().kr(16).kh(1).kw(ks).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_sub16__sse2); + } + } +} + +TEST(U8MAXPOOL_SUB16__SSE2, kc_lt_16_small_n) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t kc = 1; kc < 16; kc++) { + MaxPoolMicrokernelTester() + .kr(16) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .iterations(3) + .test(pytorch_u8maxpool_ukernel_sub16__sse2); + } + } + } +} + +TEST(U8MAXPOOL_SUB16__SSE2, kc_lt_16_small_n_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t kc = 1; kc < 16; kc++) { + MaxPoolMicrokernelTester() + .kr(16) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .xStride(17) + .iterations(3) + .test(pytorch_u8maxpool_ukernel_sub16__sse2); + } + } + } +} + +TEST(U8MAXPOOL_SUB16__SSE2, kc_lt_16_small_n_with_s) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t s = 2; s <= 5; s++) { + for (size_t kc = 1; kc < 16; kc++) { + MaxPoolMicrokernelTester() + .kr(16) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .s(s) + .iterations(1) + .test(pytorch_u8maxpool_ukernel_sub16__sse2); + } + } + } + } +} + +TEST(U8MAXPOOL_SUB16__SSE2, kc_lt_16_small_n_with_qmin) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t kc = 1; kc < 16; kc++) { + MaxPoolMicrokernelTester() + .kr(16) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .qmin(192) + .iterations(3) + .test(pytorch_u8maxpool_ukernel_sub16__sse2); + } + } + } +} + +TEST(U8MAXPOOL_SUB16__SSE2, kc_lt_16_small_n_with_qmax) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t kc = 1; kc < 16; kc++) { + MaxPoolMicrokernelTester() + .kr(16) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .qmax(192) + .iterations(3) + .test(pytorch_u8maxpool_ukernel_sub16__sse2); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_eq_16_unipass_fulltile) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).kc(16); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + tester.kh(kh).kw(kw).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_eq_16_unipass_fulltile_with_qmin) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).kc(16); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + tester.kh(kh).kw(kw).qmin(192).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_eq_16_unipass_fulltile_with_qmax) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).kc(16); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + tester.kh(kh).kw(kw).qmax(192).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_eq_16_unipass_subtile) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).kc(16); + for (size_t ks = 2; ks < tester.mr(); ks++) { + tester.kh(ks).kw(1).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + tester.kh(1).kw(ks).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_div_16_unipass_fulltile) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_div_16_unipass_fulltile_with_qmin) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(kh).kw(kw).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_div_16_unipass_fulltile_with_qmax) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(kh).kw(kw).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_div_16_unipass_fulltile_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).iterations(3); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(kh).kw(kw).kc(kc).xStride(257).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_div_16_unipass_subtile) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).iterations(3); + for (size_t ks = 2; ks < tester.mr(); ks++) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(ks).kw(1).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + tester.kh(1).kw(ks).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_gt_16_unipass_fulltile) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_gt_16_unipass_fulltile_with_qmin) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(kh).kw(kw).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_gt_16_unipass_fulltile_with_qmax) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(kh).kw(kw).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_gt_16_unipass_fulltile_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).iterations(3); + for (size_t kh = 1; kh <= tester.mr(); kh++) { + for (size_t kw = 1; kw <= tester.mr(); kw++) { + if (kh * kw == tester.mr()) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(kh).kw(kw).kc(kc).xStride(257).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_gt_16_unipass_subtile) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).iterations(3); + for (size_t ks = 2; ks < tester.mr(); ks++) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(ks).kw(1).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + tester.kh(1).kw(ks).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_eq_16_twopass_fulltile) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).kc(16); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + tester.kh(kh).kw(kw).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_eq_16_twopass_fulltile_with_qmin) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).kc(16); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + tester.kh(kh).kw(kw).qmin(192).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_eq_16_twopass_fulltile_with_qmax) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).kc(16); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + tester.kh(kh).kw(kw).qmax(192).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_eq_16_twopass_subtile) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).kc(16); + for (size_t ks = tester.mr() + 1; ks < tester.mr() + tester.qr(); ks++) { + tester.kh(ks).kw(1).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + tester.kh(1).kw(ks).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_div_16_twopass_fulltile) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_div_16_twopass_fulltile_with_qmin) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(kh).kw(kw).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_div_16_twopass_fulltile_with_qmax) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(kh).kw(kw).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_div_16_twopass_fulltile_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).iterations(3); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(kh).kw(kw).kc(kc).xStride(257).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_div_16_twopass_subtile) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).iterations(3); + for (size_t ks = tester.mr() + 1; ks < tester.mr() + tester.qr(); ks++) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(ks).kw(1).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + tester.kh(1).kw(ks).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_gt_16_twopass_fulltile) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(kh).kw(kw).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_gt_16_twopass_fulltile_with_qmin) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(kh).kw(kw).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_gt_16_twopass_fulltile_with_qmax) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(kh).kw(kw).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_gt_16_twopass_fulltile_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).iterations(3); + for (size_t kh = 1; kh <= tester.mr() + tester.qr(); kh++) { + for (size_t kw = 1; kw <= tester.mr() + tester.qr(); kw++) { + if (kh * kw == tester.mr() + tester.qr()) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(kh).kw(kw).kc(kc).xStride(257).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_gt_16_twopass_subtile) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).iterations(3); + for (size_t ks = tester.mr() + 1; ks < tester.mr() + tester.qr(); ks++) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(ks).kw(1).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + tester.kh(1).kw(ks).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_eq_16_multipass) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).kc(16); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + tester.kh(ks).kw(1).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + tester.kh(1).kw(ks).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_eq_16_multipass_with_qmin) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).kc(16); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + tester.kh(ks).kw(1).qmin(192).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + tester.kh(1).kw(ks).qmin(192).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_eq_16_multipass_with_qmax) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).kc(16); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + tester.kh(ks).kw(1).qmax(192).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + tester.kh(1).kw(ks).qmax(192).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_div_16_multipass) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(ks).kw(1).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + tester.kh(1).kw(ks).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_div_16_multipass_with_qmin) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(ks).kw(1).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + tester.kh(1).kw(ks).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_div_16_multipass_with_qmax) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(ks).kw(1).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + tester.kh(1).kw(ks).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_div_16_multipass_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).iterations(3); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + for (size_t kc = 16; kc < 256; kc += 48) { + tester.kh(ks).kw(1).kc(kc).xStride(257).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + tester.kh(1).kw(ks).kc(kc).xStride(257).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_gt_16_multipass) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).iterations(3); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(ks).kw(1).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + tester.kh(1).kw(ks).kc(kc).test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_gt_16_multipass_with_qmin) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).iterations(3); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(ks).kw(1).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + tester.kh(1).kw(ks).kc(kc).qmin(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_gt_16_multipass_with_qmax) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).iterations(3); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(ks).kw(1).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + tester.kh(1).kw(ks).kc(kc).qmax(192).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, kc_gt_16_multipass_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + auto tester = MaxPoolMicrokernelTester().kr(16).mr(9).qr(8).iterations(3); + for (size_t ks = tester.mr() + tester.qr() + 1; + ks < tester.mr() + 3 * tester.qr(); + ks += 3) { + for (size_t kc = 17; kc < 32; kc++) { + tester.kh(ks).kw(1).kc(kc).xStride(257).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + tester.kh(1).kw(ks).kc(kc).xStride(257).test( + pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, small_n) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5, 10}}) { + for (size_t kc = 16; kc < 51; kc += 5) { + MaxPoolMicrokernelTester() + .kr(16) + .mr(9) + .qr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .iterations(3) + .test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, small_n_with_x_stride) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5, 10}}) { + for (size_t kc = 16; kc < 51; kc += 5) { + MaxPoolMicrokernelTester() + .kr(16) + .mr(9) + .qr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .xStride(101) + .iterations(1) + .test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, small_n_with_y_stride) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5, 10}}) { + for (size_t kc = 16; kc < 51; kc += 5) { + MaxPoolMicrokernelTester() + .kr(16) + .mr(9) + .qr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .yStride(103) + .iterations(1) + .test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } +} + +TEST(U8MAXPOOL_16x9P8Q__SSE2, small_n_with_s) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 2; n < 5; n++) { + for (size_t ks : std::vector{{2, 3, 5}}) { + for (size_t kc = 16; kc < 51; kc += 5) { + for (size_t s = 2; s <= ks; s++) { + MaxPoolMicrokernelTester() + .kr(16) + .mr(9) + .qr(8) + .n(n) + .kh(ks) + .kw(ks) + .kc(kc) + .s(s) + .iterations(1) + .test(pytorch_u8maxpool_ukernel_16x9p8q__sse2); + } + } + } + } +} +#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/u8rmax.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/u8rmax.cc new file mode 100644 index 0000000000000..60628d8451682 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/u8rmax.cc @@ -0,0 +1,71 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include "rmax-microkernel-tester.h" + +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 +TEST(U8RMAX__NEON, n_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 16; n++) { + RMaxMicrokernelTester().n(n).test(pytorch_u8rmax_ukernel__neon); + } +} + +TEST(U8RMAX__NEON, n_eq_16) { + TEST_REQUIRES_ARM_NEON; + RMaxMicrokernelTester().n(16).test(pytorch_u8rmax_ukernel__neon); +} + +TEST(U8RMAX__NEON, n_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 16; n < 128; n += 16) { + RMaxMicrokernelTester().n(n).test(pytorch_u8rmax_ukernel__neon); + } +} + +TEST(U8RMAX__NEON, n_gt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 16; n < 32; n++) { + RMaxMicrokernelTester().n(n).test(pytorch_u8rmax_ukernel__neon); + } +} +#endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */ + +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +TEST(U8RMAX__SSE2, n_lt_16) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 16; n++) { + RMaxMicrokernelTester().n(n).test(pytorch_u8rmax_ukernel__sse2); + } +} + +TEST(U8RMAX__SSE2, n_eq_16) { + TEST_REQUIRES_X86_SSE2; + RMaxMicrokernelTester().n(16).test(pytorch_u8rmax_ukernel__sse2); +} + +TEST(U8RMAX__SSE2, n_div_16) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 16; n < 128; n += 16) { + RMaxMicrokernelTester().n(n).test(pytorch_u8rmax_ukernel__sse2); + } +} + +TEST(U8RMAX__SSE2, n_gt_16) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 17; n < 32; n++) { + RMaxMicrokernelTester().n(n).test(pytorch_u8rmax_ukernel__sse2); + } +} +#endif /* CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 */ diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/vadd-microkernel-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/vadd-microkernel-tester.h new file mode 100644 index 0000000000000..bf5c616677fa4 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/vadd-microkernel-tester.h @@ -0,0 +1,224 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +class VAddMicrokernelTester { + public: + inline VAddMicrokernelTester& n(size_t n) { + assert(n != 0); + this->n_ = n; + return *this; + } + + inline size_t n() const { + return this->n_; + } + + inline VAddMicrokernelTester& inplaceA(bool inplaceA) { + this->inplaceA_ = inplaceA; + return *this; + } + + inline bool inplaceA() const { + return this->inplaceA_; + } + + inline VAddMicrokernelTester& inplaceB(bool inplaceB) { + this->inplaceB_ = inplaceB; + return *this; + } + + inline bool inplaceB() const { + return this->inplaceB_; + } + + inline VAddMicrokernelTester& aScale(float aScale) { + assert(aScale > 0.0f); + assert(std::isnormal(aScale)); + this->aScale_ = aScale; + return *this; + } + + inline float aScale() const { + return this->aScale_; + } + + inline VAddMicrokernelTester& aZeroPoint(uint8_t aZeroPoint) { + this->aZeroPoint_ = aZeroPoint; + return *this; + } + + inline uint8_t aZeroPoint() const { + return this->aZeroPoint_; + } + + inline VAddMicrokernelTester& bScale(float bScale) { + assert(bScale > 0.0f); + assert(std::isnormal(bScale)); + this->bScale_ = bScale; + return *this; + } + + inline float bScale() const { + return this->bScale_; + } + + inline VAddMicrokernelTester& bZeroPoint(uint8_t bZeroPoint) { + this->bZeroPoint_ = bZeroPoint; + return *this; + } + + inline uint8_t bZeroPoint() const { + return this->bZeroPoint_; + } + + inline VAddMicrokernelTester& yScale(float yScale) { + assert(yScale > 0.0f); + assert(std::isnormal(yScale)); + this->yScale_ = yScale; + return *this; + } + + inline float yScale() const { + return this->yScale_; + } + + inline VAddMicrokernelTester& yZeroPoint(uint8_t yZeroPoint) { + this->yZeroPoint_ = yZeroPoint; + return *this; + } + + inline uint8_t yZeroPoint() const { + return this->yZeroPoint_; + } + + inline VAddMicrokernelTester& qmin(uint8_t qmin) { + this->qmin_ = qmin; + return *this; + } + + inline uint8_t qmin() const { + return this->qmin_; + } + + inline VAddMicrokernelTester& qmax(uint8_t qmax) { + this->qmax_ = qmax; + return *this; + } + + inline uint8_t qmax() const { + return this->qmax_; + } + + inline VAddMicrokernelTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void test(pytorch_q8vadd_ukernel_function q8vadd) const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector a(n()); + std::vector b(n()); + std::vector y(n()); + std::vector yFP(n()); + std::vector yRef(n()); + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(a.begin(), a.end(), std::ref(u8rng)); + std::generate(b.begin(), b.end(), std::ref(u8rng)); + if (inplaceA() || inplaceB()) { + std::generate(y.begin(), y.end(), std::ref(u8rng)); + } else { + std::fill(y.begin(), y.end(), 0xA5); + } + const uint8_t* aData = inplaceA() ? y.data() : a.data(); + const uint8_t* bData = inplaceB() ? y.data() : b.data(); + + /* Prepare quantization parameters */ + const union pytorch_qnnp_add_quantization_params quantizationParams = + pytorch_qnnp_compute_add_quantization_params( + aZeroPoint(), + bZeroPoint(), + yZeroPoint(), + aScale() / yScale(), + bScale() / yScale(), + qmin(), + qmax()); + const union pytorch_qnnp_add_quantization_params + scalarQuantizationParams = + pytorch_qnnp_compute_scalar_add_quantization_params( + aZeroPoint(), + bZeroPoint(), + yZeroPoint(), + aScale() / yScale(), + bScale() / yScale(), + qmin(), + qmax()); + + /* Compute reference results */ + for (size_t i = 0; i < n(); i++) { + yFP[i] = float(yZeroPoint()) + + float(int32_t(aData[i]) - int32_t(aZeroPoint())) * + (aScale() / yScale()) + + float(int32_t(bData[i]) - int32_t(bZeroPoint())) * + (bScale() / yScale()); + yFP[i] = std::min(yFP[i], float(qmax())); + yFP[i] = std::max(yFP[i], float(qmin())); + yRef[i] = pytorch_qnnp_add_quantize( + aData[i], bData[i], scalarQuantizationParams); + } + + /* Call optimized micro-kernel */ + q8vadd(n(), aData, bData, y.data(), &quantizationParams); + + /* Verify results */ + for (size_t i = 0; i < n(); i++) { + ASSERT_LE(uint32_t(y[i]), uint32_t(qmax())) + << "at " << i << ", n = " << n(); + ASSERT_GE(uint32_t(y[i]), uint32_t(qmin())) + << "at " << i << ", n = " << n(); + ASSERT_NEAR(float(int32_t(y[i])), yFP[i], 0.6f) + << "at " << i << ", n = " << n(); + ASSERT_EQ(uint32_t(yRef[i]), uint32_t(y[i])) + << "at " << i << ", n = " << n(); + } + } + } + + private: + size_t n_{1}; + bool inplaceA_{false}; + bool inplaceB_{false}; + float aScale_{0.75f}; + float bScale_{1.25f}; + float yScale_{0.96875f}; + uint8_t aZeroPoint_{121}; + uint8_t bZeroPoint_{127}; + uint8_t yZeroPoint_{133}; + uint8_t qmin_{0}; + uint8_t qmax_{255}; + size_t iterations_{15}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/x8lut.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/x8lut.cc new file mode 100644 index 0000000000000..ae87934b906e5 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/x8lut.cc @@ -0,0 +1,45 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include "lut-microkernel-tester.h" + +TEST(X8LUT__SCALAR, n_eq_1) { + LUTMicrokernelTester().n(1).test(pytorch_x8lut_ukernel__scalar); +} + +TEST(X8LUT__SCALAR, small_n) { + for (size_t n = 2; n <= 16; n++) { + LUTMicrokernelTester().n(n).test(pytorch_x8lut_ukernel__scalar); + } +} + +TEST(X8LUT__SCALAR, large_n) { + for (size_t n = 16; n <= 128; n += 2) { + LUTMicrokernelTester().n(n).test(pytorch_x8lut_ukernel__scalar); + } +} + +TEST(X8LUT__SCALAR, n_eq_1_inplace) { + LUTMicrokernelTester().n(1).inplace(true).test(pytorch_x8lut_ukernel__scalar); +} + +TEST(X8LUT__SCALAR, small_n_inplace) { + for (size_t n = 2; n <= 16; n++) { + LUTMicrokernelTester().n(n).inplace(true).test(pytorch_x8lut_ukernel__scalar); + } +} + +TEST(X8LUT__SCALAR, large_n_inplace) { + for (size_t n = 16; n <= 128; n += 2) { + LUTMicrokernelTester().n(n).inplace(true).test(pytorch_x8lut_ukernel__scalar); + } +} diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/x8zip.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/test/x8zip.cc new file mode 100644 index 0000000000000..50ff4ee38a03b --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/x8zip.cc @@ -0,0 +1,350 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +#include "zip-microkernel-tester.h" + +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 +TEST(X8ZIP_X2__NEON, n_eq_8) { + TEST_REQUIRES_ARM_NEON; + ZipMicrokernelTester().n(8).g(2).test(pytorch_qnnp_x8zip_x2__neon); +} + +TEST(X8ZIP_X2__NEON, n_div_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 8; n < 128; n += 8) { + ZipMicrokernelTester().n(n).g(2).test(pytorch_qnnp_x8zip_x2__neon); + } +} + +TEST(X8ZIP_X2__NEON, n_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + ZipMicrokernelTester().n(n).g(2).test(pytorch_qnnp_x8zip_x2__neon); + } +} + +TEST(X8ZIP_X2__NEON, n_lt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 8; n++) { + ZipMicrokernelTester().n(n).g(2).test(pytorch_qnnp_x8zip_x2__neon); + } +} + +TEST(X8ZIP_X3__NEON, n_eq_8) { + TEST_REQUIRES_ARM_NEON; + ZipMicrokernelTester().n(9).g(3).test(pytorch_qnnp_x8zip_x3__neon); +} + +TEST(X8ZIP_X3__NEON, n_div_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 8; n < 128; n += 8) { + ZipMicrokernelTester().n(n).g(3).test(pytorch_qnnp_x8zip_x3__neon); + } +} + +TEST(X8ZIP_X3__NEON, n_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + ZipMicrokernelTester().n(n).g(3).test(pytorch_qnnp_x8zip_x3__neon); + } +} + +TEST(X8ZIP_X3__NEON, n_lt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 8; n++) { + ZipMicrokernelTester().n(n).g(3).test(pytorch_qnnp_x8zip_x3__neon); + } +} + +TEST(X8ZIP_X4__NEON, n_eq_8) { + TEST_REQUIRES_ARM_NEON; + ZipMicrokernelTester().n(8).g(4).test(pytorch_qnnp_x8zip_x4__neon); +} + +TEST(X8ZIP_X4__NEON, n_div_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 8; n < 128; n += 8) { + ZipMicrokernelTester().n(n).g(4).test(pytorch_qnnp_x8zip_x4__neon); + } +} + +TEST(X8ZIP_X4__NEON, n_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + ZipMicrokernelTester().n(n).g(4).test(pytorch_qnnp_x8zip_x4__neon); + } +} + +TEST(X8ZIP_X4__NEON, n_lt_16) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 16; n++) { + ZipMicrokernelTester().n(n).g(4).test(pytorch_qnnp_x8zip_x4__neon); + } +} + +TEST(X8ZIP_XM__NEON, n_eq_8_m_eq_4) { + TEST_REQUIRES_ARM_NEON; + ZipMicrokernelTester().n(8).g(4).test(pytorch_qnnp_x8zip_xm__neon); +} + +TEST(X8ZIP_XM__NEON, n_eq_8_m_div_4) { + TEST_REQUIRES_ARM_NEON; + for (size_t g = 4; g < 32; g += 4) { + ZipMicrokernelTester().n(8).g(g).test(pytorch_qnnp_x8zip_xm__neon); + } +} + +TEST(X8ZIP_XM__NEON, n_eq_8_m_gt_4) { + TEST_REQUIRES_ARM_NEON; + for (size_t g = 5; g < 8; g++) { + ZipMicrokernelTester().n(8).g(g).test(pytorch_qnnp_x8zip_xm__neon); + } +} + +TEST(X8ZIP_XM__NEON, n_div_8_m_eq_4) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 8; n < 128; n += 8) { + ZipMicrokernelTester().n(n).g(4).test(pytorch_qnnp_x8zip_xm__neon); + } +} + +TEST(X8ZIP_XM__NEON, n_div_8_m_div_4) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 8; n < 128; n += 8) { + for (size_t g = 4; g < 32; g += 4) { + ZipMicrokernelTester().n(n).g(g).test(pytorch_qnnp_x8zip_xm__neon); + } + } +} + +TEST(X8ZIP_XM__NEON, n_div_8_m_gt_4) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 8; n < 128; n += 8) { + for (size_t g = 5; g < 8; g++) { + ZipMicrokernelTester().n(n).g(g).test(pytorch_qnnp_x8zip_xm__neon); + } + } +} + +TEST(X8ZIP_XM__NEON, n_gt_8_m_eq_4) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + ZipMicrokernelTester().n(n).g(4).test(pytorch_qnnp_x8zip_xm__neon); + } +} + +TEST(X8ZIP_XM__NEON, n_gt_8_m_div_4) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + for (size_t g = 4; g < 32; g += 4) { + ZipMicrokernelTester().n(n).g(g).test(pytorch_qnnp_x8zip_xm__neon); + } + } +} + +TEST(X8ZIP_XM__NEON, n_gt_8_m_gt_4) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 9; n < 16; n++) { + for (size_t g = 5; g < 8; g++) { + ZipMicrokernelTester().n(n).g(g).test(pytorch_qnnp_x8zip_xm__neon); + } + } +} + +TEST(X8ZIP_XM__NEON, n_lt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t n = 1; n < 8; n++) { + for (size_t g = 4; g < 12; g++) { + ZipMicrokernelTester().n(n).g(g).test(pytorch_qnnp_x8zip_xm__neon); + } + } +} +#endif + +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +TEST(X8ZIP_X2__SSE2, n_eq_16) { + TEST_REQUIRES_X86_SSE2; + ZipMicrokernelTester().n(16).g(2).test(pytorch_qnnp_x8zip_x2__sse2); +} + +TEST(X8ZIP_X2__SSE2, n_div_16) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 16; n < 256; n += 16) { + ZipMicrokernelTester().n(n).g(2).test(pytorch_qnnp_x8zip_x2__sse2); + } +} + +TEST(X8ZIP_X2__SSE2, n_gt_16) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 17; n < 32; n++) { + ZipMicrokernelTester().n(n).g(2).test(pytorch_qnnp_x8zip_x2__sse2); + } +} + +TEST(X8ZIP_X2__SSE2, n_lt_16) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 16; n++) { + ZipMicrokernelTester().n(n).g(2).test(pytorch_qnnp_x8zip_x2__sse2); + } +} + +TEST(X8ZIP_X3__SSE2, n_eq_16) { + TEST_REQUIRES_X86_SSE2; + ZipMicrokernelTester().n(16).g(3).test(pytorch_qnnp_x8zip_x3__sse2); +} + +TEST(X8ZIP_X3__SSE2, n_div_16) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 16; n < 256; n += 16) { + ZipMicrokernelTester().n(n).g(3).test(pytorch_qnnp_x8zip_x3__sse2); + } +} + +TEST(X8ZIP_X3__SSE2, n_gt_16) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 17; n < 32; n++) { + ZipMicrokernelTester().n(n).g(3).test(pytorch_qnnp_x8zip_x3__sse2); + } +} + +TEST(X8ZIP_X3__SSE2, n_lt_16) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 16; n++) { + ZipMicrokernelTester().n(n).g(3).test(pytorch_qnnp_x8zip_x3__sse2); + } +} + +TEST(X8ZIP_X4__SSE2, n_eq_16) { + TEST_REQUIRES_X86_SSE2; + ZipMicrokernelTester().n(16).g(4).test(pytorch_qnnp_x8zip_x4__sse2); +} + +TEST(X8ZIP_X4__SSE2, n_div_16) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 16; n < 256; n += 16) { + ZipMicrokernelTester().n(n).g(4).test(pytorch_qnnp_x8zip_x4__sse2); + } +} + +TEST(X8ZIP_X4__SSE2, n_gt_16) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 17; n < 32; n++) { + ZipMicrokernelTester().n(n).g(4).test(pytorch_qnnp_x8zip_x4__sse2); + } +} + +TEST(X8ZIP_X4__SSE2, n_lt_16) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 16; n++) { + ZipMicrokernelTester().n(n).g(4).test(pytorch_qnnp_x8zip_x4__sse2); + } +} + +TEST(X8ZIP_XM__SSE2, n_eq_8_m_eq_4) { + TEST_REQUIRES_X86_SSE2; + ZipMicrokernelTester().n(8).g(4).test(pytorch_qnnp_x8zip_xm__sse2); +} + +TEST(X8ZIP_XM__SSE2, n_eq_8_m_div_4) { + TEST_REQUIRES_X86_SSE2; + for (size_t g = 4; g < 32; g += 4) { + ZipMicrokernelTester().n(8).g(g).test(pytorch_qnnp_x8zip_xm__sse2); + } +} + +TEST(X8ZIP_XM__SSE2, n_eq_8_m_gt_4) { + TEST_REQUIRES_X86_SSE2; + for (size_t g = 5; g < 8; g++) { + ZipMicrokernelTester().n(8).g(g).test(pytorch_qnnp_x8zip_xm__sse2); + } +} + +TEST(X8ZIP_XM__SSE2, n_eq_16_m_eq_4) { + TEST_REQUIRES_X86_SSE2; + ZipMicrokernelTester().n(16).g(4).test(pytorch_qnnp_x8zip_xm__sse2); +} + +TEST(X8ZIP_XM__SSE2, n_eq_16_m_div_4) { + TEST_REQUIRES_X86_SSE2; + for (size_t g = 4; g < 32; g += 4) { + ZipMicrokernelTester().n(16).g(g).test(pytorch_qnnp_x8zip_xm__sse2); + } +} + +TEST(X8ZIP_XM__SSE2, n_eq_16_m_gt_4) { + TEST_REQUIRES_X86_SSE2; + for (size_t g = 5; g < 8; g++) { + ZipMicrokernelTester().n(16).g(g).test(pytorch_qnnp_x8zip_xm__sse2); + } +} + +TEST(X8ZIP_XM__SSE2, n_div_16_m_eq_4) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 16; n < 256; n += 16) { + ZipMicrokernelTester().n(n).g(4).test(pytorch_qnnp_x8zip_xm__sse2); + } +} + +TEST(X8ZIP_XM__SSE2, n_div_16_m_div_4) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 16; n < 256; n += 16) { + for (size_t g = 4; g < 32; g += 4) { + ZipMicrokernelTester().n(n).g(g).test(pytorch_qnnp_x8zip_xm__sse2); + } + } +} + +TEST(X8ZIP_XM__SSE2, n_div_16_m_gt_4) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 16; n < 256; n += 16) { + for (size_t g = 5; g < 8; g++) { + ZipMicrokernelTester().n(n).g(g).test(pytorch_qnnp_x8zip_xm__sse2); + } + } +} + +TEST(X8ZIP_XM__SSE2, n_gt_16_m_eq_4) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 17; n < 32; n++) { + ZipMicrokernelTester().n(n).g(4).test(pytorch_qnnp_x8zip_xm__sse2); + } +} + +TEST(X8ZIP_XM__SSE2, n_gt_16_m_div_4) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 17; n < 32; n++) { + for (size_t g = 4; g < 32; g += 4) { + ZipMicrokernelTester().n(n).g(g).test(pytorch_qnnp_x8zip_xm__sse2); + } + } +} + +TEST(X8ZIP_XM__SSE2, n_gt_16_m_gt_4) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 17; n < 32; n++) { + for (size_t g = 5; g < 8; g++) { + ZipMicrokernelTester().n(n).g(g).test(pytorch_qnnp_x8zip_xm__sse2); + } + } +} + +TEST(X8ZIP_XM__SSE2, n_lt_16) { + TEST_REQUIRES_X86_SSE2; + for (size_t n = 1; n < 16; n++) { + for (size_t g = 4; g < 12; g++) { + ZipMicrokernelTester().n(n).g(g).test(pytorch_qnnp_x8zip_xm__sse2); + } + } +} +#endif diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/test/zip-microkernel-tester.h b/aten/src/ATen/native/quantized/cpu/qnnpack/test/zip-microkernel-tester.h new file mode 100644 index 0000000000000..455b90fac00c8 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/test/zip-microkernel-tester.h @@ -0,0 +1,108 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +class ZipMicrokernelTester { + public: + inline ZipMicrokernelTester& n(size_t n) { + assert(n != 0); + this->n_ = n; + return *this; + } + + inline size_t n() const { + return this->n_; + } + + inline ZipMicrokernelTester& g(size_t g) { + assert(g != 0); + this->g_ = g; + return *this; + } + + inline size_t g() const { + return this->g_; + } + + inline ZipMicrokernelTester& iterations(size_t iterations) { + this->iterations_ = iterations; + return *this; + } + + inline size_t iterations() const { + return this->iterations_; + } + + void test(pytorch_xzipc_ukernel_function xzip) const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector x(n() * g()); + std::vector y(g() * n()); + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(x.begin(), x.end(), std::ref(u8rng)); + std::fill(y.begin(), y.end(), 0xA5); + + /* Call optimized micro-kernel */ + xzip(n(), x.data(), y.data()); + + /* Verify results */ + for (size_t i = 0; i < n(); i++) { + for (size_t j = 0; j < g(); j++) { + ASSERT_EQ(uint32_t(y[i * g() + j]), uint32_t(x[j * n() + i])) + << "at element " << i << ", group " << j; + } + } + } + } + + void test(pytorch_xzipv_ukernel_function xzip) const { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector x(n() * g()); + std::vector y(g() * n()); + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(x.begin(), x.end(), std::ref(u8rng)); + std::fill(y.begin(), y.end(), 0xA5); + + /* Call optimized micro-kernel */ + xzip(n(), g(), x.data(), y.data()); + + /* Verify results */ + for (size_t i = 0; i < n(); i++) { + for (size_t j = 0; j < g(); j++) { + ASSERT_EQ(uint32_t(y[i * g() + j]), uint32_t(x[j * n() + i])) + << "at element " << i << ", group " << j; + } + } + } + } + + private: + size_t n_{1}; + size_t g_{1}; + size_t iterations_{3}; +}; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack_add.cpp b/aten/src/ATen/native/quantized/cpu/qnnpack_add.cpp index e958b77661896..57be0f3091c4a 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack_add.cpp +++ b/aten/src/ATen/native/quantized/cpu/qnnpack_add.cpp @@ -11,7 +11,7 @@ namespace { class QNNPACKAdd final : public torch::OperatorKernel { public: -#ifdef USE_QNNPACK +#ifdef USE_PYTORCH_QNNPACK Tensor operator()(Tensor qa, Tensor qb, double scale, int64_t zero_point) { TORCH_CHECK(qa.ndimension() > 0, "qnnpack_add(): Got empty input tensor."); TORCH_CHECK( @@ -38,11 +38,11 @@ class QNNPACKAdd final : public torch::OperatorKernel { initQNNPACK(); - qnnp_operator_t qnnpack_operator{nullptr}; + pytorch_qnnp_operator_t qnnpack_operator{nullptr}; size_t num_elems = qa_contig.numel() / qa_contig.size(0); - const qnnp_status createStatus = qnnp_create_add_nc_q8( + const pytorch_qnnp_status createStatus = pytorch_qnnp_create_add_nc_q8( num_elems /* input size */, a_zero_point /* a zero_point */, a_scale /* a scale */, @@ -56,14 +56,14 @@ class QNNPACKAdd final : public torch::OperatorKernel { &qnnpack_operator); TORCH_INTERNAL_ASSERT( - createStatus == qnnp_status_success, + createStatus == pytorch_qnnp_status_success, "failed to create QNNPACK Add operator"); TORCH_INTERNAL_ASSERT(qnnpack_operator != nullptr); - std::unique_ptr qnnpack_uniq_ptr( - qnnpack_operator); + std::unique_ptr + qnnpack_uniq_ptr(qnnpack_operator); - const qnnp_status setupStatus = qnnp_setup_add_nc_q8( + const pytorch_qnnp_status setupStatus = pytorch_qnnp_setup_add_nc_q8( qnnpack_operator /* add op */, qa_contig.size(0) /* batch size */, (uint8_t*)qa_contig.data_ptr() /* a data */, @@ -73,15 +73,16 @@ class QNNPACKAdd final : public torch::OperatorKernel { (uint8_t*)qy.data_ptr() /* output data */, num_elems /* sum stride */); TORCH_INTERNAL_ASSERT( - setupStatus == qnnp_status_success, + setupStatus == pytorch_qnnp_status_success, "failed to setup QNNPACK Add operator"); pthreadpool_t threadpool = caffe2::mobile_threadpool(); - const qnnp_status runStatus = - qnnp_run_operator(qnnpack_operator, threadpool); + const pytorch_qnnp_status runStatus = + pytorch_qnnp_run_operator(qnnpack_operator, threadpool); TORCH_INTERNAL_ASSERT( - runStatus == qnnp_status_success, "failed to run QNNPACK Add operator"); + runStatus == pytorch_qnnp_status_success, + "failed to run QNNPACK Add operator"); return qy; } diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack_fc.cpp b/aten/src/ATen/native/quantized/cpu/qnnpack_fc.cpp deleted file mode 100644 index 45f96b2764c55..0000000000000 --- a/aten/src/ATen/native/quantized/cpu/qnnpack_fc.cpp +++ /dev/null @@ -1,130 +0,0 @@ -#include -#include -#include -#include -#include - -#include "init_qnnpack.h" -#include "qnnpack_utils.h" - -namespace at { -namespace native { -namespace { - -class QNNPACKLinear final : public torch::OperatorKernel { - public: -#ifdef USE_QNNPACK - Tensor operator()( - at::Tensor input, - at::Tensor weight, - at::Tensor bias, - double output_scale, - int64_t output_zero_point) { - TORCH_CHECK(input.dim() >= 2, "Input tensor rank should be >= 2"); - TORCH_CHECK(weight.dim() == 2, "Weight tensor rank should be == 2"); - - Tensor input_contig = input.contiguous(); - - // Y(output) = X(input_contig) x W(weight) - int64_t rows_x = 1; - int64_t cols_x = input_contig.size(input_contig.dim() - 1); - for (size_t i = 0; i < input_contig.dim() - 1; ++i) { - rows_x *= input_contig.size(i); - } - - int64_t rows_y = weight.size(0); - - TORCH_CHECK( - cols_x == weight.size(1), - "qnnpack_linear(): input size does not match weight dimension 1 size: got ", - cols_x, - " but expected ", - weight.size(1)); - - TORCH_CHECK( - !bias.defined() || (bias.ndimension() == 1 && bias.size(0) == rows_y), - "qnnpack_linear(): Given weight of size ", - weight.sizes(), - ", expected bias to be 1-dimensional with ", - rows_y, - " elements", - ", but got bias of size ", - bias.sizes(), - " instead"); - - initQNNPACK(); - - // Allocate output Tensor and a buffer for QNNPACK to use - Tensor output = at::_empty_affine_quantized( - {rows_x, rows_y}, input.options(), output_scale, output_zero_point); - - qnnp_operator_t qnnpack_operator{nullptr}; - - // QNNPACK expects both weights and inputs to be uint8 - const qnnp_status createStatus = qnnp_create_fully_connected_nc_q8( - cols_x /* input channels */, - rows_y /* output channels */, - input_contig.q_zero_point() /* input zero_point */, - input_contig.q_scale() /* input scale */, - weight.q_zero_point() /* kernel zero_point */, - weight.q_scale() /* kernel scale */, - (uint8_t*)weight.data_ptr() /* kernel data */, - (int32_t*)bias.data_ptr() /* bias data */, - output.q_zero_point() /* output zero_point */, - output.q_scale() /* output scale */, - std::numeric_limits::min() /* output_min */, - std::numeric_limits::max() /* output_max */, - 0 /* flags */, - &qnnpack_operator); - - std::unique_ptr qnnpack_uniq_ptr( - qnnpack_operator); - - TORCH_INTERNAL_ASSERT( - createStatus == qnnp_status_success, - "failed to create QNNPACK Linear operator"); - TORCH_INTERNAL_ASSERT(qnnpack_operator != nullptr); - - const qnnp_status setupStatus = qnnp_setup_fully_connected_nc_q8( - qnnpack_operator /* fully_connected */, - rows_x /* batch_size */, - (uint8_t*)input_contig.data_ptr() /* input */, - cols_x /* input stride */, - (uint8_t*)output.data_ptr() /* output */, - rows_y /* output stride */); - - TORCH_INTERNAL_ASSERT( - setupStatus == qnnp_status_success, - "failed to setup QNNPACK Linear operator"); - pthreadpool_t threadpool = caffe2::mobile_threadpool(); - - const qnnp_status runStatus = - qnnp_run_operator(qnnpack_operator, threadpool); - - TORCH_INTERNAL_ASSERT( - runStatus == qnnp_status_success, "failed to run QNNPACK operator"); - - return output; - } -#else - Tensor operator()( - at::Tensor /* input */, - at::Tensor /* weight */, - at::Tensor /* bias */, - double /* output_scale */, - int64_t /* output_zero_point */) { - TORCH_CHECK( - false, - "This PyTorch installation was not built " - "with QNNPACK operators"); - } -#endif -}; - -static auto registry = torch::RegisterOperators().op( - "quantized::qnnpack_linear(Tensor X, Tensor W, Tensor b, float Y_scale, int Y_zero_point) -> Tensor", - torch::RegisterOperators::options().kernel( - TensorTypeId::QuantizedCPUTensorId)); -} // namespace -} // namespace native -} // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack_maxpool.cpp b/aten/src/ATen/native/quantized/cpu/qnnpack_maxpool.cpp index 083610551a890..c9a4d4cf1bcc5 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack_maxpool.cpp +++ b/aten/src/ATen/native/quantized/cpu/qnnpack_maxpool.cpp @@ -12,7 +12,7 @@ namespace { class QNNPACKMaxPool2D final : public torch::OperatorKernel { public: -#ifdef USE_QNNPACK +#ifdef USE_PYTORCH_QNNPACK Tensor operator()( Tensor input, const torch::List& kernel_size, @@ -47,7 +47,7 @@ class QNNPACKMaxPool2D final : public torch::OperatorKernel { initQNNPACK(); const auto scale = input_contig.q_scale(); const auto zero_point = input_contig.q_zero_point(); - qnnp_operator_t qnnpack_operator{nullptr}; + pytorch_qnnp_operator_t qnnpack_operator{nullptr}; int64_t padH = padding[0]; int64_t padW = padding[1]; @@ -71,24 +71,25 @@ class QNNPACKMaxPool2D final : public torch::OperatorKernel { int64_t inW = input_contig.size(2); int64_t inC = input_contig.size(3); - const qnnp_status createStatus = qnnp_create_max_pooling2d_nhwc_u8( - padH /* input_padding_top */, - padW /* input_padding_right */, - padH /* input_padding_bottom */, - padW /* input_padding_left */, - kH /* pooling height */, - kW /* pooling width */, - strideH /* stride height */, - strideW /* stride width */, - dilationH /* dilation height */, - dilationW /* dilation width */, - inC /* input channels */, - std::numeric_limits::min() /* output min */, - std::numeric_limits::max() /* output max */, - 0 /* flags */, - &qnnpack_operator); + const pytorch_qnnp_status createStatus = + pytorch_qnnp_create_max_pooling2d_nhwc_u8( + padH /* input_padding_top */, + padW /* input_padding_right */, + padH /* input_padding_bottom */, + padW /* input_padding_left */, + kH /* pooling height */, + kW /* pooling width */, + strideH /* stride height */, + strideW /* stride width */, + dilationH /* dilation height */, + dilationW /* dilation width */, + inC /* input channels */, + std::numeric_limits::min() /* output min */, + std::numeric_limits::max() /* output max */, + 0 /* flags */, + &qnnpack_operator); TORCH_INTERNAL_ASSERT( - createStatus == qnnp_status_success, + createStatus == pytorch_qnnp_status_success, "failed to create QNNPACK MaxPool operator"); TORCH_INTERNAL_ASSERT(qnnpack_operator != nullptr); @@ -102,8 +103,8 @@ class QNNPACKMaxPool2D final : public torch::OperatorKernel { outH > 0 && outW > 0, "qnnpack_maxpool(): the resulting output Tensor size should be >= 0"); - std::unique_ptr qnnpack_uniq_ptr( - qnnpack_operator); + std::unique_ptr + qnnpack_uniq_ptr(qnnpack_operator); // NHWC output qy = at::_empty_affine_quantized( @@ -112,25 +113,26 @@ class QNNPACKMaxPool2D final : public torch::OperatorKernel { scale, zero_point); - const qnnp_status setupStatus = qnnp_setup_max_pooling2d_nhwc_u8( - qnnpack_operator /* max pooling */, - batch_size /* batch size */, - inH /* input height */, - inW /* input width */, - (uint8_t*)input_contig.data_ptr() /* input */, - inC /* input_pixel_stride */, - (uint8_t*)qy.data_ptr() /* output data */, - outC /* output_pixel_stride */, - nullptr /* thread pool */); + const pytorch_qnnp_status setupStatus = + pytorch_qnnp_setup_max_pooling2d_nhwc_u8( + qnnpack_operator /* max pooling */, + batch_size /* batch size */, + inH /* input height */, + inW /* input width */, + (uint8_t*)input_contig.data_ptr() /* input */, + inC /* input_pixel_stride */, + (uint8_t*)qy.data_ptr() /* output data */, + outC /* output_pixel_stride */, + nullptr /* thread pool */); TORCH_INTERNAL_ASSERT( - setupStatus == qnnp_status_success, + setupStatus == pytorch_qnnp_status_success, "failed to setup QNNPACK MaxPool operator"); pthreadpool_t threadpool = caffe2::mobile_threadpool(); - const qnnp_status runStatus = - qnnp_run_operator(qnnpack_operator, threadpool); + const pytorch_qnnp_status runStatus = + pytorch_qnnp_run_operator(qnnpack_operator, threadpool); TORCH_INTERNAL_ASSERT( - runStatus == qnnp_status_success, + runStatus == pytorch_qnnp_status_success, "failed to run QNNPACK MaxPool operator"); return qy; } diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack_relu.cpp b/aten/src/ATen/native/quantized/cpu/qnnpack_relu.cpp index 93db4cfa5e8d2..8cf68a1918412 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack_relu.cpp +++ b/aten/src/ATen/native/quantized/cpu/qnnpack_relu.cpp @@ -11,7 +11,7 @@ namespace { class QNNPACKRelu final : public torch::OperatorKernel { public: -#ifdef USE_QNNPACK +#ifdef USE_PYTORCH_QNNPACK Tensor operator()(Tensor input) { Tensor qy; @@ -31,20 +31,20 @@ class QNNPACKRelu final : public torch::OperatorKernel { num_elems_x *= input_contig.size(i); } - qnnp_operator_t qnnpack_operator{nullptr}; + pytorch_qnnp_operator_t qnnpack_operator{nullptr}; - const qnnp_status createStatus = qnnp_create_clamp_nc_u8( + const pytorch_qnnp_status createStatus = pytorch_qnnp_create_clamp_nc_u8( num_elems_x /* channels */, zero_point /* output min */, std::numeric_limits::max() /* output max */, 0 /* flags */, &qnnpack_operator); - std::unique_ptr qnnpack_uniq_ptr( - qnnpack_operator); + std::unique_ptr + qnnpack_uniq_ptr(qnnpack_operator); TORCH_INTERNAL_ASSERT( - createStatus == qnnp_status_success, + createStatus == pytorch_qnnp_status_success, "failed to create QNNPACK Relu operator"); TORCH_INTERNAL_ASSERT(qnnpack_operator != nullptr); @@ -56,7 +56,7 @@ class QNNPACKRelu final : public torch::OperatorKernel { size_t num_elems_y = volume / qy.size(0); - const qnnp_status setupStatus = qnnp_setup_clamp_nc_u8( + const pytorch_qnnp_status setupStatus = pytorch_qnnp_setup_clamp_nc_u8( qnnpack_operator, /* clamp */ input_contig.size(0) /* batch size */, (uint8_t*)input_contig.data_ptr() /* input data */, @@ -64,16 +64,16 @@ class QNNPACKRelu final : public torch::OperatorKernel { (uint8_t*)qy.data_ptr() /* output data */, num_elems_y /* output stride */); TORCH_INTERNAL_ASSERT( - setupStatus == qnnp_status_success, + setupStatus == pytorch_qnnp_status_success, "failed to setup QNNPACK Relu operator"); pthreadpool_t threadpool = caffe2::mobile_threadpool(); - const qnnp_status runStatus = - qnnp_run_operator(qnnpack_operator, threadpool); + const pytorch_qnnp_status runStatus = + pytorch_qnnp_run_operator(qnnpack_operator, threadpool); TORCH_INTERNAL_ASSERT( - runStatus == qnnp_status_success, + runStatus == pytorch_qnnp_status_success, "failed to run QNNPACK Relu operator"); return qy; diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack_utils.h b/aten/src/ATen/native/quantized/cpu/qnnpack_utils.h index 320797b501ac1..298cf472338f0 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack_utils.h +++ b/aten/src/ATen/native/quantized/cpu/qnnpack_utils.h @@ -1,11 +1,80 @@ #pragma once -#ifdef USE_QNNPACK -#include +#ifdef USE_PYTORCH_QNNPACK +#include +#include struct QnnpackOperatorDeleter { - void operator()(qnnp_operator_t op) { - qnnp_delete_operator(op); + void operator()(pytorch_qnnp_operator_t op) { + pytorch_qnnp_delete_operator(op); } }; + +// PackedWeight struct stores the original Weight and Bias as QNNPACK currently +// does not support an unpack function. +// Possible optimiation - For PyTorch Mobile, once the model is scripted and +// serialized we don't need to call unpack, so we can save some memory by +// checking for this case. +struct PackedLinearWeightsQnnp { + std::unique_ptr w; + at::Tensor orig_weight; + at::Tensor bias; + double w_scale; + int64_t w_zp; +}; + +struct PackedConvWeightsQnnp { + std::unique_ptr w; + at::Tensor orig_weight; + at::Tensor bias; + std::vector kernel; + double w_scale; + int64_t w_zp; +}; + +enum class Activation : uint8_t { NONE = 0, RELU = 1 }; + +#if defined(__ANDROID__) && !defined(__NDK_MAJOR__) +template +inline float Round(const float x) { + return ::nearbyintf(x); +} +inline double Round(const double x) { + return ::nearbyint(x); +} +#else +template +inline T Round(const T x) { + return std::nearbyint(x); +} +#endif + +inline uint8_t QuantizeUint8(float scale, int32_t zero_point, float value) { + const int32_t qmin = std::numeric_limits::min(); + const int32_t qmax = std::numeric_limits::max(); + auto r = zero_point + static_cast(Round(value / scale)); + r = std::max(r, qmin); + r = std::min(r, qmax); + return static_cast(r); +} + +inline std::pair activationLimits( + float scale, + int32_t zero_point, + Activation Ac) { + switch (Ac) { + case Activation::NONE: + return {std::numeric_limits::min(), + std::numeric_limits::max()}; + case Activation::RELU: + return {QuantizeUint8(scale, zero_point, 0.0), + std::numeric_limits::max()}; + default: +#ifdef _MSC_VER + __assume(0); +#else + __builtin_unreachable(); +#endif + } +} #endif diff --git a/aten/src/ATen/native/quantized/cpu/qpool.cpp b/aten/src/ATen/native/quantized/cpu/qpool.cpp index a0bc00f953fc9..fc5f00ca93d29 100644 --- a/aten/src/ATen/native/quantized/cpu/qpool.cpp +++ b/aten/src/ATen/native/quantized/cpu/qpool.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -14,6 +15,8 @@ namespace at { namespace native { namespace { +DEFINE_DISPATCH(qmaxpool_2d_nhwc_stub); + /* Computes the spatial 2D max pooling with dilation. Argument description in the argument list. @@ -135,58 +138,73 @@ Tensor q_maxpool_2d( oSizes = {nbatch, oC, oH, oW}; } - Tensor qy = at::_empty_affine_quantized( - oSizes, - qx.options().dtype(toQIntType(qx.scalar_type())), - qx.q_scale(), - qx.q_zero_point()); - auto qx_contig = qx.contiguous(); - auto qxd = qx_contig.data_ptr(); - auto qyd = qy.data_ptr(); - if (ndim == 3 || nbatch == 1) { - auto* iData = qxd; - auto* oData = qyd; - spatial_dilated_max_pooling( - iData, - iC, - iH, - iW, - oH, - oW, - kH, - kW, - sH, - sW, - pH, - pW, - dH, - dW, - oData); + if (qx.is_contiguous(c10::MemoryFormat::ChannelsLast)) { + // Fast path case for channels-last case. + // In this case, we can preserve the data layout in memory + // as well as use a loop nest that is more amenable to + // vectorization. + Tensor qy = at::_empty_affine_quantized( + oSizes, + qx.options().dtype(toQIntType(qx.scalar_type())), + qx.q_scale(), + qx.q_zero_point(), + qx.suggest_memory_format()); + qmaxpool_2d_nhwc_stub(qx.device().type(), qx, iC, iH, iW, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, qy); + return qy; } else { - at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) { - for (auto p = start; p < end; ++p) { - auto* iData = qxd + p * iC * iW * iH; - auto* oData = qyd + p * oC * oW * oH; - spatial_dilated_max_pooling( - iData, - iC, - iH, - iW, - oH, - oW, - kH, - kW, - sH, - sW, - pH, - pW, - dH, - dW, - oData); - } - }); + Tensor qy = at::_empty_affine_quantized( + oSizes, + qx.options().dtype(toQIntType(qx.scalar_type())), + qx.q_scale(), + qx.q_zero_point()); + auto qx_contig = qx.contiguous(); + auto qxd = qx_contig.data_ptr(); + auto qyd = qy.data_ptr(); + if (ndim == 3 || nbatch == 1) { + auto* iData = qxd; + auto* oData = qyd; + spatial_dilated_max_pooling( + iData, + iC, + iH, + iW, + oH, + oW, + kH, + kW, + sH, + sW, + pH, + pW, + dH, + dW, + oData); + } else { + at::parallel_for(0, nbatch, 0, [&](int64_t start, int64_t end) { + for (auto p = start; p < end; ++p) { + auto* iData = qxd + p * iC * iW * iH; + auto* oData = qyd + p * oC * oW * oH; + spatial_dilated_max_pooling( + iData, + iC, + iH, + iW, + oH, + oW, + kH, + kW, + sH, + sW, + pH, + pW, + dH, + dW, + oData); + } + }); + } + return qy; } - return qy; } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/quantized_ops.h b/aten/src/ATen/native/quantized/cpu/quantized_ops.h index 9475e6016b2c2..b43c9713bd9d9 100644 --- a/aten/src/ATen/native/quantized/cpu/quantized_ops.h +++ b/aten/src/ATen/native/quantized/cpu/quantized_ops.h @@ -8,11 +8,29 @@ namespace native { using qrelu_fn = void (*)(const at::Tensor& /*qx*/, at::Tensor& /*qy*/); using qadd_fn = void (*)(Tensor& /*out*/, const Tensor& /*self*/, const Tensor& /*other*/); +using qmaxpool_2d_fn = + void (*)(const Tensor &qx, + int64_t iC, // input/output channels + int64_t iH, + int64_t iW, // input sizes + int64_t oH, + int64_t oW, // output sizes + int64_t kH, + int64_t kW, // kernel size + int64_t sH, + int64_t sW, // strides + int64_t pH, + int64_t pW, // padding + int64_t dH, + int64_t dW, // dilation + Tensor &qy + ); DECLARE_DISPATCH(qrelu_fn, qrelu_stub); DECLARE_DISPATCH(qrelu_fn, qrelu6_stub); DECLARE_DISPATCH(qadd_fn, qadd_stub); DECLARE_DISPATCH(qadd_fn, qadd_relu_stub); +DECLARE_DISPATCH(qmaxpool_2d_fn, qmaxpool_2d_nhwc_stub); } // namespace native } // namespace at \ No newline at end of file diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index 8c53ac5cce7d6..9fef7782474c4 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -80,7 +80,7 @@ SparseTensor new_sparse(const TensorOptions& options) { type_id = TensorTypeId::SparseCPUTensorId; } return detail::make_tensor( - type_id, options.dtype()); + TensorTypeSet(type_id), options.dtype()); } /** Actual dispatched creation methods ***/ diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index f9cbe6c96c839..b01166245aae2 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -1,3 +1,5 @@ +#include + #include #include #include @@ -148,10 +150,23 @@ SparseTensor pow_sparse_scalar(const SparseTensor& t, Scalar value) { // div(SparseTensor, Scalar) // -------------------------------------------------------------------- +SparseTensor& div_out_sparse_zerodim(SparseTensor& r, const SparseTensor& t, const Tensor& value); + +Tensor div_sparse(const Tensor& self, const Tensor& value) { + Tensor result = at::empty({0}, self.options()); + return div_out_sparse_zerodim(result, self, value); +} + +Tensor& div_sparse_(Tensor& self, const Tensor& value) { + return div_out_sparse_zerodim(self, self, value); +} + SparseTensor& div_out_sparse_zerodim(SparseTensor& r, const SparseTensor& t, const Tensor& value) { + TORCH_CHECK(value.dim() == 0, "sparse division only supports division by a scalar (got shape ", + value.sizes(), " for argument 'other')"); + AT_ASSERT(r.is_sparse()); AT_ASSERT(t.is_sparse()); - AT_ASSERT(value.dim() == 0); if (is_same_tensor(r, t)) { r._values().div_(value); @@ -187,9 +202,40 @@ Tensor norm_sparse(const SparseTensor& self, Scalar value) { // add(SparseTensor, SparseTensor, Scalar) [broadcasts] // -------------------------------------------------------------------- +Tensor add_sparse(const Tensor& self, const Tensor& other, Scalar alpha) { + // TODO: Why?! Can't we just flip the order here... + TORCH_CHECK(!(self.is_sparse() && !other.is_sparse()), + "add(sparse, dense) is not supported. Use add(dense, sparse) instead."); + Tensor result = at::empty({0}, self.options()); + return at::add_out(result, self, other, alpha); // redispatch! +} + +Tensor& add_sparse_(Tensor& self, const Tensor& other, Scalar alpha) { + return at::add_out(self, self, other, alpha); // redispatch! +} + +// There's actually nothing sparse specific about these implementations + +Tensor sub_sparse(const Tensor& self, const Tensor& other, Scalar alpha) { + return native::add_sparse(self, other, -alpha); +} + +Tensor& sub_sparse_(Tensor& self, const Tensor& other, Scalar alpha) { + return native::add_sparse_(self, other, -alpha); +} + +Tensor& sub_out_sparse(Tensor& r, const Tensor& self, const Tensor& other, Scalar alpha) { + return at::add_out(r, self, other, -alpha); // redispatch! +} + +Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, const SparseTensor& sparse_, Scalar value); + SparseTensor& add_out_sparse_cpu(SparseTensor& r, const SparseTensor& t, const SparseTensor& src, Scalar value) { - AT_ASSERT(r.is_sparse()); - AT_ASSERT(t.is_sparse()); + if (!t.is_sparse()) { + return add_out_dense_sparse_cpu(r, t, src, value); + } + // TODO: This test seems a bit goofy + TORCH_CHECK(src.is_sparse(), "add(sparse, dense) is not supported. Use add(dense, sparse) instead."); AT_ASSERT(!t.is_cuda()); // the dispatch argument TORCH_CHECK(!r.is_cuda(), "add: expected 'out' to be CPU tensor, but got CUDA tensor"); TORCH_CHECK(!src.is_cuda(), "add: expected 'other' to be a CPU tensor, but got a CUDA tensor"); @@ -375,6 +421,15 @@ Tensor& add_out_dense_sparse_cpu(Tensor& r, const Tensor& dense, const SparseTen // mul(SparseTensor, SparseTensor) [broadcasts] // -------------------------------------------------------------------- +Tensor mul_sparse(const Tensor& self, const Tensor& other) { + Tensor result = at::empty({0}, self.options()); + return at::mul_out(result, self, other); // redispatch! +} + +Tensor& mul_sparse_(Tensor& self, const Tensor& other) { + return at::mul_out(self, self, other); // redispatch! +} + SparseTensor& mul_out_sparse_cpu(SparseTensor& r, const Tensor& t_, const Tensor& src_) { if (src_.dim() == 0) { return mul_out_sparse_zerodim(r, t_, src_); @@ -576,6 +631,19 @@ Tensor& s_addmm_out_sparse_dense_cpu( } +Tensor& addmm_out_sparse_dense_cpu( + Tensor& result, + const Tensor& self, + const SparseTensor& mat1, + const Tensor& mat2, + Scalar beta, + Scalar alpha +) { + Tensor b_self; + std::tie(b_self) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); + return s_addmm_out_sparse_dense_cpu(result, b_self, mat1, mat2, beta, alpha); +} + Tensor s_addmm_sparse_dense_cpu( const Tensor& t, const SparseTensor& sparse, @@ -588,6 +656,18 @@ Tensor s_addmm_sparse_dense_cpu( return r; } +Tensor addmm_sparse_dense_cpu( + const Tensor& self, + const SparseTensor& mat1, + const Tensor& mat2, + Scalar beta, + Scalar alpha +) { + Tensor b_self; + std::tie(b_self) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); + return s_addmm_sparse_dense_cpu(b_self, mat1, mat2, beta, alpha); +} + Tensor& s_addmm_sparse_dense_cpu_( Tensor& t, const SparseTensor& sparse, @@ -598,6 +678,8 @@ Tensor& s_addmm_sparse_dense_cpu_( return s_addmm_out_sparse_dense_cpu(t, t, sparse, dense, beta, alpha); } +// NB: Purposely no broadcasting version of addmm inplace + Tensor _sparse_addmm( const Tensor& t, const SparseTensor& sparse, @@ -605,9 +687,10 @@ Tensor _sparse_addmm( Scalar beta, Scalar alpha ) { - Tensor b_t; - std::tie(b_t) = expand_size(t, {sparse.size(0), dense.size(1)}, "addmm"); - return at::s_native_addmm(b_t, sparse, dense, beta, alpha); + // _sparse_addmm forward is functionally equivalent to addmm; it's + // just the backward that is different. This technically does an + // unnecessary redispatch, I was too lazy to make it not do that + return at::addmm(t, sparse, dense, beta, alpha); } Tensor _sparse_mm( @@ -615,16 +698,19 @@ Tensor _sparse_mm( const Tensor& dense ) { Tensor t = at::zeros({}, dense.options()); - return at::_sparse_addmm(t, sparse, dense, 0, 1); + return at::_sparse_addmm(t, sparse, dense, 0, 1); // redispatch! } +// NB: Despite its suggestive name, this actually only exists so that +// we can redispatch to addmm_out; this is NOT an implementation of +// the sparse masking version of mm SparseTensor& _sparse_mm_out( SparseTensor& result, const SparseTensor& sparse, const Tensor& dense ) { Tensor t = at::zeros({}, dense.options()); - return at::addmm_out(result, t, sparse, dense, 0, 1); + return at::addmm_out(result, t, sparse, dense, 0, 1); // redispatch! } // -------------------------------------------------------------------- diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.h b/aten/src/ATen/native/sparse/SparseTensorMath.h new file mode 100644 index 0000000000000..514f84fd8e6e5 --- /dev/null +++ b/aten/src/ATen/native/sparse/SparseTensorMath.h @@ -0,0 +1,11 @@ +#pragma once + +#include +#include + +namespace at { namespace native { + +sparse::SparseTensor& mul_out_sparse_scalar(sparse::SparseTensor& r, const sparse::SparseTensor& t, Scalar value); +sparse::SparseTensor& mul_out_sparse_zerodim(sparse::SparseTensor& r, const sparse::SparseTensor& t, const Tensor& value); + +}} diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh b/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh index c4cf79dfdb82d..d8f45e4fbe49f 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh @@ -1,6 +1,7 @@ #pragma once #include +#include namespace at { namespace native { @@ -9,8 +10,6 @@ namespace apply { using at::cuda::detail::TensorInfo; using indexT = int64_t; -const int WARP_SIZE = 32; - template __device__ void applyOp2( Op op, IndexType blockSize, @@ -324,7 +323,7 @@ __global__ void coalesceValuesKernel( #pragma unroll for (int ii = 0; ii < SZ; ii++) { - int featureDim = startFeature + ii * WARP_SIZE; + int featureDim = startFeature + ii * C10_WARP_SIZE; if (featureDim < stride) { tmp[ii] += static_cast(values[valueRow + featureDim]); @@ -334,7 +333,7 @@ __global__ void coalesceValuesKernel( #pragma unroll for (int ii = 0; ii < SZ; ii++) { - int featureDim = startFeature + ii * WARP_SIZE; + int featureDim = startFeature + ii * C10_WARP_SIZE; if (featureDim < stride) { newValues[newValueRow + featureDim] = static_cast(tmp[ii]); diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp index 927d91797de39..4d43150547679 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cpp @@ -12,7 +12,7 @@ SparseTensor& sparse_mask_out_cuda(SparseTensor& r, const Tensor& t, const Spars TORCH_CHECK(mask.is_coalesced(), "sparse_mask: mask is uncoalesced"); TORCH_CHECK(mask.sizes().equals(t.sizes()), "sparse_mask: operands have incompatible sizes; self has size ", t.sizes(), " but mask has size ", mask.sizes()); - AT_ASSERT(t.is_cuda()); // dispatch argument + TORCH_CHECK(t.is_cuda(), "sparse_mask: expected 'self' to be CUDA, but got CPU"); TORCH_CHECK(mask.is_cuda(), "sparse_mask: expected 'mask' to be CUDA, but got CPU"); TORCH_CHECK(r.is_cuda(), "sparse_mask: expected 'out' to be CUDA, but got CPU"); TORCH_CHECK(cuda::check_device({r, t, mask}), diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu index 6f2ad1c7cb430..096f34de65575 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu @@ -21,6 +21,7 @@ #if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__ #include #endif +#include namespace at { namespace native { @@ -90,10 +91,11 @@ SparseTensor coalesce_sparse_cuda(const SparseTensor& self) { // If there is no values to copy, save running the kernel. if (newValues.numel() > 0) { + const int SZ = 4; values = values.contiguous(); int64_t stride = at::prod_intlist(values.sizes().slice(1)); - dim3 grid(THCCeilDiv(newNnz, (int64_t) 4), THCCeilDiv(stride, (int64_t) 128)); - dim3 block(32, 4); + dim3 grid(THCCeilDiv(newNnz, (int64_t) SZ), THCCeilDiv(stride, (int64_t) C10_WARP_SIZE*SZ)); + dim3 block(C10_WARP_SIZE, SZ); AT_DISPATCH_ALL_TYPES_AND( at::ScalarType::Half,values.scalar_type(), "coalesce_sparse_cuda", [&] { using cuda_accscalar_t = acc_type; diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu index fdc17e4dfdc36..f681f16ce49b3 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu @@ -2,12 +2,14 @@ #include #include #include +#include #include #include #include #include #include #include +#include #include #include @@ -51,7 +53,7 @@ namespace { // -------------------------------------------------------------------- Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseTensor& sparse_, const Tensor& dense, Scalar beta, Scalar alpha) { - AT_ASSERT(t.is_cuda()); // dispatch argument + TORCH_CHECK(t.is_cuda(), "addmm: expected 'self' to be CUDA, but got CPU"); TORCH_CHECK(r_.is_cuda(), "addmm: expected 'out' to be CUDA, but got CPU"); TORCH_CHECK(sparse_.is_cuda(), "addmm: expected 'mat1' to be CUDA, but got CPU"); TORCH_CHECK(dense.is_cuda(), "addmm: expected 'mat2' to be CUDA, but got CPU"); @@ -151,6 +153,19 @@ Tensor& s_addmm_out_sparse_dense_cuda(Tensor& r_, const Tensor& t, const SparseT return r_; } +Tensor& addmm_out_sparse_dense_cuda( + Tensor& result, + const Tensor& self, + const SparseTensor& mat1, + const Tensor& mat2, + Scalar beta, + Scalar alpha +) { + Tensor b_self; + std::tie(b_self) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); + return s_addmm_out_sparse_dense_cuda(result, b_self, mat1, mat2, beta, alpha); +} + Tensor s_addmm_sparse_dense_cuda( const Tensor& t, const SparseTensor& sparse, @@ -163,6 +178,18 @@ Tensor s_addmm_sparse_dense_cuda( return r; } +Tensor addmm_sparse_dense_cuda( + const Tensor& self, + const SparseTensor& mat1, + const Tensor& mat2, + Scalar beta, + Scalar alpha +) { + Tensor b_self; + std::tie(b_self) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm_out"); + return s_addmm_sparse_dense_cuda(b_self, mat1, mat2, beta, alpha); +} + Tensor& s_addmm_sparse_dense_cuda_( Tensor& t, const SparseTensor& sparse, @@ -173,6 +200,8 @@ Tensor& s_addmm_sparse_dense_cuda_( return s_addmm_out_sparse_dense_cuda(t, t, sparse, dense, beta, alpha); } +// NB: Purposely no broadcasting version of addmm inplace + // Deleted sspaddmm (sparse, dense) -> sparse // -------------------------------------------------------------------- @@ -180,7 +209,7 @@ Tensor& s_addmm_sparse_dense_cuda_( // -------------------------------------------------------------------- SparseTensor& hspmm_out_sparse_cuda(SparseTensor& r_, const SparseTensor& sparse_, const Tensor& dense/* , Scalar alpha */) { - AT_ASSERT(sparse_.is_cuda()); // dispatch argument + TORCH_CHECK(sparse_.is_cuda(), "hspmm: expected 'self' to be CUDA, but got CPU"); TORCH_CHECK(r_.is_cuda(), "hspmm: expected 'out' to be CUDA, but got CPU"); TORCH_CHECK(dense.is_cuda(), "hspmm: expected 'mat2' to be CUDA, but got CPU"); @@ -249,9 +278,9 @@ SparseTensor hspmm_sparse_cuda(const SparseTensor& sparse, const Tensor& dense) // -------------------------------------------------------------------- Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseTensor& sparse, at::Scalar value) { - AT_ASSERT(dense.is_cuda()); // dispatch argument - TORCH_CHECK(sparse.is_cuda(), "add: expected 'other' to be CUDA, but got CPU"); - TORCH_CHECK(r_.is_cuda(), "add: expected 'out' to be CUDA, but got CPU"); + TORCH_CHECK(dense.is_cuda(), "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"); + TORCH_CHECK(sparse.is_cuda(), "add: expected 'other' to be a CUDA tensor, but got a CPU tensor"); + TORCH_CHECK(r_.is_cuda(), "add: expected 'out' to be a CUDA tensor, but got a CPU tensor"); TORCH_CHECK(cuda::check_device({sparse, r_, dense})); @@ -350,8 +379,17 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT // add(SparseTensor, SparseTensor, Scalar) [broadcasts] // -------------------------------------------------------------------- +Tensor& add_out_dense_sparse_cuda(Tensor& r, const Tensor& dense, const SparseTensor& sparse_, Scalar value); + SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const SparseTensor& src, Scalar value) { - AT_ASSERT(t.is_cuda()); // dispatch argument + if (!t.is_sparse()) { + return add_out_dense_sparse_cuda(r_, t, src, value); + } + + // TODO: This test seems a bit goofy + TORCH_CHECK(src.is_sparse(), "add(sparse, dense) is not supported. Use add(dense, sparse) instead."); + + TORCH_CHECK(t.is_cuda(), "add: expected 'self' to be CUDA, but got CPU"); TORCH_CHECK(src.is_cuda(), "add: expected 'other' to be CUDA, but got CPU"); TORCH_CHECK(r_.is_cuda(), "add: expected 'out' to be CUDA, but got CPU"); @@ -410,7 +448,7 @@ SparseTensor& mul_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t_, cons return mul_out_sparse_zerodim(r_, src_, t_); } - AT_ASSERT(t_.is_cuda()); // dispatch argument + TORCH_CHECK(t_.is_cuda(), "mul: expected 'self' to be CUDA, but got CPU"); TORCH_CHECK(src_.is_cuda(), "mul: expected 'other' to be CUDA, but got CPU"); TORCH_CHECK(r_.is_cuda(), "mul: expected 'out' to be CUDA, but got CPU"); TORCH_CHECK(cuda::check_device({r_, t_, src_})); diff --git a/aten/src/ATen/native_parse.py b/aten/src/ATen/native_parse.py index d8962140b4fb5..2924e53baee2b 100644 --- a/aten/src/ATen/native_parse.py +++ b/aten/src/ATen/native_parse.py @@ -395,6 +395,7 @@ def run(paths): assert arguments[-1] == ")", "Expecting closing ) for {}".format(func['func']) arguments = arguments[:-1] # Expect closing ) declaration['name'] = func.get('name', fn_name) + declaration['operator_name'] = func.get('name', fn_name) declaration['overload_name'] = func.get('overload_name', overload_name) declaration['inplace'] = re.search('(^__i|[^_]_$)', fn_name) is not None return_arguments = parse_return_arguments(return_decl, declaration['inplace'], func) @@ -412,7 +413,8 @@ def run(paths): declaration['cuda_bool'] = func.get('cuda_bool', False) declaration['deprecated'] = func.get('deprecated', False) declaration['device_guard'] = func.get('device_guard', True) - declaration['named_guard'] = func.get('named_guard', True) + declaration['supports_named_tensor'] = func.get('supports_named_tensor', False) + declaration['use_c10_dispatcher'] = func.get('use_c10_dispatcher', False) declaration['arguments'] = func.get('arguments', arguments) declaration['type_method_definition_dispatch'] = func.get('dispatch', declaration['name']) declaration['python_module'] = func.get('python_module', '') diff --git a/aten/src/ATen/nn.yaml b/aten/src/ATen/nn.yaml index d022cfcca9407..898bf3542442f 100644 --- a/aten/src/ATen/nn.yaml +++ b/aten/src/ATen/nn.yaml @@ -1,6 +1,6 @@ # Loss functions -- name: _thnn_binary_cross_entropy(Tensor self, Tensor target, Tensor? weight={}, int64_t reduction=Reduction::Mean) +- name: _thnn_binary_cross_entropy(Tensor self, Tensor target, Tensor? weight, int64_t reduction) cname: BCECriterion scalar_check: output: 'false' @@ -15,19 +15,19 @@ forward_scalar_types: ['Float', 'Double', 'Half'] backward_scalar_types: ['Float', 'Double', 'Half'] -- name: _thnn_l1_loss(Tensor self, Tensor target, int64_t reduction=Reduction::Mean) +- name: _thnn_l1_loss(Tensor self, Tensor target, int64_t reduction) cname: AbsCriterion scalar_check: output: 'false' grad_input: 'false' -- name: _thnn_mse_loss(Tensor self, Tensor target, int64_t reduction=Reduction::Mean) +- name: _thnn_mse_loss(Tensor self, Tensor target, int64_t reduction) cname: MSECriterion scalar_check: output: 'false' grad_input: 'false' -- name: _thnn_multi_margin_loss(Tensor self, LongTensor target, Scalar p=1, Scalar margin=1, Tensor? weight={}, int64_t reduction=Reduction::Mean) +- name: _thnn_multi_margin_loss(Tensor self, LongTensor target, Scalar p, Scalar margin, Tensor? weight, int64_t reduction) cname: MultiMarginCriterion scalar_check: output: reduction != Reduction::None || self_->dim() == 0 || (reduction == Reduction::None && self_->dim() == 1) @@ -39,27 +39,30 @@ output: reduction != Reduction::None || self_->dim() == 0 is_target: target_->dim() == 0 -- name: _thnn_nll_loss(Tensor self, LongTensor target, Tensor? weight={}, int64_t reduction=Reduction::Mean, int64_t ignore_index=-100) +- name: _thnn_nll_loss(Tensor self, LongTensor target, Tensor? weight, int64_t reduction, int64_t ignore_index) cname: ClassNLLCriterion buffers: [total_weight] scalar_check: output: reduction != Reduction::None || self_->dim() == 0 total_weight: 'true' + CPU: + forward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16'] + backward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16'] -- name: _thnn_nll_loss2d(Tensor self, LongTensor target, Tensor? weight={}, int64_t reduction=Reduction::Mean, int64_t ignore_index=-100) +- name: _thnn_nll_loss2d(Tensor self, LongTensor target, Tensor? weight, int64_t reduction, int64_t ignore_index) cname: SpatialClassNLLCriterion buffers: [total_weight] scalar_check: output: reduction != Reduction::None || self_->dim() == 0 total_weight: 'true' -- name: _thnn_smooth_l1_loss(Tensor self, Tensor target, int64_t reduction=Reduction::Mean) +- name: _thnn_smooth_l1_loss(Tensor self, Tensor target, int64_t reduction) cname: SmoothL1Criterion scalar_check: output: 'false' grad_input: 'false' -- name: _thnn_soft_margin_loss(Tensor self, Tensor target, int64_t reduction=Reduction::Mean) +- name: _thnn_soft_margin_loss(Tensor self, Tensor target, int64_t reduction) cname: SoftMarginCriterion scalar_check: output: 'false' @@ -67,14 +70,14 @@ # Activation functions -- name: _thnn_elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) +- name: _thnn_elu(Tensor self, Scalar alpha, Scalar scale, Scalar input_scale) cname: ELU has_inplace: True scalar_check: output: 'false' grad_input: 'false' -- name: _thnn_glu(Tensor self, int64_t dim=-1) +- name: _thnn_glu(Tensor self, int64_t dim) cname: GatedLinear wrap_dim: dim: self @@ -82,14 +85,14 @@ output: 'false' grad_input: 'false' -- name: _thnn_hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) +- name: _thnn_hardtanh(Tensor self, Scalar min_val, Scalar max_val) cname: HardTanh has_inplace: True scalar_check: output: 'false' grad_input: 'false' -- name: _thnn_leaky_relu(Tensor self, Scalar negative_slope=0.01) +- name: _thnn_leaky_relu(Tensor self, Scalar negative_slope) cname: LeakyReLU has_inplace: True scalar_check: @@ -106,20 +109,20 @@ # NOTE: we treat noise as an input (it's really a buffer) because the codegen # can't handle in-place functions that have buffers -- name: _thnn_rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=false, Generator* generator=nullptr) +- name: _thnn_rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator* generator) cname: RReLU has_inplace: True scalar_check: output: 'false' grad_input: 'false' -- name: _thnn_softplus(Tensor self, Scalar beta=1, Scalar threshold=20) +- name: _thnn_softplus(Tensor self, Scalar beta, Scalar threshold) cname: SoftPlus scalar_check: output: 'false' grad_input: 'false' -- name: _thnn_softshrink(Tensor self, Scalar lambd=0.5) +- name: _thnn_softshrink(Tensor self, Scalar lambd) cname: SoftShrink scalar_check: output: 'false' @@ -142,20 +145,20 @@ # Convolutions -- name: _thnn_conv2d(Tensor self, Tensor weight, IntArrayRef[2] kernel_size, Tensor? bias={}, IntArrayRef[2] stride=1, IntArrayRef[2] padding=0) +- name: _thnn_conv2d(Tensor self, Tensor weight, IntArrayRef[2] kernel_size, Tensor? bias, IntArrayRef[2] stride, IntArrayRef[2] padding) cname: SpatialConvolutionMM buffers: [finput, fgrad_input] CPU: - forward_scalar_types: ['Float', 'Double', 'Long'] - backward_scalar_types: ['Float', 'Double'] + forward_scalar_types: ['Float', 'Double', 'Long', 'BFloat16'] + backward_scalar_types: ['Float', 'Double', 'BFloat16'] -- name: _thnn_conv_depthwise2d(Tensor self, Tensor weight, IntArrayRef[2] kernel_size, Tensor? bias={}, IntArrayRef[2] stride=1, IntArrayRef[2] padding=0, IntArrayRef[2] dilation=1) +- name: _thnn_conv_depthwise2d(Tensor self, Tensor weight, IntArrayRef[2] kernel_size, Tensor? bias, IntArrayRef[2] stride, IntArrayRef[2] padding, IntArrayRef[2] dilation) cname: SpatialDepthwiseConvolution buffers: [] -- name: _thnn_conv3d(Tensor self, Tensor weight, IntArrayRef[3] kernel_size, Tensor? bias={}, IntArrayRef[3] stride=1, IntArrayRef[3] padding=0) +- name: _thnn_conv3d(Tensor self, Tensor weight, IntArrayRef[3] kernel_size, Tensor? bias, IntArrayRef[3] stride, IntArrayRef[3] padding) cname: VolumetricConvolutionMM buffers: [finput, fgrad_input] CPU: - forward_scalar_types: ['Float', 'Double', 'Long'] - backward_scalar_types: ['Float', 'Double'] + forward_scalar_types: ['Float', 'Double', 'Long', 'BFloat16'] + backward_scalar_types: ['Float', 'Double', 'BFloat16'] diff --git a/aten/src/ATen/nn_parse.py b/aten/src/ATen/nn_parse.py index 2832db3ef0a2a..416a308158219 100644 --- a/aten/src/ATen/nn_parse.py +++ b/aten/src/ATen/nn_parse.py @@ -42,11 +42,7 @@ def argument_to_declaration(param, func=None): arg['name'] = name if func is not None: - default_inits = func.get('default_init', {}) wrap_dims = func.get('wrap_dim', {}) - if name in default_inits: - # non constexpr defaults - arg['default_init'] = default_inits[name] if name in wrap_dims: arg['wrap_dim'] = wrap_dims[name] @@ -235,6 +231,8 @@ def function_info(name, arguments, cimpls, buffers, backends, inplace, scalar_ch return { 'mode': 'NN', 'name': name, + 'cpu_bfloat16': True if backend_types is not None and 'CPU' in backend_types and + 'BFloat16' in backend_types['CPU'] else False, 'backend_types': backend_types, 'arguments': arguments, 'return': 'argument 0' if inplace else get_return(arguments), diff --git a/aten/src/ATen/preprocess_declarations.py b/aten/src/ATen/preprocess_declarations.py index 368ec01851358..806f28c5651a3 100644 --- a/aten/src/ATen/preprocess_declarations.py +++ b/aten/src/ATen/preprocess_declarations.py @@ -178,14 +178,10 @@ def set_mode(option): def discover_zero_dim_tensor_operations(declaration): - def exclude(arg): - return arg.get('ignore_check') - def signature(option, i=None, value=None): elements = [TYPE_FORMAL_GENERIC.get(arg['type'], arg['type']) if i is None or j != i else value - for j, arg in enumerate(option['arguments']) - if not exclude(arg)] + for j, arg in enumerate(option['arguments'])] return '#'.join(elements) signature_to_option = {signature(option): option for option in declaration['options']} @@ -197,8 +193,7 @@ def signature(option, i=None, value=None): if signature_of_tensor_version in signature_to_option: tensor_version = \ signature_to_option[signature_of_tensor_version] - names = [arg['name'] for arg in tensor_version['arguments'] - if not exclude(arg)] + names = [arg['name'] for arg in tensor_version['arguments']] tensor_version['zero_dim_dispatch_when_scalar'] = names[i] # print("FOUND "+str(i) ) # print("Scalar Version ===== ") @@ -227,7 +222,7 @@ def run(declarations): type_to_signature=TYPE_FORMAL_GENERIC, remove_self=True) - common_with_cwrap.sort_by_number_of_options(declaration) + common_with_cwrap.sort_by_number_of_args(declaration) discover_zero_dim_tensor_operations(declaration) diff --git a/aten/src/ATen/quantized/QTensorImpl.cpp b/aten/src/ATen/quantized/QTensorImpl.cpp index 0006d509666fa..925ee7319d98b 100644 --- a/aten/src/ATen/quantized/QTensorImpl.cpp +++ b/aten/src/ATen/quantized/QTensorImpl.cpp @@ -4,9 +4,9 @@ namespace at { QTensorImpl::QTensorImpl( Storage&& storage, - TensorTypeId type_id, + TensorTypeSet type_set, QuantizerPtr quantizer) - : TensorImpl(std::move(storage), type_id), + : TensorImpl(std::move(storage), type_set), quantizer_(quantizer) {} } // namespace at diff --git a/aten/src/ATen/quantized/QTensorImpl.h b/aten/src/ATen/quantized/QTensorImpl.h index 74014f00665c9..9c187f1025e2f 100644 --- a/aten/src/ATen/quantized/QTensorImpl.h +++ b/aten/src/ATen/quantized/QTensorImpl.h @@ -17,7 +17,7 @@ struct CAFFE2_API QTensorImpl : public c10::TensorImpl { public: QTensorImpl( Storage&& storage, - TensorTypeId type_id, + TensorTypeSet type_set, QuantizerPtr quantizer); // TODO: Expose in PyTorch Frontend @@ -39,7 +39,7 @@ struct CAFFE2_API QTensorImpl : public c10::TensorImpl { const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const override { auto impl = c10::make_intrusive( - Storage(storage()), type_id(), quantizer_); + Storage(storage()), type_set(), quantizer_); copy_tensor_metadata( /*src_impl=*/this, /*dest_impl=*/impl.get(), @@ -57,7 +57,7 @@ struct CAFFE2_API QTensorImpl : public c10::TensorImpl { * see NOTE [ TensorImpl Shallow-Copying ]. */ void shallow_copy_from(const c10::intrusive_ptr& impl) override { - AT_ASSERT(has_compatible_shallow_copy_type(impl->type_id())); + AT_ASSERT(has_compatible_shallow_copy_type(impl->type_set())); auto q_impl = static_cast(impl.get()); copy_tensor_metadata( /*src_impl=*/q_impl, diff --git a/aten/src/ATen/quantized/Quantizer.cpp b/aten/src/ATen/quantized/Quantizer.cpp index 8a5d000b3903c..1b4a55b104f14 100644 --- a/aten/src/ATen/quantized/Quantizer.cpp +++ b/aten/src/ATen/quantized/Quantizer.cpp @@ -15,6 +15,13 @@ namespace at { +// Note: this is not a native function as Quantizer is not exposed to python yet +QuantizerPtr Tensor::quantizer() const { + // This is a terrible hack to emulate what VariableType is doing + at::AutoNonVariableTypeMode non_var_type_mode(true); + return get_qtensorimpl(*this)->quantizer(); +} + void checkFloatCPUTensor(std::string fn_name, Tensor t) { TORCH_CHECK( t.scalar_type() == kFloat, @@ -363,7 +370,7 @@ inline Tensor new_qtensor_cpu( allocator, /*resizable=*/true); auto tensor = detail::make_tensor( - storage, at::TensorTypeId::QuantizedCPUTensorId, quantizer); + storage, at::TensorTypeSet(at::TensorTypeId::QuantizedCPUTensorId), quantizer); get_qtensorimpl(tensor)->set_sizes_contiguous(sizes); get_qtensorimpl(tensor)->empty_tensor_restride(memory_format); return tensor; diff --git a/aten/src/ATen/templates/Functions.h b/aten/src/ATen/templates/Functions.h index 9274338b1827b..59634806be09a 100644 --- a/aten/src/ATen/templates/Functions.h +++ b/aten/src/ATen/templates/Functions.h @@ -15,6 +15,7 @@ #include #include #include +#include namespace at { diff --git a/aten/src/ATen/templates/LegacyTHFunctions.cpp b/aten/src/ATen/templates/LegacyTHFunctions.cpp index 50baa294fb1a5..3cf1d4b1b3b62 100644 --- a/aten/src/ATen/templates/LegacyTHFunctions.cpp +++ b/aten/src/ATen/templates/LegacyTHFunctions.cpp @@ -4,11 +4,10 @@ #include #include -#ifdef BUILD_NAMEDTENSOR #include -#endif #include #include +#include ${th_headers} ${extra_cuda_headers} diff --git a/aten/src/ATen/templates/NativeFunctions.h b/aten/src/ATen/templates/NativeFunctions.h index 159a0e9c12f16..c201539dfda0b 100644 --- a/aten/src/ATen/templates/NativeFunctions.h +++ b/aten/src/ATen/templates/NativeFunctions.h @@ -6,6 +6,7 @@ #include #include #include +#include #include #include diff --git a/aten/src/ATen/templates/OpsAlreadyMovedToC10.cpp b/aten/src/ATen/templates/OpsAlreadyMovedToC10.cpp new file mode 100644 index 0000000000000..360977bb597b3 --- /dev/null +++ b/aten/src/ATen/templates/OpsAlreadyMovedToC10.cpp @@ -0,0 +1,47 @@ +#include +#include + +#include +#include +#include +#include +#include + +// ${generated_comment} + +// TODO Once all ATen ops are moved to c10, this file should be removed + +namespace at { + +namespace { +struct OpNameEquals final { + bool operator()(const std::pair& lhs, const std::pair& rhs) const { + return 0 == strcmp(lhs.first, rhs.first) && 0 == strcmp(lhs.second, rhs.second); + } +}; + +struct OpNameHash final { + size_t operator()(const std::pair& p) const { + // use std::hash because std::hash would hash pointers and not pointed-to strings + return std::hash()(p.first) ^ (~ std::hash()(p.second)); + } +}; +} + +bool aten_op_is_already_moved_to_c10(const c10::OperatorName& opName) { + static std::unordered_set, OpNameHash, OpNameEquals> ops { + ${c10_ops_already_moved_from_aten_to_c10} + {"", ""} + }; + return ops.count(std::make_pair(opName.name.c_str(), opName.overload_name.c_str())) != 0; +} + +bool aten_op_is_not_moved_to_c10_yet(const c10::OperatorName& opName) { + static std::unordered_set, OpNameHash, OpNameEquals> ops { + ${c10_ops_not_moved_from_aten_to_c10_yet} + {"", ""} + }; + return ops.count(std::make_pair(opName.name.c_str(), opName.overload_name.c_str())) != 0; +} + +} diff --git a/aten/src/ATen/templates/SparseTypeDerived.cpp b/aten/src/ATen/templates/SparseTypeDerived.cpp index a706ac0f70777..075950524b0f9 100644 --- a/aten/src/ATen/templates/SparseTypeDerived.cpp +++ b/aten/src/ATen/templates/SparseTypeDerived.cpp @@ -17,6 +17,8 @@ #include #include #include +#include +#include #include #include @@ -30,7 +32,8 @@ namespace at { ${type_derived_method_definitions} -static auto& registerer = globalATenDispatch() +#ifndef USE_STATIC_DISPATCH +static auto registerer = torch::RegisterOperators() ${function_registrations}; - +#endif } diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 49c23e695fda7..ab2215359c4f6 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -11,14 +11,14 @@ #include #include #include +#include #include #include #include #include #include -#ifdef BUILD_NAMEDTENSOR +#include #include -#endif namespace caffe2 { class Tensor; @@ -42,6 +42,7 @@ struct Quantizer; // This is temporary typedef to enable Quantizer in aten native function API // we'll remove them when we are actually exposing Quantizer class // to frontend +using QuantizerPtr = c10::intrusive_ptr; using ConstQuantizerPtr = const c10::intrusive_ptr&; // Tensor is a "generic" object holding a pointer to the underlying TensorImpl object, which @@ -219,12 +220,12 @@ class CAFFE2_API Tensor { DeprecatedTypeProperties & type() const { return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( - tensorTypeIdToBackend(type_id()), + tensorTypeIdToBackend(legacyExtractTypeId(type_set())), scalar_type(), is_variable()); } - TensorTypeId type_id() const { - return impl_->type_id(); + TensorTypeSet type_set() const { + return impl_->type_set(); } ScalarType scalar_type() const { return typeMetaToScalarType(impl_->dtype()); @@ -274,6 +275,10 @@ class CAFFE2_API Tensor { /// Returns if a `Tensor` has quantized backend. bool is_quantized() const; + /// If a tensor is a quantized tensor, returns its quantizer + /// TODO: it's not in native_functions.yaml yet as it's not exposed to python + QuantizerPtr quantizer() const; + #ifdef BUILD_NAMEDTENSOR /// Returns if a `Tensor` has any dimension names bool has_names() const; @@ -317,19 +322,42 @@ class CAFFE2_API Tensor { template TensorAccessor accessor() && = delete; - // Return a `PackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and + // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and // dimension. You can optionally specify RestrictPtrTraits as a template parameter to // cast the data pointer to a __restrict__ pointer. - // In order to use this, your CUDA kernel has to take a corresponding PackedTensorAccessor + // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor // as an argument. template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> - PackedTensorAccessor packed_accessor() const& { + GenericPackedTensorAccessor generic_packed_accessor() const& { static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr()"); TORCH_CHECK(dim() == N, "expected ", N, " dims but tensor has ", dim()); - return PackedTensorAccessor(static_cast::PtrType>(data_ptr()),sizes().data(),strides().data()); + return GenericPackedTensorAccessor(static_cast::PtrType>(data_ptr()),sizes().data(),strides().data()); } - template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> - PackedTensorAccessor packed_accessor() && = delete; + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + GenericPackedTensorAccessor generic_packed_accessor() && = delete; + + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor32 packed_accessor32() const& { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor32 packed_accessor32() && = delete; + + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor64 packed_accessor64() const& { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor64 packed_accessor64() && = delete; + + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") + GenericPackedTensorAccessor packed_accessor() const & { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") + GenericPackedTensorAccessor packed_accessor() && = delete; Tensor operator-() const; Tensor& operator+=(const Tensor & other); @@ -397,7 +425,7 @@ class CAFFE2_API Tensor { }; namespace detail { -// Helper creator for Tensor clas which doesn't requires the users to pass +// Helper creator for Tensor class which doesn't requires the users to pass // in an intrusive_ptr instead it just converts the argument passed to // requested intrusive_ptr type. template @@ -405,23 +433,10 @@ Tensor make_tensor(Args&&... args) { return Tensor(c10::make_intrusive(std::forward(args)...)); } -inline Backend infer_backend(const Tensor & t) { - TORCH_CHECK(t.defined(), "undefined Tensor"); - return tensorTypeIdToBackend(t.type_id()); -} -inline Backend infer_backend(const TensorList & tl) { - TORCH_CHECK(tl.size() > 0, "expected a non-empty list of Tensors"); - return tensorTypeIdToBackend(tl[0].type_id()); -} +} // namespace detail -inline bool infer_is_variable(const Tensor & t) { - TORCH_CHECK(t.defined(), "undefined Tensor"); - return t.is_variable(); +static inline TensorTypeId legacyExtractTypeId(const Tensor& t) { + return legacyExtractTypeId(t.type_set()); } -inline bool infer_is_variable(const TensorList & tl) { - TORCH_CHECK(tl.size() > 0, "expected a non-empty list of Tensors"); - return tl[0].is_variable(); -} -} // namespace detail } // namespace at diff --git a/aten/src/ATen/templates/TensorMethods.h b/aten/src/ATen/templates/TensorMethods.h index d9ff24911b3e2..a64eebdd8a20c 100644 --- a/aten/src/ATen/templates/TensorMethods.h +++ b/aten/src/ATen/templates/TensorMethods.h @@ -8,12 +8,11 @@ #include #include #include -#if !defined(CAFFE2_IS_XPLAT_BUILD) #include -#endif -#ifdef BUILD_NAMEDTENSOR +#include #include -#endif +#include + #ifdef USE_STATIC_DISPATCH #include #include @@ -23,6 +22,30 @@ namespace at { +namespace detail { + +struct MultiDispatchTensorTypeSet : IterArgs { + TensorTypeSet ts; + void operator()(const at::Tensor& x) { + ts = ts | x.type_set(); + } + void operator()(TensorOptions x) { + ts = ts | x.type_set(); + } + void operator()(at::ArrayRef xs) { + for (const auto& x : xs) { + ts = ts | x.type_set(); + } + } +}; + +template +TensorTypeSet multi_dispatch_tensor_type_set(Args&&... args) { + return MultiDispatchTensorTypeSet().apply(std::forward(args)...).ts; +} + +} + struct Quantizer; // This is temporary typedef to enable Quantizer in aten native function API // we'll remove them when we are actually exposing Quantizer class diff --git a/aten/src/ATen/templates/TypeDefault.cpp b/aten/src/ATen/templates/TypeDefault.cpp index 8f000827dca87..0fb2eb5741164 100644 --- a/aten/src/ATen/templates/TypeDefault.cpp +++ b/aten/src/ATen/templates/TypeDefault.cpp @@ -5,9 +5,7 @@ #include #include #include -#ifdef BUILD_NAMEDTENSOR #include -#endif #include #include #include @@ -16,11 +14,15 @@ #include #include #include +#include +#include namespace at { ${type_method_definitions} -static auto& registerer = globalATenDispatch() +#ifndef USE_STATIC_DISPATCH +static auto registerer = torch::RegisterOperators() ${function_registrations}; +#endif } diff --git a/aten/src/ATen/templates/TypeDefault.h b/aten/src/ATen/templates/TypeDefault.h index faedcc6bc349a..7d7f5aa211427 100644 --- a/aten/src/ATen/templates/TypeDefault.h +++ b/aten/src/ATen/templates/TypeDefault.h @@ -9,9 +9,8 @@ #include #include #include -#ifdef BUILD_NAMEDTENSOR -#include -#endif +#include +#include namespace c10 { struct Storage; diff --git a/aten/src/ATen/templates/TypeDerived.cpp b/aten/src/ATen/templates/TypeDerived.cpp index 58851b0f584aa..b5c800024c557 100644 --- a/aten/src/ATen/templates/TypeDerived.cpp +++ b/aten/src/ATen/templates/TypeDerived.cpp @@ -12,9 +12,7 @@ #include #include #include -#ifdef BUILD_NAMEDTENSOR #include -#endif #include #include #include @@ -23,6 +21,7 @@ #include #include #include +#include #include #include @@ -30,6 +29,7 @@ #include #include +#include $extra_cuda_headers $legacy_th_headers @@ -44,6 +44,8 @@ Tensor * ${Type}::add(Tensor & a, Tensor & b) { ${type_derived_method_definitions} -static auto& registerer = globalATenDispatch() +#ifndef USE_STATIC_DISPATCH +static auto registerer = torch::RegisterOperators() ${function_registrations}; +#endif } diff --git a/aten/src/ATen/test/Dimname_test.cpp b/aten/src/ATen/test/Dimname_test.cpp index 16465987137b4..f5aac1e347b7a 100644 --- a/aten/src/ATen/test/Dimname_test.cpp +++ b/aten/src/ATen/test/Dimname_test.cpp @@ -1,77 +1,61 @@ -#ifdef BUILD_NAMEDTENSOR #include #include #include #include +#include -using at::is_valid_identifier; +#ifdef BUILD_NAMEDTENSOR using at::NameType; using at::Symbol; using at::Dimname; TEST(DimnameTest, isValidIdentifier) { - ASSERT_TRUE(is_valid_identifier("a")); - ASSERT_TRUE(is_valid_identifier("batch")); - ASSERT_TRUE(is_valid_identifier("N")); - ASSERT_TRUE(is_valid_identifier("CHANNELS")); - ASSERT_TRUE(is_valid_identifier("foo_bar_baz")); + ASSERT_TRUE(Dimname::isValidName("a")); + ASSERT_TRUE(Dimname::isValidName("batch")); + ASSERT_TRUE(Dimname::isValidName("N")); + ASSERT_TRUE(Dimname::isValidName("CHANNELS")); + ASSERT_TRUE(Dimname::isValidName("foo_bar_baz")); - ASSERT_FALSE(is_valid_identifier("")); - ASSERT_FALSE(is_valid_identifier(" ")); - ASSERT_FALSE(is_valid_identifier(" a ")); - ASSERT_FALSE(is_valid_identifier("batch1")); - ASSERT_FALSE(is_valid_identifier("foo_bar_1")); - ASSERT_FALSE(is_valid_identifier("?")); - ASSERT_FALSE(is_valid_identifier("-")); + ASSERT_FALSE(Dimname::isValidName("")); + ASSERT_FALSE(Dimname::isValidName(" ")); + ASSERT_FALSE(Dimname::isValidName(" a ")); + ASSERT_FALSE(Dimname::isValidName("batch1")); + ASSERT_FALSE(Dimname::isValidName("foo_bar_1")); + ASSERT_FALSE(Dimname::isValidName("?")); + ASSERT_FALSE(Dimname::isValidName("-")); } TEST(DimnameTest, wildcardName) { Dimname wildcard = Dimname::wildcard(); ASSERT_EQ(wildcard.type(), NameType::WILDCARD); - ASSERT_EQ(wildcard.full_name(), Symbol::dimname("*")); - ASSERT_EQ(wildcard.untagged_name(), Symbol::dimname("*")); + ASSERT_EQ(wildcard.symbol(), Symbol::dimname("*")); } TEST(DimnameTest, createNormalName) { auto foo = Symbol::dimname("foo"); auto dimname = Dimname::fromSymbol(foo); - ASSERT_EQ(dimname.type(), NameType::NORMAL); - ASSERT_EQ(dimname.full_name(), foo); - ASSERT_EQ(dimname.untagged_name(), foo); - + ASSERT_EQ(dimname.type(), NameType::BASIC); + ASSERT_EQ(dimname.symbol(), foo); + ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("inva.lid")), c10::Error); ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("invalid1")), c10::Error); } -TEST(DimnameTest, createTaggedName) { - auto foo_bar = Symbol::dimname("foo.bar"); - auto foo = Symbol::dimname("foo"); - auto dimname = Dimname::fromSymbol(foo_bar); - ASSERT_EQ(dimname.type(), NameType::TAGGED); - ASSERT_EQ(dimname.full_name(), foo_bar); - ASSERT_EQ(dimname.untagged_name(), foo); - - ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname(".bar")), c10::Error); - ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("foo.")), c10::Error); - ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("foo.bar.baz")), c10::Error); -} - static void check_unify_and_match( const std::string& dimname, const std::string& other, at::optional expected) { auto dimname1 = Dimname::fromSymbol(Symbol::dimname(dimname)); auto dimname2 = Dimname::fromSymbol(Symbol::dimname(other)); - auto result = at::unify(dimname1, dimname2); + auto result = dimname1.unify(dimname2); if (expected) { auto expected_result = Dimname::fromSymbol(Symbol::dimname(*expected)); - ASSERT_EQ(result->full_name(), expected_result.full_name()); + ASSERT_EQ(result->symbol(), expected_result.symbol()); ASSERT_EQ(result->type(), expected_result.type()); - ASSERT_EQ(result->untagged_name(), expected_result.untagged_name()); - ASSERT_TRUE(match(dimname1, dimname2)); + ASSERT_TRUE(dimname1.matches(dimname2)); } else { ASSERT_FALSE(result); - ASSERT_FALSE(match(dimname1, dimname2)); + ASSERT_FALSE(dimname1.matches(dimname2)); } } @@ -81,13 +65,5 @@ TEST(DimnameTest, unifyAndMatch) { check_unify_and_match("*", "a", "a"); check_unify_and_match("*", "*", "*"); check_unify_and_match("a", "b", c10::nullopt); - - check_unify_and_match("*", "a.b", "a.b"); - check_unify_and_match("a", "a.b", "a"); - check_unify_and_match("c", "a.b", c10::nullopt); - check_unify_and_match("a.b", "a.c", "a"); - check_unify_and_match("a.b", "a.b", "a.b"); - check_unify_and_match("c.b", "a.b", c10::nullopt); - check_unify_and_match("c.b", "a", c10::nullopt); } #endif diff --git a/aten/src/ATen/test/NamedTensor_test.cpp b/aten/src/ATen/test/NamedTensor_test.cpp index 1385f33485d1d..80176386fdb7b 100644 --- a/aten/src/ATen/test/NamedTensor_test.cpp +++ b/aten/src/ATen/test/NamedTensor_test.cpp @@ -1,11 +1,12 @@ -#ifdef BUILD_NAMEDTENSOR #include #include #include #include #include +#include +#ifdef BUILD_NAMEDTENSOR using at::Dimname; using at::DimnameList; using at::NamedTensorMeta; @@ -51,7 +52,7 @@ static bool dimnames_equal(at::DimnameList names, at::DimnameList other) { for (auto i = 0; i < names.size(); i++) { const auto& name = names[i]; const auto& other_name = other[i]; - if (name.type() != other_name.type() || name.full_name() != other_name.full_name()) { + if (name.type() != other_name.type() || name.symbol() != other_name.symbol()) { return false; } } @@ -128,17 +129,6 @@ TEST(NamedTensorTest, dimnameToPosition) { tensor = at::empty({1, 1, 1, 1}, names); ASSERT_EQ(dimname_to_position(tensor, H), 2); - - auto Cin = dimnameFromString("C.in"); - auto Cout = dimnameFromString("C.out"); - tensor = at::empty({1, 1, 1, 1}, names); - ASSERT_THROW(dimname_to_position(tensor, Cin), c10::Error); - - tensor = at::empty({1, 1}, std::vector({ Cin, Cout })); - ASSERT_THROW(dimname_to_position(tensor, C), c10::Error); - - tensor = at::empty({1, 1}, std::vector({ Cin, N })); - ASSERT_EQ(dimname_to_position(tensor, C), 0); } static void check_unify( diff --git a/aten/src/ATen/test/cuda_packedtensoraccessor_test.cu b/aten/src/ATen/test/cuda_packedtensoraccessor_test.cu index ff0c0c4eb89e5..12d3b3d9731f1 100644 --- a/aten/src/ATen/test/cuda_packedtensoraccessor_test.cu +++ b/aten/src/ATen/test/cuda_packedtensoraccessor_test.cu @@ -9,9 +9,9 @@ using namespace at; __global__ void test_tensor_packed_accessor_kernel( - PackedTensorAccessor resa, - PackedTensorAccessor t1a, - PackedTensorAccessor t2a) { + PackedTensorAccessor64 resa, + PackedTensorAccessor64 t1a, + PackedTensorAccessor64 t2a) { for (int64_t i = 0; i < resa.size(0); i++) { float val = 0.0f; for (int64_t j = 0; j < t1a.size(1); j++) { @@ -21,7 +21,7 @@ __global__ void test_tensor_packed_accessor_kernel( } } -// test PackedTensorAccessor and Tensor.packed_accessor +// test GenericPackedTensorAccessor and Tensor.generic_packed_accessor TEST(PackedtensoraccessorTest, PackedtensoraccessorTestCUDA) { if (!at::cuda::is_available()) return; manual_seed(123); @@ -30,9 +30,9 @@ TEST(PackedtensoraccessorTest, PackedtensoraccessorTestCUDA) { Tensor t2 = rand({4}, CUDA(kFloat)); Tensor res = empty({4}, CUDA(kFloat)); - auto t1a = t1.packed_accessor(); - auto t2a = t2.packed_accessor(); - auto resa = res.packed_accessor(); + auto t1a = t1.packed_accessor64(); + auto t2a = t2.packed_accessor64(); + auto resa = res.packed_accessor64(); auto stream = at::cuda::getCurrentCUDAStream(); diff --git a/aten/src/ATen/test/extension_backend_test.cpp b/aten/src/ATen/test/extension_backend_test.cpp index 367ea9004c414..c7ac6ed042c50 100644 --- a/aten/src/ATen/test/extension_backend_test.cpp +++ b/aten/src/ATen/test/extension_backend_test.cpp @@ -2,7 +2,7 @@ #include #include -#include +#include using namespace at; @@ -24,10 +24,11 @@ Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) { TEST(BackendExtensionTest, TestRegisterOp) { EXPECT_ANY_THROW(empty({5, 5}, at::kMSNPU)); - globalATenDispatch().registerOp( - Backend::MSNPU, - "aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", - &empty_override); + auto registry1 = torch::RegisterOperators() + .op(torch::RegisterOperators::options() + .schema("aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor") + .impl_unboxedOnlyKernel(TensorTypeId::MSNPUTensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)); Tensor a = empty({5, 5}, at::kMSNPU); ASSERT_EQ(a.device().type(), at::kMSNPU); ASSERT_EQ(a.device().index(), 1); @@ -40,10 +41,11 @@ TEST(BackendExtensionTest, TestRegisterOp) { ASSERT_EQ(b.dtype(), caffe2::TypeMeta::Make()); EXPECT_ANY_THROW(add(a, b)); - globalATenDispatch().registerOp( - Backend::MSNPU, - "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", - &add_override); + auto registry2 = torch::RegisterOperators() + .op(torch::RegisterOperators::options() + .schema("aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor") + .impl_unboxedOnlyKernel(TensorTypeId::MSNPUTensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)); add(a, b); ASSERT_EQ(test_int, 2); @@ -53,9 +55,10 @@ TEST(BackendExtensionTest, TestRegisterOp) { // Attempt to register on a schema that has already has a function EXPECT_ANY_THROW( - globalATenDispatch().registerOp( - Backend::MSNPU, - "aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", - &empty_override) + torch::RegisterOperators() + .op(torch::RegisterOperators::options() + .schema("aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor") + .impl_unboxedOnlyKernel(TensorTypeId::MSNPUTensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) ); } diff --git a/aten/src/ATen/test/quantized_test.cpp b/aten/src/ATen/test/quantized_test.cpp index eafe4af0522f8..ee279c6074096 100644 --- a/aten/src/ATen/test/quantized_test.cpp +++ b/aten/src/ATen/test/quantized_test.cpp @@ -131,14 +131,14 @@ TEST(TestQTensor, EmptyPerchannelQuantized) { {ch_axis}, at::device(at::kCPU).dtype(kQUInt8)); // Assigning to QTensor - auto* q_data = q.data(); + auto* q_data = q.data_ptr(); for (int i = 0; i < numel; ++i) { q_data[i].val_ = val; } // dequantize auto r = q.dequantize(); - auto* r_data = r.data(); + auto* r_data = r.data_ptr(); for (int i = 0; i < numel; ++i) { ASSERT_EQ( r_data[i], diff --git a/aten/src/ATen/test/scalar_test.cpp b/aten/src/ATen/test/scalar_test.cpp index b88abbcb04811..2ff2585a26f44 100644 --- a/aten/src/ATen/test/scalar_test.cpp +++ b/aten/src/ATen/test/scalar_test.cpp @@ -58,7 +58,7 @@ TEST(TestScalar, TestScalar) { Half h = bar.toHalf(); Scalar h2 = h; cout << "H2: " << h2.toDouble() << " " << what.toFloat() << " " - << bar.toDouble() << " " << what.isIntegral() << "\n"; + << bar.toDouble() << " " << what.isIntegral(false) << "\n"; auto gen = at::detail::getDefaultCPUGenerator(); { // See Note [Acquire lock when using random generators] diff --git a/aten/src/ATen/test/tensor_iterator_test.cpp b/aten/src/ATen/test/tensor_iterator_test.cpp index ae6c18e6343ca..0ba5da05810bd 100644 --- a/aten/src/ATen/test/tensor_iterator_test.cpp +++ b/aten/src/ATen/test/tensor_iterator_test.cpp @@ -105,3 +105,59 @@ TEST(TensorIteratorTest, SerialLoopSingleThread) { }); } +TEST(TensorIteratorTest, InputDType) { + auto iter = at::TensorIterator(); + iter.add_output(at::ones({1, 1}, at::dtype(at::kBool))); + iter.add_input(at::ones({1, 1}, at::dtype(at::kFloat))); + iter.add_input(at::ones({1, 1}, at::dtype(at::kDouble))); + iter.dont_compute_common_dtype(); + iter.build(); + EXPECT_TRUE(iter.input_dtype() == at::kFloat); + EXPECT_TRUE(iter.input_dtype(0) == at::kFloat); + EXPECT_TRUE(iter.input_dtype(1) == at::kDouble); +} + +TEST(TensorIteratorTest, ComputeCommonDTypeInputOnly) { + auto iter = at::TensorIterator(); + iter.add_output(at::ones({1, 1}, at::dtype(at::kBool))); + iter.add_input(at::ones({1, 1}, at::dtype(at::kFloat))); + iter.add_input(at::ones({1, 1}, at::dtype(at::kDouble))); + iter.compute_common_dtype_only_for_inputs(); + iter.build(); + EXPECT_TRUE(iter.dtype(0) == at::kBool); + EXPECT_TRUE(iter.dtype(1) == at::kDouble); + EXPECT_TRUE(iter.dtype(2) == at::kDouble); +} + +TEST(TensorIteratorTest, DoNotComputeCommonDTypeInputOnly) { + auto iter = at::TensorIterator(); + iter.add_output(at::ones({1, 1}, at::dtype(at::kLong))); + iter.add_input(at::ones({1, 1}, at::dtype(at::kFloat))); + iter.add_input(at::ones({1, 1}, at::dtype(at::kDouble))); + iter.compute_common_dtype_only_for_inputs(); + iter.dont_compute_common_dtype(); + iter.build(); + EXPECT_TRUE(iter.dtype(0) == at::kLong); + EXPECT_TRUE(iter.dtype(1) == at::kFloat); + EXPECT_TRUE(iter.dtype(2) == at::kDouble); +} + +TEST(TensorIteratorTest, DoNotComputeCommonDTypeIfInputSameAsOutput) { + Tensor inout = at::ones({1, 1}, at::dtype(at::kFloat)); + auto iter = at::TensorIterator(); + iter.add_output(inout); + iter.add_input(inout); + iter.add_input(at::ones({1, 1}, at::dtype(at::kDouble))); + iter.compute_common_dtype_only_for_inputs(); + ASSERT_ANY_THROW(iter.build()); +} + +TEST(TensorIteratorTest, DoNotComputeCommonDTypeIfOutputIsUndefined) { + Tensor out; + auto iter = at::TensorIterator(); + iter.add_output(out); + iter.add_input(at::ones({1, 1}, at::dtype(at::kDouble))); + iter.add_input(at::ones({1, 1}, at::dtype(at::kFloat))); + iter.compute_common_dtype_only_for_inputs(); + ASSERT_ANY_THROW(iter.build()); +} diff --git a/aten/src/TH/THGeneral.h.in b/aten/src/TH/THGeneral.h.in index f5ea79791023d..5ae5201ceb338 100644 --- a/aten/src/TH/THGeneral.h.in +++ b/aten/src/TH/THGeneral.h.in @@ -48,6 +48,11 @@ #endif // defined(__GNUC__) #endif // _WIN32 +#ifdef NO_EXPORT +#undef TH_CPP_API +#define TH_CPP_API +#endif + #define TH_API TH_EXTERNC TH_CPP_API #ifdef _WIN32 diff --git a/aten/src/TH/THGenerateBFloat16Type.h b/aten/src/TH/THGenerateBFloat16Type.h index 40c34b66e41df..40f0a8c570330 100644 --- a/aten/src/TH/THGenerateBFloat16Type.h +++ b/aten/src/TH/THGenerateBFloat16Type.h @@ -4,11 +4,13 @@ #include #define scalar_t at::BFloat16 +#define accreal double #define TH_CONVERT_ACCREAL_TO_REAL(_val) (scalar_t)(_val) #define Real BFloat16 #define TH_REAL_IS_BFLOAT16 #line 1 TH_GENERIC_FILE #include TH_GENERIC_FILE +#undef accreal #undef scalar_t #undef Real #undef TH_REAL_IS_BFLOAT16 diff --git a/aten/src/TH/THTensor.hpp b/aten/src/TH/THTensor.hpp index b0a0603f54239..699ebaa0cee8b 100644 --- a/aten/src/TH/THTensor.hpp +++ b/aten/src/TH/THTensor.hpp @@ -92,6 +92,9 @@ inline int64_t THTensor_sizeLegacyNoScalars(const THTensor *self, int dim) #include #include +#include +#include + inline std::vector THTensor_sizesLegacyNoScalars(const THTensor *self) { if (self->dim() == 0) { return {1}; diff --git a/aten/src/TH/generic/THStorageCopy.cpp b/aten/src/TH/generic/THStorageCopy.cpp index c5eda5699f541..f4234130683ea 100644 --- a/aten/src/TH/generic/THStorageCopy.cpp +++ b/aten/src/TH/generic/THStorageCopy.cpp @@ -2,18 +2,14 @@ #define TH_GENERIC_FILE "TH/generic/THStorageCopy.cpp" #else -void THStorage_(rawCopy)(THStorage *storage, scalar_t *src) -{ - ptrdiff_t i; - scalar_t *data = THStorage_(data)(storage); - for(i = 0; i < storage->numel(); i++) - data[i] = src[i]; -} - void THStorage_(copy)(THStorage *storage, THStorage *src) { THArgCheck(storage->numel() == src->numel(), 2, "size mismatch"); - THStorage_(rawCopy)(storage, THStorage_(data)(src)); + scalar_t *scalar_src = THStorage_(data)(src); + scalar_t *data = THStorage_(data)(storage); + for (ptrdiff_t i = 0; i < storage->numel(); ++i) { + data[i] = scalar_src[i]; + } } // NOTE: for performance, these macros generally use the raw data pointer in the inner loops, diff --git a/aten/src/TH/generic/THStorageCopy.h b/aten/src/TH/generic/THStorageCopy.h index 4797ba6761f05..bddc6db4e1297 100644 --- a/aten/src/TH/generic/THStorageCopy.h +++ b/aten/src/TH/generic/THStorageCopy.h @@ -3,8 +3,6 @@ #else /* Support for copy between different Storage types */ - -TH_API void THStorage_(rawCopy)(THStorage *storage, scalar_t *src); TH_API void THStorage_(copy)(THStorage *storage, THStorage *src); TH_API void THStorage_(copyByte)(THStorage *storage, struct THByteStorage *src); TH_API void THStorage_(copyChar)(THStorage *storage, struct THCharStorage *src); diff --git a/aten/src/TH/generic/THTensor.cpp b/aten/src/TH/generic/THTensor.cpp index d71ab8d6e6ccd..0850ceae5e811 100644 --- a/aten/src/TH/generic/THTensor.cpp +++ b/aten/src/TH/generic/THTensor.cpp @@ -5,9 +5,7 @@ #include #include #include -#ifdef BUILD_NAMEDTENSOR #include -#endif /**** access methods ****/ THStorage *THTensor_(storage)(const THTensor *self) diff --git a/aten/src/TH/generic/THTensorApply.hpp b/aten/src/TH/generic/THTensorApply.hpp index a7994c6bbad19..579a0c32c4c06 100644 --- a/aten/src/TH/generic/THTensorApply.hpp +++ b/aten/src/TH/generic/THTensorApply.hpp @@ -17,31 +17,31 @@ } // Used for `scatter` and `scatterAdd` -// Assumes TENSOR1 is real -// TENSOR2 is src -// TENSOR3 is index +// Assumes TENSOR1 is index +// TENSOR2 is real +// TENSOR3 is src // Tests: // 1. index->size(d) <= src->size(d) for all d // 2. index->size(d) <= real->size(d) for all d != dim #define TH_TENSOR_DIM_APPLY3_SIZE_SCATTER(TENSOR1, TENSOR2, TENSOR3, DIMENSION) \ { \ int shape_check_flag = 0; \ - for (TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < THTensor_nDimensionLegacyAll(TENSOR1); TH_TENSOR_DIM_APPLY_i++) \ + for (TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < THTensor_nDimensionLegacyAll(TENSOR2); TH_TENSOR_DIM_APPLY_i++) \ { \ - int64_t TENSOR3##_dim_size = THTensor_sizeLegacyNoScalars(TENSOR3, TH_TENSOR_DIM_APPLY_i); \ + int64_t TENSOR1##_dim_size = THTensor_sizeLegacyNoScalars(TENSOR1, TH_TENSOR_DIM_APPLY_i); \ if (TH_TENSOR_DIM_APPLY_i != DIMENSION) { \ - if (TENSOR3##_dim_size > THTensor_sizeLegacyNoScalars(TENSOR1, TH_TENSOR_DIM_APPLY_i)) { \ + if (TENSOR1##_dim_size > THTensor_sizeLegacyNoScalars(TENSOR2, TH_TENSOR_DIM_APPLY_i)) { \ shape_check_flag = 1; \ break; \ } \ } \ - if (TENSOR3##_dim_size > THTensor_sizeLegacyNoScalars(TENSOR2, TH_TENSOR_DIM_APPLY_i)) { \ + if (TENSOR1##_dim_size > THTensor_sizeLegacyNoScalars(TENSOR3, TH_TENSOR_DIM_APPLY_i)) { \ shape_check_flag = 1; \ break; \ } \ } \ if (shape_check_flag == 1) { \ - AT_ERROR("Expected ", #TENSOR3, " ", TENSOR3->sizes(), " to be smaller size than ", #TENSOR2, " ", TENSOR2->sizes(), " and to be smaller than ", #TENSOR1, " ", TENSOR1->sizes(), " apart from dimension ", DIMENSION); \ + AT_ERROR("Expected ", #TENSOR1, " ", TENSOR1->sizes(), " to be smaller size than ", #TENSOR3, " ", TENSOR3->sizes(), " and to be smaller than ", #TENSOR2, " ", TENSOR2->sizes(), " apart from dimension ", DIMENSION); \ } \ } diff --git a/aten/src/TH/generic/THTensorEvenMoreMath.cpp b/aten/src/TH/generic/THTensorEvenMoreMath.cpp index 6d14eb257c3a5..0e6d14030c3b5 100644 --- a/aten/src/TH/generic/THTensorEvenMoreMath.cpp +++ b/aten/src/TH/generic/THTensorEvenMoreMath.cpp @@ -3,9 +3,8 @@ #else #include -#ifdef BUILD_NAMEDTENSOR #include -#endif +#include // Finds non-zero elements of a tensor and returns their subscripts void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor) @@ -672,7 +671,7 @@ void THTensor_(scatter)(THTensor *tensor, int dim, THLongTensor *index, THTensor elems_per_row = THTensor_sizeLegacyNoScalars(index, dim); - TH_TENSOR_DIM_APPLY3(scalar_t, tensor, scalar_t, src, int64_t, index, dim, + TH_TENSOR_DIM_APPLY3(int64_t, index, scalar_t, tensor, scalar_t, src, dim, TH_TENSOR_DIM_APPLY3_SIZE_SCATTER, for (i = 0; i < elems_per_row; ++i) { @@ -704,7 +703,7 @@ void THTensor_(scatterAdd)(THTensor *tensor, int dim, THLongTensor *index, THTen elems_per_row = THTensor_sizeLegacyNoScalars(index, dim); - TH_TENSOR_DIM_APPLY3(scalar_t, tensor, scalar_t, src, int64_t, index, dim, + TH_TENSOR_DIM_APPLY3(int64_t, index, scalar_t, tensor, scalar_t, src, dim, TH_TENSOR_DIM_APPLY3_SIZE_SCATTER, for (i = 0; i < elems_per_row; ++i) { diff --git a/aten/src/TH/generic/THTensorLapack.cpp b/aten/src/TH/generic/THTensorLapack.cpp index 57cbafeb2c7f8..02b2d786e3cb8 100644 --- a/aten/src/TH/generic/THTensorLapack.cpp +++ b/aten/src/TH/generic/THTensorLapack.cpp @@ -179,8 +179,9 @@ void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a) if (free_b) c10::raw::intrusive_ptr::decref(b); } -void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *jobvr) +void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, bool eigenvectors) { + char jobvr = eigenvectors ? 'V' : 'N'; int n, lda, lwork, info, ldvr; THTensor *work=nullptr, *wi, *wr, *a; scalar_t wkopt; @@ -204,7 +205,7 @@ void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *job rv_data = NULL; ldvr = 1; - if (*jobvr == 'V') + if (jobvr == 'V') { THTensor_(resize2d)(rv_,n,n); /* guard against someone passing a correct size, but wrong stride */ @@ -217,13 +218,13 @@ void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *job if (n > 0) { // lapack doesn't work with size 0 /* get optimal workspace size */ - THLapack_(geev)('N', jobvr[0], n, a->data(), lda, wr->data(), wi->data(), + THLapack_(geev)('N', jobvr, n, a->data(), lda, wr->data(), wi->data(), NULL, 1, rv_data, ldvr, &wkopt, -1, &info); lwork = (int)wkopt; work = THTensor_(newWithSize1d)(lwork); - THLapack_(geev)('N', jobvr[0], n, a->data(), lda, wr->data(), wi->data(), + THLapack_(geev)('N', jobvr, n, a->data(), lda, wr->data(), wi->data(), NULL, 1, rv_data, ldvr, work->data(), lwork, &info); THLapackCheckWithCleanup(" Lapack Error in %s : %d off-diagonal elements of an didn't converge to zero", @@ -247,7 +248,7 @@ void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *job } } - if (*jobvr == 'V') + if (jobvr == 'V') { THTensor_(checkTransposed)(rv_); THTensor_(freeCopyTo)(rv__, rv_); @@ -259,7 +260,7 @@ void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *job c10::raw::intrusive_ptr::decref(work); } -void THTensor_(copyUpLoTriangle)(THTensor *a, const char *uplo) +void THTensor_(copyUpLoTriangle)(THTensor *a, char uplo) { THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1, "A should be 2 dimensional"); THArgCheck(a->size(0) == a->size(1), 1, "A should be square"); @@ -271,7 +272,7 @@ void THTensor_(copyUpLoTriangle)(THTensor *a, const char *uplo) int64_t i, j; /* Upper Triangular Case */ - if (uplo[0] == 'U') + if (uplo == 'U') { /* Clear lower triangle (excluding diagonals) */ for (i=0; isize(0) == a->size(1), 1, "A should be square"); @@ -307,7 +309,7 @@ void THTensor_(potri)(THTensor *ra_, THTensor *a, const char *uplo) lda = n; /* Run inverse */ - THLapack_(potri)(uplo[0], n, ra__->data(), lda, &info); + THLapack_(potri)(uplo, n, ra__->data(), lda, &info); THLapackCheckWithCleanup("Lapack Error %s : A(%d,%d) is 0, A cannot be factorized", THCleanup(c10::raw::intrusive_ptr::decref(ra__);), "potri", info, info); @@ -430,21 +432,23 @@ void THTensor_(orgqr)(THTensor *ra_, THTensor *a, THTensor *tau) elementary reflectors, such as is produced by the geqrf function. Args: - * `ra_` - result Tensor, which will contain the matrix Q' c. - * `a` - input Tensor, which should be a matrix with the directions of the - elementary reflectors below the diagonal. If NULL, `ra_` is used as - input. - * `tau` - input Tensor, containing the magnitudes of the elementary - reflectors. - * `c` - input Tensor, containing the matrix to be multiplied. - * `side` - char, determining whether c is left- or right-multiplied with Q. - * `trans` - char, determining whether to transpose Q before multiplying. + * `ra_` - result Tensor, which will contain the matrix Q' c. + * `a` - input Tensor, which should be a matrix with the directions of the + elementary reflectors below the diagonal. If NULL, `ra_` is used as + input. + * `tau` - input Tensor, containing the magnitudes of the elementary + reflectors. + * `c` - input Tensor, containing the matrix to be multiplied. + * `left` - bool, determining whether c is left- or right-multiplied with Q. + * `transpose` - bool, determining whether to transpose Q before multiplying. For further details, please see the LAPACK documentation. */ -void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, const char *side, const char *trans) +void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, bool left, bool transpose) { + char side = left ? 'L' : 'R'; + char trans = transpose ? 'T' : 'N'; if (a == NULL) a = ra_; THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1, "A should be 2 dimensional"); @@ -455,7 +459,7 @@ void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, co int n = c->size(1); int k = THTensor_sizeLegacyNoScalars(tau, 0); int lda; - if (*side == 'L') + if (side == 'L') { lda = m; } @@ -468,14 +472,14 @@ void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, co /* Dry-run to query the suggested size of the workspace. */ int info = 0; scalar_t wkopt = 0; - THLapack_(ormqr)(side[0], trans[0], m, n, k, a->data(), lda, + THLapack_(ormqr)(side, trans, m, n, k, a->data(), lda, tau->data(), ra__->data(), ldc, &wkopt, -1, &info); /* Allocate the workspace and call LAPACK to do the real work. */ int lwork = (int)wkopt; THTensor *work = THTensor_(newWithSize1d)(lwork); - THLapack_(ormqr)(side[0], trans[0], m, n, k, a->data(), lda, + THLapack_(ormqr)(side, trans, m, n, k, a->data(), lda, tau->data(), ra__->data(), ldc, work->data(), lwork, &info); diff --git a/aten/src/TH/generic/THTensorLapack.h b/aten/src/TH/generic/THTensorLapack.h index 8db76fb0ea51e..05dbbf9f12ec5 100644 --- a/aten/src/TH/generic/THTensorLapack.h +++ b/aten/src/TH/generic/THTensorLapack.h @@ -3,10 +3,10 @@ #else TH_API void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b_, THTensor *a_); -TH_API void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, const char *jobvr); -TH_API void THTensor_(potri)(THTensor *ra_, THTensor *a, const char *uplo); +TH_API void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_, bool eigenvectors); +TH_API void THTensor_(potri)(THTensor *ra_, THTensor *a, bool upper); TH_API void THTensor_(geqrf)(THTensor *ra_, THTensor *rtau_, THTensor *a); TH_API void THTensor_(orgqr)(THTensor *ra_, THTensor *a, THTensor *tau); -TH_API void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, const char *side, const char *trans); +TH_API void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, bool left, bool transpose); #endif diff --git a/aten/src/TH/generic/THTensorMath.cpp b/aten/src/TH/generic/THTensorMath.cpp index 116f2ad01a307..1e25ec71239c3 100644 --- a/aten/src/TH/generic/THTensorMath.cpp +++ b/aten/src/TH/generic/THTensorMath.cpp @@ -2,10 +2,9 @@ #define TH_GENERIC_FILE "TH/generic/THTensorMath.cpp" #else +#include #include -#ifdef BUILD_NAMEDTENSOR #include -#endif // HEY YOU! // @@ -182,7 +181,7 @@ void THTensor_(bitor)(THTensor *r_, THTensor *t, scalar_t value) #if !defined(TH_REAL_IS_BOOL) /* non bool only part */ -void THTensor_(addmm)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *m1, THTensor *m2) +void THTensor_(addmm)(THTensor *r_, THTensor *t, THTensor *m1, THTensor *m2, scalar_t beta, scalar_t alpha) { char transpose_r, transpose_m1, transpose_m2; THTensor *r__, *m1_, *m2_; @@ -338,7 +337,7 @@ void THTensor_(addmm)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha, #endif } -void THTensor_(addmv)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *mat, THTensor *vec) +void THTensor_(addmv)(THTensor *r_, THTensor *t, THTensor *mat, THTensor *vec, scalar_t beta, scalar_t alpha) { if( (mat->dim() != 2) || (THTensor_nDimension(vec) != 1) ) THError("matrix and vector expected, got %dD, %dD", @@ -414,7 +413,7 @@ void THTensor_(addmv)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha, #undef LDA_COND } -void THTensor_(addr)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *vec1, THTensor *vec2) +void THTensor_(addr)(THTensor *r_, THTensor *t, THTensor *vec1, THTensor *vec2, scalar_t beta, scalar_t alpha) { if( (THTensor_nDimension(vec1) != 1) || (THTensor_nDimension(vec2) != 1) ) THError("vector and vector expected, got %dD, %dD tensors", @@ -794,7 +793,7 @@ void THTensor_(match)(THTensor *r_, THTensor *m1, THTensor *m2, scalar_t gain) c10::raw::intrusive_ptr::decref(m2); } -void THTensor_(addbmm)(THTensor *result, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *batch1, THTensor *batch2) +void THTensor_(addbmm)(THTensor *result, THTensor *t, THTensor *batch1, THTensor *batch2, scalar_t beta, scalar_t alpha) { int64_t batch; @@ -829,7 +828,7 @@ void THTensor_(addbmm)(THTensor *result, scalar_t beta, THTensor *t, scalar_t al THTensor_(select)(matrix1, batch1, 0, batch); THTensor_(select)(matrix2, batch2, 0, batch); - THTensor_(addmm)(result, beta, result, alpha, matrix1, matrix2); + THTensor_(addmm)(result, result, matrix1, matrix2, beta, alpha); beta = 1; // accumulate output once } diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h index 275f20e5f6539..382c480a5d2bc 100644 --- a/aten/src/TH/generic/THTensorMath.h +++ b/aten/src/TH/generic/THTensorMath.h @@ -60,9 +60,9 @@ TH_API void THTensor_(maskedCopyBool)(THTensor *tensor, THBoolTensor *mask, THTe TH_API ptrdiff_t THTensor_(numel)(THTensor *t); -TH_API void THTensor_(addmv)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *mat, THTensor *vec); -TH_API void THTensor_(addmm)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *mat1, THTensor *mat2); -TH_API void THTensor_(addr)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *vec1, THTensor *vec2); +TH_API void THTensor_(addmv)(THTensor *r_, THTensor *t, THTensor *mat, THTensor *vec, scalar_t beta, scalar_t alpha); +TH_API void THTensor_(addmm)(THTensor *r_, THTensor *t, THTensor *mat1, THTensor *mat2, scalar_t beta, scalar_t alpha); +TH_API void THTensor_(addr)(THTensor *r_, THTensor *t, THTensor *vec1, THTensor *vec2, scalar_t beta, scalar_t alpha); #if !defined(TH_REAL_IS_BOOL) TH_API void THTensor_(mul)(THTensor *r_, THTensor *t, scalar_t value); @@ -131,8 +131,8 @@ TH_API void THTensor_(crshift)(THTensor *r_, THTensor *t, THTensor *src); TH_API void THTensor_(cfmod)(THTensor *r_, THTensor *t, THTensor *src); TH_API void THTensor_(cremainder)(THTensor *r_, THTensor *t, THTensor *src); -TH_API void THTensor_(addbmm)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *batch1, THTensor *batch2); -TH_API void THTensor_(baddbmm)(THTensor *r_, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *batch1, THTensor *batch2); +TH_API void THTensor_(addbmm)(THTensor *r_, THTensor *t, THTensor *batch1, THTensor *batch2, scalar_t beta, scalar_t alpha); +TH_API void THTensor_(baddbmm)(THTensor *r_, THTensor *t, THTensor *batch1, THTensor *batch2, scalar_t beta, scalar_t alpha); TH_API void THTensor_(match)(THTensor *r_, THTensor *m1, THTensor *m2, scalar_t gain); @@ -177,8 +177,8 @@ TH_API void THTensor_(round)(THTensor *r_, THTensor *t); TH_API void THTensor_(trunc)(THTensor *r_, THTensor *t); TH_API void THTensor_(frac)(THTensor *r_, THTensor *t); -TH_API void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int biased, int keepdim); -TH_API void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int biased, int keepdim); +TH_API void THTensor_(std_single)(THTensor *r_, THTensor *t, int dimension, bool unbiased, int keepdim); +TH_API void THTensor_(var_single)(THTensor *r_, THTensor *t, int dimension, bool unbiased, int keepdim); TH_API void THTensor_(norm)(THTensor *r_, THTensor *t, scalar_t value, int dimension, int keepdim); TH_API void THTensor_(renorm)(THTensor *r_, THTensor *t, scalar_t value, int dimension, scalar_t maxnorm); TH_API accreal THTensor_(dist)(THTensor *a, THTensor *b, scalar_t value); @@ -186,8 +186,8 @@ TH_API void THTensor_(histc)(THTensor *hist, THTensor *tensor, int64_t nbins, sc TH_API void THTensor_(bhistc)(THTensor *hist, THTensor *tensor, int64_t nbins, scalar_t minvalue, scalar_t maxvalue); TH_API accreal THTensor_(meanall)(THTensor *self); -TH_API accreal THTensor_(varall)(THTensor *self, int biased); -TH_API accreal THTensor_(stdall)(THTensor *self, int biased); +TH_API accreal THTensor_(var_all)(THTensor *self, bool unbiased); +TH_API accreal THTensor_(std_all)(THTensor *self, bool unbiased); TH_API accreal THTensor_(normall)(THTensor *t, scalar_t value); #endif #endif diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp index 32beb005b018f..282f9fe20dafc 100644 --- a/aten/src/TH/generic/THTensorMoreMath.cpp +++ b/aten/src/TH/generic/THTensorMoreMath.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #ifdef BUILD_NAMEDTENSOR #include #endif @@ -366,7 +367,7 @@ void THTensor_(baddbmm)(THTensor *result, scalar_t beta, THTensor *t, scalar_t a THTensor_(select)(matrix2, batch2, 0, batch); THTensor_(select)(result_matrix, result, 0, batch); - THTensor_(addmm)(result_matrix, beta, result_matrix, alpha, matrix1, matrix2); + THTensor_(addmm)(result_matrix, result_matrix, matrix1, matrix2, beta, alpha); } c10::raw::intrusive_ptr::decref(matrix1); @@ -1053,7 +1054,7 @@ LAB_IMPLEMENT_BASIC_FUNCTION(rsqrt,TH_MATH_NAME(TH_rsqrt),HYPER_TH_OMP_OVERHEAD_ LAB_IMPLEMENT_VECTORIZED_FUNCTION(sigmoid,TH_MATH_NAME(TH_sigmoid),HYPER_TH_OMP_OVERHEAD_THRESHOLD) -void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int biased, int keepdim) +void THTensor_(std_single)(THTensor *r_, THTensor *t, int dimension, bool unbiased, int keepdim) { THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 3, "invalid dimension %d", dimension); @@ -1078,12 +1079,12 @@ void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int biased, int ke M2 += delta * delta2; } - if (biased && t_size >= 2) + if (!unbiased && t_size >= 2) { *r__data = TH_MATH_NAME(sqrt)(M2 / t_size); - } else if (!biased && t_size >= 2) { + } else if (unbiased && t_size >= 2) { *r__data = TH_MATH_NAME(sqrt)(M2 / (t_size - 1)); - } else if (biased && t_size == 1) { + } else if (!unbiased && t_size == 1) { *r__data = 0; } else { *r__data = NAN; @@ -1094,7 +1095,7 @@ void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int biased, int ke } } -void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int biased, int keepdim) +void THTensor_(var_single)(THTensor *r_, THTensor *t, int dimension, bool unbiased, int keepdim) { THArgCheck(dimension >= 0 && dimension < THTensor_(nDimensionLegacyAll)(t), 3, "invalid dimension %d", dimension); @@ -1119,12 +1120,12 @@ void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int biased, int ke M2 += delta * delta2; } - if (biased && t_size >= 2) + if (!unbiased && t_size >= 2) { *r__data = M2 / t_size; - } else if (!biased && t_size >= 2) { + } else if (unbiased && t_size >= 2) { *r__data = M2 / (t_size - 1); - } else if (biased && t_size == 1) { + } else if (!unbiased && t_size == 1) { *r__data = 0; } else { *r__data = NAN; @@ -1300,18 +1301,18 @@ accreal THTensor_(meanall)(THTensor *tensor) return THTensor_(sumall)(tensor)/THTensor_(nElement)(tensor); } -accreal THTensor_(varall)(THTensor *tensor, int biased) +accreal THTensor_(var_all)(THTensor *tensor, bool unbiased) { accreal mean = THTensor_(meanall)(tensor); accreal sum = 0; TH_TENSOR_APPLY(scalar_t, tensor, sum += (*tensor_data - mean)*(*tensor_data - mean);); - sum /= std::max(0, THTensor_(nElement)(tensor) - (biased ? 0 : 1)); + sum /= std::max(0, THTensor_(nElement)(tensor) - (unbiased ? 1 : 0)); return sum; } -accreal THTensor_(stdall)(THTensor *tensor, int biased) +accreal THTensor_(std_all)(THTensor *tensor, bool unbiased) { - return sqrt(THTensor_(varall)(tensor, biased)); + return sqrt(THTensor_(var_all)(tensor, unbiased)); } void THTensor_(histc)(THTensor *hist, THTensor *tensor, int64_t nbins, scalar_t minvalue, scalar_t maxvalue) diff --git a/aten/src/TH/generic/THTensorRandom.cpp b/aten/src/TH/generic/THTensorRandom.cpp index cd182de9893ff..ea96e5f7bf8a6 100644 --- a/aten/src/TH/generic/THTensorRandom.cpp +++ b/aten/src/TH/generic/THTensorRandom.cpp @@ -39,7 +39,7 @@ void THTensor_(random)(THTensor *self, at::Generator *_generator) } -void THTensor_(clampedRandom)(THTensor *self, at::Generator *_generator, int64_t min, int64_t max) { +void THTensor_(clampedRandom)(THTensor *self, int64_t min, int64_t max, at::Generator *_generator) { THArgCheck(max > min, 2, "max must be greater than min, but got: min = %lld, max = %lld", min, max); uint64_t range = max - min; auto gen = at::get_generator_or_default(_generator, at::detail::getDefaultCPUGenerator()); @@ -54,12 +54,12 @@ void THTensor_(clampedRandom)(THTensor *self, at::Generator *_generator, int64_t TH_TENSOR_APPLY(scalar_t, self, *self_data = static_cast(static_cast((gen->random() % range) + min));) } -void THTensor_(cappedRandom)(THTensor *self, at::Generator *_generator, int64_t max) { +void THTensor_(cappedRandom)(THTensor *self, int64_t max, at::Generator *_generator) { THArgCheck(max > 0, 1, "max must be positive, but got: max = %lld", max); - THTensor_(clampedRandom)(self, _generator, 0, max); + THTensor_(clampedRandom)(self, 0, max, _generator); } -void THTensor_(geometric)(THTensor *self, at::Generator *_generator, double p) +void THTensor_(geometric)(THTensor *self, double p, at::Generator *_generator) { auto gen = at::get_generator_or_default(_generator, at::detail::getDefaultCPUGenerator()); // See Note [Acquire lock when using random generators] @@ -76,7 +76,7 @@ void THTensor_(geometric)(THTensor *self, at::Generator *_generator, double p) #define TH_REAL_MIN DBL_MIN #endif -void THTensor_(uniform)(THTensor *self, at::Generator *_generator, double a, double b) +void THTensor_(uniform)(THTensor *self, double a, double b, at::Generator *_generator) { auto gen = at::get_generator_or_default(_generator, at::detail::getDefaultCPUGenerator()); // See Note [Acquire lock when using random generators] @@ -91,7 +91,7 @@ void THTensor_(uniform)(THTensor *self, at::Generator *_generator, double a, dou #endif } -void THTensor_(normal)(THTensor *self, at::Generator *_generator, double mean, double stddev) +void THTensor_(normal)(THTensor *self, double mean, double stddev, at::Generator *_generator) { const int64_t size = THTensor_(numel)(self); if (size >= 16 && THTensor_(isContiguous)(self)) { @@ -106,31 +106,31 @@ void THTensor_(normal)(THTensor *self, at::Generator *_generator, double mean, d } } -void THTensor_(normal_means)(THTensor *self, at::Generator *gen, THTensor *means, double stddev) +void THTensor_(normal_means)(THTensor *self, THTensor *means, double stddev, at::Generator *gen) { THTensor_(resizeAs)(self, means); - THTensor_(normal)(self, gen, 0, stddev); + THTensor_(normal)(self, 0, stddev, gen); THTensor_(cadd)(self, self, 1, means); } -void THTensor_(normal_stddevs)(THTensor *self, at::Generator *gen, double mean, THTensor *stddevs) +void THTensor_(normal_stddevs)(THTensor *self, double mean, THTensor *stddevs, at::Generator *gen) { THTensor_(resizeAs)(self, stddevs); - THTensor_(normal)(self, gen, 0, 1); + THTensor_(normal)(self, 0, 1, gen); THTensor_(cmul)(self, self, stddevs); at::Tensor self_wrap = THTensor_wrap(self); self_wrap.add_(mean); } -void THTensor_(normal_means_stddevs)(THTensor *self, at::Generator *gen, THTensor *means, THTensor *stddevs) +void THTensor_(normal_means_stddevs)(THTensor *self, THTensor *means, THTensor *stddevs, at::Generator *gen) { THTensor_(resizeAs)(self, means); - THTensor_(normal)(self, gen, 0, 1); + THTensor_(normal)(self, 0, 1, gen); THTensor_(cmul)(self, self, stddevs); THTensor_(cadd)(self, self, 1, means); } -void THTensor_(exponential)(THTensor *self, at::Generator *_generator, double lambda) +void THTensor_(exponential)(THTensor *self, double lambda, at::Generator *_generator) { auto gen = at::get_generator_or_default(_generator, at::detail::getDefaultCPUGenerator()); // See Note [Acquire lock when using random generators] @@ -142,7 +142,7 @@ void THTensor_(exponential)(THTensor *self, at::Generator *_generator, double la #undef TH_REAL_MIN -void THTensor_(cauchy)(THTensor *self, at::Generator *_generator, double median, double sigma) +void THTensor_(cauchy)(THTensor *self, double median, double sigma, at::Generator *_generator) { auto gen = at::get_generator_or_default(_generator, at::detail::getDefaultCPUGenerator()); // See Note [Acquire lock when using random generators] @@ -152,7 +152,7 @@ void THTensor_(cauchy)(THTensor *self, at::Generator *_generator, double median, TH_TENSOR_APPLY(scalar_t, self, *self_data = (scalar_t)cauchy(gen);); } -void THTensor_(logNormal)(THTensor *self, at::Generator *_generator, double mean, double stdv) +void THTensor_(logNormal)(THTensor *self, double mean, double stdv, at::Generator *_generator) { auto gen = at::get_generator_or_default(_generator, at::detail::getDefaultCPUGenerator()); // See Note [Acquire lock when using random generators] @@ -252,7 +252,7 @@ void THTensor_(multinomialAliasSetup)(THTensor *probs, THLongTensor *J, THTensor THLongTensor_free(smaller); THLongTensor_free(larger); } -void THTensor_(multinomialAliasDraw)(THLongTensor *self, at::Generator *_generator, THTensor *q, THLongTensor *J, int n_sample) +void THTensor_(multinomialAliasDraw)(THLongTensor *self, THTensor *q, THLongTensor *J, int n_sample, at::Generator *_generator) { THArgCheck(q->dim() == 1, 1, "expected 1-D probability table, got %d-D probability table instead", diff --git a/aten/src/TH/generic/THTensorRandom.h b/aten/src/TH/generic/THTensorRandom.h index 9eaa1395d5c6d..7bdae7f64a709 100644 --- a/aten/src/TH/generic/THTensorRandom.h +++ b/aten/src/TH/generic/THTensorRandom.h @@ -6,22 +6,22 @@ #include TH_API void THTensor_(random)(THTensor *self, at::Generator *_generator); -TH_API void THTensor_(clampedRandom)(THTensor *self, at::Generator *_generator, int64_t min, int64_t max); -TH_API void THTensor_(cappedRandom)(THTensor *self, at::Generator *_generator, int64_t max); -TH_API void THTensor_(geometric)(THTensor *self, at::Generator *_generator, double p); +TH_API void THTensor_(clampedRandom)(THTensor *self, int64_t min, int64_t max, at::Generator *_generator); +TH_API void THTensor_(cappedRandom)(THTensor *self, int64_t max, at::Generator *_generator); +TH_API void THTensor_(geometric)(THTensor *self, double p, at::Generator *_generator); #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) TH_API void THTensor_(bernoulli_Tensor)(THTensor *self, at::Generator *_generator, THTensor *p); -TH_API void THTensor_(uniform)(THTensor *self, at::Generator *_generator, double a, double b); -TH_API void THTensor_(normal)(THTensor *self, at::Generator *_generator, double mean, double stdv); -TH_API void THTensor_(normal_means)(THTensor *self, at::Generator *gen, THTensor *means, double stddev); -TH_API void THTensor_(normal_stddevs)(THTensor *self, at::Generator *gen, double mean, THTensor *stddevs); -TH_API void THTensor_(normal_means_stddevs)(THTensor *self, at::Generator *gen, THTensor *means, THTensor *stddevs); -TH_API void THTensor_(exponential)(THTensor *self, at::Generator *_generator, double lambda); -TH_API void THTensor_(cauchy)(THTensor *self, at::Generator *_generator, double median, double sigma); -TH_API void THTensor_(logNormal)(THTensor *self, at::Generator *_generator, double mean, double stdv); +TH_API void THTensor_(uniform)(THTensor *self, double a, double b, at::Generator *_generator); +TH_API void THTensor_(normal)(THTensor *self, double mean, double stdv, at::Generator *_generator); +TH_API void THTensor_(normal_means)(THTensor *self, THTensor *means, double stddev, at::Generator *gen); +TH_API void THTensor_(normal_stddevs)(THTensor *self, double mean, THTensor *stddevs, at::Generator *gen); +TH_API void THTensor_(normal_means_stddevs)(THTensor *self, THTensor *means, THTensor *stddevs, at::Generator *gen); +TH_API void THTensor_(exponential)(THTensor *self, double lambda, at::Generator *_generator); +TH_API void THTensor_(cauchy)(THTensor *self, double median, double sigma, at::Generator *_generator); +TH_API void THTensor_(logNormal)(THTensor *self, double mean, double stdv, at::Generator *_generator); TH_API void THTensor_(multinomialAliasSetup)(THTensor *prob_dist, THLongTensor *J, THTensor *q); -TH_API void THTensor_(multinomialAliasDraw)(THLongTensor *self, at::Generator *_generator, THTensor *q, THLongTensor *J, int n_sample); +TH_API void THTensor_(multinomialAliasDraw)(THLongTensor *self, THTensor *q, THLongTensor *J, int n_sample, at::Generator *_generator); #endif #if defined(TH_REAL_IS_BYTE) diff --git a/aten/src/TH/vector/simd.h b/aten/src/TH/vector/simd.h index 28a0e1d40bb66..8f77d24348847 100644 --- a/aten/src/TH/vector/simd.h +++ b/aten/src/TH/vector/simd.h @@ -72,6 +72,13 @@ static inline uint32_t detectHostSIMDExtensions() #endif +#elif defined(__s390x__) + +static inline uint32_t detectHostSIMDExtensions() +{ + return SIMDExtension_DEFAULT; +} + #elif defined(__PPC64__) #if defined(__VSX__) diff --git a/aten/src/THC/THCBlas.cu b/aten/src/THC/THCBlas.cu index be4af89923050..017c2f87dc562 100644 --- a/aten/src/THC/THCBlas.cu +++ b/aten/src/THC/THCBlas.cu @@ -260,7 +260,7 @@ void THCudaBlas_HgemmStridedBatched(THCState *state, char transa, char transb, i (void*)&fBeta, c, rocblas_datatype_f16_r, (int)ldc, strideC, c, rocblas_datatype_f16_r, (int)ldc, strideC, (int) batchCount, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, - 0, 0, NULL, NULL)); + 0, 0)); #else THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); THCublasCheck(cublasGemmStridedBatchedEx(handle, diff --git a/aten/src/THC/THCNumerics.cuh b/aten/src/THC/THCNumerics.cuh index dbe23b512f182..58585f67a41a3 100644 --- a/aten/src/THC/THCNumerics.cuh +++ b/aten/src/THC/THCNumerics.cuh @@ -292,7 +292,6 @@ struct THCNumerics { static inline __host__ __device__ at::Half cos(at::Half a) { return ::cos(a); } static inline __host__ __device__ at::Half sin(at::Half a) { return ::sin(a); } static inline __host__ __device__ at::Half sqrt(at::Half a) { return ::sqrt(a); } - static inline __host__ __device__ at::Half rsqrt(at::Half a) { return ::rsqrt(a); } static inline __host__ __device__ at::Half floor(at::Half a) { return ::floor(a); } static inline __host__ __device__ at::Half trunc(at::Half a) { return ::trunc(a); } static inline __host__ __device__ at::Half acos(at::Half a) { return ::acos(a); } @@ -305,7 +304,6 @@ struct THCNumerics { static inline __host__ __device__ at::Half erf(at::Half a) { return ::erf(a); } static inline __host__ __device__ at::Half erfc(at::Half a) { return ::erfc(a); } static inline __host__ __device__ at::Half abs(at::Half a) { return std::abs(a); } - static inline __host__ __device__ at::Half round(at::Half a) { return ::nearbyint(a); } static inline __host__ __device__ at::Half frac(at::Half a) { #if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__) @@ -346,7 +344,7 @@ struct THCNumerics { }; // DEPRECATED: use math functions from std and cuda math API (if needed) -// note that the functions exp10,rsqrt,erfinv,frac and cinv +// note that the functions exp10,erfinv,frac and cinv // are not in the std namespace template <> struct THCNumerics { @@ -374,7 +372,6 @@ struct THCNumerics { static inline __host__ __device__ float cos (float a) { return cosf(a); } static inline __host__ __device__ float sin (float a) { return sinf(a); } static inline __host__ __device__ float sqrt (float a) { return sqrtf(a); } - static inline __host__ __device__ float rsqrt(float a) { return rsqrtf(a); } static inline __host__ __device__ float floor(float a) { return floorf(a); } static inline __host__ __device__ float trunc(float a) { return truncf(a); } static inline __host__ __device__ float acos (float a) { return acosf(a); } @@ -389,7 +386,6 @@ struct THCNumerics { static inline __host__ __device__ float erf (float a) { return erff(a); } static inline __host__ __device__ float erfc (float a) { return erfcf(a); } static inline __host__ __device__ float abs (float a) { return fabsf(a); } - static inline __host__ __device__ float round(float a) { return nearbyintf(a); } static inline __host__ __device__ float frac (float a) { return a - truncf(a); } static inline __host__ __device__ float cinv (float a) { return 1.0f / a; } static inline __host__ __device__ float add (float a, float b) { return a + b; } @@ -403,7 +399,7 @@ struct THCNumerics { }; // DEPRECATED: use math functions from std and cuda math API (if needed) -// note that the functions exp10,rsqrt,erfinv,frac and cinv +// note that the functions exp10,erfinv,frac and cinv // are not in the std namespace template <> struct THCNumerics { @@ -431,7 +427,6 @@ struct THCNumerics { static inline __host__ __device__ double cos (double a) { return ::cos(a); } static inline __host__ __device__ double sin (double a) { return ::sin(a); } static inline __host__ __device__ double sqrt (double a) { return ::sqrt(a); } - static inline __host__ __device__ double rsqrt(double a) { return ::rsqrt(a); } static inline __host__ __device__ double floor(double a) { return ::floor(a); } static inline __host__ __device__ double trunc(double a) { return ::trunc(a); } static inline __host__ __device__ double acos (double a) { return ::acos(a); } @@ -446,7 +441,6 @@ struct THCNumerics { static inline __host__ __device__ double erf (double a) { return ::erf(a); } static inline __host__ __device__ double erfc (double a) { return ::erfc(a); } static inline __host__ __device__ double abs (double a) { return fabs(a); } - static inline __host__ __device__ double round(double a) { return ::nearbyint(a); } static inline __host__ __device__ double frac (double a) { return a - ::trunc(a); } static inline __host__ __device__ double cinv (double a) { return 1.0 / a; } static inline __host__ __device__ double add (double a, double b) { return a + b; } diff --git a/aten/src/THC/THCScanUtils.cuh b/aten/src/THC/THCScanUtils.cuh index 839abc8947273..75576238ca360 100644 --- a/aten/src/THC/THCScanUtils.cuh +++ b/aten/src/THC/THCScanUtils.cuh @@ -3,12 +3,7 @@ #include #include - -#if defined(__HIP_PLATFORM_HCC__) -#define SCAN_UTILS_WARP_SIZE 64 -#else -#define SCAN_UTILS_WARP_SIZE 32 -#endif +#include // Collection of in-kernel scan / prefix sum utilities @@ -169,7 +164,7 @@ __device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFuncti T carry = __popc(vote); #endif - int warp = threadIdx.x / SCAN_UTILS_WARP_SIZE; + int warp = threadIdx.x / C10_WARP_SIZE; // Per each warp, write out a value if (getLaneId() == 0) { @@ -182,7 +177,7 @@ __device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFuncti // warp shuffle scan for CC 3.0+ if (threadIdx.x == 0) { int current = 0; - for (int i = 0; i < blockDim.x / SCAN_UTILS_WARP_SIZE; ++i) { + for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) { T v = smem[i]; smem[i] = binop(smem[i], current); current = binop(current, v); @@ -213,13 +208,11 @@ __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, Bi *out -= (T) in; // The outgoing carry for all threads is the last warp's sum - *carry = smem[THCCeilDiv(blockDim.x, SCAN_UTILS_WARP_SIZE) - 1]; + *carry = smem[THCCeilDiv(blockDim.x, C10_WARP_SIZE) - 1]; if (KillWARDependency) { __syncthreads(); } } -#undef SCAN_UTILS_WARP_SIZE - #endif // THC_SCAN_UTILS_INC diff --git a/aten/src/THC/THCTensorMathPointwise.cuh b/aten/src/THC/THCTensorMathPointwise.cuh index b034470791fe1..f5750e2037a0e 100644 --- a/aten/src/THC/THCTensorMathPointwise.cuh +++ b/aten/src/THC/THCTensorMathPointwise.cuh @@ -50,115 +50,6 @@ struct TensorMulOp { } }; -template -struct TensorPowOp { - TensorPowOp(T v) : val(v) {} - __device__ __forceinline__ void operator()(T* out, T* in) { - if (StaticExp == 1) { - *out = *in; - } else if (StaticExp == 2) { - *out = THCNumerics::mul(*in, *in); - } else if (StaticExp == 3) { - T square = THCNumerics::mul(*in, *in); - *out = THCNumerics::mul(square, *in); - } else { - *out = THCNumerics::pow(*in, val); - } - } - - __device__ __forceinline__ void operator()(T* v) { - if (StaticExp == 1) { - *v = *v; - } else if (StaticExp == 2) { - *v = THCNumerics::mul(*v, *v); - } else if (StaticExp == 3) { - *v = THCNumerics::mul(THCNumerics::mul(*v, *v), *v); - } else { - *v = THCNumerics::pow(*v, val); - } - } - - const T val; -}; - -template -struct TensorPowOp { - TensorPowOp(T v) : val(v) {} - __device__ __forceinline__ void operator()(T* out, T* in) { - *out = THCNumerics::cinv(*in); - } - - __device__ __forceinline__ void operator()(T* v) { - *v = THCNumerics::cinv(*v); - } - - const T val; -}; - -template -struct TensorPowOp { - TensorPowOp(T v) : val(v) {} - __device__ __forceinline__ void operator()(T* out, T* in) { - T square = THCNumerics::mul(*in, *in); - *out = THCNumerics::cinv(square); - } - - __device__ __forceinline__ void operator()(T* v) { - T square = THCNumerics::mul(*v, *v); - *v = THCNumerics::cinv(square); - } - - const T val; -}; - -template -struct TensorTPowOp { - TensorTPowOp(T v) : val(v) {} - - __device__ __forceinline__ void operator()(T* out, T* in) { - *out = THCNumerics::pow(val, *in); - } - - __device__ __forceinline__ void operator()(T* v) { - *v = THCNumerics::pow(val, *v); - } - - const T val; -}; - -template -struct TensorCPowOp { - __device__ __forceinline__ void operator()(T* out, T* in) { - *out = THCNumerics::pow(*out, *in); - } - - __device__ __forceinline__ void operator()(T* out, T* in1, T* in2) { - *out = THCNumerics::pow(*in1, *in2); - } -}; - -template <> -struct TensorCPowOp { - __device__ __forceinline__ void operator()(float* out, float* in) { - *out = powf(*out, *in); - } - - __device__ __forceinline__ void operator()(float* out, float* in1, float* in2) { - *out = powf(*in1, *in2); - } -}; - -template <> -struct TensorCPowOp { - __device__ __forceinline__ void operator()(double* out, double* in) { - *out = pow(*out, *in); - } - - __device__ __forceinline__ void operator()(double* out, double* in1, double* in2) { - *out = pow(*in1, *in2); - } -}; - template static __device__ __forceinline__ typename std::enable_if::value, bool>::type @@ -505,98 +396,4 @@ struct TensorBitXorOp { } }; -/* - * The following function was converted to CUDA form from code that comes - * with the following copyright notice. It has been released under the BSD license. - * - * Cephes Math Library Release 2.8: June, 2000 - * Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier - */ -template -struct TensorDigammaOp { - __device__ __forceinline__ void - operator()(T* out, T* in) { - using compute_type = typename std::conditional::value, accreal, typename std::conditional::value, accreal, T>::type>::type; - static const double PI_f64 = 3.14159265358979323846; - static const compute_type PSI_10 = 2.25175258906672110764; - static const compute_type A[] = { - 8.33333333333333333333E-2, - -2.10927960927960927961E-2, - 7.57575757575757575758E-3, - -4.16666666666666666667E-3, - 3.96825396825396825397E-3, - -8.33333333333333333333E-3, - 8.33333333333333333333E-2, - }; - - auto x = scalar_cast(*in); - if (x == 0) { - *out = scalar_cast(INFINITY); - return; - } - - bool x_is_integer = x == floor(x); - compute_type result = 0; - if (x < 0) { - if (x_is_integer) { - *out = scalar_cast(INFINITY); - return; - } - // Rounding errors in tan's input can really affect the output - // for extreme values, so we always perform this computation in double. - result = scalar_cast( - - PI_f64 / tan(PI_f64 * scalar_cast(x))); - x = 1 - x; - } - - while (x < 10) { - result -= 1 / x; - x += 1; - } - if (x == 10) { - *out = scalar_cast(result + PSI_10); - return; - } - - compute_type y = 0; - if (x < 1.0e17) { - compute_type z = 1.0 / (x * x); - - compute_type polevl_result = 0; - for (int i = 0; i <= 6; i++) { - polevl_result = polevl_result * z + A[i]; - } - y = z * polevl_result; - } - - *out = scalar_cast(log(x) - (0.5 / x) - y + result); - return; - } -}; - -template -struct TensorTrigammaOp { - using compute_type = typename std::conditional::value, accreal, typename std::conditional::value, accreal, T>::type>::type; - __device__ __forceinline__ void - operator()(T* out, T* in) { - const compute_type PI = 3.14159265358979323846; - compute_type x = ScalarConvert::to(*in); - compute_type sign = +1; - compute_type result = 0; - if (x < 0.5f) { - sign = -1; - compute_type sin_pi_x = THCNumerics::sin(PI * x); - result -= (PI * PI) / (sin_pi_x * sin_pi_x); - x = 1 - x; - } - for (int i = 0; i < 6; ++i) { - result += 1 / (x * x); - x += 1; - } - const compute_type ixx = 1 / (x*x); - result += (1 + 1 / (2*x) + ixx * (1.f/6 - ixx * (1.f/30 - ixx * (1.f/42)))) / x; - *out = ScalarConvert::to(sign * result); - } -}; - #endif // THC_TENSORMATH_POINTWISE_CUH diff --git a/aten/src/THC/THCTensorMathReduce.cuh b/aten/src/THC/THCTensorMathReduce.cuh index c6feed4c92a2c..970672c043949 100644 --- a/aten/src/THC/THCTensorMathReduce.cuh +++ b/aten/src/THC/THCTensorMathReduce.cuh @@ -96,17 +96,17 @@ struct ReduceWelford { template struct VarianceWelford { - VarianceWelford(const int _biased, const bool _apply_sqrt): biased{_biased}, apply_sqrt(_apply_sqrt) {} + VarianceWelford(const int _unbiased, const bool _apply_sqrt): unbiased{_unbiased}, apply_sqrt(_apply_sqrt) {} inline __device__ T operator()(const WelfordData &a) const { - T res = THCNumerics::div(a.m_2_n_, biased!=0 ? a.count_ : a.count_-1); + T res = THCNumerics::div(a.m_2_n_, unbiased ? a.count_ : a.count_-1); if (apply_sqrt) { return THCNumerics::sqrt(res); } return res; } - const int biased; + const int unbiased; const bool apply_sqrt; }; diff --git a/aten/src/THC/generic/THCStorageCopy.cu b/aten/src/THC/generic/THCStorageCopy.cu index 18a5c89897c24..c3cd7c4904b83 100644 --- a/aten/src/THC/generic/THCStorageCopy.cu +++ b/aten/src/THC/generic/THCStorageCopy.cu @@ -2,11 +2,6 @@ #define THC_GENERIC_FILE "THC/generic/THCStorageCopy.cu" #else -void THCStorage_(rawCopy)(THCState *state, THCStorage *self, scalar_t *src) -{ - THCudaCheck(cudaMemcpyAsync(THCStorage_(data)(state, self), src, self->numel() * sizeof(scalar_t), cudaMemcpyDeviceToDevice, THCState_getCurrentStream(state))); -} - // conversions are delegated to THCTensor implementation #define THC_CUDA_STORAGE_IMPLEMENT_COPY(TYPEC,TYPECUDA) \ void THCStorage_(copyCuda##TYPEC)(THCState *state, THCStorage *self, struct THCuda##TYPECUDA##Storage *src) \ diff --git a/aten/src/THC/generic/THCStorageCopy.h b/aten/src/THC/generic/THCStorageCopy.h index ffb37a048d11c..16c5afd147c82 100644 --- a/aten/src/THC/generic/THCStorageCopy.h +++ b/aten/src/THC/generic/THCStorageCopy.h @@ -4,7 +4,6 @@ /* Support for copy between different Storage types */ -THC_API void THCStorage_(rawCopy)(THCState *state, THCStorage *storage, scalar_t *src); THC_API void THCStorage_(copy)(THCState *state, THCStorage *storage, THCStorage *src); THC_API void THCStorage_(copyByte)(THCState *state, THCStorage *storage, struct THByteStorage *src); THC_API void THCStorage_(copyChar)(THCState *state, THCStorage *storage, struct THCharStorage *src); diff --git a/aten/src/THC/generic/THCTensorMasked.cu b/aten/src/THC/generic/THCTensorMasked.cu index 52f654bdc4057..afdf1382e092b 100644 --- a/aten/src/THC/generic/THCTensorMasked.cu +++ b/aten/src/THC/generic/THCTensorMasked.cu @@ -3,6 +3,7 @@ #else #include +#include void THCTensor_(maskedFill)(THCState* state, diff --git a/aten/src/THC/generic/THCTensorMathBlas.cu b/aten/src/THC/generic/THCTensorMathBlas.cu index 9cec73e785945..6c1931232135c 100644 --- a/aten/src/THC/generic/THCTensorMathBlas.cu +++ b/aten/src/THC/generic/THCTensorMathBlas.cu @@ -3,9 +3,8 @@ #else #include "ATen/cuda/CUDAContext.h" -#ifdef BUILD_NAMEDTENSOR #include -#endif +#include #define ERROR_ONLY_FP_TYPES(func) \ THError("%s for CUDA tensors only supports floating-point types. Try converting the tensors with .float()", func); @@ -55,7 +54,7 @@ accreal THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src) #endif } -void THCTensor_(addmv)(THCState *state, THCTensor *r_, scalar_t beta, THCTensor *t, scalar_t alpha, THCTensor *mat, THCTensor *vec) +void THCTensor_(addmv)(THCState *state, THCTensor *r_, THCTensor *t, THCTensor *mat, THCTensor *vec, scalar_t beta, scalar_t alpha) { #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) THCAssertSameGPU(THCTensor_(checkGPU)(state, 4, r_, t, mat, vec)); @@ -153,7 +152,7 @@ void THCTensor_(addmv)(THCState *state, THCTensor *r_, scalar_t beta, THCTensor THCTensor *tAsMatrix = THCTensor_(newWithTensor)(state, t); THCTensor_(resize2d)(state, tAsMatrix, THTensor_sizeLegacyNoScalars(tAsMatrix, 0), 1); - THCTensor_(addmm)(state, r_, beta, tAsMatrix, alpha, mat, vecAsMatrix); + THCTensor_(addmm)(state, r_, tAsMatrix, mat, vecAsMatrix, beta, alpha); // r_ will have answer as matrix, need to return a vector THCTensor_(resize1d)(state, r_, THTensor_sizeLegacyNoScalars(r_, 0)); @@ -168,7 +167,7 @@ void THCTensor_(addmv)(THCState *state, THCTensor *r_, scalar_t beta, THCTensor #endif } -void THCTensor_(addr)(THCState *state, THCTensor *r_, scalar_t beta, THCTensor *t, scalar_t alpha, THCTensor *vec1, THCTensor *vec2) +void THCTensor_(addr)(THCState *state, THCTensor *r_, THCTensor *t, THCTensor *vec1, THCTensor *vec2, scalar_t beta, scalar_t alpha) { #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) THCAssertSameGPU(THCTensor_(checkGPU)(state, 4, r_, t, vec1, vec2)); @@ -256,7 +255,7 @@ void THCTensor_(addr)(THCState *state, THCTensor *r_, scalar_t beta, THCTensor * THCTensor *vec1M = THCTensor_(newWithTensor)(state, vec1); THCTensor_(resize2d)(state, vec1M, vec1_size, 1); - THCTensor_(addmm)(state, r_, beta, t, alpha, vec1M, vec2T); + THCTensor_(addmm)(state, r_, t, vec1M, vec2T, beta, alpha); THCTensor_(free)(state, vec2T); THCTensor_(free)(state, vec1M); #endif @@ -265,7 +264,7 @@ void THCTensor_(addr)(THCState *state, THCTensor *r_, scalar_t beta, THCTensor * #endif } -void THCTensor_(addmm)(THCState *state, THCTensor *r_, scalar_t beta, THCTensor *t, scalar_t alpha, THCTensor *m1, THCTensor *m2) +void THCTensor_(addmm)(THCState *state, THCTensor *r_, THCTensor *t, THCTensor *m1, THCTensor *m2, scalar_t beta, scalar_t alpha) { #ifdef BUILD_NAMEDTENSOR // The logic in this function changes around the pointers, so save a copy of the originals. @@ -441,8 +440,8 @@ void THCTensor_(addmm)(THCState *state, THCTensor *r_, scalar_t beta, THCTensor #endif } -void THCTensor_(addbmm)(THCState *state, THCTensor *result, scalar_t beta, THCTensor *t, - scalar_t alpha, THCTensor *batch1, THCTensor *batch2) { +void THCTensor_(addbmm)(THCState *state, THCTensor *result, THCTensor *t, + THCTensor *batch1, THCTensor *batch2, scalar_t beta, scalar_t alpha) { #if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) THCAssertSameGPU(THCTensor_(checkGPU)(state, 4, result, t, batch1, batch2)); THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, t) == 2, 4, "expected 2D tensor"); @@ -477,7 +476,7 @@ void THCTensor_(addbmm)(THCState *state, THCTensor *result, scalar_t beta, THCTe THCTensor_(select)(state, slice1, batch1, 0, i); THCTensor_(select)(state, slice2, batch2, 0, i); - THCTensor_(addmm)(state, result, beta, result, alpha, slice1, slice2); + THCTensor_(addmm)(state, result, result, slice1, slice2, beta, alpha); beta = ScalarConvert::to(1); } THCTensor_(free)(state, slice1); @@ -505,8 +504,9 @@ __global__ void createBatchGemmBuffer3(const scalar_t** buffer1, const scalar_t } } -void THCTensor_(baddbmm)(THCState *state, THCTensor *result, scalar_t beta, THCTensor *t, - scalar_t alpha, THCTensor *batch1, THCTensor *batch2) { +void THCTensor_(baddbmm)(THCState *state, THCTensor *result, THCTensor *t, + THCTensor *batch1, THCTensor *batch2, + scalar_t beta, scalar_t alpha) { #if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) THCAssertSameGPU(THCTensor_(checkGPU)(state, 4, result, t, batch1, batch2)); THArgCheck(THCTensor_(nDimensionLegacyNoScalars)(state, t) == 3, 4, "expected 3D tensor"); diff --git a/aten/src/THC/generic/THCTensorMathBlas.h b/aten/src/THC/generic/THCTensorMathBlas.h index 39d5e35329d59..98608b016bff1 100644 --- a/aten/src/THC/generic/THCTensorMathBlas.h +++ b/aten/src/THC/generic/THCTensorMathBlas.h @@ -3,10 +3,10 @@ #else THC_API accreal THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src); -THC_API void THCTensor_(addmv)(THCState *state, THCTensor *self, scalar_t beta, THCTensor *t, scalar_t alpha, THCTensor *mat, THCTensor *vec); -THC_API void THCTensor_(addmm)(THCState *state, THCTensor *self, scalar_t beta, THCTensor *t, scalar_t alpha, THCTensor *mat1, THCTensor *mat2); -THC_API void THCTensor_(addr)(THCState *state, THCTensor *self, scalar_t beta, THCTensor *t, scalar_t alpha, THCTensor *vec1, THCTensor *vec2); -THC_API void THCTensor_(addbmm)(THCState *state, THCTensor *result, scalar_t beta, THCTensor *t, scalar_t alpha, THCTensor *batch1, THCTensor *batch2); -THC_API void THCTensor_(baddbmm)(THCState *state, THCTensor *result, scalar_t beta, THCTensor *t, scalar_t alpha, THCTensor *batch1, THCTensor *batch2); +THC_API void THCTensor_(addmv)(THCState *state, THCTensor *self, THCTensor *t, THCTensor *mat, THCTensor *vec, scalar_t beta, scalar_t alpha); +THC_API void THCTensor_(addmm)(THCState *state, THCTensor *self, THCTensor *t, THCTensor *mat1, THCTensor *mat2, scalar_t beta, scalar_t alpha); +THC_API void THCTensor_(addr)(THCState *state, THCTensor *self, THCTensor *t, THCTensor *vec1, THCTensor *vec2, scalar_t beta, scalar_t alpha); +THC_API void THCTensor_(addbmm)(THCState *state, THCTensor *result, THCTensor *t, THCTensor *batch1, THCTensor *batch2, scalar_t beta, scalar_t alpha); +THC_API void THCTensor_(baddbmm)(THCState *state, THCTensor *result, THCTensor *t, THCTensor *batch1, THCTensor *batch2, scalar_t beta, scalar_t alpha); #endif diff --git a/aten/src/THC/generic/THCTensorMathMagma.cu b/aten/src/THC/generic/THCTensorMathMagma.cu index 8883446397a92..fd4b83d7de618 100644 --- a/aten/src/THC/generic/THCTensorMathMagma.cu +++ b/aten/src/THC/generic/THCTensorMathMagma.cu @@ -106,13 +106,14 @@ void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor #endif } -void THCTensor_(geev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobvrs) +void THCTensor_(geev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, bool eigenvectors) { #ifdef USE_MAGMA + char jobvrs = eigenvectors ? 'V' : 'N'; THArgCheck(a_->dim() == 2, 3, "A should be 2 dimensional"); THArgCheck(a_->size(0) == a_->size(1), 3, "A should be square"); - magma_vec_t jobvr = jobvrs[0] == 'N' ? MagmaNoVec : MagmaVec; + magma_vec_t jobvr = jobvrs == 'N' ? MagmaNoVec : MagmaVec; int64_t n = a_->size(0); scalar_t *a_data = th_magma_malloc_pinned(n * n); @@ -202,14 +203,15 @@ __global__ void THCTensor_(copyLowerSymmetric)(scalar_t *input, int n, int len) } } -void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo) +void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, bool upper) { + char uplo = upper ? 'U' : 'L'; #ifdef USE_MAGMA THArgCheck(!a->is_empty() && a->dim() == 2, 2, "A should be non-empty 2 dimensional"); THArgCheck(a->size(0) == a->size(1), 2, "A should be square"); int64_t n = a->size(0); - magma_uplo_t ul = uplo[0] == 'U' ? MagmaUpper : MagmaLower; + magma_uplo_t ul = uplo == 'U' ? MagmaUpper : MagmaLower; THCTensor *input = THCTensor_(newColumnMajor)(state, ra_, a); scalar_t *input_data = THCTensor_(data)(state, input); @@ -230,7 +232,7 @@ void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char const int len = n*n; dim3 blocks(std::min(DIVUP(len, 128), 65535)); dim3 threads(128); - if (uplo[0] == 'U') { + if (uplo == 'U') { THCTensor_(copyUpperSymmetric)<<>>(input_data, n, len); } else { THCTensor_(copyLowerSymmetric)<<>>(input_data, n, len); diff --git a/aten/src/THC/generic/THCTensorMathMagma.h b/aten/src/THC/generic/THCTensorMathMagma.h index 49616c4b084d7..ae46a62c9ec61 100644 --- a/aten/src/THC/generic/THCTensorMathMagma.h +++ b/aten/src/THC/generic/THCTensorMathMagma.h @@ -6,8 +6,8 @@ // MAGMA (i.e. CUDA implementation of LAPACK functions) THC_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_); -THC_API void THCTensor_(geev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobvr); -THC_API void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo); +THC_API void THCTensor_(geev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, bool eigenvectors); +THC_API void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, bool upper); THC_API void THCTensor_(geqrf)(THCState *state, THCTensor *ra_, THCTensor *rtau_, THCTensor *a_); #endif // defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) diff --git a/aten/src/THC/generic/THCTensorMathPointwise.cu b/aten/src/THC/generic/THCTensorMathPointwise.cu index 2e2da7db7d139..8d702d0e61a4a 100644 --- a/aten/src/THC/generic/THCTensorMathPointwise.cu +++ b/aten/src/THC/generic/THCTensorMathPointwise.cu @@ -3,9 +3,8 @@ #else #include -#ifdef BUILD_NAMEDTENSOR #include -#endif +#include void THCTensor_(cbitand)(THCState* state, THCTensor *self_, THCTensor *src1, THCTensor *src2) { @@ -209,7 +208,6 @@ IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(expm1, THCNumerics::expm1, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( cos, THCNumerics::cos, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( sin, THCNumerics::sin, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( sqrt, THCNumerics::sqrt, Real) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(rsqrt, THCNumerics::rsqrt, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(floor, THCNumerics::floor, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(trunc, THCNumerics::trunc, Real) @@ -222,7 +220,6 @@ IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( atan, THCNumerics::atan, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( tanh, THCNumerics::tanh, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( erf, THCNumerics::erf, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( erfc, THCNumerics::erfc, Real) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( round, THCNumerics::round, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( frac, THCNumerics::frac, Real) IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( cinv, THCNumerics::cinv, Real) @@ -292,47 +289,6 @@ void THCTensor_(sigmoid)(THCState* state, THCTensor* self_, THCTensor* src) { #endif } -void THCTensor_(digamma)(THCState* state, THCTensor* self_, THCTensor* src) { - THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); - if (self_ != src) { - THCTensor_(resizeAs)(state, self_, src); - } - if (!THC_pointwiseApply2(state, self_, src, TensorDigammaOp())) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - - THCudaCheck(cudaGetLastError()); -#ifdef BUILD_NAMEDTENSOR - at::namedinference::propagate_names(self_, src); -#endif -} - -void THCTensor_(polygamma)(THCState* state, THCTensor* self_, int64_t n, THCTensor* src) { - THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); - if (self_ != src) { - THCTensor_(resizeAs)(state, self_, src); - } - switch (n) { - case 0: - if (!THC_pointwiseApply2(state, self_, src, TensorDigammaOp())) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - break; - case 1: - if (!THC_pointwiseApply2(state, self_, src, TensorTrigammaOp())) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - break; - default: - THError("polygamma(n,x) is not implemented for n>=2"); - } - - THCudaCheck(cudaGetLastError()); -#ifdef BUILD_NAMEDTENSOR - at::namedinference::propagate_names(self_, src); -#endif -} - #endif namespace { @@ -374,118 +330,6 @@ void THCTensor_(cmul)(THCState *state, THCTensor *self_, THCTensor *src1, THCTen at::mul_out(out, at::Tensor(retainTensorImpl(src1)), at::Tensor(retainTensorImpl(src2))); } -void THCTensor_(cpow)(THCState *state, THCTensor *self_, THCTensor *src1, THCTensor *src2) -{ - THCAssertSameGPU(THCTensor_(checkGPU)(state, 3, self_, src1, src2)); - THArgCheck(THCTensor_(nElement)(state, src1) == - THCTensor_(nElement)(state, src2), 3, "sizes do not match"); - - if (self_ == src1) { - // self = pow(self, src2) - if (!THC_pointwiseApply2(state, self_, src2, TensorCPowOp())) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else { - THCTensor_(resizeAs)(state, self_, src1); - - // self = pow(src1, src2) - if (!THC_pointwiseApply3(state, self_, src1, src2, TensorCPowOp())) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } - - THCudaCheck(cudaGetLastError()); -} - -void THCTensor_(pow)(THCState *state, THCTensor *self_, THCTensor *src, scalar_t value) { -#if defined(THC_REAL_IS_BYTE) || defined(THC_REAL_IS_CHAR) || defined(THC_REAL_IS_SHORT) || defined(THC_REAL_IS_INT) || defined(THC_REAL_IS_LONG) - if (THCNumerics::lt(value, ScalarConvert::to(0))) { - THError("Integers to negative integer powers are not allowed."); - } -#endif - THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); - if (self_ == src) { - if (THCNumerics::eq(value, ScalarConvert::to(1))) { - if (!THC_pointwiseApply1(state, self_, TensorPowOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else if (THCNumerics::eq(value, ScalarConvert::to(2))) { - if (!THC_pointwiseApply1(state, self_, TensorPowOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else if (THCNumerics::eq(value, ScalarConvert::to(3))) { - if (!THC_pointwiseApply1(state, self_, TensorPowOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } -#if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_BFLOAT16) - } else if (THCNumerics::eq(value, ScalarConvert::to(-1))) { - if (!THC_pointwiseApply1(state, self_, TensorPowOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else if (THCNumerics::eq(value, ScalarConvert::to(-2))) { - if (!THC_pointwiseApply1(state, self_, TensorPowOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } -#endif - } else { - // fallback implementation using pow - if (!THC_pointwiseApply1(state, self_, TensorPowOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } - } else { - THCTensor_(resizeAs)(state, self_, src); - - if (THCNumerics::eq(value, ScalarConvert::to(1))) { - if (!THC_pointwiseApply2(state, self_, src, TensorPowOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else if (THCNumerics::eq(value, ScalarConvert::to(2))) { - if (!THC_pointwiseApply2(state, self_, src, TensorPowOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else if (THCNumerics::eq(value, ScalarConvert::to(3))) { - if (!THC_pointwiseApply2(state, self_, src, TensorPowOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } -#if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_BFLOAT16) - } else if (THCNumerics::eq(value, ScalarConvert::to(-1))) { - if (!THC_pointwiseApply2(state, self_, src, TensorPowOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else if (THCNumerics::eq(value, ScalarConvert::to(-2))) { - if (!THC_pointwiseApply2(state, self_, src, TensorPowOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } -#endif - } else { - // fallback implementation using pow - if (!THC_pointwiseApply2(state, self_, src, TensorPowOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } - } - - THCudaCheck(cudaGetLastError()); -} - -void THCTensor_(tpow)(THCState *state, THCTensor *self_, scalar_t value, THCTensor *src) -{ - THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); - if (self_ == src) { - if (!THC_pointwiseApply1(state, self_, TensorTPowOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else { - THCTensor_(resizeAs)(state, self_, src); - - if (!THC_pointwiseApply2(state, self_, src, TensorTPowOp(value))) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } - - THCudaCheck(cudaGetLastError()); -} void THCTensor_(cdiv)(THCState* state, THCTensor *self_, THCTensor *src1, THCTensor *src2) { auto out = at::Tensor(retainTensorImpl(self_)); diff --git a/aten/src/THC/generic/THCTensorMathPointwise.h b/aten/src/THC/generic/THCTensorMathPointwise.h index 8ecf1eeabb87c..1ab3cf34b6ab3 100644 --- a/aten/src/THC/generic/THCTensorMathPointwise.h +++ b/aten/src/THC/generic/THCTensorMathPointwise.h @@ -13,17 +13,11 @@ THC_API void THCTensor_(cminValue)(THCState *state, THCTensor *self, THCTensor * #if !defined(THC_REAL_IS_BOOL) -THC_API void THCTensor_(pow)(THCState *state, THCTensor *self, THCTensor *src, scalar_t value); -THC_API void THCTensor_(tpow)(THCState *state, THCTensor *self, scalar_t value, THCTensor *src); -THC_API void THCTensor_(cpow)(THCState *state, THCTensor *self, THCTensor *src1, THCTensor *src2); - #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) THC_API void THCTensor_(sigmoid)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(log)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(lgamma)(THCState *state, THCTensor *self, THCTensor *src); -THC_API void THCTensor_(digamma)(THCState *state, THCTensor *self, THCTensor *src); -THC_API void THCTensor_(polygamma)(THCState *state, THCTensor *self, int64_t n, THCTensor *src); THC_API void THCTensor_(log10)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(log1p)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(log2)(THCState *state, THCTensor *self, THCTensor *src); @@ -41,9 +35,7 @@ THC_API void THCTensor_(tanh)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(erf)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(erfc)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(sqrt)(THCState *state, THCTensor *self, THCTensor *src); -THC_API void THCTensor_(rsqrt)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(floor)(THCState *state, THCTensor *self, THCTensor *src); -THC_API void THCTensor_(round)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(trunc)(THCState *state, THCTensor *self, THCTensor *src); THC_API void THCTensor_(frac)(THCState *state, THCTensor *self, THCTensor *src); diff --git a/aten/src/THC/generic/THCTensorMathReduce.cu b/aten/src/THC/generic/THCTensorMathReduce.cu index c333b16a62441..60a5a4cf30065 100644 --- a/aten/src/THC/generic/THCTensorMathReduce.cu +++ b/aten/src/THC/generic/THCTensorMathReduce.cu @@ -148,7 +148,7 @@ void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, scalar THCTensor_(free)(state, data); } -void THCTensor_(std)(THCState *state, THCTensor *self_, THCTensor *src, int dimension, int biased, int keepdim) +void THCTensor_(std_single)(THCState *state, THCTensor *self_, THCTensor *src, int dimension, bool unbiased, int keepdim) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); @@ -157,7 +157,7 @@ void THCTensor_(std)(THCState *state, THCTensor *self_, THCTensor *src, int dime if (!THC_reduceDim(state, self_, src, ModifyWelford>{}, ReduceWelford{}, - VarianceWelford{biased, true}, + VarianceWelford{unbiased, true}, init, dimension, keepdim)) { @@ -167,7 +167,7 @@ void THCTensor_(std)(THCState *state, THCTensor *self_, THCTensor *src, int dime THCudaCheck(cudaGetLastError()); } -void THCTensor_(var)(THCState *state, THCTensor *self_, THCTensor *src, int dimension, int biased, int keepdim) +void THCTensor_(var_single)(THCState *state, THCTensor *self_, THCTensor *src, int dimension, bool unbiased, int keepdim) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src)); @@ -176,7 +176,7 @@ void THCTensor_(var)(THCState *state, THCTensor *self_, THCTensor *src, int dime if (!THC_reduceDim(state, self_, src, ModifyWelford>{}, ReduceWelford{}, - VarianceWelford{biased, false}, + VarianceWelford{unbiased, false}, init, dimension, keepdim)) { @@ -186,13 +186,13 @@ void THCTensor_(var)(THCState *state, THCTensor *self_, THCTensor *src, int dime THCudaCheck(cudaGetLastError()); } -accreal THCTensor_(stdall)(THCState *state, THCTensor *self, int biased) +accreal THCTensor_(std_all)(THCState *state, THCTensor *self, bool unbiased) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self)); - return THCNumerics::sqrt((THCTensor_(varall)(state, self, biased))); + return THCNumerics::sqrt((THCTensor_(var_all)(state, self, unbiased))); } -accreal THCTensor_(varall)(THCState *state, THCTensor *self, int biased) +accreal THCTensor_(var_all)(THCState *state, THCTensor *self, bool unbiased) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self)); accreal mean = THCTensor_(meanall)(state, self); @@ -208,7 +208,7 @@ accreal THCTensor_(varall)(THCState *state, THCTensor *self, int biased) val = THCNumerics::div( val, - scalar_cast(std::max(0, THCTensor_(nElement)(state, self) - (biased ? 0 : 1))) + scalar_cast(std::max(0, THCTensor_(nElement)(state, self) - (unbiased ? 1 : 0))) ); THCudaCheck(cudaGetLastError()); diff --git a/aten/src/THC/generic/THCTensorMathReduce.h b/aten/src/THC/generic/THCTensorMathReduce.h index 24c573d7d341b..c93d4b309bfa5 100644 --- a/aten/src/THC/generic/THCTensorMathReduce.h +++ b/aten/src/THC/generic/THCTensorMathReduce.h @@ -21,13 +21,13 @@ THC_API scalar_t THCTensor_(maxall)(THCState *state, THCTensor *self); #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_BFLOAT16) THC_API void THCTensor_(renorm)(THCState *state, THCTensor* self, THCTensor* src, scalar_t value, int dimension, scalar_t max_norm); -THC_API void THCTensor_(std)(THCState *state, THCTensor *self, THCTensor *src, int dim, int biased, int keepdim); +THC_API void THCTensor_(std_single)(THCState *state, THCTensor *self, THCTensor *src, int dim, bool unbiased, int keepdim); THC_API void THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, scalar_t value, int dimension, int keepdim); -THC_API void THCTensor_(var)(THCState *state, THCTensor *self, THCTensor *src, int dim, int biased, int keepdim); +THC_API void THCTensor_(var_single)(THCState *state, THCTensor *self, THCTensor *src, int dim, bool unbiased, int keepdim); -THC_API accreal THCTensor_(stdall)(THCState *state, THCTensor *self, int biased); +THC_API accreal THCTensor_(std_all)(THCState *state, THCTensor *self, bool unbiased); THC_API accreal THCTensor_(normall)(THCState *state, THCTensor *self, scalar_t value); -THC_API accreal THCTensor_(varall)(THCState *state, THCTensor *self, int biased); +THC_API accreal THCTensor_(var_all)(THCState *state, THCTensor *self, bool unbiased); #endif diff --git a/aten/src/THC/generic/THCTensorRandom.cu b/aten/src/THC/generic/THCTensorRandom.cu index beeb1ace0ef70..dcd14a59a0f40 100644 --- a/aten/src/THC/generic/THCTensorRandom.cu +++ b/aten/src/THC/generic/THCTensorRandom.cu @@ -32,16 +32,13 @@ void THCTensor_(renormRows)(struct THCState* state, void THCTensor_(multinomial)(struct THCState *state, THCudaLongTensor *self, - at::Generator* gen_, THCTensor *prob_dist, int n_sample, - int with_replacement) + int with_replacement, + at::Generator* gen_) { - THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self, prob_dist)); auto gen = at::get_generator_or_default(gen_, at::cuda::detail::getDefaultCUDAGenerator()); int inputSize = THCTensor_(nDimensionLegacyAll)(state, prob_dist); - THArgCheck(inputSize > 0 && inputSize <= 2, 2, - "prob_dist must be 1 or 2 dim"); // Categories are in the innermost dimension int64_t numDist = @@ -49,21 +46,8 @@ void THCTensor_(multinomial)(struct THCState *state, int64_t numCategoriesLong = inputSize == 1 ? THCTensor_(sizeLegacyNoScalars)(state, prob_dist, 0) : THCTensor_(sizeLegacyNoScalars)(state, prob_dist, 1); - - // Since the index tensor is float, numCategories cannot exceed max - // float integer precision - THArgCheck(numCategoriesLong <= FLOAT32_MAX_CONSECUTIVE_INT, 2, - "number of categories cannot exceed 2^24"); int numCategories = (int) numCategoriesLong; - THArgCheck(n_sample > 0, 3, "cannot sample <= 0 samples"); - - if (!with_replacement) { - THArgCheck(n_sample <= numCategories, 2, - "cannot sample n_sample > prob_dist:size(1) samples without " - "replacement"); - } - int free_prob_dist = 0; // Restructure data for 2d @@ -282,7 +266,7 @@ void THCTensor_(multinomialAliasSetup)(THCState *state, THCTensor *_probs, THCud THCTensor_free(state, probs); } -void THCTensor_(multinomialAliasDraw)(THCState *state, THCudaLongTensor *self, at::Generator* gen_, THCTensor *_q, THCudaLongTensor *_J, int n_sample){ +void THCTensor_(multinomialAliasDraw)(THCState *state, THCudaLongTensor *self, THCTensor *_q, THCudaLongTensor *_J, int n_sample, at::Generator* gen_){ THArgCheck(_q->dim() == 1, 1, "expected 1-D probability table, got %d-D probability table instead", _q->dim()); diff --git a/aten/src/THC/generic/THCTensorRandom.h b/aten/src/THC/generic/THCTensorRandom.h index 67244afd010a9..f621bfbd80a3d 100644 --- a/aten/src/THC/generic/THCTensorRandom.h +++ b/aten/src/THC/generic/THCTensorRandom.h @@ -6,9 +6,9 @@ #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF) -THC_API void THCTensor_(multinomial)(struct THCState *state, THCudaLongTensor *self, at::Generator* gen_, THCTensor *prob_dist, int n_sample, int with_replacement); +THC_API void THCTensor_(multinomial)(struct THCState *state, THCudaLongTensor *self, THCTensor *prob_dist, int n_sample, int with_replacement, at::Generator* gen_); THC_API void THCTensor_(multinomialAliasSetup)(struct THCState *state, THCTensor *probs, THCudaLongTensor *J, THCTensor *q); -THC_API void THCTensor_(multinomialAliasDraw)(THCState *state, THCudaLongTensor *self, at::Generator* gen_, THCTensor *_q, THCudaLongTensor *_J, int n_sample); +THC_API void THCTensor_(multinomialAliasDraw)(THCState *state, THCudaLongTensor *self, THCTensor *_q, THCudaLongTensor *_J, int n_sample, at::Generator* gen_); #endif #endif diff --git a/aten/src/THC/generic/THCTensorTopK.cu b/aten/src/THC/generic/THCTensorTopK.cu index 902b2b54a91ee..78ecc2af56e76 100644 --- a/aten/src/THC/generic/THCTensorTopK.cu +++ b/aten/src/THC/generic/THCTensorTopK.cu @@ -2,6 +2,8 @@ #define THC_GENERIC_FILE "THC/generic/THCTensorTopK.cu" #else +#include + void THCTensor_(topk)(THCState* state, THCTensor *topK, THCudaLongTensor *indices, @@ -66,12 +68,6 @@ void THCTensor_(topk)(THCState* state, RUN_DIR(INDEX_T, -1); \ } -#ifdef __HIP_PLATFORM_HCC__ -#define TOPK_WARP_SIZE 64 -#else -#define TOPK_WARP_SIZE 32 -#endif - #define RUN_T(INDEX_T) \ TensorInfo inputInfo = \ getTensorInfo(state, input); \ @@ -105,7 +101,7 @@ void THCTensor_(topk)(THCState* state, THError("Slice to sort is too large"); \ } \ \ - dim3 block(std::min(THCRoundUp(sliceSize, (int64_t) TOPK_WARP_SIZE), (int64_t) 1024)); \ + dim3 block(std::min(THCRoundUp(sliceSize, (int64_t) C10_WARP_SIZE), (int64_t) 1024)); \ \ /* This is used as a template parameter to calculate indices. */ \ /* We only specialize it if all collapsed dim sizes are the */ \ @@ -133,7 +129,6 @@ void THCTensor_(topk)(THCState* state, #undef RUN_DIM #undef RUN_DIR #undef RUN_K -#undef TOPK_WARP_SIZE // Sort the results if the user wants them sorted, since our // selection routine does not ensure sorting diff --git a/aten/src/THCUNN/LookupTable.cu b/aten/src/THCUNN/LookupTable.cu index d7e087a353d3f..cff6638f955d7 100644 --- a/aten/src/THCUNN/LookupTable.cu +++ b/aten/src/THCUNN/LookupTable.cu @@ -6,12 +6,7 @@ #include #include #include - -#ifdef __HIP_PLATFORM_HCC__ -const int WARP_SIZE = 64; -#else -const int WARP_SIZE = 32; -#endif +#include template ::to(gradOutput[gradOutputRow + featureDim]); @@ -152,7 +147,7 @@ __global__ void cunn_LookupTable_accGradParametersKernel( #pragma unroll for (int ii = 0; ii < SZ; ii++) { - int featureDim = startFeature + ii * WARP_SIZE; + int featureDim = startFeature + ii * C10_WARP_SIZE; if (featureDim < stride) { gradWeight[weightRow + featureDim] = ScalarConvert::to(weight[ii]); diff --git a/aten/src/THCUNN/LookupTableBag.cu b/aten/src/THCUNN/LookupTableBag.cu index 335ef9b3c3566..5c758b00a0385 100644 --- a/aten/src/THCUNN/LookupTableBag.cu +++ b/aten/src/THCUNN/LookupTableBag.cu @@ -13,12 +13,8 @@ #include #include #include +#include -#if defined(__HIP_PLATFORM_HCC__) -const int WARP_SIZE = 64; -#else -const int WARP_SIZE = 32; -#endif const int MODE_SUM = 0; const int MODE_MEAN = 1; @@ -109,7 +105,7 @@ __global__ void cunn_LookupTableBag_accGradParametersKernel( #pragma unroll for (int ii = 0; ii < SZ; ii++) { - int featureDim = startFeature + ii * WARP_SIZE; + int featureDim = startFeature + ii * C10_WARP_SIZE; if (featureDim < stride) { gradient[ii] = ScalarConvert::to(gradOutput[gradOutputRow + featureDim]); @@ -129,7 +125,7 @@ __global__ void cunn_LookupTableBag_accGradParametersKernel( #pragma unroll for (int ii = 0; ii < SZ; ii++) { - int featureDim = startFeature + ii * WARP_SIZE; + int featureDim = startFeature + ii * C10_WARP_SIZE; if (featureDim < stride) { gradWeight[weightRow + featureDim] = ScalarConvert::to(weight[ii]); diff --git a/aten/src/THCUNN/SpatialDepthwiseConvolution.cu b/aten/src/THCUNN/SpatialDepthwiseConvolution.cu index 2ee0417738991..2874f7d17afe5 100644 --- a/aten/src/THCUNN/SpatialDepthwiseConvolution.cu +++ b/aten/src/THCUNN/SpatialDepthwiseConvolution.cu @@ -12,16 +12,16 @@ #include #include #include +#include -const int WARP_SIZE = 32; // Crude benchmarks suggest 256 is better than 512 and 1024 // TODO: Autotune/use better heuristics, improve speed more. const int MAX_BLOCK_SIZE = 256; static int getGradParamsNumThreads(int batchSize){ //warp per item in a batch, up to a maximum - return std::min(batchSize * WARP_SIZE, MAX_BLOCK_SIZE); + return std::min(batchSize * C10_WARP_SIZE, MAX_BLOCK_SIZE); } @@ -213,9 +213,9 @@ __global__ void spatialDepthwiseConvolutionAccGradParameters( AccT grad = ScalarConvert::to(0.0); - const int laneId = threadIdx.x % WARP_SIZE; - const int batch = threadIdx.x / WARP_SIZE; - const int nwarps = blockDim.x / WARP_SIZE; + const int laneId = threadIdx.x % C10_WARP_SIZE; + const int batch = threadIdx.x / C10_WARP_SIZE; + const int nwarps = blockDim.x / C10_WARP_SIZE; const int imageElements = outputWidth * outputHeight; // Use warp per item. In the original kernel, a threadblock was used to sum over NHW. // Here, we use a warp to sum values over HW dimension, and if batchSize is larger than the @@ -227,7 +227,7 @@ __global__ void spatialDepthwiseConvolutionAccGradParameters( // bring a nice speed-up. for (int batchIdx = batch; batchIdx < batchSize; batchIdx += nwarps){ // Warp-stride loop over elements in a batch item - for (IndexType idx = laneId; idx < imageElements; idx += WARP_SIZE) { + for (IndexType idx = laneId; idx < imageElements; idx += C10_WARP_SIZE) { // Need to calculate the following: batch position, and offset into the gradOutput // in height, and width. We can intuit the corresponding position in the input from // the other parameters we have diff --git a/aten/src/THCUNN/generic/LookupTable.cu b/aten/src/THCUNN/generic/LookupTable.cu index 77b81b2819b1c..3095ce6b3f3ee 100644 --- a/aten/src/THCUNN/generic/LookupTable.cu +++ b/aten/src/THCUNN/generic/LookupTable.cu @@ -3,6 +3,7 @@ #else #include +#include void THNN_(LookupTable_accGradParameters)( THCState *state, @@ -36,15 +37,14 @@ void THNN_(LookupTable_accGradParameters)( cudaStream_t stream = THCState_getCurrentStream(state); if (numel <= 768 && !scaleGradByFreq) { - const int WARP_SIZE = 32; const int BLOCKDIMY = 32; - dim3 grid(THCCeilDiv(stride, (int64_t)WARP_SIZE)); - dim3 block(WARP_SIZE, BLOCKDIMY); + dim3 grid(THCCeilDiv(stride, (int64_t)C10_WARP_SIZE)); + dim3 block(C10_WARP_SIZE, BLOCKDIMY); cunn_LookupTable_accGradParametersKernelByFeature <<>> (THCIndexTensor_(data)(state, input), THCTensor_(data)(state, gradOutput), diff --git a/aten/src/THCUNN/generic/SpatialConvolutionLocal.cu b/aten/src/THCUNN/generic/SpatialConvolutionLocal.cu index c12dae7fff36c..dd981aaf9542b 100644 --- a/aten/src/THCUNN/generic/SpatialConvolutionLocal.cu +++ b/aten/src/THCUNN/generic/SpatialConvolutionLocal.cu @@ -155,9 +155,11 @@ void THNN_(SpatialConvolutionLocal_updateOutput)( // weight: oH*oW x nOutputPlane x nInputPlane*kH*kW // finput3d: oH*oW x nInputPlane*kH*kW x 1 - THCTensor_(baddbmm)(state, output3d, ScalarConvert::to(1), - output3d, ScalarConvert::to(1), - weight, finput3d); + THCTensor_(baddbmm)(state, output3d, + output3d, + weight, finput3d, + ScalarConvert::to(1), + ScalarConvert::to(1)); // output3d: oH*oW x nOutputPlane x 1 THCTensor_(free)(state, output3d); @@ -260,9 +262,10 @@ void THNN_(SpatialConvolutionLocal_updateGradInput)( // weight: oH*oW x nInputPlane*kH*kW x nOutputPlane // gradOutput3d: oH*oW x nOutputPlane x 1 THCTensor_(baddbmm)(state, fgradInput3d, + fgradInput3d, + tweight, gradOutput3d, ScalarConvert::to(0), - fgradInput3d, ScalarConvert::to(1), - tweight, gradOutput3d); + ScalarConvert::to(1)); // fgradInput3d: oH*oW x nInputPlane*kH*kW x 1 // Unpack columns back into input: @@ -381,8 +384,9 @@ void THNN_(SpatialConvolutionLocal_accGradParameters)( // gradOutput3d: oH*oW x nOutputPlane x 1 // finput3d: oH*oW x 1 x kW*kH*nInputPlane - THCTensor_(baddbmm)(state, gradWeight, ScalarConvert::to(1), - gradWeight, scale, gradOutput3d, finput3d); + THCTensor_(baddbmm)(state, gradWeight, + gradWeight, gradOutput3d, finput3d, + ScalarConvert::to(1), scale); // gradWeight: oH*oW x nOutputPlane x kW*kH*nInputPlane THCTensor_(cadd)(state, gradBias, gradBias, scale, gradOutput_n); diff --git a/aten/src/THCUNN/generic/TemporalConvolution.cu b/aten/src/THCUNN/generic/TemporalConvolution.cu index 7f6b60f0ec284..e9fb2a561c7d6 100644 --- a/aten/src/THCUNN/generic/TemporalConvolution.cu +++ b/aten/src/THCUNN/generic/TemporalConvolution.cu @@ -102,7 +102,7 @@ void THNN_(TemporalConvolution_updateOutput)( THCTensor *tweight = THCTensor_(new)(state); THCTensor_(transpose)(state, tweight, weight, 0, 1); - THCTensor_(addmm)(state, outputWindow, ScalarConvert::to(1), outputWindow, ScalarConvert::to(1), inputWindow, tweight); + THCTensor_(addmm)(state, outputWindow, outputWindow, inputWindow, tweight, ScalarConvert::to(1), ScalarConvert::to(1)); THCTensor_(free)(state, tweight); } } @@ -150,7 +150,7 @@ void THNN_(TemporalConvolution_updateOutput)( THCTensor *tweight = THCTensor_(new)(state); THCTensor_(transpose)(state, tweight, weight, 0, 1); - THCTensor_(addmm)(state, outputWindow, ScalarConvert::to(1), outputWindow, ScalarConvert::to(1), inputWindow, tweight); + THCTensor_(addmm)(state, outputWindow, outputWindow, inputWindow, tweight, ScalarConvert::to(1), ScalarConvert::to(1)); THCTensor_(free)(state, tweight); } } @@ -225,7 +225,7 @@ void THNN_(TemporalConvolution_updateGradInput)( nFrame, inputFrameStride*gradInput->size(1), kW*gradInput->size(1), 1); - THCTensor_(addmm)(state, gradInputWindow, ScalarConvert::to(1), gradInputWindow, ScalarConvert::to(1), gradOutputWindow, weight); + THCTensor_(addmm)(state, gradInputWindow, gradInputWindow, gradOutputWindow, weight, ScalarConvert::to(1), ScalarConvert::to(1)); } } else @@ -257,7 +257,7 @@ void THNN_(TemporalConvolution_updateGradInput)( nFrame, inputFrameStride*gradInputSample->size(1), kW*gradInputSample->size(1), 1); - THCTensor_(addmm)(state, gradInputWindow, ScalarConvert::to(1), gradInputWindow, ScalarConvert::to(1), gradOutputWindow, weight); + THCTensor_(addmm)(state, gradInputWindow, gradInputWindow, gradOutputWindow, weight, ScalarConvert::to(1), ScalarConvert::to(1)); } } THCTensor_(free)(state, gradOutputSample); @@ -336,7 +336,7 @@ void THNN_(TemporalConvolution_accGradParameters)( THCTensor *tgradOutputWindow = THCTensor_(new)(state); THCTensor_(transpose)(state, tgradOutputWindow, gradOutputWindow, 0, 1); - THCTensor_(addmm)(state, gradWeight, ScalarConvert::to(1), gradWeight, scale, tgradOutputWindow, inputWindow); + THCTensor_(addmm)(state, gradWeight, gradWeight, tgradOutputWindow, inputWindow, ScalarConvert::to(1), scale); THCTensor_(free)(state, tgradOutputWindow); } } @@ -379,7 +379,7 @@ void THNN_(TemporalConvolution_accGradParameters)( THCTensor *tgradOutputWindow = THCTensor_(new)(state); THCTensor_(transpose)(state, tgradOutputWindow, gradOutputWindow, 0, 1); - THCTensor_(addmm)(state, gradWeight, ScalarConvert::to(1), gradWeight, scale, tgradOutputWindow, inputWindow); + THCTensor_(addmm)(state, gradWeight, gradWeight, tgradOutputWindow, inputWindow, ScalarConvert::to(1), scale); THCTensor_(free)(state, tgradOutputWindow); } } diff --git a/aten/src/THNN/THNN.h b/aten/src/THNN/THNN.h index 0887df012c395..b75844831db91 100644 --- a/aten/src/THNN/THNN.h +++ b/aten/src/THNN/THNN.h @@ -22,4 +22,7 @@ typedef void THNNState; #include #include +#include +#include + #endif diff --git a/aten/src/THNN/generic/ClassNLLCriterion.c b/aten/src/THNN/generic/ClassNLLCriterion.c index 0d0835794fc46..55e2c91c2243a 100644 --- a/aten/src/THNN/generic/ClassNLLCriterion.c +++ b/aten/src/THNN/generic/ClassNLLCriterion.c @@ -44,7 +44,9 @@ void THNN_(ClassNLLCriterion_updateOutput)( continue; } if (cur_target >= 0 && cur_target < n_classes) { - scalar_t cur_weight = weights ? THTensor_(fastGetLegacy1dNoScalars)(weights, cur_target) : 1.0f; + scalar_t cur_weight = + weights ? THTensor_(fastGetLegacy1dNoScalars)(weights, cur_target) + : (scalar_t)1.0f; THTensor_(fastSet1d)(output, i, -THTensor_(fastGet2d)(input, i, cur_target) * cur_weight); } else { int tmp = -1; @@ -78,7 +80,8 @@ void THNN_(ClassNLLCriterion_updateOutput)( int cur_target = target_data[0]; if (cur_target != ignore_index) { THAssert(cur_target >= 0 && cur_target < n_classes); - total_weight_data[0] = weights ? weights_data[cur_target] : 1.0f; + total_weight_data[0] = + weights ? weights_data[cur_target] : (scalar_t)1.0f; output_data[0] = -input_data[cur_target] * total_weight_data[0]; } } else if (THTensor_(nDimensionLegacyAll)(input) == 2) { @@ -93,7 +96,8 @@ void THNN_(ClassNLLCriterion_updateOutput)( if (cur_target != ignore_index) { THAssert(cur_target >= 0 && cur_target < n_classes); - scalar_t cur_weight = weights ? weights_data[cur_target] : 1.0f; + scalar_t cur_weight = + weights ? weights_data[cur_target] : (scalar_t)1.0f; total_weight_data[0] += cur_weight; output_data[0] -= input_data[i * n_target + cur_target] * cur_weight; } @@ -154,7 +158,9 @@ void THNN_(ClassNLLCriterion_updateGradInput)( if (cur_target == ignore_index) { continue; } - scalar_t weight = weights ? THTensor_(fastGetLegacy1dNoScalars)(weights, cur_target) : 1.0f; + scalar_t weight = + weights ? THTensor_(fastGetLegacy1dNoScalars)(weights, cur_target) + : (scalar_t)1.0f; THTensor_(fastSet2d)(gradInput, i, cur_target, -weight * THTensor_(fastGetLegacy1dNoScalars)(gradOutput, i)); } }); @@ -182,8 +188,9 @@ void THNN_(ClassNLLCriterion_updateGradInput)( if (cur_target != ignore_index) { THAssert(cur_target >= 0 && cur_target < n_classes); - gradInput_data[cur_target] = - (reduction != Reduction::Mean && weights) ? -weights_data[cur_target] : -1; + gradInput_data[cur_target] = (reduction != Reduction::Mean && weights) + ? -weights_data[cur_target] + : (scalar_t)-1; gradInput_data[cur_target] *= gradOutput_value; } @@ -201,7 +208,8 @@ void THNN_(ClassNLLCriterion_updateGradInput)( THAssert(cur_target >= 0 && cur_target < n_classes); gradInput_data[i * n_target + cur_target] = - -(weights ? weights_data[cur_target] : 1.0f) * gradOutput_value; + -(weights ? weights_data[cur_target] : (scalar_t)1.0f) * + gradOutput_value; if (reduction == Reduction::Mean && *total_weight_data) { gradInput_data[i * n_target + cur_target] /= *total_weight_data; diff --git a/aten/src/THNN/generic/SpatialConvolutionMM.c b/aten/src/THNN/generic/SpatialConvolutionMM.c index ea9d9f215b316..a9a50f5ea8269 100644 --- a/aten/src/THNN/generic/SpatialConvolutionMM.c +++ b/aten/src/THNN/generic/SpatialConvolutionMM.c @@ -132,7 +132,7 @@ static void THNN_(SpatialConvolutionMM_updateOutput_frame)( THTensor_(zero)(output); } - THTensor_(addmm)(output2d, 1, output2d, 1, weight, finput); + THTensor_(addmm)(output2d, output2d, weight, finput, 1, 1); c10::raw::intrusive_ptr::decref(output2d); } @@ -236,7 +236,7 @@ static void THNN_(SpatialConvolutionMM_updateGradInput_frame)( (THTensor_getStoragePtr(gradOutput), gradOutput->storage_offset(), gradOutput->size(0), -1, gradOutput->size(1)*gradOutput->size(2), -1); - THTensor_(addmm)(fgradInput, 0, fgradInput, 1, weight, gradOutput2d); + THTensor_(addmm)(fgradInput, fgradInput, weight, gradOutput2d, 0, 1); c10::raw::intrusive_ptr::decref(gradOutput2d); THTensor_(zero)(gradInput); @@ -330,7 +330,7 @@ static void THNN_(SpatialConvolutionMM_accGradParameters_frame)( if (gradWeight) { THTensor *tfinput = THTensor_(new)(); THTensor_(transpose)(tfinput, finput, 0, 1); - THTensor_(addmm)(gradWeight, 1, gradWeight, scale, gradOutput2d, tfinput); + THTensor_(addmm)(gradWeight, gradWeight, gradOutput2d, tfinput, 1, scale); c10::raw::intrusive_ptr::decref(tfinput); } diff --git a/aten/src/THNN/generic/VolumetricConvolutionMM.c b/aten/src/THNN/generic/VolumetricConvolutionMM.c index eb6ab354b4500..b22a148ae6c34 100644 --- a/aten/src/THNN/generic/VolumetricConvolutionMM.c +++ b/aten/src/THNN/generic/VolumetricConvolutionMM.c @@ -214,7 +214,7 @@ static void THNN_(unfolded_copy_vol)( *dst = (h >= 0 && w >= 0 && d >= 0 && h < inputHeight && w < inputWidth && d < inputDepth) ? - input_data[nip*inputDHW+ d*inputHW + h*inputWidth + w] : 0; + input_data[nip*inputDHW+ d*inputHW + h*inputWidth + w] : scalar_t(0); count++; if (count < line_seg_len) { @@ -304,7 +304,7 @@ static void THNN_(VolumetricConvolutionMM_updateOutput_frame)( THTensor_(zero)(output); } - THTensor_(addmm)(output2d, 1, output2d, 1, weight, finput); + THTensor_(addmm)(output2d, output2d, weight, finput, 1, 1); c10::raw::intrusive_ptr::decref(output2d); } @@ -544,7 +544,7 @@ static void THNN_(VolumetricConvolutionMM_updateGradInput_frame)( gradOutput->size(1)*gradOutput->size(2)*gradOutput->size(3), -1 ); - THTensor_(addmm)(fgradInput, 0, fgradInput, 1, weight, gradOutput2d); + THTensor_(addmm)(fgradInput, fgradInput, weight, gradOutput2d, 0, 1); c10::raw::intrusive_ptr::decref(gradOutput2d); THTensor_(zero)(gradInput); @@ -650,7 +650,7 @@ static void THNN_(VolumetricConvolutionMM_accGradParameters_frame)( if (gradWeight){ THTensor *tfinput = THTensor_(new)(); THTensor_(transpose)(tfinput, finput, 0, 1); - THTensor_(addmm)(gradWeight, 1, gradWeight, scale, gradOutput2d, tfinput); + THTensor_(addmm)(gradWeight, gradWeight, gradOutput2d, tfinput, 1, scale); c10::raw::intrusive_ptr::decref(tfinput); } diff --git a/aten/src/THNN/init.cpp b/aten/src/THNN/init.cpp index 453e370a3f762..098ba55fec4e8 100644 --- a/aten/src/THNN/init.cpp +++ b/aten/src/THNN/init.cpp @@ -70,6 +70,9 @@ #include #include +#include +#include + #include #include @@ -127,12 +130,18 @@ #include #include +#include +#include + #include #include #include #include +#include +#include + #include #include @@ -141,3 +150,6 @@ #include #include + +#include +#include diff --git a/benchmarks/operator_benchmark/benchmark_core.py b/benchmarks/operator_benchmark/benchmark_core.py index e6613b7524e84..4bf8295c53bc6 100644 --- a/benchmarks/operator_benchmark/benchmark_core.py +++ b/benchmarks/operator_benchmark/benchmark_core.py @@ -3,12 +3,14 @@ from __future__ import print_function from __future__ import unicode_literals -import cpp_extension # noqa import functools import numpy as np import timeit -import torch import json +import torch + +# needs to be imported after torch +import cpp_extension # noqa import benchmark_utils from collections import namedtuple diff --git a/benchmarks/operator_benchmark/pt/add_test.py b/benchmarks/operator_benchmark/pt/add_test.py index d99fd964918d4..e37ab100f40bd 100644 --- a/benchmarks/operator_benchmark/pt/add_test.py +++ b/benchmarks/operator_benchmark/pt/add_test.py @@ -8,27 +8,27 @@ """Microbenchmarks for element-wise Add operator. Supports both Caffe2/PyTorch.""" -# Configs for PT add operator +# Configs for PT add operator add_long_configs = op_bench.cross_product_configs( M=[8, 64, 128], N=range(2, 10, 3), - K=[2 ** x for x in range(0, 3)], + K=[2 ** x for x in range(0, 3)], tags=["long"] ) add_short_configs = op_bench.config_list( attrs=[ - [8, 16, 32], - [16, 32, 64], + [32, 64, 64], + [64, 64, 64], ], - attr_names=["M", "N", "K"], - tags=["short"], + attr_names=["M", "N", "K"], + tags=["short"], ) class AddBenchmark(op_bench.TorchBenchmarkBase): - def init(self, M, N, K): + def init(self, M, N, K): self.input_one = torch.rand(M, N, K) self.input_two = torch.rand(M, N, K) self.set_module_name("add") diff --git a/benchmarks/operator_benchmark/pt/qconv_test.py b/benchmarks/operator_benchmark/pt/qconv_test.py index 32f04c3cf46eb..bcc80bd602b0d 100644 --- a/benchmarks/operator_benchmark/pt/qconv_test.py +++ b/benchmarks/operator_benchmark/pt/qconv_test.py @@ -78,8 +78,41 @@ def forward(self): return self.qconv2d(self.input) +class QConv2dChainedBenchmark(op_bench.TorchBenchmarkBase): + def init(self, N, IC, OC, H, W, G, kernel, stride, pad): + scale = 1.0 / 255 + zero_point = 0 + X = torch.randn(N, IC, H, W, dtype=torch.float32) + qX = torch.quantize_linear( + X, scale=scale, zero_point=zero_point, dtype=torch.quint8 + ) + W = torch.randn(OC, IC // G, kernel, kernel, dtype=torch.float32) + qW = torch.quantize_linear(W, scale=scale, zero_point=0, dtype=torch.qint8) + + self.input = qX + self.qconv2d = nnq.Conv2d(IC, OC, kernel, stride=stride, padding=pad, groups=G) + self.qconv2d.weight = qW + self.qconv2d.scale = torch.tensor([scale], dtype=torch.double) + self.qconv2d.zero_point = torch.tensor([zero_point], dtype=torch.int) + + W2 = torch.randn(OC, OC // G, kernel, kernel, dtype=torch.float32) + qW2 = torch.quantize_linear(W2, scale=scale, zero_point=0, dtype=torch.qint8) + self.qconv2d2 = nnq.Conv2d(OC, OC, kernel, stride=stride, padding=pad, groups=G) + self.qconv2d2.weight = qW2 + self.qconv2d2.scale = torch.tensor([scale], dtype=torch.double) + self.qconv2d2.zero_point = torch.tensor([zero_point], dtype=torch.int) + self.set_module_name("QConv2dChained") + + def forward(self): + # test that layout propagation works fine + x = self.qconv2d(self.input) + x = x.relu() + return self.qconv2d2(x) + + op_bench.generate_pt_test(qconv_2d_configs, QConv2dBenchmark) op_bench.generate_pt_test(resnext_32_4d_shape_configs, QConv2dBenchmark) +op_bench.generate_pt_test(qconv_2d_configs, QConv2dChainedBenchmark) if __name__ == "__main__": diff --git a/benchmarks/operator_benchmark/pt/softmax_test.py b/benchmarks/operator_benchmark/pt/softmax_test.py index ea315110ffb11..8b4f4e6c4ac27 100644 --- a/benchmarks/operator_benchmark/pt/softmax_test.py +++ b/benchmarks/operator_benchmark/pt/softmax_test.py @@ -17,8 +17,8 @@ # Configs for softmax ops softmax_configs_short = op_bench.config_list( attrs=[ - [1, 3, 32, 32], - [2, 3, 64, 64], + [4, 3, 128, 128], + [8, 3, 256, 256], ], attr_names=[ 'N', 'C', 'H', 'W' @@ -29,7 +29,7 @@ softmax_configs_long = op_bench.config_list( attrs=[ [8, 3, 128, 128], - [16, 512, 14, 14], + [16, 512, 14, 14], [16, 256, 28, 28], ], attr_names=[ @@ -57,8 +57,8 @@ def forward(self): return self.op_func(self.input_one) -op_bench.generate_pt_tests_from_op_list(softmax_ops_list, - softmax_configs_short + softmax_configs_long, +op_bench.generate_pt_tests_from_op_list(softmax_ops_list, + softmax_configs_short + softmax_configs_long, SoftmaxBenchmark) diff --git a/benchmarks/operator_benchmark/pt/unary_test.py b/benchmarks/operator_benchmark/pt/unary_test.py index 79e6375bf97b8..5ae5deec25d85 100644 --- a/benchmarks/operator_benchmark/pt/unary_test.py +++ b/benchmarks/operator_benchmark/pt/unary_test.py @@ -56,7 +56,6 @@ def forward(self): ['cos', torch.cos], ['cos_', torch.cos_], ['cosh', torch.cosh], - ['cosh_', torch.cosh_], ['digamma', torch.digamma], ['erf', torch.erf], ['erf_', torch.erf_], @@ -97,7 +96,6 @@ def forward(self): ['sin', torch.sin], ['sin_', torch.sin_], ['sinh', torch.sinh], - ['sinh_', torch.sinh_], ['sqrt', torch.sqrt], ['sqrt_', torch.sqrt_], ['tan', torch.tan], @@ -111,7 +109,6 @@ def forward(self): ['bernoulli_', lambda t: t.bernoulli_()], ['cauchy_', lambda t: t.cauchy_()], ['digamma_', lambda t: t.digamma_()], - ['erfinv_', lambda t: t.erfinv_()], ['exponential_', lambda t: t.exponential_()], ['normal_', lambda t: t.normal_()], ['random_', lambda t: t.random_()], diff --git a/c10/core/QEngine.h b/c10/core/QEngine.h new file mode 100644 index 0000000000000..2b1cd87f160d6 --- /dev/null +++ b/c10/core/QEngine.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include + +namespace c10 { + +/** + * QEngine is an enum that is used to select the engine to run quantized ops. + */ +enum class QEngine : uint8_t { + NoQEngine = 0, + FBGEMM = 1, + QNNPACK = 2, +}; + +constexpr auto kNoQEngine = QEngine::NoQEngine; +constexpr auto kFBGEMM = QEngine::FBGEMM; +constexpr auto kQNNPACK = QEngine::QNNPACK; + +inline std::string toString(QEngine qengine) { + switch (qengine) { + case kNoQEngine: + return "NoQEngine"; + case kFBGEMM: + return "FBGEMM"; + case kQNNPACK: + return "QNNPACK"; + default: + TORCH_CHECK( + false, + "Unrecognized Quantized Engine: ", + static_cast(qengine)); + } +} + +} // namespace c10 diff --git a/c10/core/Scalar.cpp b/c10/core/Scalar.cpp index 2e6423feec564..b202b485ff11a 100644 --- a/c10/core/Scalar.cpp +++ b/c10/core/Scalar.cpp @@ -3,6 +3,7 @@ namespace c10 { Scalar Scalar::operator-() const { + TORCH_CHECK(!isBoolean(), "torch boolean negative, the `-` operator, is not suppported."); if (isFloatingPoint()) { return Scalar(-v.d); } else if (isComplex()) { diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index d4c02fb4586b2..6d79f3004cd4a 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -38,8 +38,8 @@ class C10_API Scalar { typename T, typename std::enable_if::value, bool>::type* = nullptr> - Scalar(T vv) : tag(Tag::HAS_i) { - v.i = convert(vv); + Scalar(T vv) : tag(Tag::HAS_b) { + v.i = convert(vv); } #define DEFINE_IMPLICIT_COMPLEX_CTOR(type, name, member) \ @@ -61,35 +61,45 @@ class C10_API Scalar { } else if (Tag::HAS_z == tag) { \ return checked_convert>( \ {v.z[0], v.z[1]}, #type); \ + } if (Tag::HAS_b == tag) { \ + return checked_convert(v.i, #type); \ } else { \ return checked_convert(v.i, #type); \ } \ } // TODO: Support ComplexHalf accessor - AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF( - DEFINE_ACCESSOR) + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_ACCESSOR) // also support scalar.to(); template - T to(); + T to() const; #undef DEFINE_ACCESSOR bool isFloatingPoint() const { return Tag::HAS_d == tag; } + + C10_DEPRECATED_MESSAGE("isIntegral is deprecated. Please use the overload with 'includeBool' parameter instead.") bool isIntegral() const { return Tag::HAS_i == tag; } + bool isIntegral(bool includeBool) const { + return Tag::HAS_i == tag || (includeBool && isBoolean()); + } + bool isComplex() const { return Tag::HAS_z == tag; } + bool isBoolean() const { + return Tag::HAS_b == tag; + } Scalar operator-() const; private: template::is_integer, bool>::type* = + typename std::enable_if::is_integer && ! std::is_same::value, bool>::type* = nullptr> Scalar(T vv, bool) : tag(Tag::HAS_i) { v.i = convert(vv); @@ -105,7 +115,7 @@ class C10_API Scalar { // We can't set v in the initializer list using the // syntax v{ .member = ... } because it doesn't work on MSVC - enum class Tag { HAS_d, HAS_i, HAS_z }; + enum class Tag { HAS_d, HAS_i, HAS_z, HAS_b }; Tag tag; union { double d; @@ -119,13 +129,13 @@ class C10_API Scalar { // define the scalar.to() specializations template -inline T Scalar::to() { +inline T Scalar::to() const { throw std::runtime_error("to() cast to unexpected type."); } #define DEFINE_TO(T, name) \ template <> \ - inline T Scalar::to() { \ + inline T Scalar::to() const { \ return to##name(); \ } AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_TO) diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index a87250932fc94..c582b2cb7cf70 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -307,10 +307,46 @@ static inline ScalarType toUnderlying(ScalarType t) { } } +static inline bool isSignedType(ScalarType t) { + #define CASE_SIGNED(ctype, name) \ + case ScalarType::name: \ + return std::numeric_limits::is_signed; + + switch (t) { + AT_FORALL_SCALAR_TYPES_AND(Half, CASE_SIGNED) + default: + AT_ERROR("Unknown ScalarType"); + } + #undef CASE_SIGNED +} + static inline bool isUnderlying(ScalarType type, ScalarType qtype) { return type == toUnderlying(qtype); } +// see tensor_attributes.rst for detailed explanation and examples +// of casting rules. +static inline bool canCast(const ScalarType from, const ScalarType to) { + // We disallow float -> integral, e.g., int_tensor *= float is disallowed. + if (isFloatingType(from) && isIntegralType(to, false)) { + return false; + } + + // Treat bool as a distinct "category," to be consistent with type promotion + // rules (e.g. `bool_tensor + 5 -> int64_tensor`). If `5` was in the same category + // as `bool_tensor`, we would not promote. + // Differing categories implies `bool_tensor += 5` is disallowed. + // + // NB: numpy distinguishes "unsigned" as a category to get the desired + // `bool_tensor + 5 -> int64_tensor` behavior. We don't, because: + // * We don't want the performance hit of checking the runtime sign of Scalars. + // * `uint8_tensor + 5 -> int64_tensor` would be undesirable. + if (from != ScalarType::Bool && to == ScalarType::Bool) { + return false; + } + return true; +} + static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { // This is generated according to NumPy's promote_types constexpr auto u1 = ScalarType::Byte; @@ -329,6 +365,10 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) { if (a == ud || b == ud) { return ScalarType::Undefined; } + if (isComplexType(a) || isComplexType(b)) { + AT_ERROR( + "promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be for ", toString(a), " and ", toString(b)); + } // For QInt types, we only allow exact match if (isQIntType(a) && a == b) { diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 152be49c9297a..dd829a8563766 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -2,6 +2,7 @@ #include #include +#include #include C10_DEFINE_bool( @@ -43,13 +44,13 @@ const at::Tensor& TensorImpl::grad() const { } } -TensorImpl::TensorImpl(Storage&& storage, TensorTypeId type_id) - : TensorImpl(std::move(storage), type_id, storage.dtype(), storage.device()) {} +TensorImpl::TensorImpl(Storage&& storage, TensorTypeSet type_set) + : TensorImpl(std::move(storage), type_set, storage.dtype(), storage.device()) {} -TensorImpl::TensorImpl(TensorTypeId type_id, const caffe2::TypeMeta& data_type, c10::optional device_opt) - : TensorImpl({}, type_id, data_type, std::move(device_opt)) {} +TensorImpl::TensorImpl(TensorTypeSet type_set, const caffe2::TypeMeta& data_type, c10::optional device_opt) + : TensorImpl({}, type_set, data_type, std::move(device_opt)) {} -TensorImpl::TensorImpl(Storage&& storage, TensorTypeId type_id, const caffe2::TypeMeta& data_type, +TensorImpl::TensorImpl(Storage&& storage, TensorTypeSet type_set, const caffe2::TypeMeta& data_type, c10::optional device_opt) : storage_(std::move(storage)), sizes_{0}, @@ -57,8 +58,8 @@ TensorImpl::TensorImpl(Storage&& storage, TensorTypeId type_id, const caffe2::Ty numel_(0), data_type_(data_type), device_opt_(device_opt), - type_id_(type_id) { - if (type_id != TensorTypeId::UndefinedTensorId) { + type_set_(type_set.remove(TensorTypeId::VariableTensorId)) { + if (!type_set.empty()) { AT_ASSERT(data_type.id() == caffe2::TypeIdentifier::uninitialized() || device_opt_.has_value()); // UndefinedTensorImpl is a singleton, so we skip logging it @@ -194,62 +195,12 @@ at::DataPtr PlacementDeleteContext::makeDataPtr( AutogradMetaInterface::~AutogradMetaInterface() {} -#ifdef BUILD_NAMEDTENSOR -NamedTensorMetaInterface::~NamedTensorMetaInterface() {} - -std::unique_ptr NamedTensorMetaInterface::clone() const { - TORCH_INTERNAL_ASSERT( - false, - "Attempting to clone a NamedTensorMetaInterface instance."); -} -#endif - -/// NOTE [ Treating Variables as non-Variables in type dispatch ] -/// -/// Previously, in VariableType_*.cpp (generated by gen_variable_type.py), when -/// a function is using the 'use_derived' strategy, we call its implementation -/// on the base non-Variable type (`baseType`), passing unwrapped tensors to the -/// call so that any `.dispatch_type()` calls in the implementation can treat the passed -/// tensors as non-Variables and won't dispatch back to functions in VariableType. -/// -/// However, after the Variable/Tensor merge, there is no concept of unwrapping -/// a tensor anymore, and directly passing variables to the base type calls will -/// cause the `.dispatch_type()` dispatch in the implementation to treat the tensor as a -/// variable, and any function dispatch based on `.dispatch_type()` will dispatch back to -/// VariableType, which is not what we want. -/// -/// The solution to the above problem is to add `at::NonVariableTypeMode`, which -/// when enabled will cause `legacyTensorType()` and `getType()` to always return -/// non-Variable type, even if the tensor being called on is a variable. -/// -/// TODO: Since `torch::NoGradGuard` serves the same purpose in libtorch, we should -/// merge these two thread-local guards. - -/// In the CAFFE2_FB_LIMITED_MOBILE_CAPABILITY build setting, -/// thread_local is not supported. In that case, we don't provide -/// `at::NonVariableTypeMode`. -#ifndef CAFFE2_FB_LIMITED_MOBILE_CAPABILITY - -thread_local bool NonVariableTypeMode_enabled = false; - -bool NonVariableTypeMode::is_enabled() { - return NonVariableTypeMode_enabled; -} - -void NonVariableTypeMode::set_enabled(bool enabled) { - NonVariableTypeMode_enabled = enabled; -} - -#else // defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY) - bool NonVariableTypeMode::is_enabled() { - throw std::runtime_error("NonVariableTypeMode is not supported on mobile"); + return !impl::tls_variable_is_enabled(); } void NonVariableTypeMode::set_enabled(bool enabled) { - throw std::runtime_error("NonVariableTypeMode is not supported on mobile"); + impl::tls_variable_set_enabled(!enabled); } -#endif - } // namespace c10 diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index bb966c62bd5b2..366d2b7cbe6a0 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -8,7 +8,7 @@ #include #include #include -#include +#include #include #include @@ -143,12 +143,19 @@ struct C10_API NonVariableTypeMode { static void set_enabled(bool enabled); }; -#ifdef BUILD_NAMEDTENSOR struct C10_API NamedTensorMetaInterface { - virtual ~NamedTensorMetaInterface(); - virtual std::unique_ptr clone() const; + virtual ~NamedTensorMetaInterface() {}; + virtual std::unique_ptr clone() const { + TORCH_INTERNAL_ASSERT( + false, + "Not implemented: NamedTensorMetaInterface::clone"); + }; + virtual int64_t slow_dim() const { + TORCH_INTERNAL_ASSERT( + false, + "Not implemented: NamedTensorMetaInterface::slow_dim"); + }; }; -#endif // NOTE [ Version Counter Sharing ] // @@ -289,19 +296,26 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { /** * Construct a 1-dim 0-size tensor backed by the given storage. */ - TensorImpl(Storage&& storage, TensorTypeId type_id); + TensorImpl(Storage&& storage, TensorTypeSet); /** * Construct a 1-dim 0 size tensor that doesn't have a storage. */ - TensorImpl(TensorTypeId type_id, const caffe2::TypeMeta& data_type, c10::optional device_opt); + TensorImpl(TensorTypeSet, const caffe2::TypeMeta& data_type, c10::optional device_opt); + + // Legacy constructors so I don't have to go update call sites. + // TODO: When Variable is added, delete these constructors + TensorImpl(Storage&& storage, TensorTypeId type_id) + : TensorImpl(std::move(storage), TensorTypeSet(type_id)) {} + TensorImpl(TensorTypeId type_id, const caffe2::TypeMeta& data_type, c10::optional device_opt) + : TensorImpl(TensorTypeSet(type_id), data_type, device_opt) {} private: // This constructor is private, because the data_type is redundant with // storage. Still, we pass it in separately because it's easier to write // the initializer list if we're not worried about storage being moved out // from under us. - TensorImpl(Storage&& storage, TensorTypeId type_id, const caffe2::TypeMeta& data_type, c10::optional); + TensorImpl(Storage&& storage, TensorTypeSet, const caffe2::TypeMeta& data_type, c10::optional); public: TensorImpl(const TensorImpl&) = delete; @@ -316,35 +330,12 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { */ virtual void release_resources() override; - // TODO: Ideally, type_id() would be the *only* key we need to consult - // to do a dispatch, instead of having to grovel through three different - // variables. Here's what's standing in the way: - // - // - To eliminate ScalarType, we have to allocate a TensorTypeId for - // each ScalarType+Backend combination, and then set it appropriately - // when we initially allocate a TensorImpl. - // - // - To eliminate is_variable, we have to allocate two classes of - // TensorTypeId: ones that are variables, and ones that are not. - // We may not want to eliminate this in the short term, because - // hard-coding variable status into type_id() makes it more difficult - // to do the "thread-local no_grad" trick (where we process Variables - // "as if" they were non-Variables by setting a thread local variable.) - // - // TODO: type() is a very attractive name for a method, but we don't - // actually want people to use it. Rename this to something else. - /** - * Return the TensorTypeId corresponding to this Tensor. In the future, - * this will be the sole piece of information required to dispatch - * to an operator; however, at the moment, it is not used for - * dispatch. - * - * type_id() and type() are NOT in one-to-one correspondence; we only - * have a single type_id() for CPU tensors, but many Types (CPUFloatTensor, - * CPUDoubleTensor...) + * Return the TensorTypeSet corresponding to this Tensor, specifying + * all of the TensorTypeIds that this Tensor identifies as. This is the + * information used to dispatch operations on this tensor. */ - TensorTypeId type_id() const { return type_id_; } + TensorTypeSet type_set() const { return type_set_; } /** * Return a reference to the sizes of this tensor. This reference remains @@ -409,45 +400,44 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { bool is_sparse() const { // NB: This method is not virtual and avoid dispatches for performance reasons. - auto tid = type_id(); // NB: At the moment, variables have the same TensorTypeId as their // corresponding tensor, but if this ever changes, we need to modify this. - return tid == TensorTypeId::SparseCPUTensorId || tid == TensorTypeId::SparseCUDATensorId || tid == TensorTypeId::SparseHIPTensorId; + return type_set_.has(TensorTypeId::SparseCPUTensorId) || + type_set_.has(TensorTypeId::SparseCUDATensorId) || + type_set_.has(TensorTypeId::SparseHIPTensorId); } bool is_quantized() const { // NB: This method is not virtual and avoid dispatches for performance reasons. - auto tid = type_id(); // NB: At the moment, variables have the same TensorTypeId as their // corresponding tensor, but if this ever changes, we need to modify this. - return tid == TensorTypeId::QuantizedCPUTensorId; + return type_set_.has(TensorTypeId::QuantizedCPUTensorId); } bool is_cuda() const { // NB: This method is not virtual and avoid dispatches for performance reasons. - auto tid = type_id(); // NB: At the moment, variables have the same TensorTypeId as their // corresponding tensor, but if this ever changes, we need to modify this. - return tid == TensorTypeId::CUDATensorId || tid == TensorTypeId::SparseCUDATensorId; + return type_set_.has(TensorTypeId::CUDATensorId) || + type_set_.has(TensorTypeId::SparseCUDATensorId); } bool is_hip() const { // NB: This method is not virtual and avoid dispatches for performance reasons. - auto tid = type_id(); // NB: At the moment, variables have the same TensorTypeId as their // corresponding tensor, but if this ever changes, we need to modify this. - return tid == TensorTypeId::HIPTensorId || tid == TensorTypeId::SparseHIPTensorId; + return type_set_.has(TensorTypeId::HIPTensorId) || + type_set_.has(TensorTypeId::SparseHIPTensorId); } bool is_mkldnn() const { - return type_id() == TensorTypeId::MkldnnCPUTensorId; + return type_set_.has(TensorTypeId::MkldnnCPUTensorId); } int64_t get_device() const { TORCH_CHECK( device_opt_.has_value(), - "tensor with backend ", toString(tensorTypeIdToBackend(type_id())), - " does not have a device"); + "tensor does not have a device"); // See NOTE [c10::optional operator usage in CUDA] return (*device_opt_).index(); } @@ -455,8 +445,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { Device device() const { TORCH_CHECK( device_opt_.has_value(), - "tensor with backend ", toString(tensorTypeIdToBackend(type_id())), - " does not have a device"); + "tensor does not have a device"); // See NOTE [c10::optional operator usage in CUDA] return *device_opt_; } @@ -843,6 +832,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { */ void set_autograd_meta(std::unique_ptr autograd_meta) { autograd_meta_ = std::move(autograd_meta); + if (autograd_meta_) { + type_set_ = type_set_.add(TensorTypeId::VariableTensorId); + } else { + type_set_ = type_set_.remove(TensorTypeId::VariableTensorId); + } } /** @@ -856,17 +850,21 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * Detach the autograd metadata unique_ptr from this tensor, and return it. */ std::unique_ptr detach_autograd_meta() { + type_set_ = type_set_.remove(TensorTypeId::VariableTensorId); return std::move(autograd_meta_); } -#ifdef BUILD_NAMEDTENSOR /** * Set the pointer to named tensor metadata. */ void set_named_tensor_meta(std::unique_ptr named_tensor_meta) { + TORCH_WARN_ONCE( + "Named tensors and all their associated APIs are an experimental feature ", + "and subject to change. Please do not use them for anything important ", + "until they are released as stable."); #ifdef DEBUG if (named_tensor_meta) { - TORCH_INTERNAL_ASSERT(dim() == named_tensor_meta->names.size()); + TORCH_INTERNAL_ASSERT(named_tensor_meta->slow_dim() == dim()); } #endif named_tensor_meta_ = std::move(named_tensor_meta); @@ -882,7 +880,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { c10::NamedTensorMetaInterface* named_tensor_meta() { return named_tensor_meta_.get(); } -#endif // NOTE [ TensorImpl Shallow-Copying ] @@ -915,17 +912,25 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { /** * One TensorImpl can be copied to another TensorImpl if they have the same - * type_id. The only two special cases (for legacy reason) are: + * TensorTypeSet. The only two special cases (for legacy reason) are: * CPUTensorId is compatible with CUDATensorId and SparseCPUTensorId is * compatible with SparseCUDATensorId. */ - inline bool has_compatible_shallow_copy_type(TensorTypeId from) { - TensorTypeId self = type_id(); - return (self == from) || - ((self == TensorTypeId::CPUTensorId || self == TensorTypeId::CUDATensorId || self == TensorTypeId::HIPTensorId) && - (from == TensorTypeId::CPUTensorId || from == TensorTypeId::CUDATensorId || from == TensorTypeId::HIPTensorId)) || - ((self == TensorTypeId::SparseCPUTensorId || self == TensorTypeId::SparseCUDATensorId || self == TensorTypeId::SparseHIPTensorId) && - (from == TensorTypeId::SparseCPUTensorId || from == TensorTypeId::SparseCUDATensorId || from == TensorTypeId::SparseHIPTensorId)); + inline bool has_compatible_shallow_copy_type(TensorTypeSet from) { + auto is_dense = [](TensorTypeSet ts) { + return ts.has(TensorTypeId::CPUTensorId) || + ts.has(TensorTypeId::CUDATensorId) || + ts.has(TensorTypeId::HIPTensorId); + }; + auto is_sparse = [](TensorTypeSet ts) { + return ts.has(TensorTypeId::SparseCPUTensorId) || + ts.has(TensorTypeId::SparseCUDATensorId) || + ts.has(TensorTypeId::SparseHIPTensorId); + }; + // TODO: This is going to be wrong when we introduce Variable; need to + // factor this to be agnostic to Variable. Maybe the correct fix + // is to introduce another RTTI code for subclasses. + return (type_set_ == from) || (is_dense(type_set_) && is_dense(from)) || (is_sparse(type_set_) && is_sparse(from)); } /** @@ -937,7 +942,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { virtual c10::intrusive_ptr shallow_copy_and_detach( const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const { - auto impl = c10::make_intrusive(Storage(storage()), type_id()); + auto impl = c10::make_intrusive(Storage(storage()), type_set_); copy_tensor_metadata( /*src_impl=*/this, /*dest_impl=*/impl.get(), @@ -1531,17 +1536,23 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { dest_impl->storage_offset_ = src_impl->storage_offset_; dest_impl->data_type_ = src_impl->data_type_; dest_impl->device_opt_ = src_impl->device_opt_; - dest_impl->type_id_ = src_impl->type_id_; + // This may temporarily violate invariant that + // type_set_.has(VariableTensorId) iff autograd_meta_ != nullptr... + dest_impl->type_set_ = src_impl->type_set_; + // ...so refresh Variable in autograd_meta_ + if (dest_impl->autograd_meta_) { + dest_impl->type_set_ = dest_impl->type_set_.add(TensorTypeId::VariableTensorId); + } else { + dest_impl->type_set_ = dest_impl->type_set_.remove(TensorTypeId::VariableTensorId); + } dest_impl->is_contiguous_ = src_impl->is_contiguous_; dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_; dest_impl->reserved_ = src_impl->reserved_; dest_impl->set_version_counter(version_counter); dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); -#ifdef BUILD_NAMEDTENSOR if (src_impl->named_tensor_meta_ != nullptr) { dest_impl->named_tensor_meta_ = src_impl->named_tensor_meta_->clone(); } -#endif } protected: @@ -1552,15 +1563,18 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { static const char * const err_msg_tensor_metadata_change_not_allowed; Storage storage_; + +private: // This pointer points to an AutogradMeta struct that stores autograd-specific fields // (such as grad_ / grad_fn_ / grad_accumulator_). // This pointer always has unique ownership (meaning only one TensorImpl can own it // at a time). + // This is private because we must maintain dispatcher invariants on it + // in type_set_. std::unique_ptr autograd_meta_ = nullptr; -#ifdef BUILD_NAMEDTENSOR +protected: std::unique_ptr named_tensor_meta_ = nullptr; -#endif c10::VariableVersion version_counter_; @@ -1610,9 +1624,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // (which do not have a device.) c10::optional device_opt_; + // The set of TensorTypeIds which describe this tensor + TensorTypeSet type_set_; + // You get to have eight byte-size fields here, before you // should pack this into a bitfield. - TensorTypeId type_id_; bool is_contiguous_ = true; // Tensor is stored in the channels last memory format, when dimensions @@ -1704,15 +1720,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // numel // data type pointer // (optional) device +// tensor type id // miscellaneous bitfield // -#ifdef BUILD_NAMEDTENSOR -#define NWORDS 29 -#else -#define NWORDS 28 -#endif static_assert(sizeof(void*) != sizeof(int64_t) || // if 64-bit... - sizeof(TensorImpl) == sizeof(int64_t) * NWORDS, + sizeof(TensorImpl) == sizeof(int64_t) * 30, "You changed the size of TensorImpl on 64-bit arch." "See Note [TensorImpl size constraints] on how to proceed."); diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index ff3a74dd896d3..830260888a180 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -348,6 +349,7 @@ struct C10_API TensorOptions { } // Resolves the ATen backend specified by the current construction axes. + // TODO: Deprecate this Backend backend() const noexcept { return at::tensorTypeIdToBackend(computeTensorTypeId()); } @@ -374,6 +376,13 @@ struct C10_API TensorOptions { return r; } + // Resolves the tensor type set specified by the current construction axes. + TensorTypeSet type_set() const noexcept { + auto r = TensorTypeSet(computeTensorTypeId()); + if (is_variable()) r = r.add(TensorTypeId::VariableTensorId); + return r; + } + inline TensorTypeId computeTensorTypeId() const { switch (layout()) { case Layout::Strided: diff --git a/c10/core/TensorTypeId.cpp b/c10/core/TensorTypeId.cpp index 8562800188095..1dd4476ae6ca8 100644 --- a/c10/core/TensorTypeId.cpp +++ b/c10/core/TensorTypeId.cpp @@ -38,6 +38,8 @@ const char* toString(TensorTypeId t) { return "ComplexCPUTensorId"; case TensorTypeId::ComplexCUDATensorId: return "ComplexCUDATensorId"; + case TensorTypeId::VariableTensorId: + return "VariableTensorId"; default: return "UNKNOWN_TENSOR_TYPE_ID"; } diff --git a/c10/core/TensorTypeId.h b/c10/core/TensorTypeId.h index b8d29ff55d095..d01ee9d5f3efb 100644 --- a/c10/core/TensorTypeId.h +++ b/c10/core/TensorTypeId.h @@ -6,13 +6,25 @@ namespace c10 { -// NB: Ordering will be subject to change +// A "bit" in a TensorTypeSet, which may have a unique dispatch handler +// for it. Higher bit indexes get handled by dispatching first (because +// we "count leading zeros") enum class TensorTypeId : uint8_t { - UndefinedTensorId, + // This is not a "real" tensor id, but it exists to give us a "nullopt" + // element we can return for cases when a TensorTypeSet contains no elements. + // You can think a more semantically accurate definition of TensorTypeId is: + // + // using TensorTypeId = optional + // + // and UndefinedTensorId == nullopt. We didn't actually represent + // it this way because optional would take two + // words, when TensorTypeId fits in eight bits. + UndefinedTensorId = 0, + + // This pool of IDs is not really ordered, but it is merged into + // the hierarchy for convenience and performance CPUTensorId, // PyTorch/Caffe2 supported CUDATensorId, // PyTorch/Caffe2 supported - SparseCPUTensorId, // PyTorch only - SparseCUDATensorId, // PyTorch only MKLDNNTensorId, // Caffe2 only OpenGLTensorId, // Caffe2 only OpenCLTensorId, // Caffe2 only @@ -24,9 +36,26 @@ enum class TensorTypeId : uint8_t { MkldnnCPUTensorId, QuantizedCPUTensorId, // PyTorch only ComplexCPUTensorId, // PyTorch only - ComplexCUDATensorId // PyTorch only + ComplexCUDATensorId, // PyTorch only + + // Sparse has multi-dispatch with dense; handle it first + SparseCPUTensorId, // PyTorch only + SparseCUDATensorId, // PyTorch only + + // WARNING! If you add more "wrapper" style tensor ids (tensor + // ids which don't get kernels directly defined in native_functions.yaml; + // examples are tracing or profiling) here, you need to also adjust + // legacyExtractTypeId in c10/core/TensorTypeId.h to mask them out. + + VariableTensorId, + + NumTensorIds, // Sentinel }; +static_assert( + static_cast(TensorTypeId::NumTensorIds) < 64, + "TensorTypeId is used as index into 64-bit bitmask; you must have less than 64 entries"); + C10_API const char* toString(TensorTypeId); C10_API std::ostream& operator<<(std::ostream&, TensorTypeId); diff --git a/c10/core/TensorTypeSet.cpp b/c10/core/TensorTypeSet.cpp new file mode 100644 index 0000000000000..2efa813133fbe --- /dev/null +++ b/c10/core/TensorTypeSet.cpp @@ -0,0 +1,31 @@ +#include + +namespace c10 { + +std::string toString(TensorTypeSet ts) { + std::stringstream ss; + ss << ts; + return ss.str(); +} + +std::ostream& operator<<(std::ostream& os, TensorTypeSet ts) { + if (ts.empty()) { + os << "TensorTypeSet()"; + return os; + } + os << "TensorTypeSet("; + TensorTypeId tid; + bool first = true; + while ((tid = ts.highestPriorityTypeId()) != TensorTypeId::UndefinedTensorId) { + if (!first) { + os << ", "; + } + os << tid; + ts = ts.remove(tid); + first = false; + } + os << ")"; + return os; +} + +} diff --git a/c10/core/TensorTypeSet.h b/c10/core/TensorTypeSet.h new file mode 100644 index 0000000000000..f5ce1af70b529 --- /dev/null +++ b/c10/core/TensorTypeSet.h @@ -0,0 +1,128 @@ +#pragma once + +#include +#include +#include +#include + +namespace c10 { + +// A representation of a set of TensorTypeIds. A tensor may have multiple +// tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the +// TensorTypeSet specifies what type ids apply. The internal representation is +// as a 64-bit bit set (this means only 64 tensor type ids are supported). +// +// Note that TensorTypeIds are ordered; thus, we can ask questions like "what is +// the highest priority TensorTypeId in the set"? (The set itself is not +// ordered; two sets with the same ids will always have the ids ordered in the +// same way.) +// +// At the moment, there are no nontrivial uses of this set; tensors are always +// singletons. In the near future, this set will represent variable? + tensor +// type id. In the far future, it will be requires grad? + profiling? + +// tracing? + lazy? + tensor type id. +// +// (The difference between variable and requires grad, is that +// there are currently three states a tensor can be: +// 1. Not a variable +// 2. Variable with requires_grad=False +// 3. Variable with requires_grad=True +// Eventually, we want to kill state (1), and only dispatch to autograd +// handling code if one of the inputs requires grad.) +// +// An undefined tensor is one with an empty tensor type set. +class TensorTypeSet final { +public: + enum Full { FULL }; + enum Raw { RAW }; + + // NB: default constructor representation as zero is MANDATORY as + // use of TensorTypeSet in TLS requires this. + TensorTypeSet() + : repr_(0) {} + TensorTypeSet(Full) + : repr_(-1) {} + // Public version of TensorTypeSet(uint64_t) API; external users + // must be explicit when they do this! + TensorTypeSet(Raw, uint64_t x) + : repr_(x) {} + explicit TensorTypeSet(TensorTypeId t) + : repr_(t == TensorTypeId::UndefinedTensorId + ? 0 + : 1ULL << (static_cast(t) - 1)) {} + // Test if a TensorTypeId is in the set + bool has(TensorTypeId t) const { + TORCH_INTERNAL_ASSERT(t != TensorTypeId::UndefinedTensorId); + return static_cast(repr_ & TensorTypeSet(t).repr_); + } + // Perform set union + TensorTypeSet operator|(TensorTypeSet other) const { + return TensorTypeSet(repr_ | other.repr_); + } + // Perform set intersection + TensorTypeSet operator&(TensorTypeSet other) const { + return TensorTypeSet(repr_ & other.repr_); + } + // Compute the set difference self - other + TensorTypeSet operator-(TensorTypeSet other) const { + return TensorTypeSet(repr_ & ~other.repr_); + } + // Perform set equality + bool operator==(TensorTypeSet other) const { + return repr_ == other.repr_; + } + // Add a TensorTypeId to the TensorTypeId set. Does NOT mutate, + // returns the extended TensorTypeSet! + C10_NODISCARD TensorTypeSet add(TensorTypeId t) const { + return *this | TensorTypeSet(t); + } + // Remove a TensorTypeId from the TensorTypeId set. This is + // generally not an operation you should be doing (it's + // used to implement operator<<) + C10_NODISCARD TensorTypeSet remove(TensorTypeId t) const { + return TensorTypeSet(repr_ & ~TensorTypeSet(t).repr_); + } + // Is the set empty? (AKA undefined tensor) + bool empty() const { + return repr_ == 0; + } + uint64_t raw_repr() { return repr_; } + // Return the type id in this set with the highest priority (i.e., + // is the largest in the TensorTypeId enum). Intuitively, this + // type id is the one that should handle dispatch (assuming there + // aren't any further exclusions or inclusions). + TensorTypeId highestPriorityTypeId() const { + // TODO: If I put UndefinedTensorId as entry 64 and then adjust the + // singleton constructor to shift from the right, we can get rid of the + // subtraction here. It's modestly more complicated to get right so I + // didn't do it for now. + return static_cast(64 - llvm::countLeadingZeros(repr_)); + } +private: + TensorTypeSet(uint64_t repr) : repr_(repr) {} + uint64_t repr_ = 0; +}; + +C10_API std::string toString(TensorTypeSet); +C10_API std::ostream& operator<<(std::ostream&, TensorTypeSet); + +// Historically, every tensor only had a single TensorTypeId, and it was +// always something like CPUTensorId and not something weird like VariableId. +// For the forseeable future, it will still be possible to extract /that/ +// TensorTypeId, and that's what this function does. It should be used +// for legacy code that is still using TensorTypeId for things like instanceof +// checks; if at all possible, refactor the code to stop using TensorTypeId +// in those cases. +// +// What's the difference between 'legacyExtractTypeId(s) == id' +// and 's.has(id)'? legacyExtractTypeId will NEVER return VariableTensorId; +// but s.has(VariableTensorId) will evaluate to true if s has VariableTensorId. +// For non-VariableTensorId equality tests, they are indistinguishable. +// +// NB: If you add other non-VariableTensorId other keys to this set, you'll +// have to adjust this some more (sorry.) +static inline TensorTypeId legacyExtractTypeId(TensorTypeSet s) { + return s.remove(TensorTypeId::VariableTensorId).highestPriorityTypeId(); +} + +} diff --git a/c10/core/impl/DeviceGuardImplInterface.h b/c10/core/impl/DeviceGuardImplInterface.h index 47290a25d0535..1a78fc69fb87b 100644 --- a/c10/core/impl/DeviceGuardImplInterface.h +++ b/c10/core/impl/DeviceGuardImplInterface.h @@ -102,6 +102,13 @@ struct C10_API DeviceGuardImplInterface { */ virtual Stream getStream(Device) const noexcept = 0; + /** + * Get the default stream for a given device. + */ + virtual Stream getDefaultStream(Device) const { + TORCH_CHECK(false, "Backend doesn't support acquiring a default stream.") + } + /** * Set a stream to be the thread local current stream for its device. * Return the previous stream for that device. You are NOT required diff --git a/c10/core/impl/LocalTensorTypeSet.cpp b/c10/core/impl/LocalTensorTypeSet.cpp new file mode 100644 index 0000000000000..9f4de68800960 --- /dev/null +++ b/c10/core/impl/LocalTensorTypeSet.cpp @@ -0,0 +1,42 @@ +#include + +#include + +namespace c10 { +namespace impl { + +namespace { + +/// In the CAFFE2_FB_LIMITED_MOBILE_CAPABILITY build setting, +/// thread_local is not supported. In that case, we don't provide +/// `at::NonVariableTypeMode`. +#ifndef CAFFE2_FB_LIMITED_MOBILE_CAPABILITY + +// NB: Zero initialized! +thread_local uint64_t raw_excluded; + +#else // defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY) + +uint64_t raw_excluded = 0; + +#endif + +} + +TensorTypeSet tls_excluded_tensor_type_set() { + return TensorTypeSet(TensorTypeSet::RAW, raw_excluded); +} + +bool tls_variable_is_enabled() { + return !tls_excluded_tensor_type_set().has(TensorTypeId::VariableTensorId); +} + +void tls_variable_set_enabled(bool enabled) { + if (enabled) { + raw_excluded = tls_excluded_tensor_type_set().remove(TensorTypeId::VariableTensorId).raw_repr(); + } else { + raw_excluded = tls_excluded_tensor_type_set().add(TensorTypeId::VariableTensorId).raw_repr(); + } +} + +}} // namespace c10::impl diff --git a/c10/core/impl/LocalTensorTypeSet.h b/c10/core/impl/LocalTensorTypeSet.h new file mode 100644 index 0000000000000..b049dbaa86816 --- /dev/null +++ b/c10/core/impl/LocalTensorTypeSet.h @@ -0,0 +1,22 @@ +#include + +// TLS management for TensorTypeSet +// +// This manages thread-local TensorTypeSet of excluded keys which disqualify +// tensor types from dispatch. Keys which are in this set, even if they appear +// in a list of potential valid keys on a tensor, are not considered for +// dispatch. This is used to, for example, turn off autograd after we have +// handled autograd for a top-level element. +// +// Originally, I implemented this as storing the inverted set, but +// TLS is defined to be zero-initialized, so this doesn't actually work +// (you want the set to be -1 initialized). + +namespace c10 { +namespace impl { + +C10_API bool tls_variable_is_enabled(); +C10_API void tls_variable_set_enabled(bool enabled); +C10_API TensorTypeSet tls_excluded_tensor_type_set(); + +}} // namespace c10::impl diff --git a/c10/core/impl/VirtualGuardImpl.h b/c10/core/impl/VirtualGuardImpl.h index 9cab33729e10d..c6b6420501b9c 100644 --- a/c10/core/impl/VirtualGuardImpl.h +++ b/c10/core/impl/VirtualGuardImpl.h @@ -37,6 +37,9 @@ class VirtualGuardImpl final : public DeviceGuardImplInterface { Stream getStream(Device d) const noexcept override { return impl_->getStream(d); } + Stream getDefaultStream(Device d) const override { + return impl_->getDefaultStream(d); + } Stream exchangeStream(Stream s) const noexcept override { return impl_->exchangeStream(s); } diff --git a/c10/cuda/impl/CUDAGuardImpl.h b/c10/cuda/impl/CUDAGuardImpl.h index 16e4fd579744f..47bcaf4bbc9a2 100644 --- a/c10/cuda/impl/CUDAGuardImpl.h +++ b/c10/cuda/impl/CUDAGuardImpl.h @@ -45,7 +45,10 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { C10_CUDA_CHECK_WARN(cudaSetDevice(d.index())); } Stream getStream(Device d) const noexcept override { - return getCurrentCUDAStream().unwrap(); + return getCurrentCUDAStream(d.index()).unwrap(); + } + Stream getDefaultStream(Device d) const override { + return getDefaultCUDAStream(d.index()); } // NB: These do NOT set the current device Stream exchangeStream(Stream s) const noexcept override { diff --git a/c10/macros/Export.h b/c10/macros/Export.h index bf94d6856ae29..c4880408a2506 100644 --- a/c10/macros/Export.h +++ b/c10/macros/Export.h @@ -64,6 +64,11 @@ #define C10_IMPORT C10_EXPORT #endif // _WIN32 +#ifdef NO_EXPORT +#undef C10_EXPORT +#define C10_EXPORT +#endif + // Definition of an adaptive XX_API macro, that depends on whether you are // building the library itself or not, routes to XX_EXPORT and XX_IMPORT. // Basically, you will need to do this for each shared library that you are diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index b324fad558faa..8b671fac658ab 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -193,6 +193,12 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; #define C10_HIP_HOST_DEVICE #endif +#ifdef __HIP_PLATFORM_HCC__ +#define C10_WARP_SIZE 64 +#else +#define C10_WARP_SIZE 32 +#endif + #ifdef __APPLE__ #include #endif diff --git a/c10/test/core/TensorTypeSet_test.cpp b/c10/test/core/TensorTypeSet_test.cpp new file mode 100644 index 0000000000000..507707aa15a1c --- /dev/null +++ b/c10/test/core/TensorTypeSet_test.cpp @@ -0,0 +1,55 @@ +#include + +#include + +using namespace c10; + +TEST(TensorTypeSet, Empty) { + TensorTypeSet empty_set; + for (uint8_t i = 1; i < static_cast(TensorTypeId::NumTensorIds); i++) { + auto tid = static_cast(i); + ASSERT_FALSE(empty_set.has(tid)); + } + ASSERT_TRUE(empty_set.empty()); + TensorTypeSet empty_set2; + ASSERT_TRUE(empty_set == empty_set2); + ASSERT_EQ(empty_set.highestPriorityTypeId(), TensorTypeId::UndefinedTensorId); +} + +TEST(TensorTypeSet, Singleton) { + for (uint8_t i = 1; i < static_cast(TensorTypeId::NumTensorIds); i++) { + auto tid = static_cast(i); + TensorTypeSet sing(tid); + ASSERT_EQ(sing, sing); + ASSERT_EQ(sing, TensorTypeSet().add(tid)); + ASSERT_EQ(sing, sing.add(tid)); + ASSERT_EQ(sing, sing | sing); + ASSERT_FALSE(sing.empty()); + ASSERT_TRUE(sing.has(tid)); + ASSERT_EQ(sing.highestPriorityTypeId(), tid); + ASSERT_EQ(sing.remove(tid), TensorTypeSet()); + } +} + +TEST(TensorTypeSet, Doubleton) { + for (uint8_t i = 1; i < static_cast(TensorTypeId::NumTensorIds); i++) { + for (uint8_t j = i + 1; j < static_cast(TensorTypeId::NumTensorIds); j++) { + ASSERT_LT(i, j); + auto tid1 = static_cast(i); + auto tid2 = static_cast(j); + auto doub = TensorTypeSet(tid1).add(tid2); + ASSERT_EQ(doub, TensorTypeSet(tid1) | TensorTypeSet(tid2)); + ASSERT_TRUE(doub.has(tid1)); + ASSERT_TRUE(doub.has(tid2)); + ASSERT_EQ(doub.highestPriorityTypeId(), tid2); // relies on i < j + } + } +} + +TEST(TensorTypeSet, Full) { + TensorTypeSet full(TensorTypeSet::FULL); + for (uint8_t i = 1; i < static_cast(TensorTypeId::NumTensorIds); i++) { + auto tid = static_cast(i); + ASSERT_TRUE(full.has(tid)); + } +} diff --git a/c10/test/util/bfloat16_test.cpp b/c10/test/util/bfloat16_test.cpp index a43929803e7e6..c6cd205413668 100644 --- a/c10/test/util/bfloat16_test.cpp +++ b/c10/test/util/bfloat16_test.cpp @@ -39,6 +39,25 @@ namespace { } } + TEST(BFloat16Conversion, FloatToBFloat16RNEAndBack) { + float in[100]; + for (int i = 0; i < 100; ++i) { + in[i] = i + 1.25; + } + + c10::BFloat16 bfloats[100]; + float out[100]; + + for (int i = 0; i < 100; ++i) { + bfloats[i].x = c10::detail::round_to_nearest_even(in[i]); + out[i] = c10::detail::f32_from_bits(bfloats[i].x); + + // The relative error should be less than 1/(2^7) since BFloat16 + // has 7 bits mantissa. + EXPECT_LE(fabs(out[i] - in[i]) / in[i], 1.0 / 128); + } + } + TEST(BFloat16Conversion, NaN) { float inNaN = float_from_bytes(0, 0xFF, 0x7FFFFF); EXPECT_TRUE(std::isnan(inNaN)); @@ -110,4 +129,34 @@ namespace { float res = c10::detail::f32_from_bits(b.x); EXPECT_EQ(res, expected); } + + float BinaryToFloat(uint32_t bytes) { + float res; + std::memcpy(&res, &bytes, sizeof(res)); + return res; + } + + struct BFloat16TestParam { + uint32_t input; + uint16_t rne; + }; + + class BFloat16Test : public ::testing::Test, + public ::testing::WithParamInterface { + }; + + TEST_P(BFloat16Test, BFloat16RNETest) { + float value = BinaryToFloat(GetParam().input); + uint16_t rounded = c10::detail::round_to_nearest_even(value); + EXPECT_EQ(GetParam().rne, rounded); + } + + INSTANTIATE_TEST_CASE_P( + BFloat16Test_Instantiation, BFloat16Test, + ::testing::Values(BFloat16TestParam{0x3F848000, 0x3F84}, + BFloat16TestParam{0x3F848010, 0x3F85}, + BFloat16TestParam{0x3F850000, 0x3F85}, + BFloat16TestParam{0x3F858000, 0x3F86}, + BFloat16TestParam{0x3FFF8000, 0x4000})); + } // namespace diff --git a/c10/test/util/ordered_preserving_dict_test.cpp b/c10/test/util/ordered_preserving_dict_test.cpp new file mode 100644 index 0000000000000..1811910ac7d48 --- /dev/null +++ b/c10/test/util/ordered_preserving_dict_test.cpp @@ -0,0 +1,421 @@ +#include +#include +#include + +#include +#include +#include +#include + +namespace { + +#define ASSERT_EQUAL_PRIM(t1, t2) ASSERT_TRUE(t1 == t2); + +using dict_int_int = ska_ordered::order_preserving_flat_hash_map; + +dict_int_int test_dict(dict_int_int& dict) { + for (int64_t i = 0; i < 100; ++i) { + dict[i] = i + 1; + } + + int64_t i = 0; + for (auto entry: dict) { + TORCH_INTERNAL_ASSERT(entry.first == i && entry.second == i + 1); + ++i; + } + + // erase a few entries by themselves + std::unordered_set erase_set = {0, 2, 9, 71}; + for (auto erase: erase_set) { + dict.erase(erase); + } + + // erase via iterators + auto begin = dict.begin(); + for (size_t i = 0; i < 20; ++i) + begin++; + + auto end = begin; + for (size_t i = 0; i < 20; ++i) { + erase_set.insert(end->first); + end++; + } + dict.erase(begin, end); + + std::vector order; + for (size_t i = 0; i < 100; ++i) { + if (!erase_set.count(i)) { + order.push_back(i); + } + } + + i = 0; + for (auto entry: dict) { + TORCH_INTERNAL_ASSERT(order[i] == entry.first); + TORCH_INTERNAL_ASSERT(dict[order[i]] == entry.second); + TORCH_INTERNAL_ASSERT(entry.second == order[i] + 1); + i++; + } + TORCH_INTERNAL_ASSERT(dict.size() == order.size()); + return dict; +} + +TEST(OrderedPreservingDictTest, InsertAndDeleteBasic) { + dict_int_int dict; + test_dict(dict); + dict.clear(); + test_dict(dict); +} + +TEST(OrderedPreservingDictTest, InsertExistingDoesntAffectOrder) { + dict_int_int dict; + dict[0] = 1; + dict[1] = 2; + + TORCH_INTERNAL_ASSERT(dict.begin()->first == 0); + dict[0] = 1; + TORCH_INTERNAL_ASSERT(dict.begin()->first == 0); + dict[0] = 2; + TORCH_INTERNAL_ASSERT(dict.begin()->first == 0); + + dict.erase(0); + TORCH_INTERNAL_ASSERT(dict.begin()->first == 1); +} + + +TEST(OrderedPreservingDictTest, testRefType) { + std::shared_ptr t; + using dict_references = ska_ordered::order_preserving_flat_hash_map>; + + dict_references dict; + + auto ptr = std::make_shared(1); + dict[1] = ptr; + TORCH_INTERNAL_ASSERT(ptr.use_count() == 2); + dict.erase(1); + TORCH_INTERNAL_ASSERT(ptr.use_count() == 1); + + dict[2] = ptr; + dict.clear(); + TORCH_INTERNAL_ASSERT(ptr.use_count() == 1); +} + + +TEST(OrderedPreservingDictTest, DictCollisions) { + struct BadHash { + size_t operator()(const int64_t input) { + return input % 2; + }; + }; + + using bad_hash_dict = + ska_ordered::order_preserving_flat_hash_map; + + for (auto init_dict_size : {27, 34, 41}) { + bad_hash_dict dict; + for (int64_t i = 0; i < init_dict_size; ++i) { + dict[i] = i + 1; + } + + int64_t i = 0; + for (auto entry : dict) { + TORCH_INTERNAL_ASSERT(entry.first == i && entry.second == i + 1); + ++i; + } + + // erase a few entries; + std::unordered_set erase_set = {0, 2, 9}; + for (auto erase : erase_set) { + dict.erase(erase); + } + + // erase a few entries via iterator + auto begin = dict.begin(); + for (size_t i = 0; i < 10; ++i) { + begin++; + } + auto end = begin; + for (size_t i = 0; i < 7; ++i) { + erase_set.insert(end->first); + end++; + } + dict.erase(begin, end); + + std::vector order; + for (int64_t i = 0; i < init_dict_size; ++i) { + if (!erase_set.count(i)) { + order.push_back(i); + } + } + + i = 0; + for (auto entry : dict) { + TORCH_INTERNAL_ASSERT(dict[entry.first] == entry.second); + TORCH_INTERNAL_ASSERT(dict[entry.first] == order[i] + 1); + TORCH_INTERNAL_ASSERT(order[i] == entry.first); + i += 1; + } + TORCH_INTERNAL_ASSERT(dict.size() == order.size()); + } +} + + +// Tests taken from https://github.com/Tessil/ordered-map/blob/master/tests/ordered_map_tests.cpp + +TEST(OrderedPreservingDictTest, test_range_insert) { + // insert x values in vector, range insert x-15 values from vector to map, check values + const int nb_values = 1000; + std::vector> values; + for(int i = 0; i < nb_values; i++) { + values.push_back(std::make_pair(i, i+1)); + } + + dict_int_int map = {{-1, 0}, {-2, 0}}; + map.insert(values.begin() + 10, values.end() - 5); + + TORCH_INTERNAL_ASSERT(map.size(), 987); + + ASSERT_EQUAL_PRIM(map.at(-1), 0); + + ASSERT_EQUAL_PRIM(map.at(-2), 0); + + for(int i = 10, j = 2; i < nb_values - 5; i++, j++) { + ASSERT_EQUAL_PRIM(map.at(i), i+1); + } +} + +TEST(OrderedPreservingDictTest, test_range_erase_all) { + // insert x values, delete all + const std::size_t nb_values = 1000; + dict_int_int map; + for (size_t i = 0; i < nb_values; ++i) { + map[i] = i + 1; + } + auto it = map.erase(map.begin(), map.end()); + ASSERT_TRUE(it == map.end()); + ASSERT_TRUE(map.empty()); +} + +TEST(OrderedPreservingDictTest, test_range_erase) { + // insert x values, delete all with iterators except 10 first and 780 last values + using HMap = ska_ordered::order_preserving_flat_hash_map; + + const std::size_t nb_values = 1000; + HMap map; + for (size_t i = 0; i < nb_values; ++i) { + map[c10::guts::to_string(i)] = i; + auto begin = map.begin(); + for (size_t j = 0; j <= i; ++j, begin++) { + TORCH_INTERNAL_ASSERT(begin->second == j); + } + } + + auto it_first = std::next(map.begin(), 10); + auto it_last = std::next(map.begin(), 220); + + auto it = map.erase(it_first, it_last); + ASSERT_EQUAL_PRIM(std::distance(it, map.end()), 780); + ASSERT_EQUAL_PRIM(map.size(), 790); + ASSERT_EQUAL_PRIM(std::distance(map.begin(), map.end()), 790); + + for(auto& val: map) { + ASSERT_EQUAL_PRIM(map.count(val.first), 1); + } + + // Check order + it = map.begin(); + for(std::size_t i = 0; i < nb_values; i++) { + if(i >= 10 && i < 220) { + continue; + } + auto exp_it = std::pair(c10::guts::to_string(i), i); + TORCH_INTERNAL_ASSERT(*it == exp_it); + ++it; + } +} + +TEST(OrderedPreservingDictTest, test_move_constructor_empty) { + ska_ordered::order_preserving_flat_hash_map map(0); + ska_ordered::order_preserving_flat_hash_map map_move(std::move(map)); + + TORCH_INTERNAL_ASSERT(map.empty()); + TORCH_INTERNAL_ASSERT(map_move.empty()); + + TORCH_INTERNAL_ASSERT(map.find("") == map.end()); + TORCH_INTERNAL_ASSERT(map_move.find("") == map_move.end()); +} + +TEST(OrderedPreservingDictTest, test_move_operator_empty) { + ska_ordered::order_preserving_flat_hash_map map(0); + ska_ordered::order_preserving_flat_hash_map map_move; + map_move = (std::move(map)); + + TORCH_INTERNAL_ASSERT(map.empty()); + TORCH_INTERNAL_ASSERT(map_move.empty()); + + TORCH_INTERNAL_ASSERT(map.find("") == map.end()); + TORCH_INTERNAL_ASSERT(map_move.find("") == map_move.end()); +} + +TEST(OrderedPreservingDictTest, test_reassign_moved_object_move_constructor) { + using HMap = ska_ordered::order_preserving_flat_hash_map; + + HMap map = {{"Key1", "Value1"}, {"Key2", "Value2"}, {"Key3", "Value3"}}; + HMap map_move(std::move(map)); + + ASSERT_EQUAL_PRIM(map_move.size(), 3); + ASSERT_EQUAL_PRIM(map.size(), 0); + + map = {{"Key4", "Value4"}, {"Key5", "Value5"}}; + TORCH_INTERNAL_ASSERT(map == (HMap({{"Key4", "Value4"}, {"Key5", "Value5"}}))); +} + +TEST(OrderedPreservingDictTest, test_reassign_moved_object_move_operator) { + using HMap = ska_ordered::order_preserving_flat_hash_map; + + HMap map = {{"Key1", "Value1"}, {"Key2", "Value2"}, {"Key3", "Value3"}}; + HMap map_move = std::move(map); + + ASSERT_EQUAL_PRIM(map_move.size(), 3); + ASSERT_EQUAL_PRIM(map.size(), 0); + + map = {{"Key4", "Value4"}, {"Key5", "Value5"}}; + TORCH_INTERNAL_ASSERT(map == (HMap({{"Key4", "Value4"}, {"Key5", "Value5"}}))); +} + +TEST(OrderedPreservingDictTest, test_copy_constructor_and_operator) { + using HMap = ska_ordered::order_preserving_flat_hash_map; + + + const std::size_t nb_values = 100; + HMap map; + for (size_t i = 0; i < nb_values; ++i) { + map[c10::guts::to_string(i)] = c10::guts::to_string(i); + } + + + HMap map_copy = map; + HMap map_copy2(map); + HMap map_copy3; + map_copy3[c10::guts::to_string(0)] = c10::guts::to_string(0); + + map_copy3 = map; + + TORCH_INTERNAL_ASSERT(map == map_copy); + map.clear(); + + TORCH_INTERNAL_ASSERT(map_copy == map_copy2); + TORCH_INTERNAL_ASSERT(map_copy == map_copy3); +} + +TEST(OrderedPreservingDictTest, test_copy_constructor_empty) { + ska_ordered::order_preserving_flat_hash_map map(0); + ska_ordered::order_preserving_flat_hash_map map_copy(map); + + TORCH_INTERNAL_ASSERT(map.empty()); + TORCH_INTERNAL_ASSERT(map_copy.empty()); + + TORCH_INTERNAL_ASSERT(map.find("") == map.end()); + TORCH_INTERNAL_ASSERT(map_copy.find("") == map_copy.end()); +} + +TEST(OrderedPreservingDictTest, test_copy_operator_empty) { + ska_ordered::order_preserving_flat_hash_map map(0); + ska_ordered::order_preserving_flat_hash_map map_copy(16); + map_copy = map; + + TORCH_INTERNAL_ASSERT(map.empty()); + TORCH_INTERNAL_ASSERT(map_copy.empty()); + + TORCH_INTERNAL_ASSERT(map.find("") == map.end()); + TORCH_INTERNAL_ASSERT(map_copy.find("") == map_copy.end()); +} + + +/** + * at + */ +TEST(OrderedPreservingDictTest, test_at) { + // insert x values, use at for known and unknown values. + const ska_ordered::order_preserving_flat_hash_map map = {{0, 10}, {-2, 20}}; + + ASSERT_EQUAL_PRIM(map.at(0), 10); + ASSERT_EQUAL_PRIM(map.at(-2), 20); + bool thrown = false; + try { + map.at(1); + } catch (...) { + thrown = true; + } + ASSERT_TRUE(thrown); +} + + +/** + * equal_range + */ +TEST(OrderedPreservingDictTest, test_equal_range) { + ska_ordered::order_preserving_flat_hash_map map = {{0, 10}, {-2, 20}}; + + auto it_pair = map.equal_range(0); + ASSERT_EQUAL_PRIM(std::distance(it_pair.first, it_pair.second), 1); + ASSERT_EQUAL_PRIM(it_pair.first->second, 10); + + it_pair = map.equal_range(1); + TORCH_INTERNAL_ASSERT(it_pair.first == it_pair.second); + TORCH_INTERNAL_ASSERT(it_pair.first == map.end()); +} + + +/** + * operator[] + */ +TEST(OrderedPreservingDictTest, test_access_operator) { + // insert x values, use at for known and unknown values. + ska_ordered::order_preserving_flat_hash_map map = {{0, 10}, {-2, 20}}; + + ASSERT_EQUAL_PRIM(map[0], 10); + ASSERT_EQUAL_PRIM(map[-2], 20); + ASSERT_EQUAL_PRIM(map[2], std::int64_t()); + + ASSERT_EQUAL_PRIM(map.size(), 3); +} + +/** + * swap + */ +TEST(OrderedPreservingDictTest, test_swap) { + ska_ordered::order_preserving_flat_hash_map map = {{1, 10}, {8, 80}, {3, 30}}; + ska_ordered::order_preserving_flat_hash_map map2 = {{4, 40}, {5, 50}}; + + using std::swap; + swap(map, map2); + + TORCH_INTERNAL_ASSERT(map == (ska_ordered::order_preserving_flat_hash_map{{4, 40}, {5, 50}})); + TORCH_INTERNAL_ASSERT(map2 == (ska_ordered::order_preserving_flat_hash_map{{1, 10}, {8, 80}, {3, 30}})); + + map.insert({6, 60}); + map2.insert({4, 40}); + + TORCH_INTERNAL_ASSERT(map == (ska_ordered::order_preserving_flat_hash_map{{4, 40}, {5, 50}, {6, 60}})); + TORCH_INTERNAL_ASSERT(map2 == (ska_ordered::order_preserving_flat_hash_map{{1, 10}, {8, 80}, {3, 30}, {4, 40}})); +} + +TEST(OrderedPreservingDictTest, test_swap_empty) { + ska_ordered::order_preserving_flat_hash_map map = {{1, 10}, {8, 80}, {3, 30}}; + ska_ordered::order_preserving_flat_hash_map map2; + + using std::swap; + swap(map, map2); + + TORCH_INTERNAL_ASSERT(map == (ska_ordered::order_preserving_flat_hash_map{})); + TORCH_INTERNAL_ASSERT(map2 == (ska_ordered::order_preserving_flat_hash_map{{1, 10}, {8, 80}, {3, 30}})); + + map.insert({6, 60}); + map2.insert({4, 40}); + + TORCH_INTERNAL_ASSERT(map == (ska_ordered::order_preserving_flat_hash_map{{6, 60}})); + TORCH_INTERNAL_ASSERT(map2 == (ska_ordered::order_preserving_flat_hash_map{{1, 10}, {8, 80}, {3, 30}, {4, 40}})); +} + +} diff --git a/c10/util/ArrayRef.h b/c10/util/ArrayRef.h index 8e70424f34ec4..eea1a48ebf634 100644 --- a/c10/util/ArrayRef.h +++ b/c10/util/ArrayRef.h @@ -100,7 +100,7 @@ class ArrayRef final { /// Construct an ArrayRef from a std::initializer_list. /* implicit */ constexpr ArrayRef(const std::initializer_list& Vec) - : Data(Vec.begin() == Vec.end() ? static_cast(nullptr) : Vec.begin()), + : Data(std::begin(Vec) == std::end(Vec) ? static_cast(nullptr) : std::begin(Vec)), Length(Vec.size()) {} /// @} diff --git a/c10/util/BFloat16-inl.h b/c10/util/BFloat16-inl.h index ac6f19aaace5f..ab366d000df9d 100644 --- a/c10/util/BFloat16-inl.h +++ b/c10/util/BFloat16-inl.h @@ -7,7 +7,8 @@ namespace c10 { /// Constructors inline C10_HOST_DEVICE BFloat16::BFloat16(float value) { - x = detail::bits_from_f32(value); + // RNE by default + x = detail::round_to_nearest_even(value); } /// Implicit conversions @@ -203,17 +204,64 @@ namespace std { template <> class numeric_limits { - public: - static constexpr bool is_signed = true; - static constexpr bool is_integer = false; - static constexpr bool has_infinity = true; - static constexpr bool has_quiet_NaN = true; - static constexpr c10::BFloat16 lowest() { - return at::BFloat16(0xFF7F, at::BFloat16::from_bits()); - } - static constexpr c10::BFloat16 max() { - return at::BFloat16(0x7F7F, at::BFloat16::from_bits()); - } +public: + static constexpr bool is_signed = true; + static constexpr bool is_specialized = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr auto has_denorm = numeric_limits::has_denorm; + static constexpr auto has_denorm_loss = + numeric_limits::has_denorm_loss; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 8; + static constexpr int digits10 = 2; + static constexpr int max_digits10 = 4; + static constexpr int radix = 2; + static constexpr int min_exponent = -125; + static constexpr int min_exponent10 = -37; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr c10::BFloat16 min() { + return c10::BFloat16(0x0080, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 lowest() { + return c10::BFloat16(0xFF7F, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 max() { + return c10::BFloat16(0x7F7F, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 epsilon() { + return c10::BFloat16(0x3C00, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 round_error() { + return c10::BFloat16(0x3F00, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 infinity() { + return c10::BFloat16(0x7F80, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 quiet_NaN() { + return c10::BFloat16(0x7FC0, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 signaling_NaN() { + return c10::BFloat16(0x7F80, c10::BFloat16::from_bits()); + } + static constexpr c10::BFloat16 denorm_min() { + return c10::BFloat16(0x0001, c10::BFloat16::from_bits()); + } }; +/// Used by vec256::map +inline c10::BFloat16 exp(c10::BFloat16 a) { return std::exp(float(a)); } +inline c10::BFloat16 log(c10::BFloat16 a) { return std::log(float(a)); } + } // namespace std diff --git a/c10/util/BFloat16.h b/c10/util/BFloat16.h index 98f286f6de9ee..f66fb5971e15d 100644 --- a/c10/util/BFloat16.h +++ b/c10/util/BFloat16.h @@ -43,6 +43,21 @@ namespace detail { return res >> 16; } + + inline C10_HOST_DEVICE uint16_t round_to_nearest_even(float src) { + if (std::isnan(src)) { + return 0x7FC0; + } else { + union { + uint32_t U32; + float F32; + }; + + F32 = src; + uint32_t rounding_bias = ((U32 >> 16) & 1) + 0x7FFF; + return static_cast((U32 + rounding_bias) >> 16); + } + } } // namespace detail struct alignas(2) BFloat16 { diff --git a/c10/util/Exception.h b/c10/util/Exception.h index 1c11e0c0da24c..6210e7d028b31 100644 --- a/c10/util/Exception.h +++ b/c10/util/Exception.h @@ -137,6 +137,8 @@ inline std::string if_empty_then(std::string x, std::string y) { // unsigned int (a.k.a uint32_t) and may cause a compile error with the message: // error C2397: conversion from 'long' to 'uint32_t' requires a narrowing conversion // Here the static cast is used to pass the build. +// if this is used inside a lambda the __func__ macro expands to operator(), +// which isn't very useful, but hard to fix in a macro so suppressing the warning. #define C10_THROW_ERROR(err_type, msg) \ throw ::c10::err_type({__func__, __FILE__, static_cast(__LINE__)}, msg) @@ -289,7 +291,7 @@ inline std::string if_empty_then(std::string x, std::string y) { // arguments which are concatenated into the warning message using operator<< // #define TORCH_WARN_ONCE(...) \ - C10_UNUSED static const auto C10_ANONYMOUS_VARIABLE(torch_warn_once_) = [] { \ + C10_UNUSED static const auto C10_ANONYMOUS_VARIABLE(torch_warn_once_) = [&] { \ ::c10::Warning::warn({__func__, __FILE__, static_cast(__LINE__)}, ::c10::str(__VA_ARGS__)); \ return true; \ }() diff --git a/c10/util/Half-inl.h b/c10/util/Half-inl.h index 3ce4203a59728..f3ef954a58733 100644 --- a/c10/util/Half-inl.h +++ b/c10/util/Half-inl.h @@ -70,7 +70,7 @@ inline C10_HOST_DEVICE Half operator/(const Half& a, const Half& b) { } inline C10_HOST_DEVICE Half operator-(const Half& a) { -#if __CUDA_ARCH__ >= 530 || defined(__HIP_DEVICE_COMPILE__) +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) || defined(__HIP_DEVICE_COMPILE__) return __hneg(a); #else return -static_cast(a); diff --git a/c10/util/Half.h b/c10/util/Half.h index 62421e0cc35fc..338f271627def 100644 --- a/c10/util/Half.h +++ b/c10/util/Half.h @@ -427,9 +427,20 @@ struct Converter< #pragma warning( disable : 4804 ) #endif + +// bool can be converted to any type. +// Without specializing on bool, in pytorch_linux_trusty_py2_7_9_build: +// `error: comparison of constant '255' with boolean expression is always false` +// for `f > limit::max()` below +template +typename std::enable_if::value, bool>::type overflows( + From f) { + return false; +} + // skip isnan and isinf check for integral types template -typename std::enable_if::value, bool>::type overflows( +typename std::enable_if::value && !std::is_same::value, bool>::type overflows( From f) { using limit = std::numeric_limits::type>; if (!limit::is_signed && std::numeric_limits::is_signed) { diff --git a/c10/util/order_preserving_flat_hash_map.h b/c10/util/order_preserving_flat_hash_map.h new file mode 100644 index 0000000000000..b529985517ee2 --- /dev/null +++ b/c10/util/order_preserving_flat_hash_map.h @@ -0,0 +1,1643 @@ +// Taken from https://github.com/skarupke/flat_hash_map/blob/2c4687431f978f02a3780e24b8b701d22aa32d9c/flat_hash_map.hpp +// with fixes applied: +// - https://github.com/skarupke/flat_hash_map/pull/25 +// - https://github.com/skarupke/flat_hash_map/pull/26 +// - replace size_t with uint64_t to fix it for 32bit +// - add "GCC diagnostic" pragma to ignore -Wshadow +// - make sherwood_v3_table::convertible_to_iterator public because GCC5 seems to have issues with it otherwise +// - fix compiler warnings in operator templated_iterator + +// Copyright Malte Skarupke 2017. +// Distributed under the Boost Software License, Version 1.0. +// (See http://www.boost.org/LICENSE_1_0.txt) + +// Modified to maintain insertion and deletion order through a doubly-linked list + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef _MSC_VER +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wshadow" +#endif + +#ifdef _MSC_VER +#define SKA_NOINLINE(...) __declspec(noinline) __VA_ARGS__ +#else +#define SKA_NOINLINE(...) __VA_ARGS__ __attribute__((noinline)) +#endif + +namespace ska_ordered +{ + +struct prime_number_hash_policy; +struct power_of_two_hash_policy; +struct fibonacci_hash_policy; + +namespace detailv3 +{ +template +struct functor_storage : Functor +{ + functor_storage() = default; + functor_storage(const Functor & functor) + : Functor(functor) + { + } + template + Result operator()(Args &&... args) + { + return static_cast(*this)(std::forward(args)...); + } + template + Result operator()(Args &&... args) const + { + return static_cast(*this)(std::forward(args)...); + } +}; +template +struct functor_storage +{ + typedef Result (*function_ptr)(Args...); + function_ptr function; + functor_storage(function_ptr function) + : function(function) + { + } + Result operator()(Args... args) const + { + return function(std::forward(args)...); + } + operator function_ptr &() + { + return function; + } + operator const function_ptr &() + { + return function; + } +}; +template +struct KeyOrValueHasher : functor_storage +{ + typedef functor_storage hasher_storage; + KeyOrValueHasher() = default; + KeyOrValueHasher(const hasher & hash) + : hasher_storage(hash) + { + } + uint64_t operator()(const key_type & key) + { + return static_cast(*this)(key); + } + uint64_t operator()(const key_type & key) const + { + return static_cast(*this)(key); + } + uint64_t operator()(const value_type & value) + { + return static_cast(*this)(value.first); + } + uint64_t operator()(const value_type & value) const + { + return static_cast(*this)(value.first); + } + template + uint64_t operator()(const std::pair & value) + { + return static_cast(*this)(value.first); + } + template + uint64_t operator()(const std::pair & value) const + { + return static_cast(*this)(value.first); + } +}; +template +struct KeyOrValueEquality : functor_storage +{ + typedef functor_storage equality_storage; + KeyOrValueEquality() = default; + KeyOrValueEquality(const key_equal & equality) + : equality_storage(equality) + { + } + bool operator()(const key_type & lhs, const key_type & rhs) + { + return static_cast(*this)(lhs, rhs); + } + bool operator()(const key_type & lhs, const value_type & rhs) + { + return static_cast(*this)(lhs, rhs.first); + } + bool operator()(const value_type & lhs, const key_type & rhs) + { + return static_cast(*this)(lhs.first, rhs); + } + bool operator()(const value_type & lhs, const value_type & rhs) + { + return static_cast(*this)(lhs.first, rhs.first); + } + template + bool operator()(const key_type & lhs, const std::pair & rhs) + { + return static_cast(*this)(lhs, rhs.first); + } + template + bool operator()(const std::pair & lhs, const key_type & rhs) + { + return static_cast(*this)(lhs.first, rhs); + } + template + bool operator()(const value_type & lhs, const std::pair & rhs) + { + return static_cast(*this)(lhs.first, rhs.first); + } + template + bool operator()(const std::pair & lhs, const value_type & rhs) + { + return static_cast(*this)(lhs.first, rhs.first); + } + template + bool operator()(const std::pair & lhs, const std::pair & rhs) + { + return static_cast(*this)(lhs.first, rhs.first); + } +}; +static constexpr int8_t min_lookups = 4; +template +struct sherwood_v3_entry +{ + sherwood_v3_entry() + { + } + sherwood_v3_entry(int8_t distance_from_desired) + : distance_from_desired(distance_from_desired) + { + } + ~sherwood_v3_entry() + { + } + + bool has_value() const + { + return distance_from_desired >= 0; + } + bool is_empty() const + { + return distance_from_desired < 0; + } + bool is_at_desired_position() const + { + return distance_from_desired <= 0; + } + template + void emplace(int8_t distance, Args &&... args) + { + new (std::addressof(value)) T(std::forward(args)...); + distance_from_desired = distance; + } + + void destroy_value() + { + value.~T(); + distance_from_desired = -1; + } + + sherwood_v3_entry * prev = nullptr; + sherwood_v3_entry * next = nullptr; + int8_t distance_from_desired = -1; + static constexpr int8_t special_end_value = 0; + union { T value; }; +}; + +inline int8_t log2(uint64_t value) +{ + static constexpr int8_t table[64] = + { + 63, 0, 58, 1, 59, 47, 53, 2, + 60, 39, 48, 27, 54, 33, 42, 3, + 61, 51, 37, 40, 49, 18, 28, 20, + 55, 30, 34, 11, 43, 14, 22, 4, + 62, 57, 46, 52, 38, 26, 32, 41, + 50, 36, 17, 19, 29, 10, 13, 21, + 56, 45, 25, 31, 35, 16, 9, 12, + 44, 24, 15, 8, 23, 7, 6, 5 + }; + value |= value >> 1; + value |= value >> 2; + value |= value >> 4; + value |= value >> 8; + value |= value >> 16; + value |= value >> 32; + return table[((value - (value >> 1)) * 0x07EDD5E59A4E28C2) >> 58]; +} + +template +struct AssignIfTrue +{ + void operator()(T & lhs, const T & rhs) + { + lhs = rhs; + } + void operator()(T & lhs, T && rhs) + { + lhs = std::move(rhs); + } +}; +template +struct AssignIfTrue +{ + void operator()(T &, const T &) + { + } + void operator()(T &, T &&) + { + } +}; + +inline uint64_t next_power_of_two(uint64_t i) +{ + --i; + i |= i >> 1; + i |= i >> 2; + i |= i >> 4; + i |= i >> 8; + i |= i >> 16; + i |= i >> 32; + ++i; + return i; +} + +// Implementation taken from http://en.cppreference.com/w/cpp/types/void_t +// (it takes CWG1558 into account and also works for older compilers) +template struct make_void { typedef void type;}; +template using void_t = typename make_void::type; + +template +struct HashPolicySelector +{ + typedef fibonacci_hash_policy type; +}; +template +struct HashPolicySelector> +{ + typedef typename T::hash_policy type; +}; + +template +class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal +{ + using Entry = detailv3::sherwood_v3_entry; + using AllocatorTraits = std::allocator_traits; + using EntryPointer = typename AllocatorTraits::pointer; + +public: + struct convertible_to_iterator; + + using value_type = T; + using size_type = uint64_t; + using difference_type = std::ptrdiff_t; + using hasher = ArgumentHash; + using key_equal = ArgumentEqual; + using allocator_type = EntryAlloc; + using reference = value_type &; + using const_reference = const value_type &; + using pointer = value_type *; + using const_pointer = const value_type *; + + sherwood_v3_table() + { + } + explicit sherwood_v3_table(size_type bucket_count, const ArgumentHash & hash = ArgumentHash(), const ArgumentEqual & equal = ArgumentEqual(), const ArgumentAlloc & alloc = ArgumentAlloc()) + : EntryAlloc(alloc), Hasher(hash), Equal(equal) + { + rehash(bucket_count); + } + sherwood_v3_table(size_type bucket_count, const ArgumentAlloc & alloc) + : sherwood_v3_table(bucket_count, ArgumentHash(), ArgumentEqual(), alloc) + { + } + sherwood_v3_table(size_type bucket_count, const ArgumentHash & hash, const ArgumentAlloc & alloc) + : sherwood_v3_table(bucket_count, hash, ArgumentEqual(), alloc) + { + } + explicit sherwood_v3_table(const ArgumentAlloc & alloc) + : EntryAlloc(alloc) + { + } + template + sherwood_v3_table(It first, It last, size_type bucket_count = 0, const ArgumentHash & hash = ArgumentHash(), const ArgumentEqual & equal = ArgumentEqual(), const ArgumentAlloc & alloc = ArgumentAlloc()) + : sherwood_v3_table(bucket_count, hash, equal, alloc) + { + insert(first, last); + } + template + sherwood_v3_table(It first, It last, size_type bucket_count, const ArgumentAlloc & alloc) + : sherwood_v3_table(first, last, bucket_count, ArgumentHash(), ArgumentEqual(), alloc) + { + } + template + sherwood_v3_table(It first, It last, size_type bucket_count, const ArgumentHash & hash, const ArgumentAlloc & alloc) + : sherwood_v3_table(first, last, bucket_count, hash, ArgumentEqual(), alloc) + { + } + sherwood_v3_table(std::initializer_list il, size_type bucket_count = 0, const ArgumentHash & hash = ArgumentHash(), const ArgumentEqual & equal = ArgumentEqual(), const ArgumentAlloc & alloc = ArgumentAlloc()) + : sherwood_v3_table(bucket_count, hash, equal, alloc) + { + if (bucket_count == 0) + rehash(il.size()); + insert(il.begin(), il.end()); + } + sherwood_v3_table(std::initializer_list il, size_type bucket_count, const ArgumentAlloc & alloc) + : sherwood_v3_table(il, bucket_count, ArgumentHash(), ArgumentEqual(), alloc) + { + } + sherwood_v3_table(std::initializer_list il, size_type bucket_count, const ArgumentHash & hash, const ArgumentAlloc & alloc) + : sherwood_v3_table(il, bucket_count, hash, ArgumentEqual(), alloc) + { + } + sherwood_v3_table(const sherwood_v3_table & other) + : sherwood_v3_table(other, AllocatorTraits::select_on_container_copy_construction(other.get_allocator())) + { + } + sherwood_v3_table(const sherwood_v3_table & other, const ArgumentAlloc & alloc) + : EntryAlloc(alloc), Hasher(other), Equal(other), _max_load_factor(other._max_load_factor) + { + rehash_for_other_container(other); + try + { + insert(other.begin(), other.end()); + } + catch(...) + { + clear(); + deallocate_data(entries, num_slots_minus_one, max_lookups); + throw; + } + } + sherwood_v3_table(sherwood_v3_table && other) noexcept + : EntryAlloc(std::move(other)), Hasher(std::move(other)), Equal(std::move(other)) + { + swap_pointers(other); + } + sherwood_v3_table(sherwood_v3_table && other, const ArgumentAlloc & alloc) noexcept + : EntryAlloc(alloc), Hasher(std::move(other)), Equal(std::move(other)) + { + swap_pointers(other); + } + sherwood_v3_table & operator=(const sherwood_v3_table & other) + { + if (this == std::addressof(other)) + return *this; + + clear(); + if (AllocatorTraits::propagate_on_container_copy_assignment::value) + { + if (static_cast(*this) != static_cast(other)) + { + reset_to_empty_state(); + } + AssignIfTrue()(*this, other); + } + _max_load_factor = other._max_load_factor; + static_cast(*this) = other; + static_cast(*this) = other; + rehash_for_other_container(other); + insert(other.begin(), other.end()); + return *this; + } + sherwood_v3_table & operator=(sherwood_v3_table && other) noexcept + { + if (this == std::addressof(other)) + return *this; + else if (AllocatorTraits::propagate_on_container_move_assignment::value) + { + clear(); + reset_to_empty_state(); + AssignIfTrue()(*this, std::move(other)); + swap_pointers(other); + } + else if (static_cast(*this) == static_cast(other)) + { + swap_pointers(other); + } + else + { + clear(); + _max_load_factor = other._max_load_factor; + rehash_for_other_container(other); + for (T & elem : other) + emplace(std::move(elem)); + other.clear(); + } + static_cast(*this) = std::move(other); + static_cast(*this) = std::move(other); + return *this; + } + ~sherwood_v3_table() + { + clear(); + deallocate_data(entries, num_slots_minus_one, max_lookups); + } + + const allocator_type & get_allocator() const + { + return static_cast(*this); + } + const ArgumentEqual & key_eq() const + { + return static_cast(*this); + } + const ArgumentHash & hash_function() const + { + return static_cast(*this); + } + + template + struct templated_iterator + { + templated_iterator() = default; + templated_iterator(EntryPointer current) + : current(current) + { + } + EntryPointer current = EntryPointer(); + + using iterator_category = std::forward_iterator_tag; + using value_type = ValueType; + using difference_type = ptrdiff_t; + using pointer = ValueType *; + using reference = ValueType &; + + friend bool operator==(const templated_iterator & lhs, const templated_iterator & rhs) + { + return lhs.current == rhs.current; + } + friend bool operator!=(const templated_iterator & lhs, const templated_iterator & rhs) + { + return !(lhs == rhs); + } + + templated_iterator & operator++() + { + current = current->next; + return *this; + } + templated_iterator operator++(int) + { + templated_iterator copy(*this); + ++*this; + return copy; + } + + ValueType & operator*() const + { + return current->value; + } + ValueType * operator->() const + { + return std::addressof(current->value); + } + + // the template automatically disables the operator when value_type is already + // const, because that would cause a lot of compiler warnings otherwise. + template::value && !std::is_same::value>::type> + operator templated_iterator() const + { + return { current }; + } + }; + using iterator = templated_iterator; + using const_iterator = templated_iterator; + + iterator begin() + { + return sentinel->next; + } + const_iterator begin() const + { + return sentinel->next; + } + const_iterator cbegin() const + { + return begin(); + } + iterator end() + { + return sentinel; + } + const_iterator end() const + { + return sentinel; + } + const_iterator cend() const + { + return end(); + } + + iterator find(const FindKey & key) + { + uint64_t index = hash_policy.index_for_hash(hash_object(key), num_slots_minus_one); + EntryPointer it = entries + ptrdiff_t(index); + for (int8_t distance = 0; it->distance_from_desired >= distance; ++distance, ++it) + { + if (compares_equal(key, it->value)) + return { it }; + } + return end(); + } + const_iterator find(const FindKey & key) const + { + return const_cast(this)->find(key); + } + uint64_t count(const FindKey & key) const + { + return find(key) == end() ? 0 : 1; + } + std::pair equal_range(const FindKey & key) + { + iterator found = find(key); + if (found == end()) + return { found, found }; + else + return { found, std::next(found) }; + } + std::pair equal_range(const FindKey & key) const + { + const_iterator found = find(key); + if (found == end()) + return { found, found }; + else + return { found, std::next(found) }; + } + + template + std::pair emplace(Key && key, Args &&... args) + { + uint64_t index = hash_policy.index_for_hash(hash_object(key), num_slots_minus_one); + EntryPointer current_entry = entries + ptrdiff_t(index); + int8_t distance_from_desired = 0; + for (; current_entry->distance_from_desired >= distance_from_desired; ++current_entry, ++distance_from_desired) + { + // insertion of an existing key does not change ordering + if (compares_equal(key, current_entry->value)) + return { { current_entry }, false }; + } + return emplace_new_key(distance_from_desired, current_entry, std::forward(key), std::forward(args)...); + } + + std::pair insert(const value_type & value) + { + return emplace(value); + } + std::pair insert(value_type && value) + { + return emplace(std::move(value)); + } + template + iterator emplace_hint(const_iterator, Args &&... args) + { + return emplace(std::forward(args)...).first; + } + iterator insert(const_iterator, const value_type & value) + { + return emplace(value).first; + } + iterator insert(const_iterator, value_type && value) + { + return emplace(std::move(value)).first; + } + + template + void insert(It begin, It end) + { + for (; begin != end; ++begin) + { + emplace(*begin); + } + } + void insert(std::initializer_list il) + { + insert(il.begin(), il.end()); + } + + void rehash(uint64_t num_buckets) + { + num_buckets = std::max(num_buckets, static_cast(std::ceil(num_elements / static_cast(_max_load_factor)))); + if (num_buckets == 0) + { + reset_to_empty_state(); + return; + } + auto new_prime_index = hash_policy.next_size_over(num_buckets); + if (num_buckets == bucket_count()) + return; + int8_t new_max_lookups = compute_max_lookups(num_buckets); + EntryPointer new_buckets(AllocatorTraits::allocate(*this, num_buckets + new_max_lookups)); + EntryPointer special_end_item = new_buckets + static_cast(num_buckets + new_max_lookups - 1); + for (EntryPointer it = new_buckets; it != special_end_item; ++it) + it->distance_from_desired = -1; + special_end_item->distance_from_desired = Entry::special_end_value; + std::swap(entries, new_buckets); + std::swap(num_slots_minus_one, num_buckets); + --num_slots_minus_one; + hash_policy.commit(new_prime_index); + int8_t old_max_lookups = max_lookups; + max_lookups = new_max_lookups; + num_elements = 0; + + auto start = sentinel->next; + // point sentinel to itself; + reset_list(); + // reinsert list + for (EntryPointer it = start; it != sentinel;) { + auto next = it->next; + emplace(std::move(it->value)); + it->destroy_value(); + it = next; + } + + deallocate_data(new_buckets, num_buckets, old_max_lookups); + } + + void reserve(uint64_t num_elements) + { + uint64_t required_buckets = num_buckets_for_reserve(num_elements); + if (required_buckets > bucket_count()) + rehash(required_buckets); + } + + void replace_linked_list_position(EntryPointer to_be_replaced, EntryPointer new_node) { + remove_from_list(new_node); + insert_after(new_node, to_be_replaced->prev); + remove_from_list(to_be_replaced); + } + + // the return value is a type that can be converted to an iterator + // the reason for doing this is that it's not free to find the + // iterator pointing at the next element. if you care about the + // next iterator, turn the return value into an iterator + convertible_to_iterator erase(const_iterator to_erase) + { + EntryPointer current = to_erase.current; + remove_from_list(current); + current->destroy_value(); + --num_elements; + + for (EntryPointer next = current + ptrdiff_t(1); !next->is_at_desired_position(); ++current, ++next) + { + // if an entry is being removed, and there are other entries with the + // same hash, the other entries get moved to their desired position by + // reinserting. + current->emplace(next->distance_from_desired - 1, std::move(next->value)); + replace_linked_list_position(next, current); + next->destroy_value(); + } + return { to_erase.current }; + } + + iterator erase(const_iterator begin_it, const_iterator end_it) + { + // whenever an entry is removed and there are other entries with the same + // hash, the other entries must get moved to their desired position. + // any reference to a moved entry is invalidated. + // here, we iterate through the range, and make sure that we update + // the pointer to our next entry in the list or the end of the iterator + // when it is invalidated. + + auto curr_iter = begin_it.current; + auto next_iter = curr_iter->next; + auto end_iter = end_it.current; + + while (curr_iter != end_iter) { + remove_from_list(curr_iter); + curr_iter->destroy_value(); + --num_elements; + + for (EntryPointer next_hash_slot = curr_iter + ptrdiff_t(1); !next_hash_slot->is_at_desired_position(); ++curr_iter, ++next_hash_slot) + { + curr_iter->emplace(next_hash_slot->distance_from_desired - 1, std::move(next_hash_slot->value)); + replace_linked_list_position(next_hash_slot, curr_iter); + next_hash_slot->destroy_value(); + + // we are invalidating next_iter or end_iter + if (next_hash_slot == end_iter) { + end_iter = curr_iter; + } else if (next_hash_slot == next_iter) { + next_iter = curr_iter; + } + } + curr_iter = next_iter; + next_iter = curr_iter->next; + } + + return { end_iter }; + } + + uint64_t erase(const FindKey & key) + { + auto found = find(key); + if (found == end()) + return 0; + else + { + erase(found); + return 1; + } + } + + void clear() + { + for (EntryPointer it = entries, end = it + static_cast(num_slots_minus_one + max_lookups); it != end; ++it) + { + if (it->has_value()) + it->destroy_value(); + } + reset_list(); + num_elements = 0; + } + + void shrink_to_fit() + { + rehash_for_other_container(*this); + } + + void swap(sherwood_v3_table & other) + { + using std::swap; + swap_pointers(other); + swap(static_cast(*this), static_cast(other)); + swap(static_cast(*this), static_cast(other)); + if (AllocatorTraits::propagate_on_container_swap::value) + swap(static_cast(*this), static_cast(other)); + } + + uint64_t size() const + { + return num_elements; + } + uint64_t max_size() const + { + return (AllocatorTraits::max_size(*this)) / sizeof(Entry); + } + uint64_t bucket_count() const + { + return num_slots_minus_one ? num_slots_minus_one + 1 : 0; + } + size_type max_bucket_count() const + { + return (AllocatorTraits::max_size(*this) - min_lookups) / sizeof(Entry); + } + uint64_t bucket(const FindKey & key) const + { + return hash_policy.index_for_hash(hash_object(key), num_slots_minus_one); + } + float load_factor() const + { + uint64_t buckets = bucket_count(); + if (buckets) + return static_cast(num_elements) / bucket_count(); + else + return 0; + } + void max_load_factor(float value) + { + _max_load_factor = value; + } + float max_load_factor() const + { + return _max_load_factor; + } + + bool empty() const + { + return num_elements == 0; + } + +private: + EntryPointer entries = empty_default_table(); + uint64_t num_slots_minus_one = 0; + typename HashPolicySelector::type hash_policy; + int8_t max_lookups = detailv3::min_lookups - 1; + float _max_load_factor = 0.5f; + uint64_t num_elements = 0; + std::unique_ptr> sentinel_val; + + // head of doubly linked list + EntryPointer sentinel = initSentinel(); + + EntryPointer initSentinel() { + // needs to be a pointer so that hash map can be used with forward declared types + sentinel_val = c10::guts::make_unique>(); + sentinel = sentinel_val.get(); + reset_list(); + return sentinel; + } + + EntryPointer empty_default_table() + { + EntryPointer result = AllocatorTraits::allocate(*this, detailv3::min_lookups); + EntryPointer special_end_item = result + static_cast(detailv3::min_lookups - 1); + for (EntryPointer it = result; it != special_end_item; ++it) + it->distance_from_desired = -1; + special_end_item->distance_from_desired = Entry::special_end_value; + return result; + } + + static int8_t compute_max_lookups(uint64_t num_buckets) + { + int8_t desired = detailv3::log2(num_buckets); + return std::max(detailv3::min_lookups, desired); + } + + uint64_t num_buckets_for_reserve(uint64_t num_elements) const + { + return static_cast(std::ceil(num_elements / std::min(0.5, static_cast(_max_load_factor)))); + } + void rehash_for_other_container(const sherwood_v3_table & other) + { + rehash(std::min(num_buckets_for_reserve(other.size()), other.bucket_count())); + } + + void swap_pointers(sherwood_v3_table & other) + { + using std::swap; + swap(hash_policy, other.hash_policy); + swap(entries, other.entries); + swap(num_slots_minus_one, other.num_slots_minus_one); + swap(num_elements, other.num_elements); + swap(max_lookups, other.max_lookups); + swap(_max_load_factor, other._max_load_factor); + swap(sentinel, other.sentinel); + swap(sentinel_val, other.sentinel_val); + } + + void reset_list() { + sentinel->next = sentinel; + sentinel->prev = sentinel; + } + + void remove_from_list(EntryPointer elem) { + elem->prev->next = elem->next; + elem->next->prev = elem->prev; + } + + void insert_after(EntryPointer new_elem, EntryPointer prev) { + auto next = prev->next; + + prev->next = new_elem; + new_elem->prev = prev; + + new_elem->next = next; + next->prev = new_elem; + } + + void swap_adjacent_nodes(EntryPointer before, EntryPointer after) { + // sentinel stays consant, so before->prev cannot equal after + auto before_prev = before->prev; + auto after_next = after->next; + + before_prev->next = after; + after->prev = before_prev; + + after_next->prev = before; + before->next = after_next; + + before->prev = after; + after->next = before; + } + + void swap_positions(EntryPointer p1, EntryPointer p2) { + if (p1 == p2) { + return; + } + if (p1->next == p2) { + return swap_adjacent_nodes(p1, p2); + } else if (p2->next == p1) { + return swap_adjacent_nodes(p2, p1); + } + + auto p1_prev = p1->prev; + auto p1_next = p1->next; + + auto p2_prev = p2->prev; + auto p2_next = p2->next; + + p1_prev->next = p2; + p2->prev = p1_prev; + + p1_next->prev = p2; + p2->next = p1_next; + + p2_prev->next = p1; + p1->prev = p2_prev; + + p2_next->prev = p1; + p1->next = p2_next; + } + + void append_to_list(EntryPointer new_tail) { + insert_after(new_tail, sentinel->prev); + } + + template + SKA_NOINLINE(std::pair) emplace_new_key(int8_t distance_from_desired, EntryPointer current_entry, Key && key, Args &&... args) + { + using std::swap; + if (num_slots_minus_one == 0 || distance_from_desired == max_lookups || num_elements + 1 > (num_slots_minus_one + 1) * static_cast(_max_load_factor)) + { + grow(); + return emplace(std::forward(key), std::forward(args)...); + } + else if (current_entry->is_empty()) + { + current_entry->emplace(distance_from_desired, std::forward(key), std::forward(args)...); + ++num_elements; + append_to_list(current_entry); + return { { current_entry }, true }; + } + value_type to_insert(std::forward(key), std::forward(args)...); + swap(distance_from_desired, current_entry->distance_from_desired); + // We maintain the invariant that: + // - result.current_entry contains the new value we're inserting + // and is in the LinkedList position of to_insert + // - to_insert contains the value that reprseents the position of + // result.current_entry + swap(to_insert, current_entry->value); + iterator result = { current_entry }; + for (++distance_from_desired, ++current_entry;; ++current_entry) + { + if (current_entry->is_empty()) + { + current_entry->emplace(distance_from_desired, std::move(to_insert)); + append_to_list(current_entry); + // now we can swap back the displaced value to its correct position, + // putting the new value we're inserting to the front of the list + swap_positions(current_entry, result.current); + ++num_elements; + return { result, true }; + } + else if (current_entry->distance_from_desired < distance_from_desired) + { + swap(distance_from_desired, current_entry->distance_from_desired); + swap(to_insert, current_entry->value); + // to maintain our invariants we need to swap positions + // of result.current & current_entry: + swap_positions(result.current, current_entry); + ++distance_from_desired; + } + else + { + ++distance_from_desired; + if (distance_from_desired == max_lookups) + { + // the displaced element gets put back into its correct position + // we grow the hash table, and then try again to reinsert the new element + swap(to_insert, result.current->value); + grow(); + return emplace(std::move(to_insert)); + } + } + } + } + + void grow() + { + rehash(std::max(uint64_t(4), 2 * bucket_count())); + } + + void deallocate_data(EntryPointer begin, uint64_t num_slots_minus_one, int8_t max_lookups) + { + AllocatorTraits::deallocate(*this, begin, num_slots_minus_one + max_lookups + 1); + } + + void reset_to_empty_state() + { + deallocate_data(entries, num_slots_minus_one, max_lookups); + entries = empty_default_table(); + num_slots_minus_one = 0; + hash_policy.reset(); + max_lookups = detailv3::min_lookups - 1; + } + + template + uint64_t hash_object(const U & key) + { + return static_cast(*this)(key); + } + template + uint64_t hash_object(const U & key) const + { + return static_cast(*this)(key); + } + template + bool compares_equal(const L & lhs, const R & rhs) + { + return static_cast(*this)(lhs, rhs); + } + +public: + struct convertible_to_iterator + { + EntryPointer it; + + operator iterator() + { + if (it->has_value()) + return { it }; + else + return ++iterator{it}; + } + operator const_iterator() + { + if (it->has_value()) + return { it }; + else + return ++const_iterator{it}; + } + }; + +}; +} + +struct prime_number_hash_policy +{ + static uint64_t mod0(uint64_t) { return 0llu; } + static uint64_t mod2(uint64_t hash) { return hash % 2llu; } + static uint64_t mod3(uint64_t hash) { return hash % 3llu; } + static uint64_t mod5(uint64_t hash) { return hash % 5llu; } + static uint64_t mod7(uint64_t hash) { return hash % 7llu; } + static uint64_t mod11(uint64_t hash) { return hash % 11llu; } + static uint64_t mod13(uint64_t hash) { return hash % 13llu; } + static uint64_t mod17(uint64_t hash) { return hash % 17llu; } + static uint64_t mod23(uint64_t hash) { return hash % 23llu; } + static uint64_t mod29(uint64_t hash) { return hash % 29llu; } + static uint64_t mod37(uint64_t hash) { return hash % 37llu; } + static uint64_t mod47(uint64_t hash) { return hash % 47llu; } + static uint64_t mod59(uint64_t hash) { return hash % 59llu; } + static uint64_t mod73(uint64_t hash) { return hash % 73llu; } + static uint64_t mod97(uint64_t hash) { return hash % 97llu; } + static uint64_t mod127(uint64_t hash) { return hash % 127llu; } + static uint64_t mod151(uint64_t hash) { return hash % 151llu; } + static uint64_t mod197(uint64_t hash) { return hash % 197llu; } + static uint64_t mod251(uint64_t hash) { return hash % 251llu; } + static uint64_t mod313(uint64_t hash) { return hash % 313llu; } + static uint64_t mod397(uint64_t hash) { return hash % 397llu; } + static uint64_t mod499(uint64_t hash) { return hash % 499llu; } + static uint64_t mod631(uint64_t hash) { return hash % 631llu; } + static uint64_t mod797(uint64_t hash) { return hash % 797llu; } + static uint64_t mod1009(uint64_t hash) { return hash % 1009llu; } + static uint64_t mod1259(uint64_t hash) { return hash % 1259llu; } + static uint64_t mod1597(uint64_t hash) { return hash % 1597llu; } + static uint64_t mod2011(uint64_t hash) { return hash % 2011llu; } + static uint64_t mod2539(uint64_t hash) { return hash % 2539llu; } + static uint64_t mod3203(uint64_t hash) { return hash % 3203llu; } + static uint64_t mod4027(uint64_t hash) { return hash % 4027llu; } + static uint64_t mod5087(uint64_t hash) { return hash % 5087llu; } + static uint64_t mod6421(uint64_t hash) { return hash % 6421llu; } + static uint64_t mod8089(uint64_t hash) { return hash % 8089llu; } + static uint64_t mod10193(uint64_t hash) { return hash % 10193llu; } + static uint64_t mod12853(uint64_t hash) { return hash % 12853llu; } + static uint64_t mod16193(uint64_t hash) { return hash % 16193llu; } + static uint64_t mod20399(uint64_t hash) { return hash % 20399llu; } + static uint64_t mod25717(uint64_t hash) { return hash % 25717llu; } + static uint64_t mod32401(uint64_t hash) { return hash % 32401llu; } + static uint64_t mod40823(uint64_t hash) { return hash % 40823llu; } + static uint64_t mod51437(uint64_t hash) { return hash % 51437llu; } + static uint64_t mod64811(uint64_t hash) { return hash % 64811llu; } + static uint64_t mod81649(uint64_t hash) { return hash % 81649llu; } + static uint64_t mod102877(uint64_t hash) { return hash % 102877llu; } + static uint64_t mod129607(uint64_t hash) { return hash % 129607llu; } + static uint64_t mod163307(uint64_t hash) { return hash % 163307llu; } + static uint64_t mod205759(uint64_t hash) { return hash % 205759llu; } + static uint64_t mod259229(uint64_t hash) { return hash % 259229llu; } + static uint64_t mod326617(uint64_t hash) { return hash % 326617llu; } + static uint64_t mod411527(uint64_t hash) { return hash % 411527llu; } + static uint64_t mod518509(uint64_t hash) { return hash % 518509llu; } + static uint64_t mod653267(uint64_t hash) { return hash % 653267llu; } + static uint64_t mod823117(uint64_t hash) { return hash % 823117llu; } + static uint64_t mod1037059(uint64_t hash) { return hash % 1037059llu; } + static uint64_t mod1306601(uint64_t hash) { return hash % 1306601llu; } + static uint64_t mod1646237(uint64_t hash) { return hash % 1646237llu; } + static uint64_t mod2074129(uint64_t hash) { return hash % 2074129llu; } + static uint64_t mod2613229(uint64_t hash) { return hash % 2613229llu; } + static uint64_t mod3292489(uint64_t hash) { return hash % 3292489llu; } + static uint64_t mod4148279(uint64_t hash) { return hash % 4148279llu; } + static uint64_t mod5226491(uint64_t hash) { return hash % 5226491llu; } + static uint64_t mod6584983(uint64_t hash) { return hash % 6584983llu; } + static uint64_t mod8296553(uint64_t hash) { return hash % 8296553llu; } + static uint64_t mod10453007(uint64_t hash) { return hash % 10453007llu; } + static uint64_t mod13169977(uint64_t hash) { return hash % 13169977llu; } + static uint64_t mod16593127(uint64_t hash) { return hash % 16593127llu; } + static uint64_t mod20906033(uint64_t hash) { return hash % 20906033llu; } + static uint64_t mod26339969(uint64_t hash) { return hash % 26339969llu; } + static uint64_t mod33186281(uint64_t hash) { return hash % 33186281llu; } + static uint64_t mod41812097(uint64_t hash) { return hash % 41812097llu; } + static uint64_t mod52679969(uint64_t hash) { return hash % 52679969llu; } + static uint64_t mod66372617(uint64_t hash) { return hash % 66372617llu; } + static uint64_t mod83624237(uint64_t hash) { return hash % 83624237llu; } + static uint64_t mod105359939(uint64_t hash) { return hash % 105359939llu; } + static uint64_t mod132745199(uint64_t hash) { return hash % 132745199llu; } + static uint64_t mod167248483(uint64_t hash) { return hash % 167248483llu; } + static uint64_t mod210719881(uint64_t hash) { return hash % 210719881llu; } + static uint64_t mod265490441(uint64_t hash) { return hash % 265490441llu; } + static uint64_t mod334496971(uint64_t hash) { return hash % 334496971llu; } + static uint64_t mod421439783(uint64_t hash) { return hash % 421439783llu; } + static uint64_t mod530980861(uint64_t hash) { return hash % 530980861llu; } + static uint64_t mod668993977(uint64_t hash) { return hash % 668993977llu; } + static uint64_t mod842879579(uint64_t hash) { return hash % 842879579llu; } + static uint64_t mod1061961721(uint64_t hash) { return hash % 1061961721llu; } + static uint64_t mod1337987929(uint64_t hash) { return hash % 1337987929llu; } + static uint64_t mod1685759167(uint64_t hash) { return hash % 1685759167llu; } + static uint64_t mod2123923447(uint64_t hash) { return hash % 2123923447llu; } + static uint64_t mod2675975881(uint64_t hash) { return hash % 2675975881llu; } + static uint64_t mod3371518343(uint64_t hash) { return hash % 3371518343llu; } + static uint64_t mod4247846927(uint64_t hash) { return hash % 4247846927llu; } + static uint64_t mod5351951779(uint64_t hash) { return hash % 5351951779llu; } + static uint64_t mod6743036717(uint64_t hash) { return hash % 6743036717llu; } + static uint64_t mod8495693897(uint64_t hash) { return hash % 8495693897llu; } + static uint64_t mod10703903591(uint64_t hash) { return hash % 10703903591llu; } + static uint64_t mod13486073473(uint64_t hash) { return hash % 13486073473llu; } + static uint64_t mod16991387857(uint64_t hash) { return hash % 16991387857llu; } + static uint64_t mod21407807219(uint64_t hash) { return hash % 21407807219llu; } + static uint64_t mod26972146961(uint64_t hash) { return hash % 26972146961llu; } + static uint64_t mod33982775741(uint64_t hash) { return hash % 33982775741llu; } + static uint64_t mod42815614441(uint64_t hash) { return hash % 42815614441llu; } + static uint64_t mod53944293929(uint64_t hash) { return hash % 53944293929llu; } + static uint64_t mod67965551447(uint64_t hash) { return hash % 67965551447llu; } + static uint64_t mod85631228929(uint64_t hash) { return hash % 85631228929llu; } + static uint64_t mod107888587883(uint64_t hash) { return hash % 107888587883llu; } + static uint64_t mod135931102921(uint64_t hash) { return hash % 135931102921llu; } + static uint64_t mod171262457903(uint64_t hash) { return hash % 171262457903llu; } + static uint64_t mod215777175787(uint64_t hash) { return hash % 215777175787llu; } + static uint64_t mod271862205833(uint64_t hash) { return hash % 271862205833llu; } + static uint64_t mod342524915839(uint64_t hash) { return hash % 342524915839llu; } + static uint64_t mod431554351609(uint64_t hash) { return hash % 431554351609llu; } + static uint64_t mod543724411781(uint64_t hash) { return hash % 543724411781llu; } + static uint64_t mod685049831731(uint64_t hash) { return hash % 685049831731llu; } + static uint64_t mod863108703229(uint64_t hash) { return hash % 863108703229llu; } + static uint64_t mod1087448823553(uint64_t hash) { return hash % 1087448823553llu; } + static uint64_t mod1370099663459(uint64_t hash) { return hash % 1370099663459llu; } + static uint64_t mod1726217406467(uint64_t hash) { return hash % 1726217406467llu; } + static uint64_t mod2174897647073(uint64_t hash) { return hash % 2174897647073llu; } + static uint64_t mod2740199326961(uint64_t hash) { return hash % 2740199326961llu; } + static uint64_t mod3452434812973(uint64_t hash) { return hash % 3452434812973llu; } + static uint64_t mod4349795294267(uint64_t hash) { return hash % 4349795294267llu; } + static uint64_t mod5480398654009(uint64_t hash) { return hash % 5480398654009llu; } + static uint64_t mod6904869625999(uint64_t hash) { return hash % 6904869625999llu; } + static uint64_t mod8699590588571(uint64_t hash) { return hash % 8699590588571llu; } + static uint64_t mod10960797308051(uint64_t hash) { return hash % 10960797308051llu; } + static uint64_t mod13809739252051(uint64_t hash) { return hash % 13809739252051llu; } + static uint64_t mod17399181177241(uint64_t hash) { return hash % 17399181177241llu; } + static uint64_t mod21921594616111(uint64_t hash) { return hash % 21921594616111llu; } + static uint64_t mod27619478504183(uint64_t hash) { return hash % 27619478504183llu; } + static uint64_t mod34798362354533(uint64_t hash) { return hash % 34798362354533llu; } + static uint64_t mod43843189232363(uint64_t hash) { return hash % 43843189232363llu; } + static uint64_t mod55238957008387(uint64_t hash) { return hash % 55238957008387llu; } + static uint64_t mod69596724709081(uint64_t hash) { return hash % 69596724709081llu; } + static uint64_t mod87686378464759(uint64_t hash) { return hash % 87686378464759llu; } + static uint64_t mod110477914016779(uint64_t hash) { return hash % 110477914016779llu; } + static uint64_t mod139193449418173(uint64_t hash) { return hash % 139193449418173llu; } + static uint64_t mod175372756929481(uint64_t hash) { return hash % 175372756929481llu; } + static uint64_t mod220955828033581(uint64_t hash) { return hash % 220955828033581llu; } + static uint64_t mod278386898836457(uint64_t hash) { return hash % 278386898836457llu; } + static uint64_t mod350745513859007(uint64_t hash) { return hash % 350745513859007llu; } + static uint64_t mod441911656067171(uint64_t hash) { return hash % 441911656067171llu; } + static uint64_t mod556773797672909(uint64_t hash) { return hash % 556773797672909llu; } + static uint64_t mod701491027718027(uint64_t hash) { return hash % 701491027718027llu; } + static uint64_t mod883823312134381(uint64_t hash) { return hash % 883823312134381llu; } + static uint64_t mod1113547595345903(uint64_t hash) { return hash % 1113547595345903llu; } + static uint64_t mod1402982055436147(uint64_t hash) { return hash % 1402982055436147llu; } + static uint64_t mod1767646624268779(uint64_t hash) { return hash % 1767646624268779llu; } + static uint64_t mod2227095190691797(uint64_t hash) { return hash % 2227095190691797llu; } + static uint64_t mod2805964110872297(uint64_t hash) { return hash % 2805964110872297llu; } + static uint64_t mod3535293248537579(uint64_t hash) { return hash % 3535293248537579llu; } + static uint64_t mod4454190381383713(uint64_t hash) { return hash % 4454190381383713llu; } + static uint64_t mod5611928221744609(uint64_t hash) { return hash % 5611928221744609llu; } + static uint64_t mod7070586497075177(uint64_t hash) { return hash % 7070586497075177llu; } + static uint64_t mod8908380762767489(uint64_t hash) { return hash % 8908380762767489llu; } + static uint64_t mod11223856443489329(uint64_t hash) { return hash % 11223856443489329llu; } + static uint64_t mod14141172994150357(uint64_t hash) { return hash % 14141172994150357llu; } + static uint64_t mod17816761525534927(uint64_t hash) { return hash % 17816761525534927llu; } + static uint64_t mod22447712886978529(uint64_t hash) { return hash % 22447712886978529llu; } + static uint64_t mod28282345988300791(uint64_t hash) { return hash % 28282345988300791llu; } + static uint64_t mod35633523051069991(uint64_t hash) { return hash % 35633523051069991llu; } + static uint64_t mod44895425773957261(uint64_t hash) { return hash % 44895425773957261llu; } + static uint64_t mod56564691976601587(uint64_t hash) { return hash % 56564691976601587llu; } + static uint64_t mod71267046102139967(uint64_t hash) { return hash % 71267046102139967llu; } + static uint64_t mod89790851547914507(uint64_t hash) { return hash % 89790851547914507llu; } + static uint64_t mod113129383953203213(uint64_t hash) { return hash % 113129383953203213llu; } + static uint64_t mod142534092204280003(uint64_t hash) { return hash % 142534092204280003llu; } + static uint64_t mod179581703095829107(uint64_t hash) { return hash % 179581703095829107llu; } + static uint64_t mod226258767906406483(uint64_t hash) { return hash % 226258767906406483llu; } + static uint64_t mod285068184408560057(uint64_t hash) { return hash % 285068184408560057llu; } + static uint64_t mod359163406191658253(uint64_t hash) { return hash % 359163406191658253llu; } + static uint64_t mod452517535812813007(uint64_t hash) { return hash % 452517535812813007llu; } + static uint64_t mod570136368817120201(uint64_t hash) { return hash % 570136368817120201llu; } + static uint64_t mod718326812383316683(uint64_t hash) { return hash % 718326812383316683llu; } + static uint64_t mod905035071625626043(uint64_t hash) { return hash % 905035071625626043llu; } + static uint64_t mod1140272737634240411(uint64_t hash) { return hash % 1140272737634240411llu; } + static uint64_t mod1436653624766633509(uint64_t hash) { return hash % 1436653624766633509llu; } + static uint64_t mod1810070143251252131(uint64_t hash) { return hash % 1810070143251252131llu; } + static uint64_t mod2280545475268481167(uint64_t hash) { return hash % 2280545475268481167llu; } + static uint64_t mod2873307249533267101(uint64_t hash) { return hash % 2873307249533267101llu; } + static uint64_t mod3620140286502504283(uint64_t hash) { return hash % 3620140286502504283llu; } + static uint64_t mod4561090950536962147(uint64_t hash) { return hash % 4561090950536962147llu; } + static uint64_t mod5746614499066534157(uint64_t hash) { return hash % 5746614499066534157llu; } + static uint64_t mod7240280573005008577(uint64_t hash) { return hash % 7240280573005008577llu; } + static uint64_t mod9122181901073924329(uint64_t hash) { return hash % 9122181901073924329llu; } + static uint64_t mod11493228998133068689(uint64_t hash) { return hash % 11493228998133068689llu; } + static uint64_t mod14480561146010017169(uint64_t hash) { return hash % 14480561146010017169llu; } + static uint64_t mod18446744073709551557(uint64_t hash) { return hash % 18446744073709551557llu; } + + using mod_function = uint64_t (*)(uint64_t); + + mod_function next_size_over(uint64_t & size) const + { + // prime numbers generated by the following method: + // 1. start with a prime p = 2 + // 2. go to wolfram alpha and get p = NextPrime(2 * p) + // 3. repeat 2. until you overflow 64 bits + // you now have large gaps which you would hit if somebody called reserve() with an unlucky number. + // 4. to fill the gaps for every prime p go to wolfram alpha and get ClosestPrime(p * 2^(1/3)) and ClosestPrime(p * 2^(2/3)) and put those in the gaps + // 5. get PrevPrime(2^64) and put it at the end + static constexpr const uint64_t prime_list[] = + { + 2llu, 3llu, 5llu, 7llu, 11llu, 13llu, 17llu, 23llu, 29llu, 37llu, 47llu, + 59llu, 73llu, 97llu, 127llu, 151llu, 197llu, 251llu, 313llu, 397llu, + 499llu, 631llu, 797llu, 1009llu, 1259llu, 1597llu, 2011llu, 2539llu, + 3203llu, 4027llu, 5087llu, 6421llu, 8089llu, 10193llu, 12853llu, 16193llu, + 20399llu, 25717llu, 32401llu, 40823llu, 51437llu, 64811llu, 81649llu, + 102877llu, 129607llu, 163307llu, 205759llu, 259229llu, 326617llu, + 411527llu, 518509llu, 653267llu, 823117llu, 1037059llu, 1306601llu, + 1646237llu, 2074129llu, 2613229llu, 3292489llu, 4148279llu, 5226491llu, + 6584983llu, 8296553llu, 10453007llu, 13169977llu, 16593127llu, 20906033llu, + 26339969llu, 33186281llu, 41812097llu, 52679969llu, 66372617llu, + 83624237llu, 105359939llu, 132745199llu, 167248483llu, 210719881llu, + 265490441llu, 334496971llu, 421439783llu, 530980861llu, 668993977llu, + 842879579llu, 1061961721llu, 1337987929llu, 1685759167llu, 2123923447llu, + 2675975881llu, 3371518343llu, 4247846927llu, 5351951779llu, 6743036717llu, + 8495693897llu, 10703903591llu, 13486073473llu, 16991387857llu, + 21407807219llu, 26972146961llu, 33982775741llu, 42815614441llu, + 53944293929llu, 67965551447llu, 85631228929llu, 107888587883llu, + 135931102921llu, 171262457903llu, 215777175787llu, 271862205833llu, + 342524915839llu, 431554351609llu, 543724411781llu, 685049831731llu, + 863108703229llu, 1087448823553llu, 1370099663459llu, 1726217406467llu, + 2174897647073llu, 2740199326961llu, 3452434812973llu, 4349795294267llu, + 5480398654009llu, 6904869625999llu, 8699590588571llu, 10960797308051llu, + 13809739252051llu, 17399181177241llu, 21921594616111llu, 27619478504183llu, + 34798362354533llu, 43843189232363llu, 55238957008387llu, 69596724709081llu, + 87686378464759llu, 110477914016779llu, 139193449418173llu, + 175372756929481llu, 220955828033581llu, 278386898836457llu, + 350745513859007llu, 441911656067171llu, 556773797672909llu, + 701491027718027llu, 883823312134381llu, 1113547595345903llu, + 1402982055436147llu, 1767646624268779llu, 2227095190691797llu, + 2805964110872297llu, 3535293248537579llu, 4454190381383713llu, + 5611928221744609llu, 7070586497075177llu, 8908380762767489llu, + 11223856443489329llu, 14141172994150357llu, 17816761525534927llu, + 22447712886978529llu, 28282345988300791llu, 35633523051069991llu, + 44895425773957261llu, 56564691976601587llu, 71267046102139967llu, + 89790851547914507llu, 113129383953203213llu, 142534092204280003llu, + 179581703095829107llu, 226258767906406483llu, 285068184408560057llu, + 359163406191658253llu, 452517535812813007llu, 570136368817120201llu, + 718326812383316683llu, 905035071625626043llu, 1140272737634240411llu, + 1436653624766633509llu, 1810070143251252131llu, 2280545475268481167llu, + 2873307249533267101llu, 3620140286502504283llu, 4561090950536962147llu, + 5746614499066534157llu, 7240280573005008577llu, 9122181901073924329llu, + 11493228998133068689llu, 14480561146010017169llu, 18446744073709551557llu + }; + static constexpr uint64_t (* const mod_functions[])(uint64_t) = + { + &mod0, &mod2, &mod3, &mod5, &mod7, &mod11, &mod13, &mod17, &mod23, &mod29, &mod37, + &mod47, &mod59, &mod73, &mod97, &mod127, &mod151, &mod197, &mod251, &mod313, &mod397, + &mod499, &mod631, &mod797, &mod1009, &mod1259, &mod1597, &mod2011, &mod2539, &mod3203, + &mod4027, &mod5087, &mod6421, &mod8089, &mod10193, &mod12853, &mod16193, &mod20399, + &mod25717, &mod32401, &mod40823, &mod51437, &mod64811, &mod81649, &mod102877, + &mod129607, &mod163307, &mod205759, &mod259229, &mod326617, &mod411527, &mod518509, + &mod653267, &mod823117, &mod1037059, &mod1306601, &mod1646237, &mod2074129, + &mod2613229, &mod3292489, &mod4148279, &mod5226491, &mod6584983, &mod8296553, + &mod10453007, &mod13169977, &mod16593127, &mod20906033, &mod26339969, &mod33186281, + &mod41812097, &mod52679969, &mod66372617, &mod83624237, &mod105359939, &mod132745199, + &mod167248483, &mod210719881, &mod265490441, &mod334496971, &mod421439783, + &mod530980861, &mod668993977, &mod842879579, &mod1061961721, &mod1337987929, + &mod1685759167, &mod2123923447, &mod2675975881, &mod3371518343, &mod4247846927, + &mod5351951779, &mod6743036717, &mod8495693897, &mod10703903591, &mod13486073473, + &mod16991387857, &mod21407807219, &mod26972146961, &mod33982775741, &mod42815614441, + &mod53944293929, &mod67965551447, &mod85631228929, &mod107888587883, &mod135931102921, + &mod171262457903, &mod215777175787, &mod271862205833, &mod342524915839, + &mod431554351609, &mod543724411781, &mod685049831731, &mod863108703229, + &mod1087448823553, &mod1370099663459, &mod1726217406467, &mod2174897647073, + &mod2740199326961, &mod3452434812973, &mod4349795294267, &mod5480398654009, + &mod6904869625999, &mod8699590588571, &mod10960797308051, &mod13809739252051, + &mod17399181177241, &mod21921594616111, &mod27619478504183, &mod34798362354533, + &mod43843189232363, &mod55238957008387, &mod69596724709081, &mod87686378464759, + &mod110477914016779, &mod139193449418173, &mod175372756929481, &mod220955828033581, + &mod278386898836457, &mod350745513859007, &mod441911656067171, &mod556773797672909, + &mod701491027718027, &mod883823312134381, &mod1113547595345903, &mod1402982055436147, + &mod1767646624268779, &mod2227095190691797, &mod2805964110872297, &mod3535293248537579, + &mod4454190381383713, &mod5611928221744609, &mod7070586497075177, &mod8908380762767489, + &mod11223856443489329, &mod14141172994150357, &mod17816761525534927, + &mod22447712886978529, &mod28282345988300791, &mod35633523051069991, + &mod44895425773957261, &mod56564691976601587, &mod71267046102139967, + &mod89790851547914507, &mod113129383953203213, &mod142534092204280003, + &mod179581703095829107, &mod226258767906406483, &mod285068184408560057, + &mod359163406191658253, &mod452517535812813007, &mod570136368817120201, + &mod718326812383316683, &mod905035071625626043, &mod1140272737634240411, + &mod1436653624766633509, &mod1810070143251252131, &mod2280545475268481167, + &mod2873307249533267101, &mod3620140286502504283, &mod4561090950536962147, + &mod5746614499066534157, &mod7240280573005008577, &mod9122181901073924329, + &mod11493228998133068689, &mod14480561146010017169, &mod18446744073709551557 + }; + const uint64_t * found = std::lower_bound(std::begin(prime_list), std::end(prime_list) - 1, size); + size = *found; + return mod_functions[1 + found - prime_list]; + } + void commit(mod_function new_mod_function) + { + current_mod_function = new_mod_function; + } + void reset() + { + current_mod_function = &mod0; + } + + uint64_t index_for_hash(uint64_t hash, uint64_t /*num_slots_minus_one*/) const + { + return current_mod_function(hash); + } + uint64_t keep_in_range(uint64_t index, uint64_t num_slots_minus_one) const + { + return index > num_slots_minus_one ? current_mod_function(index) : index; + } + +private: + mod_function current_mod_function = &mod0; +}; + +struct power_of_two_hash_policy +{ + uint64_t index_for_hash(uint64_t hash, uint64_t num_slots_minus_one) const + { + return hash & num_slots_minus_one; + } + uint64_t keep_in_range(uint64_t index, uint64_t num_slots_minus_one) const + { + return index_for_hash(index, num_slots_minus_one); + } + int8_t next_size_over(uint64_t & size) const + { + size = detailv3::next_power_of_two(size); + return 0; + } + void commit(int8_t) + { + } + void reset() + { + } + +}; + +struct fibonacci_hash_policy +{ + uint64_t index_for_hash(uint64_t hash, uint64_t /*num_slots_minus_one*/) const + { + return (11400714819323198485ull * hash) >> shift; + } + uint64_t keep_in_range(uint64_t index, uint64_t num_slots_minus_one) const + { + return index & num_slots_minus_one; + } + + int8_t next_size_over(uint64_t & size) const + { + size = std::max(uint64_t(2), detailv3::next_power_of_two(size)); + return 64 - detailv3::log2(size); + } + void commit(int8_t shift) + { + this->shift = shift; + } + void reset() + { + shift = 63; + } + +private: + int8_t shift = 63; +}; + +template, typename E = std::equal_to, typename A = std::allocator > > +class order_preserving_flat_hash_map + : public detailv3::sherwood_v3_table + < + std::pair, + K, + H, + detailv3::KeyOrValueHasher, H>, + E, + detailv3::KeyOrValueEquality, E>, + A, + typename std::allocator_traits::template rebind_alloc>> + > +{ + using Table = detailv3::sherwood_v3_table + < + std::pair, + K, + H, + detailv3::KeyOrValueHasher, H>, + E, + detailv3::KeyOrValueEquality, E>, + A, + typename std::allocator_traits::template rebind_alloc>> + >; +public: + + using key_type = K; + using mapped_type = V; + + using Table::Table; + order_preserving_flat_hash_map() + { + } + + inline V & operator[](const K & key) + { + return emplace(key, convertible_to_value()).first->second; + } + inline V & operator[](K && key) + { + return emplace(std::move(key), convertible_to_value()).first->second; + } + V & at(const K & key) + { + auto found = this->find(key); + if (found == this->end()) + throw std::out_of_range("Argument passed to at() was not in the map."); + return found->second; + } + const V & at(const K & key) const + { + auto found = this->find(key); + if (found == this->end()) + throw std::out_of_range("Argument passed to at() was not in the map."); + return found->second; + } + + using Table::emplace; + std::pair emplace() + { + return emplace(key_type(), convertible_to_value()); + } + template + std::pair insert_or_assign(const key_type & key, M && m) + { + auto emplace_result = emplace(key, std::forward(m)); + if (!emplace_result.second) + emplace_result.first->second = std::forward(m); + return emplace_result; + } + template + std::pair insert_or_assign(key_type && key, M && m) + { + auto emplace_result = emplace(std::move(key), std::forward(m)); + if (!emplace_result.second) + emplace_result.first->second = std::forward(m); + return emplace_result; + } + template + typename Table::iterator insert_or_assign(typename Table::const_iterator, const key_type & key, M && m) + { + return insert_or_assign(key, std::forward(m)).first; + } + template + typename Table::iterator insert_or_assign(typename Table::const_iterator, key_type && key, M && m) + { + return insert_or_assign(std::move(key), std::forward(m)).first; + } + + friend bool operator==(const order_preserving_flat_hash_map & lhs, const order_preserving_flat_hash_map & rhs) + { + if (lhs.size() != rhs.size()) + return false; + for (const typename Table::value_type & value : lhs) + { + auto found = rhs.find(value.first); + if (found == rhs.end()) + return false; + else if (value.second != found->second) + return false; + } + return true; + } + friend bool operator!=(const order_preserving_flat_hash_map & lhs, const order_preserving_flat_hash_map & rhs) + { + return !(lhs == rhs); + } + +private: + struct convertible_to_value + { + operator V() const + { + return V(); + } + }; +}; + +template, typename E = std::equal_to, typename A = std::allocator > +class flat_hash_set + : public detailv3::sherwood_v3_table + < + T, + T, + H, + detailv3::functor_storage, + E, + detailv3::functor_storage, + A, + typename std::allocator_traits::template rebind_alloc> + > +{ + using Table = detailv3::sherwood_v3_table + < + T, + T, + H, + detailv3::functor_storage, + E, + detailv3::functor_storage, + A, + typename std::allocator_traits::template rebind_alloc> + >; +public: + + using key_type = T; + + using Table::Table; + flat_hash_set() + { + } + + template + std::pair emplace(Args &&... args) + { + return Table::emplace(T(std::forward(args)...)); + } + std::pair emplace(const key_type & arg) + { + return Table::emplace(arg); + } + std::pair emplace(key_type & arg) + { + return Table::emplace(arg); + } + std::pair emplace(const key_type && arg) + { + return Table::emplace(std::move(arg)); + } + std::pair emplace(key_type && arg) + { + return Table::emplace(std::move(arg)); + } + + friend bool operator==(const flat_hash_set & lhs, const flat_hash_set & rhs) + { + if (lhs.size() != rhs.size()) + return false; + for (const T & value : lhs) + { + if (rhs.find(value) == rhs.end()) + return false; + } + return true; + } + friend bool operator!=(const flat_hash_set & lhs, const flat_hash_set & rhs) + { + return !(lhs == rhs); + } +}; + + +template +struct power_of_two_std_hash : std::hash +{ + typedef ska_ordered::power_of_two_hash_policy hash_policy; +}; + +} // end namespace ska + +#ifndef _MSC_VER +#pragma GCC diagnostic pop +#endif diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index dd8c73d1bbbed..2b82999f59881 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -59,7 +59,6 @@ endif() # addressed yet. add_subdirectory(core) -add_subdirectory(proto) add_subdirectory(serialize) add_subdirectory(utils) add_subdirectory(perfkernels) @@ -95,6 +94,7 @@ if (NOT INTERN_BUILD_MOBILE OR BUILD_CAFFE2_MOBILE) endif() endif() add_subdirectory(opt) + add_subdirectory(proto) add_subdirectory(python) add_subdirectory(queue) add_subdirectory(sgd) @@ -170,43 +170,44 @@ if (FALSE) endforeach() endif() -# ---[ List of libraries to link with -add_library(caffe2_protos STATIC $) -add_dependencies(caffe2_protos Caffe2_PROTO) -# If we are going to link protobuf locally inside caffe2 libraries, what we will do is -# to create a helper static library that always contains libprotobuf source files, and -# link the caffe2 related dependent libraries to it. -target_include_directories(caffe2_protos INTERFACE $) -# Reason for this public dependency is as follows: -# (1) Strictly speaking, we should not expose any Protobuf related functions. We should -# only use function interfaces wrapped with our own public API, and link protobuf -# locally. -# (2) However, currently across the Caffe2 codebase, we have extensive use of protobuf -# functionalities. For example, not only libcaffe2.so uses it, but also other -# binaries such as python extensions etc. As a result, we will have to have a -# transitive dependency to libprotobuf. -# -# Good thing is that, if we specify CAFFE2_LINK_LOCAL_PROTOBUF, then we do not need to -# separately deploy protobuf binaries - libcaffe2.so will contain all functionalities -# one needs. One can verify this via ldd. -# -# TODO item in the future includes: -# (1) Enable using lite protobuf -# (2) Properly define public API that do not directly depend on protobuf itself. -# (3) Expose the libprotobuf.a file for dependent libraries to link to. -# -# What it means for users/developers? -# (1) Users: nothing affecting the users, other than the fact that CAFFE2_LINK_LOCAL_PROTOBUF -# avoids the need to deploy protobuf. -# (2) Developers: if one simply uses core caffe2 functionality without using protobuf, -# nothing changes. If one has a dependent library that uses protobuf, then one needs to -# have the right protobuf version as well as linking to libprotobuf.a. -target_link_libraries(caffe2_protos PUBLIC protobuf::libprotobuf) -if (NOT BUILD_SHARED_LIBS) - INSTALL(TARGETS caffe2_protos ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}") +if (NOT INTERN_BUILD_MOBILE OR BUILD_CAFFE2_MOBILE) + # ---[ List of libraries to link with + add_library(caffe2_protos STATIC $) + add_dependencies(caffe2_protos Caffe2_PROTO) + # If we are going to link protobuf locally inside caffe2 libraries, what we will do is + # to create a helper static library that always contains libprotobuf source files, and + # link the caffe2 related dependent libraries to it. + target_include_directories(caffe2_protos INTERFACE $) + # Reason for this public dependency is as follows: + # (1) Strictly speaking, we should not expose any Protobuf related functions. We should + # only use function interfaces wrapped with our own public API, and link protobuf + # locally. + # (2) However, currently across the Caffe2 codebase, we have extensive use of protobuf + # functionalities. For example, not only libcaffe2.so uses it, but also other + # binaries such as python extensions etc. As a result, we will have to have a + # transitive dependency to libprotobuf. + # + # Good thing is that, if we specify CAFFE2_LINK_LOCAL_PROTOBUF, then we do not need to + # separately deploy protobuf binaries - libcaffe2.so will contain all functionalities + # one needs. One can verify this via ldd. + # + # TODO item in the future includes: + # (1) Enable using lite protobuf + # (2) Properly define public API that do not directly depend on protobuf itself. + # (3) Expose the libprotobuf.a file for dependent libraries to link to. + # + # What it means for users/developers? + # (1) Users: nothing affecting the users, other than the fact that CAFFE2_LINK_LOCAL_PROTOBUF + # avoids the need to deploy protobuf. + # (2) Developers: if one simply uses core caffe2 functionality without using protobuf, + # nothing changes. If one has a dependent library that uses protobuf, then one needs to + # have the right protobuf version as well as linking to libprotobuf.a. + target_link_libraries(caffe2_protos PUBLIC protobuf::libprotobuf) + if (NOT BUILD_SHARED_LIBS) + INSTALL(TARGETS caffe2_protos ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}") + endif() endif() - # ========================================================== # formerly-libtorch # ========================================================== @@ -247,22 +248,32 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) set(GENERATED_CXX_TORCH "${TORCH_SRC_DIR}/csrc/autograd/generated/Functions.cpp" - "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType_0.cpp" - "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType_1.cpp" - "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType_2.cpp" - "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType_3.cpp" - "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType_4.cpp" "${TORCH_SRC_DIR}/csrc/jit/generated/register_aten_ops_0.cpp" "${TORCH_SRC_DIR}/csrc/jit/generated/register_aten_ops_1.cpp" "${TORCH_SRC_DIR}/csrc/jit/generated/register_aten_ops_2.cpp" ) + if(NOT INTERN_DISABLE_AUTOGRAD) + list(APPEND GENERATED_CXX_TORCH + "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType_0.cpp" + "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType_1.cpp" + "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType_2.cpp" + "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType_3.cpp" + "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType_4.cpp" + ) + endif() + set(GENERATED_H_TORCH - "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType.h" "${TORCH_SRC_DIR}/csrc/autograd/generated/Functions.h" "${TORCH_SRC_DIR}/csrc/autograd/generated/variable_factories.h" ) + if(NOT INTERN_DISABLE_AUTOGRAD) + list(APPEND GENERATED_H_TORCH + "${TORCH_SRC_DIR}/csrc/autograd/generated/VariableType.h" + ) + endif() + set(GENERATED_CXX_PYTHON "${TORCH_SRC_DIR}/csrc/autograd/generated/python_functions.cpp" "${TORCH_SRC_DIR}/csrc/autograd/generated/python_variable_methods.cpp" @@ -292,6 +303,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) "${PYTHON_EXECUTABLE}" tools/setup_helpers/generate_code.py --declarations-path "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml" --nn-path "aten/src" + $<$:--disable-autograd> DEPENDS "${CMAKE_BINARY_DIR}/aten/src/ATen/Declarations.yaml" "${CMAKE_CURRENT_LIST_DIR}/../aten/src/THNN/generic/THNN.h" @@ -351,7 +363,6 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/autograd/record_function.cpp ${TORCH_SRC_DIR}/csrc/autograd/saved_variable.cpp ${TORCH_SRC_DIR}/csrc/autograd/variable.cpp - ${TORCH_SRC_DIR}/csrc/autograd/VariableTypeManual.cpp ${TORCH_SRC_DIR}/csrc/jit/autodiff.cpp ${TORCH_SRC_DIR}/csrc/jit/attributes.cpp ${TORCH_SRC_DIR}/csrc/jit/argument_spec.cpp @@ -369,7 +380,6 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/ir.cpp ${TORCH_SRC_DIR}/csrc/jit/irparser.cpp ${TORCH_SRC_DIR}/csrc/jit/jit_log.cpp - ${TORCH_SRC_DIR}/csrc/jit/netdef_converter.cpp ${TORCH_SRC_DIR}/csrc/jit/operator.cpp ${TORCH_SRC_DIR}/csrc/jit/register_c10_ops.cpp ${TORCH_SRC_DIR}/csrc/jit/subgraph_matcher.cpp @@ -413,6 +423,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/utils/memory_dag.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/quantization.cpp + ${TORCH_SRC_DIR}/csrc/jit/passes/fuse_linear.cpp ${TORCH_SRC_DIR}/csrc/jit/print_handler.cpp ${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp ${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp @@ -448,6 +459,12 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/function.cpp ) + if (NOT INTERN_DISABLE_AUTOGRAD) + list(APPEND TORCH_SRCS + ${TORCH_SRC_DIR}/csrc/autograd/VariableTypeManual.cpp + ) + endif() + if (NOT INTERN_BUILD_MOBILE) list(APPEND TORCH_SRCS ${TORCH_SRC_DIR}/csrc/api/src/jit.cpp @@ -456,23 +473,19 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/distributed/rpc/future_message.cpp ${TORCH_SRC_DIR}/csrc/distributed/rpc/message.cpp ${TORCH_SRC_DIR}/csrc/distributed/rpc/script_call.cpp + ${TORCH_SRC_DIR}/csrc/distributed/rpc/script_remote_call.cpp + ${TORCH_SRC_DIR}/csrc/distributed/rpc/script_rref_proto.cpp ${TORCH_SRC_DIR}/csrc/distributed/rpc/script_ret.cpp ${TORCH_SRC_DIR}/csrc/jit/export.cpp ${TORCH_SRC_DIR}/csrc/jit/import_legacy.cpp + ${TORCH_SRC_DIR}/csrc/jit/netdef_converter.cpp + ${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/fused_kernel.cpp ) - if (NOT WIN32) - list(APPEND TORCH_SRCS - ${TORCH_SRC_DIR}/csrc/jit/fuser/cpu/fused_kernel.cpp) - endif() endif() if (USE_CUDA) - if (NOT USE_ROCM) - list(APPEND Caffe2_GPU_SRCS - ${TORCH_SRC_DIR}/csrc/jit/fuser/cuda/fused_kernel.cpp - ) - endif() list(APPEND Caffe2_GPU_SRCS + ${TORCH_SRC_DIR}/csrc/jit/fuser/cuda/fused_kernel.cpp ${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp ${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp ${TORCH_SRC_DIR}/csrc/cuda/comm.cpp @@ -485,12 +498,15 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) if (USE_ROCM) list(APPEND Caffe2_HIP_SRCS + ${TORCH_SRC_DIR}/csrc/jit/fuser/cuda/fused_kernel.cpp + ${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp ${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp ${TORCH_SRC_DIR}/csrc/cuda/comm.cpp ) # caffe2_nvrtc's stubs to driver APIs are useful for HIP. # See NOTE [ ATen NVRTC Stub and HIP ] add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS}) + target_link_libraries(caffe2_nvrtc ${CUDA_NVRTC} ${PYTORCH_HIP_HCC_LIBRARIES} ${ROCM_HIPRTC_LIB}) target_link_libraries(caffe2_nvrtc ${CUDA_NVRTC} ${CUDA_CUDA_LIB} ${CUDA_NVRTC_LIB}) target_include_directories(caffe2_nvrtc PRIVATE ${CUDA_INCLUDE_DIRS}) target_compile_definitions(caffe2_nvrtc PRIVATE USE_ROCM __HIP_PLATFORM_HCC__) @@ -505,17 +521,28 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/api/src/data/samplers/random.cpp ${TORCH_SRC_DIR}/csrc/api/src/data/samplers/sequential.cpp ${TORCH_SRC_DIR}/csrc/api/src/data/samplers/stream.cpp + ${TORCH_SRC_DIR}/csrc/api/src/serialize.cpp ${TORCH_SRC_DIR}/csrc/api/src/jit.cpp ${TORCH_SRC_DIR}/csrc/api/src/nn/init.cpp ${TORCH_SRC_DIR}/csrc/api/src/nn/module.cpp ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/batchnorm.cpp ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/conv.cpp ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/dropout.cpp + ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/distance.cpp ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/embedding.cpp - ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/functional.cpp + ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/fold.cpp ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/linear.cpp - ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/named_any.cpp + ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/loss.cpp + ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/pooling.cpp ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/rnn.cpp + ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/container/functional.cpp + ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/container/named_any.cpp + ${TORCH_SRC_DIR}/csrc/api/src/nn/options/batchnorm.cpp + ${TORCH_SRC_DIR}/csrc/api/src/nn/options/conv.cpp + ${TORCH_SRC_DIR}/csrc/api/src/nn/options/dropout.cpp + ${TORCH_SRC_DIR}/csrc/api/src/nn/options/linear.cpp + ${TORCH_SRC_DIR}/csrc/api/src/nn/options/pooling.cpp + ${TORCH_SRC_DIR}/csrc/api/src/nn/options/rnn.cpp ${TORCH_SRC_DIR}/csrc/api/src/optim/adagrad.cpp ${TORCH_SRC_DIR}/csrc/api/src/optim/adam.cpp ${TORCH_SRC_DIR}/csrc/api/src/optim/lbfgs.cpp @@ -572,14 +599,15 @@ ENDIF() # formerly-libtorch flags # ========================================================== -if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) - +if (NOT INTERN_BUILD_MOBILE) # Forces caffe2.pb.h to be generated before its dependents are compiled. # Adding the generated header file to the ${TORCH_SRCS} list is not sufficient # to establish the dependency, since the generation procedure is declared in a different CMake file. # See https://samthursfield.wordpress.com/2015/11/21/cmake-dependencies-between-targets-and-files-and-custom-commands/#custom-commands-in-different-directories add_dependencies(torch Caffe2_PROTO) +endif() +if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) target_compile_definitions(torch PUBLIC _THP_CORE) @@ -768,7 +796,7 @@ ENDIF() DESTINATION share/cmake/Torch) if (USE_DISTRIBUTED) - if (NOT MSVC AND NOT APPLE) + if (NOT MSVC) add_subdirectory(${TORCH_SRC_DIR}/lib/c10d lib_c10d) endif() endif() @@ -830,13 +858,14 @@ if (NOT WIN32 AND NOT USE_ASAN) target_compile_options(torch PRIVATE "-fvisibility=hidden") endif() - -caffe2_interface_library(caffe2_protos caffe2_protos_whole) -target_link_libraries(torch PRIVATE caffe2_protos_whole) -if (${CAFFE2_LINK_LOCAL_PROTOBUF}) - target_link_libraries(torch INTERFACE protobuf::libprotobuf) -else() - target_link_libraries(torch PUBLIC protobuf::libprotobuf) +if (NOT INTERN_BUILD_MOBILE OR BUILD_CAFFE2_MOBILE) + caffe2_interface_library(caffe2_protos caffe2_protos_whole) + target_link_libraries(torch PRIVATE caffe2_protos_whole) + if (${CAFFE2_LINK_LOCAL_PROTOBUF}) + target_link_libraries(torch INTERFACE protobuf::libprotobuf) + else() + target_link_libraries(torch PUBLIC protobuf::libprotobuf) + endif() endif() if (USE_OPENMP AND OPENMP_FOUND) diff --git a/caffe2/contrib/CMakeLists.txt b/caffe2/contrib/CMakeLists.txt index 19306aa06439a..86442ce768975 100644 --- a/caffe2/contrib/CMakeLists.txt +++ b/caffe2/contrib/CMakeLists.txt @@ -1,5 +1,4 @@ add_subdirectory(aten) -add_subdirectory(gloo) add_subdirectory(nccl) add_subdirectory(opencl) add_subdirectory(prof) @@ -8,6 +7,12 @@ if (USE_TENSORRT) add_subdirectory(tensorrt) endif() +# Only build Gloo Caffe2 ops on Linux, as it hardcodes +# the Linux-specific `gloo::transport::tcp` namespace. +if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux") +add_subdirectory(gloo) +endif() + # Pass the src lists back to the parent # CPU source, include, deps, test sources, binary sources diff --git a/caffe2/contrib/gloo/gloo_test.py b/caffe2/contrib/gloo/gloo_test.py index 93aa5db89225f..91169e1d974eb 100644 --- a/caffe2/contrib/gloo/gloo_test.py +++ b/caffe2/contrib/gloo/gloo_test.py @@ -633,8 +633,13 @@ def test_close_connection(self, comm_size, device_option): comm_size=comm_size, device_option=device_option, tmpdir=tmpdir) - # Check that test finishes quickly because connections get closed - self.assertLess(time.time() - start_time, 2.0) + # Check that test finishes quickly because connections get closed. + # This assert used to check that the end to end runtime was less + # than 2 seconds, but this may not always be the case if there + # is significant overhead in starting processes. Ideally, this + # assert is replaced by one that doesn't depend on time but rather + # checks the success/failure status of the barrier that is run. + self.assertLess(time.time() - start_time, 20.0) def _test_io_error( self, diff --git a/caffe2/core/CMakeLists.txt b/caffe2/core/CMakeLists.txt index 177cf2259b44b..1a156eb63ccc9 100644 --- a/caffe2/core/CMakeLists.txt +++ b/caffe2/core/CMakeLists.txt @@ -1,3 +1,11 @@ +if (INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE) + list(APPEND Caffe2_CPU_SRCS + "${CMAKE_CURRENT_SOURCE_DIR}/common.cc" + ) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE) + return() +endif() + # ---[ GPU files # ------[ cuDNN if (USE_CUDNN) diff --git a/caffe2/core/net_dag_utils.cc b/caffe2/core/net_dag_utils.cc index 8786ac6b7b10d..498a86f9fc7f8 100644 --- a/caffe2/core/net_dag_utils.cc +++ b/caffe2/core/net_dag_utils.cc @@ -381,9 +381,13 @@ std::vector prepareOperatorNodes( const OperatorDef& op_def = net_def->op(idx); VLOG(1) << "Creating operator #" << idx << ": " << op_def.name() << ": " << op_def.type(); - if (!op_def.has_device_option() && net_def_has_device_option) { + if (net_def_has_device_option) { OperatorDef temp_def(op_def); - temp_def.mutable_device_option()->CopyFrom(net_def->device_option()); + + DeviceOption temp_dev(net_def->device_option()); + temp_dev.MergeFrom(op_def.device_option()); + + temp_def.mutable_device_option()->CopyFrom(temp_dev); operator_nodes[idx].operator_ = CreateOperator(temp_def, ws, idx); } else { auto op = CreateOperator(op_def, ws, idx); diff --git a/caffe2/core/net_simple.cc b/caffe2/core/net_simple.cc index 759c2fc06e1ad..38f03b339f278 100644 --- a/caffe2/core/net_simple.cc +++ b/caffe2/core/net_simple.cc @@ -31,12 +31,16 @@ SimpleNet::SimpleNet( VLOG(1) << "Creating operator " << operator_def.name() << ": " << operator_def.type(); std::unique_ptr op{nullptr}; - if (!operator_def.has_device_option() && net_def_has_device_option) { - // In the case that the operator def does not specify a device option but - // the net def has a default option, we copy the device option over to the - // operator def. + if (net_def_has_device_option) { + // In the case when net def specifies device option, final device option + // will be equal to merge of operator and net def device options, with + // preference to settings from the operator. OperatorDef temp_def(operator_def); - temp_def.mutable_device_option()->CopyFrom(net_def->device_option()); + + DeviceOption temp_dev(net_def->device_option()); + temp_dev.MergeFrom(operator_def.device_option()); + + temp_def.mutable_device_option()->CopyFrom(temp_dev); op = CreateOperator(temp_def, ws, idx); } else { op = CreateOperator(operator_def, ws, idx); diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h index 745771df3582e..4296209f512c8 100644 --- a/caffe2/core/operator.h +++ b/caffe2/core/operator.h @@ -28,8 +28,8 @@ #include "caffe2/proto/caffe2_pb.h" #include "caffe2/utils/proto_utils.h" -#include #if !defined(CAFFE2_IS_XPLAT_BUILD) +#include #include #endif @@ -532,6 +532,10 @@ class CAFFE2_API OperatorBase : public Observable { } } + virtual std::string debug_info_string() const { + return ""; + } + inline const OperatorDef& debug_def() const { CAFFE_ENFORCE(has_debug_def(), "operator_def was null!"); return *operator_def_; diff --git a/caffe2/core/tensor.h b/caffe2/core/tensor.h index c752f904b59bf..41b2b421c4051 100644 --- a/caffe2/core/tensor.h +++ b/caffe2/core/tensor.h @@ -6,7 +6,9 @@ #include #include +#if !defined(CAFFE2_IS_XPLAT_BUILD) #include "ATen/core/Tensor.h" +#endif #include #include @@ -117,6 +119,7 @@ class CAFFE2_API Tensor final { * The tensor will share the same instance (data, strides, sizes, etc) but * a different subset of APIs would be available */ +#if !defined(CAFFE2_IS_XPLAT_BUILD) explicit Tensor(at::Tensor tensor) : impl_(std::move(tensor.impl_)) { enforce_invariants(); @@ -129,6 +132,7 @@ class CAFFE2_API Tensor final { explicit operator at::Tensor() && { return at::Tensor::wrap_tensor_impl(std::move(impl_)); } +#endif bool is_same(const Tensor& other) const noexcept { return impl_ == other.impl_; diff --git a/caffe2/ideep/operators/operator_fallback_ideep.cc b/caffe2/ideep/operators/operator_fallback_ideep.cc index c06fd283524d7..1558e9b23e01b 100644 --- a/caffe2/ideep/operators/operator_fallback_ideep.cc +++ b/caffe2/ideep/operators/operator_fallback_ideep.cc @@ -50,7 +50,7 @@ #include "caffe2/operators/bbox_transform_op.h" #include "caffe2/operators/box_with_nms_limit_op.h" -#ifdef CAFFE2_USE_GLOO +#if __linux__ && defined(CAFFE2_USE_GLOO) #include #include #include @@ -284,7 +284,7 @@ REGISTER_IDEEP_OPERATOR( BatchMatMul, IDEEPFallbackOp>); -#ifdef CAFFE2_USE_GLOO +#if __linux__ && defined(CAFFE2_USE_GLOO) namespace gloo { // gloo operators REGISTER_IDEEP_OPERATOR( diff --git a/caffe2/operators/channelwise_conv3d_op_cudnn.cu b/caffe2/operators/channelwise_conv3d_op_cudnn.cu index cd40e308fb6a6..9936b8da4d9c9 100644 --- a/caffe2/operators/channelwise_conv3d_op_cudnn.cu +++ b/caffe2/operators/channelwise_conv3d_op_cudnn.cu @@ -99,7 +99,11 @@ __global__ void DepthwiseConv3dGPUKernelNCHW( const int input_offset = (input_offset_temp) + (in_l * in_cols * in_rows) + (in_r * in_cols) + in_c; +#if __CUDA_ARCH__ >= 350 sum += __ldg(input + input_offset) * __ldg(filter_offset + f_c); +#else + sum += input[input_offset] * filter_offset[f_c]; +#endif } } } @@ -120,7 +124,11 @@ __global__ void DepthwiseConv3dGPUKernelNCHW( in_l >= 0 && in_l < in_length) { const int input_offset = (input_offset_temp) + (in_l * in_cols * in_rows) + (in_r * in_cols) + in_c; +#if __CUDA_ARCH__ >= 350 sum += __ldg(input + input_offset) * __ldg(filter_offset + f_c); +#else + sum += input[input_offset] * filter_offset[f_c]; +#endif } } } @@ -181,7 +189,11 @@ __global__ void DepthwiseConv3dBackpropFilterGPUKernelNCHW( (OC * out_length * out_rows * out_cols) + (OL * out_rows * out_cols) + (OH * out_cols) + (OW); +#if __CUDA_ARCH__ >= 350 const T out_bp = __ldg(out_backprop + out_backprop_offset); +#else + const T out_bp = out_backprop[out_backprop_offset]; +#endif if (in_r_start >= 0 && in_c_start >= 0 && in_r_end < in_rows && in_c_end < in_cols && in_l_start >= 0 && in_l_end < in_length) { #pragma unroll @@ -200,7 +212,11 @@ __global__ void DepthwiseConv3dBackpropFilterGPUKernelNCHW( for (int f_c = 0; f_c < filter_cols; ++f_c) { const int in_c = in_c_start + f_c; const int input_offset = input_offset_temp + in_c; +#if __CUDA_ARCH__ >= 350 T partial_sum = __ldg(input + input_offset) * out_bp; +#else + T partial_sum = input[input_offset] * out_bp; +#endif T* addr = filter_backprop + (in_d * filter_rows * filter_cols * filter_length) + (f_l * filter_rows * filter_cols) + (f_c + filter_cols * f_r); @@ -227,7 +243,11 @@ __global__ void DepthwiseConv3dBackpropFilterGPUKernelNCHW( if (in_r >= 0 && in_r < in_rows && in_c >= 0 && in_c < in_cols && in_l >= 0 && in_l < in_length) { const int input_offset = input_offset_temp + in_c; +#if __CUDA_ARCH__ >= 350 T partial_sum = __ldg(input + input_offset) * out_bp; +#else + T partial_sum = input[input_offset] * out_bp; +#endif T* addr = filter_backprop + (in_d * filter_rows * filter_cols * filter_length) + (f_l * filter_rows * filter_cols) + (f_c + filter_cols * f_r); @@ -300,8 +320,13 @@ __global__ void DepthwiseConv3dBackpropInputGPUKernelNCHW( (IC * out_length * out_rows * out_cols) + (out_l * out_rows * out_cols) + (out_r * out_cols) + (out_c); +#if __CUDA_ARCH__ >= 350 sum += __ldg(out_backprop + out_backprop_offset) * __ldg(filter + filter_offset); +#else + sum += out_backprop[out_backprop_offset] * + filter[filter_offset]; +#endif } } } diff --git a/caffe2/operators/fused_rowwise_8bit_conversion_ops.cc b/caffe2/operators/fused_rowwise_8bit_conversion_ops.cc index 2bfd2c49add91..f2633173f0cad 100644 --- a/caffe2/operators/fused_rowwise_8bit_conversion_ops.cc +++ b/caffe2/operators/fused_rowwise_8bit_conversion_ops.cc @@ -146,3 +146,16 @@ the original, un-quantized floating point values. NO_GRADIENT(Fused8BitRowwiseQuantizedToHalfFloat); } // namespace caffe2 + +// To workaround comma + +using Fused8BitRowwiseQuantizedToFloatCPUOp = + caffe2::Fused8BitRowwiseQuantizedToFloatOp< + float, + caffe2::convertfp32fp32, + caffe2::CPUContext>; + +C10_EXPORT_CAFFE2_OP_TO_C10_CPU( + Fused8BitRowwiseQuantizedToFloat, + "_caffe2::Fused8BitRowwiseQuantizedToFloat(Tensor scale_bias_quantized_input) -> Tensor", + Fused8BitRowwiseQuantizedToFloatCPUOp); diff --git a/caffe2/operators/fused_rowwise_8bit_conversion_ops.h b/caffe2/operators/fused_rowwise_8bit_conversion_ops.h index bfdb6be27807d..6bb5730e39bd4 100644 --- a/caffe2/operators/fused_rowwise_8bit_conversion_ops.h +++ b/caffe2/operators/fused_rowwise_8bit_conversion_ops.h @@ -2,12 +2,15 @@ #define CAFFE2_OPERATORS_FUSED_ROWWISE_8BIT_CONVERSION_OPS_H_ #include "caffe2/core/context.h" +#include "caffe2/core/export_caffe2_op_to_c10.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/operators/reducer_functors.h" #include "caffe2/utils/eigen_utils.h" #include "caffe2/utils/math.h" +C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(Fused8BitRowwiseQuantizedToFloat); + namespace caffe2 { #define IS_LITTLE_ENDIAN \ diff --git a/caffe2/operators/rnn/recurrent_network_executor.h b/caffe2/operators/rnn/recurrent_network_executor.h index 2022cc838cd20..33fb4575d18a6 100644 --- a/caffe2/operators/rnn/recurrent_network_executor.h +++ b/caffe2/operators/rnn/recurrent_network_executor.h @@ -39,13 +39,14 @@ class RecurrentNetworkExecutorBase { timestep_blob_(timestep_blob) { const bool net_def_has_device_option = step_net_def_.has_device_option(); for (int i = 0; i < step_net_def_.op_size(); i++) { - if (!step_net_def_.op(i).has_device_option() && - net_def_has_device_option) { - // In the case that the operator def does not specify a device option - // but the net def has a default option, we copy the device option over - // to the operator def. - step_net_def_.mutable_op(i)->mutable_device_option()->CopyFrom( - step_net_def_.device_option()); + if (net_def_has_device_option) { + // In the case when net def specifies device option, final device option + // will be equal to merge of operator and net def device options, with + // preference to settings from the operator. + DeviceOption option; + option.CopyFrom(step_net_def_.device_option()); + option.MergeFrom(step_net_def_.op(i).device_option()); + step_net_def_.mutable_op(i)->mutable_device_option()->CopyFrom(option); } op_deps_.push_back(op_deps(i)); } diff --git a/caffe2/operators/sparse_normalize_op.cc b/caffe2/operators/sparse_normalize_op.cc index 3b3f16b7ce28d..1c6dfa5596be7 100644 --- a/caffe2/operators/sparse_normalize_op.cc +++ b/caffe2/operators/sparse_normalize_op.cc @@ -6,9 +6,8 @@ namespace caffe2 { template <> bool SparseNormalizeOp::RunOnDevice() { - return DispatchHelper>::call( - this, Input(INDICES)); + this, Input(INDICES)); } template <> @@ -49,10 +48,14 @@ bool SparseNormalizeOp::DoRunWithType() { REGISTER_CPU_OPERATOR(SparseNormalize, SparseNormalizeOp); OPERATOR_SCHEMA(SparseNormalize) - .NumInputs(2) + .NumInputs(2, 3) .NumOutputs(1) .Input(0, "param", "Parameters to be normalized") .Input(1, "indices", "Sparse indices") + .Input( + 2, + "grad", + "Gradient computed (optional - not used, this argument is for backwards compatibility)") .Output(0, "output_param", "Normalized parameters") .EnforceOneToOneInplace() .Arg( diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc index 54a9bcfee821b..d79c73e7cee18 100644 --- a/caffe2/opt/bound_shape_inferencer.cc +++ b/caffe2/opt/bound_shape_inferencer.cc @@ -53,7 +53,7 @@ void BoundShapeInferencer::InferOps( InferSparseLengthsSum(op); } else if ( op.type() == "FC" || op.type() == "FCTransposed" || - op.type() == "FbFCPacked") { + op.type() == "FbFCPacked" || op.type() == "Int8FC") { InferFC(op); } else if (op.type() == "Concat") { InferConcat(op); @@ -424,12 +424,13 @@ void BoundShapeInferencer::InferFC(const OperatorDef& op) { const ShapeInfo& w_shape_info = w_it->second; const auto b_it = shape_info_.find(op.input(2)); CAFFE_ENFORCE( - w_it != shape_info_.end(), + b_it != shape_info_.end(), "Shape of BIAS input of FC ", op.input(2), " needs to be presented"); const ShapeInfo& b_shape_info = b_it->second; bool fp16 = (op.type() == "FbFCPacked"); + bool int8_fc = (op.type() == "Int8FC" || op.engine() == "DNNLOWP"); auto x_it = shape_info_.find(op.input(0)); if (x_it == shape_info_.end()) { // We don't have a hint at the x input we try to deduce it from weight @@ -451,13 +452,21 @@ void BoundShapeInferencer::InferFC(const OperatorDef& op) { dims.push_back(K); current_dim_type_ = ShapeInfo::DimType::BATCH; current_max_batch_size_ = spec_.max_batch_size; + TensorProto::DataType w_data_type; + if (fp16) { + w_data_type = TensorProto_DataType_FLOAT; + } else if (int8_fc) { + w_data_type = TensorProto_DataType_UINT8; + } else { + w_data_type = w_shape.data_type(); + } // Note: for FbFCPacked, weight is fp16 but actications are in fp32 CheckAndSetTensorShapeAndType( op.input(0), ShapeInfo::DimType::BATCH, dims, - fp16 ? TensorProto_DataType_FLOAT : w_shape.data_type(), - false); + w_data_type, + int8_fc ? true : false); } else { ShapeInfo& x_shape_info = x_it->second; if (x_shape_info.dim_type != ShapeInfo::DimType::BATCH) { @@ -472,12 +481,20 @@ void BoundShapeInferencer::InferFC(const OperatorDef& op) { shape_info_[op.input(0)].shape, w_shape_info.shape, b_shape_info.shape}; std::vector output_shapes = InferOutput(op, input_shapes); CAFFE_ENFORCE_EQ(output_shapes.size(), 1); + TensorProto::DataType output_data_type; + if (fp16) { + output_data_type = TensorProto_DataType_FLOAT; + } else if (int8_fc) { + output_data_type = TensorProto_DataType_UINT8; + } else { + output_data_type = output_shapes[0].data_type(); + } CheckAndSetTensorShapeAndType( op.output(0), ShapeInfo::DimType::BATCH, ConvertToVec(output_shapes[0].dims()), - fp16 ? TensorProto_DataType_FLOAT : output_shapes[0].data_type(), - false); + output_data_type, + int8_fc ? true : false); } void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) { @@ -511,7 +528,8 @@ void BoundShapeInferencer::InferCommonOp(const OperatorDef& op) { {"Int8AveragePool", 0}, {"Int8FC", 1}, {"Int8Conv", 1}, - {"Int8SumRelu", 0}}; + {"Int8SumRelu", 0}, + {"Int8Relu", 0}}; CAFFE_ENFORCE( type_info_from_input.find(op.type()) != type_info_from_input.end(), "Undefined quantized output data type, add it into type_info_from_input"); diff --git a/caffe2/opt/custom/glow_net_transform.cc b/caffe2/opt/custom/glow_net_transform.cc index 679dd11d26082..141522bf25cc7 100644 --- a/caffe2/opt/custom/glow_net_transform.cc +++ b/caffe2/opt/custom/glow_net_transform.cc @@ -125,7 +125,11 @@ void onnxifi( if (kv.size() == 2) { auto dims = caffe2::split(',', kv.back()); TensorShape input; - input.set_data_type(TensorProto_DataType_FLOAT); + if (kv.front().find("int8") != std::string::npos) { + input.set_data_type(TensorProto_DataType_UINT8); + } else { + input.set_data_type(TensorProto_DataType_FLOAT); + } bool valid = true; for (const auto& d : dims) { try { diff --git a/caffe2/perfkernels/CMakeLists.txt b/caffe2/perfkernels/CMakeLists.txt index 3ca9ae5a1b152..42c5ddd3bc0be 100644 --- a/caffe2/perfkernels/CMakeLists.txt +++ b/caffe2/perfkernels/CMakeLists.txt @@ -1,3 +1,11 @@ +if (INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE) + list(APPEND Caffe2_CPU_SRCS + "${CMAKE_CURRENT_SOURCE_DIR}/embedding_lookup_idx.cc" + ) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE) + return() +endif() + # ---[ CPU files. file(GLOB common_srcs *.cc) file(GLOB avx_srcs *_avx.cc) diff --git a/caffe2/perfkernels/embedding_lookup_idx.cc b/caffe2/perfkernels/embedding_lookup_idx.cc index 3e66525af25da..825251c0aa0cd 100644 --- a/caffe2/perfkernels/embedding_lookup_idx.cc +++ b/caffe2/perfkernels/embedding_lookup_idx.cc @@ -1,6 +1,8 @@ #include "caffe2/perfkernels/embedding_lookup_idx.h" -#include "caffe2/core/types.h" +#include +#include "caffe2/core/common.h" +#include "caffe2/core/logging.h" #include "caffe2/perfkernels/common.h" namespace caffe2 { diff --git a/caffe2/python/core.py b/caffe2/python/core.py index 6589c9a3b4480..ef2124632abff 100644 --- a/caffe2/python/core.py +++ b/caffe2/python/core.py @@ -451,11 +451,12 @@ def GetIndexFromGradientList(g_list, name): OpSSA = namedtuple('OpSSA', ['op', 'in_versions', 'out_versions']) -GradGenMeta = namedtuple('GradGenMeta', ['grad_op', 'idx', 'gradient']) +GradGenMeta = namedtuple('GradGenMeta', + ['grad_op', 'idx', 'gradient', 'device_option']) SparseGradGenMeta = namedtuple('SparseGradGenMeta', [ 'grad_op_indices', 'idx_indices', 'grad_op_values', 'idx_values', - 'gradient', + 'gradient', 'device_option', ]) @@ -607,9 +608,13 @@ def AppendSparseGenerators(self, sparse_generators): else: # both indices and values are generated assert(len(generators) == 2) - op1_i, idx1_i, op1_v, idx1_v, g1 = generators[0] - op2_i, idx2_i, op2_v, idx2_v, g2 = generators[1] + op1_i, idx1_i, op1_v, idx1_v, g1, dev_1 = generators[0] + op2_i, idx2_i, op2_v, idx2_v, g2, dev_2 = generators[1] assert(g1 == g2) + assert dev_1 == dev_2, ( + "Unequal devices for sparse generators: " + "{} and {}".format(dev1, dev2) + ) assert(op1_i is None or op2_i is None) assert(op1_v is None or op2_v is None) assert(idx1_i == 0 or idx2_i == 0) @@ -617,7 +622,7 @@ def AppendSparseGenerators(self, sparse_generators): generator = SparseGradGenMeta( op1_i or op2_i, idx1_i + idx2_i, op1_v or op2_v, idx1_v + idx2_v, - g1) + g1, dev_1) self.gradient_generators[name][version].append(generator) def BuildGradientGenerators( # NOQA @@ -650,15 +655,17 @@ def BuildGradientGenerators( # NOQA # we'll merge indices and values generators # corresponding to the same gradient in step (3) if g.indices == output: - m = SparseGradGenMeta(grad_op, i, None, 0, g) + m = SparseGradGenMeta( + grad_op, i, None, 0, g, grad_op.device_option) else: assert(g.values == output) - m = SparseGradGenMeta(None, 0, grad_op, i, g) + m = SparseGradGenMeta( + None, 0, grad_op, i, g, grad_op.device_option) sparse_generators[input_name][input_version].append(m) else: self.gradient_generators[input_name][input_version] \ .append(GradGenMeta( - grad_op, i, g)) + grad_op, i, g, grad_op.device_option)) # (3) merge indices and values generators for sparse gradients, and # add them to gradient_generators @@ -678,11 +685,11 @@ def BuildGradientGenerators( # NOQA if str(g.indices) not in locally_generated_blobs and \ str(g.values) not in locally_generated_blobs: self.gradient_generators[input_name][input_version].append( - SparseGradGenMeta(None, 0, None, 0, g)) + SparseGradGenMeta(None, 0, None, 0, g, forward_op.device_option)) else: if str(g) not in locally_generated_blobs: self.gradient_generators[input_name][input_version].append( - GradGenMeta(None, 0, g)) + GradGenMeta(None, 0, g, forward_op.device_option)) # Finally, for the gradients specified in g_input, we update the # gradient frontier to reflect the input versions that the gradients @@ -701,12 +708,12 @@ def remove_suffix(s, suffix): for g in generator: if type(g) is GradGenMeta: - grad_op, idx, _ = g + grad_op, idx, _, _ = g if grad_op: return grad_op.output[idx] else: assert(type(g) is SparseGradGenMeta) - op_i, idx_i, op_v, idx_v, _ = g + op_i, idx_i, op_v, idx_v, _, _ = g if op_i: return remove_suffix(op_i.output[idx_i], '_indices') if op_v: @@ -720,16 +727,12 @@ def _SetSumOpsDeviceOption(self, sum_ops, generators): # we already checked that device options are consistent so we can just # use the first one we find for generator in generators: - grad_op = generator.grad_op if type(generator) is GradGenMeta \ - else generator.grad_op_values or generator.grad_op_indices - if grad_op: - if grad_op.HasField('device_option'): - for op in sum_ops: - op.device_option.CopyFrom(grad_op.device_option) - op.device_option.extra_info.extend([ - "{}:1".format(IR.IS_AUTO_GEN_SUM_OPS_TAG) - ]) - break + for op in sum_ops: + op.device_option.CopyFrom(generator.device_option) + op.device_option.extra_info.extend([ + "{}:1".format(IR.IS_AUTO_GEN_SUM_OPS_TAG) + ]) + break def _DisambiguateGradOpOutput(self, grad_op, idx, cnt): new_grad_output = ( @@ -756,7 +759,7 @@ def _MakeDenseSumOps(self, generators, out_base_name): first_grad_op = True for generator in generators: - grad_op, idx, g = generator + grad_op, idx, g, _ = generator assert(type(g) is not GradientSlice) if grad_op: if first_grad_op: @@ -790,7 +793,7 @@ def _MakeSparseSumOps(self, generators, out_base_name): for generator in generators: assert(type(generator) is SparseGradGenMeta) - op_i, idx_i, op_v, idx_v, g = generator + op_i, idx_i, op_v, idx_v, g, _ = generator if op_i: out, cnt_i = self._DisambiguateGradOpOutput(op_i, idx_i, cnt_i) indices_concat_input.append(out) @@ -864,16 +867,14 @@ def _VerifyGradientGenerators(self, generator): all_gradient_names = [] all_device_options = [] for g in generator: + if g.device_option: + all_device_options.append(g.device_option) if type(g) is GradGenMeta: if g.grad_op: all_gradient_names.append(g.gradient) - all_device_options.append(g.grad_op.device_option) else: assert(type(g) is SparseGradGenMeta) - if g.grad_op_indices: - all_device_options.append(g.grad_op_indices.device_option) - if g.grad_op_values: - all_device_options.append(g.grad_op_values.device_option) + if g.gradient.values: all_gradient_names.append(g.gradient.values) # Check if all grad op device options are the same. @@ -935,7 +936,8 @@ def _AppendAutoGradGenerator(self, y, grad, autograd_op): # a ConstantFill operator. Autogeneration for sparse gradients is # not supported generator = GradGenMeta( - autograd_op, 0 if autograd_op else None, str(grad)) + autograd_op, 0 if autograd_op else None, str(grad), + autograd_op.device_option) self.gradient_generators[str(y)][self.frontier[str(y)]].append( generator) diff --git a/caffe2/python/core_gradients_test.py b/caffe2/python/core_gradients_test.py index b6e9817717b54..a7b736e4de284 100644 --- a/caffe2/python/core_gradients_test.py +++ b/caffe2/python/core_gradients_test.py @@ -915,6 +915,44 @@ def testIncorrectOperator(self): except Exception as e: self.assertTrue("schema" in str(e)) + def testDeviceOptionsPropagation(self): + ''' + Test verifies that aggregation operators in a backward path will be in + the same device as the parameter. + ''' + device_0 = 'node:0' + + # init_net. + init_net = core.Net("init_net") + with core.DeviceScope(0, node_name=device_0): + w = init_net.UniformFill([], 'w', shape=[10000, 64]) + ids = init_net.GivenTensorFill( + [], + 'ids', + values=np.random.random_integers(low=0, high=10000, size=10), + ) + ids_2 = init_net.GivenTensorFill( + [], + 'ids_2', + values=np.random.random_integers(low=0, high=10000, size=10), + ) + + # train_net. + train_net = core.Net("train_net") + with core.DeviceScope(0, node_name=device_0): + vals = train_net.Gather([w, ids], "gathered") + r_vals = train_net.ReduceSum([vals], 1, axes=0) + + vals_2 = train_net.Gather([w, ids_2], "gathered_2") + r_vals_2 = train_net.ReduceSum([vals_2], 1, axes=0) + + loss = train_net.Sum([r_vals, r_vals_2], 1) + train_net.AddGradientOperators([loss]) + # All concat operators should be on device_0 + for op in train_net.Proto().op: + if op.type == 'Concat': + self.assertEqual(op.device_option.node_name, device_0) + if __name__ == '__main__': unittest.main() diff --git a/caffe2/python/dyndep.py b/caffe2/python/dyndep.py index be85af542c496..af203fa27b129 100644 --- a/caffe2/python/dyndep.py +++ b/caffe2/python/dyndep.py @@ -7,7 +7,7 @@ import ctypes import os - +from threading import Lock from caffe2.python import core, extension_loader @@ -36,6 +36,7 @@ def InitOpsLibrary(name): _IMPORTED_DYNDEPS = set() +dll_lock = Lock() def GetImportedOpsLibraries(): @@ -43,8 +44,9 @@ def GetImportedOpsLibraries(): def _init_impl(path): - _IMPORTED_DYNDEPS.add(path) - with extension_loader.DlopenGuard(): - ctypes.CDLL(path) - # reinitialize available ops - core.RefreshRegisteredOperators() + with dll_lock: + _IMPORTED_DYNDEPS.add(path) + with extension_loader.DlopenGuard(): + ctypes.CDLL(path) + # reinitialize available ops + core.RefreshRegisteredOperators() diff --git a/caffe2/python/hypothesis_test.py b/caffe2/python/hypothesis_test.py index f2e295b8bb32c..fc0be25c5c1c7 100644 --- a/caffe2/python/hypothesis_test.py +++ b/caffe2/python/hypothesis_test.py @@ -773,13 +773,12 @@ def ftrl(w, nz, i, g, alpha): self.assertReferenceChecks(gc, op, [var, nz, indices, grad, alpha], ftrl) - # TODO: (bddppq) test_unique keeps running into segfault on rocm 1.8.2 @given(input=hu.tensor(max_value=20, max_dim=1, dtype=np.int32, elements=st.integers(min_value=0, max_value=10)), with_remapping=st.booleans(), - **hu.gcs_no_hip) + **hu.gcs) def test_unique(self, input, with_remapping, gc, dc): op = core.CreateOperator( "Unique", diff --git a/caffe2/python/layers/split.py b/caffe2/python/layers/split.py index 449e2c67d0341..50bcdbf88b11e 100644 --- a/caffe2/python/layers/split.py +++ b/caffe2/python/layers/split.py @@ -13,8 +13,8 @@ class Split(ModelLayer): - def __init__(self, model, input_record, num_splits, axis=1, - name='split', **kwargs): + def __init__(self, model, input_record, num_splits=1, axis=1, + name='split', split=None, **kwargs): super(Split, self).__init__(model, name, input_record, **kwargs) self.axis = axis # Assume that first dimension is batch, so actual axis in shape is @@ -28,25 +28,48 @@ def __init__(self, model, input_record, num_splits, axis=1, input_shape = input_record.field_type().shape assert len(input_shape) >= axis - assert input_shape[axis] % num_splits == 0 + if split is None: + assert input_shape[axis] % num_splits == 0 + else: + num_splits = len(split) + assert input_shape[axis] == sum(split) - output_shape = list(input_shape) - output_shape[axis] = int(output_shape[axis] / num_splits) + if split is None: + output_shape = list(input_shape) + output_shape[axis] = int(output_shape[axis] / num_splits) + else: + output_shape = [] + for i in range(num_splits): + output_shape_i = list(input_shape) + output_shape_i[axis] = split[i] + output_shape.append(output_shape_i) data_type = input_record.field_type().base - output_scalars = [ - schema.Scalar( - (data_type, output_shape), - self.get_next_blob_reference('output_{}'.format(i)), - ) - for i in range(num_splits) - ] + + if split is None: + output_scalars = [ + schema.Scalar( + (data_type, output_shape), + self.get_next_blob_reference('output_{}'.format(i)), + ) + for i in range(num_splits) + ] + else: + output_scalars = [ + schema.Scalar( + (data_type, output_shape[i]), + self.get_next_blob_reference('output_{}'.format(i)), + ) + for i in range(num_splits) + ] self.output_schema = schema.Tuple(*output_scalars) + self.split = split def add_ops(self, net): net.Split( self.input_record.field_blobs(), self.output_schema.field_blobs(), + split=self.split, axis=self.axis, ) diff --git a/caffe2/python/onnx/backend.py b/caffe2/python/onnx/backend.py index 4c5b0c88c8539..f7056ed1da5c1 100644 --- a/caffe2/python/onnx/backend.py +++ b/caffe2/python/onnx/backend.py @@ -13,6 +13,7 @@ import os import collections from subprocess import Popen, PIPE +import sys import zipfile import itertools @@ -887,7 +888,7 @@ def _onnx_model_to_caffe2_net(cls, onnx_model, device, opset_version, include_in cls._dummy_name.reset(cls._all_names_in_graph(init_model.graph) | cls._all_names_in_graph(pred_model.graph)) - success = True + errors = [] for net, model in ( (init_net, init_model), (pred_net, pred_model) ): net.device_option.CopyFrom(device_option) for node in model.graph.node: @@ -895,8 +896,9 @@ def _onnx_model_to_caffe2_net(cls, onnx_model, device, opset_version, include_in c2ops = cls._onnx_node_to_caffe2_op( init_model, pred_model, node, opset_version) except Exception as e: - success = False - print('ONNX FATAL:', e) + msg = 'Error while processing node: {}. Exception: {}'.format(node, e) + errors.append(msg) + print('ONNX FATAL:', msg, file=sys.stderr) continue init_net.op.extend(c2ops.init_ops) net.op.extend(c2ops.ops) @@ -906,8 +908,10 @@ def _onnx_model_to_caffe2_net(cls, onnx_model, device, opset_version, include_in net.external_input.extend( value_info.name for value_info in model.graph.input) - if not success: - raise RuntimeError('ONNX conversion failed') + if len(errors) > 0: + raise RuntimeError( + "ONNX conversion failed, encountered {} errors:\n\n{}".format( + len(errors), "\n\n".join(errors))) return init_net, pred_net diff --git a/caffe2/python/onnx/tests/onnx_backend_test.py b/caffe2/python/onnx/tests/onnx_backend_test.py index 5883b87304dcc..d5aa7eb285fef 100644 --- a/caffe2/python/onnx/tests/onnx_backend_test.py +++ b/caffe2/python/onnx/tests/onnx_backend_test.py @@ -30,6 +30,8 @@ '|test_reduce_log_sum.*' # Does not support ReduceLogSum. '|test_reduce_prod.*' # Does not support ReduceProd. '|test_reduce_sum_square.*' # Does not support ReduceSumSquare + '|test_det.*' # Does not support Det + '|test_range.*' # Does not support Range '|test_tile.*' # Tile's Caffe2 implementation needs some tweak '|test_lstm.*' # Seems LSTM case has some problem '|test_simple_rnn.*' # Seems simple RNN case has some problem @@ -76,6 +78,10 @@ '|test_gather_elements.*' # opset 11 is not supported yet '|test_scatter.*' # opset 11 is not supported yet '|test_unique.*' # opset 11 is not supported yet + '|test_gathernd.*' # opset 11 is not supported yet + '|test_sequence_.*' # type sequence is not supported yet + '|test_.*negative_ax.*' # negative axis is not supported yet + '|test_.*negative_ind.*' # negative axis is not supported yet ')') # Quick patch to unbreak master CI, is working on the debugging. diff --git a/caffe2/python/operator_test/learning_rate_op_test.py b/caffe2/python/operator_test/learning_rate_op_test.py index fbb55eb22db83..704bfc7c3c372 100644 --- a/caffe2/python/operator_test/learning_rate_op_test.py +++ b/caffe2/python/operator_test/learning_rate_op_test.py @@ -112,9 +112,11 @@ def ref(iter): self.assertReferenceChecks(gc, op, [iter], ref) - @given(gc=hu.gcs['gc'], - min_num_iter=st.integers(min_value=10, max_value=20), - max_num_iter=st.integers(min_value=50, max_value=100)) + @given( + gc=hu.gcs['gc'], + min_num_iter=st.integers(min_value=10, max_value=20), + max_num_iter=st.integers(min_value=50, max_value=100), + ) def test_composite_learning_rate_op(self, gc, min_num_iter, max_num_iter): np.random.seed(65535) # Generate the iteration numbers for sub policy @@ -128,7 +130,7 @@ def test_composite_learning_rate_op(self, gc, min_num_iter, max_num_iter): accu_iter_num[i] += accu_iter_num[i - 1] total_iter_nums = accu_iter_num[-1] - policy_lr_scale = np.random.uniform(low=2.0, high=2.0, size=num_lr_policy) + policy_lr_scale = np.random.uniform(low=0.1, high=2.0, size=num_lr_policy) # args for StepLRPolicy step_size = np.random.randint(low=2, high=min_num_iter // 2) diff --git a/caffe2/python/operator_test/sparse_normalize_test.py b/caffe2/python/operator_test/sparse_normalize_test.py index 8d17c302a2835..bd8dbd5f7b536 100644 --- a/caffe2/python/operator_test/sparse_normalize_test.py +++ b/caffe2/python/operator_test/sparse_normalize_test.py @@ -24,14 +24,14 @@ def ref_normalize(param_in, use_max_norm, norm): # Suppress filter_too_much health check. # Likely caused by `assume` call falling through too often. @settings(suppress_health_check=[HealthCheck.filter_too_much]) - @given(inputs=hu.tensors(n=1, min_dim=2, max_dim=2), + @given(inputs=hu.tensors(n=2, min_dim=2, max_dim=2), use_max_norm=st.booleans(), norm=st.floats(min_value=1.0, max_value=4.0), data_strategy=st.data(), **hu.gcs_cpu_only) def test_sparse_normalize(self, inputs, use_max_norm, norm, data_strategy, gc, dc): - param = inputs + param, grad = inputs param += 0.02 * np.sign(param) param[param == 0.0] += 0.02 @@ -47,7 +47,7 @@ def test_sparse_normalize(self, inputs, use_max_norm, norm, hypothesis.assume(np.array_equal(np.unique(indices.flatten()), np.sort(indices.flatten()))) - op = core.CreateOperator( + op1 = core.CreateOperator( "SparseNormalize", ["param", "indices"], ["param"], @@ -55,7 +55,18 @@ def test_sparse_normalize(self, inputs, use_max_norm, norm, norm=norm, ) - def ref_sparse_normalize(param, indices): + # Sparsify grad + grad = grad[indices] + + op2 = core.CreateOperator( + "SparseNormalize", + ["param", "indices", "grad"], + ["param"], + use_max_norm=use_max_norm, + norm=norm, + ) + + def ref_sparse_normalize(param, indices, grad=None): param_out = np.copy(param) for _, index in enumerate(indices): param_out[index] = self.ref_normalize( @@ -67,6 +78,11 @@ def ref_sparse_normalize(param, indices): # self.assertDeviceChecks(dc, op, [param, indices], [0]) self.assertReferenceChecks( - gc, op, [param, indices], + gc, op1, [param, indices], + ref_sparse_normalize + ) + + self.assertReferenceChecks( + gc, op2, [param, indices, grad], ref_sparse_normalize ) diff --git a/caffe2/python/operator_test/torch_integration_test.py b/caffe2/python/operator_test/torch_integration_test.py index 648f2097965e0..2e5526cc6bbfc 100644 --- a/caffe2/python/operator_test/torch_integration_test.py +++ b/caffe2/python/operator_test/torch_integration_test.py @@ -1,14 +1,16 @@ from __future__ import absolute_import, division, print_function, unicode_literals -from caffe2.python import core, workspace -import torch -from hypothesis import given import caffe2.python.hypothesis_test_util as hu import hypothesis.strategies as st import numpy as np -from scipy.stats import norm +import struct +import torch import unittest +from caffe2.python import core, workspace +from hypothesis import given +from scipy.stats import norm + def generate_rois(roi_counts, im_dims): assert len(roi_counts) == len(im_dims) @@ -68,6 +70,51 @@ def create_bbox_transform_inputs(roi_counts, num_classes, rotated): return rois, deltas, im_info +# Eigen/Python round 0.5 away from 0, Numpy rounds to even +round_to_nearest = np.vectorize(round) + + +def bytes_to_floats(byte_matrix): + floats = np.empty([np.shape(byte_matrix)[0], 1], dtype=np.float32) + for i, byte_values in enumerate(byte_matrix): + floats[i], = struct.unpack('f', bytearray(byte_values)) + return floats + + +def floats_to_bytes(floats): + byte_matrix = np.empty([np.shape(floats)[0], 4], dtype=np.uint8) + for i, value in enumerate(floats): + assert isinstance(value, np.float32), (value, floats) + as_bytes = struct.pack('f', value) + # In Python3 bytes will be a list of int, in Python2 a list of string + if isinstance(as_bytes[0], int): + byte_matrix[i] = list(as_bytes) + else: + byte_matrix[i] = list(map(ord, as_bytes)) + return byte_matrix + + +def fused_rowwise_8bit_quantize_reference(data): + minimum = np.min(data, axis=1, keepdims=True) + maximum = np.max(data, axis=1, keepdims=True) + span = maximum - minimum + bias = minimum + scale = span / 255.0 + inverse_scale = 255.0 / (span + 1e-8) + quantized_data = round_to_nearest((data - bias) * inverse_scale) + scale_bytes = floats_to_bytes(scale.reshape(-1)) + bias_bytes = floats_to_bytes(bias.reshape(-1)) + return np.concatenate([quantized_data, scale_bytes, bias_bytes], axis=1) + + +def fused_rowwise_8bit_quantize_dequantize_reference(data): + fused_quantized = fused_rowwise_8bit_quantize_reference(data) + scale = bytes_to_floats(fused_quantized[:, -8:-4].astype(np.uint8)) + bias = bytes_to_floats(fused_quantized[:, -4:].astype(np.uint8)) + quantized_data = fused_quantized[:, :-8] + return quantized_data * scale + bias + + class TorchIntegration(hu.HypothesisTestCase): @given( roi_counts=st.lists(st.integers(0, 5), min_size=1, max_size=10), @@ -609,6 +656,25 @@ def test_resize_nearest_op_cpu(self): def test_resize_nearest_op_cuda(self): return self._test_resize_nearest_op("cuda") + @given(input_data=hu.tensor(min_dim=2, max_dim=2)) + def test_Fused8BitRowwiseQuantizedToFloat(self, input_data): + QuantizeOp = core.CreateOperator( + "FloatToFused8BitRowwiseQuantized", + ["input_data"], + ["quantized_data"], + ) + + workspace.FeedBlob("input_data", input_data) + workspace.RunOperatorOnce(QuantizeOp) + + quantized_data = workspace.FetchBlob("quantized_data") + + dequantized_data = torch.ops._caffe2.Fused8BitRowwiseQuantizedToFloat( + torch.tensor(quantized_data) + ) + + reference = fused_rowwise_8bit_quantize_dequantize_reference(input_data) + np.testing.assert_array_almost_equal(dequantized_data.numpy(), reference) if __name__ == '__main__': unittest.main() diff --git a/caffe2/python/pybind_state.h b/caffe2/python/pybind_state.h index 637bf47355126..5a3628d2ef22e 100644 --- a/caffe2/python/pybind_state.h +++ b/caffe2/python/pybind_state.h @@ -35,7 +35,7 @@ #else -struct PyArrayObject; // Forward declaring PyArrayObject for safety +struct PyArrayObject; // Forward declaring PyArrayObject for safety #endif // USE_NUMPY @@ -65,8 +65,11 @@ class C10_EXPORT BlobFetcherBase { class BlobFeederBase { public: virtual ~BlobFeederBase(); - virtual void - Feed(const DeviceOption& option, PyArrayObject* array, Blob* blob, bool in_place = false) = 0; + virtual void Feed( + const DeviceOption& option, + PyArrayObject* array, + Blob* blob, + bool in_place = false) = 0; }; C10_DECLARE_TYPED_REGISTRY( @@ -232,7 +235,8 @@ class TensorFeeder : public BlobFeederBase { PyBytes_AsStringAndSize(input[i], &str, &strSize) != -1, "Had a PyBytes object but cannot convert it to a string."); } else if (PyUnicode_Check(input[i])) { // string - str = const_cast(PyUnicode_AsUTF8AndSize(input[i], &strSize)); + str = + const_cast(PyUnicode_AsUTF8AndSize(input[i], &strSize)); CAFFE_ENFORCE( str, "Had a PyUnicode object but cannot convert it to a string."); @@ -327,10 +331,27 @@ class PythonOpBase : public Operator { try { auto pickle = py::reinterpret_steal(PyImport_ImportModule("pickle")); + CAFFE_ENFORCE(pickle); auto loads = pickle.attr("loads").cast(); CAFFE_ENFORCE(loads); - auto builder_call = loads(py::bytes(pickled)).cast(); + py::tuple builder_call; + try { + builder_call = loads(py::bytes(pickled)).cast(); + } catch (const py::error_already_set& e) { +#if PY_MAJOR_VERSION >= 3 + LOG(INFO) << "Cannot unpickle python operator: " << e.what(); + LOG(INFO) << "Try latin1 encoding for python3 run"; + // to use the `_a` literal for arguments + using namespace pybind11::literals; + builder_call = loads(py::bytes(pickled), "encoding"_a = "latin1") + .template cast(); +#else + // for py2, simply re-throw the exception, as there is no encoding + // argument for pickle.loads + throw; +#endif + } CAFFE_ENFORCE(builder_call); CAFFE_ENFORCE_EQ(py::len(builder_call), 3); auto func = builder_call[0].cast(); diff --git a/caffe2/quantization/server/conv_dnnlowp_op.cc b/caffe2/quantization/server/conv_dnnlowp_op.cc index 253658009890a..febab049ff685 100644 --- a/caffe2/quantization/server/conv_dnnlowp_op.cc +++ b/caffe2/quantization/server/conv_dnnlowp_op.cc @@ -1115,6 +1115,10 @@ void ConvDNNLowPOp::ConvNHWCCore_( const int kernel_dim = KernelDim_(); const int Y_HxW = this->GetDimsSize(*Y); + if (N == 0) { + LOG(WARNING) << "The batch size is 0 in ConvNHWCCore_ function!"; + } + if (FLAGS_caffe2_dnnlowp_dump_tensors) { // Dump input activation StoreMatrixInMatrixMarketFormat( diff --git a/caffe2/quantization/server/fbgemm_pack_op.cc b/caffe2/quantization/server/fbgemm_pack_op.cc index 0938fdf2ea96b..04ea04440ce09 100644 --- a/caffe2/quantization/server/fbgemm_pack_op.cc +++ b/caffe2/quantization/server/fbgemm_pack_op.cc @@ -229,7 +229,9 @@ FullyConnectedDNNLowPPackWeightOp::FullyConnectedDNNLowPPackWeightOp( : DNNLowPOp(operator_def, ws), axis_w_(this->GetSingleArgument("axis_w", 1)), quantize_channelwise_( - this->GetSingleArgument("quantize_channelwise", false)) { + this->GetSingleArgument("quantize_channelwise", false)), + save_unpacked_weights_( + this->GetSingleArgument("save_unpacked_weights", false)) { if (this->debug_def().engine() == "DNNLOWP_ROWWISE") { quantize_channelwise_ = true; } @@ -258,6 +260,13 @@ bool FullyConnectedDNNLowPPackWeightOp::RunOnDevice() { QuantizeWeight( InputBlob(0), K, N, Y->qparams, W_quantized, qfactory_.get()); + if (save_unpacked_weights_) { + ReinitializeTensor( + &Y->original_tensor, filter.sizes(), at::dtype().device(CPU)); + auto* buffer = Y->original_tensor.template mutable_data(); + CAFFE_ENFORCE_EQ(Y->original_tensor.numel(), W_quantized.size()); + memcpy(buffer, W_quantized.data(), W_quantized.size() * sizeof(int8_t)); + } if (this->InputIsType(0) && quantize_channelwise_) { static int log_occurences = 0; if (log_occurences < 32) { diff --git a/caffe2/quantization/server/fbgemm_pack_op.h b/caffe2/quantization/server/fbgemm_pack_op.h index 54d3bdaf19909..b615b4735ed9d 100644 --- a/caffe2/quantization/server/fbgemm_pack_op.h +++ b/caffe2/quantization/server/fbgemm_pack_op.h @@ -24,6 +24,7 @@ class FullyConnectedDNNLowPPackWeightOp final int axis_w_; bool quantize_channelwise_; int nbits_in_non_outlier_; // only for DNNLOWP_ACC16 + bool save_unpacked_weights_; INPUT_TAGS(FILTER, BIAS); }; diff --git a/caffe2/quantization/server/pool_dnnlowp_op.cc b/caffe2/quantization/server/pool_dnnlowp_op.cc index cfc8c35ff4aac..6d695bc2853b0 100644 --- a/caffe2/quantization/server/pool_dnnlowp_op.cc +++ b/caffe2/quantization/server/pool_dnnlowp_op.cc @@ -327,54 +327,90 @@ class AveragePoolDnnLowPOp final } break; case 3: + if (is_same::value) { #ifdef _OPENMP #pragma omp parallel for #endif - for (int n = 0; n < X.dim32(0); ++n) { - const T* Xdata_temp = Xdata + n * height * width * depth * channels; - T* Ydata_temp = Ydata + - n * pooled_height * pooled_width * pooled_depth * channels; - for (int ph = 0; ph < pooled_height; ++ph) { - int hstart = ph * stride_h() - pad_t(); - int hend = min(hstart + kernel_h(), height); - hstart = max(hstart, 0); - for (int pw = 0; pw < pooled_width; ++pw) { - int wstart = pw * stride_w() - pad_l(); - int wend = min(wstart + kernel_w(), width); - wstart = max(wstart, 0); - for (int pd = 0; pd < pooled_depth; ++pd) { - int dstart = pd * stride_[2] - pads_[2]; - int dend = min(dstart + kernel_[2], depth); - dstart = max(dstart, 0); - int size = (hend - hstart) * (wend - wstart) * (dend - dstart); - float multiplier = - in_qparams_[0].scale / out_qparams_.scale / size; + for (int n = 0; n < X.dim32(0); ++n) { + average_pool_3d_avx2( + reinterpret_cast(Xdata), + n, + height, + width, + depth, + channels, + pooled_height, + pooled_width, + pooled_depth, + kernel_h(), + kernel_w(), + kernel_[2], + stride_h(), + stride_w(), + stride_[2], + pad_t(), + pad_l(), + pads_[2], + reinterpret_cast(Ydata), + in_qparams_[0].scale, + out_qparams_.scale, + in_qparams_[0].zero_point, + out_qparams_.zero_point, + minimum, + maximum); + } + } else { +#ifdef _OPENMP +#pragma omp parallel for +#endif + for (int n = 0; n < X.dim32(0); ++n) { + const T* Xdata_temp = Xdata + n * height * width * depth * channels; + T* Ydata_temp = Ydata + + n * pooled_height * pooled_width * pooled_depth * channels; + for (int ph = 0; ph < pooled_height; ++ph) { + int hstart = ph * stride_h() - pad_t(); + int hend = min(hstart + kernel_h(), height); + hstart = max(hstart, 0); + for (int pw = 0; pw < pooled_width; ++pw) { + int wstart = pw * stride_w() - pad_l(); + int wend = min(wstart + kernel_w(), width); + wstart = max(wstart, 0); + for (int pd = 0; pd < pooled_depth; ++pd) { + int dstart = pd * stride_[2] - pads_[2]; + int dend = min(dstart + kernel_[2], depth); + dstart = max(dstart, 0); + int size = + (hend - hstart) * (wend - wstart) * (dend - dstart); + float multiplier = + in_qparams_[0].scale / out_qparams_.scale / size; - for (int c = 0; c < channels; ++c) { - const int pool_idx = - ((ph * pooled_width + pw) * pooled_depth + pd) * - channels + - c; - int32_t Yh = -in_qparams_[0].zero_point * size; - for (int h = hstart; h < hend; ++h) { - for (int w = wstart; w < wend; ++w) { - for (int d = dstart; d < dend; ++d) { - const int input_idx = - ((h * width + w) * depth + d) * channels + c; - Yh += Xdata_temp[input_idx]; + for (int c = 0; c < channels; ++c) { + const int pool_idx = + ((ph * pooled_width + pw) * pooled_depth + pd) * + channels + + c; + int32_t Yh = -in_qparams_[0].zero_point * size; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + for (int d = dstart; d < dend; ++d) { + const int input_idx = + ((h * width + w) * depth + d) * channels + c; + Yh += Xdata_temp[input_idx]; + } } } - } - Ydata_temp[pool_idx] = std::min( - std::max( - nearbyint(Yh * multiplier + out_qparams_.zero_point), - minimum), - maximum); - } // channel - } // depth - } // width - } // height - } // for each image + Ydata_temp[pool_idx] = std::min( + std::max( + nearbyint( + Yh * multiplier + out_qparams_.zero_point), + minimum), + maximum); + } // channel + } // depth + } // width + } // height + } // for each image + } break; default: CAFFE_THROW("Unsupported pooling size : ", this->kernel_.size()); diff --git a/caffe2/quantization/server/pool_dnnlowp_op_avx2.cc b/caffe2/quantization/server/pool_dnnlowp_op_avx2.cc index 690546adc99b5..21c301d614b7a 100644 --- a/caffe2/quantization/server/pool_dnnlowp_op_avx2.cc +++ b/caffe2/quantization/server/pool_dnnlowp_op_avx2.cc @@ -199,4 +199,158 @@ void average_pool_avx2( } // ph loop } +void average_pool_3d_avx2( + const uint8_t* Xdata, + int n, + int height, + int width, + int depth, + int channels, + int pooled_height, + int pooled_width, + int pooled_depth, + int kernel_h, + int kernel_w, + int kernel_d, + int stride_h, + int stride_w, + int stride_d, + int pad_t, + int pad_l, + int pad_d, + uint8_t* Ydata, + float in_scale, + float out_scale, + int32_t in_zero_point, + int32_t out_zero_point, + int32_t minimum, + int32_t maximum) { + const uint8_t* Xdata_temp = Xdata + n * height * width * depth * channels; + uint8_t* Ydata_temp = + Ydata + n * pooled_height * pooled_width * pooled_depth * channels; + + const __m256i shuffle_mask_v = _mm256_set_epi8( + 0xff, + 0xff, + 0xff, + 0xff, + 0xff, + 0xff, + 0xff, + 0xff, + 0xff, + 0xff, + 0xff, + 0xff, + 0x0c, + 0x08, + 0x04, + 0x00, + 0xff, + 0xff, + 0xff, + 0xff, + 0xff, + 0xff, + 0xff, + 0xff, + 0xff, + 0xff, + 0xff, + 0xff, + 0x0c, + 0x08, + 0x04, + 0x00); + const __m256i permute_mask_v = + _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); + + const __m256i min_v = _mm256_set1_epi32(minimum); + const __m256i max_v = _mm256_set1_epi32(maximum); + __m256 out_zero_point_v = _mm256_set1_ps(out_zero_point); + + for (int ph = 0; ph < pooled_height; ++ph) { + int hstart = ph * stride_h - pad_t; + int hend = hstart + kernel_h < height ? hstart + kernel_h : height; + hstart = hstart > 0 ? hstart : 0; + for (int pw = 0; pw < pooled_width; ++pw) { + int wstart = pw * stride_w - pad_l; + int wend = wstart + kernel_w < width ? wstart + kernel_w : width; + wstart = wstart > 0 ? wstart : 0; + for (int pd = 0; pd < pooled_depth; ++pd) { + int dstart = pd * stride_d - pad_d; + int dend = dstart + kernel_d < depth ? dstart + kernel_d : depth; + dstart = max(dstart, 0); + + int size = (hend - hstart) * (wend - wstart) * (dend - dstart); + float multiplier = in_scale / out_scale / size; + __m256 multiplier_v = _mm256_set1_ps(multiplier); + + uint8_t* Yh = Ydata_temp + + ((ph * pooled_width + pw) * pooled_depth + pd) * channels; + constexpr int VLEN = 8; + int32_t Yh0 = -in_zero_point * size; + + // vectorized loop + for (int c = 0; c < channels / VLEN * VLEN; c += VLEN) { + __m256i Yh0_v = _mm256_set1_epi32(Yh0); + + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + for (int d = dstart; d < dend; ++d) { + const int input_idx = + ((h * width + w) * depth + d) * channels + c; + const __m256i temp_v = _mm256_cvtepu8_epi32(_mm_loadl_epi64( + reinterpret_cast(Xdata_temp + input_idx))); + Yh0_v = _mm256_add_epi32(Yh0_v, temp_v); + } + } + } + + __m256 Yh0_fp = _mm256_cvtepi32_ps(Yh0_v); + __m256 Y_float_v = + _mm256_fmadd_ps(Yh0_fp, multiplier_v, out_zero_point_v); + __m256i Y_rounded_v = _mm256_cvtps_epi32(Y_float_v); + __m256i Y_clipped_v = + _mm256_max_epi32(min_v, _mm256_min_epi32(max_v, Y_rounded_v)); + + Y_clipped_v = _mm256_shuffle_epi8(Y_clipped_v, shuffle_mask_v); + Y_clipped_v = + _mm256_permutevar8x32_epi32(Y_clipped_v, permute_mask_v); + *reinterpret_cast(Yh + c) = + _mm256_extract_epi64(Y_clipped_v, 0); + } + + // remainder + for (int c = channels / VLEN * VLEN; c < channels; ++c) { + Yh[c] = 0; + } + + for (int c = channels / VLEN * VLEN; c < channels; ++c) { + const int pool_idx = + ((ph * pooled_width + pw) * pooled_depth + pd) * channels + c; + + int32_t Yh_t = -in_zero_point * size; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + for (int d = dstart; d < dend; ++d) { + const int input_idx = + ((h * width + w) * depth + d) * channels + c; + + Yh_t += Xdata_temp[input_idx]; + } + } + } + + Ydata_temp[pool_idx] = std::min( + std::max( + nearbyint(Yh_t * multiplier + out_zero_point), minimum), + maximum); + } + + } // pd loop + } // pw loop + } // ph loop +} + } // namespace caffe2 diff --git a/caffe2/quantization/server/pool_dnnlowp_op_avx2.h b/caffe2/quantization/server/pool_dnnlowp_op_avx2.h index 4ed3932621381..b0e10abb20d21 100644 --- a/caffe2/quantization/server/pool_dnnlowp_op_avx2.h +++ b/caffe2/quantization/server/pool_dnnlowp_op_avx2.h @@ -45,4 +45,31 @@ void average_pool_avx2( int32_t minimum, int32_t maximum); +void average_pool_3d_avx2( + const uint8_t* Xdata, + int n, + int height, + int width, + int depth, + int channels, + int pooled_height, + int pooled_width, + int pooled_depth, + int kernel_h, + int kernel_w, + int kernel_d, + int stride_h, + int stride_w, + int stride_d, + int pad_t, + int pad_l, + int pad_d, + uint8_t* Ydata, + float in_scale, + float out_scale, + int32_t in_zero_point, + int32_t out_zero_point, + int32_t minimum, + int32_t maximum); + } // namespace caffe2 diff --git a/caffe2/sgd/learning_rate_functors.h b/caffe2/sgd/learning_rate_functors.h index aec243560827f..d9633b0216f8a 100644 --- a/caffe2/sgd/learning_rate_functors.h +++ b/caffe2/sgd/learning_rate_functors.h @@ -219,9 +219,13 @@ class HillLearningRate : public LearningRateFunctor { template class CompositeLearningRateItem { public: - CompositeLearningRateItem(int64_t num_iter, LearningRateFunctor* policy) - : num_iter_(num_iter), policy_(policy) {} + CompositeLearningRateItem( + int64_t num_iter, + float lr_scale, + LearningRateFunctor* policy) + : num_iter_(num_iter), lr_scale_(lr_scale), policy_(policy) {} int64_t num_iter_; + float lr_scale_; LearningRateFunctor* policy_; }; @@ -236,6 +240,7 @@ class CompositeLearningRate : public LearningRateFunctor { for (auto it = sub_policies.begin(); it != sub_policies.end(); ++it) { DCHECK_GT(it->num_iter_, 0); sub_policies_[num_iter_start].reset(it->policy_); + sub_policy_lr_scales_[num_iter_start] = it->lr_scale_; num_iter_start += it->num_iter_; } } @@ -243,11 +248,15 @@ class CompositeLearningRate : public LearningRateFunctor { auto sub_policy = sub_policies_.upper_bound(iter); DCHECK(sub_policy != sub_policies_.begin()); --sub_policy; - return (*sub_policy->second)(iter); + auto sub_policy_lr_scale = sub_policy_lr_scales_.upper_bound(iter); + DCHECK(sub_policy_lr_scale != sub_policy_lr_scales_.begin()); + --sub_policy_lr_scale; + return ((*sub_policy->second)(iter)) * (sub_policy_lr_scale->second); } private: std::map>> sub_policies_; + std::map sub_policy_lr_scales_; }; // Cyclical: return a learning rate with period 2 * stepsize and @@ -256,16 +265,94 @@ class CompositeLearningRate : public LearningRateFunctor { template class CyclicalLearningRate : public LearningRateFunctor { public: - CyclicalLearningRate(const T base_lr, const T max_lr, const int stepsize) - : base_lr_(base_lr), max_lr_(max_lr), stepsize_(stepsize) {} + CyclicalLearningRate( + const T base_lr, + const T max_lr, + const int stepsize, + const T decay) + : base_lr_(base_lr), + max_lr_(max_lr), + stepsize_(stepsize), + decay_(decay) {} T operator()(const int64_t iter) const override { - int cycle = static_cast((iter / (2 * stepsize_)) + 1); + int64_t cycle = static_cast((iter / (2 * stepsize_)) + 1); T x = abs(static_cast(iter) / stepsize_ - 2 * cycle + 1); - return (1 + (T(max_lr_) / T(base_lr_) - 1) * std::max(T(0.0), (1 - x))); + return 1 + + (T(abs(max_lr_)) / T(abs(base_lr_)) - 1) * std::max(T(0.0), (1 - x)) * + std::pow(decay_, static_cast(iter / (2 * stepsize_))); } T base_lr_; T max_lr_; int stepsize_; + T decay_; +}; + +// constantThenLinearWarmup: first use a constant multiplier +// and then ramp up to the global lr +template +class ConstantThenLinearWarmupLearningRate : public LearningRateFunctor { + public: + ConstantThenLinearWarmupLearningRate( + const T start_warmup_multiplier, + const int64_t constant_warmup_num_iter, + const int64_t linear_warmup_num_iter) + : constant_warmup_num_iter_(constant_warmup_num_iter), + linear_warmup_num_iter_(linear_warmup_num_iter), + constant_warmup_lr_(start_warmup_multiplier, constant_warmup_num_iter), + linear_warmup_lr_(start_warmup_multiplier, linear_warmup_num_iter) {} + + T operator()(const int64_t iter) const override { + if (iter < constant_warmup_num_iter_) { + return constant_warmup_lr_(iter); + } else if (iter < constant_warmup_num_iter_ + linear_warmup_num_iter_) { + return linear_warmup_lr_(iter - constant_warmup_num_iter_); + } else { + return 1.0; + } + } + int64_t constant_warmup_num_iter_; + int64_t linear_warmup_num_iter_; + ConstantWarmupLearningRate constant_warmup_lr_; + LinearWarmupLearningRate linear_warmup_lr_; +}; + +// CompositeCyclicalLearningRate: first use a constant multiplier +// and then ramp up to the global lr, and then use a cyclical learning rate +template +class CompositeCyclicalLearningRate : public LearningRateFunctor { + public: + CompositeCyclicalLearningRate( + const T base_lr, + const T start_warmup_multiplier, + const int64_t constant_warmup_num_iter, + const int64_t linear_warmup_num_iter, + const T cyclical_max_lr, + const int cyclical_step_size, + const T cyclical_decay) + : constant_warmup_num_iter_(constant_warmup_num_iter), + linear_warmup_num_iter_(linear_warmup_num_iter), + constant_then_linear_warmup_lr_( + start_warmup_multiplier, + constant_warmup_num_iter, + linear_warmup_num_iter), + cyclical_lr_( + base_lr, + cyclical_max_lr, + cyclical_step_size, + cyclical_decay) {} + + T operator()(const int64_t iter) const override { + if (iter < constant_warmup_num_iter_ + linear_warmup_num_iter_) { + return constant_then_linear_warmup_lr_(iter); + } + return cyclical_lr_( + iter - constant_warmup_num_iter_ - linear_warmup_num_iter_); + } + + int64_t constant_warmup_num_iter_; + int64_t linear_warmup_num_iter_; + ConstantThenLinearWarmupLearningRate constant_then_linear_warmup_lr_; + CyclicalLearningRate cyclical_lr_; }; } // namespace caffe2 diff --git a/caffe2/sgd/learning_rate_op.cc b/caffe2/sgd/learning_rate_op.cc index aa497156bfda3..bd2ad56477c95 100644 --- a/caffe2/sgd/learning_rate_op.cc +++ b/caffe2/sgd/learning_rate_op.cc @@ -33,6 +33,8 @@ more exponential. Learning rate is controlled by the following arguments: `hill`: uses those in both `linearWarmup` and `inv`, plus `end_multiplier` `composite`: uses `sub_policy_num_iters` and additional args with format `cyclic`: uses `max_lr`, `stepsize` + `constantThenLinearWarmup`: uses `start_warmup_multiplier`, `constant_warmup_num_iter`, `linear_warmup_num_iter` + `compositeCyclical`: uses `start_warmup_multiplier`, `constant_warmup_num_iter`, `linear_warmup_num_iter`, `cyclical_max_lr`, `cyclical_step_size`, `cyclical_decay` sub_policy_{sub_policy_index}_{sub_policy_arg}, for example: sub_policy_0_policy: "exp", sub_policy_0_gamma: 0.99, sub_policy_0_lr_scale: 1.2 @@ -54,7 +56,12 @@ more exponential. Learning rate is controlled by the following arguments: `m2`: defaults to 0.5, the second piece lr of piece warmup `n2`: defaults to 0, iter threshold of the second piece lr `m3`: defaults to 0.5, the third piece lr of piece warmup - + `start_warmup_multiplier`: defaults to 0.1, part of constantThenLinearWarmup + `constant_warmup_num_iter`: defaults to 10000000, part of constantThenLinearWarmup and constantThenLinearWarmup + `linear_warmup_num_iter`: defaults to 10000000, part of constantThenLinearWarmup and CompositeCyclicalLRPolicy + `cyclical_max_lr`: defaults to 0.05, part of CompositeCyclicalLRPolicy + `cyclical_step_size`: defaults to 1000000, part of CompositeCyclicalLRPolicy + `cyclical_decay`: defaults to 1.0, part of CompositeCyclicalLRPolicy Usage: train_net.LearningRate(*iterations*, "*label*", base_lr=*float*, @@ -101,6 +108,18 @@ Example usage: .Arg("m2", "") .Arg("n2", "") .Arg("m3", "") + .Arg("start_warmup_multiplier", "defaults to 0.1") + .Arg("constant_warmup_num_iter", "defaults to 10000000") + .Arg("linear_warmup_num_iter", "defaults to 10000000") + .Arg( + "cyclical_max_lr", + "defaults to 0.05, part of CompositeCyclicalLRPolicy") + .Arg( + "cyclical_step_size", + "defaults to 1000000, part of CompositeCyclicalLRPolicy") + .Arg( + "cyclical_decay", + "defaults to 0.999, part of CompositeCyclicalLRPolicy") .Input(0, "input", "description needed") .Output(0, "output", "description needed") .DeviceInferenceFunction([](const OperatorDef& def) { @@ -110,4 +129,4 @@ Example usage: }); NO_GRADIENT(LearningRate); -} // namespace caffe2 +} // namespace caffe2 diff --git a/caffe2/sgd/learning_rate_op.h b/caffe2/sgd/learning_rate_op.h index 8a57a31d17e0f..acd11caa719a2 100644 --- a/caffe2/sgd/learning_rate_op.h +++ b/caffe2/sgd/learning_rate_op.h @@ -27,7 +27,7 @@ class LearningRateOp final : public Operator { bool RunOnDevice() override { int64_t iter = OperatorBase::Input(0, CPU).template data()[0]; - T learning_rate = cur_base_lr_ * (*functor_)(iter); + T learning_rate = base_lr_ * (*functor_)(iter); // Write to output. auto* output = Output(0); output->Resize(vector()); @@ -39,17 +39,10 @@ class LearningRateOp final : public Operator { private: unique_ptr> functor_; T base_lr_; - T base_lr_scale_; - T cur_base_lr_; LearningRateFunctor* createLearningRateFunctor( const string& policy, const string& arg_prefix = "") { - if (policy != "composite") { - base_lr_scale_ = - this->template GetSingleArgument(arg_prefix + "lr_scale", 1.0); - cur_base_lr_ = base_lr_scale_ * base_lr_; - } if (policy == "fixed") { return new FixedLearningRate(); } else if (policy == "alter") { @@ -140,7 +133,7 @@ class LearningRateOp final : public Operator { this->template GetSingleArgument(arg_prefix + "n1", 0); T m2 = this->template GetSingleArgument(arg_prefix + "m2", 0.5); int64_t n2 = - this->template GetSingleArgument(arg_prefix + "n1", 0); + this->template GetSingleArgument(arg_prefix + "n2", 0); T m3 = this->template GetSingleArgument(arg_prefix + "m3", 0.5); return new PieceWarmupLearningRate(m1, n1, m2, n2, m3); } else if (policy == "composite") { @@ -166,8 +159,11 @@ class LearningRateOp final : public Operator { "Defining composite LR policy as a subpolicy of composite LR " "policy is not allowed."); } + const float scale_lr = this->template GetSingleArgument( + sub_policy_arg_prefix_str + "lr_scale", 1.0); sub_policies.push_back(CompositeLearningRateItem( sub_policy_num_iters[i], + scale_lr, createLearningRateFunctor(sub_policy, sub_policy_arg_prefix_str))); } return new CompositeLearningRate(sub_policies); @@ -176,9 +172,44 @@ class LearningRateOp final : public Operator { this->template GetSingleArgument(arg_prefix + "max_lr", 0.005); int stepsize = this->template GetSingleArgument(arg_prefix + "stepsize", 0); + T decay = + this->template GetSingleArgument(arg_prefix + "decay", 1.0); DCHECK_GT(stepsize, 0); DCHECK_GE(max_lr, base_lr_); - return new CyclicalLearningRate(base_lr_, max_lr, stepsize); + return new CyclicalLearningRate(base_lr_, max_lr, stepsize, decay); + } else if (policy == "constantThenLinearWarmup") { + T start_warmup_multiplier = this->template GetSingleArgument( + arg_prefix + "start_warmup_multiplier", 0.1); + int64_t constant_warmup_num_iter = this->template GetSingleArgument( + arg_prefix + "constant_warmup_num_iter", 10000000); + int64_t linear_warmup_num_iter = this->template GetSingleArgument( + arg_prefix + "linear_warmup_num_iter", 10000000); + return new ConstantThenLinearWarmupLearningRate( + start_warmup_multiplier, + constant_warmup_num_iter, + linear_warmup_num_iter); + } else if (policy == "compositeCyclical") { + T start_warmup_multiplier = this->template GetSingleArgument( + arg_prefix + "start_warmup_multiplier", 0.1); + int64_t constant_warmup_num_iter = this->template GetSingleArgument( + arg_prefix + "constant_warmup_num_iter", 10000000); + int64_t linear_warmup_num_iter = this->template GetSingleArgument( + arg_prefix + "linear_warmup_num_iter", 10000000); + T cyclical_max_lr = this->template GetSingleArgument( + arg_prefix + "cyclical_max_lr", 0.05); + int cyclical_step_size = this->template GetSingleArgument( + arg_prefix + "cyclical_step_size", 1000000); + T cyclical_decay = this->template GetSingleArgument( + arg_prefix + "cyclical_decay", 1.0); + DCHECK_GE(cyclical_max_lr, base_lr_); + return new CompositeCyclicalLearningRate( + base_lr_, + start_warmup_multiplier, + constant_warmup_num_iter, + linear_warmup_num_iter, + cyclical_max_lr, + cyclical_step_size, + cyclical_decay); } else { CAFFE_THROW("Unknown learning rate policy: ", policy); return NULL; diff --git a/caffe2/utils/CMakeLists.txt b/caffe2/utils/CMakeLists.txt index 99b90a4ab9f51..2fc7a19224547 100644 --- a/caffe2/utils/CMakeLists.txt +++ b/caffe2/utils/CMakeLists.txt @@ -1,3 +1,15 @@ +if (INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE) + list(APPEND Caffe2_CPU_SRCS + utils/string_utils.cc + utils/threadpool/pthreadpool.cc + utils/threadpool/pthreadpool_impl.cc + utils/threadpool/ThreadPool.cc + utils/threadpool/ThreadPoolMobile.cc + ) + set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} PARENT_SCOPE) + return() +endif() + list(APPEND Caffe2_CPU_SRCS utils/bench_utils.cc utils/cpuid.cc diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index 3efaafe05e562..387304cf74679 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -180,7 +180,8 @@ if (INTERN_BUILD_ATEN_OPS) # that they are equivalent so it must be a dependency of the script set(core_gen_checked_inputs ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/core/TensorBody.h - ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/core/TensorMethods.h) + ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/core/TensorMethods.h + ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/core/OpsAlreadyMovedToC10.cpp) file(MAKE_DIRECTORY ${CMAKE_BINARY_DIR}/aten/src/ATen) file(MAKE_DIRECTORY ${CMAKE_BINARY_DIR}/aten/src/ATen/core_tmp) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 0a033dbc9b219..28891e645922c 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -34,7 +34,7 @@ macro(enable_ubsan) endmacro() # ---[ Custom Protobuf -if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO) +if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO AND (NOT INTERN_BUILD_MOBILE OR BUILD_CAFFE2_MOBILE)) disable_ubsan() include(${CMAKE_CURRENT_LIST_DIR}/ProtoBuf.cmake) enable_ubsan() @@ -107,7 +107,7 @@ else() set(AT_MKLDNN_ENABLED 0) set(AT_MKL_ENABLED 0) endif() -set_property(CACHE BLAS PROPERTY STRINGS "Eigen;ATLAS;OpenBLAS;MKL;vecLib") +set_property(CACHE BLAS PROPERTY STRINGS "Eigen;ATLAS;OpenBLAS;MKL;vecLib;FLAME") message(STATUS "Trying to find preferred BLAS backend of choice: " ${BLAS}) if(BLAS STREQUAL "Eigen") @@ -186,6 +186,28 @@ set(CONFU_DEPENDENCIES_SOURCE_DIR ${PROJECT_BINARY_DIR}/confu-srcs set(CONFU_DEPENDENCIES_BINARY_DIR ${PROJECT_BINARY_DIR}/confu-deps CACHE PATH "Confu-style dependencies binary directory") +# ---[ pthreadpool +# QNNPACK and NNPACK both depend on pthreadpool, but when building with libtorch +# they should use the pthreadpool implementation under caffe2/utils/threadpool +# instead of the default implementation. To avoid confusion, add pthreadpool +# subdirectory explicitly with EXCLUDE_FROM_ALL property prior to QNNPACK/NNPACK +# does so, which will prevent it from installing the default pthreadpool library. +if(INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE AND (USE_QNNPACK OR USE_NNPACK)) + if(NOT DEFINED PTHREADPOOL_SOURCE_DIR) + set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party") + set(PTHREADPOOL_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/pthreadpool" CACHE STRING "pthreadpool source directory") + endif() + + IF(NOT TARGET pthreadpool) + SET(PTHREADPOOL_BUILD_TESTS OFF CACHE BOOL "") + SET(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE BOOL "") + ADD_SUBDIRECTORY( + "${PTHREADPOOL_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/pthreadpool" + EXCLUDE_FROM_ALL) + ENDIF() +endif() + # ---[ QNNPACK if(USE_QNNPACK) if (IOS) @@ -272,6 +294,67 @@ if(USE_QNNPACK) include_directories(SYSTEM "${CAFFE2_THIRD_PARTY_ROOT}/neon2sse") endif() +# ---[ PYTORCH_QNNPACK +if(USE_PYTORCH_QNNPACK) + if (IOS) + list(LENGTH IOS_ARCH IOS_ARCH_COUNT) + if (IOS_ARCH_COUNT GREATER 1) + message(WARNING + "Multi-architecture (${IOS_ARCH}) builds are not supported in QNNPACK. " + "Specify a single architecture in IOS_ARCH and re-configure, or " + "turn this warning off by USE_PYTORCH_QNNPACK=OFF.") + set(USE_PYTORCH_QNNPACK OFF) + endif() + if (NOT IOS_ARCH MATCHES "^(i386|x86_64|armv7.*|arm64.*)$") + message(WARNING + "Target architecture \"${IOS_ARCH}\" is not supported in QNNPACK. " + "Supported architectures are x86, x86-64, ARM, and ARM64. " + "Turn this warning off by USE_PYTORCH_QNNPACK=OFF.") + set(USE_PYTORCH_QNNPACK OFF) + endif() + else() + if (NOT IOS AND NOT (CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux|Darwin)$")) + message(WARNING + "Target platform \"${CMAKE_SYSTEM_NAME}\" is not supported in QNNPACK. " + "Supported platforms are Android, iOS, Linux, and macOS. " + "Turn this warning off by USE_PYTORCH_QNNPACK=OFF.") + set(USE_PYTORCH_QNNPACK OFF) + endif() + if (NOT IOS AND NOT (CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|AMD64|x86_64|armv[0-9].*|arm64|aarch64)$")) + message(WARNING + "Target architecture \"${CMAKE_SYSTEM_PROCESSOR}\" is not supported in QNNPACK. " + "Supported architectures are x86, x86-64, ARM, and ARM64. " + "Turn this warning off by USE_PYTORCH_QNNPACK=OFF.") + set(USE_PYTORCH_QNNPACK OFF) + endif() + endif() + if (USE_PYTORCH_QNNPACK) + if (NOT DEFINED PYTORCH_QNNPACK_SOURCE_DIR) + set(PYTORCH_QNNPACK_SOURCE_DIR "${PROJECT_SOURCE_DIR}/aten/src/ATen/native/quantized/cpu/qnnpack" CACHE STRING "QNNPACK source directory") + endif() + + if(NOT TARGET pytorch_qnnpack) + set(PYTORCH_QNNPACK_BUILD_TESTS OFF CACHE BOOL "") + set(PYTORCH_QNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "") + set(PYTORCH_QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "") + set(PYTORCH_QNNPACK_LIBRARY_TYPE "static" CACHE STRING "") + set(PTHREADPOOL_LIBRARY_TYPE "static" CACHE STRING "") + set(CPUINFO_LIBRARY_TYPE "static" CACHE STRING "") + set(CPUINFO_LOG_LEVEL "error" CACHE STRING "") + add_subdirectory( + "${PYTORCH_QNNPACK_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/pytorch_qnnpack") + # We build static versions of QNNPACK and pthreadpool but link + # them into a shared library for Caffe2, so they need PIC. + set_property(TARGET pytorch_qnnpack PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON) + endif() + + list(APPEND Caffe2_DEPENDENCY_LIBS pytorch_qnnpack) + endif() +endif() + # ---[ NNPACK if(USE_NNPACK) include(${CMAKE_CURRENT_LIST_DIR}/External/nnpack.cmake) @@ -815,7 +898,7 @@ if(USE_CUDA) caffe2_update_option(USE_NVRTC OFF) endif() if(CAFFE2_USE_CUDNN) - IF(CUDNN_STATIC_LINKAGE) + IF(CUDNN_STATIC) LIST(APPEND Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS caffe2::cudnn "${CUDA_TOOLKIT_ROOT_DIR}/lib64/libculibos.a" "dl") ELSE() @@ -893,7 +976,7 @@ if(USE_ROCM) hip_include_directories(${Caffe2_HIP_INCLUDE}) set(Caffe2_HIP_DEPENDENCY_LIBS - ${PYTORCH_HIP_HCC_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES} ${hipcub_LIBRARIES}) + ${PYTORCH_HIP_HCC_LIBRARIES} ${PYTORCH_MIOPEN_LIBRARIES} ${hipcub_LIBRARIES} ${ROCM_HIPRTC_LIB}) # Note [rocblas & rocfft cmake bug] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -941,8 +1024,8 @@ if(USE_CUDA) endif() if(USE_GLOO) - if(NOT ${CMAKE_SYSTEM_NAME} STREQUAL "Linux") - message(WARNING "Gloo can only be used on Linux.") + if(MSVC) + message(WARNING "Gloo can not be used on Windows.") caffe2_update_option(USE_GLOO OFF) elseif(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) message(WARNING "Gloo can only be used on 64-bit systems.") @@ -1035,7 +1118,7 @@ if (USE_ZSTD) endif() # ---[ Onnx -if (CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO) +if (CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO AND NOT INTERN_DISABLE_ONNX) if(EXISTS "${CAFFE2_CUSTOM_PROTOC_EXECUTABLE}") set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${CAFFE2_CUSTOM_PROTOC_EXECUTABLE}) endif() diff --git a/cmake/Modules/FindBLAS.cmake b/cmake/Modules/FindBLAS.cmake index d6c03467397c9..e93e98a6095d4 100644 --- a/cmake/Modules/FindBLAS.cmake +++ b/cmake/Modules/FindBLAS.cmake @@ -211,6 +211,20 @@ if((NOT BLAS_LIBRARIES) endif (BLAS_LIBRARIES) endif() +if((NOT BLAS_LIBRARIES) + AND ((NOT WITH_BLAS) OR (WITH_BLAS STREQUAL "FLAME"))) + # FLAME's blis library (https://github.com/flame/blis) + check_fortran_libraries( + BLAS_LIBRARIES + BLAS + sgemm + "" + "blis") + if (BLAS_LIBRARIES) + set(BLAS_INFO "FLAME") + endif (BLAS_LIBRARIES) +endif() + # BLAS in ATLAS library? (http://math-atlas.sourceforge.net/) if((NOT BLAS_LIBRARIES) AND ((NOT WITH_BLAS) OR (WITH_BLAS STREQUAL "atlas"))) diff --git a/cmake/Modules/FindLAPACK.cmake b/cmake/Modules/FindLAPACK.cmake index a92d3adab09be..c057f207132f1 100644 --- a/cmake/Modules/FindLAPACK.cmake +++ b/cmake/Modules/FindLAPACK.cmake @@ -143,6 +143,21 @@ if(BLAS_FOUND) endif() endif() + # FLAME + IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "FLAME")) + check_lapack_libraries( + LAPACK_LIBRARIES + LAPACK + cheev + "" + "flame" + "${BLAS_LIBRARIES}" + ) + if(LAPACK_LIBRARIES) + SET(LAPACK_INFO "FLAME") + endif(LAPACK_LIBRARIES) + endif() + # ACML IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "acml")) SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) diff --git a/cmake/Modules/FindOpenBLAS.cmake b/cmake/Modules/FindOpenBLAS.cmake index 70574ab95b0f0..be9e713d16f56 100644 --- a/cmake/Modules/FindOpenBLAS.cmake +++ b/cmake/Modules/FindOpenBLAS.cmake @@ -24,7 +24,7 @@ SET(Open_BLAS_LIB_SEARCH_PATHS /usr/local/lib64 /usr/local/opt/openblas/lib /opt/OpenBLAS/lib - $ENV{OpenBLAS}cd + $ENV{OpenBLAS} $ENV{OpenBLAS}/lib $ENV{OpenBLAS_HOME} $ENV{OpenBLAS_HOME}/lib diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 0f11986e361c0..297509575de77 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -127,7 +127,6 @@ function (caffe2_print_configuration_summary) if(${USE_DISTRIBUTED}) message(STATUS " USE_MPI : ${USE_MPI}") message(STATUS " USE_GLOO : ${USE_GLOO}") - message(STATUS " USE_GLOO_IBVERBS : ${USE_GLOO_IBVERBS}") endif() message(STATUS " BUILD_NAMEDTENSOR : ${BUILD_NAMEDTENSOR}") diff --git a/cmake/iOS.cmake b/cmake/iOS.cmake index 4dffe50f8af1a..b838a918ef77b 100644 --- a/cmake/iOS.cmake +++ b/cmake/iOS.cmake @@ -158,7 +158,7 @@ set (CMAKE_OSX_SYSROOT ${CMAKE_IOS_SDK_ROOT} CACHE PATH "Sysroot used for iOS su # set the architecture for iOS if (IOS_PLATFORM STREQUAL "OS") - set (DEFAULT_IOS_ARCH "armv7;armv7s;arm64") + set (DEFAULT_IOS_ARCH "arm64") elseif (IOS_PLATFORM STREQUAL "SIMULATOR") set (DEFAULT_IOS_ARCH "x86_64") elseif (IOS_PLATFORM STREQUAL "WATCHOS") diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 79a37a5025f05..414b2be3afbae 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -167,6 +167,8 @@ IF(HIP_FOUND) # TODO: miopen_LIBRARIES should return fullpath to the library file, # however currently it's just the lib name FIND_LIBRARY(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES} HINTS ${MIOPEN_PATH}/lib) + # hiprtc is part of HIP + FIND_LIBRARY(ROCM_HIPRTC_LIB hiprtc HINTS ${HIP_PATH}/lib) # Necessary includes for building PyTorch since we include HIP headers that depend on hcc/hsa headers. diff --git a/docker/caffe2/jenkins/common/install_rocm.sh b/docker/caffe2/jenkins/common/install_rocm.sh index 82bc4251118d4..6294c7c3b0650 100644 --- a/docker/caffe2/jenkins/common/install_rocm.sh +++ b/docker/caffe2/jenkins/common/install_rocm.sh @@ -23,13 +23,13 @@ install_ubuntu() { rocfft \ miopen-hip \ rocblas \ - rocm-profiler \ - cxlactivitylogger \ hipsparse \ rocrand \ hipcub \ rocthrust \ - rccl + rccl \ + rocprofiler-dev \ + roctracer-dev } install_centos() { @@ -55,13 +55,13 @@ install_centos() { rocfft \ miopen-hip \ rocblas \ - rocm-profiler \ - cxlactivitylogger \ hipsparse \ rocrand \ rccl \ hipcub \ - rocthrust + rocthrust \ + rocprofiler-dev \ + roctracer-dev } # Install Python packages depending on the base OS diff --git a/docs/cpp/source/check-doxygen.sh b/docs/cpp/source/check-doxygen.sh index 4311227cb91dc..5d7b6a893478e 100755 --- a/docs/cpp/source/check-doxygen.sh +++ b/docs/cpp/source/check-doxygen.sh @@ -42,6 +42,7 @@ cp original-doxygen-log.txt doxygen-log.txt # Filter out some warnings. ignore_warning "warning: no uniquely matching class member found for" ignore_warning "warning: explicit link request to 'Item' could not be resolved" +ignore_warning "warning: Included by graph for 'types.h' not generated, too many nodes" # Count the number of remaining warnings. warnings="$(grep 'warning:' doxygen-log.txt | wc -l)" diff --git a/docs/cpp/source/notes/tensor_basics.rst b/docs/cpp/source/notes/tensor_basics.rst index 5d25efcf68de9..09032546a3a9a 100644 --- a/docs/cpp/source/notes/tensor_basics.rst +++ b/docs/cpp/source/notes/tensor_basics.rst @@ -76,20 +76,25 @@ CUDA accessors .. code-block:: cpp __global__ void packed_accessor_kernel( - PackedTensorAccessor foo, + PackedTensorAccessor64 foo, float* trace) { int i=threadIdx.x atomicAdd(trace, foo[i][i]) } - + torch::Tensor foo = torch::rand({12, 12}); // assert foo is 2-dimensional and holds floats. - auto foo_a = foo.packed_accessor(); + auto foo_a = foo.packed_accessor64(); float trace = 0; packed_accessor_kernel<<<1, 12>>>(foo_a, &trace); +In addition to ``PackedTensorAccessor64`` and ``packed_accessor64`` there are +also the corresponding ``PackedTensorAccessor32`` and ``packed_accessor32`` +which use 32-bit integers for indexing. This can be quite a bit faster on CUDA +but may lead to overflows in the indexing calculations. + Note that the template can hold other parameters such as the pointer restriction and the integer type for indexing. See documentation for a thorough template description of *accessors* and *packed accessors*. diff --git a/docs/source/conf.py b/docs/source/conf.py index 165a643a56078..a0bc0ddfeeaa5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -240,11 +240,11 @@ def setup(app): from docutils import nodes from sphinx.util.docfields import TypedField from sphinx import addnodes - +import sphinx.ext.doctest # Without this, doctest adds any example with a `>>>` as a test doctest_test_doctest_blocks = '' - +doctest_default_flags = sphinx.ext.doctest.doctest.ELLIPSIS doctest_global_setup = ''' try: import torchvision diff --git a/docs/source/jit.rst b/docs/source/jit.rst index 7d2b7e14eacdd..beabafd55b324 100644 --- a/docs/source/jit.rst +++ b/docs/source/jit.rst @@ -202,9 +202,9 @@ Modules The :func:`@torch.jit.ignore ` annotation's behavior changes in PyTorch 1.2. Before PyTorch 1.2 the @ignore decorator was used to make a function or method callable from code that is exported. To get this functionality back, - use ``@torch.jit.ignore(drop_on_export=True)``. ``@torch.jit.ignore`` is now equivalent - to ``@torch.jit.ignore(drop_on_export=False)``. See :func:`@torch.jit.ignore ` - for details. + use ``@torch.jit.unused()``. ``@torch.jit.ignore`` is now equivalent + to ``@torch.jit.ignore(drop=False)``. See :func:`@torch.jit.ignore ` + and :func:`@torch.jit.unused` for details. When passed to the :func:`torch.jit.script ` function, a ``torch.nn.Module``\'s data is copied to a ``ScriptModule`` and the TorchScript compiler compiles the module. @@ -354,7 +354,7 @@ Containers are assumed to have type ``Tensor`` and be non-optional (see tell the TorchScript compiler what the type should be. Python 3 style type hints are now supported. -:: +.. testcode:: import torch from typing import Dict, Optional @@ -422,9 +422,9 @@ net models. In particular, TorchScript supports: Unlike Python, each variable in TorchScript function must have a single static type. This makes it easier to optimize TorchScript functions. -.. TODO: test this code with `testcode`, but it looks like that doesn't support exceptions +Example (a type mismatch) -Example (a type mismatch):: +.. testcode:: import torch @@ -434,8 +434,25 @@ Example (a type mismatch):: r = torch.rand(1) else: r = 4 - return r # Type mismatch: r is set to type Tensor in the true branch - # and type int in the false branch + return r + + +.. testoutput:: + + Traceback (most recent call last): + ... + RuntimeError: ... + + Type mismatch: r is set to type Tensor in the true branch and type int in the false branch: + @torch.jit.script + def an_error(x): + if x: + ~~~~~... <--- HERE + r = torch.rand(1) + else: + r = 4 + return r + ... Default Types @@ -493,7 +510,7 @@ use `Python 3 type hints`_. If you are on Python 2, you can use ``torch.jit.anno Example (type annotations for Python 3): -:: +.. testcode:: import torch import torch.nn as nn @@ -908,9 +925,9 @@ Pattern Matching Assignments a, b, *c = a_tuple Multiple Assignments - :: +:: - a = b, c = tup + a = b, c = tup Print Statements ^^^^^^^^^^^^^^^^ @@ -970,9 +987,7 @@ constant by adding the name of the attribute to the ``__constants__`` list for the type. For loops over a ``nn.ModuleList`` will unroll the body of the loop at compile time, with each member of the constant module list. -.. TODO: enable testcode when https://github.com/pytorch/pytorch/pull/24412 lands - -:: +.. testcode:: class SubModule(torch.nn.Module): def __init__(self): @@ -1028,15 +1043,30 @@ is an error to use it after the end of the if statement. Similarly, a variable is not allowed to be used if it is only *defined* along some paths through the function. -.. TODO: Test this code and catch the exception +Example: -Example:: +.. testcode:: @torch.jit.script def foo(x): if x < 0: y = 4 - print(y) # Error: undefined value y + print(y) + +.. testoutput:: + + Traceback (most recent call last): + ... + RuntimeError: ... + + y is not defined in the false branch... + @torch.jit.script... + def foo(x): + if x < 0: + ~~~~~~~~~... <--- HERE + y = 4 + print(y) + ... Non-local variables are resolved to Python values at compile time when the function is defined. These values are then converted into TorchScript values using diff --git a/docs/source/optim.rst b/docs/source/optim.rst index 7ca7725d5d4e1..2ed1b5d661425 100644 --- a/docs/source/optim.rst +++ b/docs/source/optim.rst @@ -169,3 +169,5 @@ should write your code this way: :members: .. autoclass:: torch.optim.lr_scheduler.OneCycleLR :members: +.. autoclass:: torch.optim.lr_scheduler.CosineAnnealingWarmRestarts + :members: diff --git a/docs/source/sparse.rst b/docs/source/sparse.rst index b746af7f7fdcc..d962fd841e21f 100644 --- a/docs/source/sparse.rst +++ b/docs/source/sparse.rst @@ -127,7 +127,7 @@ Therefore, representation of a SparseTensor of sparse_dim = 0 is simply a dense .. method:: sub .. method:: sub_ .. method:: t_ - .. method:: toDense + .. method:: to_dense .. method:: transpose .. method:: transpose_ .. method:: zero_ diff --git a/docs/source/tensor_attributes.rst b/docs/source/tensor_attributes.rst index d9dfb8aae286c..1717d87fcece6 100644 --- a/docs/source/tensor_attributes.rst +++ b/docs/source/tensor_attributes.rst @@ -34,6 +34,80 @@ Boolean ``torch.bool`` ``torch To find out if a :class:`torch.dtype` is a floating point data type, the property :attr:`is_floating_point` can be used, which returns ``True`` if the data type is a floating point data type. +When the dtypes of inputs to an arithmetic operation (`add`, `sub`, `div`, `mul`) differ, we promote +by finding the minimum dtype that satisfies the following rules: + +* If the type of a scalar operand is of a higher category than tensor operands + (where floating > integral > boolean), we promote to a type with sufficient size to hold + all scalar operands of that category. +* If a zero-dimension tensor operand has a higher category than dimensioned operands, + we promote to a type with sufficient size and category to hold all zero-dim tensor operands of + that category. +* If there are no higher-category zero-dim operands, we promote to a type with sufficient size + and category to hold all dimensioned operands. + +A floating point scalar operand has dtype `torch.get_default_dtype()` and an integral +non-boolean scalar operand has dtype `torch.int64`. Unlike numpy, we do not inspect +values when determining the minimum `dtypes` of an operand. Quantized and complex types +are not yet supported. + +Promotion Examples:: + + >>> float_tensor = torch.ones(1, dtype=torch.float) + >>> double_tensor = torch.ones(1, dtype=torch.double) + >>> int_tensor = torch.ones(1, dtype=torch.int) + >>> long_tensor = torch.ones(1, dtype=torch.long) + >>> uint_tensor = torch.ones(1, dtype=torch.uint8) + >>> double_tensor = torch.ones(1, dtype=torch.double) + >>> bool_tensor = torch.ones(1, dtype=torch.bool) + # zero-dim tensors + >>> long_zerodim = torch.tensor(1, dtype=torch.long) + >>> int_zerodim = torch.tensor(1, dtype=torch.int) + + >>> torch.add(5, 5).dtype + torch.int64 + # 5 is an int64, but does not have higher category than int_tensor so is not considered. + >>> (int_tensor + 5).dtype + torch.int32 + >>> (int_tensor + long_zerodim).dtype + torch.int32 + >>> (long_tensor + int_tensor).dtype + torch.int64 + >>> (bool_tensor + long_tensor).dtype + torch.int64 + >>> (bool_tensor + uint_tensor).dtype + torch.uint8 + >>> (float_tensor + double_tensor).dtype + torch.float64 + >>> (bool_tensor + int_tensor).dtype + torch.int32 + # Since long is a different kind than float, result dtype only needs to be large enough + # to hold the float. + >>> torch.add(long_tensor, float_tensor).dtype + torch.float32 + +When the output tensor of an arithmetic operation is specified, we allow casting to its `dtype` except that: + * An integral output tensor cannot accept a floating point tensor. + * A boolean output tensor cannot accept a non-boolean tensor. + +Casting Examples:: + + # allowed: + >>> float_tensor *= double_tensor + >>> float_tensor *= int_tensor + >>> float_tensor *= uint_tensor + >>> float_tensor *= bool_tensor + >>> float_tensor *= double_tensor + >>> int_tensor *= long_tensor + >>> int_tensor *= uint_tensor + >>> uint_tensor *= int_tensor + + # disallowed (RuntimeError: result type can't be cast to the desired output type): + >>> int_tensor *= float_tensor + >>> bool_tensor *= int_tensor + >>> bool_tensor *= uint_tensor + + .. _device-doc: torch.device diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index fadb92e9153b0..d094525ad78c0 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -374,6 +374,7 @@ view of a storage and defines numeric operations on it. .. automethod:: q_zero_point .. automethod:: q_per_channel_scales .. automethod:: q_per_channel_zero_points + .. automethod:: q_per_channel_axis .. automethod:: random_ .. automethod:: reciprocal .. automethod:: reciprocal_ diff --git a/ios/.gitignore b/ios/.gitignore new file mode 100644 index 0000000000000..7fb7ff6d53afb --- /dev/null +++ b/ios/.gitignore @@ -0,0 +1,37 @@ +## macOS +.DS_Store + +## Build generated +build/ +DerivedData +build.xcarchive + +## Various settings +*.pbxuser +!default.pbxuser +*.mode1v3 +!default.mode1v3 +*.mode2v3 +!default.mode2v3 +*.perspectivev3 +!default.perspectivev3 +xcuserdata + +## Other +*.xccheckout +*.moved-aside +*.xcuserstate +*.xcscmblueprint +*.xcworkspacedata +IDEWorkspaceChecks.plist + +## Obj-C/Swift specific +*.hmap +*.ipa + +# CocoaPods +Pods/ + +# Carthage +Carthage/Checkouts +Carthage/Build diff --git a/ios/LibTorch.h b/ios/LibTorch.h new file mode 100644 index 0000000000000..e75bb1eb8404a --- /dev/null +++ b/ios/LibTorch.h @@ -0,0 +1,6 @@ +#ifndef LibTorch_h +#define LibTorch_h + +#include + +#endif diff --git a/ios/LibTorch.podspec b/ios/LibTorch.podspec new file mode 100644 index 0000000000000..5ace441a5dada --- /dev/null +++ b/ios/LibTorch.podspec @@ -0,0 +1,36 @@ +Pod::Spec.new do |s| + s.name = 'LibTorch' + s.version = '0.0.3' + s.authors = 'PyTorch Team' + s.license = { :type => 'BSD' } + s.homepage = 'https://github.com/pytorch/pytorch' + s.source = { :http => 'https://ossci-ios-build.s3.amazonaws.com/libtorch_ios_nightly_build.zip' } + s.summary = 'The PyTorch C++ library for iOS' + s.description = <<-DESC + The PyTorch C++ library for iOS. + DESC + s.ios.deployment_target = '12.0' + s.default_subspec = 'Core' + s.subspec 'Core' do |ss| + ss.dependency 'LibTorch/Torch' + ss.source_files = 'src/*.{h,cpp,c,cc}' + ss.public_header_files = ['src/LibTorch.h'] + end + s.subspec 'Torch' do |ss| + ss.header_mappings_dir = 'install/include/' + ss.preserve_paths = 'install/include/**/*.{h,cpp,cc,c}' + ss.vendored_libraries = 'install/lib/*.a' + ss.libraries = ['c++', 'stdc++'] + end + s.user_target_xcconfig = { + 'HEADER_SEARCH_PATHS' => '$(inherited) "$(PODS_ROOT)/LibTorch/install/include/"', + 'OTHER_LDFLAGS' => '-force_load "$(PODS_ROOT)/LibTorch/install/lib/libtorch.a"', + 'CLANG_CXX_LANGUAGE_STANDARD' => 'c++11', + 'CLANG_CXX_LIBRARY' => 'libc++' + } + s.pod_target_xcconfig = { + 'HEADER_SEARCH_PATHS' => '$(inherited) "$(PODS_ROOT)/LibTorch/install/include/"', + 'VALID_ARCHS' => 'x86_64 arm64' + } + s.library = ['c++', 'stdc++'] +end \ No newline at end of file diff --git a/ios/README.md b/ios/README.md new file mode 100644 index 0000000000000..7a60bfed99fb2 --- /dev/null +++ b/ios/README.md @@ -0,0 +1,5 @@ +## LibTorch + +The PyTorch C++ static library for iOS. + +(Detailed documentation will be added soon) \ No newline at end of file diff --git a/ios/TestApp/Podfile b/ios/TestApp/Podfile new file mode 100644 index 0000000000000..bf3e0053b9a12 --- /dev/null +++ b/ios/TestApp/Podfile @@ -0,0 +1,5 @@ + +platform :ios, '12.0' +target 'TestApp' do + pod 'LibTorch' +end diff --git a/ios/TestApp/TestApp.xcodeproj/project.pbxproj b/ios/TestApp/TestApp.xcodeproj/project.pbxproj new file mode 100644 index 0000000000000..337c5671af5cc --- /dev/null +++ b/ios/TestApp/TestApp.xcodeproj/project.pbxproj @@ -0,0 +1,341 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 50; + objects = { + +/* Begin PBXBuildFile section */ + A06D4CB5232F0DB200763E16 /* AppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = A06D4CB4232F0DB200763E16 /* AppDelegate.m */; }; + A06D4CB8232F0DB200763E16 /* ViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = A06D4CB7232F0DB200763E16 /* ViewController.mm */; }; + A06D4CBB232F0DB200763E16 /* Main.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = A06D4CB9232F0DB200763E16 /* Main.storyboard */; }; + A06D4CBD232F0DB200763E16 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = A06D4CBC232F0DB200763E16 /* Assets.xcassets */; }; + A06D4CC0232F0DB200763E16 /* LaunchScreen.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = A06D4CBE232F0DB200763E16 /* LaunchScreen.storyboard */; }; + A06D4CC3232F0DB200763E16 /* main.m in Sources */ = {isa = PBXBuildFile; fileRef = A06D4CC2232F0DB200763E16 /* main.m */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + A06D4CB0232F0DB200763E16 /* TestApp.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = TestApp.app; sourceTree = BUILT_PRODUCTS_DIR; }; + A06D4CB3232F0DB200763E16 /* AppDelegate.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; + A06D4CB4232F0DB200763E16 /* AppDelegate.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = AppDelegate.m; sourceTree = ""; }; + A06D4CB6232F0DB200763E16 /* ViewController.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = ViewController.h; sourceTree = ""; }; + A06D4CB7232F0DB200763E16 /* ViewController.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = ViewController.mm; sourceTree = ""; }; + A06D4CBA232F0DB200763E16 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/Main.storyboard; sourceTree = ""; }; + A06D4CBC232F0DB200763E16 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; + A06D4CBF232F0DB200763E16 /* Base */ = {isa = PBXFileReference; lastKnownFileType = file.storyboard; name = Base; path = Base.lproj/LaunchScreen.storyboard; sourceTree = ""; }; + A06D4CC1232F0DB200763E16 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; + A06D4CC2232F0DB200763E16 /* main.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = main.m; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + A06D4CAD232F0DB200763E16 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + A06D4CA7232F0DB200763E16 = { + isa = PBXGroup; + children = ( + A06D4CB2232F0DB200763E16 /* TestApp */, + A06D4CB1232F0DB200763E16 /* Products */, + ); + sourceTree = ""; + }; + A06D4CB1232F0DB200763E16 /* Products */ = { + isa = PBXGroup; + children = ( + A06D4CB0232F0DB200763E16 /* TestApp.app */, + ); + name = Products; + sourceTree = ""; + }; + A06D4CB2232F0DB200763E16 /* TestApp */ = { + isa = PBXGroup; + children = ( + A06D4CB3232F0DB200763E16 /* AppDelegate.h */, + A06D4CB4232F0DB200763E16 /* AppDelegate.m */, + A06D4CB6232F0DB200763E16 /* ViewController.h */, + A06D4CB7232F0DB200763E16 /* ViewController.mm */, + A06D4CB9232F0DB200763E16 /* Main.storyboard */, + A06D4CBC232F0DB200763E16 /* Assets.xcassets */, + A06D4CBE232F0DB200763E16 /* LaunchScreen.storyboard */, + A06D4CC1232F0DB200763E16 /* Info.plist */, + A06D4CC2232F0DB200763E16 /* main.m */, + ); + path = TestApp; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + A06D4CAF232F0DB200763E16 /* TestApp */ = { + isa = PBXNativeTarget; + buildConfigurationList = A06D4CC6232F0DB200763E16 /* Build configuration list for PBXNativeTarget "TestApp" */; + buildPhases = ( + A06D4CAC232F0DB200763E16 /* Sources */, + A06D4CAD232F0DB200763E16 /* Frameworks */, + A06D4CAE232F0DB200763E16 /* Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = TestApp; + productName = TestApp; + productReference = A06D4CB0232F0DB200763E16 /* TestApp.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + A06D4CA8232F0DB200763E16 /* Project object */ = { + isa = PBXProject; + attributes = { + LastUpgradeCheck = 1030; + TargetAttributes = { + A06D4CAF232F0DB200763E16 = { + CreatedOnToolsVersion = 10.3; + }; + }; + }; + buildConfigurationList = A06D4CAB232F0DB200763E16 /* Build configuration list for PBXProject "TestApp" */; + compatibilityVersion = "Xcode 9.3"; + developmentRegion = en; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = A06D4CA7232F0DB200763E16; + productRefGroup = A06D4CB1232F0DB200763E16 /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + A06D4CAF232F0DB200763E16 /* TestApp */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + A06D4CAE232F0DB200763E16 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + A06D4CC0232F0DB200763E16 /* LaunchScreen.storyboard in Resources */, + A06D4CBD232F0DB200763E16 /* Assets.xcassets in Resources */, + A06D4CBB232F0DB200763E16 /* Main.storyboard in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + A06D4CAC232F0DB200763E16 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + A06D4CB8232F0DB200763E16 /* ViewController.mm in Sources */, + A06D4CC3232F0DB200763E16 /* main.m in Sources */, + A06D4CB5232F0DB200763E16 /* AppDelegate.m in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin PBXVariantGroup section */ + A06D4CB9232F0DB200763E16 /* Main.storyboard */ = { + isa = PBXVariantGroup; + children = ( + A06D4CBA232F0DB200763E16 /* Base */, + ); + name = Main.storyboard; + sourceTree = ""; + }; + A06D4CBE232F0DB200763E16 /* LaunchScreen.storyboard */ = { + isa = PBXVariantGroup; + children = ( + A06D4CBF232F0DB200763E16 /* Base */, + ); + name = LaunchScreen.storyboard; + sourceTree = ""; + }; +/* End PBXVariantGroup section */ + +/* Begin XCBuildConfiguration section */ + A06D4CC4232F0DB200763E16 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 12.4; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + MTL_FAST_MATH = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + }; + name = Debug; + }; + A06D4CC5232F0DB200763E16 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++14"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + CODE_SIGN_IDENTITY = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu11; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 12.4; + MTL_ENABLE_DEBUG_INFO = NO; + MTL_FAST_MATH = YES; + SDKROOT = iphoneos; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; + A06D4CC7232F0DB200763E16 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_STYLE = Automatic; + DEVELOPMENT_TEAM = GW8XWHWQR7; + INFOPLIST_FILE = TestApp/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + PRODUCT_BUNDLE_IDENTIFIER = com.pytorch.testApp.TestApp; + PRODUCT_NAME = "$(TARGET_NAME)"; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + A06D4CC8232F0DB200763E16 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CODE_SIGN_STYLE = Automatic; + DEVELOPMENT_TEAM = GW8XWHWQR7; + INFOPLIST_FILE = TestApp/Info.plist; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + PRODUCT_BUNDLE_IDENTIFIER = com.pytorch.testApp.TestApp; + PRODUCT_NAME = "$(TARGET_NAME)"; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + A06D4CAB232F0DB200763E16 /* Build configuration list for PBXProject "TestApp" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + A06D4CC4232F0DB200763E16 /* Debug */, + A06D4CC5232F0DB200763E16 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + A06D4CC6232F0DB200763E16 /* Build configuration list for PBXNativeTarget "TestApp" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + A06D4CC7232F0DB200763E16 /* Debug */, + A06D4CC8232F0DB200763E16 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = A06D4CA8232F0DB200763E16 /* Project object */; +} diff --git a/ios/TestApp/TestApp/AppDelegate.h b/ios/TestApp/TestApp/AppDelegate.h new file mode 100644 index 0000000000000..2a9bac67c9034 --- /dev/null +++ b/ios/TestApp/TestApp/AppDelegate.h @@ -0,0 +1,8 @@ +#import + +@interface AppDelegate : UIResponder + +@property (strong, nonatomic) UIWindow *window; + +@end + diff --git a/ios/TestApp/TestApp/AppDelegate.m b/ios/TestApp/TestApp/AppDelegate.m new file mode 100644 index 0000000000000..ed6928ac023fe --- /dev/null +++ b/ios/TestApp/TestApp/AppDelegate.m @@ -0,0 +1,43 @@ +#import "AppDelegate.h" + +@interface AppDelegate () + +@end + +@implementation AppDelegate + + +- (BOOL)application:(UIApplication *)application didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + // Override point for customization after application launch. + return YES; +} + + +- (void)applicationWillResignActive:(UIApplication *)application { + // Sent when the application is about to move from active to inactive state. This can occur for certain types of temporary interruptions (such as an incoming phone call or SMS message) or when the user quits the application and it begins the transition to the background state. + // Use this method to pause ongoing tasks, disable timers, and invalidate graphics rendering callbacks. Games should use this method to pause the game. +} + + +- (void)applicationDidEnterBackground:(UIApplication *)application { + // Use this method to release shared resources, save user data, invalidate timers, and store enough application state information to restore your application to its current state in case it is terminated later. + // If your application supports background execution, this method is called instead of applicationWillTerminate: when the user quits. +} + + +- (void)applicationWillEnterForeground:(UIApplication *)application { + // Called as part of the transition from the background to the active state; here you can undo many of the changes made on entering the background. +} + + +- (void)applicationDidBecomeActive:(UIApplication *)application { + // Restart any tasks that were paused (or not yet started) while the application was inactive. If the application was previously in the background, optionally refresh the user interface. +} + + +- (void)applicationWillTerminate:(UIApplication *)application { + // Called when the application is about to terminate. Save data if appropriate. See also applicationDidEnterBackground:. +} + + +@end diff --git a/ios/TestApp/TestApp/Assets.xcassets/AppIcon.appiconset/Contents.json b/ios/TestApp/TestApp/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 0000000000000..d8db8d65fd79f --- /dev/null +++ b/ios/TestApp/TestApp/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,98 @@ +{ + "images" : [ + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "20x20", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "29x29", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "40x40", + "scale" : "3x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "2x" + }, + { + "idiom" : "iphone", + "size" : "60x60", + "scale" : "3x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "20x20", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "29x29", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "40x40", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "1x" + }, + { + "idiom" : "ipad", + "size" : "76x76", + "scale" : "2x" + }, + { + "idiom" : "ipad", + "size" : "83.5x83.5", + "scale" : "2x" + }, + { + "idiom" : "ios-marketing", + "size" : "1024x1024", + "scale" : "1x" + } + ], + "info" : { + "version" : 1, + "author" : "xcode" + } +} \ No newline at end of file diff --git a/ios/TestApp/TestApp/Assets.xcassets/Contents.json b/ios/TestApp/TestApp/Assets.xcassets/Contents.json new file mode 100644 index 0000000000000..da4a164c91865 --- /dev/null +++ b/ios/TestApp/TestApp/Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "version" : 1, + "author" : "xcode" + } +} \ No newline at end of file diff --git a/ios/TestApp/TestApp/Base.lproj/LaunchScreen.storyboard b/ios/TestApp/TestApp/Base.lproj/LaunchScreen.storyboard new file mode 100644 index 0000000000000..bfa36129419f8 --- /dev/null +++ b/ios/TestApp/TestApp/Base.lproj/LaunchScreen.storyboard @@ -0,0 +1,25 @@ + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/ios/TestApp/TestApp/Base.lproj/Main.storyboard b/ios/TestApp/TestApp/Base.lproj/Main.storyboard new file mode 100644 index 0000000000000..942f0bc452d11 --- /dev/null +++ b/ios/TestApp/TestApp/Base.lproj/Main.storyboard @@ -0,0 +1,24 @@ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/ios/TestApp/TestApp/Info.plist b/ios/TestApp/TestApp/Info.plist new file mode 100644 index 0000000000000..49d8238263452 --- /dev/null +++ b/ios/TestApp/TestApp/Info.plist @@ -0,0 +1,45 @@ + + + + + CFBundleDevelopmentRegion + $(DEVELOPMENT_LANGUAGE) + CFBundleExecutable + $(EXECUTABLE_NAME) + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + $(PRODUCT_NAME) + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleVersion + 1 + LSRequiresIPhoneOS + + UILaunchStoryboardName + LaunchScreen + UIMainStoryboardFile + Main + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + UIInterfaceOrientationLandscapeLeft + UIInterfaceOrientationLandscapeRight + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + UIInterfaceOrientationPortraitUpsideDown + UIInterfaceOrientationLandscapeLeft + UIInterfaceOrientationLandscapeRight + + + diff --git a/ios/TestApp/TestApp/ViewController.h b/ios/TestApp/TestApp/ViewController.h new file mode 100644 index 0000000000000..9c7dfc57ec311 --- /dev/null +++ b/ios/TestApp/TestApp/ViewController.h @@ -0,0 +1,6 @@ +#import + +@interface ViewController : UIViewController + +@end + diff --git a/ios/TestApp/TestApp/ViewController.mm b/ios/TestApp/TestApp/ViewController.mm new file mode 100644 index 0000000000000..df3663615c3eb --- /dev/null +++ b/ios/TestApp/TestApp/ViewController.mm @@ -0,0 +1,16 @@ +#import "ViewController.h" +#import + +@interface ViewController () + +@end + +@implementation ViewController + +- (void)viewDidLoad { + [super viewDidLoad]; + // Do any additional setup after loading the view. +} + + +@end diff --git a/ios/TestApp/TestApp/main.m b/ios/TestApp/TestApp/main.m new file mode 100644 index 0000000000000..81e84cbb78185 --- /dev/null +++ b/ios/TestApp/TestApp/main.m @@ -0,0 +1,8 @@ +#import +#import "AppDelegate.h" + +int main(int argc, char * argv[]) { + @autoreleasepool { + return UIApplicationMain(argc, argv, nil, NSStringFromClass([AppDelegate class])); + } +} diff --git a/scripts/build_android.sh b/scripts/build_android.sh index 047c1f6766dcd..6cc93fae1d42f 100755 --- a/scripts/build_android.sh +++ b/scripts/build_android.sh @@ -47,10 +47,6 @@ echo "Caffe2 path: $CAFFE2_ROOT" echo "Using Android NDK at $ANDROID_NDK" echo "Android NDK version: $ANDROID_NDK_VERSION" -# Build protobuf from third_party so we have a host protoc binary. -echo "Building protoc" -$CAFFE2_ROOT/scripts/build_host_protoc.sh - # Now, actually build the Android target. BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_android"} INSTALL_PREFIX=${BUILD_ROOT}/install @@ -59,15 +55,25 @@ cd $BUILD_ROOT CMAKE_ARGS=() +if [ -n "${BUILD_PYTORCH_MOBILE:-}" ]; then + CMAKE_ARGS+=("-DBUILD_CAFFE2_MOBILE=OFF") + CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')") + CMAKE_ARGS+=("-DPYTHON_EXECUTABLE=$(python -c 'import sys; print(sys.executable)')") + CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") +else + # Build protobuf from third_party so we have a host protoc binary. + echo "Building protoc" + $CAFFE2_ROOT/scripts/build_host_protoc.sh + # Use locally built protoc because we'll build libprotobuf for the + # target architecture and need an exact version match. + CMAKE_ARGS+=("-DCAFFE2_CUSTOM_PROTOC_EXECUTABLE=$CAFFE2_ROOT/build_host_protoc/bin/protoc") +fi + # If Ninja is installed, prefer it to Make if [ -x "$(command -v ninja)" ]; then CMAKE_ARGS+=("-GNinja") fi -# Use locally built protoc because we'll build libprotobuf for the -# target architecture and need an exact version match. -CMAKE_ARGS+=("-DCAFFE2_CUSTOM_PROTOC_EXECUTABLE=$CAFFE2_ROOT/build_host_protoc/bin/protoc") - # Use android-cmake to build Android project from CMake. CMAKE_ARGS+=("-DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake") diff --git a/scripts/build_ios.sh b/scripts/build_ios.sh index a2f5aef064518..872aa7667d882 100755 --- a/scripts/build_ios.sh +++ b/scripts/build_ios.sh @@ -9,11 +9,6 @@ CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" -# Build protobuf from third_party so we have a host protoc binary. -echo "Building protoc" -BITCODE_FLAGS="-DCMAKE_C_FLAGS=-fembed-bitcode -DCMAKE_CXX_FLAGS=-fembed-bitcode " -$CAFFE2_ROOT/scripts/build_host_protoc.sh --other-flags $BITCODE_FLAGS - # Now, actually build the iOS target. BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_ios"} INSTALL_PREFIX=${BUILD_ROOT}/install @@ -22,9 +17,28 @@ cd $BUILD_ROOT CMAKE_ARGS=() -# Use locally built protoc because we'll build libprotobuf for the -# target architecture and need an exact version match. -CMAKE_ARGS+=("-DCAFFE2_CUSTOM_PROTOC_EXECUTABLE=$CAFFE2_ROOT/build_host_protoc/bin/protoc") +if [ -n "${BUILD_PYTORCH_MOBILE:-}" ]; then + CMAKE_ARGS+=("-DBUILD_CAFFE2_MOBILE=OFF") + CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')") + CMAKE_ARGS+=("-DPYTHON_EXECUTABLE=$(python -c 'import sys; print(sys.executable)')") + CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") + # bitcode + if [ "${ENABLE_BITCODE:-}" == '1' ]; then + CMAKE_ARGS+=("-DCMAKE_C_FLAGS=-fembed-bitcode") + CMAKE_ARGS+=("-DCMAKE_CXX_FLAGS=-fembed-bitcode") + fi +else + # Build protobuf from third_party so we have a host protoc binary. + echo "Building protoc" + BITCODE_FLAGS="-DCMAKE_C_FLAGS=-fembed-bitcode -DCMAKE_CXX_FLAGS=-fembed-bitcode " + $CAFFE2_ROOT/scripts/build_host_protoc.sh --other-flags $BITCODE_FLAGS + # Use locally built protoc because we'll build libprotobuf for the + # target architecture and need an exact version match. + CMAKE_ARGS+=("-DCAFFE2_CUSTOM_PROTOC_EXECUTABLE=$CAFFE2_ROOT/build_host_protoc/bin/protoc") + # Bitcode is enabled by default for caffe2 + CMAKE_ARGS+=("-DCMAKE_C_FLAGS=-fembed-bitcode") + CMAKE_ARGS+=("-DCMAKE_CXX_FLAGS=-fembed-bitcode") +fi # Use ios-cmake to build iOS project from CMake. # This projects sets CMAKE_C_COMPILER to /usr/bin/gcc and @@ -40,11 +54,19 @@ fi # IOS_PLATFORM controls type of iOS platform (see ios-cmake) if [ -n "${IOS_PLATFORM:-}" ]; then CMAKE_ARGS+=("-DIOS_PLATFORM=${IOS_PLATFORM}") + if [ "${IOS_PLATFORM}" == "SIMULATOR" ]; then + # iOS Simulator build is not supported by NNPACK + CMAKE_ARGS+=("-DUSE_NNPACK=OFF") + fi else # IOS_PLATFORM is not set, default to OS, which builds iOS. CMAKE_ARGS+=("-DIOS_PLATFORM=OS") fi +if [ -n "${IOS_ARCH:-}" ]; then + CMAKE_ARGS+=("-DIOS_ARCH=${IOS_ARCH}") +fi + # Don't build binaries or tests (only the library) CMAKE_ARGS+=("-DBUILD_TEST=OFF") CMAKE_ARGS+=("-DBUILD_BINARY=OFF") @@ -57,6 +79,7 @@ CMAKE_ARGS+=("-DUSE_OPENCV=OFF") CMAKE_ARGS+=("-DUSE_LMDB=OFF") CMAKE_ARGS+=("-DUSE_LEVELDB=OFF") CMAKE_ARGS+=("-DUSE_MPI=OFF") +CMAKE_ARGS+=("-DUSE_NUMPY=OFF") # pthreads CMAKE_ARGS+=("-DCMAKE_THREAD_LIBS_INIT=-lpthread") @@ -68,11 +91,9 @@ if [ "${VERBOSE:-}" == '1' ]; then CMAKE_ARGS+=("-DCMAKE_VERBOSE_MAKEFILE=1") fi -CMAKE_ARGS+=("-DCMAKE_C_FLAGS=-fembed-bitcode") -CMAKE_ARGS+=("-DCMAKE_CXX_FLAGS=-fembed-bitcode") cmake "$CAFFE2_ROOT" \ -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \ - -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_BUILD_TYPE=MinSizeRel \ -DBUILD_SHARED_LIBS=OFF \ ${CMAKE_ARGS[@]} \ $@ diff --git a/scripts/build_mobile.sh b/scripts/build_mobile.sh new file mode 100755 index 0000000000000..dec43a4eda088 --- /dev/null +++ b/scripts/build_mobile.sh @@ -0,0 +1,68 @@ +#!/bin/bash +############################################################################## +# Example command to build the mobile target. +############################################################################## +# +# This script shows how one can build a libtorch library optimized for mobile +# devices using host toolchain. + +set -e + +export PYTORCH_BUILD_MOBILE=1 +CAFFE2_ROOT="$( cd "$(dirname "$0")"/.. ; pwd -P)" + +echo "Bash: $(/bin/bash --version | head -1)" +echo "Caffe2 path: $CAFFE2_ROOT" + +# Now, actually build the Android target. +BUILD_ROOT=${BUILD_ROOT:-"$CAFFE2_ROOT/build_mobile"} +INSTALL_PREFIX=${BUILD_ROOT}/install +mkdir -p $BUILD_ROOT +cd $BUILD_ROOT + +CMAKE_ARGS=() +CMAKE_ARGS+=("-DBUILD_CAFFE2_MOBILE=OFF") +CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')") +CMAKE_ARGS+=("-DPYTHON_EXECUTABLE=$(python -c 'import sys; print(sys.executable)')") +CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") +CMAKE_ARGS+=("-DBUILD_SHARED_LIBS=OFF") + +# If Ninja is installed, prefer it to Make +if [ -x "$(command -v ninja)" ]; then + CMAKE_ARGS+=("-GNinja") +fi + +# Disable unused dependencies +CMAKE_ARGS+=("-DUSE_CUDA=OFF") +CMAKE_ARGS+=("-DUSE_GFLAGS=OFF") +CMAKE_ARGS+=("-DUSE_OPENCV=OFF") +CMAKE_ARGS+=("-DUSE_LMDB=OFF") +CMAKE_ARGS+=("-DUSE_LEVELDB=OFF") +CMAKE_ARGS+=("-DUSE_MPI=OFF") +CMAKE_ARGS+=("-DUSE_OPENMP=OFF") + +# Only toggle if VERBOSE=1 +if [ "${VERBOSE:-}" == '1' ]; then + CMAKE_ARGS+=("-DCMAKE_VERBOSE_MAKEFILE=1") +fi + +# Use-specified CMake arguments go last to allow overridding defaults +CMAKE_ARGS+=($@) + +cmake "$CAFFE2_ROOT" \ + -DCMAKE_INSTALL_PREFIX=$INSTALL_PREFIX \ + -DCMAKE_BUILD_TYPE=Release \ + "${CMAKE_ARGS[@]}" + +# Cross-platform parallel build +if [ -z "$MAX_JOBS" ]; then + if [ "$(uname)" == 'Darwin' ]; then + MAX_JOBS=$(sysctl -n hw.ncpu) + else + MAX_JOBS=$(nproc) + fi +fi + +echo "Will install headers and libs to $INSTALL_PREFIX for further project usage." +cmake --build . --target install -- "-j${MAX_JOBS}" +echo "Installation completed, now you can copy the headers/libs from $INSTALL_PREFIX to your project directory." diff --git a/scripts/fbcode-dev-setup/onnx_c2_setup.sh b/scripts/fbcode-dev-setup/onnx_c2_setup.sh index e5c05dcc74bd5..f89adec8d0043 100755 --- a/scripts/fbcode-dev-setup/onnx_c2_setup.sh +++ b/scripts/fbcode-dev-setup/onnx_c2_setup.sh @@ -140,7 +140,7 @@ with_proxy python setup.py develop # Build PyTorch and Caffe2 cd "$onnx_root/pytorch" with_proxy pip install -r "requirements.txt" -with_proxy python setup.py build_deps develop +with_proxy python setup.py develop # Sanity checks and useful info cd "$onnx_root" diff --git a/setup.py b/setup.py index 2838b65c0def1..aeb3778c10008 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ # BUILD_CAFFE2_OPS=0 # disable Caffe2 operators build # -# USE_GLOO_IBVERBS +# USE_IBVERBS # toggle features related to distributed support # # USE_OPENCV @@ -181,7 +181,7 @@ import importlib from tools.build_pytorch_libs import build_caffe2 -from tools.setup_helpers.env import (IS_WINDOWS, IS_DARWIN, IS_LINUX, +from tools.setup_helpers.env import (IS_WINDOWS, IS_DARWIN, check_env_flag, build_type) from tools.setup_helpers.cmake import CMake from tools.setup_helpers.cuda import CUDA_HOME, CUDA_VERSION @@ -352,9 +352,6 @@ def check_file(f): if sys.version_info <= (2, 7): install_requires += ['future'] -if sys.version_info[0] == 2: - install_requires += ['requests'] - missing_pydep = ''' Missing build dependency: Unable to `import {importname}`. Please install it via `conda install {module}` or `pip install {module}` @@ -403,10 +400,10 @@ def run(self): else: report('-- Not using NCCL') if cmake_cache_vars['USE_DISTRIBUTED']: - if IS_LINUX: - report('-- Building with c10d distributed package ') + if IS_WINDOWS: + report('-- Building without distributed package') else: - report('-- Building without c10d distributed package') + report('-- Building with distributed package ') else: report('-- Building without distributed package') @@ -825,7 +822,10 @@ def print_box(msg): 'include/torch/csrc/api/include/torch/detail/*.h', 'include/torch/csrc/api/include/torch/detail/ordered_dict.h', 'include/torch/csrc/api/include/torch/nn/*.h', + 'include/torch/csrc/api/include/torch/nn/functional/*.h', + 'include/torch/csrc/api/include/torch/nn/options/*.h', 'include/torch/csrc/api/include/torch/nn/modules/*.h', + 'include/torch/csrc/api/include/torch/nn/modules/container/*.h', 'include/torch/csrc/api/include/torch/nn/parallel/*.h', 'include/torch/csrc/api/include/torch/optim/*.h', 'include/torch/csrc/api/include/torch/serialize/*.h', diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py new file mode 100644 index 0000000000000..b39054a6db852 --- /dev/null +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -0,0 +1,53 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse +import sys +import torch +from torch._C import parse_schema + + +def check_bc(new_schema_dict): + existing_schemas = torch._C._jit_get_all_schemas() + for existing_schema in existing_schemas: + print("processing existing schema: ", str(existing_schema)) + new_schemas = new_schema_dict.get(existing_schema.name, []) + found = False + for new_schema in new_schemas: + if new_schema.is_backward_compatible_with(existing_schema): + found = True + break + if not found: + print('Can NOT find backward compatible schemas after changes ' + 'for schema {} from the following candidates:\n[\n{}\n]' + .format( + str(existing_schema), + "\n\t".join(str(s) for s in new_schemas))) + print('The PR is introducing backward incompatible changes to the ' + 'operator library. Please contact PyTorch team to confirm ' + 'whether this change is wanted or not.') + # TODO Print out more details about why candidates don't match. + return False + print('Found backward compatible schemas for all existing schemas') + return True + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Process some integers.') + parser.add_argument( + '--new-schemas', + help='filename to load new schemas', + type=str, + default='schemas.txt') + args = parser.parse_args() + new_schema_dict = dict() + with open(args.new_schemas, 'r') as f: + line = f.readline() + while line: + s = parse_schema(line.strip()) + line = f.readline() + slist = new_schema_dict.get(s.name, []) + slist.append(s) + new_schema_dict[s.name] = slist + + if not check_bc(new_schema_dict): + sys.exit(1) diff --git a/test/backward_compatibility/dump_all_function_schemas.py b/test/backward_compatibility/dump_all_function_schemas.py new file mode 100644 index 0000000000000..55b4d96dcf091 --- /dev/null +++ b/test/backward_compatibility/dump_all_function_schemas.py @@ -0,0 +1,24 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import argparse +import torch + + +def dump(filename): + schemas = torch._C._jit_get_all_schemas() + with open(filename, 'w') as f: + for s in schemas: + f.write(str(s)) + f.write('\n') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Process some integers.') + parser.add_argument( + '-f', + '--filename', + help='filename to dump the schemas', + type=str, + default='schemas.txt') + args = parser.parse_args() + dump(args.filename) diff --git a/test/common_device_type.py b/test/common_device_type.py new file mode 100644 index 0000000000000..dda0b8d7911a6 --- /dev/null +++ b/test/common_device_type.py @@ -0,0 +1,344 @@ +import inspect +from functools import wraps +import unittest +import torch +from common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \ + skipCUDANonDefaultStreamIf + +# Note: Generic Device-Type Testing +# +# [WRITING TESTS] +# +# Write your test class as usual except: +# (1) Each test method should have one of two signatures: +# +# (1a) testX(self, device) +# +# (1b) @dtypes() +# testX(self, device, dtype) +# +# Note in the latter case the dtypes decorator with a nonempty list of +# valid dtypes is not optional. +# +# When the test is called it will be given a device, like 'cpu' or +# 'cuda,' and a dtype from the list specified in @dtypes. If +# device-specific dtypes are specified using @dtypesIfCPU or +# @dtypesIfCUDA then those devices will only see the dtypes specified +# for them. +# (2) Prefer using test decorators defined in this file to others. +# For example, using the @skipIfNoLapack decorator instead of the +# @skipCPUIfNoLapack will cause the test to not run on CUDA if +# LAPACK is not available, which is wrong. If you need to use a decorator +# you may want to ask about porting it to this framework. +# +# See the TestTorchDeviceType class in test_torch.py for an example. +# +# [RUNNING TESTS] +# +# After defining your test class call instantiate_device_type_tests on it +# and pass in globals() for the second argument. This will instantiate +# discoverable device-specific test classes from your generic class. It will +# also hide the tests in your generic class so they're not run. +# +# If you device-generic test class is TestClass then new classes with names +# TestClass will be created for each available device type. +# TestClassCPU and TestClassCUDA, for example. Tests in these classes also +# have the device type and dtype, if provided, appended to their original +# name. testX, for instance, becomes testX_ or +# testX__. +# +# More concretely, TestTorchDeviceType becomes TestTorchDeviceTypeCPU, +# TestTorchDeviceTypeCUDA, ... test_diagonal in TestTorchDeviceType becomes +# test_diagonal_cpu, test_diagonal_cuda, ... test_erfinv, which accepts a dtype, +# becomes test_erfinv_cpu_float, test_erfinv_cpu_double, test_erfinv_cuda_half, +# ... +# +# In short, if you write a test signature like +# def textX(self, device) +# You are effectively writing +# def testX_cpu(self, device='cpu') +# def textX_cuda(self, device='cuda') +# def testX_xla(self, device='xla') +# ... +# +# These tests can be run directly like normal tests: +# "python test_torch.py TestTorchDeviceTypeCPU.test_diagonal_cpu" +# +# All the tests for a particular device type can be run using the class, and +# other collections of tests can be run using pytest filtering, like +# +# "pytest test_torch.py -k 'test_diag'" +# +# which will run test_diag on every available device. +# +# To specify particular device types the 'and' keyword can be used: +# +# "pytest test_torch.py -k 'test_erfinv and cpu'" +# +# will run test_erfinv on all cpu dtypes. +# +# [ADDING A DEVICE TYPE] +# +# To add a device type: +# +# (1) Create a new "TestBase" extending DeviceTypeTestBase. +# See CPUTestBase and CUDATestBase below. +# (2) Define the "device_type" attribute of the base to be the +# appropriate string. +# (3) Add logic to this file that appends your base class to +# device_type_test_bases when your device type is available. +# (4) (Optional) Write setUpClass/tearDownClass class methods that +# instantiate dependencies (see MAGMA in CUDATestBase). +# (5) (Optional) Override the "instantiate_test" method for total +# control over how your class creates tests. +# +# setUpClass is called AFTER tests have been created and BEFORE and ONLY IF +# they are run. This makes it useful for initializing devices and dependencies. +# + +# List of device type test bases that can be used to instantiate tests. +# See below for how this list is populated. If you're adding a device type +# you should check if it's available and (if it is) add it to this list. +device_type_test_bases = [] + + +class DeviceTypeTestBase(TestCase): + device_type = "generic_device_type" + + # Returns the dtypes the test has requested. + # Prefers device-specific dtype specifications over generic ones. + @classmethod + def _get_dtypes(cls, test): + if not hasattr(test, 'dtypes'): + return None + return test.dtypes.get(cls.device_type, test.dtypes.get('all', None)) + + # Creates device-specific tests. + @classmethod + def instantiate_test(cls, test): + test_name = test.__name__ + "_" + cls.device_type + + dtypes = cls._get_dtypes(test) + if dtypes is None: # Test has no dtype variants + assert not hasattr(cls, test_name), "Redefinition of test {0}".format(test_name) + + @wraps(test) + def instantiated_test(self, test=test): + return test(self, cls.device_type) + + setattr(cls, test_name, instantiated_test) + else: # Test has dtype variants + for dtype in dtypes: + dtype_str = str(dtype).split('.')[1] + dtype_test_name = test_name + "_" + dtype_str + assert not hasattr(cls, dtype_test_name), "Redefinition of test {0}".format(dtype_test_name) + + @wraps(test) + def instantiated_test(self, test=test, dtype=dtype): + return test(self, cls.device_type, dtype) + + setattr(cls, dtype_test_name, instantiated_test) + + +class CPUTestBase(DeviceTypeTestBase): + device_type = "cpu" + + +class CUDATestBase(DeviceTypeTestBase): + device_type = "cuda" + _do_cuda_memory_leak_check = True + _do_cuda_non_default_stream = True + + @classmethod + def setUpClass(cls): + # has_magma shows up after cuda is initialized + torch.ones(1).cuda() + cls.no_magma = not torch.cuda.has_magma + + +# Adds available device-type-specific test base classes +device_type_test_bases.append(CPUTestBase) +if torch.cuda.is_available(): + device_type_test_bases.append(CUDATestBase) + + +# Adds 'instantiated' device-specific test cases to the given scope. +# The tests in these test cases are derived from the generic tests in +# generic_test_class. +# See note "Generic Device Type Testing." +def instantiate_device_type_tests(generic_test_class, scope): + # Removes the generic test class from its enclosing scope so its tests + # are not discoverable. + del scope[generic_test_class.__name__] + + # Creates an 'empty' version of the generic_test_class + # Note: we don't inherit from the generic_test_class directly because + # that would add its tests to our test classes and they would be + # discovered (despite not being runnable). Inherited methods also + # can't be removed later, and we can't rely on load_tests because + # pytest doesn't support it (as of this writing). + empty_name = generic_test_class.__name__ + "_base" + empty_class = type(empty_name, generic_test_class.__bases__, {}) + + # Acquires members names + generic_members = set(dir(generic_test_class)) - set(dir(empty_class)) + generic_tests = [x for x in generic_members if x.startswith('test')] + + # Creates device-specific test cases + for base in device_type_test_bases: + class_name = generic_test_class.__name__ + base.device_type.upper() + device_type_test_class = type(class_name, (base, empty_class), {}) + + for name in generic_members: + if name in generic_tests: # Instantiates test member + + # Requires tests be a function for Python2 compat + # (In Python2 tests are type checked methods wrapping functions) + test = getattr(generic_test_class, name) + if hasattr(test, '__func__'): + test = test.__func__ + assert inspect.isfunction(test), "Couldn't extract function from '{0}'".format(name) + + # Instantiates the device-specific tests + device_type_test_class.instantiate_test(test) + else: # Ports non-test member + assert not hasattr(device_type_test_class, name), "Redefinition of non-test member {0}".format(name) + + # Unwraps to functions (when available) for Python2 compat + nontest = getattr(generic_test_class, name) + if hasattr(nontest, '__func__'): + nontest = nontest.__func__ + + setattr(device_type_test_class, name, nontest) + + # Mimics defining the instantiated class in the caller's file + # by setting its module to the given class's and adding + # the module to the given scope. + # This lets the instantiated class be discovered by unittest. + device_type_test_class.__module__ = generic_test_class.__module__ + scope[class_name] = device_type_test_class + + +# Decorator that skips a test if the given condition is true. +# Notes: +# (1) Skip conditions stack. +# (2) Skip conditions can be bools or strings. If a string the +# test base must have defined the corresponding attribute to be False +# for the test to run. If you want to use a string argument you should +# probably define a new decorator instead (see below). +# (3) Prefer the existing decorators to defining the 'device_type' kwarg. +class skipIf(object): + + def __init__(self, dep, reason, device_type=None): + self.dep = dep + self.reason = reason + self.device_type = device_type + + def __call__(self, fn): + + @wraps(fn) + def dep_fn(slf, device, *args, **kwargs): + if self.device_type is None or self.device_type == slf.device_type: + if (isinstance(self.dep, str) and getattr(slf, self.dep, True)) or (isinstance(self.dep, bool) and self.dep): + raise unittest.SkipTest(self.reason) + + return fn(slf, device, *args, **kwargs) + return dep_fn + + +# Skips a test on CPU if the condition is true. +class skipCPUIf(skipIf): + + def __init__(self, dep, reason): + super(skipCPUIf, self).__init__(dep, reason, device_type='cpu') + + +# Skips a test on CUDA if the condition is true. +class skipCUDAIf(skipIf): + + def __init__(self, dep, reason): + super(skipCUDAIf, self).__init__(dep, reason, device_type='cuda') + + +class onlyOn(object): + + def __init__(self, device_type): + self.device_type = device_type + + def __call__(self, fn): + + @wraps(fn) + def only_fn(slf, device, *args, **kwargs): + if self.device_type != slf.device_type: + reason = "Only runs on {0}".format(self.device_type) + raise unittest.SkipTest(reason) + + return fn(slf, device, *args, **kwargs) + + return only_fn + + +# Decorator that instantiates a variant of the test for each given dtype. +# Notes: +# (1) Tests that accept the dtype argument MUST use this decorator. +# (2) Can be overriden for the CPU or CUDA, respectively, using dtypesIfCPU +# or dtypesIfCUDA. +# (3) Prefer the existing decorators to defining the 'device_type' kwarg. +class dtypes(object): + + # Note: *args, **kwargs for Python2 compat. + # Python 3 allows (self, *args, device_type='all'). + def __init__(self, *args, **kwargs): + assert args is not None and len(args) != 0, "No dtypes given" + assert all(isinstance(arg, torch.dtype) for arg in args), "Unknown dtype in {0}".format(str(args)) + self.args = args + self.device_type = kwargs.get('device_type', 'all') + + def __call__(self, fn): + d = getattr(fn, 'dtypes', {}) + assert self.device_type not in d, "dtypes redefinition for {0}".format(self.device_type) + d[self.device_type] = self.args + fn.dtypes = d + return fn + + +# Overrides specified dtypes on the CPU. +class dtypesIfCPU(dtypes): + + def __init__(self, *args): + super(dtypesIfCPU, self).__init__(*args, device_type='cpu') + + +# Overrides specified dtypes on CUDA. +class dtypesIfCUDA(dtypes): + + def __init__(self, *args): + super(dtypesIfCUDA, self).__init__(*args, device_type='cuda') + + +def onlyCPU(fn): + return onlyOn('cpu')(fn) + + +def onlyCUDA(fn): + return onlyOn('cuda')(fn) + + +# Skips a test on CPU if LAPACK is not available. +def skipCPUIfNoLapack(fn): + return skipCPUIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")(fn) + + +# Skips a test on CPU if MKL is not available. +def skipCPUIfNoMkl(fn): + return skipCPUIf(not TEST_MKL, "PyTorch is built without MKL support")(fn) + + +# Skips a test on CUDA if MAGMA is not available. +def skipCUDAIfNoMagma(fn): + return skipCUDAIf('no_magma', "no MAGMA library detected")(skipCUDANonDefaultStreamIf(True)(fn)) + + +# Skips a test on CUDA when using ROCm. +def skipCUDAIfRocm(fn): + return skipCUDAIf(TEST_WITH_ROCM, "test doesn't currently work on the ROCm stack")(fn) diff --git a/test/common_nn.py b/test/common_nn.py index dd61a95779f7c..c0d41d8026606 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -1258,11 +1258,13 @@ def fractional_max_pool3d_test(test_case): dict( module_name='MaxPool1d', constructor_args=(4,), + cpp_constructor_args='(4)', input_size=(2, 10, 4), ), dict( module_name='MaxPool1d', constructor_args=(4, 4), + cpp_constructor_args='(torch::nn::MaxPool1dOptions(4).stride(4))', input_size=(2, 10, 4), desc='stride', ), @@ -1376,57 +1378,67 @@ def fractional_max_pool3d_test(test_case): dict( module_name='MaxPool2d', constructor_args=((3, 3), (2, 2), (1, 1)), + cpp_constructor_args='(torch::nn::MaxPool2dOptions({3, 3}).stride({2, 2}).padding({1, 1}))', input_size=(1, 3, 7, 7), ), dict( module_name='AvgPool1d', constructor_args=(2,), + cpp_constructor_args="(2)", input_size=(2, 3, 6), ), dict( module_name='AvgPool1d', constructor_args=((2,), (2,)), + cpp_constructor_args="(torch::nn::AvgPool1dOptions(2).stride(2))", input_size=(2, 3, 6), desc='stride', ), dict( module_name='AvgPool1d', constructor_args=(2, 2, 1), + cpp_constructor_args="(torch::nn::AvgPool1dOptions(2).stride(2).padding(1))", input_size=(2, 3, 6), desc='stride_pad', ), dict( module_name='AvgPool2d', constructor_args=((2, 2),), + cpp_constructor_args="(torch::nn::AvgPool2dOptions({2, 2}))", input_size=(2, 3, 6, 6), ), dict( module_name='AvgPool2d', constructor_args=((2, 2), (2, 2)), + cpp_constructor_args="(torch::nn::AvgPool2dOptions({2, 2}).stride({2, 2}))", input_size=(2, 3, 6, 6), desc='stride', ), dict( module_name='AvgPool2d', constructor_args=((2, 2), (2, 2), (1, 1)), + cpp_constructor_args="(torch::nn::AvgPool2dOptions({2, 2}).stride({2, 2}).padding({1, 1}))", input_size=(2, 3, 6, 6), desc='stride_pad', ), dict( fullname='AvgPool2d_divisor', constructor=lambda: nn.AvgPool2d((2, 2), divisor_override=1), + cpp_constructor_args="(torch::nn::AvgPool2dOptions({2, 2}).divisor_override(1))", input_size=(2, 3, 6, 6), check_with_long_tensor=True, ), dict( fullname='AvgPool2d_divisor_stride', constructor=lambda: nn.AvgPool2d((2, 2), (2, 2), divisor_override=1), + cpp_constructor_args="(torch::nn::AvgPool2dOptions({2, 2}).stride({2, 2}).divisor_override(1))", input_size=(2, 3, 6, 6), check_with_long_tensor=True, ), dict( fullname='AvgPool2d_divisor_stride_pad', constructor=lambda: nn.AvgPool2d((2, 2), (2, 2), (1, 1), divisor_override=1), + cpp_constructor_args="(torch::nn::AvgPool2dOptions({2, 2}).stride({2, 2}).padding({1, 1}).divisor_override(1))", input_size=(2, 3, 6, 6), check_with_long_tensor=True, ), @@ -1580,100 +1592,117 @@ def fractional_max_pool3d_test(test_case): dict( module_name='MaxPool3d', constructor_args=((2, 2, 2),), + cpp_constructor_args='(torch::nn::MaxPool3dOptions({2, 2, 2}))', input_size=(2, 3, 5, 5, 5), ), dict( module_name='MaxPool3d', constructor_args=(2, (2, 2, 2)), + cpp_constructor_args='(torch::nn::MaxPool3dOptions(2).stride({2, 2, 2}))', input_size=(2, 3, 5, 5, 5), desc='stride', ), dict( module_name='MaxPool3d', constructor_args=(2, 2, (1, 1, 1)), + cpp_constructor_args='(torch::nn::MaxPool3dOptions(2).stride(2).padding({1, 1, 1}))', input_size=(2, 3, 5, 5, 5), desc='stride_padding', ), dict( module_name='AvgPool3d', constructor_args=((2, 2, 2),), + cpp_constructor_args="(torch::nn::AvgPool3dOptions({2, 2, 2}))", input_size=(2, 3, 4, 4, 4), ), dict( module_name='AvgPool3d', constructor_args=(2, (2, 2, 2)), + cpp_constructor_args="(torch::nn::AvgPool3dOptions(2).stride({2, 2, 2}))", input_size=(2, 3, 5, 5, 5), desc='stride', ), dict( module_name='AvgPool3d', constructor_args=(2, 2, (1, 1, 1)), + cpp_constructor_args="(torch::nn::AvgPool3dOptions(2).stride(2).padding({1, 1, 1}))", input_size=(2, 3, 5, 5, 5), desc='stride_pad', ), dict( module_name='AvgPool3d', constructor_args=(4, 2, (1, 2, 1)), + cpp_constructor_args="(torch::nn::AvgPool3dOptions(4).stride(2).padding({1, 2, 1}))", input_size=(2, 3, 5, 5, 5), desc='stride_pad_gpu_fixedkw_output', ), dict( module_name='AvgPool3d', constructor_args=((2, 4, 8), 1, (1, 1, 2)), + cpp_constructor_args="(torch::nn::AvgPool3dOptions({2, 4, 8}).stride(1).padding({1, 1, 2}))", input_size=(2, 3, 2, 4, 8), desc='stride_pad_gpu_general_output', ), dict( module_name='AvgPool3d', constructor_args=(3, 1, 0), + cpp_constructor_args="(torch::nn::AvgPool3dOptions(3).stride(1).padding(0))", input_size=(2, 3, 4, 4, 4), desc='stride1_pad0_gpu_input', ), dict( module_name='AvgPool3d', constructor_args=(2, 2, (1, 1, 1)), + cpp_constructor_args="(torch::nn::AvgPool3dOptions(2).stride(2).padding({1, 1, 1}))", input_size=(2, 3, 4, 4, 4), desc='stride_pad_gpu_input_nooverlap', ), dict( fullname='AvgPool3d_divisor', constructor=lambda: nn.AvgPool3d((2, 2, 2), divisor_override=1), + cpp_constructor_args="(torch::nn::AvgPool3dOptions({2, 2, 2}).divisor_override(1))", input_size=(2, 3, 4, 4, 4), check_with_long_tensor=True, ), dict( fullname='AvgPool3d_divisor_stride', constructor=lambda: nn.AvgPool3d(2, (2, 2, 2), divisor_override=1), + cpp_constructor_args="(torch::nn::AvgPool3dOptions(2).stride({2, 2, 2}).divisor_override(1))", input_size=(2, 3, 5, 5, 5), check_with_long_tensor=True, ), dict( fullname='AvgPool3d_divisor_stride_pad', constructor=lambda: nn.AvgPool3d(2, 2, (1, 1, 1), divisor_override=1), + cpp_constructor_args="(torch::nn::AvgPool3dOptions(2).stride(2).padding({1, 1, 1}).divisor_override(1))", input_size=(2, 3, 5, 5, 5), check_with_long_tensor=True, ), dict( fullname='AvgPool3d_divisor_stride_pad_gpu_fixedkw_output', constructor=lambda: nn.AvgPool3d(4, 2, (1, 2, 1), divisor_override=1), + cpp_constructor_args="(torch::nn::AvgPool3dOptions(4).stride(2).padding({1, 2, 1}).divisor_override(1))", input_size=(2, 3, 5, 5, 5), check_with_long_tensor=True, ), dict( fullname='AvgPool3d_divisor_stride_pad_gpu_general_output', constructor=lambda: nn.AvgPool3d((2, 4, 8), 1, (1, 1, 2), divisor_override=1), + cpp_constructor_args="(torch::nn::AvgPool3dOptions({2, 4, 8}).stride(1).padding({1, 1, 2}).divisor_override(1))", input_size=(2, 3, 2, 4, 8), check_with_long_tensor=True, ), dict( fullname='AvgPool3d_divisor_stride1_pad0_gpu_input', constructor=lambda: nn.AvgPool3d(3, 1, 0, divisor_override=1), + cpp_constructor_args="(torch::nn::AvgPool3dOptions(3).stride(1).padding(0).divisor_override(1))", input_size=(2, 3, 4, 4, 4), check_with_long_tensor=True, ), dict( fullname='AvgPool3d_divisor_stride_pad_gpu_input_nooverlap', constructor=lambda: nn.AvgPool3d(2, 2, (1, 1, 1), divisor_override=1), + cpp_constructor_args="(torch::nn::AvgPool3dOptions(2).stride(2).padding({1, 1, 1}).divisor_override(1))", input_size=(2, 3, 4, 4, 4), check_with_long_tensor=True, ), @@ -2194,6 +2223,7 @@ def fractional_max_pool3d_test(test_case): dict( fullname='Fold', constructor=lambda: nn.Fold((3, 3), (2, 2), (1, 1), (0, 0), (1, 1)), + cpp_constructor_args='(torch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1}))', input_size=(2, 16, 4), check_gradgrad=False, test_cuda=True, @@ -2208,6 +2238,7 @@ def fractional_max_pool3d_test(test_case): dict( fullname='Fold_int_input', constructor=lambda: nn.Fold(3, 2, 1, 0, 1), + cpp_constructor_args='(torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1))', input_size=(2, 16, 4), check_gradgrad=False, test_cuda=True, @@ -2768,6 +2799,7 @@ def padding3d_circular(input, pad): criterion_tests = [ dict( module_name='L1Loss', + cpp_constructor_args='', input_size=(2, 3, 4), target_size=(2, 3, 4), reference_fn=lambda i, t, _: 1. / i.numel() * @@ -3094,6 +3126,7 @@ def padding3d_circular(input, pad): ), dict( module_name='L1Loss', + cpp_constructor_args='', input_size=(), target_size=(), reference_fn=lambda i, t, _: 1. / i.numel() * (i - t).abs().sum(), diff --git a/test/common_quantization.py b/test/common_quantization.py index 3b1827c015344..fb73b498de089 100644 --- a/test/common_quantization.py +++ b/test/common_quantization.py @@ -61,6 +61,7 @@ def prepare_dynamic(model, qconfig_dict=None): # QuantizationTestCase used as a base class for testing quantization on modules class QuantizationTestCase(TestCase): def setUp(self): + super(QuantizationTestCase, self).setUp() self.calib_data = [(torch.rand(2, 5, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)) for _ in range(2)] self.train_data = [(torch.rand(2, 5, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)) for _ in range(2)] self.img_data = [(torch.rand(2, 3, 10, 10, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)) @@ -106,19 +107,16 @@ def checkWrappedQuantizedLinear(self, mod): has Quantize and DeQuantize submodules """ self.assertEqual(type(mod.module), nnq.Linear) - self.assertEqual(mod.module.bias.dtype, torch.qint32) self.checkQuantDequant(mod) def checkQuantizedLinear(self, mod): self.assertEqual(type(mod), nnq.Linear) - self.assertEqual(mod.bias.dtype, torch.qint32) def checkDynamicQuantizedLinear(self, mod): r"""Checks that mod has been swapped for an nnqd.Linear module, the bias is float. """ self.assertEqual(type(mod), nnqd.Linear) - self.assertEqual(mod.bias.dtype, torch.float) def checkLinear(self, mod): self.assertEqual(type(mod), torch.nn.Linear) @@ -281,8 +279,8 @@ def __init__(self): 'dtype': torch.quint8, 'qscheme': torch.per_tensor_affine } - custom_qconfig = QConfig(weight=default_weight_observer(), - activation=default_observer(**custom_options)) + custom_qconfig = QConfig(activation=default_observer(**custom_options), + weight=default_weight_observer()) self.sub2.fc1.qconfig = custom_qconfig self.sub2.fc1 = QuantWrapper(self.sub2.fc1) @@ -494,3 +492,34 @@ def forward(self, x): out = self.relu2(out) out = self.avgpool(out) return out + +class ModelMultipleOps(torch.nn.Module): + def __init__(self): + super(ModelMultipleOps, self).__init__() + norm_layer = nn.BatchNorm2d + inplanes = 3 + self.conv1 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) + self.conv2 = nn.Conv2d(inplanes, inplanes, (1, 1), bias=False) + self.bn1 = norm_layer(inplanes) + self.relu1 = nn.ReLU() + self.relu2 = nn.ReLU() + self.downsample = torch.nn.Identity() + self.skip_add = nn.quantized.FloatFunctional() + self.cat = nn.quantized.FloatFunctional() + self.avgpool = nn.AdaptiveAvgPool2d((4, 4)) + self.fc = nn.Linear(12, 6) + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + identity = self.downsample(x) + out = self.skip_add.add(out, identity) + out = self.relu2(out) + out = self.avgpool(out) + out = self.conv2(out) + out = torch.nn.functional.max_pool2d(out, 2, 2) + out = self.cat.cat([out, out]) + out = out.view(-1, 3 * 2 * 2) + out = self.fc(out) + return out diff --git a/test/common_quantized.py b/test/common_quantized.py index dbd8b8375b842..131049a05111f 100644 --- a/test/common_quantized.py +++ b/test/common_quantized.py @@ -4,6 +4,7 @@ from __future__ import absolute_import, division, print_function, unicode_literals import numpy as np import torch +from contextlib import contextmanager """Computes the output shape given convolution parameters.""" def _conv_output_shape(input_size, kernel_size, padding, stride, dilation, @@ -61,3 +62,11 @@ def _calculate_dynamic_qparams(X, dtype): zero_point = max(qmin, zero_point) zero_point = min(qmax, zero_point) return [float(scale), int(zero_point)] + +@contextmanager +def enable_mobile_quantized_engine(): + torch.backends.quantized.engine = torch.qnnpack + try: + yield + finally: + torch.backends.quantized.engine = torch.fbgemm diff --git a/test/common_utils.py b/test/common_utils.py index bb5b826f3acd1..51997be0579e6 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -22,7 +22,7 @@ import time from collections import OrderedDict from contextlib import contextmanager -from functools import wraps, partial +from functools import wraps from itertools import product from copy import deepcopy from numbers import Number @@ -43,7 +43,7 @@ torch.set_default_tensor_type('torch.DoubleTensor') -torch.backends.cudnn.disable_global_flags() +torch.backends.disable_global_flags() parser = argparse.ArgumentParser(add_help=False) @@ -120,6 +120,7 @@ def add_to_test_cases(suite_or_case): PY34 = sys.version_info >= (3, 4) IS_WINDOWS = sys.platform == "win32" +IS_MACOS = sys.platform == "darwin" IS_PPC = platform.machine() == "ppc64le" # Environment variable `IS_PYTORCH_CI` is set in `.jenkins/common.sh`. @@ -184,7 +185,7 @@ def _check_module_exists(name): TEST_WITH_TSAN = os.getenv('PYTORCH_TEST_WITH_TSAN', '0') == '1' TEST_WITH_UBSAN = os.getenv('PYTORCH_TEST_WITH_UBSAN', '0') == '1' TEST_WITH_ROCM = os.getenv('PYTORCH_TEST_WITH_ROCM', '0') == '1' - +TEST_WITH_QNNPACK = os.getenv('PYTORCH_TEST_WITH_QNNPACK', '0') == '1' # Enables tests that are slow to run (disabled by default) TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1' @@ -215,52 +216,6 @@ def run_test_function(self): return run_test_function -class torchtest(): - """Allows to generate and run per-device unittests. - - This decorator class allows to generate and run per-device unittest. - - Example: - - class _TestTorchMixin(torchtest): - - @torchtest.for_all_device_types() - def test_zeros_like(self, device): - expected = torch.zeros((100, 100,), device=device) - - Will execute: - - test_zeros_like (__main__.TestTorch) ... skipped 'Look at test_zeros_like_cpu, test_zeros_like_cuda results.' - test_zeros_like_cpu (__main__.TestTorch) ... ok - test_zeros_like_cuda (__main__.TestTorch) ... ok - - To work properly, test class should be inherited from `torchtest`. - for_all_device_types decorator does not guarantee proper functionality in - combination with other decorators. - - Please do not extend this decorator to support other cases (such as dtype, - layouts, etc) without consulting with bigger group. Devices is the special - case as build flags control additions/removals (see - https://github.com/pytorch/pytorch/pull/23824 for the reference). - """ - @classmethod - def for_all_device_types(cls): - def wrapper(fn): - test_names = [] - - for device in torch.testing.get_all_device_types(): - test_name = fn.__name__ + '_' + device - assert not hasattr(cls, test_name), "Duplicated test name: " + test_name - setattr(cls, test_name, _test_function(fn, device)) - test_names.append(test_name) - - @wraps(fn) - def empty_test(*args, **kwargs): - raise unittest.SkipTest("Look at {} results.".format(", ".join(test_names))) - return empty_test - return wrapper - - def skipIfNoLapack(fn): @wraps(fn) def wrapper(*args, **kwargs): @@ -459,6 +414,65 @@ def __exit__(self, exec_type, exec_value, traceback): warnings.warn('{} leaked {} bytes ROCm memory on device {}'.format( self.name, after - before, i), RuntimeWarning) +# "min_satisfying_examples" setting has been deprecated in hypythesis +# 3.56.0 and removed in hypothesis 4.x +try: + import hypothesis + if hypothesis.version.__version_info__ >= (3, 56, 0): + hypothesis.settings.register_profile( + "pytorch_ci", + hypothesis.settings( + derandomize=True, + suppress_health_check=[hypothesis.HealthCheck.too_slow], + database=None, + max_examples=100)) + hypothesis.settings.register_profile( + "dev", + hypothesis.settings( + suppress_health_check=[hypothesis.HealthCheck.too_slow], + database=None, + max_examples=10, + verbosity=hypothesis.Verbosity.verbose)) + hypothesis.settings.register_profile( + "debug", + hypothesis.settings( + suppress_health_check=[hypothesis.HealthCheck.too_slow], + database=None, + max_examples=1000, + verbosity=hypothesis.Verbosity.verbose)) + else: + hypothesis.settings.register_profile( + "pytorch_ci", + hypothesis.settings( + derandomize=True, + suppress_health_check=[hypothesis.HealthCheck.too_slow], + database=None, + max_examples=100, + min_satisfying_examples=1)) + hypothesis.settings.register_profile( + "dev", + hypothesis.settings( + suppress_health_check=[hypothesis.HealthCheck.too_slow], + database=None, + max_examples=10, + min_satisfying_examples=1, + verbosity=hypothesis.Verbosity.verbose)) + hypothesis.settings.register_profile( + "debug", + hypothesis.settings( + suppress_health_check=[hypothesis.HealthCheck.too_slow], + database=None, + max_examples=1000, + min_satisfying_examples=1, + verbosity=hypothesis.Verbosity.verbose)) + + hypothesis.settings.load_profile( + "pytorch_ci" if IS_PYTORCH_CI else os.getenv('PYTORCH_HYPOTHESIS_PROFILE', + 'dev') + ) +except ImportError: + print('Fail to import hypothesis in common_utils, tests are not derandomized') + class TestCase(expecttest.TestCase): precision = 1e-5 maxDiff = None @@ -477,7 +491,7 @@ def __init__(self, method_name='runTest'): # Wraps the tested method if we should enforce non default CUDA stream. self._do_cuda_non_default_stream &= getattr(test_method, '_do_cuda_non_default_stream', True) - if self._do_cuda_non_default_stream and not IS_WINDOWS: + if self._do_cuda_non_default_stream and not IS_WINDOWS and not TEST_WITH_ROCM: self.wrap_with_cuda_policy(method_name, self.enforceNonDefaultStream) def assertLeaksNoCudaTensors(self, name=None): @@ -1014,9 +1028,9 @@ def random_symmetric_psd_matrix(l, *batches): return torch.matmul(A, A.transpose(-2, -1)) -def random_symmetric_pd_matrix(l, *batches): - A = torch.randn(*(batches + (l, l))) - return torch.matmul(A, A.transpose(-2, -1)) + torch.eye(l) * 1e-5 +def random_symmetric_pd_matrix(matrix_size, *batch_dims): + A = torch.randn(*(batch_dims + (matrix_size, matrix_size))) + return torch.matmul(A, A.transpose(-2, -1)) + torch.eye(matrix_size) * 1e-5 def make_nonzero_det(A, sign=None, min_singular_value=0.1): @@ -1048,31 +1062,37 @@ def random_fullrank_matrix_distinct_singular_value(matrix_size, *batch_dims, **k return u.matmul(s.expand(batch_dims + (matrix_size, matrix_size)).matmul(v.transpose(-2, -1))) -def random_linalg_solve_processed_inputs(A_dims, b_dims, gen_fn, transform_fn, cast_fn): - """ - For solve methods, this returns the following values: - RHS tensor: generated using torch.randn - LHS tensor: generated using gen_fn - Transformed LHS tensor(s): returned after calling transform_fn. - This can be a tuple or a single tensor depending on transform_fn - For instance, if transform_fn == torch.cholesky, then the return value - is a single tensor. If transform_fn == torch.lu, then the return value - is a tuple of tensors - """ - RHS = cast_fn(torch.randn(*b_dims)) - LHS = cast_fn(gen_fn(*A_dims)) - transformed_LHS = transform_fn(LHS) - return RHS, LHS, transformed_LHS - - def lu_solve_test_helper(self, A_dims, b_dims, cast, pivot): - b, A, (LU_data, LU_pivots, info) = random_linalg_solve_processed_inputs( - A_dims, b_dims, random_fullrank_matrix_distinct_singular_value, - partial(torch.lu, get_infos=True, pivot=pivot), cast) + b = cast(torch.randn(*b_dims)) + A = cast(random_fullrank_matrix_distinct_singular_value(*A_dims)) + LU_data, LU_pivots, info = torch.lu(A, get_infos=True, pivot=pivot) self.assertEqual(info, torch.zeros_like(info)) return b, A, LU_data, LU_pivots +def cholesky_solve_test_helper(A_dims, b_dims, cast, upper): + b = cast(torch.randn(*b_dims)) + A = cast(random_symmetric_pd_matrix(*A_dims)) + L = torch.cholesky(A, upper=upper) + return b, A, L + + +def triangular_solve_test_helper(A_dims, b_dims, cast, upper, unitriangular): + triangle_function = torch.triu if upper else torch.tril + b = cast(torch.randn(*b_dims)) + A = cast(torch.randn(*A_dims)) + A_triangular = triangle_function(A) + if unitriangular: + A_triangular.diagonal(dim1=-2, dim2=-1).fill_(1.) + return b, A_triangular + + +def solve_test_helper(A_dims, b_dims, cast): + b = cast(torch.randn(*b_dims)) + A = cast(random_fullrank_matrix_distinct_singular_value(*A_dims)) + return b, A + + def brute_pdist(inp, p=2): """Computes the same as torch.pdist using primitives""" n = inp.shape[-2] diff --git a/test/cpp/api/CMakeLists.txt b/test/cpp/api/CMakeLists.txt index b815dd6f16f20..2b933817f6cc7 100644 --- a/test/cpp/api/CMakeLists.txt +++ b/test/cpp/api/CMakeLists.txt @@ -5,6 +5,7 @@ set(TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/any.cpp ${TORCH_API_TEST_DIR}/dataloader.cpp ${TORCH_API_TEST_DIR}/expanding-array.cpp + ${TORCH_API_TEST_DIR}/functional.cpp ${TORCH_API_TEST_DIR}/integration.cpp ${TORCH_API_TEST_DIR}/init.cpp ${TORCH_API_TEST_DIR}/jit.cpp diff --git a/test/cpp/api/any.cpp b/test/cpp/api/any.cpp index 02a32abb4b141..c99714958ef9b 100644 --- a/test/cpp/api/any.cpp +++ b/test/cpp/api/any.cpp @@ -1,8 +1,6 @@ #include -#include -#include -#include +#include #include diff --git a/test/cpp/api/autograd.cpp b/test/cpp/api/autograd.cpp index c3c79aa1884ab..98ad89843760f 100644 --- a/test/cpp/api/autograd.cpp +++ b/test/cpp/api/autograd.cpp @@ -350,8 +350,6 @@ TEST(CustomAutogradTest, InvalidGradients) { ASSERT_THROWS_WITH( MyFunction::apply(input1).sum().backward(), "expected shape"); auto input2 = torch::randn(10, torch::dtype(torch::kDouble).requires_grad(true)); - ASSERT_THROWS_WITH( - MyFunction::apply(input2).sum().backward(), "expected type"); } TEST(CustomAutogradTest, NoGradInput) { diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp new file mode 100644 index 0000000000000..001dc6caa5322 --- /dev/null +++ b/test/cpp/api/functional.cpp @@ -0,0 +1,65 @@ +#include + +#include + +#include + +namespace F = torch::nn::functional; + +using namespace torch::nn; + +struct FunctionalTest : torch::test::SeedingFixture {}; + +TEST_F(FunctionalTest, MaxPool1d) { + auto x = torch::ones({1, 1, 5}, torch::requires_grad()); + auto y = F::max_pool1d(x, MaxPool1dOptions(3).stride(2)); + + ASSERT_EQ(y.ndimension(), 3); + ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1 ,2}))); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2})); +} + +TEST_F(FunctionalTest, MaxPool2d) { + auto x = torch::ones({2, 5, 5}, torch::requires_grad()); + auto y = F::max_pool2d(x, MaxPool2dOptions(3).stride(2)); + + ASSERT_EQ(y.ndimension(), 3); + ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2 ,2}))); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2})); +} + +TEST_F(FunctionalTest, MaxPool3d) { + auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad()); + auto y = F::max_pool3d(x, MaxPool3dOptions(3).stride(2)); + + ASSERT_EQ(y.ndimension(), 4); + ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2}))); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2, 2})); +} + +TEST_F(FunctionalTest, AvgPool1d) { + auto x = torch::ones({1, 1, 5}, torch::requires_grad()); + auto y = F::avg_pool1d(x, AvgPool1dOptions(3).stride(2)); + + ASSERT_EQ(y.ndimension(), 3); + ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2}))); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2})); +} + +TEST_F(FunctionalTest, AvgPool2d) { + auto x = torch::ones({2, 5, 5}, torch::requires_grad()); + auto y = F::avg_pool2d(x, AvgPool2dOptions(3).stride(2)); + + ASSERT_EQ(y.ndimension(), 3); + ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2})); +} + +TEST_F(FunctionalTest, AvgPool3d) { + auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad()); + auto y = F::avg_pool3d(x, AvgPool3dOptions(3).stride(2)); + + ASSERT_EQ(y.ndimension(), 4); + ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2}))); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2, 2})); +} diff --git a/test/cpp/api/misc.cpp b/test/cpp/api/misc.cpp index 096891175b510..7e36bb9f31118 100644 --- a/test/cpp/api/misc.cpp +++ b/test/cpp/api/misc.cpp @@ -56,8 +56,9 @@ TEST(NoGradTest, SetsGradModeCorrectly) { auto y = model->forward(x); torch::Tensor s = y.sum(); - s.backward(); - ASSERT_FALSE(model->weight.grad().defined()); + // Mimicking python API behavior: + ASSERT_THROWS_WITH(s.backward(), + "element 0 of tensors does not require grad and does not have a grad_fn") } struct AutogradTest : torch::test::SeedingFixture { @@ -70,7 +71,7 @@ struct AutogradTest : torch::test::SeedingFixture { }; TEST_F(AutogradTest, CanTakeDerivatives) { - z.backward(); + z.backward(torch::ones_like(z)); ASSERT_TRUE(x.grad().allclose(y)); } diff --git a/test/cpp/api/module.cpp b/test/cpp/api/module.cpp index 76a76e4375ac5..de8c9b7566a48 100644 --- a/test/cpp/api/module.cpp +++ b/test/cpp/api/module.cpp @@ -1,11 +1,6 @@ #include -#include -#include -#include -#include -#include -#include +#include #include @@ -115,6 +110,17 @@ TEST_F(ModuleTest, ReplaceModule) { ASSERT_EQ(model->l1.get(), model->named_modules()["l1"]->as()); } +TEST_F(ModuleTest, UnregisterModule) { + struct TestModel : public torch::nn::Module {}; + TestModel model; + ASSERT_THROWS_WITH( + model.unregister_module("linear"), + "No Module with name `linear` is registered"); + model.register_module("linear", torch::nn::Linear(3, 4)); + model.unregister_module("linear"); + ASSERT_TRUE(model.children().empty()); +} + TEST_F(ModuleTest, RegisterParameterThrowsForEmptyOrDottedName) { struct TestModel : public torch::nn::Module {}; ASSERT_THROWS_WITH( diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp index d2119ab70ff3f..818573f8b1dd1 100644 --- a/test/cpp/api/modules.cpp +++ b/test/cpp/api/modules.cpp @@ -1,14 +1,6 @@ #include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include #include @@ -103,6 +95,110 @@ TEST_F(ModulesTest, Conv3d) { ASSERT_TRUE(model->weight.grad().numel() == 3 * 2 * 3 * 3 * 3); } +TEST_F(ModulesTest, MaxPool1d) { + MaxPool1d model(MaxPool1dOptions(3).stride(2)); + auto x = torch::ones({1, 1, 5}, torch::requires_grad()); + auto y = model(x); + torch::Tensor s = y.sum(); + + s.backward(); + ASSERT_EQ(y.ndimension(), 3); + ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1 ,2}))); + ASSERT_EQ(s.ndimension(), 0); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2})); +} + +TEST_F(ModulesTest, MaxPool2dEven) { + MaxPool2d model(MaxPool2dOptions(3).stride(2)); + auto x = torch::ones({2, 5, 5}, torch::requires_grad()); + auto y = model(x); + torch::Tensor s = y.sum(); + + s.backward(); + ASSERT_EQ(y.ndimension(), 3); + ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2 ,2}))); + ASSERT_EQ(s.ndimension(), 0); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2})); +} + +TEST_F(ModulesTest, MaxPool2dUneven) { + MaxPool2d model(MaxPool2dOptions({3, 2}).stride({2, 2})); + auto x = torch::ones({2, 5, 4}, torch::requires_grad()); + auto y = model(x); + torch::Tensor s = y.sum(); + + s.backward(); + ASSERT_EQ(y.ndimension(), 3); + ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); + ASSERT_EQ(s.ndimension(), 0); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2})); +} + +TEST_F(ModulesTest, MaxPool3d) { + MaxPool3d model(MaxPool3dOptions(3).stride(2)); + auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad()); + auto y = model(x); + torch::Tensor s = y.sum(); + + s.backward(); + ASSERT_EQ(y.ndimension(), 4); + ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2}))); + ASSERT_EQ(s.ndimension(), 0); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2, 2})); +} + +TEST_F(ModulesTest, AvgPool1d) { + AvgPool1d model(AvgPool1dOptions(3).stride(2)); + auto x = torch::ones({1, 1, 5}, torch::requires_grad()); + auto y = model(x); + torch::Tensor s = y.sum(); + + s.backward(); + ASSERT_EQ(y.ndimension(), 3); + ASSERT_TRUE(torch::allclose(y, torch::ones({1, 1, 2}))); + ASSERT_EQ(s.ndimension(), 0); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({1, 1, 2})); +} + +TEST_F(ModulesTest, AvgPool2dEven) { + AvgPool2d model(AvgPool2dOptions(3).stride(2)); + auto x = torch::ones({2, 5, 5}, torch::requires_grad()); + auto y = model(x); + torch::Tensor s = y.sum(); + + s.backward(); + ASSERT_EQ(y.ndimension(), 3); + ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); + ASSERT_EQ(s.ndimension(), 0); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2})); +} + +TEST_F(ModulesTest, AvgPool2dUneven) { + AvgPool2d model(AvgPool2dOptions({3, 2}).stride({2, 2})); + auto x = torch::ones({2, 5, 4}, torch::requires_grad()); + auto y = model(x); + torch::Tensor s = y.sum(); + + s.backward(); + ASSERT_EQ(y.ndimension(), 3); + ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2}))); + ASSERT_EQ(s.ndimension(), 0); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2})); +} + +TEST_F(ModulesTest, AvgPool3d) { + AvgPool3d model(AvgPool3dOptions(3).stride(2)); + auto x = torch::ones({2, 5, 5, 5}, torch::requires_grad()); + auto y = model(x); + torch::Tensor s = y.sum(); + + s.backward(); + ASSERT_EQ(y.ndimension(), 4); + ASSERT_TRUE(torch::allclose(y, torch::ones({2, 2, 2, 2}))); + ASSERT_EQ(s.ndimension(), 0); + ASSERT_EQ(y.sizes(), torch::IntArrayRef({2, 2, 2, 2})); +} + TEST_F(ModulesTest, Linear) { Linear model(5, 2); auto x = torch::randn({10, 5}, torch::requires_grad()); @@ -118,6 +214,21 @@ TEST_F(ModulesTest, Linear) { ASSERT_EQ(model->weight.grad().numel(), 2 * 5); } +TEST_F(ModulesTest, Fold) { + Fold model(FoldOptions({4, 5}, {2, 2})); + auto x = torch::randn({1, 3 * 2 * 2, 12}, torch::requires_grad()); + auto y = model(x); + torch::Tensor s = y.sum(); + + s.backward(); + ASSERT_EQ(y.ndimension(), 4); + ASSERT_EQ(s.ndimension(), 0); + ASSERT_EQ(y.size(0), 1); + ASSERT_EQ(y.size(1), 3); + ASSERT_EQ(y.size(2), 4); + ASSERT_EQ(y.size(3), 5); +} + TEST_F(ModulesTest, SimpleContainer) { auto model = std::make_shared(); auto l1 = model->add(Linear(10, 3), "l1"); @@ -129,7 +240,7 @@ TEST_F(ModulesTest, SimpleContainer) { x = l2(x).clamp_min(0); x = l3(x).clamp_min(0); - x.backward(); + x.backward(torch::ones_like(x)); ASSERT_EQ(x.ndimension(), 2); ASSERT_EQ(x.size(0), 1000); ASSERT_EQ(x.size(1), 100); @@ -177,7 +288,7 @@ TEST_F(ModulesTest, Dropout) { torch::Tensor x = torch::ones(100, torch::requires_grad()); torch::Tensor y = dropout(x); - y.backward(); + y.backward(torch::ones_like(y)); ASSERT_EQ(y.ndimension(), 1); ASSERT_EQ(y.size(0), 100); ASSERT_LT(y.sum().item(), 130); // Probably @@ -326,6 +437,49 @@ TEST_F(ModulesTest, Linear2_CUDA) { ASSERT_EQ(model->weight.grad().numel(), 2 * 5); } +TEST_F(ModulesTest, L1Loss) { + L1Loss loss; + auto input = torch::randn({5,6}, torch::requires_grad()); + auto target = torch::empty({5,6}).random_(2); + auto output = loss->forward(torch::sigmoid(input), target); + auto s = output.sum(); + s.backward(); + + ASSERT_EQ(output.sizes(), torch::IntArrayRef()); + ASSERT_EQ(input.sizes(), input.grad().sizes()); +} + +TEST_F(ModulesTest, CosineSimilarity) { + CosineSimilarity cos(CosineSimilarityOptions().dim(1)); + float data1[] = {1, 2, 3, 4, 5, 6}; + auto input1 = torch::from_blob(data1, {2, 3}, torch::requires_grad()); + float data2[] = {1, 8, 3, 2, 1, 6}; + auto input2 = torch::from_blob(data2, {2, 3}, torch::requires_grad()); + auto output = cos->forward(input1, input2); + float data3[] = {0.8078, 0.8721}; + auto expected = torch::from_blob(data3, {2}); + auto s = output.sum(); + s.backward(); + + ASSERT_TRUE(output.allclose(expected, 1e-04)); + ASSERT_EQ(input1.sizes(), input1.grad().sizes()); +} + +TEST_F(ModulesTest, PairwiseDistance) { + PairwiseDistance dist(PairwiseDistanceOptions(1)); + float data1[] = {1, 2, 3, 4, 5, 6}; + auto input1 = torch::from_blob(data1, {2, 3}, torch::requires_grad()); + float data2[] = {1, 8, 3, 2, 1, 6}; + auto input2 = torch::from_blob(data2, {2, 3}, torch::requires_grad()); + auto output = dist->forward(input1, input2); + auto expected = torch::full({2}, 6); + auto s = output.sum(); + s.backward(); + + ASSERT_TRUE(output.allclose(expected)); + ASSERT_EQ(input1.sizes(), input1.grad().sizes()); +} + TEST_F(ModulesTest, PrettyPrintLinear) { ASSERT_EQ( c10::str(Linear(3, 4)), "torch::nn::Linear(in=3, out=4, with_bias=true)"); @@ -342,12 +496,49 @@ TEST_F(ModulesTest, PrettyPrintConv) { c10::str(Conv2d(Conv2dOptions(3, 4, 5).stride(2))), "torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 5], stride=[2, 2])"); - const auto options = Conv2dOptions(3, 4, torch::IntArrayRef{5, 6}).stride({1, 2}); + const auto options = + Conv2dOptions(3, 4, torch::IntArrayRef{5, 6}).stride({1, 2}); ASSERT_EQ( c10::str(Conv2d(options)), "torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 6], stride=[1, 2])"); } +TEST_F(ModulesTest, PrettyPrintMaxPool) { + ASSERT_EQ( + c10::str(MaxPool1d(5)), + "torch::nn::MaxPool1d(kernel_size=5, stride=5)"); + ASSERT_EQ( + c10::str(MaxPool2d(5)), + "torch::nn::MaxPool2d(kernel_size=[5, 5], stride=[5, 5])"); + ASSERT_EQ( + c10::str(MaxPool2d(MaxPool2dOptions(5).stride(2))), + "torch::nn::MaxPool2d(kernel_size=[5, 5], stride=[2, 2])"); + + const auto options = + MaxPool2dOptions(torch::IntArrayRef{5, 6}).stride({1, 2}); + ASSERT_EQ( + c10::str(MaxPool2d(options)), + "torch::nn::MaxPool2d(kernel_size=[5, 6], stride=[1, 2])"); +} + +TEST_F(ModulesTest, PrettyPrintAvgPool) { + ASSERT_EQ( + c10::str(AvgPool1d(5)), + "torch::nn::AvgPool1d(kernel_size=5, stride=5)"); + ASSERT_EQ( + c10::str(AvgPool2d(5)), + "torch::nn::AvgPool2d(kernel_size=[5, 5], stride=[5, 5])"); + ASSERT_EQ( + c10::str(AvgPool2d(AvgPool2dOptions(5).stride(2))), + "torch::nn::AvgPool2d(kernel_size=[5, 5], stride=[2, 2])"); + + const auto options = + AvgPool2dOptions(torch::IntArrayRef{5, 6}).stride({1, 2}); + ASSERT_EQ( + c10::str(AvgPool2d(options)), + "torch::nn::AvgPool2d(kernel_size=[5, 6], stride=[1, 2])"); +} + TEST_F(ModulesTest, PrettyPrintDropout) { ASSERT_EQ(c10::str(Dropout(0.5)), "torch::nn::Dropout(rate=0.5)"); ASSERT_EQ( @@ -372,6 +563,24 @@ TEST_F(ModulesTest, PrettyPrintEmbedding) { "torch::nn::Embedding(count=10, dimension=2)"); } +TEST_F(ModulesTest, PrettyPrintCosineSimilarity) { + ASSERT_EQ( + c10::str(CosineSimilarity()), + "torch::nn::CosineSimilarity(dim=1, eps=1e-08)"); + ASSERT_EQ( + c10::str(CosineSimilarity(CosineSimilarityOptions().dim(0).eps(0.5))), + "torch::nn::CosineSimilarity(dim=0, eps=0.5)"); +} + +TEST_F(ModulesTest, PrettyPrintPairwiseDistance) { + ASSERT_EQ( + c10::str(PairwiseDistance()), + "torch::nn::PairwiseDistance(p=2, eps=1e-06, keepdim=false)"); + ASSERT_EQ( + c10::str(PairwiseDistance(PairwiseDistanceOptions(3).eps(0.5).keepdim(true))), + "torch::nn::PairwiseDistance(p=3, eps=0.5, keepdim=true)"); +} + TEST_F(ModulesTest, PrettyPrintNestedModel) { struct InnerTestModule : torch::nn::Module { InnerTestModule() diff --git a/test/cpp/api/optim.cpp b/test/cpp/api/optim.cpp index 47ccfb49b75b4..ab747ca4a5f0e 100644 --- a/test/cpp/api/optim.cpp +++ b/test/cpp/api/optim.cpp @@ -1,12 +1,6 @@ #include -#include -#include -#include -#include -#include -#include -#include +#include #include #include diff --git a/test/cpp/api/ordered_dict.cpp b/test/cpp/api/ordered_dict.cpp index 7b59a2ff8e296..d44c2786c1e5d 100644 --- a/test/cpp/api/ordered_dict.cpp +++ b/test/cpp/api/ordered_dict.cpp @@ -137,6 +137,20 @@ TEST(OrderedDictTest, CanIterateItems) { ASSERT_EQ(iterator, dict.end()); } +TEST(OrderedDictTest, EraseWorks) { + OrderedDict dict = {{"a", 1}, {"b", 2}, {"c", 3}}; + dict.erase("b"); + ASSERT_FALSE(dict.contains("b")); + ASSERT_EQ(dict["a"], 1); + ASSERT_EQ(dict["c"], 3); + dict.erase("a"); + ASSERT_FALSE(dict.contains("a")); + ASSERT_EQ(dict["c"], 3); + dict.erase("c"); + ASSERT_FALSE(dict.contains("c")); + ASSERT_TRUE(dict.is_empty()); +} + TEST(OrderedDictTest, ClearMakesTheDictEmpty) { OrderedDict dict = {{"a", 1}, {"b", 2}}; ASSERT_FALSE(dict.is_empty()); diff --git a/test/cpp/api/parallel.cpp b/test/cpp/api/parallel.cpp index 89891be7ee05e..5cbd3fe2d2d69 100644 --- a/test/cpp/api/parallel.cpp +++ b/test/cpp/api/parallel.cpp @@ -37,7 +37,7 @@ TEST_F(ParallelTest, DifferentiableScatter_MultiCUDA) { .allclose(input)); torch::Tensor sum = output[0].to({torch::kCUDA, 1}) + output[1]; - sum.backward(); + sum.backward(torch::ones_like(sum)); ASSERT_TRUE(input.grad().defined()); ASSERT_TRUE(input.grad().device().is_cpu()); @@ -61,7 +61,7 @@ TEST_F(ParallelTest, DifferentiableGather_MultiCUDA) { ASSERT_TRUE(chunks[0].to({torch::kCUDA, 0}).allclose(a)); ASSERT_TRUE(chunks[1].allclose(b)); - output.backward(); + output.backward(torch::ones_like(output)); ASSERT_TRUE(a.grad().defined()); ASSERT_EQ(a.grad().device(), torch::Device(torch::kCUDA, 0)); diff --git a/test/cpp/api/rnn.cpp b/test/cpp/api/rnn.cpp index 0bd23d0d99eca..5061021fd4b99 100644 --- a/test/cpp/api/rnn.cpp +++ b/test/cpp/api/rnn.cpp @@ -375,8 +375,8 @@ TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) { // Copy weights and biases from CPU GRU to CUDA GRU { at::NoGradGuard guard; - const auto num_directions = gru_cpu->options.bidirectional_ ? 2 : 1; - for (int64_t layer = 0; layer < gru_cpu->options.layers_; layer++) { + const auto num_directions = gru_cpu->options.bidirectional() ? 2 : 1; + for (int64_t layer = 0; layer < gru_cpu->options.layers(); layer++) { for (auto direction = 0; direction < num_directions; direction++) { const auto layer_idx = (layer * num_directions) + direction; copyParameters(gru_cuda, layer_idx, gru_cpu, layer_idx); @@ -430,8 +430,8 @@ TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) { // Copy weights and biases from CPU LSTM to CUDA LSTM { at::NoGradGuard guard; - const auto num_directions = lstm_cpu->options.bidirectional_ ? 2 : 1; - for (int64_t layer = 0; layer < lstm_cpu->options.layers_; layer++) { + const auto num_directions = lstm_cpu->options.bidirectional() ? 2 : 1; + for (int64_t layer = 0; layer < lstm_cpu->options.layers(); layer++) { for (auto direction = 0; direction < num_directions; direction++) { const auto layer_idx = (layer * num_directions) + direction; copyParameters(lstm_cuda, layer_idx, lstm_cpu, layer_idx); diff --git a/test/cpp/api/sequential.cpp b/test/cpp/api/sequential.cpp index cbb65dbdaedc9..87c6a57860cbf 100644 --- a/test/cpp/api/sequential.cpp +++ b/test/cpp/api/sequential.cpp @@ -1,14 +1,6 @@ #include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include #include #include diff --git a/test/cpp/api/serialize.cpp b/test/cpp/api/serialize.cpp index 993a33045a809..f34516193fc63 100644 --- a/test/cpp/api/serialize.cpp +++ b/test/cpp/api/serialize.cpp @@ -2,14 +2,7 @@ #include -#include -#include -#include -#include -#include -#include -#include -#include +#include #include diff --git a/test/cpp/api/static.cpp b/test/cpp/api/static.cpp index a4cae7f71a12e..df90d1fc03763 100644 --- a/test/cpp/api/static.cpp +++ b/test/cpp/api/static.cpp @@ -1,11 +1,8 @@ #include #include -#include -#include -#include - #include +#include #include #include diff --git a/test/cpp/api/tensor.cpp b/test/cpp/api/tensor.cpp index b7638d7707109..aa55543aacdd2 100644 --- a/test/cpp/api/tensor.cpp +++ b/test/cpp/api/tensor.cpp @@ -1,4 +1,5 @@ #include +#include #include @@ -8,6 +9,8 @@ #include #include +#include + template bool exactly_equal(at::Tensor left, T right) { return left.item() == right; @@ -194,6 +197,80 @@ TEST(TensorTest, ContainsCorrectValuesForManyValuesVariable) { ASSERT_TRUE(almost_equal(tensor[2], 3.125)); } +TEST(TensorTest, MultidimTensorCtor) { + { + auto tensor = torch::tensor({{1, 2}, {3, 4}}); + ASSERT_EQ(tensor.dtype(), torch::kInt); + ASSERT_EQ(tensor.sizes(), torch::IntArrayRef({2, 2})); + ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 5, torch::kInt).view(tensor.sizes()))); + ASSERT_FALSE(tensor.requires_grad()); + } + { + auto tensor = torch::tensor({{1, 2}, {3, 4}}, torch::dtype(torch::kFloat).requires_grad(true)); + ASSERT_EQ(tensor.dtype(), torch::kFloat); + ASSERT_EQ(tensor.sizes(), torch::IntArrayRef({2, 2})); + ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 5, torch::kFloat).view(tensor.sizes()))); + ASSERT_TRUE(tensor.requires_grad()); + } + { + auto tensor = torch::tensor({{{{{{{{1.0, 2.0, 3.0}}}}}, {{{{{4.0, 5.0, 6.0}}}}}, {{{{{7.0, 8.0, 9.0}}}}}}}}); + ASSERT_EQ(tensor.dtype(), torch::kDouble); + ASSERT_EQ(tensor.sizes(), torch::IntArrayRef({1, 1, 3, 1, 1, 1, 1, 3})); + ASSERT_TRUE(torch::allclose(tensor, torch::arange(1, 10, torch::kDouble).view(tensor.sizes()))); + ASSERT_FALSE(tensor.requires_grad()); + } + { + ASSERT_THROWS_WITH(torch::tensor({{{2, 3, 4}, {{5, 6}, {7}}}}), + "Expected all sub-lists to have sizes: 2 (e.g. {5, 6}), but got sub-list {7} with sizes: 1"); + } + { + ASSERT_THROWS_WITH(torch::tensor({{{1, 2.0}, {1, 2.0}}}), + "Expected all elements of the tensor to have the same scalar type: Int, but got element of scalar type: Double"); + } + { + ASSERT_THROWS_WITH(torch::tensor({{{true, 2.0, 3}, {true, 2.0, 3}}}), + "Expected all elements of the tensor to have the same scalar type: Bool, but got element of scalar type: Double"); + } +} + +TEST(TensorTest, MultidimTensorCtor_CUDA) { + { + auto tensor = torch::tensor( + {{{{{{{{1.0, 2.0, 3.0}}}}}, {{{{{4.0, 5.0, 6.0}}}}}, {{{{{7.0, 8.0, 9.0}}}}}}}}, + torch::dtype(torch::kDouble).device(torch::kCUDA)); + ASSERT_TRUE(tensor.device().is_cuda()); + ASSERT_EQ(tensor.dtype(), torch::kDouble); + ASSERT_EQ(tensor.sizes(), torch::IntArrayRef({1, 1, 3, 1, 1, 1, 1, 3})); + ASSERT_TRUE(torch::allclose( + tensor, + torch::arange(1, 10, torch::kDouble).view(tensor.sizes()).to(torch::kCUDA))); + ASSERT_FALSE(tensor.requires_grad()); + } +} + +TEST(TensorTest, PrettyPrintListInitTensor) { + { + ASSERT_EQ( + c10::str(torch::detail::ListInitTensor(1.1)), + "1.1"); + } + { + ASSERT_EQ( + c10::str(torch::detail::ListInitTensor({1.1, 2.2})), + "{1.1, 2.2}"); + } + { + ASSERT_EQ( + c10::str(torch::detail::ListInitTensor({{1, 2}, {3, 4}})), + "{{1, 2}, {3, 4}}"); + } + { + ASSERT_EQ( + c10::str(torch::detail::ListInitTensor({{{{{{{{1.1, 2.2, 3.3}}}}}, {{{{{4.4, 5.5, 6.6}}}}}, {{{{{7.7, 8.8, 9.9}}}}}}}})), + "{{{{{{{{1.1, 2.2, 3.3}}}}}, {{{{{4.4, 5.5, 6.6}}}}}, {{{{{7.7, 8.8, 9.9}}}}}}}}"); + } +} + TEST(TensorTest, ContainsCorrectValuesWhenConstructedFromVector) { std::vector v = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; auto tensor = at::tensor(v); @@ -311,3 +388,56 @@ TEST(TensorTest, DataPtr) { ASSERT_EQ(tensor_not_copy.data_ptr(), tensor.data_ptr()); ASSERT_EQ(tensor_not_copy.data_ptr(), tensor.data_ptr()); } + +TEST(TensorTest, Data) { + const auto tensor = torch::empty({3, 3}); + ASSERT_TRUE(torch::equal(tensor, tensor.data())); + + const auto tensor2 = at::empty({3, 3}); + ASSERT_THROW(tensor2.data(), c10::Error); +} + +TEST(TensorTest, BackwardAndGrad) { + auto x = torch::tensor({5}, at::TensorOptions().requires_grad(true)); + auto y = x * x; + y.backward(); + ASSERT_EQ(x.grad().item(), 10.0); + + x = at::tensor({5}); + y = x * x; + ASSERT_THROWS_WITH(y.backward(), "backward is not implemented for Tensor"); + ASSERT_THROWS_WITH(x.grad(), "grad is not implemented for Tensor"); +} + +TEST(TensorTest, BackwardCreatesOnesGrad) { + const auto x = torch::tensor({5}, at::TensorOptions().requires_grad(true)); + x.backward(); + ASSERT_TRUE(torch::equal(x.grad(), + torch::ones_like(x))); +} + +TEST(TensorTest, IsLeaf) { + auto x = torch::tensor({5}, at::TensorOptions().requires_grad(true)); + auto y = x * x; + ASSERT_TRUE(x.is_leaf()); + ASSERT_FALSE(y.is_leaf()); + + x = at::tensor({5}); + y = x * x; + const auto message = "is_leaf is not implemented for Tensor"; + ASSERT_THROWS_WITH(y.is_leaf(), message); + ASSERT_THROWS_WITH(x.is_leaf(), message); +} + +TEST(TensorTest, OutputNr) { + auto x = torch::tensor({5}, at::TensorOptions().requires_grad(true)); + auto y = x * x; + ASSERT_EQ(x.output_nr(), 0); + ASSERT_EQ(y.output_nr(), 0); + + x = at::tensor({5}); + y = x * x; + const auto message = "output_nr is not implemented for Tensor"; + ASSERT_THROWS_WITH(y.output_nr(), message); + ASSERT_THROWS_WITH(x.output_nr(), message); +} diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index 16b6aba7cc26c..2bb1945a5775f 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -19,6 +19,15 @@ if (USE_CUDA) ${TORCH_CUDA_LIBRARIES}) target_compile_definitions(test_jit PRIVATE USE_CUDA) +elseif (USE_ROCM) + target_link_libraries(test_jit PRIVATE + ${ROCM_HIPRTC_LIB} + ${PYTORCH_HIP_HCC_LIBRARIES} + ${TORCH_CUDA_LIBRARIES}) + + target_link_libraries(test_jit PRIVATE caffe2_gpu) + + target_compile_definitions(test_jit PRIVATE USE_ROCM) endif() if (INSTALL_TEST) diff --git a/test/cpp/jit/test_alias_analysis.cpp b/test/cpp/jit/test_alias_analysis.cpp index a0c282a98931d..4f68aa5b4cc28 100644 --- a/test/cpp/jit/test_alias_analysis.cpp +++ b/test/cpp/jit/test_alias_analysis.cpp @@ -1119,7 +1119,7 @@ void testAliasRegistration() { }) .aliasAnalysis(AliasAnalysisKind::CONSERVATIVE)); }, - "Tried to register operator foo::rand3(Tensor(a) arg1) -> Tensor(b) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA"); + "Tried to register operator foo::rand3(Tensor(a) arg1) -> (Tensor(b)) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA"); } { expectThrows( @@ -1132,7 +1132,7 @@ void testAliasRegistration() { }) .aliasAnalysis(AliasAnalysisKind::CONSERVATIVE)); }, - "Tried to register operator foo::rand4(Tensor(a) arg1) -> Tensor(a) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA"); + "Tried to register operator foo::rand4(Tensor(a) arg1) -> (Tensor(a)) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA"); } { expectThrows( @@ -1145,7 +1145,7 @@ void testAliasRegistration() { }) .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA)); }, - "Tried to register operator foo::rand5 with AliasAnalysisKind::FROM_SCHEMA, but the schema is inferred"); + "Tried to register operator foo::rand5(Tensor _0) -> (Tensor _0) with AliasAnalysisKind::FROM_SCHEMA, but the schema is inferred"); } { auto registry = torch::RegisterOperators().op( @@ -1235,7 +1235,7 @@ void testAliasRegistration() { [](at::Tensor t) -> at::Tensor { return t * 2; }) .aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION)); }, - "Tried to register operator foo::rand11(Tensor(a) arg1) -> Tensor(a) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA"); + "Tried to register operator foo::rand11(Tensor(a) arg1) -> (Tensor(a)) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA"); } { expectThrows( @@ -1247,7 +1247,7 @@ void testAliasRegistration() { [](at::Tensor t) -> at::Tensor { return t * 2; }) .aliasAnalysis(AliasAnalysisKind::PURE_FUNCTION)); }, - "Tried to register operator foo::rand12(Tensor(a) arg1) -> Tensor(b) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA"); + "Tried to register operator foo::rand12(Tensor(a) arg1) -> (Tensor(b)) with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA"); } } diff --git a/test/cpp/jit/test_constant_propagation.cpp b/test/cpp/jit/test_constant_propagation.cpp index ae0a4a02d6b01..dc561abdce342 100644 --- a/test/cpp/jit/test_constant_propagation.cpp +++ b/test/cpp/jit/test_constant_propagation.cpp @@ -77,10 +77,7 @@ graph(): c10::List list; auto li = IValue(list); std::vector tup = {li}; - push( - stack, - c10::ivalue::Tuple::create( - tup, TupleType::create({ListType::ofFloats()}))); + push(stack, c10::ivalue::Tuple::create(tup)); return 0; }; }, diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index 9d7ade23a402c..e6432279fc6c6 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -741,14 +741,14 @@ void testRecordFunction() { auto t = torch::randn({1, 2, 3}, at::kCPU); t.set_requires_grad(true); auto t2 = invokeTestRecordFunction(t); - t2.backward(); + t2.backward(torch::ones_like(t2)); auto eager_inputs = traced_inputs; traced_inputs.clear(); t = torch::randn({1, 2, 3}, at::kCPU); t.set_requires_grad(true); t2 = invokeTestRecordFunctionJIT(t); - t2.backward(); + t2.backward(torch::ones_like(t2)); auto jit_inputs = traced_inputs; traced_inputs.clear(); @@ -864,7 +864,7 @@ void testThreadLocalDebugInfo() { auto t = torch::randn({1, 2, 3}, at::kCPU); t.set_requires_grad(true); auto t2 = t.pow(2); - t2.backward(); + t2.backward(torch::ones_like(t2)); } autograd::profiler::popCallback(); diff --git a/test/cpp/jit/test_schema_matching.cpp b/test/cpp/jit/test_schema_matching.cpp new file mode 100644 index 0000000000000..d86a8d46235ff --- /dev/null +++ b/test/cpp/jit/test_schema_matching.cpp @@ -0,0 +1,92 @@ +#include +#include +#include +#include "test/cpp/jit/test_base.h" +#include "torch/csrc/jit/custom_operator.h" + +#include +#include + +namespace torch { +namespace jit { + +void testSchemaMatching() { + { + RegisterOperators reg({ + Operator( + "aten::test_vartype(t[] a, t b) -> (t)", + [](const Node* node) { + return [](Stack& stack) { + c10::List list; + double a; + pop(stack, list, a); + push(stack, a); + return 0; + }; + }), + }); + script::Module m("m"); + m.define(R"( + def test(self): + a = (1.0, 2.0) + return torch.test_vartype(a, 2.0) + )"); + auto result = m.run_method("test"); + TORCH_INTERNAL_ASSERT(result.toDouble() == 2.0); + + const std::string error_example = R"JIT( + def test_2(self): + a = (1.0, 2.0) + non_float = (1, 1) + return torch.test_vartype(a, non_float) + )JIT"; + + std::string err = ""; + try { + m.define(error_example); + } catch (const std::exception &e) { + err = e.what(); + } + TORCH_INTERNAL_ASSERT(err.find("previously matched to type") != std::string::npos); + } + { + RegisterOperators reg({ + Operator( + "aten::test_vartype2(t a, t[] b) -> (t[])", + [](const Node* node) { + return [](Stack& stack) { + double a; + c10::List list; + pop(stack, a, list); + push(stack, a); + return 0; + }; + }), + }); + script::Module m("m"); + m.define(R"JIT( + def test(self): + a = (1.0, 2.0) + return torch.test_vartype2(3.0, a) + )JIT"); + auto result = m.run_method("test"); + TORCH_INTERNAL_ASSERT(result.toDouble() == 3.0); + + static const auto error_exam2 = R"JIT( + def test_2(self): + a = (1, 2) + return torch.test_vartype2(3.0, a) + )JIT"; + + + std::string err = ""; + try { + m.define(error_exam2); + } catch (const std::exception &e) { + err = e.what(); + } + TORCH_INTERNAL_ASSERT(err.find("previously matched to type") != std::string::npos); + } +} +} // namespace jit +} // namespace torch diff --git a/test/cpp/jit/test_subgraph_rewriter.cpp b/test/cpp/jit/test_subgraph_rewriter.cpp new file mode 100644 index 0000000000000..36310f1cda008 --- /dev/null +++ b/test/cpp/jit/test_subgraph_rewriter.cpp @@ -0,0 +1,108 @@ +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +using namespace testing; + +void testFilterMatch() { + auto graph = std::make_shared(); + + script::parseIR( + R"IR( +graph(%0): + %a = a::aaa(%0) + %b = prim::Constant[value=1]() + %c = c::ccc(%a, %b) + return (%c))IR", + graph.get()); + + std::string pattern = R"IR( +graph(%a, %b): + %c = c::ccc(%a, %b) + return (%c))IR"; + Graph pattern_graph; + std::unordered_map vmap; + + script::parseIR( + pattern, + &pattern_graph, + vmap); + + auto filter = [](const Match& match, + const std::unordered_map& vmap) { + const auto& match_vmap = match.values_map; + auto b_node = match_vmap.at(vmap.at("b"))->node(); + return b_node->kind() == prim::Constant; + }; + + std::string replacement = R"IR( +graph(%a, %b): + %d = d::ddd(%a, %b) + return (%d))IR"; + + SubgraphRewriter rewriter; + rewriter.RegisterRewritePattern(pattern, replacement); + rewriter.runOnGraph(graph, filter); + + FileCheck().check("d::ddd") + ->check_not("c::ccc") + ->run(*graph); +} + +void testFilterNoMatch() { + auto graph = std::make_shared(); + script::parseIR( + R"IR( +graph(%0): + %a = a::aaa(%0) + %b = prim::Constant[value=1]() + %c = c::ccc(%a, %b) + return (%c))IR", + graph.get()); + + std::string pattern = R"IR( +graph(%a, %b): + %c = c::ccc(%a, %b) + return (%c))IR"; + Graph pattern_graph; + std::unordered_map vmap; + + script::parseIR( + pattern, + &pattern_graph, + vmap); + + auto filter = [](const Match& match, + const std::unordered_map& vmap) { + const auto& match_vmap = match.values_map; + auto b_node = match_vmap.at(vmap.at("b"))->node(); + // b_node is not Constant, so this won't match and we'll skip the rewrite + return b_node->kind() == prim::Assign; + }; + + std::string replacement = R"IR( +graph(%a, %b): + %d = d::ddd(%a, %b) + return (%d))IR"; + + SubgraphRewriter rewriter; + rewriter.RegisterRewritePattern(pattern, replacement); + rewriter.runOnGraph(graph, filter); + + FileCheck().check("c::ccc") + ->check_not("d::ddd") + ->run(*graph); + +} + + +void testSubgraphRewriter() { + testFilterMatch(); + testFilterNoMatch(); +} + +}} diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index ef1cbd81c26dd..a08d589cc92d5 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -19,6 +19,7 @@ namespace jit { _(CustomOperatorAliasing) \ _(IValueKWargs) \ _(CustomFusion) \ + _(SchemaMatching) \ _(Differentiate) \ _(DifferentiateWithRequiresGrad) \ _(FromQualString) \ @@ -52,6 +53,7 @@ namespace jit { _(RecordFunction) \ _(ThreadLocalDebugInfo) \ _(SubgraphMatching) \ + _(SubgraphRewriter) \ _(ModuleDefine) \ _(QualifiedName) \ _(ClassImport) \ diff --git a/test/cpp_api_parity/__init__.py b/test/cpp_api_parity/__init__.py index 38aca9b096485..d45862630569e 100644 --- a/test/cpp_api_parity/__init__.py +++ b/test/cpp_api_parity/__init__.py @@ -5,11 +5,9 @@ [ 'module_name', 'module_variant_name', - 'python_constructor_args', + 'test_instance', 'cpp_constructor_args', - 'example_inputs', 'has_parity', - 'python_module_class', 'cpp_sources', 'num_attrs_recursive', 'device', @@ -20,8 +18,17 @@ ParityStatus = namedtuple('ParityStatus', ['has_impl_parity', 'has_doc_parity']) -TorchNNModuleMetadata = namedtuple('TorchNNModuleMetadata', ['cpp_default_constructor_args', 'num_attrs_recursive', 'cpp_sources']) -TorchNNModuleMetadata.__new__.__defaults__ = (None, None, '') +TorchNNModuleMetadata = namedtuple( + 'TorchNNModuleMetadata', + [ + 'cpp_default_constructor_args', + 'num_attrs_recursive', + 'python_legacy_constructor_args', + 'python_optional_attribute_to_jit_type', + 'cpp_sources', + ] +) +TorchNNModuleMetadata.__new__.__defaults__ = (None, None, [], {}, '') ''' This function expects the parity tracker Markdown file to have the following format: diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md index 04b62cfe375d5..8619782e80964 100644 --- a/test/cpp_api_parity/parity-tracker.md +++ b/test/cpp_api_parity/parity-tracker.md @@ -16,16 +16,16 @@ torch.nn.ConvTranspose1d|No|No torch.nn.ConvTranspose2d|No|No torch.nn.ConvTranspose3d|No|No torch.nn.Unfold|No|No -torch.nn.Fold|No|No -torch.nn.MaxPool1d|No|No -torch.nn.MaxPool2d|No|No -torch.nn.MaxPool3d|No|No +torch.nn.Fold|Yes|No +torch.nn.MaxPool1d|Yes|No +torch.nn.MaxPool2d|Yes|No +torch.nn.MaxPool3d|Yes|No torch.nn.MaxUnpool1d|No|No torch.nn.MaxUnpool2d|No|No torch.nn.MaxUnpool3d|No|No -torch.nn.AvgPool1d|No|No -torch.nn.AvgPool2d|No|No -torch.nn.AvgPool3d|No|No +torch.nn.AvgPool1d|Yes|No +torch.nn.AvgPool2d|Yes|No +torch.nn.AvgPool3d|Yes|No torch.nn.FractionalMaxPool2d|No|No torch.nn.LPPool1d|No|No torch.nn.LPPool2d|No|No @@ -100,7 +100,7 @@ torch.nn.Embedding|No|No torch.nn.EmbeddingBag|No|No torch.nn.CosineSimilarity|No|No torch.nn.PairwiseDistance|No|No -torch.nn.L1Loss|No|No +torch.nn.L1Loss|Yes|No torch.nn.MSELoss|No|No torch.nn.CrossEntropyLoss|No|No torch.nn.CTCLoss|No|No diff --git a/test/cpp_api_parity/sample_module.py b/test/cpp_api_parity/sample_module.py index 965a6fbd504f6..59a9e46fe24c6 100644 --- a/test/cpp_api_parity/sample_module.py +++ b/test/cpp_api_parity/sample_module.py @@ -15,13 +15,25 @@ class SampleModule(torch.nn.Module): def __init__(self, has_parity, has_submodule, int_option=0, double_option=0.1, - bool_option=False, string_option='0', tensor_option=torch.empty(1)): + bool_option=False, string_option='0', tensor_option=torch.zeros(1), + int_or_tuple_option=0): super(SampleModule, self).__init__() self.has_parity = has_parity - self.register_parameter('param', torch.nn.Parameter(torch.empty(3, 4))) - self.register_buffer('buffer', torch.empty(4, 5)) if has_submodule: self.submodule = SampleModule(self.has_parity, False) + + # The following attributes will be included in the `num_attrs_recursive` count. + self.has_submodule = has_submodule + self.int_option = int_option + self.double_option = double_option + self.bool_option = bool_option + self.string_option = string_option + self.tensor_option = tensor_option + self.int_or_tuple_option = int_or_tuple_option + self.register_parameter('param', torch.nn.Parameter(torch.empty(3, 4))) + self.register_buffer('buffer', torch.empty(4, 5)) + self.attr = 0 + self.reset_parameters() def reset_parameters(self): @@ -47,17 +59,18 @@ def forward(self, x): struct C10_EXPORT SampleModuleOptions { SampleModuleOptions(bool has_submodule) : has_submodule_(has_submodule) {} TORCH_ARG(bool, has_submodule); - TORCH_ARG(int64_t, int_option); - TORCH_ARG(double, double_option); - TORCH_ARG(bool, bool_option); - TORCH_ARG(std::string, string_option); - TORCH_ARG(torch::Tensor, tensor_option); + TORCH_ARG(int64_t, int_option) = 0; + TORCH_ARG(double, double_option) = 0.1; + TORCH_ARG(bool, bool_option) = false; + TORCH_ARG(std::string, string_option) = "0"; + TORCH_ARG(torch::Tensor, tensor_option) = torch::zeros({1}); + TORCH_ARG(ExpandingArray<2>, int_or_tuple_option) = 0; }; struct C10_EXPORT SampleModuleImpl : public torch::nn::Cloneable { SampleModuleImpl(bool has_submodule) : SampleModuleImpl(SampleModuleOptions(has_submodule)) {} - explicit SampleModuleImpl(SampleModuleOptions options) { - if (options.has_submodule_) { + explicit SampleModuleImpl(SampleModuleOptions options) : options(std::move(options)) { + if (options.has_submodule()) { submodule = register_module("submodule", std::make_shared(false)); } reset(); @@ -70,6 +83,7 @@ def forward(self, x): torch::Tensor forward(torch::Tensor x) { return x + param * 2 + (submodule ? submodule->forward(x) : torch::zeros_like(x)); } + SampleModuleOptions options; torch::Tensor param; torch::Tensor buffer; int attr; @@ -84,25 +98,24 @@ def forward(self, x): module_tests = [ dict( module_name='SampleModule', + desc='has_parity', constructor_args=(True, True), cpp_constructor_args='(true)', input_size=(3, 4), - desc='has_parity', has_parity=True, ), dict( - module_name='SampleModule', - constructor_args=(False, True), + fullname='SampleModule_no_parity', + constructor=lambda: SampleModule(False, True), cpp_constructor_args='(true)', input_size=(3, 4), - desc='no_parity', has_parity=False, ), ] torch_nn_modules.module_metadata_map['SampleModule'] = TorchNNModuleMetadata( cpp_default_constructor_args='(true)', - num_attrs_recursive=6, + num_attrs_recursive=20, cpp_sources=SAMPLE_MODULE_CPP_SOURCE, ) diff --git a/test/cpp_api_parity/torch_nn_modules.py b/test/cpp_api_parity/torch_nn_modules.py index 0c02308a5af58..4875c04205244 100644 --- a/test/cpp_api_parity/torch_nn_modules.py +++ b/test/cpp_api_parity/torch_nn_modules.py @@ -1,3 +1,5 @@ +import torch + from cpp_api_parity import TorchNNModuleMetadata # NOTE: In order to let Python/C++ API parity test pass for any of the modules here, @@ -20,8 +22,15 @@ # as the Python module constructor. # # `num_attrs_recursive`: the number of attributes (including parameters, buffers and non-tensor -# attributes) of a module. If the module contains any submodule, the submodule's attributes -# also need to be counted. +# attributes) of the Python module. If the module contains any submodule, the submodule's +# attributes also need to be counted. +# +# `python_legacy_constructor_args`: (optional) list of legacy Python constructor args that are +# ignored in Python/C++ API parity test. +# +# `python_optional_attribute_to_jit_type`: (optional) map between Python None-able module +# attribute to its corresponding JIT type. For example, in `AvgPool2d`: +# { "divisor_override": torch._C.OptionalType(torch._C.IntType.get()) } module_metadata_map = { 'Conv1d': TorchNNModuleMetadata(), 'Conv2d': TorchNNModuleMetadata(), @@ -30,16 +39,43 @@ 'ConvTranspose2d': TorchNNModuleMetadata(), 'ConvTranspose3d': TorchNNModuleMetadata(), 'Unfold': TorchNNModuleMetadata(), - 'Fold': TorchNNModuleMetadata(), - 'MaxPool1d': TorchNNModuleMetadata(), - 'MaxPool2d': TorchNNModuleMetadata(), - 'MaxPool3d': TorchNNModuleMetadata(), + 'Fold': TorchNNModuleMetadata( + cpp_default_constructor_args="(3, 2)", + num_attrs_recursive=5, + ), + 'MaxPool1d': TorchNNModuleMetadata( + cpp_default_constructor_args="(2)", + num_attrs_recursive=6, + ), + 'MaxPool2d': TorchNNModuleMetadata( + cpp_default_constructor_args="(2)", + num_attrs_recursive=6, + ), + 'MaxPool3d': TorchNNModuleMetadata( + cpp_default_constructor_args="(2)", + num_attrs_recursive=6, + ), 'MaxUnpool1d': TorchNNModuleMetadata(), 'MaxUnpool2d': TorchNNModuleMetadata(), 'MaxUnpool3d': TorchNNModuleMetadata(), - 'AvgPool1d': TorchNNModuleMetadata(), - 'AvgPool2d': TorchNNModuleMetadata(), - 'AvgPool3d': TorchNNModuleMetadata(), + 'AvgPool1d': TorchNNModuleMetadata( + cpp_default_constructor_args="(2)", + num_attrs_recursive=5, + ), + 'AvgPool2d': TorchNNModuleMetadata( + cpp_default_constructor_args="(2)", + num_attrs_recursive=6, + python_optional_attribute_to_jit_type={ + "divisor_override": torch._C.OptionalType(torch._C.IntType.get()), + } + ), + 'AvgPool3d': TorchNNModuleMetadata( + cpp_default_constructor_args="(2)", + num_attrs_recursive=6, + python_optional_attribute_to_jit_type={ + "divisor_override": torch._C.OptionalType(torch._C.IntType.get()), + } + ), 'FractionalMaxPool2d': TorchNNModuleMetadata(), 'LPPool1d': TorchNNModuleMetadata(), 'LPPool2d': TorchNNModuleMetadata(), @@ -114,7 +150,11 @@ 'EmbeddingBag': TorchNNModuleMetadata(), 'CosineSimilarity': TorchNNModuleMetadata(), 'PairwiseDistance': TorchNNModuleMetadata(), - 'L1Loss': TorchNNModuleMetadata(), + 'L1Loss': TorchNNModuleMetadata( + cpp_default_constructor_args="()", + num_attrs_recursive=1, + python_legacy_constructor_args=['size_average', 'reduce'], + ), 'MSELoss': TorchNNModuleMetadata(), 'CrossEntropyLoss': TorchNNModuleMetadata(), 'CTCLoss': TorchNNModuleMetadata(), diff --git a/test/cpp_extensions/msnpu_extension.cpp b/test/cpp_extensions/msnpu_extension.cpp index c265905e12ee1..2abc1a6b4a7b6 100644 --- a/test/cpp_extensions/msnpu_extension.cpp +++ b/test/cpp_extensions/msnpu_extension.cpp @@ -1,6 +1,6 @@ #include -#include +#include using namespace at; @@ -48,22 +48,24 @@ std::tuple fake_convolution_backward( } void init_msnpu_extension() { - globalATenDispatch().registerOp( - Backend::MSNPU, - "aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", - &empty_override); - globalATenDispatch().registerOp( - Backend::MSNPU, - "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", - &add_override); - globalATenDispatch().registerOp( - Backend::MSNPU, - "aten::convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor", - &fake_convolution); - globalATenDispatch().registerOp( - Backend::MSNPU, - "aten::convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)", - &fake_convolution_backward); + static auto registry = torch::RegisterOperators() + .op(torch::RegisterOperators::options() + .schema("aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor") + .impl_unboxedOnlyKernel(TensorTypeId::MSNPUTensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor") + .impl_unboxedOnlyKernel(TensorTypeId::MSNPUTensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor") + .impl_unboxedOnlyKernel(TensorTypeId::MSNPUTensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + .op(torch::RegisterOperators::options() + .schema("aten::convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)") + .impl_unboxedOnlyKernel(TensorTypeId::MSNPUTensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) + ; } // TODO: Extend this to exercise multi-device setting. In that case, diff --git a/test/jit_utils.py b/test/jit_utils.py index 80a3d203b7e3c..02776580aeee0 100644 --- a/test/jit_utils.py +++ b/test/jit_utils.py @@ -158,6 +158,9 @@ def extract_files(buffer): for a, b in zip(code_files, code_files_2): self.assertMultiLineEqual(a, b) + if isinstance(m, torch._C.ScriptModule): + self.assertTrue(torch._C._ivalue_tags_match(m, imported._c)) + def emitFunctionHook(self, func): # func has invalid names for export, skip the jitter check @@ -488,6 +491,24 @@ def runAndSaveRNG(self, func, inputs, kwargs=None): results = func(*inputs, **kwargs) return results + def checkModule(self, nn_module, args): + """ + Check that a nn.Module's results in Script mode match eager and that it + can be exported + """ + sm = torch.jit.script(nn_module) + + with freeze_rng_state(): + eager_out = nn_module(*args) + + with freeze_rng_state(): + script_out = sm(*args) + + self.assertEqual(eager_out, script_out) + self.assertExportImportModule(sm, args) + + return sm + @contextmanager def enable_profiling_mode(): torch._C._jit_set_profiling_mode(True) diff --git a/test/onnx/expect/TestOperators.test_dyn_arange.expect b/test/onnx/expect/TestOperators.test_dyn_arange.expect new file mode 100644 index 0000000000000..c4f208507fef5 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_dyn_arange.expect @@ -0,0 +1,129 @@ +ir_version: 4 +producer_name: "pytorch" +producer_version: "1.2" +graph { + node { + output: "1" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 7 + raw_data: "\000\000\000\000\000\000\000\000" + } + type: TENSOR + } + } + node { + input: "0" + output: "2" + op_type: "Shape" + } + node { + input: "2" + input: "1" + output: "3" + op_type: "Gather" + attribute { + name: "axis" + i: 0 + type: INT + } + } + node { + input: "3" + output: "4" + op_type: "Unsqueeze" + attribute { + name: "axes" + ints: 0 + type: INTS + } + } + node { + input: "4" + output: "5" + op_type: "ConstantOfShape" + attribute { + name: "value" + t { + dims: 1 + data_type: 7 + raw_data: "\001\000\000\000\000\000\000\000" + } + type: TENSOR + } + } + node { + input: "5" + output: "6" + op_type: "NonZero" + } + node { + input: "6" + output: "7" + op_type: "Transpose" + attribute { + name: "perm" + ints: 1 + ints: 0 + type: INTS + } + } + node { + input: "7" + output: "8" + op_type: "Squeeze" + attribute { + name: "axes" + ints: 1 + type: INTS + } + } + node { + input: "8" + output: "9" + op_type: "Cast" + attribute { + name: "to" + i: 7 + type: INT + } + } + name: "torch-jit-export" + input { + name: "0" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 5 + } + dim { + dim_value: 3 + } + dim { + dim_value: 2 + } + } + } + } + } + output { + name: "9" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_value: 5 + } + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/test/onnx/expect/TestOperators.test_gelu.expect b/test/onnx/expect/TestOperators.test_gelu.expect new file mode 100644 index 0000000000000..54c270d9b60b7 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_gelu.expect @@ -0,0 +1,118 @@ +ir_version: 4 +producer_name: "pytorch" +producer_version: "1.2" +graph { + node { + output: "1" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + raw_data: "\363\004\265?" + } + type: TENSOR + } + } + node { + input: "x" + input: "1" + output: "2" + op_type: "Div" + } + node { + input: "2" + output: "3" + op_type: "Erf" + } + node { + output: "4" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + raw_data: "\000\000\200?" + } + type: TENSOR + } + } + node { + input: "3" + input: "4" + output: "5" + op_type: "Add" + } + node { + input: "x" + input: "5" + output: "6" + op_type: "Mul" + } + node { + output: "7" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + raw_data: "\000\000\000?" + } + type: TENSOR + } + } + node { + input: "6" + input: "7" + output: "8" + op_type: "Mul" + } + name: "torch-jit-export" + input { + name: "x" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } + output { + name: "8" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + dim { + dim_value: 5 + } + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/test/onnx/expect/TestOperators.test_mm.expect b/test/onnx/expect/TestOperators.test_mm.expect index a94f2b8766360..967960d5fecae 100644 --- a/test/onnx/expect/TestOperators.test_mm.expect +++ b/test/onnx/expect/TestOperators.test_mm.expect @@ -10,7 +10,7 @@ graph { t { dims: 1 data_type: 1 - raw_data: "\000\000\000\000" + raw_data: "\000\000\200?" } type: TENSOR } diff --git a/test/onnx/expect/TestOperators.test_round.expect b/test/onnx/expect/TestOperators.test_round.expect new file mode 100644 index 0000000000000..fe281c826ec19 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_round.expect @@ -0,0 +1,40 @@ +ir_version: 4 +producer_name: "pytorch" +producer_version: "1.2" +graph { + node { + input: "0" + output: "1" + op_type: "Round" + } + name: "torch-jit-export" + input { + name: "0" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "1" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 11 +} diff --git a/test/onnx/expect/TestOperators.test_rsqrt.expect b/test/onnx/expect/TestOperators.test_rsqrt.expect new file mode 100644 index 0000000000000..f1987e4ce3dc9 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_rsqrt.expect @@ -0,0 +1,64 @@ +ir_version: 4 +producer_name: "pytorch" +producer_version: "1.2" +graph { + node { + input: "0" + output: "1" + op_type: "Sqrt" + } + node { + output: "2" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + raw_data: "\000\000\200?" + } + type: TENSOR + } + } + node { + input: "2" + input: "1" + output: "3" + op_type: "Div" + } + name: "torch-jit-export" + input { + name: "0" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "3" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/test/onnx/expect/TestOperators.test_upsample_nearest.expect b/test/onnx/expect/TestOperators.test_upsample_nearest.expect index d8988640af495..62fa6edec6807 100644 --- a/test/onnx/expect/TestOperators.test_upsample_nearest.expect +++ b/test/onnx/expect/TestOperators.test_upsample_nearest.expect @@ -31,26 +31,36 @@ graph { } } node { + input: "3" output: "4" + op_type: "Cast" + attribute { + name: "to" + i: 11 + type: INT + } + } + node { + output: "5" op_type: "Constant" attribute { name: "value" t { - data_type: 7 - raw_data: "\002\000\000\000\000\000\000\000" + data_type: 11 + raw_data: "\000\000\000\000\000\000\000@" } type: TENSOR } } node { - input: "3" input: "4" - output: "5" + input: "5" + output: "6" op_type: "Mul" } node { - input: "5" - output: "6" + input: "6" + output: "7" op_type: "Cast" attribute { name: "to" @@ -59,12 +69,12 @@ graph { } } node { - input: "6" - output: "7" + input: "7" + output: "8" op_type: "Floor" } node { - output: "8" + output: "9" op_type: "Constant" attribute { name: "value" @@ -77,13 +87,13 @@ graph { } node { input: "input" - output: "9" + output: "10" op_type: "Shape" } node { + input: "10" input: "9" - input: "8" - output: "10" + output: "11" op_type: "Gather" attribute { name: "axis" @@ -92,26 +102,36 @@ graph { } } node { - output: "11" + input: "11" + output: "12" + op_type: "Cast" + attribute { + name: "to" + i: 11 + type: INT + } + } + node { + output: "13" op_type: "Constant" attribute { name: "value" t { - data_type: 7 - raw_data: "\002\000\000\000\000\000\000\000" + data_type: 11 + raw_data: "\000\000\000\000\000\000\000@" } type: TENSOR } } node { - input: "10" - input: "11" - output: "12" + input: "12" + input: "13" + output: "14" op_type: "Mul" } node { - input: "12" - output: "13" + input: "14" + output: "15" op_type: "Cast" attribute { name: "to" @@ -120,13 +140,13 @@ graph { } } node { - input: "13" - output: "14" + input: "15" + output: "16" op_type: "Floor" } node { - input: "7" - output: "15" + input: "8" + output: "17" op_type: "Unsqueeze" attribute { name: "axes" @@ -135,8 +155,8 @@ graph { } } node { - input: "14" - output: "16" + input: "16" + output: "18" op_type: "Unsqueeze" attribute { name: "axes" @@ -145,9 +165,9 @@ graph { } } node { - input: "15" - input: "16" - output: "17" + input: "17" + input: "18" + output: "19" op_type: "Concat" attribute { name: "axis" @@ -156,7 +176,7 @@ graph { } } node { - output: "18" + output: "20" op_type: "Constant" attribute { name: "value" @@ -169,8 +189,8 @@ graph { } } node { - input: "17" - output: "19" + input: "19" + output: "21" op_type: "Cast" attribute { name: "to" @@ -180,12 +200,12 @@ graph { } node { input: "input" - output: "20" + output: "22" op_type: "Shape" } node { - input: "20" - output: "21" + input: "22" + output: "23" op_type: "Slice" attribute { name: "axes" @@ -204,8 +224,8 @@ graph { } } node { - input: "21" - output: "22" + input: "23" + output: "24" op_type: "Cast" attribute { name: "to" @@ -214,15 +234,15 @@ graph { } } node { - input: "19" - input: "22" - output: "23" + input: "21" + input: "24" + output: "25" op_type: "Div" } node { - input: "18" - input: "23" - output: "24" + input: "20" + input: "25" + output: "26" op_type: "Concat" attribute { name: "axis" @@ -232,8 +252,8 @@ graph { } node { input: "input" - input: "24" - output: "25" + input: "26" + output: "27" op_type: "Upsample" attribute { name: "mode" @@ -265,7 +285,7 @@ graph { } } output { - name: "25" + name: "27" type { tensor_type { elem_type: 1 diff --git a/test/onnx/test_onnx_opset.py b/test/onnx/test_onnx_opset.py index 2b7fcf29228c4..0efe3dd875d63 100644 --- a/test/onnx/test_onnx_opset.py +++ b/test/onnx/test_onnx_opset.py @@ -271,6 +271,7 @@ def forward(self, x): ops = [{"op_name" : "Constant"}, {"op_name" : "ConstantOfShape"}, + {"op_name" : "Cast"}, {"op_name" : "Add"}] ops = {9 : ops, 10 : ops} x = torch.tensor(12) @@ -395,41 +396,6 @@ def forward(self, input): x = torch.randn(2, 3, 4) check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10]) - def test_advanced_index(self): - class MyModule(Module): - def forward(self, x): - return x[:, torch.tensor([[0, 2], [1, 1]]), :, torch.tensor([2, 1]), torch.tensor([0, 3])] - - x = torch.randn(3, 4, 5, 6, 7) - - ops = [{'op_name': 'Constant'}, - {'op_name': 'Constant'}, - {'op_name': 'Constant'}, - {'op_name': 'Shape'}, - {'op_name': 'Constant'}, - {'attributes': [{'i': 0, 'name': 'axis', 'type': 2}], 'op_name': 'Gather'}, - {'op_name': 'Constant'}, - {'attributes': [{'i': 0, 'name': 'axis', 'type': 2}], 'op_name': 'Gather'}, - {'op_name': 'Constant'}, - {'attributes': [{'i': 0, 'name': 'axis', 'type': 2}], 'op_name': 'Gather'}, - {'op_name': 'Constant'}, - {'attributes': [{'i': 0, 'name': 'axis', 'type': 2}], 'op_name': 'Gather'}, - {'attributes': [{'ints': [1, 3, 4, 0, 2], 'name': 'perm', 'type': 7}], 'op_name': 'Transpose'}, - {'attributes': [{'i': 3, 'name': 'axis', 'type': 2}], 'op_name': 'Flatten'}, - {'op_name': 'Mul'}, - {'op_name': 'Add'}, - {'op_name': 'Mul'}, - {'op_name': 'Mul'}, - {'op_name': 'Add'}, - {'attributes': [{'i': 0, 'name': 'axis', 'type': 2}], 'op_name': 'Gather'}, - {'op_name': 'Shape'}, - {'attributes': [{'i': 0, 'name': 'axis', 'type': 2}], 'op_name': 'Concat'}, - {'op_name': 'Reshape'}] - - ops = {9 : ops, 10 : ops} - - check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10]) - if __name__ == '__main__': run_tests() diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 375d043dc5aa4..e6cfef1827bf0 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -433,6 +433,10 @@ def test_sqrt(self): x = torch.randn(3, 4, requires_grad=True) self.assertONNX(lambda x: torch.sqrt(x), x) + def test_rsqrt(self): + x = torch.randn(3, 4, requires_grad=True) + self.assertONNX(lambda x: torch.rsqrt(x), x) + def test_equal(self): x = torch.randn(1, 2, 3, 1, requires_grad=False).int() y = torch.randn(1, 4, requires_grad=False).int() @@ -660,7 +664,6 @@ def test_master_opset(self): def test_std(self): x = torch.randn(2, 3, 4).float() - y = torch.randn(2, 3, 4).float() self.assertONNX(lambda x: torch.std(x, dim=(0, 1), unbiased=True, keepdim=True), x) def test_cumsum(self): @@ -709,6 +712,14 @@ def forward(self, scores, bbox_deltas, im_info, anchors): inputs = (scores, bbox_deltas, im_info, anchors) self.assertONNX(model, inputs) + def test_dyn_arange(self): + class TestModel(torch.nn.Module): + def forward(self, input): + return torch.arange(input.shape[0]) + + input = torch.randn(5, 3, 2) + self.assertONNX(TestModel(), input) + def test_layer_norm_aten(self): model = torch.nn.LayerNorm([10, 10]) x = torch.randn(20, 5, 10, 10) @@ -723,11 +734,19 @@ def test_frobenius_norm(self): x = torch.randn(2, 3, 4).float() self.assertONNX(lambda x: torch.norm(x, p="fro", dim=(0, 1), keepdim=True), x) + def test_gelu(self): + x = torch.randn(2, 3, 4, 5, requires_grad=True) + self.assertONNX(lambda x: torch.nn.functional.gelu(x), x) + def test_unique(self): x = torch.randint(3, (2, 3, 4, 5)).float() self.assertONNX(lambda x: torch.unique(x, dim=0, sorted=True, return_inverse=False, return_counts=True), x, opset_version=11) + def test_round(self): + x = torch.tensor([0.9920, -1.0362, -1.5000, 2.5000], requires_grad=True) + self.assertONNX(lambda x: torch.round(x), x, opset_version=11) + if __name__ == '__main__': no_onnx_dep_flag = '--no-onnx' diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index 874f067e73fd4..cb85b1ed7348c 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -684,6 +684,14 @@ def forward(self, input): input = torch.empty(BATCH_SIZE, 10, 10).uniform_(4, 9) self.run_model_test(MyModel(), train=False, input=input, batch_size=BATCH_SIZE) + def test_rsqrt(self): + class MyModel(torch.nn.Module): + def forward(self, input): + return input.rsqrt() + + input = torch.randn(4, 2, 3, requires_grad=True) + self.run_model_test(MyModel(), train=False, input=input, batch_size=BATCH_SIZE) + def test_log(self): class MyModel(torch.nn.Module): def __init__(self): @@ -906,6 +914,45 @@ def forward(self, ma, m1, m2): m2 = torch.randn(4, 5) self.run_model_test(MyModel(), train=False, input=(ma, m1, m2), batch_size=BATCH_SIZE, use_gpu=False) + def test_scalar_type(self): + class ArithmeticModel(torch.nn.Module): + def forward(self, x): + return x.size(0) * 2 * x + + x = torch.ones(2, 3, dtype=torch.float32) + self.run_model_test(ArithmeticModel(), input=x, train=False, batch_size=BATCH_SIZE) + + class ReciprocalModel(torch.nn.Module): + def forward(self, x): + return torch.reciprocal(x) + + x = torch.tensor([2.0, 4.0], dtype=torch.double) + self.run_model_test(ReciprocalModel(), input=x, train=False, batch_size=BATCH_SIZE) + + class ComparisonModel(torch.nn.Module): + def forward(self, x, y): + return x.ge(0.5) & y.le(2) + + x = torch.ones(2, 3, dtype=torch.int32) + y = torch.ones(2, 3, dtype=torch.float32) + self.run_model_test(ComparisonModel(), input=(x, y), train=False, batch_size=BATCH_SIZE) + + # TODO: re-enable the two tests after https://github.com/pytorch/pytorch/issues/26328 is resolved. + class MatMulModel(torch.nn.Module): + def forward(self, x, y): + return torch.mm(x, y) + + x = torch.ones(3, 4) + y = torch.ones(4, 5) + # self.run_model_test(MatMulModel(), input=(x, y), train=False, batch_size=BATCH_SIZE) + + class AddMMModel(torch.nn.Module): + def forward(self, x): + return torch.mm(x, x) + x + + x = torch.ones(3, 3) + # self.run_model_test(AddMMModel(), input=x, train=False, batch_size=BATCH_SIZE) + # test for a pytorch optimization pass, see https://github.com/pytorch/pytorch/pull/7872 def test_consecutive_transposes(self): class MyModel(torch.nn.Module): @@ -1338,6 +1385,28 @@ def forward(self, x): self.run_model_test(FullClass(), train=False, input=(x,), batch_size=BATCH_SIZE, use_gpu=False, example_outputs=FullClass()(x)) + def test_clamp(self): + class ClampModel(torch.nn.Module): + def forward(self, x): + return x.clamp(-0.5, 0.5) + + x = torch.randn(3, 4) + self.run_model_test(ClampModel(), train=False, input=(x,), batch_size=BATCH_SIZE) + + class ClampMinModel(torch.nn.Module): + def forward(self, x): + return x.clamp(min=-0.5) + + x = torch.randn(3, 4) + self.run_model_test(ClampMinModel(), train=False, input=(x,), batch_size=BATCH_SIZE) + + class ClampMaxModel(torch.nn.Module): + def forward(self, x): + return x.clamp(max=0.5) + + x = torch.randn(3, 4) + self.run_model_test(ClampMaxModel(), train=False, input=(x,), batch_size=BATCH_SIZE) + @skipIfUnsupportedMinOpsetVersion(9) def test_where_functional(self): class WhereFunctional(torch.nn.Module): @@ -2110,6 +2179,18 @@ def forward(self, x): x = torch.arange(16).view(2, 2, 4).to(torch.float32) self.run_model_test(MaskedFillModel2(), input=(x, ), train=False, batch_size=BATCH_SIZE) + @skipIfUnsupportedMinOpsetVersion(9) + def test_gelu(self): + class GeluModel(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.gelu(x) + + model = GeluModel() + inputs = torch.randn(2, 4, 5, 6, requires_grad=True) + outputs = model(inputs) + self.run_model_test(model, train=False, input=(inputs,), batch_size=BATCH_SIZE, + example_outputs=(outputs,)) + # a bit of metaprogramming to set up all the rnn tests diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index f1dbafe1dd705..4f39595176072 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -19,9 +19,34 @@ import model_defs.word_language_model as word_language_model +def ort_test_with_input(ort_sess, input, output, rtol, atol): + input, _ = torch.jit._flatten(input) + output, _ = torch.jit._flatten(output) + + def to_numpy(tensor): + if tensor.requires_grad: + return tensor.detach().cpu().numpy() + else: + return tensor.cpu().numpy() + + inputs = list(map(to_numpy, input)) + outputs = list(map(to_numpy, output)) + + ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs)) + ort_outs = ort_sess.run(None, ort_inputs) + + # compare onnxruntime and PyTorch results + assert len(outputs) == len(ort_outs), "number of outputs differ" + + # compare onnxruntime and PyTorch results + [np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)] + + def run_model_test(self, model, batch_size=2, state_dict=None, input=None, use_gpu=True, rtol=0.001, atol=1e-7, - example_outputs=None, do_constant_folding=True): + example_outputs=None, do_constant_folding=True, + dynamic_axes=None, test_with_inputs=None, + input_names=None, output_names=None): model.eval() if input is None: @@ -40,31 +65,26 @@ def run_model_test(self, model, batch_size=2, state_dict=None, opset_version=self.opset_version, example_outputs=output, do_constant_folding=do_constant_folding, - keep_initializers_as_inputs=self.keep_initializers_as_inputs) - - input, _ = torch.jit._flatten(input) - output, _ = torch.jit._flatten(output) - - def to_numpy(tensor): - if tensor.requires_grad: - return tensor.detach().cpu().numpy() - else: - return tensor.cpu().numpy() - - inputs = list(map(to_numpy, input)) - outputs = list(map(to_numpy, output)) + keep_initializers_as_inputs=self.keep_initializers_as_inputs, + dynamic_axes=dynamic_axes, + input_names=input_names, output_names=output_names) # compute onnxruntime output prediction ort_sess = onnxruntime.InferenceSession(f.getvalue()) - ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs)) - ort_outs = ort_sess.run(None, ort_inputs) + ort_test_with_input(ort_sess, input, output, rtol, atol) - # compare onnxruntime and PyTorch results - assert len(outputs) == len(ort_outs), "number of outputs differ" + # if addiional test inputs are provided run the onnx + # model with these inputs and check the outputs + if test_with_inputs is not None: + for test_input in test_with_inputs: + if isinstance(test_input, torch.Tensor): + test_input = (test_input,) + output = model(*test_input) + if isinstance(output, torch.Tensor): + output = (output,) - # compare onnxruntime and PyTorch results - [np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)] + ort_test_with_input(ort_sess, test_input, output, rtol, atol) class TestONNXRuntime(unittest.TestCase): @@ -72,10 +92,20 @@ class TestONNXRuntime(unittest.TestCase): opset_version = _export_onnx_opset_version keep_initializers_as_inputs = True # For IR version 3 type export. - def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=True, batch_size=2, use_gpu=True): + def setUp(self): + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + np.random.seed(seed=0) + + def run_test(self, model, input, rtol=1e-3, atol=1e-7, do_constant_folding=True, + batch_size=2, use_gpu=True, dynamic_axes=None, test_with_inputs=None, + input_names=None, output_names=None): run_model_test(self, model, batch_size=batch_size, input=input, use_gpu=use_gpu, rtol=rtol, atol=atol, - do_constant_folding=do_constant_folding) + do_constant_folding=do_constant_folding, + dynamic_axes=dynamic_axes, test_with_inputs=test_with_inputs, + input_names=input_names, output_names=output_names) def run_word_language_model(self, model_name): ntokens = 50 @@ -116,6 +146,55 @@ def test_index_2d_sliceint(self): def test_index_2d_neg_slice(self): self._test_index_generic(lambda input: input[0:-1, :]) + def test_clamp(self): + class ClampModel(torch.nn.Module): + def forward(self, x): + return x.clamp(-0.5, 0.5) + + x = torch.randn(3, 4) + self.run_test(ClampModel(), x) + + class ClampMinModel(torch.nn.Module): + def forward(self, x): + return x.clamp(min=-0.5) + + x = torch.randn(3, 4) + self.run_test(ClampMinModel(), x) + + class ClampMaxModel(torch.nn.Module): + def forward(self, x): + return x.clamp(max=0.5) + + x = torch.randn(3, 4) + self.run_test(ClampMaxModel(), x) + + @skipIfUnsupportedMinOpsetVersion(11) + def test_clamp_dyn(self): + class ClampMaxModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + return x.clamp(None, x.size(0)) + + x = torch.arange(16).view(4, 4).float() + self.run_test(ClampMaxModel(), x) + + + class ClampMinModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + return x.clamp(x.size(0), None) + + x = torch.arange(16).view(4, 4).float() + self.run_test(ClampMinModel(), x) + + class ClampMinMaxModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, x): + return x.clamp(x.size(0), x.size(1)) + + x = torch.arange(16).view(2, 8).float() + self.run_test(ClampMinMaxModel(), x) + @skipIfUnsupportedMinOpsetVersion(9) def test_full_trace(self): class FullModel(torch.nn.Module): @@ -274,6 +353,22 @@ def forward(self, x): x = torch.rand(5, 5, 5) self.run_test(DynamicSliceExportMod(), x) + @skipIfUnsupportedMinOpsetVersion(9) + def test_arange(self): + class ArangeModel(torch.nn.Module): + def forward(self, input): + return torch.arange(input.shape[0]), \ + torch.arange(12), \ + torch.arange(start=input.shape[0], end=input.shape[0] + 5) + + x = torch.randn(5, 3, 2) + y = torch.randn(8, 3, 2) + self.run_test(ArangeModel(), x, test_with_inputs=[y], + input_names=['input_1'], + output_names=['output_1', 'output_2', 'output_3'], + dynamic_axes={'input_1': [0], + 'output_1': [0]}) + def _test_index_generic(self, fn): class MyModel(torch.nn.Module): def __init__(self): @@ -743,6 +838,52 @@ def forward(self, x): x = torch.randn(2, 16, 4, 3, requires_grad=True) self.run_test(PixelShuffle(), x) + @skipIfUnsupportedMinOpsetVersion(9) + def test_scalar_type(self): + class ArithmeticModel(torch.nn.Module): + def forward(self, x): + return x.size(0) * 2 * x + + x = torch.ones(2, 3, dtype=torch.float32) + self.run_test(ArithmeticModel(), x) + + class ReciprocalModel(torch.nn.Module): + def forward(self, x): + return torch.reciprocal(x) + + x = torch.tensor([2.0, 4.0], dtype=torch.double) + self.run_test(ReciprocalModel(), x) + + class ComparisonModel(torch.nn.Module): + def forward(self, x, y): + return x.ge(0.5) & y.le(2) + + x = torch.ones(2, 3, dtype=torch.int32) + y = torch.ones(2, 3, dtype=torch.float32) + self.run_test(ComparisonModel(), (x, y)) + + # TODO: re-enable the two tests after https://github.com/pytorch/pytorch/issues/26328 is resolved. + class MatMulModel(torch.nn.Module): + def forward(self, x): + return (torch.mm(x, x) + x + torch.mm(x, x) + x) + + x = torch.ones(3, 3) + # self.run_test(MatMulModel(), x) + + class AddMMModel(torch.nn.Module): + def forward(self, x): + return torch.mm(x, x) + x + + x = torch.ones(3, 3) + # self.run_test(AddMMModel(), x) + + class FullModel(torch.nn.Module): + # add is used for exporting full + def forward(self, x): + return torch.full((3, 4), x) + x = torch.tensor(12) + self.run_test(FullModel(), x) + def test_frobenius_norm(self): class NormModel(torch.nn.Module): def forward(self, x): @@ -759,6 +900,30 @@ def forward(self, x): x = torch.randn(4, 2, 3, requires_grad=True) self.run_test(NormModel(), x) + @skipIfUnsupportedMinOpsetVersion(9) + def test_gelu(self): + class GeluModel(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.gelu(x) + + x = torch.randn(2, 4, 5, 6, requires_grad=True) + self.run_test(GeluModel(), x) + + def test_rsqrt(self): + class RsqrtModel(torch.nn.Module): + def forward(self, x): + return x.rsqrt() + + x = torch.randn(4, 2, 3, requires_grad=True).to(dtype=torch.float64) + self.run_test(RsqrtModel(), x) + + def test_rsqrt_zeros(self): + class RsqrtModel(torch.nn.Module): + def forward(self, x): + return x.rsqrt() + x = torch.zeros(4, 2, 3, requires_grad=True).to(dtype=torch.float64) + self.run_test(RsqrtModel(), x) + # TODO: enable opset 11 test once ORT support for unique is in @skipIfUnsupportedOpsetVersion([11]) @skipIfUnsupportedMinOpsetVersion(11) @@ -792,6 +957,33 @@ def forward(self, input): model = CumSum() self.run_test(model, x) + def test_log(self): + class Log(torch.nn.Module): + def forward(self, input): + return torch.log(input) + x = torch.rand(2, 3, 4) + model = Log() + self.run_test(model, x) + + def test_log1p(self): + class Log1p(torch.nn.Module): + def forward(self, input): + return torch.log1p(input) + x = torch.rand(2, 3, 4) + model = Log1p() + self.run_test(model, x) + + # TODO: remove the skip tag once ORT implementation is in place + @skipIfUnsupportedMinOpsetVersion(11) + @skipIfUnsupportedOpsetVersion([11]) + def test_round(self): + class Round(torch.nn.Module): + def forward(self, x): + return torch.round(x) + + x = torch.tensor([0.9920, -1.0362, -1.5000, 3.5000], requires_grad=True) + self.run_test(Round(), x) + def _dispatch_rnn_test(self, name, *args, **kwargs): if name == 'elman': self._elman_rnn_test(*args, **kwargs) @@ -832,7 +1024,7 @@ def make_input(batch_size): return input input = make_input(RNN_BATCH_SIZE) - self.run_test(model, input, batch_size=RNN_BATCH_SIZE, atol=1e-7) + self.run_test(model, input, batch_size=RNN_BATCH_SIZE) # test that the model still runs with a different batch size other_input = make_input(RNN_BATCH_SIZE + 1) @@ -908,11 +1100,11 @@ def make_input(batch_size): return input input = make_input(RNN_BATCH_SIZE) - self.run_test(model, input, batch_size=RNN_BATCH_SIZE, atol=1e-5) + self.run_test(model, input, batch_size=RNN_BATCH_SIZE) # test that the model still runs with a different batch size other_input = make_input(RNN_BATCH_SIZE + 1) - self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1, atol=1e-5) + self.run_test(model, other_input, batch_size=RNN_BATCH_SIZE + 1) def make_test(name, base, layer, bidirectional, initial_state, diff --git a/test/run_test.py b/test/run_test.py index 4349d3a8bfe5c..08ad1eb1bbc25 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -60,7 +60,9 @@ 'jit_fuser', 'tensorboard', 'namedtensor', + 'type_promotion', 'jit_disabled', + 'function_schema', ] # skip < 3.6 b/c fstrings added in 3.6 @@ -80,11 +82,7 @@ 'nccl', ] -DISTRIBUTED_TESTS_CONFIG = { - 'gloo': { - 'WORLD_SIZE': '2' if torch.cuda.device_count() == 2 else '3' - }, -} +DISTRIBUTED_TESTS_CONFIG = {} if dist.is_available(): @@ -96,7 +94,10 @@ DISTRIBUTED_TESTS_CONFIG['nccl'] = { 'WORLD_SIZE': '2' if torch.cuda.device_count() == 2 else '3' } - + if dist.is_gloo_available(): + DISTRIBUTED_TESTS_CONFIG['gloo'] = { + 'WORLD_SIZE': '2' if torch.cuda.device_count() == 2 else '3' + } # https://stackoverflow.com/questions/2549939/get-signal-names-from-numbers-in-python SIGNALS_TO_NAMES_DICT = {getattr(signal, n): n for n in dir(signal) @@ -385,6 +386,8 @@ def get_selected_tests(options): target_arch = os.environ.get('VSCMD_ARG_TGT_ARCH') if target_arch != 'x64': WINDOWS_BLACKLIST.append('cpp_extensions') + WINDOWS_BLACKLIST.append('jit') + WINDOWS_BLACKLIST.append('jit_fuser') selected_tests = exclude_tests(WINDOWS_BLACKLIST, selected_tests, 'on Windows') diff --git a/test/test_autograd.py b/test/test_autograd.py index 32e8add0ead9d..e7706c5b2b19d 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -21,7 +21,7 @@ from torch.utils.checkpoint import checkpoint from common_utils import (TEST_MKL, TestCase, run_tests, skipIfNoLapack, suppress_warnings, skipIfRocm, slowTest, - load_tests, random_symmetric_pd_matrix, random_symmetric_matrix, IS_WINDOWS) + load_tests, random_symmetric_pd_matrix, random_symmetric_matrix, IS_WINDOWS, IS_MACOS) from common_cuda import TEST_CUDA from torch.autograd import Variable, Function, detect_anomaly from torch.autograd.function import InplaceFunction @@ -34,6 +34,8 @@ exclude_tensor_method, mask_not_all_zeros, S) +from common_device_type import (instantiate_device_type_tests, skipCUDAIfRocm, + onlyCUDA) # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -201,9 +203,6 @@ def backward(ctx, grad_output): with self.assertRaisesRegex(RuntimeError, 'expected shape'): input = torch.randn(5, 5, dtype=torch.float, requires_grad=True) MyFunction.apply(input).sum().backward() - with self.assertRaisesRegex(RuntimeError, 'expected type'): - input = torch.randn(10, dtype=torch.double, requires_grad=True) - MyFunction.apply(input).sum().backward() def test_accumulate_grad(self): grad_output = torch.ones(5, 5) @@ -685,47 +684,6 @@ def test_sparse_mm_backward(self): "calculating the gradient of a sparse Tensor argument to mm is not supported."): z.sum().backward() - # NOTE: flaky on ROCm CI - @skipIfRocm - def test_sparse_ctor_getter_backward(self): - # See NOTE [ Sparse: autograd and API ] on the expected behavior of this test - def test(size, sparse_dim, nnz, device): - v_size = [nnz] + list(size[sparse_dim:]) - i = torch.rand(sparse_dim, nnz) - i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i)) - i = i.to(torch.long) - - inp = torch.randn(v_size, requires_grad=True) - other = self.genSparseTensor(size, sparse_dim, nnz, is_uncoalesced=True)[0] - other = other.to(device) - - def fn(v): - x = torch.sparse_coo_tensor(i, v, size, device=device) - y = (x + other).coalesce() - yv = y.values() - new_v = yv.tanh() - z = torch.sparse_coo_tensor(y.indices(), new_v, y.size()) - return z.coalesce().values() - - gradcheck(fn, (inp,)) - # FIXME: make gradgradcheck work. - # gradgradcheck(fn, (inp,)) - - # assert that _values is non-differentiable - with self.assertRaisesRegex(RuntimeError, "does not have a grad_fn"): - other.detach().requires_grad_()._values().backward(torch.ones_like(other._values())) - - devices = ['cpu'] - - if torch.cuda.is_available(): - devices.append('cuda') - - for empty_i, empty_v, empty_nnz in product([True, False], repeat=3): - sparse_size = [] if empty_i else [2, 1] - dense_size = [1, 0, 2] if empty_v else [1, 2] - nnz = 0 if empty_nnz else 5 - for device in devices: - test(sparse_size + dense_size, len(sparse_size), nnz, device) def test_multi_backward(self): x = torch.randn(5, 5, requires_grad=True) @@ -900,7 +858,10 @@ def test_no_unnecessary_save(self): for i in range(3): x.detach_() x.copy_(mu + i) - loss += (x * torch.tensor([float(i)])).sum() + ft = torch.tensor([float(i)]) + multiplied = x * ft + s = multiplied.sum() + loss += s loss.backward() def test_no_grad(self): @@ -1712,19 +1673,6 @@ def test_sparse_gather_x_scalar(self): def test_sparse_gather_both_scalar(self): self._test_sparse_gather((), (), 0) - # autograd tests via common_method_invocations don't allow input tensors to - # be sparse (RuntimeError: gradcheck expects all tensor inputs are dense when - # check_sparse_nnz is set to False.) - def test_sparse_mask_autograd(self): - for device in ['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']: - tensor = torch.randn(3, requires_grad=True, device=device) - mask = torch.ones(3, device=device) - mask[1] = 0 - mask = mask.to_sparse() - converted = tensor.sparse_mask(mask).to_dense() - converted.sum().backward() - self.assertEqual(tensor.grad, mask.to_dense()) - def test_gc_in_destructor(self): """ Previously, if a Function destructor triggered a garbage collection, @@ -1924,7 +1872,6 @@ def backward(ctx, grad_output): torch.autograd.grad(y, x) # should not error! @unittest.skipIf(torch.cuda.device_count() < 2, "no multi-GPU") - @skipIfRocm def test_unused_output_gpu(self): from torch.nn.parallel._functions import Broadcast x = Variable(torch.randn(5, 5).float().cuda(), requires_grad=True) @@ -1953,7 +1900,6 @@ def backward(ctx, grad_output): self.assertEqual(device[0], 1) @unittest.skipIf(torch.cuda.device_count() < 2, "no multi-GPU") - @skipIfRocm def test_inputbuffer_add_multigpu(self): input = torch.randn(1).cuda(0).requires_grad_() output = input.cuda(1) + input.cuda(1) @@ -2066,65 +2012,6 @@ def test_type_conversions(self): self._test_type_conversion_backward(lambda x: x.cuda(0)) self._test_type_conversion_backward(lambda x: x.cuda(1)) - def _test_pyscalar_conversions(self, t, integral_conv): - # integral -> integral - l = t(torch.zeros(1, 1, 1, dtype=torch.long)) - pyscalar = -12345 - l[0] = pyscalar - self.assertEqual(integral_conv(l), pyscalar) - - # floating point -> floating point - f = Variable(t(torch.randn(1, 1))) - pyscalar = -12345.1 - f[0] = pyscalar - self.assertEqual(float(f), pyscalar) - f[0] = nan - self.assertTrue(math.isnan(float(f))) - f[0] = inf - self.assertEqual(float(f), inf, allow_inf=True) - f[0] = -inf - self.assertEqual(float(f), -inf, allow_inf=True) - - # integral -> floating point - # check we can convert something that loses precision - pyscalar = 1234567890123456789 - self.assertNotEqual(pyscalar, integral_conv(float(pyscalar))) - l[0] = pyscalar - self.assertEqual(float(l), float(pyscalar)) - - # floating point -> integral - f[0] = nan - self.assertRaises(ValueError, lambda: integral_conv(f[0])) - f[0] = inf - self.assertRaises(OverflowError, lambda: integral_conv(f[0])) - f[0] = -inf - self.assertRaises(OverflowError, lambda: integral_conv(f[0])) - f[0] = sys.float_info.max - self.assertEqual(integral_conv(f), sys.float_info.max) - - # bool, nonzero - def test_nonzero(tensor, value, expected): - tensor[0] = value - self.assertEqual(expected, bool(tensor)) - self.assertEqual(expected, True if tensor else False) - - test_nonzero(l, 0, False) - test_nonzero(l, -2, True) - test_nonzero(f, 0.0, False) - test_nonzero(f, sys.float_info.min, True) - test_nonzero(f, nan, bool(nan)) - test_nonzero(f, inf, bool(inf)) - test_nonzero(f, -inf, bool(-inf)) - - def test_pyscalar_conversions(self): - self._test_pyscalar_conversions(lambda x: x, lambda x: int(x)) - if sys.version_info[0] == 2: - self._test_pyscalar_conversions(lambda x: x, lambda x: long(x)) - if torch.cuda.is_available(): - self._test_pyscalar_conversions(lambda x: x.cuda(), lambda x: int(x)) - if sys.version_info[0] == 2: - self._test_pyscalar_conversions(lambda x: x.cuda(), lambda x: long(x)) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") def test_pin_memory(self): x = torch.randn(2, 2, requires_grad=True) @@ -2475,38 +2362,6 @@ def test_trapz(self): lambda y, x: torch.trapz(y, x), True, f_args_variable, f_args_tensor) - # skip this test if running on rocm, because in cdist - # we use __shfl_down_sync on CUDA for fast reduction - # and it gives incorrect results on rocm platform - @skipIfRocm - def test_cdist(self): - def _test_cdist_for_size(sizes): - devices = torch.testing.get_all_device_types() - for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: - for device in devices: - x = torch.randn(sizes, device=device, dtype=torch.double) - y = torch.randn(sizes, device=device, dtype=torch.double) - - eps = 1e-6 - # to avoid extremum - x = x - (((x - y) < eps).double() * 2 * eps) - x.requires_grad = True - y.requires_grad = True - - f_args_variable = (x, y) - - def f(a, b): - return torch.cdist(a, b, p) - - f_args_tensor = deepcopy(unpack_variables(f_args_variable)) - run_functional_checks(self, "test_cdist", "cdist", f, - True, f_args_variable, f_args_tensor) - - _test_cdist_for_size((S, S)) - _test_cdist_for_size((S, S, S)) - _test_cdist_for_size((3, 5)) - _test_cdist_for_size((2, 3, 5)) - _test_cdist_for_size((1, 2, 3)) def test_var_mean_differentiable(self): dim = [2, 4] @@ -2546,6 +2401,27 @@ def run_test(upper, dims): run_test(upper, dims) run_test(upper, dims) + @skipIfNoLapack + def test_cholesky_solve(self): + def _test_with_size(A_dims, B_dims, upper): + root = torch.rand(*A_dims).requires_grad_() + b = torch.rand(*B_dims).requires_grad_() + + def func(root, b, upper): + if upper: + A = root.triu() + else: + A = root.tril() + return torch.cholesky_solve(b, A, upper) + + gradcheck(func, [root, b, upper]) + gradgradcheck(func, [root, b, upper]) + + for (a_size, b_size), upper in product([((3, 3), (3, 4)), ((3, 3), (3, 2)), + ((2, 3, 3), (2, 3, 4)), ((2, 3, 3), (2, 3, 2))], + [True, False]): + _test_with_size(a_size, b_size, upper) + @skipIfNoLapack def test_symeig(self): def func(root, upper): @@ -2727,22 +2603,6 @@ def test_pow_scalar_base(self): a = torch.arange(1, 13, dtype=torch.double).view(3, 4).requires_grad_() gradcheck(lambda a: torch.pow(2, a), (a,)) - # test for backward in https://github.com/pytorch/pytorch/issues/15511 - def test_pdist_large(self): - def func(x): - return torch.pdist(x, p=2) - - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: - # shape[0] should be able to be (roughly) arbitrarily large, but the kernel - # is currently limited to smaller sizes (see issue above); this is just testing - # a floor. - shape = (1000, 1) - x = torch.randn(shape, device=device).requires_grad_() - output = torch.pdist(x, p=2) - # just run a single backward, as gradcheck/gradgradcheck is expensive here - output.sum().backward() - @skipIfNoLapack def test_pinverse(self): # Why is pinverse tested this way, and not ordinarily as other linear algebra methods? @@ -2974,29 +2834,6 @@ def closure(x): # test select on expanded input case test(torch.randn(2, 3), lambda x: x.expand(10, 2, 3), [2, 3], [3, 1], 0) - def _test_where_functional(self, t): - x = Variable(t(torch.randn(5, 5)), requires_grad=True) - y = Variable(t(torch.randn(5, 5)), requires_grad=True) - cond = Variable(t(mask_not_all_zeros((5, 5))), requires_grad=False) - - def where(cond, x, y): - return torch.where(cond, x, y) - - gradcheck(where, [cond, x, y], raise_exception=True) - gradgradcheck(where, [cond, x, y], [Variable(t(torch.randn(5, 5)))]) - - x = Variable(t(torch.randn(5, 1, 5)), requires_grad=True) - y = Variable(t(torch.randn(5, 5, 1)), requires_grad=True) - gradcheck(where, [cond, x, y], raise_exception=True) - gradgradcheck(where, [cond, x, y], [Variable(t(torch.randn(5, 5, 5)))]) - - def test_where_functional(self): - self._test_where_functional(lambda t: t) - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") - def test_where_functional_cuda(self): - self._test_where_functional(lambda t: t.cuda()) - def _test_lerp_tensor_weights(self, cast): def construct_inputs(*shapes): start = cast(torch.randn(shapes[0])).requires_grad_() @@ -3268,43 +3105,6 @@ def test_diagonal_derivative_requires_grad(self): d, = torch.autograd.grad(c, a, retain_graph=True, create_graph=True) self.assertTrue(d.requires_grad) - @staticmethod - def _test_set_requires_grad_only_for_floats(self, cuda): - dtypes = [torch.int64, torch.int32, torch.int16, torch.int8, - torch.float, torch.double] - if cuda: - dtypes.append(torch.half) - - def f1(dt): - a = torch.ones(1, dtype=dt, device='cuda' if cuda else 'cpu') - a.requires_grad_() - - def f2(dt): - a = torch.ones(1, dtype=dt, device='cuda' if cuda else 'cpu') - a.requires_grad = True - - def f3(dt): - torch.ones(1, dtype=dt, device='cuda' if cuda else 'cpu', requires_grad=True) - - for dt in dtypes: - a = torch.ones(1, dtype=dt, device='cuda' if cuda else 'cpu') - a.requires_grad = False # should always work - a.requires_grad_(False) - - for f in [f1, f2, f3]: - if dt.is_floating_point: - f(dt) - else: - with self.assertRaisesRegex(RuntimeError, 'floating point', - msg="dt: {} device: {}".format(a.dtype, a.device)): - f(dt) - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") - def test_set_requires_grad_only_for_floats_cuda(self): - self._test_set_requires_grad_only_for_floats(self, True) - - def test_set_requires_grad_only_for_floats(self): - self._test_set_requires_grad_only_for_floats(self, False) @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") def test_rnn_backward_to_input_but_not_parameters_cuda(self): @@ -3554,6 +3354,7 @@ def backward(ctx, grad): s = TestCase.runWithPytorchAPIUsageStderr(code) self.assertRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown") + @unittest.skipIf(IS_MACOS, "Fails with SIGBUS on macOS; https://github.com/pytorch/pytorch/issues/25941") def test_deep_reentrant(self): class DeepReentrant(Function): @@ -3577,15 +3378,6 @@ def backward(ctx, x): # in the same thread recursively DeepReentrant.apply(v).sum().backward() - @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') - def test_advanced_indexing_backwards_large(self): - # See https://github.com/pytorch/pytorch/issues/22843 - n = (1 << 16) - x = torch.rand(n, 1, device='cuda', requires_grad=True) - a = x[:, [0]] - a.sum().backward() - self.assertEqual(x.grad, torch.ones(n, 1, device='cuda')) - def test_reentrant_priority(self): order = [] @@ -3880,5 +3672,215 @@ def fn(*inputs): for test in method_tests(): add_test(*test) +# Generic device type autograd tests. +class TestAutogradDeviceType(TestCase): + + # skip this test if running on rocm, because in cdist + # we use __shfl_down_sync on CUDA for fast reduction + # and it gives incorrect results on rocm platform + @skipCUDAIfRocm + def test_cdist(self, device): + def _test_cdist_for_size(sizex, sizey=None): + if sizey is None: + sizey = sizex + for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: + x = torch.randn(sizex, device=device, dtype=torch.double) + y = torch.randn(sizey, device=device, dtype=torch.double) + eps = 1e-6 + # to avoid extremum + x = x - (((x - y) < eps).double() * 2 * eps) + x.requires_grad = True + y.requires_grad = True + f_args_variable = (x, y) + + def f(a, b): + return torch.cdist(a, b, p) + f_args_tensor = deepcopy(unpack_variables(f_args_variable)) + run_functional_checks(self, "test_cdist", "cdist", f, + True, f_args_variable, f_args_tensor) + _test_cdist_for_size((S, S)) + _test_cdist_for_size((S, S, S)) + _test_cdist_for_size((3, 5)) + _test_cdist_for_size((2, 3, 5)) + _test_cdist_for_size((1, 2, 3)) + _test_cdist_for_size((1, 1), (S, 1)) + + + # NOTE: flaky on ROCm CI + @skipCUDAIfRocm + def test_sparse_ctor_getter_backward(self, device): + # See NOTE [ Sparse: autograd and API ] on the expected behavior of this test + def _test(size, sparse_dim, nnz, device): + v_size = [nnz] + list(size[sparse_dim:]) + i = torch.rand(sparse_dim, nnz) + i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i)) + i = i.to(torch.long) + + inp = torch.randn(v_size, requires_grad=True) + other = self.genSparseTensor(size, sparse_dim, nnz, is_uncoalesced=True)[0] + other = other.to(device) + + def fn(v): + x = torch.sparse_coo_tensor(i, v, size, device=device) + y = (x + other).coalesce() + yv = y.values() + new_v = yv.tanh() + z = torch.sparse_coo_tensor(y.indices(), new_v, y.size()) + return z.coalesce().values() + + gradcheck(fn, (inp,)) + # FIXME: make gradgradcheck work. + # gradgradcheck(fn, (inp,)) + + # assert that _values is non-differentiable + with self.assertRaisesRegex(RuntimeError, "does not have a grad_fn"): + other.detach().requires_grad_()._values().backward(torch.ones_like(other._values())) + + for empty_i, empty_v, empty_nnz in product([True, False], repeat=3): + sparse_size = [] if empty_i else [2, 1] + dense_size = [1, 0, 2] if empty_v else [1, 2] + nnz = 0 if empty_nnz else 5 + _test(sparse_size + dense_size, len(sparse_size), nnz, device) + + # autograd tests via common_method_invocations don't allow input tensors to + # be sparse (RuntimeError: gradcheck expects all tensor inputs are dense when + # check_sparse_nnz is set to False.) + def test_sparse_mask_autograd(self, device): + tensor = torch.randn(3, requires_grad=True, device=device) + mask = torch.ones(3, device=device) + mask[1] = 0 + mask = mask.to_sparse() + converted = tensor.sparse_mask(mask).to_dense() + converted.sum().backward() + self.assertEqual(tensor.grad, mask.to_dense()) + + def test_pyscalar_conversions(self, device): + def _test_pyscalar_conversions(t, integral_conv): + # integral -> integral + l = t(torch.zeros(1, 1, 1, dtype=torch.long)) + pyscalar = -12345 + l[0] = pyscalar + self.assertEqual(integral_conv(l), pyscalar) + + # floating point -> floating point + f = Variable(t(torch.randn(1, 1))) + pyscalar = -12345.1 + f[0] = pyscalar + self.assertEqual(float(f), pyscalar) + f[0] = nan + self.assertTrue(math.isnan(float(f))) + f[0] = inf + self.assertEqual(float(f), inf, allow_inf=True) + f[0] = -inf + self.assertEqual(float(f), -inf, allow_inf=True) + + # integral -> floating point + # check we can convert something that loses precision + pyscalar = 1234567890123456789 + self.assertNotEqual(pyscalar, integral_conv(float(pyscalar))) + l[0] = pyscalar + self.assertEqual(float(l), float(pyscalar)) + + # floating point -> integral + f[0] = nan + self.assertRaises(ValueError, lambda: integral_conv(f[0])) + f[0] = inf + self.assertRaises(OverflowError, lambda: integral_conv(f[0])) + f[0] = -inf + self.assertRaises(OverflowError, lambda: integral_conv(f[0])) + f[0] = sys.float_info.max + self.assertEqual(integral_conv(f), sys.float_info.max) + + # bool, nonzero + def test_nonzero(tensor, value, expected): + tensor[0] = value + self.assertEqual(expected, bool(tensor)) + self.assertEqual(expected, True if tensor else False) + + test_nonzero(l, 0, False) + test_nonzero(l, -2, True) + test_nonzero(f, 0.0, False) + test_nonzero(f, sys.float_info.min, True) + test_nonzero(f, nan, bool(nan)) + test_nonzero(f, inf, bool(inf)) + test_nonzero(f, -inf, bool(-inf)) + + + _test_pyscalar_conversions(lambda x: x.to(device), lambda x: int(x)) + if sys.version_info[0] == 2: + _test_pyscalar_conversions(lambda x: x.to(device), lambda x: long(x)) + + def test_set_requires_grad_only_for_floats(self, device): + dtypes = [torch.int64, torch.int32, torch.int16, torch.int8, + torch.float, torch.double] + if device == 'cuda': + dtypes.append(torch.half) + + def f1(dt): + a = torch.ones(1, dtype=dt, device=device) + a.requires_grad_() + + def f2(dt): + a = torch.ones(1, dtype=dt, device=device) + a.requires_grad = True + + def f3(dt): + torch.ones(1, dtype=dt, device=device, requires_grad=True) + + for dt in dtypes: + a = torch.ones(1, dtype=dt, device=device) + a.requires_grad = False # should always work + a.requires_grad_(False) + + for f in [f1, f2, f3]: + if dt.is_floating_point: + f(dt) + else: + with self.assertRaisesRegex(RuntimeError, 'floating point', + msg="dt: {} device: {}".format(a.dtype, a.device)): + f(dt) + + @onlyCUDA + def test_advanced_indexing_backwards_large(self, device): + # See https://github.com/pytorch/pytorch/issues/22843 + n = (1 << 16) + x = torch.rand(n, 1, device=device, requires_grad=True) + a = x[:, [0]] + a.sum().backward() + self.assertEqual(x.grad, torch.ones(n, 1, device=device)) + + # test for backward in https://github.com/pytorch/pytorch/issues/15511 + def test_pdist_large(self, device): + def func(x): + return torch.pdist(x, p=2) + + # shape[0] should be able to be (roughly) arbitrarily large, but the kernel + # is currently limited to smaller sizes (see issue above); this is just testing + # a floor. + shape = (1000, 1) + x = torch.randn(shape, device=device).requires_grad_() + output = torch.pdist(x, p=2) + # just run a single backward, as gradcheck/gradgradcheck is expensive here + output.sum().backward() + + def test_where_functional(self, device): + x = torch.randn(5, 5, device=device, requires_grad=True) + y = torch.randn(5, 5, device=device, requires_grad=True) + cond = mask_not_all_zeros((5, 5)).to(device=device) + + def where(cond, x, y): + return torch.where(cond, x, y) + + gradcheck(where, [cond, x, y], raise_exception=True) + gradgradcheck(where, [cond, x, y], [torch.randn(5, 5, device=device)]) + + x = torch.randn(5, 1, 5, device=device, requires_grad=True) + y = torch.randn(5, 5, 1, device=device, requires_grad=True) + gradcheck(where, [cond, x, y], raise_exception=True) + gradgradcheck(where, [cond, x, y], [torch.randn(5, 5, 5, device=device)]) + + +instantiate_device_type_tests(TestAutogradDeviceType, globals()) + if __name__ == '__main__': run_tests() diff --git a/test/test_c10d.py b/test/test_c10d.py index 538dbe7593e2b..e572068cf809e 100644 --- a/test/test_c10d.py +++ b/test/test_c10d.py @@ -11,6 +11,7 @@ import time import unittest from datetime import timedelta +from sys import platform from itertools import groupby from functools import partial, reduce @@ -38,6 +39,12 @@ sys.exit(0) +if platform == 'darwin': + LOOPBACK = 'lo0' +else: + LOOPBACK = 'lo' + + def gpus_for_rank(world_size): """Multigpu tests are designed to simulate the multi nodes with multi GPUs on each node. Nccl backend requires equal #GPUs in each process. @@ -511,7 +518,7 @@ def test_default_store_timeout_gloo(self): class ProcessGroupGlooTest(MultiProcessTestCase): def opts(self, threads=2): opts = c10d.ProcessGroupGloo.Options() - opts.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] + opts.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)] opts.timeout = 5.0 opts.threads = threads return opts @@ -521,8 +528,8 @@ def test_multi_device_constructor(self): opts = c10d.ProcessGroupGloo.Options() opts.timeout = 5.0 opts.devices = [ - c10d.ProcessGroupGloo.create_tcp_device(interface="lo"), - c10d.ProcessGroupGloo.create_tcp_device(interface="lo"), + c10d.ProcessGroupGloo.create_device(interface=LOOPBACK), + c10d.ProcessGroupGloo.create_device(interface=LOOPBACK), ] pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, opts) @@ -645,8 +652,7 @@ def _test_broadcast_stress(self, inputs): (i * self.world_size) + (i % self.world_size) ]), inputs[i], - None, - "Mismatch in iteration %d" % i, + message=("Mismatch in iteration %d" % i), ) def test_broadcast_stress(self): @@ -728,8 +734,7 @@ def _test_allreduce_stress(self, inputs): (self.world_size * (self.world_size - 1) / 2) ]), inputs[i], - None, - "Mismatch in iteration %d" % i, + message=("Mismatch in iteration %d" % i), ) def test_allreduce_stress(self): @@ -981,8 +986,7 @@ def _test_scatter_stress(self, inputs, fn): self.assertEqual( torch.Tensor([iter + root]), outputs[iter][root], - None, - "Mismatch in iteration %d for rank %d" % (iter, root) + message=("Mismatch in iteration %d for rank %d" % (iter, root)), ) def test_scatter_stress(self): @@ -1128,8 +1132,7 @@ def _test_gather_stress(self, inputs, fn): self.assertEqual( expected_outputs[iter], outputs[iter], - None, - "Mismatch in iteration %d for root %d" % (iter, root) + message=("Mismatch in iteration %d for root %d" % (iter, root)) ) def test_gather_stress(self): @@ -1229,8 +1232,7 @@ def _test_allgather_stress(self, inputs, fn): self.assertEqual( expected_outputs[i], outputs[i], - None, - "Mismatch in iteration %d" % i + message=("Mismatch in iteration %d" % i), ) def test_allgather_stress(self): @@ -1318,8 +1320,7 @@ def _test_reduce_stress(self, inputs): (self.world_size * (self.world_size - 1) / 2) ]), outputs[i], - None, - "Mismatch in iteration %d with root rank %d" % (iter, root), + message=("Mismatch in iteration %d with root rank %d" % (iter, root)), ) def test_reduce_stress(self): @@ -1369,6 +1370,7 @@ def test_send_recv_all_to_all(self): continue self.assertEqual(torch.Tensor([i]), outputs[i]) + @unittest.skipIf(platform == 'darwin', 'ProcessGroup timeout not yet supported on macOS') def test_timeout_kwarg(self): store = c10d.FileStore(self.file.name, self.world_size) pg = c10d.ProcessGroupGloo( @@ -1838,7 +1840,7 @@ def update_parameters(model): def _test_gloo_backend(self, devices, device_ids, multi_device=False): store = c10d.FileStore(self.file.name, self.world_size) options = c10d.ProcessGroupGloo.Options() - options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] + options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)] process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) self._test_ddp_with_process_group(process_group, devices, device_ids, multi_device) @@ -1980,7 +1982,7 @@ def test_dist_broadcast_coalesced_nccl(self): def test_dist_broadcast_coalesced_gloo(self): store = c10d.FileStore(self.file.name, self.world_size) options = c10d.ProcessGroupGloo.Options() - options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] + options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)] process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) device = torch.device('cuda') @@ -2019,7 +2021,7 @@ def test_dist_broadcast_coalesced_gloo(self): def test_sync_params_no_buffers(self): store = c10d.FileStore(self.file.name, self.world_size) options = c10d.ProcessGroupGloo.Options() - options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] + options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)] process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) # Use all available devices on every process here (data is small, so should be fine). @@ -2046,7 +2048,7 @@ def test_sync_params_no_buffers(self): def test_sync_params_with_buffers(self): store = c10d.FileStore(self.file.name, self.world_size) options = c10d.ProcessGroupGloo.Options() - options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] + options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)] process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) devices = gpus_for_rank(self.world_size)[self.rank] @@ -2991,7 +2993,11 @@ def test_nccl_errors_nonblocking(self): def _test_nccl_errors_blocking(self, func): os.environ["NCCL_BLOCKING_WAIT"] = "1" store = c10d.FileStore(self.file.name, self.world_size) - process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size, "", timeout=timedelta(seconds=self.op_timeout_sec)) + process_group = c10d.ProcessGroupNCCL( + store, + self.rank, + self.world_size, + timeout=timedelta(seconds=self.op_timeout_sec)) process_group.allreduce(torch.rand(10).cuda(self.rank)) if self.rank == 0: work = process_group.allreduce(torch.rand(10).cuda(self.rank)) @@ -3053,7 +3059,7 @@ def test_broadcast_coalesced_nccl(self): def test_broadcast_coalesced_gloo_cuda(self): store = c10d.FileStore(self.file.name, self.world_size) options = c10d.ProcessGroupGloo.Options() - options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] + options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)] process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) device = torch.device('cuda:%d' % self.rank) self._test_broadcast_coalesced(process_group, device) @@ -3062,7 +3068,7 @@ def test_broadcast_coalesced_gloo_cuda(self): def test_broadcast_coalesced_gloo_cpu(self): store = c10d.FileStore(self.file.name, self.world_size) options = c10d.ProcessGroupGloo.Options() - options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] + options.devices = [c10d.ProcessGroupGloo.create_device(interface=LOOPBACK)] process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) device = torch.device('cpu') self._test_broadcast_coalesced(process_group, device) diff --git a/test/test_c10d_spawn.py b/test/test_c10d_spawn.py index 8004ed4d2206c..fe83ddf695982 100644 --- a/test/test_c10d_spawn.py +++ b/test/test_c10d_spawn.py @@ -34,7 +34,7 @@ class ProcessGroupShareTensorTest(TestCase): @classmethod def opts(cls, threads=2): opts = c10d.ProcessGroupGloo.Options() - opts.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] + opts.devices = [c10d.ProcessGroupGloo.create_device(interface="lo")] opts.timeout = 5.0 opts.threads = threads return opts diff --git a/test/test_cpp_api_parity.py b/test/test_cpp_api_parity.py index ac7ca9b558297..0354dc681f1dc 100644 --- a/test/test_cpp_api_parity.py +++ b/test/test_cpp_api_parity.py @@ -5,6 +5,7 @@ import unittest import warnings import inspect +import re import torch from torch._six import PY2 @@ -35,22 +36,54 @@ tensor1.allclose(tensor2); } -bool check_ivalue_equality(const c10::IValue& ivalue1, const c10::IValue& ivalue2) { - if (ivalue1.tagKind() != ivalue2.tagKind()) { - AT_ERROR("Value type mismatch: ", "ivalue1: ", ivalue1.tagKind(), ", ivalue2: ", ivalue2.tagKind()); +bool check_ivalue_equality(const c10::IValue& ivalue_python, const c10::IValue& ivalue_cpp) { + // For Python modules, we allow the use of `int` to represent attributes that + // are multidimensional but have the same value in all dimensions. The corresponding + // data type for C++ modules is `ExpandingArray` (which is converted to `IntList` by the + // `IValue` constructor), and here we check that all elements in the `ExpandingArray` + // are equal to the Python `int` attribute. + if (ivalue_python.isInt() && ivalue_cpp.isIntList()) { + auto ivalue_cpp_list = ivalue_cpp.toIntListRef(); + std::vector ivalue_python_vec(ivalue_cpp_list.size()); + std::fill(ivalue_python_vec.begin(), ivalue_python_vec.end(), ivalue_python.toInt()); + return ivalue_python_vec == ivalue_cpp_list; } - if (ivalue1.isInt()) { - return ivalue1.toInt() == ivalue2.toInt(); - } else if (ivalue1.isDouble()) { - return ivalue1.toDouble() == ivalue2.toDouble(); - } else if (ivalue1.isBool()) { - return ivalue1.toBool() == ivalue2.toBool(); - } else if (ivalue1.isString()) { - return ivalue1.toString() == ivalue2.toString(); - } else if (ivalue1.isTensor()) { - return check_tensor_equality(ivalue1.toTensor(), ivalue2.toTensor()); + + // For Python modules, we allow the use of "none" / "mean" / "sum" to represent the reduction type. + // The corresponding data type for C++ modules is `Reduction::Reduction` enum, and here we map the + // reduction types between Python version and C++ version. + if (ivalue_python.isString() && ivalue_cpp.isInt()) { + auto& ivalue_python_str = ivalue_python.toStringRef(); + auto ivalue_cpp_int = ivalue_cpp.toInt(); + if (ivalue_python_str == "none") { + return ivalue_cpp_int == Reduction::None; + } else if (ivalue_python_str == "mean") { + return ivalue_cpp_int == Reduction::Mean; + } else if (ivalue_python_str == "sum") { + return ivalue_cpp_int == Reduction::Sum; + } + } + + if (ivalue_python.tagKind() != ivalue_cpp.tagKind()) { + AT_ERROR("Value type mismatch: ", "from Python: ", ivalue_python.tagKind(), ", from C++: ", ivalue_cpp.tagKind()); + } + + if (ivalue_python.isInt()) { + return ivalue_python.toInt() == ivalue_cpp.toInt(); + } else if (ivalue_python.isDouble()) { + return ivalue_python.toDouble() == ivalue_cpp.toDouble(); + } else if (ivalue_python.isBool()) { + return ivalue_python.toBool() == ivalue_cpp.toBool(); + } else if (ivalue_python.isString()) { + return ivalue_python.toStringRef() == ivalue_cpp.toStringRef(); + } else if (ivalue_python.isTensor()) { + return check_tensor_equality(ivalue_python.toTensor(), ivalue_cpp.toTensor()); + } else if (ivalue_python.isIntList()) { + return ivalue_python.toIntListRef() == ivalue_cpp.toIntListRef(); + } else if (ivalue_python.isNone()) { + return ivalue_cpp.isNone(); } else { - AT_ERROR("Unsupported value type: ", ivalue1.tagKind()); + AT_ERROR("Unsupported value type: ", ivalue_python.tagKind()); } } """ @@ -82,19 +115,25 @@ CHECK_MODULE_ATTR_EQUALITY = Template("""\ TORCH_CHECK( check_ivalue_equality( - ${script_module_prefix}.get_attribute("${attr_name}"), c10::IValue(${cpp_module_prefix}->${attr_name})), + ${script_module_prefix}.get_attribute("${python_attr_name}"), c10::IValue(${cpp_module_prefix}->${cpp_attr_name})), GENERATE_PARITY_TEST_ERROR_MSG( - "`${cpp_module_prefix}->${attr_name}`", - ${cpp_module_prefix}->${attr_name}, - ${script_module_prefix}.get_attribute("${attr_name}"))); + "`${cpp_module_prefix}->${cpp_attr_name}`", + c10::IValue(${cpp_module_prefix}->${cpp_attr_name}), + ${script_module_prefix}.get_attribute("${python_attr_name}"))); """) TORCH_NN_MODULE_TEST_CTOR_ARGS = Template("""\n void ${module_name}_test_ctor_args() { ${module_qualified_name} m_init_by_cpp(${module_option}); + + ${extra_stmts} } """) +TORCH_NN_MODULE_TEST_OPTIONS_ARG = Template("""\ +m_init_by_cpp->options.${options_arg_name}(); +""") + TORCH_NN_MODULE_TEST_INIT = Template("""\n void ${module_variant_name}_test_init( const std::string& saved_module_path, @@ -179,14 +218,24 @@ def _python_arg_to_cpp_arg(self, python_arg): elif type(python_arg) == bool: return CppArg(type='bool', value=str(python_arg).lower()) elif type(python_arg) == str: - return CppArg(type='std::string', value='"{}"'.format(python_arg)) + # if `python_arg` is one of the reduction types, we use the corresponding `Reduction::Reduction` enum. + if python_arg in ['none', 'mean', 'sum']: + if python_arg == 'none': + cpp_arg = 'Reduction::None' + elif python_arg == 'mean': + cpp_arg = 'Reduction::Mean' + elif python_arg == 'sum': + cpp_arg = 'Reduction::Sum' + return CppArg(type='Reduction::Reduction', value='{}'.format(cpp_arg)) + else: + return CppArg(type='std::string', value='"{}"'.format(python_arg)) elif type(python_arg) == torch.Tensor: return CppArg( type='torch::Tensor', value='torch::empty({})'.format(str(list(python_arg.shape)).replace('[', '{').replace(']', '}'))) else: raise RuntimeError( - "{} is not a supported arg type for C++ module methods".format(type(python_default_value))) + "{} is not a supported arg type for C++ module methods".format(type(python_arg))) def _compile_cpp_code_inline(self, name, cpp_sources, functions): # Just-in-time compile the C++ test code @@ -198,18 +247,70 @@ def _compile_cpp_code_inline(self, name, cpp_sources, functions): ) return cpp_module - # This tests that Python and C++ torch.nn modules have matching constructor arg names and types. - def _test_torch_nn_module_ctor_args(self, module_name): + def _get_python_module_init_arg_spec(self, module_name): python_module_class = getattr(torch.nn, module_name) - module_metadata = torch_nn_modules.module_metadata_map[module_name] - cpp_default_constructor_args_str = module_metadata.cpp_default_constructor_args if PY2: init_arg_spec = inspect.getargspec(python_module_class.__init__) else: init_arg_spec = inspect.getfullargspec(python_module_class.__init__) + return init_arg_spec + + def _prepare_tensors_for_module_input_or_target(self, test_params, tensors): + if type(tensors) == tuple: + tensors = list(tensors) + elif type(tensors) == torch.Tensor: + tensors = [tensors] + else: + raise RuntimeError("Unexpected input type: {}".format(type(tensors))) + + if test_params.device != 'cuda' or TEST_CUDA: + tensors = [x.to(test_params.device) for x in tensors] + + return tensors + + def _get_example_inputs(self, test_params): + example_inputs = test_params.test_instance._get_input() + example_inputs = self._prepare_tensors_for_module_input_or_target(test_params, example_inputs) + + # We set all inputs to torch.nn module to requires grad, so that the backward test can always be run. + # However, we skip embedding layers for now, becuase they only accept LongTensor as inputs, + # And LongTensor cannot require grad. + if test_params.module_name not in ["Embedding", "Embedding_sparse", "EmbeddingBag", "EmbeddingBag_sparse"]: + example_inputs = [x.requires_grad_() for x in example_inputs] + + return example_inputs + + def _get_example_targets(self, test_params): + example_targets = test_params.test_instance._get_target() + example_targets = self._prepare_tensors_for_module_input_or_target(test_params, example_targets) + return example_targets + + def _get_forward_input_args(self, test_params): + example_inputs = self._get_example_inputs(test_params) + if isinstance(test_params.test_instance, common_nn.CriterionTest): + example_targets = self._get_example_targets(test_params) + else: + example_targets = [] + + input_args = () + for example_input in example_inputs: + input_args += (example_input, ) + for example_target in example_targets: + input_args += (example_target, ) + + return input_args + + # This tests that Python and C++ torch.nn modules have matching constructor arg names and types. + def _test_torch_nn_module_ctor_args(self, module_name): + module_metadata = torch_nn_modules.module_metadata_map[module_name] + cpp_default_constructor_args_str = module_metadata.cpp_default_constructor_args + init_arg_spec = self._get_python_module_init_arg_spec(module_name) init_kwargs_defaults = init_arg_spec.defaults python_default_constructor_arg_names = [x for x in init_arg_spec.args[1:-len(init_kwargs_defaults)] if x != 'has_parity'] - cpp_default_constructor_arg_values = cpp_default_constructor_args_str.strip('()').split(',') + # NOTE: the regex is used here to split up e.g. `(1, {2, 3}, 4)` into `['1', '{2, 3}', '4']` + cpp_default_constructor_arg_values = re.findall(r'{[^}]*}|[^,\s()]+', cpp_default_constructor_args_str) + + # Step 1: Check that the # of non-keyword args in C++ module constructor is equal to that in Python module constructor. self.assertEqual( len(cpp_default_constructor_arg_values), len(python_default_constructor_arg_names), @@ -222,16 +323,30 @@ def _test_torch_nn_module_ctor_args(self, module_name): len(python_default_constructor_arg_names), python_default_constructor_arg_names)) + # Step 2: Generate code to construct C++ module options using values from `cpp_default_constructor_args`. cpp_module_option = 'torch::nn::{}Options{}'.format(module_name, cpp_default_constructor_args_str) init_kwargs = init_arg_spec.args[-len(init_kwargs_defaults):] for arg_name, python_default_value in zip(init_kwargs, init_kwargs_defaults): - cpp_module_option += '.{}({})'.format(arg_name, self._python_arg_to_cpp_arg(python_default_value).value) - + # NOTE: If a Python module constructor arg's default value is None, we don't test its corresponding + # options arg in C++ module (because the way to set the C++ options arg to an empty value is to not + # specify it, which means we can't test that the options arg exists). + # Instead, we test that all options args exist by calling their accessors after constructing the + # C++ module with the options. + if arg_name not in module_metadata.python_legacy_constructor_args and python_default_value is not None: + cpp_module_option += '.{}({})'.format(arg_name, self._python_arg_to_cpp_arg(python_default_value).value) + + # Step 3: Generate code to check existence of all Python module constructor args in the C++ module options. + extra_stmts = [TORCH_NN_MODULE_TEST_OPTIONS_ARG.substitute(options_arg_name=arg_name) + for arg_name in python_default_constructor_arg_names + init_kwargs + if arg_name not in module_metadata.python_legacy_constructor_args] + + # Step 4: Compile the test code and run the tests. cpp_sources = TORCH_NN_MODULE_COMMON_TEST_HARNESS + module_metadata.cpp_sources cpp_sources += TORCH_NN_MODULE_TEST_CTOR_ARGS.substitute( module_name=module_name, module_qualified_name='torch::nn::{}'.format(module_name), - module_option=cpp_module_option) + module_option=cpp_module_option, + extra_stmts=''.join(extra_stmts)) cpp_test_name = module_name + '_test_ctor_args' cpp_module = self._compile_cpp_code_inline( name=cpp_test_name, cpp_sources=cpp_sources, functions=cpp_test_name) @@ -240,8 +355,8 @@ def _test_torch_nn_module_ctor_args(self, module_name): def _test_torch_nn_module_variant(self, test_params): def generate_test_cpp_sources(test_params, template, extra_stmts): - example_inputs = test_params.example_inputs - input_arg_types = [self._python_arg_to_cpp_arg(arg).type for arg in example_inputs] + input_args = self._get_forward_input_args(test_params) + input_arg_types = [self._python_arg_to_cpp_arg(arg).type for arg in list(input_args)] input_args = ['arg{}'.format(str(i)) for i in range(len(input_arg_types))] input_arg_declarations = ['{} {}'.format(arg_type, arg_name) for arg_type, arg_name in zip(input_arg_types, input_args)] test_cpp_sources = template.substitute( @@ -275,21 +390,31 @@ def generate_attr_equality_checks(module, script_module_prefix=script_module_prefix, cpp_module_prefix=cpp_module_prefix, buffer_name=name)) + + init_arg_spec = self._get_python_module_init_arg_spec(module.__class__.__name__) + # NOTE: `init_arg_spec.args[0]` is `self`, which is not counted as a constructor arg in the API parity test. + python_constructor_arg_names = [x for x in init_arg_spec.args[1:] if x != 'has_parity'] for name, attr in module.__dict__.items(): if name not in TORCH_NN_MODULE_IGNORED_ATTRS: + # Every constructor arg of the Python module must have + # a corresponding C++ module options arg. + if name in python_constructor_arg_names: + cpp_attr_name = 'options.{}()'.format(name) + else: + cpp_attr_name = name stmts.append(CHECK_MODULE_ATTR_EQUALITY.substitute( script_module_prefix=script_module_prefix, cpp_module_prefix=cpp_module_prefix, - attr_name=name)) + python_attr_name=name, + cpp_attr_name=cpp_attr_name)) return stmts device = test_params.device - python_module_class = test_params.python_module_class - python_constructor_args = test_params.python_constructor_args - example_inputs = test_params.example_inputs + python_constructor = test_params.test_instance.constructor + python_constructor_args = test_params.test_instance.constructor_args torch.manual_seed(2) - module = python_module_class(*python_constructor_args).to(device) + module = python_constructor(*python_constructor_args).to(device) extra_stmts = generate_attr_equality_checks(module) assert len(extra_stmts) == test_params.num_attrs_recursive @@ -300,27 +425,27 @@ def generate_attr_equality_checks(module, def setup_forward_test(test_params): device = test_params.device - python_module_class = test_params.python_module_class - python_constructor_args = test_params.python_constructor_args - example_inputs = test_params.example_inputs + python_constructor = test_params.test_instance.constructor + python_constructor_args = test_params.test_instance.constructor_args + input_args = self._get_forward_input_args(test_params) torch.manual_seed(2) - module = python_module_class(*python_constructor_args).to(device) - python_output = module(*example_inputs) + module = python_constructor(*python_constructor_args).to(device) + python_output = module(*input_args) - return (([module], device, python_output, example_inputs), + return (([module], device, python_output, input_args), generate_test_cpp_sources( test_params=test_params, template=TORCH_NN_MODULE_TEST_FORWARD, extra_stmts='')) def setup_backward_test(test_params): device = test_params.device - python_module_class = test_params.python_module_class - python_constructor_args = test_params.python_constructor_args - example_inputs = test_params.example_inputs + python_constructor = test_params.test_instance.constructor + python_constructor_args = test_params.test_instance.constructor_args + input_args = self._get_forward_input_args(test_params) torch.manual_seed(2) - module = python_module_class(*python_constructor_args).to(device) - python_output = module(*example_inputs) + module = python_constructor(*python_constructor_args).to(device) + python_output = module(*input_args) python_output.sum().backward() # JIT tracing does not save a module's parameters' gradients into ScriptModule. # Instead, we create another module `grad_module` with the same structure as `module`, @@ -332,11 +457,13 @@ def setup_backward_test(test_params): if param.grad is not None: grad_param.data = param.grad - return (([module, grad_module], device, example_inputs), + return (([module, grad_module], device, input_args), generate_test_cpp_sources( test_params=test_params, template=TORCH_NN_MODULE_TEST_BACKWARD, extra_stmts='')) - def trace_module(module, example_inputs): + def trace_module(module, input_args): + module_metadata = torch_nn_modules.module_metadata_map[module.__class__.__name__] + # JIT tracing does not automatically save a module's non-parameter / non-buffer attributes # into a ScriptModule's slots, which means we can't access them via `get_attributes()` in C++. # Here, we manually register these attributes into the ScriptModule so that we can access them @@ -346,11 +473,22 @@ def register_attrs(module, script_module): register_attrs(sub_module, sub_script_module) for key, value in module.__dict__.items(): if key not in TORCH_NN_MODULE_IGNORED_ATTRS: - script_module._c._register_attribute( - key, torch.jit.annotations.ann_to_type(type(value)), value) + if value is None: + value_type = module_metadata.python_optional_attribute_to_jit_type[key] + elif type(value) == tuple: + assert all(isinstance(x, type(value[0])) for x in value), \ + "All elements in a tuple attribute of a Python torch.nn module must have the same type." + # Here, we set the Python tuple attribute's type to `ListType` in the ScriptModule, + # which will automatically be converted to `IntList` later and match the type + # of the corresponding attribute in C++ module (which is initially an `ExpandingArray` + # and is converted to `IntList` by the `IValue` constructor). + value_type = torch._C.ListType(torch.jit.annotations.ann_to_type(type(value[0]))) + else: + value_type = torch.jit.annotations.ann_to_type(type(value)) + script_module._c._register_attribute(key, value_type, value) # We use JIT tracing to serialize Python module state, so that we can load it into C++ - traced_script_module = torch.jit.trace(module, example_inputs) + traced_script_module = torch.jit.trace(module, input_args) register_attrs(module, traced_script_module) return traced_script_module @@ -361,11 +499,8 @@ def serialize_module_into_file(script_module): return module_file.name def test_methods(test_params): - device = test_params.device - python_module_class = test_params.python_module_class - python_constructor_args = test_params.python_constructor_args module_variant_name = test_params.module_variant_name - example_inputs = test_params.example_inputs + input_args = self._get_forward_input_args(test_params) args_map = {} @@ -390,12 +525,14 @@ def test_methods(test_params): for method_name, _ in torch_nn_test_methods: args = args_map[method_name] modules = args[0] - script_modules = [trace_module(module, example_inputs) for module in modules] + script_modules = [trace_module(module, input_args) for module in modules] module_file_names = [serialize_module_into_file(script_module) for script_module in script_modules] cpp_args = module_file_names[:] for arg in args[1:]: - if isinstance(arg, list): + if isinstance(arg, tuple): + cpp_args += list(arg) + elif isinstance(arg, list): cpp_args += arg else: cpp_args.append(arg) @@ -433,33 +570,21 @@ def _compute_module_name(test_params_dict): return module_name -def _process_test_params(test_params_dict, module_metadata, device): +def _process_test_params(test_params_dict, module_metadata, device, is_criterion): module_name = _compute_module_name(test_params_dict) - desc = test_params_dict.get('desc', None) - python_module_class = getattr(torch.nn, module_name) - - test_params_dict['constructor'] = test_params_dict.get('constructor', python_module_class) - test = common_nn.TestBase(**test_params_dict) - module_variant_name = test.get_name()[5:] + (('_' + device) if device != 'cpu' else '') - example_inputs = test._get_input() - - if type(example_inputs) == tuple: - example_inputs = list(example_inputs) - elif type(example_inputs) == torch.Tensor: - example_inputs = [example_inputs] + test_params_dict['constructor'] = test_params_dict.get('constructor', getattr(torch.nn, module_name)) + if is_criterion: + test = common_nn.CriterionTest(**test_params_dict) else: - raise RuntimeError("Unexpected input type: {}".format(type(example_inputs))) + test = common_nn.ModuleTest(**test_params_dict) + module_variant_name = test.get_name()[5:] + (('_' + device) if device != 'cpu' else '') - if device != 'cuda' or TEST_CUDA: - example_inputs = [x.to(device) for x in example_inputs] return TorchNNTestParams( module_name=module_name, module_variant_name=module_variant_name, - python_constructor_args=test.constructor_args, + test_instance=test, cpp_constructor_args=test_params_dict.get('cpp_constructor_args'), - example_inputs=example_inputs, has_parity=test_params_dict.get('has_parity', True), - python_module_class=python_module_class, cpp_sources=module_metadata.cpp_sources, num_attrs_recursive=module_metadata.num_attrs_recursive, device=device, @@ -477,70 +602,74 @@ def add_test(test_name, test_fn): torch_nn_test_params_map = {} -all_module_tests = sample_module.module_tests + \ - common_nn.module_tests + \ - common_nn.new_module_tests + \ - common_nn.criterion_tests + \ - common_nn.new_criterion_tests -for test_params_dict in all_module_tests: - # We skip all `torch.nn.functional` tests for now - if 'FunctionalModule' in str(test_params_dict.get('constructor', '')): - continue +def add_torch_nn_module_tests(module_tests, is_criterion): + for test_params_dict in module_tests: + # We skip all `torch.nn.functional` tests for now + if 'FunctionalModule' in str(test_params_dict.get('constructor', '')): + continue - module_name = _compute_module_name(test_params_dict) + module_name = _compute_module_name(test_params_dict) - assert hasattr(torch.nn, module_name), \ - "`torch.nn` doesn't have module `{}`. ".format(module_name) + \ - "If you are adding a new test, please set `fullname` using format `ModuleName_desc`, " + \ - "or set `module_name` using format `ModuleName`." + assert hasattr(torch.nn, module_name), \ + "`torch.nn` doesn't have module `{}`. ".format(module_name) + \ + "If you are adding a new test, please set `fullname` using format `ModuleName_desc`, " + \ + "or set `module_name` using format `ModuleName`." - module_full_name = 'torch.nn.' + module_name - if module_full_name not in parity_table['torch.nn']: - raise RuntimeError( - 'Module `{}` is not found in Python / C++ API parity table. Please update parity table at {}.'.format( - module_full_name, parity_table_path)) + module_full_name = 'torch.nn.' + module_name + if module_full_name not in parity_table['torch.nn']: + raise RuntimeError( + 'Module `{}` is not found in Python / C++ API parity table. Please update parity table at {}.'.format( + module_full_name, parity_table_path)) - has_impl_parity, _ = parity_table['torch.nn'][module_full_name] + has_impl_parity, _ = parity_table['torch.nn'][module_full_name] - def add_ctor_args_test_for_module(module_name, has_impl_parity): - ctor_args_test_name = 'test_torch_nn_{}_ctor_args'.format(module_name) + def add_ctor_args_test_for_module(module_name, has_impl_parity): + ctor_args_test_name = 'test_torch_nn_{}_ctor_args'.format(module_name) - def ctor_args_test(self): - self._test_torch_nn_module_ctor_args( - module_name=self._testMethodName.replace('test_torch_nn_', '').replace('_ctor_args', '')) + def ctor_args_test(self): + self._test_torch_nn_module_ctor_args( + module_name=self._testMethodName.replace('test_torch_nn_', '').replace('_ctor_args', '')) - if not has_impl_parity: - ctor_args_test = unittest.expectedFailure(ctor_args_test) + if not has_impl_parity: + ctor_args_test = unittest.expectedFailure(ctor_args_test) - # We only run one constructor args test per module - if not has_test(ctor_args_test_name): - add_test(ctor_args_test_name, ctor_args_test) + # We only run one constructor args test per module + if not has_test(ctor_args_test_name): + add_test(ctor_args_test_name, ctor_args_test) - def add_variant_test_for_module(module_name, test_params_dict, has_impl_parity): - module_metadata = torch_nn_modules.module_metadata_map[module_name] - for device in devices: - test_params = _process_test_params( - test_params_dict=test_params_dict, - module_metadata=module_metadata, - device=device) - test_name = 'test_torch_nn_{}'.format(test_params.module_variant_name) - torch_nn_test_params_map[test_name] = test_params + def add_variant_test_for_module(module_name, test_params_dict, has_impl_parity): + module_metadata = torch_nn_modules.module_metadata_map[module_name] + for device in devices: + test_params = _process_test_params( + test_params_dict=test_params_dict, + module_metadata=module_metadata, + device=device, + is_criterion=is_criterion) + test_name = 'test_torch_nn_{}'.format(test_params.module_variant_name) + torch_nn_test_params_map[test_name] = test_params - def test_fn(self): - self._test_torch_nn_module_variant(test_params=torch_nn_test_params_map[self._testMethodName]) + def test_fn(self): + self._test_torch_nn_module_variant(test_params=torch_nn_test_params_map[self._testMethodName]) - if device == 'cuda': - test_fn = unittest.skipIf(not TEST_CUDA, "CUDA unavailable")(test_fn) + if device == 'cuda': + test_fn = unittest.skipIf(not TEST_CUDA, "CUDA unavailable")(test_fn) - if not has_impl_parity: - test_fn = unittest.expectedFailure(test_fn) + if not has_impl_parity: + test_fn = unittest.expectedFailure(test_fn) + + add_test(test_name, test_fn) - add_test(test_name, test_fn) + add_ctor_args_test_for_module(module_name, has_impl_parity) + add_variant_test_for_module(module_name, test_params_dict, has_impl_parity) - add_ctor_args_test_for_module(module_name, has_impl_parity) - add_variant_test_for_module(module_name, test_params_dict, has_impl_parity) +add_torch_nn_module_tests( + sample_module.module_tests + common_nn.module_tests + common_nn.new_module_tests, + is_criterion=False) +add_torch_nn_module_tests( + common_nn.criterion_tests + common_nn.new_criterion_tests, + is_criterion=True) # Assert that there exists auto-generated tests for SampleModule. assert len([name for name in TestCppApiParity.__dict__ if 'SampleModule' in name]) == \ diff --git a/test/test_cuda.py b/test/test_cuda.py index 999794b75fd71..ecf807a474b52 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -23,8 +23,8 @@ from common_methods_invocations import tri_tests_args, tri_large_tests_args, \ _compare_trilu_indices, _compare_large_trilu_indices from common_utils import TestCase, get_gpu_type, to_gpu, freeze_rng_state, run_tests, \ - PY3, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, skipIfRocm, TEST_NUMPY, TEST_WITH_ROCM, \ - load_tests, slowTest, skipCUDANonDefaultStreamIf + PY3, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, skipIfRocm, \ + TEST_WITH_ROCM, load_tests, slowTest, skipCUDANonDefaultStreamIf # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -534,7 +534,6 @@ def tmp(t): 'dot': 1e-2, 'erf': 1e-3, 'erfc': 1e-3, - 'erfinv': 1e-3, 'exp': 1e-2, 'expm1': 1e-2, 'fill': 1e-3, @@ -591,7 +590,6 @@ def tmp(t): 'cosh', 'erf', 'erfc', - 'erfinv', 'exp', 'expm1', 'reciprocal', @@ -665,15 +663,7 @@ def tmp(self): class TestCuda(TestCase): _do_cuda_memory_leak_check = True - # See https://github.com/pytorch/pytorch/issues/21589 - # We used to have this turned on for the tests in this file which - # we had tested to be OK, but when people added new tests to - # this file, it would trigger nondeterministic failures that - # are hard to debug. Since there are KNOWN bugs with our - # stream handling, we shouldn't turn this on by default. - # If you decide to make this True, be sure to run the test suite - # under cuda-memcheck - _do_cuda_non_default_stream = False + _do_cuda_non_default_stream = True FIFTY_MIL_CYCLES = 50000000 @staticmethod @@ -1072,7 +1062,8 @@ def test_mul(dtype): self.assertEqual(x * y, 4.5) self.assertEqual(y * x, 4.5) - with self.assertRaisesRegex(RuntimeError, "doesn't match the desired"): + + with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"): y *= x x *= y self.assertEqual(x, 4.5) @@ -1095,21 +1086,6 @@ def test_abs_zero(self): for num in abs_zeros: self.assertGreater(math.copysign(1.0, num), 0.0) - def test_neg(self): - _TestTorchMixin._test_neg(self, lambda t: t.cuda()) - - def test_bitwise_not(self): - _TestTorchMixin._test_bitwise_not(self, 'cuda') - - def test_logical_not(self): - _TestTorchMixin._test_logical_not(self, 'cuda') - - def test_logical_xor(self): - _TestTorchMixin._test_logical_xor(self, 'cuda') - - def test_isinf(self): - _TestTorchMixin._test_isinf(self, lambda t: t.cuda()) - @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") def test_arithmetic_large_tensor(self): x = torch.empty(2**30, device='cuda') @@ -1402,9 +1378,6 @@ def test_cat_autogpu(self): z = torch.cat([x, y], 0) self.assertEqual(z.get_device(), x.get_device()) - def test_clamp(self): - _TestTorchMixin._test_clamp(self, 'cuda') - def test_cat(self): SIZE = 10 for dim in range(-3, 3): @@ -1426,12 +1399,6 @@ def test_cat(self): z = torch.cat([x, y]) self.assertEqual(z.size(), (21, SIZE, SIZE)) - def test_cat_empty_legacy(self): - _TestTorchMixin._test_cat_empty_legacy(self, use_cuda=True) - - def test_cat_empty(self): - _TestTorchMixin._test_cat_empty(self, use_cuda=True) - def test_bernoulli(self): _TestTorchMixin._test_bernoulli(self, torch.float32, torch.float64, 'cuda') _TestTorchMixin._test_bernoulli(self, torch.float32, torch.float16, 'cuda') @@ -1903,6 +1870,7 @@ def _test_stream_event_nogil(self, sync_func, p2c, c2p): c2p.put(sync_func(self, TestCuda.FIFTY_MIL_CYCLES)) @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") + # Flaky on the ROCm CI @skipIfRocm def test_stream_event_nogil(self): for sync_func in [TestCuda._stream_synchronize, @@ -1966,7 +1934,6 @@ def test_events_wait(self): self.assertTrue(s1.query()) @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") - @skipIfRocm def test_events_multi_gpu_query(self): d0 = torch.device('cuda:0') d1 = torch.device('cuda:1') @@ -1974,6 +1941,7 @@ def test_events_multi_gpu_query(self): with torch.cuda.device(d0): s0 = torch.cuda.current_stream() e0 = s0.record_event() + s0.synchronize() with torch.cuda.device(d1): s1 = torch.cuda.current_stream() @@ -2156,12 +2124,12 @@ def test_sum_cpu_gpu_mismatch(self): x = torch.randn(20, dtype=torch.float32, device='cuda:0') y = torch.randn(1, dtype=torch.float32) with self.assertRaisesRegex(RuntimeError, - 'expected device cpu and dtype Float but got device cuda:0 and dtype Float'): + 'expected device cpu but got device cuda:0'): torch.sum(x, dim=[0], dtype=torch.float32, out=y) # makeing sure half to float promotion is also properly working. x = x.half() with self.assertRaisesRegex(RuntimeError, - 'expected device cpu and dtype Float but got device cuda:0 and dtype Half'): + 'expected dtype Float but got dtype Half'): torch.sum(x, dim=[0], dtype=torch.float32, out=y) @skipIfRocm @@ -2207,101 +2175,6 @@ def test_prod_large(self): def _select_broadcastable_dims(dims_full=None): return _TestTorchMixin._select_broadcastable_dims(dims_full) - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_inverse(self): - _TestTorchMixin._test_inverse(self, lambda t: t.cuda()) - - @slowTest - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_inverse_many_batches(self): - _TestTorchMixin._test_inverse_slow(self, lambda t: t.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_pinverse(self): - _TestTorchMixin._test_pinverse(self, lambda t: t.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_matrix_rank(self): - _TestTorchMixin._test_matrix_rank(self, lambda x: x.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_matrix_power(self): - _TestTorchMixin._test_matrix_power(self, conv_fn=lambda t: t.cuda()) - - def test_chain_matmul(self): - _TestTorchMixin._test_chain_matmul(self, cast=lambda t: t.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_det_logdet_slogdet(self): - _TestTorchMixin._test_det_logdet_slogdet(self, 'cuda') - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_det_logdet_slogdet_batched(self): - _TestTorchMixin._test_det_logdet_slogdet_batched(self, 'cuda') - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_solve(self): - _TestTorchMixin._test_solve(self, lambda t: t.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_solve_batched(self): - _TestTorchMixin._test_solve_batched(self, lambda t: t.cuda()) - - @slowTest - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_solve_batched_many_batches(self): - _TestTorchMixin._test_solve_batched_many_batches(self, lambda t: t.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_solve_batched_dims(self): - _TestTorchMixin._test_solve_batched_dims(self, lambda t: t.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_cholesky_solve(self): - _TestTorchMixin._test_cholesky_solve(self, lambda t: t.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_cholesky_solve_batched(self): - _TestTorchMixin._test_cholesky_solve_batched(self, lambda t: t.cuda()) - - @slowTest - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_cholesky_solve_batched_many_batches(self): - _TestTorchMixin._test_cholesky_solve_batched_many_batches(self, lambda t: t.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_cholesky_solve_batched_dims(self): - _TestTorchMixin._test_cholesky_solve_batched_dims(self, lambda t: t.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_cholesky_inverse(self): - _TestTorchMixin._test_cholesky_inverse(self, lambda t: t.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_cholesky(self): - _TestTorchMixin._test_cholesky(self, lambda t: t.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_cholesky_batched(self): - _TestTorchMixin._test_cholesky_batched(self, lambda t: t.cuda()) - - @slowTest - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_cholesky_batched_many_batches(self): - _TestTorchMixin._test_cholesky_batched_many_batches(self, lambda t: t.cuda()) - - def test_view(self): - _TestTorchMixin._test_view(self, lambda t: t.cuda()) - - def test_flip(self): - _TestTorchMixin._test_flip(self, use_cuda=True) - - def test_rot90(self): - _TestTorchMixin._test_rot90(self, use_cuda=True) - - def test_signal_window_functions(self): - _TestTorchMixin._test_signal_window_functions(self, device=torch.device('cuda')) - @skipIfRocm def test_fft_ifft_rfft_irfft(self): _TestTorchMixin._test_fft_ifft_rfft_irfft(self, device=torch.device('cuda')) @@ -2373,11 +2246,6 @@ def plan_cache_max_size(n, device=None): self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0 self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1 - # passes on ROCm w/ python 2.7, fails w/ python 3.6 - @skipIfRocm - def test_stft(self): - _TestTorchMixin._test_stft(self, device=torch.device('cuda')) - def test_multinomial(self): _TestTorchMixin._test_multinomial(self, torch.cuda.FloatTensor) @@ -2414,10 +2282,6 @@ def test_multinomial(self): samples = probs.multinomial(1000000, replacement=True) self.assertGreater(probs[samples].min().item(), 0) - @skipCUDANonDefaultStreamIf(True) - def test_multinomial_alias(self): - _TestTorchMixin._test_multinomial_alias(self, lambda t: t.cuda()) - @staticmethod def mute(): os.dup2(os.open(os.devnull, os.O_WRONLY), sys.stderr.fileno()) @@ -2455,25 +2319,6 @@ def test_multinomial_invalid_probs_cuda(self): self._spawn_method(test_method, torch.Tensor([1, 1, nan])) self._spawn_method(test_method, torch.Tensor([0, 1, 0])) - def test_broadcast(self): - _TestTorchMixin._test_broadcast(self, lambda t: t.cuda()) - - def test_contiguous(self): - _TestTorchMixin._test_contiguous(self, lambda t: t.cuda()) - - def test_broadcast_fused_matmul(self): - _TestTorchMixin._test_broadcast_fused_matmul(self, lambda t: t.cuda()) - - def test_broadcast_batched_matmul(self): - _TestTorchMixin._test_broadcast_batched_matmul(self, lambda t: t.cuda()) - - def test_index(self): - _TestTorchMixin._test_index(self, lambda t: t.cuda()) - - @skipCUDANonDefaultStreamIf(True) - def test_advancedindex(self): - _TestTorchMixin._test_advancedindex(self, lambda t: t.cuda()) - def test_advancedindex_mixed_cpu_cuda(self): def test(x, ia, ib): # test getitem @@ -2522,9 +2367,6 @@ def test(x, ia, ib): ib = ib.to(other_device) test(x, ia, ib) - def test_advancedindex_big(self): - _TestTorchMixin._test_advancedindex_big(self, lambda t: t.cuda()) - @slowTest @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") def test_huge_index(self): @@ -2534,14 +2376,6 @@ def test_huge_index(self): res_cpu = src.cpu()[idx.cpu()] self.assertEqual(res.cpu(), res_cpu) - def test_kthvalue(self): - _TestTorchMixin._test_kthvalue(self, device='cuda') - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_lu(self): - _TestTorchMixin._test_lu(self, lambda t: t.cuda(), pivot=False) - _TestTorchMixin._test_lu(self, lambda t: t.cuda(), pivot=True) - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") def test_lu_solve(self): _TestTorchMixin._test_lu_solve(self, lambda t: t.cuda(), pivot=False) @@ -2552,29 +2386,11 @@ def test_lu_solve_batched(self): _TestTorchMixin._test_lu_solve_batched(self, lambda t: t.cuda(), pivot=False) _TestTorchMixin._test_lu_solve_batched(self, lambda t: t.cuda(), pivot=True) - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - def test_lu_solve_batched_non_contiguous(self): - _TestTorchMixin._test_lu_solve_batched_non_contiguous(self, lambda t: t.cuda()) - - @slowTest - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_lu_solve_batched_many_batches(self): - _TestTorchMixin._test_lu_solve_batched_many_batches(self, lambda t: t.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - def test_lu_solve_batched_broadcasting(self): - _TestTorchMixin._test_lu_solve_batched_broadcasting(self, lambda t: t.cuda()) - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") def test_lu_unpack(self): _TestTorchMixin._test_lu_unpack(self, lambda t: t.cuda(), pivot=False) _TestTorchMixin._test_lu_unpack(self, lambda t: t.cuda(), pivot=True) - def test_dim_reduction(self): - _TestTorchMixin._test_dim_reduction(self, lambda t: t.cuda()) - def test_tensor_gather(self): _TestTorchMixin._test_gather(self, lambda t: t.cuda(), False) @@ -2606,12 +2422,6 @@ def test_max_with_inf(self): def test_min_with_inf(self): _TestTorchMixin._test_min_with_inf(self, (torch.half, torch.float, torch.double), 'cuda') - def test_rpow(self): - _TestTorchMixin._test_rpow(self, lambda x: x.cuda()) - - def test_remainder_overflow(self): - _TestTorchMixin._test_remainder_overflow(self, dtype=torch.int64, device='cuda') - def test_var(self): cpu_tensor = torch.randn(2, 3, 3) gpu_tensor = cpu_tensor.cuda() @@ -2661,7 +2471,6 @@ def test_var_stability(self): tensor = tensor.unsqueeze(1) self.assertEqual(tensor.var(0), 0.03125) - @skipIfRocm def test_digamma(self): def test(use_double=False): cpu_tensor = torch.randn(10, 10, 10) @@ -2690,7 +2499,6 @@ def test(use_double=False): norm_errors = (gpu_out - cpu_out.cuda()) / gpu_out self.assertEqual(norm_errors, expected_errors) - @skipIfRocm def test_polygamma(self): def test(use_double=False): cpu_tensor = torch.randn(10, 10, 10) @@ -2709,18 +2517,6 @@ def test(use_double=False): test(True) test(False) - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_symeig(self): - _TestTorchMixin._test_symeig(self, lambda t: t.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_svd(self): - _TestTorchMixin._test_svd(self, lambda t: t.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_svd_no_singularvectors(self): - _TestTorchMixin._test_svd_no_singularvectors(self, lambda t: t.cuda()) - def test_arange(self): for t in ['IntTensor', 'LongTensor', 'FloatTensor', 'DoubleTensor']: a = torch.cuda.__dict__[t]() @@ -2744,65 +2540,11 @@ def test_logspace(self): b = torch.logspace(1, 10, 10, 2) self.assertEqual(a, b.cuda()) - def test_lerp(self): - _TestTorchMixin._test_lerp(self, lambda t: t.cuda()) - - def test_diagonal(self): - _TestTorchMixin._test_diagonal(self, dtype=torch.float32, device='cuda') - - def test_diagflat(self): - _TestTorchMixin._test_diagflat(self, dtype=torch.float32, device='cuda') - - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") @skipCUDANonDefaultStreamIf(True) - def test_norm(self): - _TestTorchMixin._test_norm(self, device='cuda') - - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - @skipCUDANonDefaultStreamIf(True) - def test_nuclear_norm_axes_small_brute_force(self): - _TestTorchMixin._test_nuclear_norm_axes(self, device='cuda') - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - @skipCUDANonDefaultStreamIf(True) - def test_nuclear_norm_exceptions(self): - _TestTorchMixin._test_nuclear_norm_exceptions(self, device='cuda') - - def test_dist(self): - _TestTorchMixin._test_dist(self, device='cuda') - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_geqrf(self): - _TestTorchMixin._test_geqrf(self, lambda t: t.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - @skipCUDANonDefaultStreamIf(True) - def test_triangular_solve(self): - _TestTorchMixin._test_triangular_solve(self, lambda t: t.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") def test_triangular_solve_batched(self): _TestTorchMixin._test_triangular_solve_batched(self, lambda t: t.cuda()) - @slowTest - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_triangular_solve_batched_many_batches(self): - _TestTorchMixin._test_triangular_solve_batched_many_batches(self, lambda t: t.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_triangular_solve_batched_dims(self): - _TestTorchMixin._test_triangular_solve_batched_dims(self, lambda t: t.cuda()) - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_lstsq(self): - _TestTorchMixin._test_lstsq(self, 'cuda') - - @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") - def test_qr(self): - _TestTorchMixin._test_qr(self, lambda t: t.cuda()) - @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") def test_get_set_rng_state_all(self): states = torch.cuda.get_rng_state_all() @@ -2821,12 +2563,6 @@ def test_nvtx(self): torch.cuda.nvtx.mark("bar") torch.cuda.nvtx.range_pop() - def test_randperm_cuda(self): - _TestTorchMixin._test_randperm(self, device='cuda') - - def test_random_neg_values(self): - _TestTorchMixin._test_random_neg_values(self, use_cuda=True) - def test_bincount_cuda(self): _TestTorchMixin._test_bincount(self, device='cuda') # ensure CUDA code coverage @@ -2848,6 +2584,13 @@ def test_bincount_cuda(self): self.assertEqual(t.cpu().bincount(), t.bincount()) self.assertEqual(t.cpu().bincount(w_cpu), t.bincount(w)) + t = torch.zeros([10], dtype=torch.int32, device='cuda') + # 35488 * 65536 as int32 would cause overflow to negative value + # giving negative bin offset + t[0] = 35488 + counted = t.bincount(minlength=65536) + self.assertEqual(torch.sum(counted), 10) + def test_tiny_half_norm_(self): a = torch.arange(25).cuda().float() a /= 100000000 @@ -2910,9 +2653,6 @@ def test_large_trilu_indices(self): for test_args in tri_large_tests_args: _compare_large_trilu_indices(self, *test_args, device='cuda') - def test_triu_tril(self): - _TestTorchMixin._test_triu_tril(self, lambda t: t.cuda()) - def test_cuda_round(self): # test half-to-even a = [-5.8, -3.5, -2.3, -1.5, -0.5, 0.5, 1.5, 2.3, 3.5, 5.8] @@ -2953,6 +2693,84 @@ def test_cuda_kernel_loop_overflow_large(self): torch.cuda.synchronize() self.assertEqual(y[0, 0, 0, 2**31 - 2], expected) + @skipCUDANonDefaultStreamIf(True) + def test_streaming_backwards_sync(self): + default_stream = torch.cuda.current_stream() + stream = torch.cuda.Stream() + + class MultiplyInStream(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x * 2 + + @staticmethod + def backward(ctx, grad): + self.assertEqual(torch.cuda.current_stream(), stream) + # delays the operation in the the background stream + torch.cuda._sleep(1000 * 1000) + return grad * 2 + + x = torch.randn(5, 5, device='cuda', requires_grad=True) + with torch.cuda.stream(stream): + stream.wait_stream(default_stream) + output = MultiplyInStream.apply(x) + output.sum().backward() + + self.assertEqual(x.grad, torch.ones_like(x) * 2) + self.assertEqual(torch.cuda.current_stream(), default_stream) + + def test_streaming_backwards_multiple_streams(self): + + class StreamModel(torch.nn.Module): + def __init__(self): + super(StreamModel, self).__init__() + self.event = torch.cuda.Event() + self.stream0 = torch.cuda.Stream() + self.stream1 = torch.cuda.Stream() + + def forward(self, x): + x0 = x.clone() + torch._C._cuda_setStream(self.stream0._cdata) + y0 = x0 * 2 + self.event.record(stream=torch.cuda.current_stream()) + + torch._C._cuda_setStream(self.stream1._cdata) + y1 = x * 3 + self.stream1.wait_event(self.event) + return y0 + y1 + + stream = torch.cuda.Stream() + + def accum_hook(grad): + self.assertEqual(torch.cuda.current_stream(), stream) + + with torch.cuda.stream(stream): + x = torch.randn(5, 5, device='cuda', requires_grad=True) + x.register_hook(accum_hook) + torch.cuda.current_stream().wait_stream(stream) + model = StreamModel().cuda() + model(x).sum().backward() + + self.assertEqual(x.grad, torch.ones_like(x) * 5) + + @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") + def test_cuda_init_race(self): + # See https://github.com/pytorch/pytorch/issues/16559 + import subprocess + subprocess.check_call([sys.executable, '-c', """\ +import torch +import threading + +def worker(rank): + torch.tensor([1.]).cuda(rank) + +t1 = threading.Thread(target=worker, args=(0,)) +t2 = threading.Thread(target=worker, args=(1,)) +t1.start() +t2.start() +"""]) + + def load_ignore_file(): from os.path import join, dirname global ignores diff --git a/test/test_dataloader.py b/test/test_dataloader.py index b24ffc5c7a94e..d450e6005dc85 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -683,6 +683,31 @@ def error_worker_init_fn(_): raise RuntimeError("Error in worker_init_fn") +class BulkLoadingDataset(Dataset): + def __init__(self, length): + self.length = length + + def __getitem__(self, indices): + assert isinstance(indices, (list, tuple)) + return torch.as_tensor(indices) + + def __len__(self): + return self.length + + +class BulkLoadingSampler(torch.utils.data.Sampler): + def __init__(self, dataset, batch_size): + self.dataset = dataset + self.batch_size = batch_size + + def __iter__(self): + for x in torch.randperm(len(self.dataset)).split(self.batch_size): + yield x.tolist() + + def __len__(self): + return int(math.ceil(len(self.dataset) / float(self.batch_size))) + + @unittest.skipIf( TEST_WITH_TSAN, "Fails with TSAN with the following error: starting new threads after multi-threaded " @@ -764,6 +789,19 @@ def test_sequential_batch(self): self._test_sequential(DataLoader(self.dataset)) self._test_sequential(DataLoader(self.dataset, batch_size=2)) + def test_bulk_loading_nobatch(self): + n = 35 + bs = 4 + ds = BulkLoadingDataset(n) + sampler = BulkLoadingSampler(ds, batch_size=4) + + for num_workers in [0, 4]: + dl = DataLoader(ds, num_workers=num_workers, batch_size=None, sampler=sampler, pin_memory=TEST_CUDA) + self.assertFalse(dl._auto_collation) + samples = list(dl) + self.assertEqual(samples[0].is_pinned(), TEST_CUDA) + self.assertEqual(set(torch.cat(samples, 0).tolist()), set(range(n))) + def test_growing_dataset(self): dataset = [torch.ones(4) for _ in range(4)] dataloader_seq = DataLoader(dataset, shuffle=False) @@ -834,6 +872,15 @@ def test_invalid_ctor_args_combinations(self): with self.assertRaisesRegex(ValueError, "timeout option should be non-negative"): DataLoader(self.dataset, timeout=-1) + + # disable auto-batching + with self.assertRaisesRegex(ValueError, + "batch_size=None option disables auto-batching and is mutually exclusive"): + DataLoader(self.dataset, batch_size=None, shuffle=True) + with self.assertRaisesRegex(ValueError, + "batch_size=None option disables auto-batching and is mutually exclusive"): + DataLoader(self.dataset, batch_size=None, drop_last=True) + if torch.multiprocessing._supports_context: valid_ctx = list(torch.multiprocessing.get_all_start_methods())[-1] with self.assertRaisesRegex(ValueError, r"multi-process loading \(num_workers > 0\), but got"): @@ -1605,7 +1652,7 @@ def test_pin_memory(self): class NamedTupleDataset(Dataset): from collections import namedtuple - Batch = namedtuple('Batch', ['data', 'label']) + Batch = namedtuple('Batch', ['data', 'label', 'random_tensor']) Data = namedtuple('Data', ['positive', 'negative']) def __len__(self): @@ -1613,7 +1660,7 @@ def __len__(self): def __getitem__(self, ndx): return self.Batch(data=self.Data(positive=ndx, negative=-ndx), - label=str(ndx)) + label=str(ndx), random_tensor=torch.randn(3)) @unittest.skipIf( @@ -1625,12 +1672,22 @@ def setUp(self): super(TestNamedTupleDataLoader, self).setUp() self.dataset = NamedTupleDataset() - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - def test_collate_and_pin_memory_with_namedtuple(self): - loader = DataLoader(self.dataset, batch_size=2, pin_memory=True) + def test_dataloader_with_namedtuple(self): + # auto-collation + loader = DataLoader(self.dataset, batch_size=2, pin_memory=TEST_CUDA) + for batch in loader: + self.assertIsInstance(batch, NamedTupleDataset.Batch) + self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA) + self.assertIsInstance(batch.data, NamedTupleDataset.Data) + self.assertIsInstance(batch.data.positive, torch.Tensor) + self.assertEqual(batch.data.positive.is_pinned(), TEST_CUDA) + # no auto-collation + loader = DataLoader(self.dataset, batch_size=None, pin_memory=TEST_CUDA) for batch in loader: self.assertIsInstance(batch, NamedTupleDataset.Batch) + self.assertEqual(batch.random_tensor.is_pinned(), TEST_CUDA) self.assertIsInstance(batch.data, NamedTupleDataset.Data) + self.assertNotIsInstance(batch.data.positive, torch.Tensor) class SimpleCustomBatch(object): diff --git a/test/test_distributions.py b/test/test_distributions.py index 1e5dd5fdf387b..09073e970090b 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -651,7 +651,7 @@ def is_all_nan(tensor): class TestDistributions(TestCase): _do_cuda_memory_leak_check = True - _do_cuda_non_default_stream = False + _do_cuda_non_default_stream = True def _gradcheck_log_prob(self, dist_ctor, ctor_params): # performs gradient checks on log_prob @@ -2908,7 +2908,7 @@ def test_halfcauchy_shape_scalar_params(self): self.assertEqual(halfcauchy.sample(torch.Size((3, 2))).size(), torch.Size((3, 2))) self.assertEqual(halfcauchy.log_prob(self.scalar_sample).size(), - torch.Size()) + torch.Size()) self.assertEqual(halfcauchy.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) self.assertEqual(halfcauchy.log_prob(self.tensor_sample_2).size(), diff --git a/test/test_docs_coverage.py b/test/test_docs_coverage.py index 8189f1bcbaf21..8b95091d5a460 100644 --- a/test/test_docs_coverage.py +++ b/test/test_docs_coverage.py @@ -43,7 +43,7 @@ def test_torch(self): # below are some jit functions 'wait', 'fork', 'parse_type_comment', 'import_ir_module', 'import_ir_module_from_buffer', 'merge_type_from_type_comment', - 'parse_ir', + 'parse_ir', 'parse_schema', # below are symbols mistakely binded to torch.*, but should # go to torch.nn.functional.* instead diff --git a/test/test_fake_quant.py b/test/test_fake_quant.py index 1b7454e7edfcc..f1b6636139ef0 100644 --- a/test/test_fake_quant.py +++ b/test/test_fake_quant.py @@ -2,12 +2,11 @@ import torch.cuda import torch.jit import numpy as np -import unittest from hypothesis import given from hypothesis import strategies as st import hypothesis_utils as hu from hypothesis_utils import no_deadline -from common_utils import run_tests +from common_utils import run_tests, TestCase from torch.quantization import FakeQuantize # Reference method for fake quantize @@ -27,7 +26,7 @@ def _fake_quantize_per_tensor_affine_grad_reference(dY, X, scale, zero_point, qu NP_RANDOM_SEED = 19 tolerance = 1e-6 -class TestFakeQuantizePerTensorAffine(unittest.TestCase): +class TestFakeQuantizePerTensorAffine(TestCase): # NOTE: Tests in this class are decorated with no_deadline # to prevent spurious failures due to cuda runtime initialization. diff --git a/test/test_function_schema.py b/test/test_function_schema.py new file mode 100644 index 0000000000000..0f0a8640e9db5 --- /dev/null +++ b/test/test_function_schema.py @@ -0,0 +1,104 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import torch +from common_utils import TestCase, run_tests +from torch._C import parse_schema + + +class TestFunctionSchema(TestCase): + def test_serialize_and_deserialize(self): + schemas = torch._C._jit_get_all_schemas() + # so far we have around 1700 registered schemas + self.assertGreater(len(schemas), 1000) + for schema in schemas: + parsed_schema = parse_schema(str(schema)) + self.assertEqual(parsed_schema, schema) + self.assertTrue(parsed_schema.is_backward_compatible_with(schema)) + + def test_backward_compatible_args(self): + old_schema = parse_schema('any(Tensor self, int dim) -> Tensor') + new_schema = parse_schema('any(Tensor self, int? dim) -> Tensor') + self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) + self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) + new_schema = parse_schema('any(Tensor self, int dim=5) -> Tensor') + self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) + self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) + new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> Tensor') + self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) + self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) + + def test_backward_compatible_kwargs(self): + old_schema = parse_schema('any(Tensor self, *, Tensor out) -> Tensor') + new_schema = parse_schema('any(Tensor self, *, bool extra1=True, Tensor out, bool extra2=False) -> Tensor') + self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) + self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) + new_schema = parse_schema('any(Tensor self, Tensor out) -> Tensor') + self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) + self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) + + def test_backward_compatible_ret(self): + old_schema = parse_schema('any(Tensor self) -> Tensor?') + new_schema = parse_schema('any(Tensor self) -> Tensor') + self.assertTrue(new_schema.is_backward_compatible_with(old_schema)) + self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) + + def test_backward_incompatible_name(self): + old_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> Tensor') + new_schema = parse_schema('any_(Tensor self, int dim, bool keepdim=False) -> Tensor') + self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) + + def test_backward_incompatible_vararg(self): + old_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> Tensor') + new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False, ...) -> Tensor') + self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) + + def test_backward_incompatible_returns(self): + old_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> Tensor') + new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> (Tensor, ...)') + self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) + new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> int') + self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) + new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> Tensor?') + self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + self.assertTrue(old_schema.is_backward_compatible_with(new_schema)) + new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor)') + self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) + new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> Tensor out') + self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) + + def test_backward_incompatible_args(self): + old_schema = parse_schema('any(Tensor self, int[] dims, bool keepdim=False) -> Tensor') + new_schema = parse_schema('any(Tensor s, int[] dims, bool keepdim=False) -> Tensor') + self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) + new_schema = parse_schema('any(Tensor self, int[3] dims, bool keepdim=False) -> Tensor') + self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) + new_schema = parse_schema('any(Tensor self, int[](a) dims, bool keepdim=False) -> Tensor') + self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) + new_schema = parse_schema('any(Tensor self, int dims, bool keepdim=False) -> Tensor') + self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) + new_schema = parse_schema('any(Tensor self, int[] dim, bool keepdim=False, bool? extra) -> Tensor') + self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) + + def test_backward_incompatible_kwargs(self): + old_schema = parse_schema('any(Tensor self, int[] dims, *, bool keepdim=False) -> Tensor') + new_schema = parse_schema('any(Tensor self, int[] dims, *, bool keepdim) -> Tensor') + self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + self.assertTrue(old_schema.is_backward_compatible_with(new_schema)) + new_schema = parse_schema('any(Tensor self, int[] dims, *, bool keepdim=False, bool extra) -> Tensor') + self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + self.assertFalse(old_schema.is_backward_compatible_with(new_schema)) + + +if __name__ == '__main__': + run_tests() diff --git a/test/test_jit.py b/test/test_jit.py index 36d34ca4159bb..315670afa8b22 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -39,7 +39,7 @@ from test_module.no_future_div import div_int_nofuture, div_float_nofuture # Standard library -from collections import namedtuple +from collections import namedtuple, OrderedDict from copy import deepcopy from functools import wraps from itertools import product, chain @@ -86,7 +86,23 @@ RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1 PY35 = sys.version_info >= (3, 5) -WINDOWS = sys.platform == 'win32' + +def default_tensor_type(type): + type_str = torch.typename(type) + + def decorator(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + old_type = torch.Tensor().type() + torch.set_default_tensor_type(type_str) + try: + return fn(*args, **kwargs) + finally: + torch.set_default_tensor_type(old_type) + + return wrapper + + return decorator def LSTMCellF(input, hx, cx, *params): @@ -902,6 +918,8 @@ def forward(self, x): def get_forward_graph(m): return m._get_method("forward").graph torch._C._jit_pass_constant_propagation(get_forward_graph(m._c)) + # TODO: change to use module level constant prop + torch._C._jit_pass_constant_propagation(m._c._get_module('conv')._get_method('conv2d_forward').graph) qconfig_dict = { '': QConfig( @@ -910,14 +928,12 @@ def get_forward_graph(m): } torch._C._jit_pass_insert_observers(m._c, "forward", qconfig_dict) assert len([x for x, _ in m._c._get_modules() - if x.startswith('observer_for_')]) == 2, \ - 'Expected to have 2 observer submodules' - FileCheck().check('ClassType = prim::GetAttr[name="observer_for_') \ - .check_next('prim::CallMethod[name="forward"](%observer_for_') \ + if x.startswith('observer_for_')]) == 0, \ + 'Expected to have 0 observer submodules' + FileCheck().check_not('ClassType = prim::GetAttr[name="observer_for_') \ .check('ClassType = prim::GetAttr[name="conv"](%self)') \ .check_next('Tensor = prim::CallMethod[name="forward"]') \ - .check('ClassType = prim::GetAttr[name="observer_for_') \ - .check_next('prim::CallMethod[name="forward"](%observer_for_') \ + .check_not('ClassType = prim::GetAttr[name="observer_for_') \ .run(str(get_forward_graph(m._c))) assert len([x for x, _ in m._c._get_module('conv')._get_modules() if x.startswith('observer_for_')]) == 3, \ @@ -926,11 +942,10 @@ def get_forward_graph(m): .check_next('prim::CallMethod[name="forward"](%observer_for_') \ .check('ClassType = prim::GetAttr[name="observer_for_') \ .check_next('prim::CallMethod[name="forward"](%observer_for_') \ - .check_next('Tensor = prim::CallMethod[name="conv2d_forward"](%self') \ + .check('Tensor = aten::conv2d') \ .check('ClassType = prim::GetAttr[name="observer_for_') \ .check_next('prim::CallMethod[name="forward"](%observer_for_') \ - .run(str(get_forward_graph(m._c._get_module("conv")))) - + .run(str(m._c._get_module("conv")._get_method('conv2d_forward').graph)) @_tmp_donotuse_dont_inline_everything def test_insert_observers_child_qconfig(self): @@ -993,8 +1008,10 @@ def get_forward(c): qconfig_dict) # check m is not observed check_not_observed(get_forward(m._c).graph) - # check conv is observed - check_observed(get_forward(m._c._get_module('conv')).graph) + # check conv.forward is observed + check_not_observed(get_forward(m._c._get_module('conv')).graph) + # check conv.conv2d_forward is observed + check_observed(m._c._get_module('conv')._get_method('conv2d_forward').graph) # check sub is not observed check_not_observed(get_forward(m._c._get_module('sub')).graph) # check forward of sub.linear is observed @@ -1035,11 +1052,9 @@ def forward(self, x): def get_forward(m): return m._c._get_method("forward") - def test_module(module, relu_call): + def test_module(module, relu_call, num_observers): m = torch.jit.script(module()) observer = torch.jit.script(Observer()) - - torch._C._jit_pass_constant_propagation(get_forward(m).graph) qconfig_dict = { '': QConfig( @@ -1048,19 +1063,20 @@ def test_module(module, relu_call): } torch._C._jit_pass_insert_observers(m._c, "forward", qconfig_dict) assert len([x for x, _ in m._c._get_modules() - if x.startswith('observer_for_')]) == 2, \ - 'Expected to have 2 observer submodules' - FileCheck().check('ClassType = prim::GetAttr[name="observer_for_') \ - .check_next('prim::CallMethod[name="forward"](%observer_for_') \ - .check('ClassType = prim::GetAttr[name="conv"]') \ - .check_next('prim::CallMethod[name="forward"]') \ - .check_not('ClassType = prim::GetAttr[name="observer_for_') \ - .check(relu_call) \ - .check('ClassType = prim::GetAttr[name="observer_for_') \ - .check_next('prim::CallMethod[name="forward"](%observer_for_') \ - .run(str(get_forward(m).graph)) - test_module(M, 'prim::CallFunction(') - test_module(M2, 'prim::CallMethod[name="forward"]') + if x.startswith('observer_for_')]) == num_observers, \ + 'Expected to have ' + str(num_observers) + ' observer submodules' + c = FileCheck().check('ClassType = prim::GetAttr[name="conv"]') \ + .check_next('prim::CallMethod[name="forward"]') \ + .check_not('ClassType = prim::GetAttr[name="observer_for_') \ + .check(relu_call) + if num_observers == 1: + c = c.check('ClassType = prim::GetAttr[name="observer_for_') \ + .check_next('prim::CallMethod[name="forward"](%observer_for_') + c.run(str(get_forward(m).graph)) + # TODO: add checks for conv and relu later, graph looks correct but this pr + # has too many changes already + test_module(M, 'prim::CallFunction(', 1) + test_module(M2, 'prim::CallMethod[name="forward"]', 0) @_tmp_donotuse_dont_inline_everything def test_insert_quant_dequant(self): @@ -1085,7 +1101,6 @@ def forward(self, x): m = torch.jit.script(M()) observer = torch.jit.script(Observer()) - torch._C._jit_pass_constant_propagation(m.graph) qconfig_dict = { '': QConfig( @@ -1103,20 +1118,27 @@ def get_forward(m): m._c = torch._C._jit_pass_insert_quant_dequant(m._c, "forward") get_forward(m)(data) + FileCheck().check_not("aten::quantize_linear") \ + .check("prim::CallMethod[name=\"forward\"]") \ + .check_not("aten::quantize_linear") \ + .check("return") \ + .run(str(get_forward(m).graph)) FileCheck().check("aten::quantize_linear") \ .check_next("aten::int_repr") \ .check_next("aten::_dequantize_linear") \ - .check("prim::CallMethod[name=\"forward\"]") \ + .check("aten::conv2d") \ .check("aten::quantize_linear") \ .check_next("aten::int_repr") \ .check_next("aten::_dequantize_linear") \ .check("return") \ - .run(str(get_forward(m).graph)) + .run(str(m._c._get_module('conv')._get_method('conv2d_forward').graph)) def test_quant_fusion(self): - input_str = """ + input_strs = [ + # aten::conv2d --> quantized::conv2d + """ graph(%a, %w, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, -%b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype, %c, %d, %e, %f): +%r_scale, %r_zero_point, %r_dtype, %c, %d, %e, %f): %a_quant = aten::quantize_linear(%a, %a_scale, %a_zero_point, %a_dtype) # CHECK-NOT: aten::int_repr %a_intrepr = aten::int_repr(%a_quant) @@ -1127,27 +1149,104 @@ def test_quant_fusion(self): %w_intrepr = aten::int_repr(%w_quant) # CHECK-NOT: aten::_dequantize_linear %w_dequant = aten::_dequantize_linear(%w_intrepr, %w_scale, %w_zero_point, %w_dtype) + # CHECK: quantized::conv_prepack + # CHECK: quantized::conv2d + # CHECK-NOT: aten::conv2d + %r = aten::conv2d(%a_dequant, %w_dequant, %b, %c, %d, %e, %f) + # CHECK-NOT: aten::quantize_linear + %r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype) + # CHECK: aten::int_repr + %r_intrepr = aten::int_repr(%r_quant) + # CHECK: aten::_dequantize_linear + %r_dequant = aten::_dequantize_linear(%r_intrepr, %r_scale, %r_zero_point, %r_dtype) + return (%r_dequant)""", + # addmm -> quantized::linear + """ +graph(%a, %w, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, +%r_scale, %r_zero_point, %r_dtype, %4): + %a_quant = aten::quantize_linear(%a, %a_scale, %a_zero_point, %a_dtype) # CHECK-NOT: aten::int_repr - %b_quant = aten::quantize_linear(%b, %b_scale, %b_zero_point, %b_dtype) - %b_intrepr = aten::int_repr(%b_quant) + %a_intrepr = aten::int_repr(%a_quant) # CHECK-NOT: aten::_dequantize_linear - %b_dequant = aten::_dequantize_linear(%b_intrepr, %b_scale, %b_zero_point, %b_dtype) - # CHECK: quantized::fbgemm_conv_prepack - # CHECK: quantized::fbgemm_conv2d - # CHECK-NOT: aten::conv2d - %r = aten::conv2d(%a_dequant, %w_dequant, %b_dequant, %c, %d, %e, %f) + %a_dequant = aten::_dequantize_linear(%a_intrepr, %a_scale, %a_zero_point, %a_dtype) + %w_quant = aten::quantize_linear(%w, %w_scale, %w_zero_point, %w_dtype) + # CHECK-NOT: aten::int_repr + %w_intrepr = aten::int_repr(%w_quant) + # CHECK-NOT: aten::_dequantize_linear + %w_dequant = aten::_dequantize_linear(%w_intrepr, %w_scale, %w_zero_point, %w_dtype) + # CHECK: aten::t + # CHECK: quantized::linear_prepack + # CHECK: quantized::linear + # CHECK-NOT: aten::addmm + %r = aten::addmm(%b, %a_dequant, %w_dequant, %4, %4) # CHECK-NOT: aten::quantize_linear %r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype) # CHECK: aten::int_repr %r_intrepr = aten::int_repr(%r_quant) # CHECK: aten::_dequantize_linear %r_dequant = aten::_dequantize_linear(%r_intrepr, %r_scale, %r_zero_point, %r_dtype) - return (%r_dequant) -) -""" - graph = parse_ir(input_str) - torch._C._jit_pass_quant_fusion(graph) - FileCheck().run(input_str, graph) + return (%r_dequant)""", + # matmul(with bias) -> quantized::linear + """ +graph(%a, %w, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, +%r_scale, %r_zero_point, %r_dtype, %4): + %a_quant = aten::quantize_linear(%a, %a_scale, %a_zero_point, %a_dtype) + # CHECK-NOT: aten::int_repr + %a_intrepr = aten::int_repr(%a_quant) + # CHECK-NOT: aten::_dequantize_linear + %a_dequant = aten::_dequantize_linear(%a_intrepr, %a_scale, %a_zero_point, %a_dtype) + %w_quant = aten::quantize_linear(%w, %w_scale, %w_zero_point, %w_dtype) + # CHECK-NOT: aten::int_repr + %w_intrepr = aten::int_repr(%w_quant) + # CHECK-NOT: aten::_dequantize_linear + %w_dequant = aten::_dequantize_linear(%w_intrepr, %w_scale, %w_zero_point, %w_dtype) + # CHECK-NOT: aten::int_repr + # CHECK-NOT: aten::_dequantize_linear + # CHECK: aten::t + # CHECK: quantized::linear_prepack + # CHECK: quantized::linear + # CHECK-NOT: aten::addmm + %output = aten::matmul(%a_dequant, %w_dequant) + %r = aten::add_(%output, %b, %4) + # CHECK-NOT: aten::quantize_linear + %r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype) + # CHECK: aten::int_repr + %r_intrepr = aten::int_repr(%r_quant) + # CHECK: aten::_dequantize_linear + %r_dequant = aten::_dequantize_linear(%r_intrepr, %r_scale, %r_zero_point, %r_dtype) + return (%r_dequant)""", + # matmul(without bias) -> quantized::linear + """ +graph(%a, %w, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, +%r_scale, %r_zero_point, %r_dtype): + %a_quant = aten::quantize_linear(%a, %a_scale, %a_zero_point, %a_dtype) + # CHECK-NOT: aten::int_repr + %a_intrepr = aten::int_repr(%a_quant) + # CHECK-NOT: aten::_dequantize_linear + %a_dequant = aten::_dequantize_linear(%a_intrepr, %a_scale, %a_zero_point, %a_dtype) + %w_quant = aten::quantize_linear(%w, %w_scale, %w_zero_point, %w_dtype) + # CHECK-NOT: aten::int_repr + %w_intrepr = aten::int_repr(%w_quant) + # CHECK-NOT: aten::_dequantize_linear + %w_dequant = aten::_dequantize_linear(%w_intrepr, %w_scale, %w_zero_point, %w_dtype) + # CHECK: aten::t + # CHECK: prim::Constant() + # CHECK: quantized::linear_prepack + # CHECK: quantized::linear + # CHECK-NOT: aten::matmul + %r = aten::matmul(%a_dequant, %w_dequant) + # CHECK-NOT: aten::quantize_linear + %r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype) + # CHECK: aten::int_repr + %r_intrepr = aten::int_repr(%r_quant) + # CHECK: aten::_dequantize_linear + %r_dequant = aten::_dequantize_linear(%r_intrepr, %r_scale, %r_zero_point, %r_dtype) + return (%r_dequant)""" + ] + for input_str in input_strs: + graph = parse_ir(input_str) + torch._C._jit_pass_quant_fusion(graph) + FileCheck().run(input_str, graph) @_tmp_donotuse_dont_inline_everything def test_foldbn_trivial(self): @@ -1264,6 +1363,53 @@ def forward(self, x): FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 1, exactly=True) \ .run(str(get_forward(m.sub).graph)) + def test_fuse_linear(self): + input_strs = [""" +graph(%input, %weight, %bias, %4): + # CHECK-NOT: aten::t + # CHECK-NOT: aten::addmm + # CHECK: aten::linear + %weight_t = aten::t(%weight) + %res = aten::addmm(%bias, %input, %weight_t, %4, %4) + return (%res)""", """ +graph(%input, %weight, %bias, %4): + # CHECK-NOT: aten::t + # CHECK-NOT: aten::matmul + # CHECK-NOT: aten::add_ + # CHECK: aten::linear + %weight_t = aten::t(%weight) + %output = aten::matmul(%input, %weight_t) + %res = aten::add_(%output, %bias, %4) + return (%res)""", """ +graph(%input, %weight): + # CHECK-NOT: aten::t + # CHECK-NOT: aten::matmul + # CHECK: aten::linear + %weight_t = aten::t(%weight) + %output = aten::matmul(%input, %weight_t) + return (%output)"""] + for input_str in input_strs: + graph = parse_ir(input_str) + torch._C._jit_pass_fuse_linear(graph) + FileCheck().run(input_str, graph) + + @_tmp_donotuse_dont_inline_everything + def test_fold_quantize(self): + class M(torch.nn.Module): + def __init__(self): + super(M, self).__init__() + self.weight = torch.nn.Parameter(torch.tensor([2], dtype=torch.float)) + + def forward(self, x): + return torch.quantize_linear(self.weight, 2.0, 0, torch.quint8) + + m = torch.jit.script(M()) + torch._C._jit_pass_fold_quantize(m._c, 'forward') + self.assertTrue(m._c._has_attribute('_quantized_weight')) + FileCheck().check_not('GetAttr[name="weight"]') \ + .check('GetAttr[name="_quantized_weight"]') \ + .run(m._c._get_method('forward').graph) + def test_pattern_based_rewrite(self): # mul(mul(mul(mul(x,y),z),x),y) --> mul(mul(mulmul(x,y,z), x), y) --> # --> mulmul(mulmul(x,y,z), x, y) @@ -1549,6 +1695,51 @@ def test_trace_size(self): def test_trace_size_with_grad(self): self.do_trace_size(True) + def do_trace_arange(self, requires_grad): + def arange(x): + return torch.arange(x.shape[0]) + + def arange_scalar(x): + return torch.arange(12) + + def arange_start_end(x): + return torch.arange(start=x.shape[0], end=x.shape[0] + 5) + + x = torch.randn(5, 3, 2, requires_grad=requires_grad) + y = torch.randn(8, 2, 4, requires_grad=requires_grad) + + # Check that it behaves as expected + traced_arange = torch.jit.trace(arange, x) + self.assertEqual(traced_arange(y), arange(y)) + self.assertEqual(traced_arange(x), arange(x)) + + traced_arange_scalar = torch.jit.trace(arange_scalar, x) + self.assertEqual(traced_arange_scalar(y), arange_scalar(y)) + self.assertEqual(traced_arange_scalar(x), arange_scalar(x)) + + traced_arange_start_end = torch.jit.trace(arange_start_end, x) + self.assertEqual(traced_arange_start_end(y), arange_start_end(y)) + self.assertEqual(traced_arange_start_end(x), arange_start_end(x)) + + def test_trace_arange(self): + self.do_trace_arange(False) + + # test the different graph_executor path that happens when + # gradients are required and sizes are involved + def test_trace_arange_with_grad(self): + self.do_trace_arange(True) + + # Test that a trace of torch.full(x.shape) doesn't store the shape as a constant + def test_trace_full_dynamic_shape(self): + def full_with_shape_like(x): + return torch.full(x.shape, 2) + + x = torch.randn(3, 4) + ge = torch.jit.trace(full_with_shape_like, example_inputs=x) + y = torch.randn(2, 7) + self.assertEqual(ge(y).shape, y.shape) + self.assertEqual(ge(x).shape, x.shape) + def test_trace_casts(self): casts = [ lambda x: x.byte(), @@ -1835,7 +2026,6 @@ def doit(x, y): for node in g.nodes(): self.assertTrue(g2.findNode(node.kind()) is not None) - @unittest.skipIf(IS_WINDOWS, "NYI: JIT tests not yet supported on windows") @unittest.skipIf(IS_SANDCASTLE, "gtest runs these in sandcastle") @unittest.skipIf(RUN_CUDA, "covered by test_cpp_cuda") @skipIfRocm @@ -1845,7 +2035,6 @@ def test_cpp(self): torch._C._jit_run_cpp_tests(run_cuda=False) tests_setup.shutdown() - @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA") @skipIfRocm def test_cpp_cuda(self): @@ -1870,8 +2059,6 @@ def test_dropout(self): self.assertEqual(outputs, m(*inputs)) @unittest.skipIf(not RUN_CUDA, "test_dropout_cuda require CUDA") - @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") - @skipIfRocm def test_dropout_cuda(self): # Dropout AD is dispatched to _fused_dropout in CUDA case, # which is not included in TestJitGeneratedFunctional @@ -2039,12 +2226,11 @@ def foo(a): def test_ge_unoptimized(self): self.run_ge_tests(False, False) - @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle") + @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle") @enable_cpu_fuser def test_ge_optimized(self): self.run_ge_tests(True, False) - @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "requires CUDA") def test_ge_cuda(self): self.run_ge_tests(True, True) @@ -2135,7 +2321,6 @@ def outer(x, y): self.assertGraphContains(fn.graph, kind='aten::einsum') self.assertEqual(fn(x, y), outer(x, y)) - @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") @unittest.skipIf(not RUN_CUDA, "calls .cuda()") def test_traced_module_cuda(self): class Model(nn.Module): @@ -3185,6 +3370,22 @@ def foo(x): else: cu.define(full) + def test_namedtuple_python(self): + MyTuple = namedtuple('MyTuple', ['a']) + + @torch.jit.unused + def fn(): + # type: () -> MyTuple + return MyTuple(1) + + # Only check compilation + @torch.jit.script + def fn2(): + # type: () -> MyTuple + return fn() + + FileCheck().check("NamedTuple").run(fn2.graph) + def test_inherit_method(self): class A(torch.jit.ScriptModule): def __init__(self): @@ -3857,12 +4058,22 @@ def f_grad(x): self.checkScript(f_grad, (y,)) def test_tensor_data(self): - x = torch.randn(3, 4) + x = torch.randn(3, 4, requires_grad=True) + y = torch.randn(4, 5) def f_data(x): return x.data - self.checkScript(f_data, (x,)) + scripted_f_data = torch.jit.script(f_data) + + scripted_x = scripted_f_data(x) + self.assertEqual(scripted_x, f_data(x)) + self.assertEqual(scripted_x.requires_grad, False) + + scripted_y = scripted_f_data(y) + self.assertEqual(scripted_y, f_data(y)) + self.assertEqual(scripted_x.requires_grad, False) + def test_tensor_dtype(self): x_byte = torch.empty(34, 56, 78, dtype=torch.uint8) @@ -4848,7 +5059,11 @@ def fn(x): x = torch.zeros(3, 4, dtype=torch.long) graph = _propagate_shapes(fn.graph, (x,), False) - FileCheck().check('Long(*, *) = aten::add').run(graph) + default = torch.get_default_dtype() + if(default == torch.float): + FileCheck().check('Float(*, *) = aten::add').run(graph) + else: + FileCheck().check('Double(*, *) = aten::add').run(graph) def test_integral_shape_inference(self): cu = torch.jit.CompilationUnit(''' @@ -4861,7 +5076,7 @@ def test_integral_shape_inference(a): self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs) @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser') - @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle") + @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle") @enable_cpu_fuser def test_batchnorm_fuser_cpu(self): code = ''' @@ -4890,7 +5105,7 @@ def test_batchnorm_fuser_cpu(self): FileCheck().check('sqrtf').run(code) @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser') - @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle") + @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle") @enable_cpu_fuser def test_fuser_double_float_codegen(self): fns = ['log', 'log10', 'log1p', 'log2', 'lgamma', 'exp', 'expm1', 'erf', @@ -4944,7 +5159,7 @@ def test_dispatch(op, expects, dtype, binary=False): test_dispatch(fn, lookup_c_equivalent_fn(fn) + 'f(', torch.float, binary=True) @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser') - @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle") + @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle") @enable_cpu_fuser def test_fuser_double_literal_precision(self): code = ''' @@ -6046,7 +6261,7 @@ def func(t): return {lhs} {op} {rhs} ''') - def test(op, const, swap_args): + def test(op, tensor, const, swap_args, template=template): args = ('t', const) if swap_args: args = (const, 't') @@ -6055,7 +6270,11 @@ def test(op, const, swap_args): scope = {} execWrapper(code, globals(), scope) cu = torch.jit.CompilationUnit(code) - self.assertEqual(cu.func(tensor), scope['func'](tensor)) + message = 'with code `{} {} {}` and t={}'.format(args[0], op, args[1], tensor) + res1 = cu.func(tensor) + res2 = scope['func'](tensor) + self.assertEqual(res1, res2, message + "\nres1=" + str(res1) + "\nres2=" + str(res2)) + self.assertEqual(res1.dtype, res2.dtype, message + "\nres1=" + str(res1) + "\nres2=" + str(res2)) var_int = [2, -2] var_float = [1.4321, -1.2] @@ -6080,7 +6299,7 @@ def test(op, const, swap_args): if op == '%' and swap_args is True: continue - test(op, const, swap_args) + test(op, tensor, const, swap_args) def test_tensor_number_math(self): self._test_tensor_number_math() @@ -6205,7 +6424,7 @@ def func(): return ten1 ''') - lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]", + lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]", "(1, 1)", "torch.jit.annotate(List[int], [])", "[2.5, 2.5]", "[[2], [2]]", "[[-.5], [2.2]]", "[[False], [True]]"] dtypes = ["", ", dtype=torch.float", ", dtype=torch.double", ", dtype=torch.half", @@ -7334,17 +7553,29 @@ def test_check_not(): with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'): fb.run("22 1 22") - def _dtype_to_expect(self, dtype, dim=0): - param = ', '.join(['*'] * dim) - param = '(' + param + ')' + def _dtype_to_jit_name(self, dtype): if(dtype == torch.float32): - return "Float" + param + return "Float" if(dtype == torch.float64): - return "Double" + param + return "Double" if(dtype == torch.int64): - return "Long" + param + return "Long" if(dtype == torch.int32): - return "Int" + param + return "Int" + if(dtype == torch.bool): + return "Bool" + raise RuntimeError('dtype not handled') + + def _dtype_to_expect(self, dtype, dim=0): + param = ', '.join(['*'] * dim) + param = '(' + param + ')' + jit_type = self._dtype_to_jit_name(dtype) + if dim >= 0: + return jit_type + param + # special case representing wrapped number + else: + return jit_type.lower() + def _test_dtype_op_shape(self, ops, args, input_dims=1): if input_dims < 1: @@ -7396,6 +7627,107 @@ def test_dtype_op_shape2(self): self._test_dtype_op_shape(ops, args=[1], input_dims=4) + + def _test_binary_op_shape(self, ops, input_dims=1): + + dtypes = [torch.float32, torch.float64, torch.int64, torch.int32, torch.bool] + + if input_dims == 0: + shape = '1' + else: + shape = '[' + ('1,' * 4) + ']' + for _ in range(1, input_dims): + shape = '[' + ",".join([shape] * 4) + ']' + + template = dedent(''' + def func(): + arg1 = {} + arg2 = {} + return torch.{}(arg1, arg2) + ''') + + args = [] + for dtype in dtypes: + args = args + ["torch.tensor({}, dtype={})".format(shape, dtype)] + args = args + [1, 1.5] + + def isBool(arg): + return type(arg) == bool or (type(arg) == str and "torch.bool" in arg) + + for op in ops: + for first_arg in args: + for second_arg in args: + # subtract not supported for bool + if (op == 'sub' or op == 'div') and (isBool(first_arg) or isBool(second_arg)): + continue + # div not implemneted correctly for mixed-type or in params + if (op == 'div' and (type(first_arg) != type(second_arg) or type(first_arg) == int)): + continue + return_line = "torch.{}({}, {})".format(op, first_arg, second_arg) + # uncomment for debugging a failed test: + # print("testing {}".format(return_line)) + code = template.format(first_arg, second_arg, op) + scope = {} + exec(code, globals(), scope) + non_jit_result = scope['func']() + + cu = torch.jit.CompilationUnit(code) + graph = cu.func.graph + torch._C._jit_pass_complete_shape_analysis(graph, (), False) + # use dim=-1 to represent a python/jit scalar. + dim = -1 if type(first_arg) != str and type(second_arg) != str else non_jit_result.dim() + dtype = non_jit_result.dtype + # jit only supports int/float scalars. + if dim < 0: + if dtype == torch.int64: + dtype = torch.int32 + if dtype == torch.float64: + dtype = torch.float32 + expect = self._dtype_to_expect(dtype, dim) + jit_output = next(graph.outputs()) + + check = FileCheck() + check.check(expect).run(str(jit_output)) + + def test_binary_op_shape(self): + self._test_binary_op_shape(['mul', 'div', 'add', 'sub'], 0) + self._test_binary_op_shape(['mul', 'div', 'add', 'sub'], 3) + + @default_tensor_type(torch.FloatTensor) + def test_wrapped_number(self): + # Scalar's get converted to 'wrapped' tensors of default tensor type. + # Wrapped tensors behave differently in certain promotion operations: + # float_tensor * double -> float but wrapped_float * double -> double. + # This can cause issues in check-trace if not handled correctly in + # `aten::isclose()`. + + def foobar(): + x = -10000.0 + result = x * torch.ones(1, dtype=torch.float) + return result + scripted = torch.jit.trace(foobar, (), check_trace=True) + + def test_no_dtype_shape(self): + + @torch.jit.script + def foo(x): + scalar_number = x.item() + return x.add(scalar_number) + + @torch.jit.script + def foo2(x): + scalar_number = x.item() + return torch.tensor(1).add(scalar_number) + + t = torch.tensor(5) + g = foo.graph_for(t) + type = next(g.outputs()) + self.assertTrue(type.type() == torch._C.TensorType.get()) + g2 = foo2.graph_for(t) + type = next(g.outputs()) + self.assertTrue(type.type() == torch._C.TensorType.get()) + + def test_filecheck_parse(self): def test_check(): file = """ @@ -7638,6 +7970,59 @@ def forward(self): m = M() self.assertEqual(m(), 10) + def test_moduledict(self): + from collections import OrderedDict + + class Inner(torch.nn.Module): + def forward(self, x): + return x + 10 + + class Inner2(torch.nn.Module): + def forward(self, x): + return x * 2 + + class Inner3(torch.nn.Module): + def forward(self, x): + return (x - 4) * 3 + + class M(torch.nn.Module): + __constants__ = ['moduledict'] + + def __init__(self): + super(M, self).__init__() + modules = OrderedDict([ + ('one', Inner()), + ('two', Inner2()), + ('three', Inner3()), + ]) + self.moduledict = nn.ModuleDict(modules) + + def forward(self, x, skip_name): + # type: (Tensor, str) + names = torch.jit.annotate(List[str], []) + values = [] + for name in self.moduledict: + names.append(name) + + for name, mod in self.moduledict.items(): + if name != skip_name: + names.append(name) + x = mod(x) + values.append(x) + + for mod in self.moduledict.values(): + x = mod(x) + values.append(x) + + for key in self.moduledict.keys(): + names.append(key) + + return x, names + + for name in ["", "one", "two", "three"]: + inp = torch.tensor(1) + self.checkModule(M(), (inp, name)) + def test_script_module_for2(self): class Sub(torch.jit.ScriptModule): def __init__(self): @@ -8349,6 +8734,15 @@ def test_if_tracing(x): self.checkScript(test_if_tracing, (inp,)) + def test_is_scripting(self): + def foo(): + return torch.jit.is_scripting() + + self.assertFalse(foo()) + scripted = torch.jit.script(foo) + FileCheck().check("is_scripting").run(scripted.graph) + self.assertTrue(scripted()) + def test_script_outputs(self): with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"): @torch.jit.script @@ -8905,7 +9299,6 @@ def foo(self, bar, input): m = M() self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64))) - @unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows") def test_trace_of_script(self): @torch.jit.script def foo(a, c): @@ -9934,6 +10327,11 @@ def create(self): self.assertEqual(r.dtype, torch.float) self.assertEqual(torch.zeros([1, 1, 2], dtype=torch.float), r) + def fn(): + return torch.zeros((1, 2, 3)) + + self.checkScript(fn, ()) + def test_vararg_zeros(self): def foo(): return torch.zeros(3, 4, 5, dtype=torch.int) @@ -11822,13 +12220,6 @@ def forward(self, x): FooMod(), (torch.rand(3, 4),), f, operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK) - def test_trace_checker_arange_as_constant(self): - with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'): - @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(4, 5),)]) - def foo(x): - y = torch.arange(0, x.shape[0]).double() - return x + y.unsqueeze(1) - @suppress_warnings def test_trace_checker_dot_data(self): with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Tensor-valued Constant nodes differed in value ' @@ -14003,37 +14394,71 @@ def forward(self, x): out = m(torch.ones(5, 5, 5).cuda()) self.assertTrue(out[0].is_cuda) - def test_ignore_decorator(self): - class M(torch.jit.ScriptModule): - def __init__(self): - super(M, self).__init__() - tensor = torch.zeros(1, requires_grad=False) - self.register_buffer('some_state', torch.nn.Parameter(tensor)) + with warnings.catch_warnings(record=True) as warns: + class M(torch.jit.ScriptModule): + def __init__(self): + super(M, self).__init__() + tensor = torch.zeros(1, requires_grad=False) + self.register_buffer('some_state', torch.nn.Parameter(tensor)) - @torch.jit.script_method - def forward(self, x): - self.ignored_code(x) - return x + @torch.jit.script_method + def forward(self, x): + self.ignored_code(x) + return x + + @torch.jit.ignore(drop_on_export=True) + def ignored_code(self, x): + self.some_state = torch.tensor((100,)) - @torch.jit.ignore(drop_on_export=True) - def ignored_code(self, x): - self.some_state = torch.tensor((100,)) + if not PY2: + FileCheck().check("TorchScript will now drop the function").run(str(warns[0])) # Assert ignored code is run m = M() - self.assertEqual(m.some_state, torch.zeros(1)) - m(torch.ones(1)) - self.assertEqual(m.some_state, torch.zeros(1) + 100) m2 = self.getExportImportCopy(m) pp = str(m2.forward.code) - self.assertIn('IgnoredPythonOp', pp) self.assertNotIn('ignored_code', pp) - with self.assertRaisesRegex(torch.jit.Error, "This Python function is annotated to be ignored"): + with self.assertRaisesRegex(torch.jit.Error, "annotated to be ignored and cannot be run"): m2.forward(torch.ones(1)) + def test_ignored_as_value(self): + class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + @torch.jit.unused + def tuple_ignored(self, x): + # type: (Tensor) -> Tuple[Tensor, Tensor] + return x, x + + @torch.jit.unused + def single_val_ignored(self, x, y): + # type: (Tensor, Tensor) -> Tensor + return x + + def forward(self, x, use_ignore_path): + # type: (Tensor, bool) -> Tuple[Tensor, Tensor] + if False: + return self.tuple_ignored(x) + if use_ignore_path: + return self.single_val_ignored(x, x), self.single_val_ignored(x, x) + return x, x + + original = Model() + scripted = torch.jit.script(original) + self.assertEqual(scripted(torch.tensor(.5), False), (torch.tensor(.5), torch.tensor(.5))) + + buffer = io.BytesIO() + torch.jit.save(scripted, buffer) + buffer.seek(0) + loaded = torch.jit.load(buffer) + + with self.assertRaisesRegex(torch._C.JITException, "annotated to be ignored and cannot be run"): + loaded(torch.tensor(.5), True) + def test_module_error(self): class MyModule(torch.nn.Module): def __init__(self): @@ -14615,11 +15040,11 @@ def __init__(self, in_features, out_features): [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8) self.register_buffer('_packed_weight', - torch.ops.quantized.fbgemm_linear_prepack(qweight)) + torch.ops.quantized.linear_prepack(qweight)) @torch.jit.export def __getstate__(self): - return torch.ops.quantized.fbgemm_linear_unpack(self._packed_weight) + return torch.ops.quantized.linear_unpack(self._packed_weight)[0] def forward(self): return self._packed_weight @@ -14627,15 +15052,15 @@ def forward(self): @torch.jit.export def __setstate__(self, state): self._packed_weight.set_( - torch.ops.quantized.fbgemm_linear_prepack(state)) + torch.ops.quantized.linear_prepack(state)) @property def weight(self): - return torch.ops.quantized.fbgemm_linear_unpack(self._packed_weight) + return torch.ops.quantized.linear_unpack(self._packed_weight)[0] @weight.setter def weight(self, w): - self._packed_weight = torch.ops.quantized.fbgemm_linear_prepack(w) + self._packed_weight = torch.ops.quantized.linear_prepack(w) with torch.jit._disable_emit_hooks(): x = torch.jit.script(Linear(10, 10)) @@ -14675,24 +15100,6 @@ def foo(a): class TestRecursiveScript(JitTestCase): - def checkModule(self, nn_module, args): - """ - Check that a nn.Module's results in Script mode match eager and that it - can be exported - """ - sm = torch.jit.script(nn_module) - - with freeze_rng_state(): - eager_out = nn_module(*args) - - with freeze_rng_state(): - script_out = sm(*args) - - self.assertEqual(eager_out, script_out) - self.assertExportImportModule(sm, args) - - return sm - def test_init_error(self): class M(nn.Module): def __init__(self): @@ -14791,6 +15198,36 @@ def forward(self, z): self.checkModule(M(), (torch.randn(2, 2),)) + def test_module_repr(self): + class Submodule(nn.Module): + def forward(self, x): + return x + + class MyModule(nn.Module): + def __init__(self): + super(MyModule, self).__init__() + self.conv = nn.Conv2d(10, 10, 3) + self.lin = nn.Linear(10, 10) + self.sub = Submodule() + + def forward(self, x): + return self.lin(x) + self.sub(x) + self.conv(x) + + m = torch.jit.script(MyModule()) + + with self.capture_stdout() as out: + print(m) + + f = FileCheck() + f.check('MyModule') + f.check('Conv2d') + f.check('Linear') + f.check('Submodule') + f.run(out[0]) + + + self.assertEqual(m.original_name, 'MyModule') + def test_class_compile(self): def other_fn(a, b): # type: (int, Tensor) -> Tensor @@ -16443,8 +16880,8 @@ class TestJitGeneratedFunctional(JitTestCase): ('unfold', (S, S, S, S), ([2, 3]),), ('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),), ('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),), - ('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax'], ['aten::neg', 'aten::add', 'aten::div'])), - ('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax'], ['aten::neg', 'aten::add', 'aten::div'])), + ('gumbel_softmax', (S, S), (2.,), '', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])), + ('gumbel_softmax', (S, S), (2., True,), 'hard', (True, ['aten::softmax', 'aten::add', 'aten::div'], ['aten::neg'])), ('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),), ('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)), 1, 1., non_differentiable(torch.randn(S))),), @@ -16556,7 +16993,7 @@ def do_test(self, name=name, self_size=self_size, args=new_args, test_name=test_ # We enable the CPU fuser during these checks for more consistent # behavior. Otherwise, we are going to have to analyze the graph to # see if producer values are Dimension - @enable_cpu_fuser_if(not (IS_SANDCASTLE or IS_WINDOWS)) + @enable_cpu_fuser_if(not IS_SANDCASTLE) def check(name): set_rng_seed(2) is_magic_method = name[:2] == '__' and name[-2:] == '__' @@ -16591,8 +17028,7 @@ def fn(*inputs, **kwargs): check_against_reference(self, traced_fn, fn, (self_variable,) + args_variable, kwargs_variable, check_types=check_types) - # Fuser not supported on windows - if IS_SANDCASTLE or IS_WINDOWS: + if IS_SANDCASTLE: autodiff_nodes = autodiff_nodes + fusible_nodes fusible_nodes = [] self.assertAutodiffNode(traced_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes) @@ -16603,8 +17039,7 @@ def fn(*inputs, **kwargs): fn, (self_variable,) + args_variable, kwargs_variable, check_types=check_types) - # Fuser not supported on windows - if IS_SANDCASTLE or IS_WINDOWS: + if IS_SANDCASTLE: autodiff_nodes = autodiff_nodes + fusible_nodes fusible_nodes = [] self.assertAutodiffNode(script_fn.last_graph, @@ -17321,7 +17756,6 @@ def check_replicas(self, module, replicas, input_shape=(2, 2)): self.assertEqual(replica(replica_input).data, expected_output) @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported") - @skipIfRocm def test_python_submodule_exception(self): module = self.Msm1(self.Mpy()).cuda() msg = "Cannot replicate.*" @@ -17329,14 +17763,12 @@ def test_python_submodule_exception(self): dp.replicate(module, {0, 1}) @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported") - @skipIfRocm def test_python_submodule_script(self): module = self.Mpy1(self.Msm()).cuda() replicas = dp.replicate(module, {0, 1}) self.check_replicas(module, replicas) @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported") - @skipIfRocm def test_shared_module(self): s = self.Msm() p1 = self.Mpy1(s) @@ -17345,14 +17777,12 @@ def test_shared_module(self): self.check_replicas(module, replicas) @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported") - @skipIfRocm def test_traced_module(self): module = torch.jit.trace(self.Mpy1(self.Mpy()), torch.ones(2, 2)).cuda() replicas = dp.replicate(module, {0, 1}) self.check_replicas(module, replicas) @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported") - @skipIfRocm def test_tensor_sharing(self): module = self.Msm1(self.Msm()).cuda() replica = dp.replicate(module, {0, 1}) @@ -17376,6 +17806,29 @@ def test_tensor_sharing(self): class TestList(JitTestCase): + def test_in_check(self): + def int_in(x): + # type: (List[int]) -> bool + return 2 in x + + self.checkScript(int_in, ([1, 2, 3],)) + self.checkScript(int_in, ([1, 3, 3],)) + + def float_in(x): + # type: (List[float]) -> bool + return 2. in x + + self.checkScript(float_in, ([1., 2., 3.],)) + self.checkScript(float_in, ([1., 3., 3.],)) + + def str_in(x): + # type: (List[str]) -> bool + return 'hi' in x + + self.checkScript(str_in, (['not', 'here'],)) + self.checkScript(str_in, (['hi', 'bye'],)) + self.checkScript(str_in, ([],)) + def test_list_literal(self): def reassign(): x = [1] @@ -17430,6 +17883,66 @@ def reassign_nested(): with self.assertRaisesRegex(RuntimeError, "previously has type"): self.checkScript(reassign_nested, (), optimize=False) + def test_min_bool_list(self): + def jit_min_list(a, b): + # type: (List[bool], List[bool]) -> List[bool] + return min(a, b) + + self.checkScript(jit_min_list, ([True, False], [False, True])) + + def test_min_max_list(self): + def jit_min_list(a, b): + # type: (List[int], List[int]) -> List[int] + return min(a, b) + + def jit_min_list_float(a, b): + # type: (List[float], List[float]) -> List[float] + return min(a, b) + + def jit_min_list_bool(a, b): + # type: (List[bool], List[bool]) -> List[bool] + return min(a, b) + + def run_tests(func, a, b): + for t in zip(a, b): + self.checkScript(func, t) + + args_left_int = [[1, 8, 8], [2, 1, 1], [], [2], [1], [1, 2, 3]] + args_right_int = [[2, 1, 1], [1, 8, 8], [], [1], [], [1, 2]] + run_tests(jit_min_list, args_left_int, args_right_int) + + args_left_float = [[1., 8., 8.], [2., 1., 1.], [], [2.], [1.], [1., 2., 3.]] + args_right_float = [[2., 1., 1.], [1., 8., 8.], [], [1.], [], [1., 2.]] + run_tests(jit_min_list_float, args_left_float, args_right_float) + + args_left_bool = [[], [], [], [False], [True], [False, True], [True, True], + [False, False, False], [False, False, True]] + args_right_bool = [[], [False], [True], [True], [False], [True, True], + [False, True], [False, False, True], [False, False, False]] + run_tests(jit_min_list_bool, args_left_bool, args_right_bool) + + def jit_max_list(a, b): + # type: (List[int], List[int]) -> List[int] + return max(a, b) + + def jit_max_list_float(a, b): + # type: (List[float], List[float]) -> List[float] + return max(a, b) + + def jit_max_list_bool(a, b): + # type: (List[bool], List[bool]) -> List[bool] + return max(a, b) + + args_left_int = [[1, 8, 8], [8, 1, 1], [], [1], [], [1, 2]] + args_right_int = [[8, 1, 1], [1, 8, 8], [], [2], [1], [1, 2, 3]] + run_tests(jit_max_list, args_left_int, args_right_int) + + args_left_float = [[1., 8., 8.], [8., 1., 1.], [], [1.], [], [1., 2.]] + args_right_float = [[8., 1., 1.], [1., 8., 8.], [], [2.], [1.], [1., 2., 3.]] + run_tests(jit_max_list_float, args_left_float, args_right_float) + + run_tests(jit_max_list_bool, args_left_bool, args_right_bool) + def test_list_gather(self): def index(): a = [1, 2, 3] @@ -18040,6 +18553,52 @@ def copy_list(a): for l in [[], [1], [1, 2, 3]]: self.assertEqual(copy_list(l), l) + def test_min_max_single_list(self): + def min_intlist(li): + # type: (List[int]) -> int + return min(li) + + def max_intlist(li): + # type: (List[int]) -> int + return max(li) + + def min_boollist(li): + # type: (List[bool]) -> bool + return min(li) + + def max_boollist(li): + # type: (List[bool]) -> bool + return max(li) + + def min_floatlist(li): + # type: (List[float]) -> float + return min(li) + + def max_floatlist(li): + # type: (List[float]) -> float + return max(li) + + + int_lists = [1], [2, 1, 2], [-3, 4, 2], [-2, -7, 1, 4], [2, 1, 0, 4], [] + + def check_list(fn, li): + if len(li) == 0: + self.checkScriptRaisesRegex(fn, (li,), Exception, "arg is an empty sequence") + else: + self.checkScript(fn, (li,)) + + for int_list in int_lists: + check_list(min_intlist, int_list) + check_list(max_intlist, int_list) + + bool_li = list(map(lambda x: bool(x), int_list)) + check_list(min_boollist, bool_li) + check_list(max_boollist, bool_li) + + float_li = list(map(lambda x: float(x), int_list)) + check_list(min_floatlist, float_li) + check_list(max_floatlist, float_li) + class TestDict(JitTestCase): def dict(self): @@ -18335,6 +18894,36 @@ def fn(my_dict, keys): a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2} self.checkScript(fn, (a_dict, ('a', 'c'))) + def test_ordered_dict(self): + def test_func(fn, inputs): + self.assertEqual(fn(*inputs), torch.jit.script(fn)(*inputs)) + + def repeated_key(): + return OrderedDict([(1, 2), (2, 3), (1, 4)]) + + test_func(repeated_key, ()) + + def no_args(): + a = OrderedDict() + a["one"] = torch.tensor(1) + a["two"] = torch.tensor(2) + + test_func(no_args, ()) + + def test_dict_constructor(): + a = dict() + a["one"] = torch.tensor(1) + return a, dict([(1, 2), (2, 3), (1, 4)]) # noqa: C406 + + test_func(test_dict_constructor, ()) + + def test_dict_error(): + a = dict() + a[1] = 2 + return a + + with self.assertRaisesRegex(Exception, "Arguments for call are not"): + torch.jit.script(test_dict_error) class TestClassType(JitTestCase): def test_get_with_method(self): @@ -18975,6 +19564,14 @@ def __xor__(self, other): # type: (int) -> int return self.x ^ other + def __getitem__(self, other): + # type: (int) -> int + return other + 1 + + def __setitem__(self, idx, val): + # type: (int, int) -> None + self.x = val * idx + def add(): return BinOps(4) + 3 def sub(): # noqa: E306 @@ -19003,8 +19600,15 @@ def _or(): # noqa: E306 return BinOps(4) | 3 def _xor(): # noqa: E306 return BinOps(4) ^ 3 + def getitem(): # noqa: E306 + return BinOps(4)[1] + def setitem(): # noqa: E306 + a = BinOps(4) + a[1] = 5 + return a.x + + ops = [add, sub, mul, pow, ne, eq, lt, gt, le, ge, _and, _or, _xor, getitem, setitem] - ops = [add, sub, mul, pow, ne, eq, lt, gt, le, ge, _and, _or, _xor] if not PY2: ops.append(truediv) for func in ops: diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index d4da6d13d80e8..7d34d1c92d109 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from torch.testing import FileCheck -from common_utils import run_tests, IS_WINDOWS, skipIfRocm, IS_SANDCASTLE +from common_utils import run_tests, IS_SANDCASTLE from textwrap import dedent from itertools import product, permutations @@ -37,18 +37,16 @@ def func(x): self.assertEqual(func(a), a.abs() * 2) self.assertAllFused(func.graph_for(a)) - @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle") + @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") @enable_cpu_fuser def test_abs_cpu(self): self._test_fused_abs() @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @skipIfRocm def test_abs_cuda(self): self._test_fused_abs(device="cuda") @unittest.skipIf(not RUN_CUDA, "requires CUDA") - @skipIfRocm def test_zero_element_tensors(self): def decode(sin_t, cos_t): theta = torch.atan2(sin_t.float(), cos_t.float()) @@ -75,7 +73,6 @@ def f(x, y): self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_broadcast_cuda(self): def scaleshift(x, scale, shift): return x * scale + shift @@ -124,7 +121,6 @@ def test_cuda_half(self): self.assertEqual(grads_half, fusion_grads) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_checks_cat_inputs(self): # We shouldn't treat cat nodes as broadcasting. All their inputs # need to be checked for having the same map size, before we can @@ -142,7 +138,6 @@ def f(x, y): self.assertAllFused(f.graph_for(x, y)) @unittest.skipIf(not RUN_CUDA, "No CUDA") - @skipIfRocm def test_chunk_cuda(self): def fn(x): a, b, c = x.chunk(3, 1) @@ -185,7 +180,7 @@ def chunk_4_last(x): for fn in fns: self.checkScript(fn, [tensor]) - @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle") + @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") @enable_cpu_fuser def test_chunk_correctness(self): return self._test_chunk_correctness(self, 'cpu') @@ -195,7 +190,6 @@ def test_chunk_correctness_cuda(self): return self._test_chunk_correctness(self, 'cuda') @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_chunk_distributes_cuda(self): def f(x, y): z1, z2 = (x + y).chunk(2, dim=1) @@ -210,7 +204,6 @@ def f(x, y): .check_count('ConstantChunk', 2, exactly=True).run(str(graph)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_chunk_motion_deduplicates_inputs(self): def func1(x): z = x * x @@ -233,7 +226,6 @@ def func2(x): self.assertEqual(len(list(fusion_group.inputs())), 1) @unittest.skipIf(not RUN_CUDA, "No CUDA") - @skipIfRocm def test_chunk_multiple_cuda(self): # The arguments are intentionally used out of order as a test to see # if the fusion compiler adds extra args in the correct order @@ -254,7 +246,6 @@ def fn(s, x, y, z): self.assertAllFused(ge.graph_for(*inputs)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_clamp(self): def func2(a, b): return torch.clamp(a + b, min=0, max=2) @@ -284,7 +275,6 @@ def funcOptMax(a, b): self.assertAllFused(graph, except_for={'aten::Float'}) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_dropout(self): def func(x): x = torch.nn.functional.dropout(x) @@ -298,7 +288,6 @@ def func(x): self.assertAllFused(graph, except_for={'aten::div', 'prim::Constant'}) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_comparison_eq_ne(self): def f(x, y): mask = (x == 0).type_as(x) @@ -322,7 +311,6 @@ def fn_test_comparison_gt_lt(x, y): return z @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_comparison_gt_lt_cuda(self): x = torch.randn(4, 4, dtype=torch.float, device='cuda') y = torch.randn(4, 4, dtype=torch.float, device='cuda') @@ -331,7 +319,6 @@ def test_comparison_gt_lt_cuda(self): self.assertAllFused(ge.graph_for(x, y)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_comparison_ge_le_cuda(self): def f(x, y): mask = (x >= 0).type_as(x) @@ -351,7 +338,6 @@ def f(x, y): "aten::_size_if_not_equal")) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_addcmul_cuda(self): t = torch.randn(1, 4, dtype=torch.float, device='cuda') t1 = torch.randn(4, 1, dtype=torch.float, device='cuda') @@ -371,7 +357,6 @@ def foo(t, t1, t2): # If this is a real problem, we'll need to revisit Torchscript Function # lifetimes in Python. @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_lerp(self): start = torch.randn(4, 1, dtype=torch.float, device='cuda') end = torch.randn(1, 4, dtype=torch.float, device='cuda') @@ -394,7 +379,6 @@ def foo_weight_tensor(start, end): self.assertAllFused(graph) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_concat_cuda(self): hx = torch.randn(3, 20, dtype=torch.float, device='cuda') cx = torch.randn(3, 20, dtype=torch.float, device='cuda') @@ -408,7 +392,6 @@ def foo(hx, cx): FileCheck().check("FusedConcat").check_next("return").run(str(graph)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_concat_invariant_cuda(self): # Invariant: the output of prim::FusedConcat may # not be an input to any node inside the FusionGroup. @@ -431,7 +414,6 @@ def fn_test_exp(x, y): return (x + .5 * y).exp() @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_exp_cuda(self): x = torch.randn(4, 4, dtype=torch.float, device='cuda') y = torch.randn(4, 4, dtype=torch.float, device='cuda') @@ -440,7 +422,6 @@ def test_exp_cuda(self): self.assertAllFused(ge.graph_for(x, y)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm @_inline_everything def test_fuse_decompose_normalization(self): class ResLike(torch.jit.ScriptModule): @@ -495,7 +476,6 @@ def test_norm_decompose(nm, in_opt_graph, not_in_opt_graph, in_fusegraph): ['aten::layer_norm('], ['aten::sub', 'aten::mul', 'aten::add']) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_threshold(self): def f(x): return torch.threshold(x, 0, -10) + x + x + x @@ -507,7 +487,6 @@ def f(x): self.assertAllFused(scripted.graph_for(x)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_scalar_arg_cuda(self): def fn_test_scalar_arg(x, p): # type: (Tensor, float) -> Tensor @@ -523,7 +502,7 @@ def fn_test_scalar_arg(x, p): self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes", "aten::_size_if_not_equal")) - @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle") + @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") @enable_cpu_fuser def test_fuser_deduplication(self): # See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation @@ -543,7 +522,7 @@ def f(x, y): # check that a, b share storage, i.e. were generated as a single output in the fuser self.assertEqual(ga.data_ptr(), gb.data_ptr()) - @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle") + @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") @enable_cpu_fuser @unittest.skip("temporarily disabled because fusion was restricted in fixing #22833") def test_fuser_iou(self): @@ -587,7 +566,6 @@ def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2): @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") - @skipIfRocm @enable_cpu_fuser def test_fusion_reuse_multi_gpu(self): def fn(x, y): @@ -608,7 +586,6 @@ def fn(x, y): @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") - @skipIfRocm @enable_cpu_fuser def test_kernel_cache_multi_gpu(self): def not_fusible(x): @@ -638,7 +615,6 @@ def fn(x, y, z): self.assertEqual(new_cache_size - prev_cache_size, 1) @unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device") - @skipIfRocm def test_nonzero_device_cuda(self): device = 'cuda:' + str(1) x = torch.tensor([0.4], dtype=torch.float, device=device) @@ -651,7 +627,6 @@ def doit(x, y): self.assertAllFused(ge.graph_for(x, y)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_lstm_cuda(self): inputs = get_lstm_inputs('cuda', training=True) module = self.checkScript(LSTMCellS, inputs) @@ -670,7 +645,6 @@ def test_lstm_cuda(self): "aten::_grad_sum_to_size")) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_lstm_concat_cuda(self): inputs = get_lstm_inputs('cuda') ge = self.checkTrace(LSTMCellC, inputs) @@ -678,7 +652,6 @@ def test_lstm_concat_cuda(self): FileCheck().check("FusedConcat").check_next("return").run(str(graph)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_lstm_gates_permutations_cuda(self): # lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh. # Test that any permutation of this will still result in one FusionGroup. @@ -702,7 +675,6 @@ def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh): # TODO: Fuser doesn't work at all when inputs require grad. Fix that @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_lstm_traced_cuda(self): inputs = get_lstm_inputs('cuda') ge = self.checkTrace(LSTMCellF, inputs) @@ -711,7 +683,7 @@ def test_lstm_traced_cuda(self): .check_not("aten::tanh").check("FusionGroup").check_next("TupleConstruct") \ .check_next("return").check_not("FusionGroup_1").run(str(graph)) - @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle") + @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") @unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746") @enable_cpu_fuser def test_lstm_traced_cpu(self): @@ -730,7 +702,6 @@ def test_lstm_traced_cpu(self): raise @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_milstm_cuda(self): inputs = get_milstm_inputs('cuda', training=True) module = self.checkScript(MiLSTMCell, inputs) @@ -743,7 +714,6 @@ def test_milstm_cuda(self): (hy + cy).sum().backward() @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_rand_cuda(self): class M(torch.jit.ScriptModule): __constants__ = ['d'] @@ -772,7 +742,6 @@ def fn_test_relu(x, y): return F.relu(x + .5 * y) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_relu_cuda(self): x = torch.randn(4, 4, dtype=torch.float, device='cuda') y = torch.randn(4, 4, dtype=torch.float, device='cuda') @@ -781,7 +750,6 @@ def test_relu_cuda(self): self.assertAllFused(ge.graph_for(x, y)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_erf_cuda(self): def fn_test_erf(x): return F.relu(torch.erf(x) - torch.erfc(x)) @@ -794,7 +762,6 @@ def fn_test_erf(x): "aten::_size_if_not_equal")) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_rand_broadcast_cuda(self): def fn_test_rand(x, y): r = torch.rand_like(y) @@ -815,7 +782,7 @@ def fn_test_rand(x, y): out = script_f(x, y) self.assertEqual(out[0], out[1]) - @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle") + @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") @enable_cpu_fuser def test_scalar(self): def fn(x, y): @@ -827,7 +794,6 @@ def fn(x, y): self.assertAllFused(ge.graph_for(x, y)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_small_constant_cuda(self): def fn_test_small_constant(x, y): return (1e-8 * x + 5e-9 * y) * 1e8 @@ -838,7 +804,6 @@ def fn_test_small_constant(x, y): self.assertAllFused(ge.graph_for(x, y)) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_tensor_scalar_ops_cuda(self): def should_fuse(x): z = 3. @@ -863,7 +828,7 @@ def should_not_fuse(x, z): self.assertGraphContainsExactly( ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True) - @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser CPU support for Windows or Sandcastle") + @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") @enable_cpu_fuser def test_where_and_typing(self): def f(x, y): @@ -883,7 +848,6 @@ def f(x, y): self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'}) @unittest.skipIf(not RUN_CUDA, "fuser requires CUDA") - @skipIfRocm def test_grad_sum_to_size_elimination(self): def my_broadcasted_cell(a, b, c): diff --git a/test/test_jit_py3.py b/test/test_jit_py3.py index a1c74d6930ce1..b42d68d8fe55f 100644 --- a/test/test_jit_py3.py +++ b/test/test_jit_py3.py @@ -3,9 +3,9 @@ from torch.testing import FileCheck from typing import NamedTuple, List, Optional import unittest +import sys import torch - class TestScriptPy3(JitTestCase): def test_joined_str(self): def func(x): @@ -48,6 +48,21 @@ def foo(x) -> float: self.assertEqual(foo(torch.rand(3, 4)), 18.0) + @unittest.skipIf(sys.version_info[0] < 3 and sys.version_info[1] < 6, "dict not ordered") + def test_dict_preserves_order(self): + def dict_ordering(): + a : Dict[int, int] = {} + for i in range(1000): + a[i] = i + 1 + return a + + self.checkScript(dict_ordering, ()) + di = torch.jit.script(dict_ordering)() + res = list(di.items()) + for i in range(1000): + key, value = res[i] + self.assertTrue(key == i and value == i + 1) + def test_return_named_tuple(self): class FeatureVector(NamedTuple): float_features: float diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py index 85c66832cd958..8a61b9240fe49 100644 --- a/test/test_namedtensor.py +++ b/test/test_namedtensor.py @@ -1,23 +1,34 @@ import unittest -from common_utils import TestCase, run_tests +from common_utils import TestCase, run_tests, TEST_NUMPY from common_cuda import TEST_CUDA -from collections import namedtuple +from collections import namedtuple, OrderedDict import itertools import functools import torch from torch import Tensor +from torch._six import PY2 import torch.nn.functional as F from multiprocessing.reduction import ForkingPickler import pickle import io +import os import sys import warnings +def check_env_flag(name, default=''): + return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] + +TEST_NAMEDTENSOR = check_env_flag('TEST_NAMEDTENSOR') + skipIfNamedTensorDisabled = \ unittest.skipIf(not torch._C._BUILD_NAMEDTENSOR, 'PyTorch not compiled with namedtensor support') +skipIfNotTestingNamedTensor = \ + unittest.skipIf(not TEST_NAMEDTENSOR, + 'TEST_NAMEDTENSOR=0; set it to 1 to enable named tensor tests') + def pass_name_to_python_arg_parser(name): x = torch.empty(2, names=(name,)) @@ -69,6 +80,17 @@ def fn(*inputs): class TestNamedTensor(TestCase): + def test_aaa_must_run_first_check_experimental_warning(self): + # TODO(rzou): It would be nice for this to be a "real" python warning. + # Right now this error message only prints once and doesn't respect + # warnings.simplefilter behavior (where python users can control whether + # or not to display warnings once, all the time, or never). + with warnings.catch_warnings(record=True) as warns: + x = torch.randn(3, 3, names=('N', 'C')) + self.assertEqual(len(warns), 1) + self.assertTrue(str(warns[0].message).startswith( + 'Named tensors and all their associated APIs are an experimental feature')) + def test_trivial(self): pass @@ -89,8 +111,8 @@ def _test_name_inference(self, op, args=(), expected_names=(), device='cpu', # Right now I don't know what it should look like. def assertTensorDataAndNamesEqual(self, x, y): self.assertEqual(x.names, y.names) - unnamed_x = x.view_names(None) - unnamed_y = y.view_names(None) + unnamed_x = x.renamed(None) + unnamed_y = y.renamed(None) self.assertEqual(unnamed_x, unnamed_y) def _test_factory(self, factory, device): @@ -132,18 +154,6 @@ def _test_factory(self, factory, device): names65 = ['A' * i for i in range(1, 66)] x = factory([1] * 65, names=names64, device=device) - # Tests for tagged names - x = factory(2, 3, 1, names=('C.in', 'H', 'C.out'), device=device) - self.assertEqual(x.names, ('C.in', 'H', 'C.out')) - - with self.assertRaisesRegex(RuntimeError, 'construct a tensor with duplicate names'): - x = factory(2, 1, 1, names=('C.in', 'H', 'C.in'), device=device) - - with self.assertRaisesRegex( - RuntimeError, - 'with duplicate names unless they are tagged and have different tags'): - x = factory(2, 1, 1, names=('C.in', 'H', 'C'), device=device) - def test_has_names(self): unnamed = torch.empty(2, 3) none_named = torch.empty(2, 3, names=(None, None)) @@ -155,6 +165,51 @@ def test_has_names(self): self.assertTrue(partially_named.has_names()) self.assertTrue(fully_named.has_names()) + @unittest.skipIf(PY2, "Ellipsis object not supported in python 2") + def test_py3_ellipsis(self): + # Need to exec or else flake8 will complain about invalid python 2. + tensor = torch.randn(2, 3, 5, 7) + scope = {'tensor': tensor} + code_str = "output = tensor.refine_names('N', ..., 'C')" + exec(code_str, globals(), scope) + self.assertEqual(scope['output'].names, ['N', None, None, 'C']) + + def test_refine_names(self): + # Unnamed tensor -> Unnamed tensor + self._test_name_inference(Tensor.refine_names, + [create('None:1,None:2,None:3'), 'N', 'C', 'H'], + ['N', 'C', 'H']) + + # Named tensor -> Named tensor + self._test_name_inference(Tensor.refine_names, + [create('N:1,C:2,H:3'), 'N', 'C', 'H'], + ['N', 'C', 'H']) + + # Partially named tensor -> named tensor + self._test_name_inference(Tensor.refine_names, + [create('None:1,C:2,None:3'), None, 'C', 'H'], + [None, 'C', 'H']) + + # Too few names + self._test_name_inference(Tensor.refine_names, + [create('None:2,None:3'), 'N', 'C', 'H'], + maybe_raises_regex="different number of dims") + + # Cannot change Tensor[D] to Tensor[N] + self._test_name_inference(Tensor.refine_names, + [create('D:3'), 'N'], + maybe_raises_regex="is different from") + + # Cannot change Tensor[D] to Tensor[None] + self._test_name_inference(Tensor.refine_names, + [create('D:3'), None], + maybe_raises_regex="'D' is more specific than None") + + # globbing behavior exists + self._test_name_inference(Tensor.refine_names, + [create('None:1,None:1,None:2,None:3'), '...', 'C', 'H'], + [None, None, 'C', 'H']) + def test_repr(self): named_tensor = torch.zeros(2, 3).names_('N', 'C') expected = "tensor([[0., 0., 0.],\n [0., 0., 0.]], names=('N', 'C'))" @@ -186,7 +241,7 @@ def test_no_multiprocessing_support(self): def test_big_tensor_repr(self): def check_repr(named_tensor): - unnamed_tensor = named_tensor.view_names(None) + unnamed_tensor = named_tensor.renamed(None) expected = "{}, names={})".format(repr(unnamed_tensor)[:-1], named_tensor.names) self.assertEqual(repr(named_tensor), expected) @@ -218,82 +273,82 @@ def test_names_(self): with self.assertRaisesRegex(RuntimeError, 'duplicate names'): tensor.names_('N', 'N') - def test_view_names(self): + def test_renamed(self): tensor = torch.empty(1, 1, names=('N', 'C')) - self.assertEqual(tensor.view_names(None).names, (None, None)) - self.assertEqual(tensor.view_names('H', 'W').names, ('H', 'W')) + self.assertEqual(tensor.renamed(None).names, (None, None)) + self.assertEqual(tensor.renamed('H', 'W').names, ('H', 'W')) # Check that we didn't modify tensor.names self.assertEqual(tensor.names, ('N', 'C')) with self.assertRaisesRegex(RuntimeError, 'Number of names'): - tensor.view_names('N', 'C', 'W') + tensor.renamed('N', 'C', 'W') with self.assertRaisesRegex(RuntimeError, 'duplicate names'): - tensor.view_names('N', 'N') + tensor.renamed('N', 'N') with self.assertRaisesRegex(RuntimeError, 'either positional args or keyword args'): - tensor.view_names(None, N='batch') + tensor.renamed(None, N='batch') - # view_names returns a view on the tensor - self.assertEqual(tensor.view_names('H', 'W').data_ptr(), tensor.data_ptr()) - self.assertEqual(tensor.view_names(None).data_ptr(), tensor.data_ptr()) + # renamed returns a view on the tensor + self.assertEqual(tensor.renamed('H', 'W').data_ptr(), tensor.data_ptr()) + self.assertEqual(tensor.renamed(None).data_ptr(), tensor.data_ptr()) - def test_view_names_globber(self): + def test_renamed_globber(self): scalar = torch.randn([]) unnamed_tensor = torch.empty(1, 1, 1, 1) named_tensor = torch.empty(1, 1, 1, 1, names=('N', 'C', 'H', 'W')) - self.assertEqual(scalar.view_names(None).names, []) - self.assertEqual(scalar.view_names('*').names, []) + self.assertEqual(scalar.renamed(None).names, []) + self.assertEqual(scalar.renamed('...').names, []) # Check that it works with unnamed tensors - self.assertEqual(unnamed_tensor.view_names('*').names, unnamed_tensor.names) - self.assertEqual(unnamed_tensor.view_names('*', 'H', 'W').names, + self.assertEqual(unnamed_tensor.renamed('...').names, unnamed_tensor.names) + self.assertEqual(unnamed_tensor.renamed('...', 'H', 'W').names, [None, None, 'H', 'W']) - self.assertEqual(unnamed_tensor.view_names('N', '*', 'W').names, + self.assertEqual(unnamed_tensor.renamed('N', '...', 'W').names, ['N', None, None, 'W']) - self.assertEqual(unnamed_tensor.view_names('N', 'C', '*').names, + self.assertEqual(unnamed_tensor.renamed('N', 'C', '...').names, ['N', 'C', None, None]) # Check that it works with named tensors - self.assertEqual(named_tensor.view_names('*').names, named_tensor.names) - self.assertEqual(named_tensor.view_names('*', 'width').names, + self.assertEqual(named_tensor.renamed('...').names, named_tensor.names) + self.assertEqual(named_tensor.renamed('...', 'width').names, ['N', 'C', 'H', 'width']) - self.assertEqual(named_tensor.view_names('batch', 'channels', '*', 'width').names, + self.assertEqual(named_tensor.renamed('batch', 'channels', '...', 'width').names, ['batch', 'channels', 'H', 'width']) - self.assertEqual(named_tensor.view_names('batch', '*').names, + self.assertEqual(named_tensor.renamed('batch', '...').names, ['batch', 'C', 'H', 'W']) # Test empty glob - self.assertEqual(unnamed_tensor.view_names('*', None, None, None, None).names, + self.assertEqual(unnamed_tensor.renamed('...', None, None, None, None).names, [None, None, None, None]) - self.assertEqual(named_tensor.view_names('N', 'C', 'H', '*', 'W').names, + self.assertEqual(named_tensor.renamed('N', 'C', 'H', '...', 'W').names, ['N', 'C', 'H', 'W']) # Multiple globs throw with self.assertRaisesRegex(RuntimeError, 'More than one '): - named_tensor.view_names('*', 'channels', '*') + named_tensor.renamed('...', 'channels', '...') - def test_view_names_rename_map(self): + def test_renamed_rename_map(self): scalar = torch.randn([]) unnamed_tensor = torch.empty(1, 1, 1, 1) named_tensor = torch.empty(1, 1, 1, 1, names=('N', 'C', 'H', 'W')) with self.assertRaisesRegex(RuntimeError, "dim 'N' does not exist"): - scalar.view_names(N='batch') + scalar.renamed(N='batch') with self.assertRaisesRegex(RuntimeError, "dim 'N' does not exist"): - unnamed_tensor.view_names(N='batch') + unnamed_tensor.renamed(N='batch') with self.assertRaisesRegex(RuntimeError, "dim 'B' does not exist"): - named_tensor.view_names(B='batch') + named_tensor.renamed(B='batch') with self.assertRaisesRegex(RuntimeError, "dim 'B' does not exist"): - named_tensor.view_names(H='height', B='batch') + named_tensor.renamed(H='height', B='batch') - self.assertEqual(named_tensor.view_names(N='batch').data_ptr(), + self.assertEqual(named_tensor.renamed(N='batch').data_ptr(), named_tensor.data_ptr()) - self.assertEqual(named_tensor.view_names(N='batch').names, + self.assertEqual(named_tensor.renamed(N='batch').names, ['batch', 'C', 'H', 'W']) - self.assertEqual(named_tensor.view_names(N='batch', H='height').names, + self.assertEqual(named_tensor.renamed(N='batch', H='height').names, ['batch', 'C', 'height', 'W']) def test_set_names_property(self): @@ -343,6 +398,48 @@ def _test(factory, device): expected = torch.full([1, 2, 3], 2, device=device).names_(*names) self.assertTensorDataAndNamesEqual(result, expected) + def test_tensor_from_lists(self): + names = ('N', 'C') + tensor = torch.tensor([[1]], names=names) + self.assertEqual(tensor.names, names) + + names = ('N',) + tensor = torch.tensor([1], names=names) + self.assertEqual(tensor.names, names) + + with self.assertRaisesRegex(RuntimeError, 'Number of names'): + names = ('N', 'C') + tensor = torch.tensor([1], names=names) + + @unittest.skipIf(not TEST_NUMPY, "no numpy") + def test_tensor_from_numpy(self): + import numpy as np + arr = np.array([[1]]) + names = ('N', 'C') + tensor = torch.tensor([[1]], names=names) + self.assertEqual(tensor.names, names) + + def test_tensor_from_tensor(self): + x = torch.randn(1, 1) + names = ('N', 'C') + tensor = torch.tensor(x, names=names) + self.assertEqual(tensor.names, names) + + def test_tensor_from_named_tensor(self): + x = torch.randn(1, 1, names=('N', 'D')) + tensor = torch.tensor(x) + self.assertEqual(tensor.names, ('N', 'D')) + + # there's no way to distinguish between names=None and not passing in names. + # If the user passes in names=None they are asking for trouble. + x = torch.randn(1, 1, names=('N', 'D')) + tensor = torch.tensor(x, names=None) + self.assertEqual(tensor.names, ('N', 'D')) + + x = torch.randn(1, 1, names=('N', 'D')) + with self.assertRaisesRegex(RuntimeError, "Name mismatch"): + tensor = torch.tensor(x, names=('N', 'C')) + def test_size(self): t = torch.empty(2, 3, 5, names=('N', None, 'C')) self.assertEqual(t.size('N'), 2) @@ -487,7 +584,7 @@ def out_function(name, *args, **kwargs): out_fn = getattr(torch, name) def fn(a, b): - result = a.new_empty([0]) + result = torch.empty([0], dtype=a.dtype, device=a.device) out_fn(a, b, *args, out=result, **kwargs) return result @@ -564,7 +661,7 @@ def out_function(name, *args, **kwargs): out_fn = getattr(torch, name) def fn(tensor): - result = tensor.new_empty([0]) + result = torch.empty([0], dtype=tensor.dtype, device=tensor.device) out_fn(tensor, *args, out=result, **kwargs) return result @@ -686,6 +783,82 @@ def test_bernoulli(self): torch.bernoulli(tensor, out=result) self.assertEqual(result.names, names) + def test_flatten(self): + tensor = torch.randn(2, 3, 5, 7, 11, names=('N', 'C', 'D', 'H', 'W')) + + # basic + out = tensor.flatten('D', 'W', 'features') + self.assertEqual(out.names, ['N', 'C', 'features']) + self.assertEqual(out.renamed(None), tensor.renamed(None).view(2, 3, -1)) + + # int overload + out = tensor.flatten(2, 4, 'features') + self.assertEqual(out.names, ['N', 'C', 'features']) + self.assertEqual(out.renamed(None), tensor.renamed(None).view(2, 3, -1)) + + # list overload + out = tensor.flatten(['D', 'H', 'W'], 'features') + self.assertEqual(out.names, ['N', 'C', 'features']) + self.assertEqual(out.renamed(None), tensor.renamed(None).view(2, 3, -1)) + + # Non-contiguous flatten: N and H are not "adjacent" in memory. + sentences = torch.randn(2, 3, 5, 7, names=('N', 'T', 'H', 'D')) + sentences = sentences.transpose('T', 'H') + out = sentences.flatten('N', 'H', 'N_H') + self.assertEqual(out.names, ['N_H', 'T', 'D']) + + with self.assertRaisesRegex(RuntimeError, "Name 'L' not found in"): + tensor.flatten(['D', 'L'], 'features') + + with self.assertRaisesRegex(RuntimeError, "must be consecutive in"): + tensor.flatten(['D', 'W'], 'features') + + with self.assertRaisesRegex(RuntimeError, "must be consecutive in"): + tensor.flatten(['H', 'D', 'W'], 'features') + + def test_unflatten(self): + tensor = torch.randn(7, 2 * 3 * 5, 11, names=('N', 'D', 'K')) + + # accepts iterable of tuples + out = tensor.unflatten('D', (('C', 2), ('H', 3), ('W', 5))) + self.assertEqual(out.names, ('N', 'C', 'H', 'W', 'K')) + self.assertEqual(out.shape, (7, 2, 3, 5, 11)) + + # accepts OrderedDict + out = tensor.unflatten('D', OrderedDict((('C', 2), ('H', 3), ('W', 5)))) + self.assertEqual(out.names, ('N', 'C', 'H', 'W', 'K')) + self.assertEqual(out.shape, (7, 2, 3, 5, 11)) + + # Unflatten left-most + out = tensor.unflatten('N', (('N', 7), ('H', 1))) + self.assertEqual(out.names, ('N', 'H', 'D', 'K')) + self.assertEqual(out.shape, (7, 1, 2 * 3 * 5, 11)) + + # Unflatten right-most + out = tensor.unflatten('K', (('K', 11), ('H', 1))) + self.assertEqual(out.names, ('N', 'D', 'K', 'H')) + self.assertEqual(out.shape, (7, 2 * 3 * 5, 11, 1)) + + # takes positional dim + out = tensor.unflatten(1, (('C', 2), ('H', 3), ('W', 5))) + self.assertEqual(out.names, ('N', 'C', 'H', 'W', 'K')) + self.assertEqual(out.shape, (7, 2, 3, 5, 11)) + + with self.assertRaisesRegex(RuntimeError, "don't multiply up to"): + tensor.unflatten('D', (('H', 3), ('W', 5))) + + with self.assertRaisesRegex(RuntimeError, 'OrderedDict or iterable of tuples'): + tensor.unflatten('D', None) + + with self.assertRaisesRegex(RuntimeError, 'non-empty'): + tensor.unflatten('D', OrderedDict()) + + def test_unsupported_op_error_msg(self): + named = torch.randn(3, 3, names=('N', 'C')) + with self.assertRaisesRegex( + RuntimeError, "pdist is not yet supported with named tensors"): + torch.pdist(named) + def test_reduction_fns(self): def check_output(output, expected_names): if isinstance(output, torch.Tensor): @@ -721,7 +894,7 @@ def test_multidim_reduce(op_name, device): def test_out_variant(op_name, device): t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device) - out = t.new_empty([0]) + out = torch.empty([0], device=device) getattr(torch, op_name)(t, 'C', out=out) check_output(out, ['N', 'L']) @@ -735,23 +908,27 @@ def test_keepdim(op_name, device): 'supports_complete_reduce', 'supports_multidim_reduce', 'supports_out_variant', + 'supports_keepdim', + 'output_lambda', ]) tests = [ - Case('sum', True, True, True), - Case('prod', True, False, True), - Case('mean', True, True, True), - Case('var', True, True, True), - Case('std', True, True, True), - Case('std_mean', True, True, False), - Case('var_mean', True, True, False), + Case('sum', True, True, True, True, None), + Case('prod', True, False, True, True, None), + Case('mean', True, True, True, True, None), + Case('var', True, True, True, True, None), + Case('std', True, True, True, True, None), + Case('std_mean', True, True, False, True, None), + Case('var_mean', True, True, False, True, None), + Case('unbind', False, False, False, False, None), ] for testcase, device in itertools.product(tests, torch.testing.get_all_device_types()): op_name = testcase.op_name test_simple_reduce(op_name, device) - test_keepdim(op_name, device) + if testcase.supports_keepdim: + test_keepdim(op_name, device) if testcase.supports_out_variant: test_out_variant(op_name, device) if testcase.supports_complete_reduce: @@ -763,31 +940,31 @@ def test_masked_select(self): # simple self._test_name_inference( torch.masked_select, - (create('N:2,C:3'), (create('2,3') > 0).view_names('N', 'C')), + (create('N:2,C:3'), (create('2,3') > 0).renamed('N', 'C')), expected_names=[None]) # left broadcast self._test_name_inference( torch.masked_select, - (create('C:3'), (create('2,3') > 0).view_names('N', 'C')), + (create('C:3'), (create('2,3') > 0).renamed('N', 'C')), expected_names=[None]) # right broadcast self._test_name_inference( torch.masked_select, - (create('N:2,C:3'), (create('3') > 0).view_names('C')), + (create('N:2,C:3'), (create('3') > 0).renamed('C')), expected_names=[None]) # error self._test_name_inference( torch.masked_select, - (create('N:2,C:3'), (create('3') > 0).view_names('D')), + (create('N:2,C:3'), (create('3') > 0).renamed('D')), maybe_raises_regex='do not match') # out= self._test_name_inference( out_fn(torch.masked_select), - (create('0'), create('N:2,C:3'), (create('2,3') > 0).view_names('N', 'C')), + (create('0'), create('N:2,C:3'), (create('2,3') > 0).renamed('N', 'C')), expected_names=[None]) def test_cat(self): @@ -825,37 +1002,37 @@ def test_masked_fill(self): # simple self._test_name_inference( Tensor.masked_fill, - (create('N:2,C:3'), (create('2,3') > 0).view_names('N', 'C'), 3.14), + (create('N:2,C:3'), (create('2,3') > 0).renamed('N', 'C'), 3.14), expected_names=['N', 'C']) # left broadcast self._test_name_inference( Tensor.masked_fill, - (create('C:3'), (create('2,3') > 0).view_names('N', 'C'), 3.14), + (create('C:3'), (create('2,3') > 0).renamed('N', 'C'), 3.14), maybe_raises_regex="must be less than or equal to") # right broadcast self._test_name_inference( Tensor.masked_fill, - (create('N:2,C:3'), (create('3') > 0).view_names('C'), 3.14), + (create('N:2,C:3'), (create('3') > 0).renamed('C'), 3.14), expected_names=['N', 'C']) # error self._test_name_inference( Tensor.masked_fill, - (create('N:2,C:3'), (create('3') > 0).view_names('D'), 3.14), + (create('N:2,C:3'), (create('3') > 0).renamed('D'), 3.14), maybe_raises_regex='do not match') # inplace self._test_name_inference( Tensor.masked_fill_, - (create('N:2,C:3'), (create('2,3') > 0).view_names('N', 'C'), 3.14), + (create('N:2,C:3'), (create('2,3') > 0).renamed('N', 'C'), 3.14), expected_names=['N', 'C']) # inplace, computed names don't match output tensor names self._test_name_inference( Tensor.masked_fill_, - (create('N:2,None:3'), (create('2,3') > 0).view_names('N', 'C'), 3.14), + (create('N:2,None:3'), (create('2,3') > 0).renamed('N', 'C'), 3.14), maybe_raises_regex="not the same as the computed output names") @@ -916,23 +1093,6 @@ def _test_select(self, device): RuntimeError, 'Please look up dimensions by name'): y = x.select(None, 1) - with self.assertRaisesRegex( - RuntimeError, 'Name \'C.in\' not found in'): - y = x.select('C.in', 1) - - x = torch.empty(2, 3, 4, 5, names=('N', 'C.in', 'H', 'W'), device=device) - y = x.select('C', 1) - self.assertEqual(y.names, ('N', 'H', 'W')) - - x = torch.empty(2, 3, 4, 5, names=('C.out', 'C.in', 'H', 'W'), device=device) - y = x.select('C.in', 1) - self.assertEqual(y.names, ('C.out', 'H', 'W')) - - with self.assertRaisesRegex( - RuntimeError, 'Name \'C\' could refer to multiple dimensions'): - y = x.select('C', 1) - - def test_select(self): self._test_select('cpu') @@ -952,7 +1112,22 @@ def test_as_strided(self): def test_as_strided_cuda(self): self._test_as_strided('cuda') - def test_no_jit_support(self): + def test_no_jit_tracer_support(self): + def foo(x): + return torch.full(x.shape, 2, names=('N',)) + + with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'): + x = torch.randn(3) + torch.jit.trace(foo, example_inputs=x) + + def bar(x): + return x.select('N', 1) + + with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'): + x = torch.randn(3) + torch.jit.trace(bar, example_inputs=x) + + def test_no_jit_script_support(self): @torch.jit.script def foo(x): return x + 1 @@ -973,15 +1148,73 @@ def return_named_tensor(input): return_named_tensor(torch.randn(1, 1)) def test_align_to(self): + # trivial + tensor = create('N:3') + output = tensor.align_to('N') + self.assertEqual(output.names, ['N']) + self.assertEqual(output.shape, [3]) + + # unsqueeze behavior + tensor = create('N:3') + output = tensor.align_to('N', 'D') + self.assertEqual(output.names, ['N', 'D']) + self.assertEqual(output.shape, [3, 1]) + + # transpose behavior + tensor = create('N:3,C:2') + output = tensor.align_to('C', 'N') + self.assertEqual(output.names, ['C', 'N']) + self.assertEqual(output.shape, [2, 3]) + + # unsqueeze / transpose + tensor = create('C:2,N:3,H:5') + output = tensor.align_to('N', 'H', 'W', 'C') + self.assertEqual(output.names, ['N', 'H', 'W', 'C']) + self.assertEqual(output.shape, [3, 5, 1, 2]) + + # globbing + tensor = create('N:7,H:3,W:5,C:2') + output = tensor.align_to('...', 'C', 'H', 'W') + self.assertEqual(output.names, ['N', 'C', 'H', 'W']) + self.assertEqual(output.shape, [7, 2, 3, 5]) + + tensor = create('N:7,C:2,H:3,W:5') + output = tensor.align_to('...', 'W', 'H') + self.assertEqual(output.names, ['N', 'C', 'W', 'H']) + self.assertEqual(output.shape, [7, 2, 5, 3]) + + # All input dimensions must be named + with self.assertRaisesRegex(RuntimeError, "All input dims must be named"): + create('None:2,C:3').align_to('N', 'C') + + # not enough names + with self.assertRaisesRegex(RuntimeError, "Cannot find dim 'N'"): + create('N:2,C:3').align_to('C') + + # names not found + with self.assertRaisesRegex(RuntimeError, "Cannot find dim 'C'"): + create('N:2,C:3').align_to('D', 'N') + + def test_align_as(self): + # align_as calls align_to internally. align_to has pretty substantial tests, + # so just test some basic things here. + tensor = create('C:2,N:3,H:5') + other = create('N:1,H:1,W:1,C:1') + output = tensor.align_as(other) + self.assertEqual(output.names, ['N', 'H', 'W', 'C']) + self.assertEqual(output.shape, [3, 5, 1, 2]) + + def test_align_tensors_two_inputs(self): def _test(tensor_namedshape, align_names, expected_sizes, expected_error): tensor_names, tensor_sizes = tensor_namedshape tensor = torch.empty(*tensor_sizes, names=tensor_names) + other = torch.empty([1] * len(align_names), names=align_names) if expected_error is not None: with self.assertRaisesRegex(RuntimeError, expected_error): - tensor.align_to(align_names) + torch.align_tensors(tensor, other) return - output = tensor.align_to(align_names) + output, _ = torch.align_tensors(tensor, other) self.assertEqual(output.shape, expected_sizes) self.assertEqual(output.names, align_names) @@ -1002,10 +1235,6 @@ def _test(tensor_namedshape, align_names, expected_sizes, expected_error): align_names=['D'], expected_sizes=None, expected_error='not a subsequence'), - Case(tensor_namedshape=(['N', 'C'], [2, 3]), - align_names=['C'], - expected_sizes=None, - expected_error='shorter list of dims'), # single-dim alignment test Case(tensor_namedshape=(['C'], [2]), @@ -1038,10 +1267,6 @@ def _test(tensor_namedshape, align_names, expected_sizes, expected_error): expected_error=None), # unnamed tensor tests - Case(tensor_namedshape=[None, [2, 3]], - align_names=[None], - expected_sizes=None, - expected_error='shorter list'), Case(tensor_namedshape=[None, [2, 3]], align_names=[None, None], expected_sizes=[2, 3], @@ -1090,14 +1315,12 @@ def _test(tensor_namedshape, align_names, expected_sizes, expected_error): _test(*test) def test_align_tensors(self): - # align_tensors shares code with align_to. test_align_to already tests - # the alignment rules, so we don't do that again here. def reference_fn(*tensors): longest_names = tensors[0].names for tensor in tensors: if len(tensor.names) > len(longest_names): longest_names = tensor.names - return [tensor.align_to(longest_names) for tensor in tensors] + return [tensor.align_to(*longest_names) for tensor in tensors] x = torch.empty(1, 1, names=('N', 'H')) y = torch.empty(2, 3, 5, names=('N', 'C', 'H')) @@ -1406,7 +1629,7 @@ def test_dot(self): # Disable all tests if named tensor is not available. for attr in dir(TestNamedTensor): if attr.startswith('test_'): - new_test = skipIfNamedTensorDisabled(getattr(TestNamedTensor, attr)) + new_test = skipIfNamedTensorDisabled(skipIfNotTestingNamedTensor(getattr(TestNamedTensor, attr))) setattr(TestNamedTensor, attr, new_test) if __name__ == '__main__': diff --git a/test/test_nn.py b/test/test_nn.py index 3d56ff044a862..585e5aea32e6e 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -24,7 +24,7 @@ import torch.nn.utils.rnn as rnn_utils from torch.nn.utils import clip_grad_norm_, clip_grad_value_ from torch.nn.utils import parameters_to_vector, vector_to_parameters -from torch.autograd import Variable, gradcheck +from torch.autograd import gradcheck from torch.autograd.gradcheck import gradgradcheck from torch.nn import Parameter from torch.nn.parallel._functions import Broadcast @@ -35,6 +35,7 @@ from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \ module_tests, criterion_tests, new_criterion_tests, loss_reference_fns, \ ctcloss_reference, new_module_tests +from common_device_type import instantiate_device_type_tests from torch.nn import MultiheadAttention @@ -905,90 +906,66 @@ def test_zero_grad(self): self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_()) def test_no_grad(self): - module = nn.Conv2d(2, 5, kernel_size=3, padding=1) - input = torch.randn(1, 2, 10, 10) - x = input - y = input.clone() + for dtype in [torch.bfloat16, torch.float, torch.double]: + module = nn.Conv2d(2, 5, kernel_size=3, padding=1).to(dtype) + input = torch.randn(1, 2, 10, 10).to(dtype) + x = input + y = input.clone() - output = module(x) - self.assertTrue(output.requires_grad) - output.backward(torch.ones(1, 5, 10, 10)) + output = module(x) + self.assertTrue(output.requires_grad) + output.backward(torch.ones(1, 5, 10, 10)) - with torch.no_grad(): - output2 = module(y) - self.assertFalse(output2.requires_grad) - self.assertRaises(RuntimeError, lambda: output2.backward(torch.ones(1, 5, 10, 10))) + with torch.no_grad(): + output2 = module(y) + self.assertFalse(output2.requires_grad) + self.assertRaises(RuntimeError, lambda: output2.backward(torch.ones(1, 5, 10, 10))) def test_invalid_conv1d(self): - module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True) - input = torch.randn(1, 3, 4) - with self.assertRaisesRegex(RuntimeError, - r'Calculated padded input size per channel: \(4\). ' + - r'Kernel size: \(10\). Kernel size can\'t be greater than actual input size'): - module(input) + for dtype in [torch.bfloat16, torch.float, torch.double]: + module = nn.Conv1d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True).to(dtype) + input = torch.randn(1, 3, 4).to(dtype) + with self.assertRaisesRegex(RuntimeError, + r'Calculated padded input size per channel: \(4\). ' + + r'Kernel size: \(10\). Kernel size can\'t be greater than actual input size'): + module(input) - # Negative stride check - module = nn.Conv1d(in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True) - input = torch.randn(1, 3, 4) - with self.assertRaisesRegex(RuntimeError, 'negative stride is not supported'): - module(input) + # Negative stride check + module = nn.Conv1d(in_channels=3, out_channels=6, kernel_size=3, stride=-1, bias=True).to(dtype) + input = torch.randn(1, 3, 4).to(dtype) + with self.assertRaisesRegex(RuntimeError, 'negative stride is not supported'): + module(input) def test_invalid_conv2d(self): - module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2) - input = torch.empty(1, 1, 4, 4) - self.assertRaises(RuntimeError, lambda: module(input)) - - module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True) - input = torch.randn(1, 3, 1, 1) - with self.assertRaisesRegex(RuntimeError, - r'Calculated padded input size per channel: \(1 x 1\). ' + - r'Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size'): - module(input) - - # Negative stride check - module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True) - input = torch.randn(1, 3, 4, 4) - with self.assertRaisesRegex(RuntimeError, 'negative stride is not supported'): - module(input) - - def test_invalid_conv3d(self): - module = torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2) - input = torch.empty(1, 1, 4, 4, 4) - self.assertRaises(RuntimeError, lambda: module(input)) - - # Negative stride check - module = torch.nn.Conv3d(1, 1, kernel_size=3, stride=-2) - input = torch.empty(1, 1, 4, 4, 4) - with self.assertRaisesRegex(RuntimeError, 'negative stride is not supported'): - module(input) - - def _test_dropout(self, cls, cuda, input): - p = 0.2 - device = torch.device("cuda") if cuda else torch.device("cpu") - input = input.to(device).fill_(1 - p) - - module = cls(p) - input_var = input.clone().requires_grad_() - output = module(input_var) - self.assertLess(abs(output.data.mean() - (1 - p)), 0.05) - output.backward(input) - self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05) + for dtype in [torch.bfloat16, torch.float, torch.double]: + module = torch.nn.Conv2d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype) + input = torch.empty(1, 1, 4, 4).to(dtype) + self.assertRaises(RuntimeError, lambda: module(input)) - module = cls(p, True) - input_var = input.clone().requires_grad_() - output = module(input_var + 0) - self.assertLess(abs(output.data.mean() - (1 - p)), 0.05) - output.backward(input) - self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05) + module = nn.Conv2d(in_channels=3, out_channels=33, kernel_size=10, stride=1, bias=True) + input = torch.randn(1, 3, 1, 1) + with self.assertRaisesRegex(RuntimeError, + r'Calculated padded input size per channel: \(1 x 1\). ' + + r'Kernel size: \(10 x 10\). Kernel size can\'t be greater than actual input size'): + module(input) - # check eval mode doesn't change anything - for inplace in [True, False]: - module = cls(p, inplace).eval() - self.assertEqual(input, module(input)) + # Negative stride check + module = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=4, stride=-1, bias=True).to(dtype) + input = torch.randn(1, 3, 4, 4).to(dtype) + with self.assertRaisesRegex(RuntimeError, 'negative stride is not supported'): + module(input) - # Check that these don't raise errors - module.__repr__() - str(module) + def test_invalid_conv3d(self): + for dtype in [torch.bfloat16, torch.float, torch.double]: + module = torch.nn.Conv3d(1, 1, kernel_size=3, dilation=2, stride=2).to(dtype) + input = torch.empty(1, 1, 4, 4, 4).to(dtype) + self.assertRaises(RuntimeError, lambda: module(input)) + + # Negative stride check + module = torch.nn.Conv3d(1, 1, kernel_size=3, stride=-2) + input = torch.empty(1, 1, 4, 4, 4) + with self.assertRaisesRegex(RuntimeError, 'negative stride is not supported'): + module(input) def _test_alpha_dropout(self, cls, input): mean = input.mean() @@ -1804,7 +1781,7 @@ def test_overwrite_module_params_on_conversion_cpu_cuda(self): # Without using `torch.no_grad()`, this will leak CUDA memory. # (Issue is filed at https://github.com/pytorch/pytorch/issues/21875) mw[0][0] = 5 - with self.assertRaisesRegex(RuntimeError, "Expected object of backend CPU but got backend CUDA"): + with self.assertRaisesRegex(RuntimeError, "Expected object of backend CUDA but got backend CPU"): mw[0][0] == mw._base[0][0] try: @@ -1917,7 +1894,7 @@ def compare_scaling(grads): grads = torch.arange(1., 101).view(10, 10), torch.ones(10).div(1000) for norm_type in [0.5, 1.5, 2, 4, 'inf']: for p, g in zip(l.parameters(), grads): - p._grad = Variable(g.clone().view_as(p.data)) + p._grad = g.clone().view_as(p.data) norm_before = compute_norm(norm_type) norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type) norm_after = compute_norm(norm_type) @@ -1986,7 +1963,7 @@ def test_vector_to_parameters(self): fc1 = nn.Linear(10, 20) model = nn.Sequential(conv1, fc1) - vec = Variable(torch.arange(0., 980)) + vec = torch.arange(0., 980) vector_to_parameters(vec, model.parameters()) sample = next(model.parameters())[0, 0, 0] @@ -2353,7 +2330,7 @@ def test_threshold_int(self): def test_embedding_sparse_basic(self): embedding = nn.Embedding(10, 20, sparse=True) - input = Variable(torch.LongTensor([[0, 2, 4, 5], [4, 3, 0, 9]])) + input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long) embedding(input).sum().backward() self.assertTrue(embedding.weight.grad.is_sparse) self.assertEqual(embedding.weight.grad.shape, embedding.weight.shape) @@ -2449,13 +2426,13 @@ def _test_embedding_backward(self, device='cpu', dtype=torch.float64): def test_embedding_padding_idx(self): embedding = nn.Embedding(10, 20, padding_idx=0) - input = Variable(torch.LongTensor([[0, 2, 4, 5], [4, 3, 0, 9]])) + input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long) output = embedding(input) self.assertEqual(output[0][0].sum(), 0) self.assertEqual(output[1][2].sum(), 0) embedding = nn.Embedding(10, 20, padding_idx=0, sparse=True) - input = Variable(torch.LongTensor([[0, 2, 4, 5], [4, 3, 0, 9]])) + input = torch.tensor([[0, 2, 4, 5], [4, 3, 0, 9]], dtype=torch.long) output = embedding(input) self.assertEqual(output[0][0].sum(), 0) self.assertEqual(output[1][2].sum(), 0) @@ -2463,13 +2440,13 @@ def test_embedding_padding_idx(self): # negative indexing check for padding_idx # padding_idx=-2, num_embeddings=10 ==> index 8 padded embedding = nn.Embedding(10, 20, padding_idx=-2) - input = Variable(torch.LongTensor([[0, 2, 8, 5], [4, 8, 0, 9]])) + input = torch.tensor([[0, 2, 8, 5], [4, 8, 0, 9]], dtype=torch.long) output = embedding(input) self.assertEqual(output[0][2].sum(), 0) self.assertEqual(output[1][1].sum(), 0) embedding = nn.Embedding(10, 20, padding_idx=-2, sparse=True) - input = Variable(torch.LongTensor([[0, 2, 8, 5], [4, 8, 0, 9]])) + input = torch.tensor([[0, 2, 8, 5], [4, 8, 0, 9]], dtype=torch.long) output = embedding(input) self.assertEqual(output[0][2].sum(), 0) self.assertEqual(output[1][1].sum(), 0) @@ -2483,7 +2460,7 @@ def test_embedding_padding_idx(self): embedding = nn.Embedding(5, 2, padding_idx=padding_idx) for n in (1, 2): for other_indices in ([], [1, 3], [2]): - indices = torch.LongTensor(other_indices + [padding_idx] * n) + indices = torch.tensor(other_indices + [padding_idx] * n, dtype=torch.long) pre = embedding.weight[padding_idx].clone() embedding(indices).sum().backward() after = (embedding.weight + embedding.weight.grad)[padding_idx] @@ -2501,7 +2478,7 @@ def test_embedding_padding_idx(self): def test_embedding_max_norm(self): embedding = nn.Embedding(22, 5, max_norm=1.0) - input = Variable(torch.LongTensor([2, 8, 8, 6])) + input = torch.tensor([2, 8, 8, 6], dtype=torch.long) output = embedding(input) self.assertEqual(output[1], output[2]) self.assertTrue(output.data.norm(p=2, dim=1).le(1).all()) @@ -2555,6 +2532,7 @@ def test_embedding_functional(self): res_F = F.embedding(a, embeddings) self.assertEqual(res_old, res_F) + # test is flaky on ROCm CI @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") @repeat_test_for_types([torch.float, torch.half]) @skipIfRocm @@ -2896,7 +2874,7 @@ def test_embeddingbag_from_pretrained_options(self): @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_pool3d_size_one_feature_dim(self): # Tests crazy strides for feature dim of size 1 - x = Variable(torch.randn(7, 1, 5, 3, 2, device="cuda")) + x = torch.randn(7, 1, 5, 3, 2, device="cuda") strange_strides = [30, 1234, 6, 2, 1] y = x.as_strided(x.size(), strange_strides) x = x.cpu().as_strided(x.size(), strange_strides) @@ -2957,6 +2935,7 @@ def _test_embedding_bag_empty_input(self, device): x = torch.tensor([], device=device, dtype=torch.long) for sparse in [True, False]: Embed = torch.nn.EmbeddingBag(m, n, sparse=sparse) + Embed.to(device) output = Embed(input=x, offsets=torch.tensor([0], device=device, dtype=torch.long)) self.assertEqual(output, torch.zeros_like(output)) @@ -3155,51 +3134,6 @@ def func(x): gradcheck(func, [x]) gradgradcheck(func, [x]) - def test_Dropout(self): - input = torch.Tensor(1000) - self._test_dropout(nn.Dropout, False, input) - - def test_Dropout2d(self): - b = random.randint(1, 5) - w = random.randint(1, 5) - h = random.randint(1, 5) - num_features = 1000 - input = torch.Tensor(num_features, b, w, h) - self._test_dropout(nn.Dropout2d, False, input) - - def test_Dropout3d(self): - b = random.randint(1, 5) - w = random.randint(1, 5) - h = random.randint(1, 5) - d = random.randint(1, 2) - num_features = 1000 - input = torch.Tensor(num_features, b, d, w, h) - self._test_dropout(nn.Dropout3d, False, input) - - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - def test_Dropout_cuda(self): - input = torch.Tensor(1000) - self._test_dropout(nn.Dropout, True, input) - - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - def test_Dropout2d_cuda(self): - b = random.randint(1, 5) - w = random.randint(1, 5) - h = random.randint(1, 5) - num_features = 1000 - input = torch.Tensor(num_features, b, w, h) - self._test_dropout(nn.Dropout2d, True, input) - - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - def test_Dropout3d_cuda(self): - b = random.randint(1, 5) - w = random.randint(1, 5) - h = random.randint(1, 5) - d = random.randint(1, 2) - num_features = 1000 - input = torch.Tensor(num_features, b, d, w, h) - self._test_dropout(nn.Dropout3d, True, input) - def test_AlphaDropout(self): # generate random tensor with zero mean and unit std input = torch.randn(5000) @@ -3214,498 +3148,137 @@ def test_FeatureAlphaDropout(self): input = torch.randn(num_features, b, d, w, h) self._test_alpha_dropout(nn.FeatureAlphaDropout, input) - def _test_InstanceNorm_general(self, cls, input, device="cpu", dtype=torch.float): - # default case track_running_stats=False - b, c = input.size(0), input.size(1) - input_var = input.to(device=device, dtype=dtype).requires_grad_() - - IN = cls(c, eps=0).to(device, dtype) + def test_pad(self): + inputs = torch.randn(1, 3, 4, 4, requires_grad=True) + _assertGradAndGradgradChecks(self, lambda x: F.pad(x, (1, 1, 1, 1)), (inputs,)) + _assertGradAndGradgradChecks(self, lambda x: F.pad(x, (-1, 1, -2, 1)), (inputs,)) + _assertGradAndGradgradChecks(self, lambda x: F.pad(x, (-1, 1, -2, 1), value=2), (inputs,)) + self.assertTrue(gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='replicate'), (inputs,))) + self.assertTrue(gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='reflect'), (inputs,))) - output = IN(input_var) - out_reshaped = output.view(b * c, -1) + inputs = torch.randn(1, 2, 3, 4, 4, requires_grad=True) + self.assertTrue(gradcheck(lambda x: F.pad(x, (1, 1, 1, 1, 1, 1), mode='replicate'), (inputs,))) - mean = out_reshaped.mean(1) - var = out_reshaped.var(1, unbiased=False) + # assert that relfection padding errors when pad >= input size + expected_err_msg = r"Padding size should be less than the corresponding input dimension" + self.assertRaisesRegex(RuntimeError, expected_err_msg, + lambda: F.pad(torch.randn(1, 1, 2, 3), (1, 1, 3, 0), mode='reflect')) + self.assertRaisesRegex(RuntimeError, expected_err_msg, + lambda: F.pad(torch.randn(1, 1, 2), (2, 1), mode='reflect')) - self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5) - self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5) + def test_pad_scalar_error(self): + inputs = torch.tensor(0., requires_grad=True) + self.assertRaises(AssertionError, lambda: F.pad(inputs, (1, 1))) + self.assertRaises(AssertionError, lambda: F.pad(inputs, (1,))) - # check that eval mode doesn't change behavior - grad_out = torch.randn_like(output) - res1 = output.data.clone() - output.backward(grad_out) - grad1 = input_var.grad.data.clone() + @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), + "Scipy v1.0 and/or numpy not found") + def test_multihead_attention(self): + def _scaled_dot_attn_ref(Q, K, V, dims, unseen_mask=None, key_padding_mask=None): + """ Numpy-based reference implementation of scaled dot attention + for testing""" - IN.eval() - output = IN(input_var) - input_var.grad = None - output.backward(grad_out) - res2 = output.data - grad2 = input_var.grad.data - self.assertEqual(res1, res2) - self.assertEqual(grad1, grad2) + QKT = _batchmatmul( + Q, + np.transpose(K, axes=[0, 1, 3, 2]) + / np.sqrt(dims[3], dtype=np.float32), # divide by sqrt(d_head) + ) + b1, b2, s1, s2 = QKT.shape + if unseen_mask is not None or src_lengths is not None: + # assert s1 == s2 + for i in range(b1): + for j in range(b2): + for m in range(s1): + for n in range(s2): + if unseen_mask is not None and unseen_mask[m][n] == 0: + QKT[i, j, m, n] = -np.inf + if key_padding_mask is not None and key_padding_mask[i][n]: + QKT[i, j, m, n] = -np.inf - # If track_running_stats=True and momentum=1, running_mean/var should be - # equal to mean/var of the input (with unbias correction) - IN = cls(c, momentum=1, eps=0, track_running_stats=True).to(device, dtype) + reference = _softmax(QKT) + ref_attn_weight = reference + ref_attn_weight = np.sum(ref_attn_weight, axis=1) / b2 + reference = _batchmatmul(reference, V) + return reference, ref_attn_weight - output = IN(input_var) + def _batchmatmul(a, b): # batchmatmul over 4 dim matrix + """ Numpy-based batch matrix multiply over 4 dim matrix""" + assert a.shape[0] == b.shape[0] + assert a.shape[1] == b.shape[1] + retval = np.zeros( + (a.shape[0], a.shape[1], a.shape[2], b.shape[3]), dtype=np.float32 + ) + for i in range(a.shape[0]): + for j in range(a.shape[1]): + retval[i, j, :, :] = np.matmul(a[i, j, :, :], b[i, j, :, :]) + return retval - input_reshaped = input_var.transpose(1, 0).reshape(c, -1) - mean = input_reshaped.mean(1) + def _softmax(x): # softmax over 4 dim matrix + """ Numpy-based reference softmax over 4 dim matrix""" + np.seterr(invalid='ignore') + output = np.zeros(x.shape, dtype=np.float64) + for i in range(x.shape[0]): + for j in range(x.shape[1]): + for k in range(x.shape[2]): + x_curr = x[i, j, k, :] + e_x = np.exp(x_curr - np.amax(x_curr)) + output[i, j, k, :] = e_x / np.sum(e_x) + return output - input_reshaped = input_var.transpose(1, 0).reshape(c, b, -1) - var = input_reshaped.var(2, unbiased=True)[:, :] + def _split_heads_ref(X, dims, nheads, d_head): + X_split = np.reshape(X, dims[:2] + [nheads, d_head]) + X_split_transposed = np.transpose(X_split, [0, 2, 1, 3]) + reference = np.reshape(X_split_transposed, [dims[0], nheads, dims[1], d_head]) + return reference - self.assertAlmostEqual(torch.abs(mean.data - IN.running_mean).mean(), 0, delta=1e-5) - self.assertAlmostEqual(torch.abs(var.data.mean(1) - IN.running_var).mean(), 0, delta=1e-5) + def _combine_heads_ref(X, dims, nheads, d_head): + X_transposed = np.transpose(X, [0, 2, 1, 3]) + reference = np.reshape(X_transposed, dims[:2] + [nheads * d_head]) + return reference - # in eval mode, adding X * std to a channel in input should make the - # corresponding channel in output have mean X - IN.eval() - delta = IN.running_var.sqrt() * torch.arange(c, device=device, dtype=dtype) - delta = delta.view(-1, *[1 for _ in range(2, input.dim())]) - output = IN(input_var + delta) - self.assertEqual(output.transpose(0, 1).reshape(c, -1).mean(1), torch.arange(c)) + def _fc(X, X_weight, X_bias): + X_fc_b = X_bias.detach().numpy() + X_fc_w = X_weight.detach().numpy() + return np.matmul(X, np.transpose(X_fc_w)) + X_fc_b - def _test_InstanceNorm_cuda_half(self, cls, input): - # THNN - input = Variable(input.cuda().half().random_(1, 10), requires_grad=True) - m = cls(input.size(1), affine=True, track_running_stats=True).to("cuda", torch.half) - thnn_output = m(input) - thnn_output.sum().backward() - thnn_input_grad = input.grad.data.clone() - self.assertEqual(thnn_output.type(), input.type()) - # cuDNN - if TEST_CUDNN: - input.grad = None - m = m.float() - cudnn_output = m(input) - cudnn_output.sum().backward() - cudnn_input_grad = input.grad.data.clone() - self.assertEqual(cudnn_output.type(), input.type()) - self.assertAlmostEqual(cudnn_output, thnn_output, delta=1e-4) - self.assertAlmostEqual(cudnn_input_grad, thnn_input_grad, delta=1e-3) + def _create_src_lengths_mask(batch_size, src_lengths): + """ + Generate boolean mask to prevent attention beyond the end of source + Inputs: + batch_size : int + src_lengths : [batch_size] of sentence lengths + Outputs: + [batch_size, max_src_len] + """ + max_srclen = src_lengths.max() + src_indices = torch.arange(0, max_srclen).unsqueeze(0).type_as(src_lengths) + src_indices = src_indices.expand(batch_size, max_srclen) + src_lengths = src_lengths.unsqueeze(dim=1).expand(batch_size, max_srclen) + # returns [batch_size, max_seq_len] + return (src_indices < src_lengths).int().detach() - def test_InstanceNorm1d_general(self): - b = random.randint(3, 5) - c = random.randint(3, 5) - d = random.randint(8, 10) + def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, add_zero_attn=False, + saved_kv=False, same_embed_dim=False): + for _ in range(100): + batch_sz, seq_len = [random.randint(2, 10) for r in range(2)] + d_head = random.randint(3, 10) + nheads = random.randint(3, 10) + d_model = d_head * nheads + if same_embed_dim: + kv_dim = d_model + else: + kv_dim = random.randint(5, 20) + dims = [batch_sz, seq_len, kv_dim] - input = torch.rand(b, c, d) - self._test_InstanceNorm_general(nn.InstanceNorm1d, input, dtype=torch.float) - - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - def test_InstanceNorm1d_general_cuda(self): - b = random.randint(3, 5) - c = random.randint(3, 5) - d = random.randint(8, 10) - - input = torch.rand(b, c, d) - self._test_InstanceNorm_general(nn.InstanceNorm1d, input, "cuda", torch.float) - self._test_InstanceNorm_cuda_half(nn.InstanceNorm1d, input) - - def test_InstanceNorm2d_general(self): - b = random.randint(3, 5) - c = random.randint(3, 5) - w = random.randint(3, 6) - h = random.randint(6, 8) - - input = torch.rand(b, c, h, w) - self._test_InstanceNorm_general(nn.InstanceNorm2d, input, dtype=torch.float) - - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - def test_InstanceNorm2d_general_cuda(self): - b = random.randint(3, 5) - c = random.randint(3, 5) - w = random.randint(3, 6) - h = random.randint(6, 8) - - input = torch.rand(b, c, h, w) - self._test_InstanceNorm_general(nn.InstanceNorm2d, input, "cuda", torch.float) - self._test_InstanceNorm_cuda_half(nn.InstanceNorm2d, input) - - def test_InstanceNorm3d_general(self): - b = random.randint(3, 5) - c = random.randint(3, 5) - w = random.randint(2, 5) - h = random.randint(2, 5) - d = random.randint(2, 5) - - input = torch.rand(b, c, h, w, d) - self._test_InstanceNorm_general(nn.InstanceNorm3d, input, dtype=torch.float) - - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - @skipIfRocm - def test_InstanceNorm3d_general_cuda(self): - b = random.randint(3, 5) - c = random.randint(2, 5) - w = random.randint(2, 5) - h = random.randint(2, 5) - d = random.randint(2, 5) - - input = torch.rand(b, c, h, w, d) - self._test_InstanceNorm_general(nn.InstanceNorm3d, input, "cuda", torch.float) - self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input) - - def _test_LayerNorm_general(self, device="cpu", dtype=torch.float): - for i in range(2, 6): - shape = torch.randint(3, 6, (i,), dtype=torch.long).tolist() - x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10) - normalized_ndim = random.randint(1, i - 1) # inclusive - normalized_shape = shape[-normalized_ndim:] - unnormalized_shape = shape[:-normalized_ndim] - - # test that LN normalizes to mean 0 and stddev 1 - ln = nn.LayerNorm(normalized_shape, eps=0).to(device, dtype) - ln.weight.data.fill_(1) - ln.bias.data.fill_(0) - output = ln(x) - out_reshaped = output.view(*(unnormalized_shape + [-1])) - mean = out_reshaped.mean(-1) - var = out_reshaped.var(-1, unbiased=False) - self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5) - self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5) - - # test that LN applies weight and bias correctly - scale, bias = torch.empty(2).uniform_(0.2, 2).tolist() - ln.weight.data.fill_(scale) - ln.bias.data.fill_(bias) - output = ln(x) - out_reshaped = output.view(*(unnormalized_shape + [-1])) - mean = out_reshaped.mean(-1) - var = out_reshaped.var(-1, unbiased=False) - self.assertAlmostEqual(torch.abs(mean.data).mean(), bias, delta=1e-5) - self.assertAlmostEqual(torch.abs(var.data).mean(), scale ** 2, delta=1e-5) - - bad_norm_shape_input_shape = { - (): (), - (2, 3): (3,), - (2,): (1, 2, 3), - (10,): (2, 3), - 10: (2, 3), - } - for norm_shape, input_shape in bad_norm_shape_input_shape.items(): - ln = nn.LayerNorm(norm_shape) - input = torch.empty(input_shape, device=device, dtype=dtype).uniform_(0, 10) - self.assertRaises(RuntimeError, lambda: ln(input)) - - def _test_LayerNorm_cuda_half(self): - input = Variable(torch.empty(2, 3, 3, 2).to("cuda", torch.half).random_(1, 10), requires_grad=True) - m = nn.LayerNorm([3, 2]).to("cuda", torch.half) - output = m(input) - output.sum().backward() - self.assertEqual(output.type(), input.type()) - - def test_LayerNorm_general(self): - self._test_LayerNorm_general() - - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - def test_LayerNorm_general_cuda(self): - self._test_LayerNorm_general("cuda") - self._test_LayerNorm_cuda_half() - - def _test_GroupNorm_general(self, device="cpu", dtype=torch.float): - good_shape_g = { - (1, 2, 3, 4): 2, - (2, 3, 10): 3, - (3, 1, 1, 1, 2): 1, - (2, 6, 4, 2, 2): 3, - } - for shape, g in good_shape_g.items(): - x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10) - b = shape[0] - c = shape[1] - - # test that GN normalizes to mean 0 and stddev 1 - gn = nn.GroupNorm(g, c, eps=0).to(device, dtype) - gn.weight.data.fill_(1) - gn.bias.data.fill_(0) - output = gn(x) - out_reshaped = output.view(b, g, -1) - mean = out_reshaped.mean(-1) - var = out_reshaped.var(-1, unbiased=False) - self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5) - self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5) - - # test that GN applies weight and bias correctly - scale = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2) - bias = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2) - gn.weight.data.copy_(scale) - gn.bias.data.copy_(bias) - output = gn(x) - out_reshaped = output.view(b, c, -1) - out_normed = (out_reshaped - bias.view(c, 1)) / scale.view(c, 1) - out_normed_reshaped = out_normed.view(b, g, -1) - mean = out_normed_reshaped.mean(-1) - var = out_normed_reshaped.var(-1, unbiased=False) - self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5) - self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5) - - bad_shape_g = { - (1, 2, 3, 4): 3, - (2, 3, 10): 2, - (3, 1, 1, 1, 2): 10, - (2, 6, 4, 2, 2): 4, - } - for shape, g in bad_shape_g.items(): - gn = nn.GroupNorm(g, shape[1]) - input = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10) - self.assertRaises(RuntimeError, lambda: gn(input)) - - def _test_GroupNorm_cuda_half(self): - input = Variable(torch.empty(2, 3, 3, 2).to("cuda", torch.half).random_(1, 10), requires_grad=True) - input = torch.zeros(2, 4, 3, 2, requires_grad=True).cuda().half().random_(1, 10) - m = nn.GroupNorm(2, 4).to("cuda", torch.half) - output = m(input) - output.sum().backward() - self.assertEqual(output.type(), input.type()) - - def test_GroupNorm_general(self): - self._test_GroupNorm_general(dtype=torch.float) - - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - def test_GroupNorm_general_cuda(self): - self._test_GroupNorm_general("cuda", torch.float) - self._test_GroupNorm_cuda_half() - - def test_pad(self): - inputs = torch.randn(1, 3, 4, 4, requires_grad=True) - _assertGradAndGradgradChecks(self, lambda x: F.pad(x, (1, 1, 1, 1)), (inputs,)) - _assertGradAndGradgradChecks(self, lambda x: F.pad(x, (-1, 1, -2, 1)), (inputs,)) - _assertGradAndGradgradChecks(self, lambda x: F.pad(x, (-1, 1, -2, 1), value=2), (inputs,)) - self.assertTrue(gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='replicate'), (inputs,))) - self.assertTrue(gradcheck(lambda x: F.pad(x, (-1, 1, -2, 1), mode='reflect'), (inputs,))) - - inputs = torch.randn(1, 2, 3, 4, 4, requires_grad=True) - self.assertTrue(gradcheck(lambda x: F.pad(x, (1, 1, 1, 1, 1, 1), mode='replicate'), (inputs,))) - - # assert that relfection padding errors when pad >= input size - expected_err_msg = r"Padding size should be less than the corresponding input dimension" - self.assertRaisesRegex(RuntimeError, expected_err_msg, - lambda: F.pad(torch.randn(1, 1, 2, 3), (1, 1, 3, 0), mode='reflect')) - self.assertRaisesRegex(RuntimeError, expected_err_msg, - lambda: F.pad(torch.randn(1, 1, 2), (2, 1), mode='reflect')) - - @staticmethod - def _test_one_hot(self, use_cuda=False): - device = torch.device('cuda' if use_cuda else 'cpu') - with self.assertRaises(RuntimeError): - torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1) - - with self.assertRaises(RuntimeError): - torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 3) - - t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device)) - expected = torch.tensor([[0, 0, 0, 1, 0], - [0, 0, 0, 0, 1], - [0, 1, 0, 0, 0], - [1, 0, 0, 0, 0]], device=device) - self.assertEqual(t, expected) - - t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1) - expected = torch.tensor([[0, 0, 0, 1, 0], - [0, 0, 0, 0, 1], - [0, 1, 0, 0, 0], - [1, 0, 0, 0, 0]], device=device) - self.assertEqual(t, expected) - - t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6) - expected = torch.tensor([[0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 1, 0], - [0, 1, 0, 0, 0, 0], - [1, 0, 0, 0, 0, 0]], device=device) - self.assertEqual(t, expected) - - t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device)) - expected = torch.tensor([[[0, 0, 0, 1, 0], - [0, 0, 0, 0, 1]], - [[0, 1, 0, 0, 0], - [1, 0, 0, 0, 0]]], device=device) - self.assertEqual(t, expected) - - t = torch.nn.functional.one_hot(torch.tensor(4, device=device)) - expected = torch.tensor([0, 0, 0, 0, 1], device=device) - self.assertEqual(t, expected) - - t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100) - expected = torch.empty([4, 0, 100]) - self.assertEqual(t, expected) - - with self.assertRaises(RuntimeError): - torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device)) - - with self.assertRaises(RuntimeError): - torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2) - - def test_one_hot(self): - self._test_one_hot(self) - - @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - def test_one_hot_cuda(self): - self._test_one_hot(self, use_cuda=True) - - def test_pad_scalar_error(self): - inputs = torch.tensor(0., requires_grad=True) - self.assertRaises(AssertionError, lambda: F.pad(inputs, (1, 1))) - self.assertRaises(AssertionError, lambda: F.pad(inputs, (1,))) - - def test_nn_scalars(self): - # One off tests to ensure scalars from nn.yaml are properly applied - def verify_scalars(input, output): - if input.dim() == 0: - self.assertEqual((), output.shape) - else: - self.assertNotEqual((), output.shape) - output.sum().backward() - self.assertEqual(input.shape, input.grad.shape) - - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: - for input_shape in [(5, 6), ()]: - for module in [torch.nn.ELU, torch.nn.Hardtanh, torch.nn.LeakyReLU, torch.nn.LogSigmoid, - torch.nn.RReLU, torch.nn.Softshrink, torch.nn.Softplus, torch.nn.Sigmoid, - torch.nn.Tanh]: - input = torch.randn(input_shape, device=device, requires_grad=True) - m = module() - output = m(input) - verify_scalars(input, output) - - def test_nn_scalars_reductions(self): - # One off tests to ensure scalars from nn.yaml are properly applied - def verify_reduction_scalars(input, reduction, output): - if reduction != 'none' or input.dim() == 0: - self.assertEqual((), output.shape) - else: - self.assertNotEqual((), output.shape) - output.sum().backward() - self.assertEqual(input.shape, input.grad.shape) - - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: - for input_shape in [(5, 6), ()]: - for reduction in ['none', 'mean', 'sum']: - for module in [torch.nn.BCELoss, torch.nn.L1Loss, torch.nn.MSELoss, - torch.nn.SmoothL1Loss, torch.nn.SoftMarginLoss]: - input = torch.randn(input_shape, device=device, requires_grad=True) - target = torch.empty(input_shape, device=device).random_(2) - sigmoid = nn.Sigmoid() - - input = torch.randn(input_shape, device=device, requires_grad=True) - m = module(reduction=reduction) - output = m(sigmoid(input), target) - verify_reduction_scalars(input, reduction, output) - - @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), - "Scipy v1.0 and/or numpy not found") - def test_multihead_attention(self): - def _scaled_dot_attn_ref(Q, K, V, dims, unseen_mask=None, key_padding_mask=None): - """ Numpy-based reference implementation of scaled dot attention - for testing""" - - QKT = _batchmatmul( - Q, - np.transpose(K, axes=[0, 1, 3, 2]) - / np.sqrt(dims[3], dtype=np.float32), # divide by sqrt(d_head) - ) - b1, b2, s1, s2 = QKT.shape - if unseen_mask is not None or src_lengths is not None: - # assert s1 == s2 - for i in range(b1): - for j in range(b2): - for m in range(s1): - for n in range(s2): - if unseen_mask is not None and unseen_mask[m][n] == 0: - QKT[i, j, m, n] = -np.inf - if key_padding_mask is not None and key_padding_mask[i][n]: - QKT[i, j, m, n] = -np.inf - - reference = _softmax(QKT) - ref_attn_weight = reference - ref_attn_weight = np.sum(ref_attn_weight, axis=1) / b2 - reference = _batchmatmul(reference, V) - return reference, ref_attn_weight - - def _batchmatmul(a, b): # batchmatmul over 4 dim matrix - """ Numpy-based batch matrix multiply over 4 dim matrix""" - assert a.shape[0] == b.shape[0] - assert a.shape[1] == b.shape[1] - retval = np.zeros( - (a.shape[0], a.shape[1], a.shape[2], b.shape[3]), dtype=np.float32 - ) - for i in range(a.shape[0]): - for j in range(a.shape[1]): - retval[i, j, :, :] = np.matmul(a[i, j, :, :], b[i, j, :, :]) - return retval - - def _softmax(x): # softmax over 4 dim matrix - """ Numpy-based reference softmax over 4 dim matrix""" - np.seterr(invalid='ignore') - output = np.zeros(x.shape, dtype=np.float64) - for i in range(x.shape[0]): - for j in range(x.shape[1]): - for k in range(x.shape[2]): - x_curr = x[i, j, k, :] - e_x = np.exp(x_curr - np.amax(x_curr)) - output[i, j, k, :] = e_x / np.sum(e_x) - return output - - def _split_heads_ref(X, dims, nheads, d_head): - X_split = np.reshape(X, dims[:2] + [nheads, d_head]) - X_split_transposed = np.transpose(X_split, [0, 2, 1, 3]) - reference = np.reshape(X_split_transposed, [dims[0], nheads, dims[1], d_head]) - return reference - - def _combine_heads_ref(X, dims, nheads, d_head): - X_transposed = np.transpose(X, [0, 2, 1, 3]) - reference = np.reshape(X_transposed, dims[:2] + [nheads * d_head]) - return reference - - def _fc(X, X_weight, X_bias): - X_fc_b = X_bias.detach().numpy() - X_fc_w = X_weight.detach().numpy() - return np.matmul(X, np.transpose(X_fc_w)) + X_fc_b - - def _create_src_lengths_mask(batch_size, src_lengths): - """ - Generate boolean mask to prevent attention beyond the end of source - - Inputs: - batch_size : int - src_lengths : [batch_size] of sentence lengths - - Outputs: - [batch_size, max_src_len] - """ - max_srclen = src_lengths.max() - src_indices = torch.arange(0, max_srclen).unsqueeze(0).type_as(src_lengths) - src_indices = src_indices.expand(batch_size, max_srclen) - src_lengths = src_lengths.unsqueeze(dim=1).expand(batch_size, max_srclen) - # returns [batch_size, max_seq_len] - return (src_indices < src_lengths).int().detach() - - def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, add_zero_attn=False, - saved_kv=False, same_embed_dim=False): - for _ in range(100): - batch_sz, seq_len = [random.randint(2, 10) for r in range(2)] - d_head = random.randint(3, 10) - nheads = random.randint(3, 10) - d_model = d_head * nheads - if same_embed_dim: - kv_dim = d_model - else: - kv_dim = random.randint(5, 20) - dims = [batch_sz, seq_len, kv_dim] - - saved_k = None - saved_k_tensor = None - saved_v = None - saved_v_tensor = None - if saved_kv: - saved_k = np.random.rand(batch_sz * nheads, seq_len, d_head) - saved_k_tensor = torch.from_numpy(saved_k) - saved_v = np.random.rand(batch_sz * nheads, seq_len, d_head) - saved_v_tensor = torch.from_numpy(saved_v) + saved_k = None + saved_k_tensor = None + saved_v = None + saved_v_tensor = None + if saved_kv: + saved_k = np.random.rand(batch_sz * nheads, seq_len, d_head) + saved_k_tensor = torch.from_numpy(saved_k) + saved_v = np.random.rand(batch_sz * nheads, seq_len, d_head) + saved_v_tensor = torch.from_numpy(saved_v) key_padding_mask = None key_padding_mask_tensor = None @@ -4007,7 +3580,6 @@ def test_batchnorm_grad(self): self._test_batchnorm_grad() @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") - @skipIfRocm def test_batchnorm_grad_cuda(self): self._test_batchnorm_grad("cuda") if TEST_CUDNN: @@ -4211,7 +3783,6 @@ def _test_gather(self, output_device): _assertGradAndGradgradChecks(self, lambda x, y: dp.gather((x, y), output_device), inputs) @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") - @skipIfRocm def test_gather_cpu(self): self._test_gather(-1) @@ -4222,10 +3793,10 @@ def test_gather_gpu(self): @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") def test_gather_different_len_dicts(self): inputs = ( - {'a': Variable(torch.randn(1, 2).cuda(0), requires_grad=True)}, + {'a': torch.randn(1, 2, requires_grad=True, device="cuda:0")}, { - 'b': Variable(torch.randn(1, 2).cuda(1), requires_grad=True), - 'a': Variable(torch.randn(1, 2).cuda(1), requires_grad=True) + 'b': torch.randn(1, 2, requires_grad=True, device="cuda:1"), + 'a': torch.randn(1, 2, requires_grad=True, device="cuda:1"), } ) with self.assertRaises(ValueError): @@ -4264,7 +3835,7 @@ def test_broadcast_no_grad(self): @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") def test_replicate(self): module = nn.Linear(10, 5).float().cuda() - input = Variable(torch.randn(2, 10).float().cuda()) + input = torch.randn(2, 10, dtype=torch.float, device="cuda") expected_output = module(input).data for devices in [(0, 1), [0, 1]]: replicas = dp.replicate(module, devices) @@ -4287,7 +3858,6 @@ def test_replicate_buffers(self): self.assertEqual(replica.bn.num_batches_tracked.get_device(), i, 'buffer on wrong device') @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") - @skipIfRocm def test_data_parallel_buffers_requiring_grad(self): class TestModule(nn.Module): def __init__(self, t): @@ -4310,7 +3880,6 @@ def fn(t): torch.autograd.gradcheck(fn, (m.t_rg,)) @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") - @skipIfRocm def test_data_parallel_rnn(self): class TestModule(torch.nn.Module): @@ -4449,7 +4018,7 @@ def local_test(out): @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel_small_back(self): l = nn.Linear(10, 5).float().cuda() - i = Variable(torch.randn(20, 10).float().cuda()) + i = torch.randn(20, 10, dtype=torch.float, device="cuda") out = dp.data_parallel(l, i, (0, 1)) self.assertEqual(out, l(i)) @@ -4544,7 +4113,7 @@ def forward(self, x): gc.collect() model = nn.DataParallel(Model().cuda()) - data = Variable(torch.randn(1).cuda()) + data = torch.randn(1, device="cuda") model(data) refcycles = gc.collect() @@ -4560,7 +4129,7 @@ def forward(self, x): return x l = Layer() - i = Variable(torch.randn(20, 10).float().cuda()) + i = torch.randn(20, 10, dtype=torch.float, device="cuda") with torch.no_grad(): dp.data_parallel(l, i, (0, 1)) self.assertRaises(AssertionError, lambda: dp.data_parallel(l, i, (0, 1))) @@ -4568,7 +4137,7 @@ def forward(self, x): @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") def test_data_parallel(self): l = nn.Linear(10, 5).float().cuda() - i = Variable(torch.randn(20, 10).float().cuda(1)) + i = torch.randn(20, 10, dtype=torch.float, device="cuda:1") l.cuda(1) expected_out = l(i) loss = expected_out.sum() @@ -4660,7 +4229,7 @@ class Net(nn.Module): def forward(self, *input): return fn(input) - i = Variable(torch.randn(20, 3).float().cuda(1)) + i = torch.randn(20, 3, dtype=torch.float, device="cuda:1") input = (i.cos(), (i.sin(), i), i.sin()) gpus = range(torch.cuda.device_count()) output = dp.data_parallel(Net(), input, gpus) @@ -4771,7 +4340,6 @@ def test_data_parallel_device_args(self): self.assertEqual(out, l(i)) @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") - @skipIfRocm def test_data_parallel_function_deletion(self): # this test case is originated from #16532 def gradient_penalty(net, x): @@ -5091,8 +4659,8 @@ def test_assignments(get_list, a, b, c): self.assertEqual(l.state_dict()['buf'], buf) def test_Conv2d_inconsistent_types(self): - inputs = Variable(torch.randn(4, 1, 7, 7).float()) - weights = Variable(torch.randn(1, 1, 3, 3).double()) + inputs = torch.randn(4, 1, 7, 7, dtype=torch.float) + weights = torch.randn(1, 1, 3, 3, dtype=torch.double) # inconsistent types should raise an exception self.assertRaises(RuntimeError, lambda: nn.functional.conv2d(inputs, weights)) # but it should work with the same type @@ -5100,9 +4668,9 @@ def test_Conv2d_inconsistent_types(self): @unittest.skipIf(not TEST_CUDA, 'CUDA not available') def test_Conv2d_inconsistent_types_on_GPU_without_cudnn(self): - inputs = Variable(torch.randn(4, 1, 7, 7).float().cuda()) - weights = Variable(torch.randn(1, 1, 3, 3).double().cuda()) - bias = Variable(torch.randn(1).double().cuda()) + inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda") + weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda") + bias = torch.randn(1, dtype=torch.double, device="cuda") with torch.backends.cudnn.flags(enabled=False): # inconsistent types should raise an exception @@ -5115,9 +4683,9 @@ def test_Conv2d_inconsistent_types_on_GPU_without_cudnn(self): @unittest.skipIf(not TEST_CUDA, 'CUDA not available') @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available') def test_Conv2d_inconsistent_types_on_GPU_with_cudnn(self): - inputs = Variable(torch.randn(4, 1, 7, 7).float().cuda()) - weights = Variable(torch.randn(1, 1, 3, 3).double().cuda()) - bias = Variable(torch.randn(1).double().cuda()) + inputs = torch.randn(4, 1, 7, 7, dtype=torch.float, device="cuda") + weights = torch.randn(1, 1, 3, 3, dtype=torch.double, device="cuda") + bias = torch.randn(1, dtype=torch.double, device="cuda") with torch.backends.cudnn.flags(enabled=True): # inconsistent types should raise an exception @@ -5230,44 +4798,46 @@ def run_test(benchmark): run_test(benchmark=True) def test_conv_modules_raise_error_on_incorrect_input_size(self): - modules = [nn.Conv1d(3, 8, 3), nn.ConvTranspose1d(3, 8, 3), - nn.Conv2d(3, 8, 3), nn.ConvTranspose2d(3, 8, 3), - nn.Conv3d(3, 8, 3), nn.ConvTranspose3d(3, 8, 3)] + for dtype in [torch.bfloat16, torch.double, torch.float]: + modules = [nn.Conv1d(3, 8, 3).to(dtype), nn.ConvTranspose1d(3, 8, 3).to(dtype), + nn.Conv2d(3, 8, 3).to(dtype), nn.ConvTranspose2d(3, 8, 3).to(dtype), + nn.Conv3d(3, 8, 3).to(dtype), nn.ConvTranspose3d(3, 8, 3).to(dtype)] - invalid_input_dims = [(2, 4), (2, 4), - (3, 5), (3, 5), - (4, 6), (4, 6)] + invalid_input_dims = [(2, 4), (2, 4), + (3, 5), (3, 5), + (4, 6), (4, 6)] - for invalid_dims, module in zip(invalid_input_dims, modules): - for dims in invalid_dims: - input = torch.empty(torch.Size((3, ) * dims)) - self.assertRaises(RuntimeError, lambda: module(input)) + for invalid_dims, module in zip(invalid_input_dims, modules): + for dims in invalid_dims: + input = torch.empty(torch.Size((3, ) * dims)) + self.assertRaises(RuntimeError, lambda: module(input)) def test_conv_shapecheck(self): - def test(should_raise, module, input_size): - input = torch.empty(3, *input_size) + def test(should_raise, module, input_size, dtype): + input = torch.empty(3, *input_size).to(dtype) if should_raise: self.assertRaises(RuntimeError, lambda: module(input)) else: # just run it to ensure no exception raised. module(input) - # Conv1d - test(True, nn.Conv1d(1, 1, 3), (1, 2)) - test(True, nn.Conv1d(1, 1, 3, stride=2), (1, 2)) - test(False, nn.Conv1d(1, 1, 2), (1, 2)) - test(False, nn.Conv1d(1, 1, 2, stride=2), (1, 2)) - test(False, nn.Conv1d(1, 1, 3, stride=2, padding=1), (1, 2)) + for dtype in [torch.bfloat16, torch.float, torch.double]: + # Conv1d + test(True, nn.Conv1d(1, 1, 3).to(dtype), (1, 2), dtype) + test(True, nn.Conv1d(1, 1, 3, stride=2).to(dtype), (1, 2), dtype) + test(False, nn.Conv1d(1, 1, 2).to(dtype), (1, 2), dtype) + test(False, nn.Conv1d(1, 1, 2, stride=2).to(dtype), (1, 2), dtype) + test(False, nn.Conv1d(1, 1, 3, stride=2, padding=1).to(dtype), (1, 2), dtype) - # Conv2d - test(True, nn.Conv2d(1, 1, (3, 3)), (1, 2, 2)) - test(False, nn.Conv2d(1, 1, (3, 3)), (1, 3, 3)) - test(False, nn.Conv2d(1, 1, (3, 3), padding=1), (1, 2, 2)) + # Conv2d + test(True, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 2, 2), dtype) + test(False, nn.Conv2d(1, 1, (3, 3)).to(dtype), (1, 3, 3), dtype) + test(False, nn.Conv2d(1, 1, (3, 3), padding=1).to(dtype), (1, 2, 2), dtype) - # Conv3D - test(True, nn.Conv3d(1, 1, (3, 3, 3)), (1, 2, 2, 2)) - test(False, nn.Conv3d(1, 1, (3, 3, 3)), (1, 3, 3, 3)) - test(False, nn.Conv3d(1, 1, (3, 3, 3), padding=1), (1, 2, 2, 2)) + # Conv3D + test(True, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 2, 2, 2), dtype) + test(False, nn.Conv3d(1, 1, (3, 3, 3)).to(dtype), (1, 3, 3, 3), dtype) + test(False, nn.Conv3d(1, 1, (3, 3, 3), padding=1).to(dtype), (1, 2, 2, 2), dtype) def test_ConvTranspose2d_output_size(self): m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2) @@ -5306,14 +4876,14 @@ def _test_Conv2d_naive_groups(self, device="cpu", dtype=torch.float): m1 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype) m1.weight.data.copy_(m.weight.data[:2]) m1.bias.data.copy_(m.bias.data[:2]) - i1 = Variable(i.data[:, :2].contiguous(), requires_grad=True) + i1 = i.data[:, :2].contiguous().requires_grad_(True) output1 = m1(i1) output1.backward(grad_output[:, :2].contiguous()) m2 = nn.Conv2d(2, 2, kernel_size=3).to(device, dtype) m2.weight.data.copy_(m.weight.data[2:]) m2.bias.data.copy_(m.bias.data[2:]) - i2 = Variable(i.data[:, 2:].contiguous(), requires_grad=True) + i2 = i.data[:, 2:].contiguous().requires_grad_(True) output2 = m2(i2) output2.backward(grad_output[:, 2:].contiguous()) @@ -5343,13 +4913,13 @@ def test_Conv2d_groups_nobias(self): m1 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype) m1.weight.data.copy_(m.weight.data[:2]) - i1 = Variable(i.data[:, :2].contiguous(), requires_grad=True) + i1 = i.data[:, :2].contiguous().requires_grad_(True) output1 = m1(i1) output1.backward(grad_output[:, :2].contiguous()) m2 = nn.Conv2d(2, 2, kernel_size=3, bias=False).to(device, dtype) m2.weight.data.copy_(m.weight.data[2:]) - i2 = Variable(i.data[:, 2:].contiguous(), requires_grad=True) + i2 = i.data[:, 2:].contiguous().requires_grad_(True) output2 = m2(i2) output2.backward(grad_output[:, 2:].contiguous()) @@ -5379,13 +4949,13 @@ def test_Conv2d_groups_nobias_v2(self): m1 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype) m1.weight.data.copy_(m.weight.data[:8]) - i1 = Variable(i.data[:, :2].contiguous(), requires_grad=True) + i1 = i.data[:, :2].contiguous().requires_grad_(True) output1 = m1(i1) output1.backward(grad_output[:, :8].contiguous()) m2 = nn.Conv2d(2, 8, kernel_size=3, bias=False).to(device, dtype) m2.weight.data.copy_(m.weight.data[8:]) - i2 = Variable(i.data[:, 2:].contiguous(), requires_grad=True) + i2 = i.data[:, 2:].contiguous().requires_grad_(True) output2 = m2(i2) output2.backward(grad_output[:, 8:].contiguous()) @@ -5451,7 +5021,7 @@ def test_MaxUnpool2d_output_size(self): for i in range(0, 4, 2): for j in range(0, 4, 2): small_t[:, :, i, j] = 100 - output_small, indices_small = m(Variable(small_t)) + output_small, indices_small = m(small_t) for h in range(3, 10): for w in range(3, 10): if 4 <= h <= 6 and 4 <= w <= 6: @@ -5504,8 +5074,8 @@ def _test_loss_equal_input_target_shape(self, cast): 'poisson_nll_loss': lambda x, y: F.poisson_nll_loss(x, y), } - input = Variable(cast(torch.randn(3, 5))) - target = Variable(cast(torch.randn(5, 3))) + input = cast(torch.randn(3, 5)) + target = cast(torch.randn(5, 3)) for _name, fn in losses.items(): self.assertRaises(Exception, lambda: fn(input, target)) @@ -5993,6 +5563,14 @@ def test_LSTM_cell(self): (hx + cx).sum().backward() + @unittest.skipIf(not TEST_CUDA, 'CUDA not available') + def test_pack_sequence_batch_sizes_throw(self): + with self.assertRaisesRegex(ValueError, r"batch_sizes should always be on CPU"): + m = nn.LSTM(3, 4, bidirectional=True, num_layers=2).to('cuda') + a = torch.rand(5, 3, device='cuda') + b = torch.tensor([1, 1, 1, 1, 1], device='cuda') + input = nn.utils.rnn.PackedSequence(a, b) + def test_Transformer_cell(self): # this is just a smoke test; these modules are implemented through # autograd so no Jacobian test is needed @@ -6442,11 +6020,11 @@ def test_cudnn_weight_format(self): first_warn = True for rnn in rnns: rnn.cuda() - input = Variable(torch.randn(5, 4, 10).cuda(), requires_grad=True) - hx = Variable(torch.randn(1, 5, 20).cuda(), requires_grad=True) + input = torch.randn(5, 4, 10, requires_grad=True, device="cuda") + hx = torch.randn(1, 5, 20, requires_grad=True, device="cuda") all_vars = [input, hx] + list(rnn.parameters()) if isinstance(rnn, nn.LSTM): - cx = Variable(torch.randn(1, 5, 20).cuda(), requires_grad=True) + cx = torch.randn(1, 5, 20, requires_grad=True, device="cuda") all_vars[2:2] = [cx] hx = (hx, cx) @@ -6491,13 +6069,13 @@ def test_cudnn_weight_tying(self): for rnn in rnns: rnn.bias_ih_l0_reverse = rnn.bias_ih_l0 rnn.cuda() - input = Variable(torch.randn(5, 4, 10).cuda(), requires_grad=True) - hx = Variable(torch.randn(2, 5, 20).cuda(), requires_grad=True) + input = torch.randn(5, 4, 10, requires_grad=True, device="cuda") + hx = torch.randn(2, 5, 20, requires_grad=True, device="cuda") all_vars = [input, hx] + list(rnn.parameters()) opt = torch.optim.SGD(rnn.parameters(), lr=0.1) opt.zero_grad() if isinstance(rnn, nn.LSTM): - cx = Variable(torch.randn(2, 5, 20).cuda(), requires_grad=True) + cx = torch.randn(2, 5, 20, requires_grad=True, device="cuda") all_vars[2:2] = [cx] hx = (hx, cx) @@ -6545,16 +6123,16 @@ def check_rnn_grads(rnn1, rnn2): is_lstm = isinstance(rnn, nn.LSTM) if is_lstm: - hx = (Variable(hx_val.clone(), requires_grad=True), - Variable(hx_val.clone().add(1), requires_grad=True)) - hx_cuda = (Variable(hx_val.clone().cuda(), requires_grad=True), - Variable(hx_val.clone().cuda().add(1), requires_grad=True)) + hx = (hx_val.clone().requires_grad_(True), + hx_val.clone().add(1).requires_grad_(True)) + hx_cuda = (hx_val.clone().cuda().requires_grad_(True), + hx_val.clone().cuda().add(1).requires_grad_(True)) else: - hx = Variable(hx_val.clone(), requires_grad=True) - hx_cuda = Variable(hx_val.clone().cuda(), requires_grad=True) + hx = hx_val.clone().requires_grad_(True) + hx_cuda = hx_val.clone().cuda().requires_grad_(True) - inp = Variable(input_val.clone(), requires_grad=True) - inp_cu = Variable(input_val.clone().cuda(), requires_grad=True) + inp = input_val.clone().requires_grad_(True) + inp_cu = input_val.clone().cuda().requires_grad_(True) output1, hy1 = rnn(inp, hx) output2, hy2 = rnn_cuda(inp_cu, hx_cuda) if is_lstm: @@ -6912,16 +6490,16 @@ def forward_backward(cuda, rnn, input_val, hx_val, grad_output, grad_hy, weights if isinstance(input_val, rnn_utils.PackedSequence): input = rnn_utils.PackedSequence( - Variable(input_val.data.data, requires_grad=True), input_val.batch_sizes) + input_val.data.data.requires_grad_(True), input_val.batch_sizes) input_var = input.data else: - input = Variable(input_val.clone(), requires_grad=True) + input = input_val.clone().requires_grad_(True) input_var = input if is_lstm: - hx = (Variable(hx_val.clone(), requires_grad=True), - Variable(hx_val.add(1), requires_grad=True)) + hx = (hx_val.clone().requires_grad_(True), + hx_val.add(1).requires_grad_(True)) else: - hx = Variable(hx_val.clone(), requires_grad=True) + hx = hx_val.clone().requires_grad_(True) if cuda: rnn.cuda() @@ -8645,7 +8223,7 @@ def test_upsamplingNearest2d(self): m = nn.Upsample(size=4, mode='nearest') in_t = torch.ones(1, 1, 2, 2) with warnings.catch_warnings(record=True) as w: - out_t = m(Variable(in_t)) + out_t = m(in_t) self.assertEqual(torch.ones(1, 1, 4, 4), out_t.data) input = torch.randn(1, 1, 2, 2, requires_grad=True) @@ -8718,7 +8296,7 @@ def test_upsamplingNearest3d(self): m = nn.Upsample(size=4, mode='nearest') in_t = torch.ones(1, 1, 2, 2, 2) with warnings.catch_warnings(record=True) as w: - out_t = m(Variable(in_t)) + out_t = m(in_t) self.assertEqual(torch.ones(1, 1, 4, 4, 4), out_t.data) input = torch.randn(1, 1, 2, 2, 2, requires_grad=True) @@ -9057,9 +8635,9 @@ def test_conv_double_backward_stride(self): def test_cudnn_noncontiguous_weight(self): # Noncontiguous weights must be contiguous() before being # passed to cuDNN - input = Variable(torch.cuda.DoubleTensor([1, 1, 1]).view(1, 1, 3)) - weights1 = Variable(torch.cuda.DoubleTensor([1]).expand(1, 1, 2)) - weights2 = Variable(torch.cuda.DoubleTensor([1]).expand(1, 1, 2)).contiguous() + input = torch.tensor([1, 1, 1], dtype=torch.double, device="cuda").view(1, 1, 3) + weights1 = torch.tensor([1], dtype=torch.double, device="cuda").expand(1, 1, 2) + weights2 = torch.tensor([1], dtype=torch.double, device="cuda").expand(1, 1, 2).contiguous() self.assertEqual(F.conv1d(input, weights1, bias=None, stride=2, dilation=2), F.conv1d(input, weights2, bias=None, stride=2, dilation=2)) @@ -9218,11 +8796,25 @@ def test_softmin(self): self.assertEqual(F.softmin(x, 1), F.softmax(-x, 1)) self.assertEqual(F.softmin(x, 0), F.softmax(-x, 0)) - def test_log_softmax(self): - x_small = torch.ones(1, 2, dtype=torch.float32) + @repeat_test_for_types([torch.float, torch.bfloat16]) + def test_log_softmax(self, dtype=torch.float): + x_small = torch.ones(1, 2, dtype=dtype) x_big = x_small + 1e16 self.assertEqual(F.log_softmax(x_small, -1), F.log_softmax(x_big, -1)) + def test_log_softmax_cpu(self, dtype=torch.bfloat16): + inputf = torch.rand(32, 100, device="cpu", dtype=torch.float, requires_grad=True) + input = inputf.to(dtype).detach().requires_grad_(True) + outf = F.log_softmax(inputf, dim=-1) + out = F.log_softmax(input, dim=-1) + self.assertEqual(out.dtype, dtype) + self.assertEqual(out, outf, prec=0.1) + + out.sum().backward() + outf.sum().backward() + self.assertEqual(input.grad.dtype, dtype) + self.assertEqual(input.grad, inputf.grad.to(dtype), prec=0.1) + def test_adaptive_log_softmax(self): # args validation with self.assertRaises(ValueError): @@ -9316,6 +8908,21 @@ def test_adaptive_log_softmax(self): out = asfm.predict(x) self.assertEqual(out, asfm.log_prob(x).argmax(dim=1)) + def test_cross_entropy_loss(self, dtype=torch.bfloat16): + loss_cpu = nn.CrossEntropyLoss().cpu() + inputf = torch.randn(15, 10, device="cpu", dtype=torch.float, requires_grad=True) + input = inputf.to(dtype).detach().requires_grad_(True) + target = torch.empty(15, dtype=torch.long).random_(10) + + outf = loss_cpu(inputf, target) + out = loss_cpu(input, target) + self.assertEqual(out.dtype, dtype) + self.assertEqual(out, outf, prec=1e-1) + + outf.backward() + out.backward() + self.assertEqual(input.grad.dtype, dtype) + self.assertEqual(input.grad, inputf.grad, prec=1e-1) class TestNNInit(TestCase): def setUp(self): @@ -9768,293 +9375,661 @@ def eval_constructor(*args, **kwargs): desc = test_params.get('desc', None) test_params['desc'] = 'with_long_tensor' if desc is None else desc + '_with_long_tensor' - def double_equivalent_of_long_tensor(size): - return torch.randint(-1000, 1000, size=size).double() + def double_equivalent_of_long_tensor(size): + return torch.randint(-1000, 1000, size=size).double() + + def apply_to_cons(t): + if t.is_floating_point(): + if isinstance(t, Parameter): + return Parameter(double_equivalent_of_long_tensor(t.size())) + elif isinstance(t, torch.Tensor): + return double_equivalent_of_long_tensor(t.size()) + else: + return t + + def gen_long_tensor_constructor(constructor): + def long_tensor_constructor(*args, **kwargs): + cons = constructor(*args, **kwargs) + cons._apply(apply_to_cons) + return cons + long_tensor_constructor.__name__ = constructor.__name__ + return long_tensor_constructor + + def gen_long_tensor_input(input_size): + def input_func(): + return double_equivalent_of_long_tensor(input_size) + return input_func + + def reference_fn(i, p, m): + m._apply(lambda t: t.long()) + input = i.long() + out = m.forward(input) + return out + + test_params['constructor'] = gen_long_tensor_constructor(test_params['constructor']) + test_params['input_fn'] = gen_long_tensor_input(test_params['input_size']) + test_params['reference_fn'] = reference_fn + test_params['check_forward_only'] = True + # Currently we don't support conv2d/conv3d for LongTensor in CUDA + test_params['test_cuda'] = False + test = NewModuleTest(**test_params) + + add_test(test, decorator) + +for test_params in criterion_tests + new_criterion_tests: + name = test_params.pop('module_name') + test_params['constructor'] = getattr(nn, name) + test = NewCriterionTest(**test_params) + decorator = test_params.pop('decorator', None) + add_test(test, decorator) + if 'check_sum_reduction' in test_params: + desc = test_params.get('desc', None) + test_params['desc'] = 'sum_reduction' if desc is None else desc + '_sum_reduction' + + def gen_sum_reduction_constructor(constructor): + def sum_reduction_constructor(*args, **kwargs): + cons = constructor(*args, reduction='sum', **kwargs) + return cons + sum_reduction_constructor.__name__ = constructor.__name__ + return sum_reduction_constructor + + test_params['constructor'] = gen_sum_reduction_constructor(test_params['constructor']) + test = NewCriterionTest(**test_params) + add_test(test, decorator) + + +class UnpoolingNet(nn.Module): + def __init__(self, pool, unpool): + super(UnpoolingNet, self).__init__() + self.pool = pool + self.unpool = unpool + + def forward(self, input): + return self.unpool(*self.pool(input)) + + +add_test(NewModuleTest( + constructor=lambda: UnpoolingNet( + nn.MaxPool1d(2, return_indices=True), + nn.MaxUnpool1d(2)), + input_size=(1, 1, 4), + fullname='MaxUnpool1d_net',)) +add_test(NewModuleTest( + constructor=lambda: UnpoolingNet( + nn.MaxPool2d(2, return_indices=True), + nn.MaxUnpool2d(2)), + input_size=(1, 1, 2, 4), + fullname='MaxUnpool2d_net',)) +add_test(NewModuleTest( + constructor=lambda: UnpoolingNet( + nn.MaxPool3d(2, return_indices=True), + nn.MaxUnpool3d(2)), + input_size=(1, 1, 2, 4, 6), + fullname='MaxUnpool3d_net', + check_gradgrad=False,)) + + +class _AdaptiveLogSoftmaxWithLoss(nn.AdaptiveLogSoftmaxWithLoss): + def __call__(self, input): + t = torch.tensor([0, 1, 4, 8]).to(input.device) + return nn.AdaptiveLogSoftmaxWithLoss.__call__(self, input, t).output + +add_test(NewModuleTest( + constructor=lambda: _AdaptiveLogSoftmaxWithLoss(16, 10, [2, 6]), + input_size=(4, 16), + fullname='AdaptiveLogSoftmax')) + + +# The following are helpers for TestNN.test_affine_* +if torch.cuda.is_available(): + def device_(): + return ['cpu', 'cuda'] +else: + def device_(): + return ['cpu'] + + +def angle_rad_(): + return [r * math.pi * 2 for r in [0.0, 0.5, 0.25, 0.125, random.random()]] + + +def axis_vector_(): + t = (random.random(), random.random(), random.random()) + l = sum(x ** 2 for x in t) ** 0.5 + + return [(1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0), tuple(x / l for x in t)] + + +def input_size2d_(): + return [[1, 1, 3, 5], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 3, 4]] + + +def output_size2d_(): + return [[1, 1, 5, 3], [1, 1, 3, 5], [1, 1, 4, 3], [1, 1, 5, 5], [1, 1, 6, 6]] + + +def input_size2dsq_(): + return [[1, 1, 2, 2], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 6, 6]] + + +def output_size2dsq_(): + return [[1, 1, 2, 2], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 5, 5], [1, 1, 6, 6]] + + +def input_size3d_(): + return [[1, 1, 2, 2, 2], [1, 1, 2, 3, 4], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 3, 4, 5]] + + +def input_size3dsq_(): + return [[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 6, 6, 6]] + + +def output_size3dsq_(): + return [[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 5, 5, 5], [1, 1, 6, 6, 6]] + + +def output_size3d_(): + return [[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 3, 4, 5], [1, 1, 4, 3, 2], [1, 1, 5, 5, 5], [1, 1, 6, 6, 6]] + + +def _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad): + input_center = [(x - 1) / 2.0 for x in input_size] + output_center = [(x - 1) / 2.0 for x in output_size] + + s = math.sin(angle_rad) + c = math.cos(angle_rad) + + intrans_ary = np.array([ + [1, 0, input_center[2]], + [0, 1, input_center[3]], + [0, 0, 1], + ], dtype=np.float64) + + inscale_ary = np.array([ + [input_center[2], 0, 0], + [0, input_center[3], 0], + [0, 0, 1], + ], dtype=np.float64) + + rotation_ary = np.array([ + [c, -s, 0], + [s, c, 0], + [0, 0, 1], + ], dtype=np.float64) + + outscale_ary = np.array([ + [1.0 / output_center[2], 0, 0], + [0, 1.0 / output_center[3], 0], + [0, 0, 1], + ], dtype=np.float64) + + outtrans_ary = np.array([ + [1, 0, -output_center[2]], + [0, 1, -output_center[3]], + [0, 0, 1], + ], dtype=np.float64) + + reorder_ary = np.array([ + [0, 1, 0], + [1, 0, 0], + [0, 0, 1], + ], dtype=np.float64) + + transform_ary = np.dot(np.dot(np.dot(np.dot( + intrans_ary, + inscale_ary), + rotation_ary.T), + outscale_ary), + outtrans_ary) + grid_ary = np.dot(np.dot(np.dot(reorder_ary, rotation_ary.T), outscale_ary), outtrans_ary) + + transform_tensor = torch.from_numpy((rotation_ary)).to(device, torch.float32) + transform_tensor = transform_tensor[:2].unsqueeze(0) + + return transform_tensor, transform_ary, grid_ary + + +def _buildEquivalentAffineTransforms3d(device, input_size, output_size, angle_rad, axis_vector): + input_center = [(x - 1) / 2.0 for x in input_size] + output_center = [(x - 1) / 2.0 for x in output_size] + + s = math.sin(angle_rad) + c = math.cos(angle_rad) + c1 = 1 - c + + intrans_ary = np.array([ + [1, 0, 0, input_center[2]], + [0, 1, 0, input_center[3]], + [0, 0, 1, input_center[4]], + [0, 0, 0, 1], + ], dtype=np.float64) + + inscale_ary = np.array([ + [input_center[2], 0, 0, 0], + [0, input_center[3], 0, 0], + [0, 0, input_center[4], 0], + [0, 0, 0, 1], + ], dtype=np.float64) + + l, m, n = axis_vector + scipyRotation_ary = np.array([ + [l * l * c1 + c, m * l * c1 - n * s, n * l * c1 + m * s, 0], + [l * m * c1 + n * s, m * m * c1 + c, n * m * c1 - l * s, 0], + [l * n * c1 - m * s, m * n * c1 + l * s, n * n * c1 + c, 0], + [0, 0, 0, 1], + ], dtype=np.float64) + + z, y, x = axis_vector + torchRotation_ary = np.array([ + [x * x * c1 + c, y * x * c1 - z * s, z * x * c1 + y * s, 0], + [x * y * c1 + z * s, y * y * c1 + c, z * y * c1 - x * s, 0], + [x * z * c1 - y * s, y * z * c1 + x * s, z * z * c1 + c, 0], + [0, 0, 0, 1], + ], dtype=np.float64) + + outscale_ary = np.array([ + [1.0 / output_center[2], 0, 0, 0], + [0, 1.0 / output_center[3], 0, 0], + [0, 0, 1.0 / output_center[4], 0], + [0, 0, 0, 1], + ], dtype=np.float64) - def apply_to_cons(t): - if t.is_floating_point(): - if isinstance(t, Parameter): - return Parameter(double_equivalent_of_long_tensor(t.size())) - elif isinstance(t, torch.Tensor): - return double_equivalent_of_long_tensor(t.size()) - else: - return t + outtrans_ary = np.array([ + [1, 0, 0, -output_center[2]], + [0, 1, 0, -output_center[3]], + [0, 0, 1, -output_center[4]], + [0, 0, 0, 1], + ], dtype=np.float64) - def gen_long_tensor_constructor(constructor): - def long_tensor_constructor(*args, **kwargs): - cons = constructor(*args, **kwargs) - cons._apply(apply_to_cons) - return cons - long_tensor_constructor.__name__ = constructor.__name__ - return long_tensor_constructor + reorder_ary = np.array([ + [0, 0, 1, 0], + [0, 1, 0, 0], + [1, 0, 0, 0], + [0, 0, 0, 1], + ], dtype=np.float64) - def gen_long_tensor_input(input_size): - def input_func(): - return double_equivalent_of_long_tensor(input_size) - return input_func + transform_ary = np.dot(np.dot(np.dot(np.dot( + intrans_ary, + inscale_ary), + np.linalg.inv(scipyRotation_ary)), + outscale_ary), + outtrans_ary) + grid_ary = np.dot(np.dot(np.dot(reorder_ary, np.linalg.inv(scipyRotation_ary)), outscale_ary), outtrans_ary) - def reference_fn(i, p, m): - m._apply(lambda t: t.long()) - input = i.long() - out = m.forward(input) - return out + transform_tensor = torch.from_numpy((torchRotation_ary)).to(device, torch.float32) + transform_tensor = transform_tensor[:3].unsqueeze(0) - test_params['constructor'] = gen_long_tensor_constructor(test_params['constructor']) - test_params['input_fn'] = gen_long_tensor_input(test_params['input_size']) - test_params['reference_fn'] = reference_fn - test_params['check_forward_only'] = True - # Currently we don't support conv2d/conv3d for LongTensor in CUDA - test_params['test_cuda'] = False - test = NewModuleTest(**test_params) + return transform_tensor, transform_ary, grid_ary +# end TestNN.test_affine_* helpers - add_test(test, decorator) +class GenericDeviceTypeHelpers(object): + def _test_dropout(self, cls, device, input): + p = 0.2 + input = input.to(device).fill_(1 - p) -for test_params in criterion_tests + new_criterion_tests: - name = test_params.pop('module_name') - test_params['constructor'] = getattr(nn, name) - test = NewCriterionTest(**test_params) - decorator = test_params.pop('decorator', None) - add_test(test, decorator) - if 'check_sum_reduction' in test_params: - desc = test_params.get('desc', None) - test_params['desc'] = 'sum_reduction' if desc is None else desc + '_sum_reduction' + module = cls(p) + input_var = input.clone().requires_grad_() + output = module(input_var) + self.assertLess(abs(output.data.mean() - (1 - p)), 0.05) + output.backward(input) + self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05) - def gen_sum_reduction_constructor(constructor): - def sum_reduction_constructor(*args, **kwargs): - cons = constructor(*args, reduction='sum', **kwargs) - return cons - sum_reduction_constructor.__name__ = constructor.__name__ - return sum_reduction_constructor + module = cls(p, True) + input_var = input.clone().requires_grad_() + output = module(input_var + 0) + self.assertLess(abs(output.data.mean() - (1 - p)), 0.05) + output.backward(input) + self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05) - test_params['constructor'] = gen_sum_reduction_constructor(test_params['constructor']) - test = NewCriterionTest(**test_params) - add_test(test, decorator) + # check eval mode doesn't change anything + for inplace in [True, False]: + module = cls(p, inplace).eval() + self.assertEqual(input, module(input)) + # Check that these don't raise errors + module.__repr__() + str(module) -class UnpoolingNet(nn.Module): - def __init__(self, pool, unpool): - super(UnpoolingNet, self).__init__() - self.pool = pool - self.unpool = unpool + def _test_InstanceNorm_general(self, cls, input, device, dtype=torch.float): + # default case track_running_stats=False + b, c = input.size(0), input.size(1) + input_var = input.to(device=device, dtype=dtype).requires_grad_() - def forward(self, input): - return self.unpool(*self.pool(input)) + IN = cls(c, eps=0).to(device, dtype) + output = IN(input_var) + out_reshaped = output.view(b * c, -1) -add_test(NewModuleTest( - constructor=lambda: UnpoolingNet( - nn.MaxPool1d(2, return_indices=True), - nn.MaxUnpool1d(2)), - input_size=(1, 1, 4), - fullname='MaxUnpool1d_net',)) -add_test(NewModuleTest( - constructor=lambda: UnpoolingNet( - nn.MaxPool2d(2, return_indices=True), - nn.MaxUnpool2d(2)), - input_size=(1, 1, 2, 4), - fullname='MaxUnpool2d_net',)) -add_test(NewModuleTest( - constructor=lambda: UnpoolingNet( - nn.MaxPool3d(2, return_indices=True), - nn.MaxUnpool3d(2)), - input_size=(1, 1, 2, 4, 6), - fullname='MaxUnpool3d_net', - check_gradgrad=False,)) + mean = out_reshaped.mean(1) + var = out_reshaped.var(1, unbiased=False) + self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5) + self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5) -class _AdaptiveLogSoftmaxWithLoss(nn.AdaptiveLogSoftmaxWithLoss): - def __call__(self, input): - t = torch.tensor([0, 1, 4, 8]).to(input.device) - return nn.AdaptiveLogSoftmaxWithLoss.__call__(self, input, t).output + # check that eval mode doesn't change behavior + grad_out = torch.randn_like(output) + res1 = output.data.clone() + output.backward(grad_out) + grad1 = input_var.grad.data.clone() -add_test(NewModuleTest( - constructor=lambda: _AdaptiveLogSoftmaxWithLoss(16, 10, [2, 6]), - input_size=(4, 16), - fullname='AdaptiveLogSoftmax')) + IN.eval() + output = IN(input_var) + input_var.grad = None + output.backward(grad_out) + res2 = output.data + grad2 = input_var.grad.data + self.assertEqual(res1, res2) + self.assertEqual(grad1, grad2) + # If track_running_stats=True and momentum=1, running_mean/var should be + # equal to mean/var of the input (with unbias correction) + IN = cls(c, momentum=1, eps=0, track_running_stats=True).to(device, dtype) -# The following are helpers for TestNN.test_affine_* -if torch.cuda.is_available(): - def device_(): - return ['cpu', 'cuda'] -else: - def device_(): - return ['cpu'] + output = IN(input_var) + input_reshaped = input_var.transpose(1, 0).reshape(c, -1) + mean = input_reshaped.mean(1) -def angle_rad_(): - return [r * math.pi * 2 for r in [0.0, 0.5, 0.25, 0.125, random.random()]] + input_reshaped = input_var.transpose(1, 0).reshape(c, b, -1) + var = input_reshaped.var(2, unbiased=True)[:, :] + self.assertAlmostEqual(torch.abs(mean.data - IN.running_mean).mean(), 0, delta=1e-5) + self.assertAlmostEqual(torch.abs(var.data.mean(1) - IN.running_var).mean(), 0, delta=1e-5) -def axis_vector_(): - t = (random.random(), random.random(), random.random()) - l = sum(x ** 2 for x in t) ** 0.5 + # in eval mode, adding X * std to a channel in input should make the + # corresponding channel in output have mean X + IN.eval() + delta = IN.running_var.sqrt() * torch.arange(c, device=device, dtype=dtype) + delta = delta.view(-1, *[1 for _ in range(2, input.dim())]) + output = IN(input_var + delta) + self.assertEqual(output.transpose(0, 1).reshape(c, -1).mean(1), torch.arange(c)) - return [(1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0), tuple(x / l for x in t)] + def _test_InstanceNorm_cuda_half(self, cls, input): + # THNN + input = input.to(device='cuda', dtype=torch.half).random_(1, 10).requires_grad_(True) + m = cls(input.size(1), affine=True, track_running_stats=True).to("cuda", torch.half) + thnn_output = m(input) + thnn_output.sum().backward() + thnn_input_grad = input.grad.data.clone() + self.assertEqual(thnn_output.type(), input.type()) + # cuDNN + if TEST_CUDNN: + input.grad = None + m = m.float() + cudnn_output = m(input) + cudnn_output.sum().backward() + cudnn_input_grad = input.grad.data.clone() + self.assertEqual(cudnn_output.type(), input.type()) + self.assertAlmostEqual(cudnn_output, thnn_output, delta=1e-4) + self.assertAlmostEqual(cudnn_input_grad, thnn_input_grad, delta=1e-3) + def _test_LayerNorm_general(self, device, dtype=torch.float): + for i in range(2, 6): + shape = torch.randint(3, 6, (i,), dtype=torch.long).tolist() + x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10) + normalized_ndim = random.randint(1, i - 1) # inclusive + normalized_shape = shape[-normalized_ndim:] + unnormalized_shape = shape[:-normalized_ndim] -def input_size2d_(): - return [[1, 1, 3, 5], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 3, 4]] + # test that LN normalizes to mean 0 and stddev 1 + ln = nn.LayerNorm(normalized_shape, eps=0).to(device, dtype) + ln.weight.data.fill_(1) + ln.bias.data.fill_(0) + output = ln(x) + out_reshaped = output.view(*(unnormalized_shape + [-1])) + mean = out_reshaped.mean(-1) + var = out_reshaped.var(-1, unbiased=False) + self.assertAlmostEqual(torch.abs(mean.data).mean(), 0, delta=1e-5) + self.assertAlmostEqual(torch.abs(var.data).mean(), 1, delta=1e-5) + # test that LN applies weight and bias correctly + scale, bias = torch.empty(2).uniform_(0.2, 2).tolist() + ln.weight.data.fill_(scale) + ln.bias.data.fill_(bias) + output = ln(x) + out_reshaped = output.view(*(unnormalized_shape + [-1])) + mean = out_reshaped.mean(-1) + var = out_reshaped.var(-1, unbiased=False) + self.assertAlmostEqual(torch.abs(mean.data).mean(), bias, delta=1e-5) + self.assertAlmostEqual(torch.abs(var.data).mean(), scale ** 2, delta=1e-5) -def output_size2d_(): - return [[1, 1, 5, 3], [1, 1, 3, 5], [1, 1, 4, 3], [1, 1, 5, 5], [1, 1, 6, 6]] + bad_norm_shape_input_shape = { + (): (), + (2, 3): (3,), + (2,): (1, 2, 3), + (10,): (2, 3), + 10: (2, 3), + } + for norm_shape, input_shape in bad_norm_shape_input_shape.items(): + ln = nn.LayerNorm(norm_shape) + input = torch.empty(input_shape, device=device, dtype=dtype).uniform_(0, 10) + self.assertRaises(RuntimeError, lambda: ln(input)) + + def _test_LayerNorm_cuda_half(self): + input = torch.empty(2, 3, 3, 2, device="cuda", dtype=torch.half).random_(1, 10).requires_grad_(True) + m = nn.LayerNorm([3, 2]).to("cuda", torch.half) + output = m(input) + output.sum().backward() + self.assertEqual(output.type(), input.type()) + def _test_GroupNorm_general(self, device, dtype=torch.float): + good_shape_g = { + (1, 2, 3, 4): 2, + (2, 3, 10): 3, + (3, 1, 1, 1, 2): 1, + (2, 6, 4, 2, 2): 3, + } + for shape, g in good_shape_g.items(): + x = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10) + b = shape[0] + c = shape[1] -def input_size2dsq_(): - return [[1, 1, 2, 2], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 6, 6]] + # test that GN normalizes to mean 0 and stddev 1 + gn = nn.GroupNorm(g, c, eps=0).to(device, dtype) + gn.weight.data.fill_(1) + gn.bias.data.fill_(0) + output = gn(x) + out_reshaped = output.view(b, g, -1) + mean = out_reshaped.mean(-1) + var = out_reshaped.var(-1, unbiased=False) + self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5) + self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5) + # test that GN applies weight and bias correctly + scale = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2) + bias = torch.empty(c, device=device, dtype=dtype).uniform_(0.2, 2) + gn.weight.data.copy_(scale) + gn.bias.data.copy_(bias) + output = gn(x) + out_reshaped = output.view(b, c, -1) + out_normed = (out_reshaped - bias.view(c, 1)) / scale.view(c, 1) + out_normed_reshaped = out_normed.view(b, g, -1) + mean = out_normed_reshaped.mean(-1) + var = out_normed_reshaped.var(-1, unbiased=False) + self.assertAlmostEqual(torch.abs(mean).mean(), 0, delta=1e-5) + self.assertAlmostEqual(torch.abs(var).mean(), 1, delta=1e-5) -def output_size2dsq_(): - return [[1, 1, 2, 2], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 5, 5], [1, 1, 6, 6]] + bad_shape_g = { + (1, 2, 3, 4): 3, + (2, 3, 10): 2, + (3, 1, 1, 1, 2): 10, + (2, 6, 4, 2, 2): 4, + } + for shape, g in bad_shape_g.items(): + gn = nn.GroupNorm(g, shape[1]) + input = torch.empty(*shape, device=device, dtype=dtype).uniform_(0, 10) + self.assertRaises(RuntimeError, lambda: gn(input)) + def _test_GroupNorm_cuda_half(self): + input = torch.zeros(2, 4, 3, 2, requires_grad=True).cuda().half().random_(1, 10) + m = nn.GroupNorm(2, 4).to("cuda", torch.half) + output = m(input) + output.sum().backward() + self.assertEqual(output.type(), input.type()) -def input_size3d_(): - return [[1, 1, 2, 2, 2], [1, 1, 2, 3, 4], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 3, 4, 5]] +class TestNNDeviceType(NNTestCase, GenericDeviceTypeHelpers): + def test_Dropout(self, device): + input = torch.Tensor(1000) + self._test_dropout(nn.Dropout, device, input) + def test_Dropout2d(self, device): + b = random.randint(1, 5) + w = random.randint(1, 5) + h = random.randint(1, 5) + num_features = 1000 + input = torch.Tensor(num_features, b, w, h) + self._test_dropout(nn.Dropout2d, device, input) -def input_size3dsq_(): - return [[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 6, 6, 6]] + def test_Dropout3d(self, device): + b = random.randint(1, 5) + w = random.randint(1, 5) + h = random.randint(1, 5) + d = random.randint(1, 2) + num_features = 1000 + input = torch.Tensor(num_features, b, d, w, h) + self._test_dropout(nn.Dropout3d, device, input) + def test_InstanceNorm1d_general(self, device): + b = random.randint(3, 5) + c = random.randint(3, 5) + d = random.randint(8, 10) -def output_size3dsq_(): - return [[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 5, 5, 5], [1, 1, 6, 6, 6]] + input = torch.rand(b, c, d) + self._test_InstanceNorm_general(nn.InstanceNorm1d, input, device) + if device == 'cuda': + self._test_InstanceNorm_cuda_half(nn.InstanceNorm1d, input) -def output_size3d_(): - return [[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 3, 4, 5], [1, 1, 4, 3, 2], [1, 1, 5, 5, 5], [1, 1, 6, 6, 6]] + def test_InstanceNorm2d_general(self, device): + b = random.randint(3, 5) + c = random.randint(3, 5) + w = random.randint(3, 6) + h = random.randint(6, 8) + input = torch.rand(b, c, h, w) + self._test_InstanceNorm_general(nn.InstanceNorm2d, input, device) -def _buildEquivalentAffineTransforms2d(device, input_size, output_size, angle_rad): - input_center = [(x - 1) / 2.0 for x in input_size] - output_center = [(x - 1) / 2.0 for x in output_size] + if device == 'cuda': + self._test_InstanceNorm_cuda_half(nn.InstanceNorm2d, input) - s = math.sin(angle_rad) - c = math.cos(angle_rad) + def test_InstanceNorm3d_general(self, device): + b = random.randint(3, 5) + c = random.randint(3, 5) + w = random.randint(2, 5) + h = random.randint(2, 5) + d = random.randint(2, 5) - intrans_ary = np.array([ - [1, 0, input_center[2]], - [0, 1, input_center[3]], - [0, 0, 1], - ], dtype=np.float64) + input = torch.rand(b, c, h, w, d) + self._test_InstanceNorm_general(nn.InstanceNorm3d, input, device) - inscale_ary = np.array([ - [input_center[2], 0, 0], - [0, input_center[3], 0], - [0, 0, 1], - ], dtype=np.float64) + if device == 'cuda': + self._test_InstanceNorm_cuda_half(nn.InstanceNorm3d, input) - rotation_ary = np.array([ - [c, -s, 0], - [s, c, 0], - [0, 0, 1], - ], dtype=np.float64) + def test_LayerNorm_general(self, device): + self._test_LayerNorm_general(device) - outscale_ary = np.array([ - [1.0 / output_center[2], 0, 0], - [0, 1.0 / output_center[3], 0], - [0, 0, 1], - ], dtype=np.float64) + if device == 'cuda': + self._test_LayerNorm_cuda_half() - outtrans_ary = np.array([ - [1, 0, -output_center[2]], - [0, 1, -output_center[3]], - [0, 0, 1], - ], dtype=np.float64) + def test_GroupNorm_general(self, device): + self._test_GroupNorm_general(device) - reorder_ary = np.array([ - [0, 1, 0], - [1, 0, 0], - [0, 0, 1], - ], dtype=np.float64) + if device == 'cuda': + self._test_GroupNorm_cuda_half() - transform_ary = np.dot(np.dot(np.dot(np.dot( - intrans_ary, - inscale_ary), - rotation_ary.T), - outscale_ary), - outtrans_ary) - grid_ary = np.dot(np.dot(np.dot(reorder_ary, rotation_ary.T), outscale_ary), outtrans_ary) + def test_one_hot(self, device): + with self.assertRaises(RuntimeError): + torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1) - transform_tensor = torch.from_numpy((rotation_ary)).to(device, torch.float32) - transform_tensor = transform_tensor[:2].unsqueeze(0) + with self.assertRaises(RuntimeError): + torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 3) - return transform_tensor, transform_ary, grid_ary + t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device)) + expected = torch.tensor([[0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + [0, 1, 0, 0, 0], + [1, 0, 0, 0, 0]], device=device) + self.assertEqual(t, expected) + t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -1) + expected = torch.tensor([[0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + [0, 1, 0, 0, 0], + [1, 0, 0, 0, 0]], device=device) + self.assertEqual(t, expected) -def _buildEquivalentAffineTransforms3d(device, input_size, output_size, angle_rad, axis_vector): - input_center = [(x - 1) / 2.0 for x in input_size] - output_center = [(x - 1) / 2.0 for x in output_size] + t = torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), 6) + expected = torch.tensor([[0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 1, 0], + [0, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0]], device=device) + self.assertEqual(t, expected) - s = math.sin(angle_rad) - c = math.cos(angle_rad) - c1 = 1 - c + t = torch.nn.functional.one_hot(torch.tensor([[3, 4], [1, 0]], device=device)) + expected = torch.tensor([[[0, 0, 0, 1, 0], + [0, 0, 0, 0, 1]], + [[0, 1, 0, 0, 0], + [1, 0, 0, 0, 0]]], device=device) + self.assertEqual(t, expected) - intrans_ary = np.array([ - [1, 0, 0, input_center[2]], - [0, 1, 0, input_center[3]], - [0, 0, 1, input_center[4]], - [0, 0, 0, 1], - ], dtype=np.float64) + t = torch.nn.functional.one_hot(torch.tensor(4, device=device)) + expected = torch.tensor([0, 0, 0, 0, 1], device=device) + self.assertEqual(t, expected) - inscale_ary = np.array([ - [input_center[2], 0, 0, 0], - [0, input_center[3], 0, 0], - [0, 0, input_center[4], 0], - [0, 0, 0, 1], - ], dtype=np.float64) + t = torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device), 100) + expected = torch.empty([4, 0, 100]) + self.assertEqual(t, expected) - l, m, n = axis_vector - scipyRotation_ary = np.array([ - [l * l * c1 + c, m * l * c1 - n * s, n * l * c1 + m * s, 0], - [l * m * c1 + n * s, m * m * c1 + c, n * m * c1 - l * s, 0], - [l * n * c1 - m * s, m * n * c1 + l * s, n * n * c1 + c, 0], - [0, 0, 0, 1], - ], dtype=np.float64) + with self.assertRaises(RuntimeError): + torch.nn.functional.one_hot(torch.empty([4, 0], dtype=torch.long, device=device)) - z, y, x = axis_vector - torchRotation_ary = np.array([ - [x * x * c1 + c, y * x * c1 - z * s, z * x * c1 + y * s, 0], - [x * y * c1 + z * s, y * y * c1 + c, z * y * c1 - x * s, 0], - [x * z * c1 - y * s, y * z * c1 + x * s, z * z * c1 + c, 0], - [0, 0, 0, 1], - ], dtype=np.float64) + with self.assertRaises(RuntimeError): + torch.nn.functional.one_hot(torch.tensor([3, 4, 1, 0], device=device), -2) - outscale_ary = np.array([ - [1.0 / output_center[2], 0, 0, 0], - [0, 1.0 / output_center[3], 0, 0], - [0, 0, 1.0 / output_center[4], 0], - [0, 0, 0, 1], - ], dtype=np.float64) + def test_nn_scalars(self, device): + # One off tests to ensure scalars from nn.yaml are properly applied + def verify_scalars(input, output): + if input.dim() == 0: + self.assertEqual((), output.shape) + else: + self.assertNotEqual((), output.shape) + output.sum().backward() + self.assertEqual(input.shape, input.grad.shape) - outtrans_ary = np.array([ - [1, 0, 0, -output_center[2]], - [0, 1, 0, -output_center[3]], - [0, 0, 1, -output_center[4]], - [0, 0, 0, 1], - ], dtype=np.float64) + for input_shape in [(5, 6), ()]: + for module in [torch.nn.ELU, torch.nn.Hardtanh, torch.nn.LeakyReLU, torch.nn.LogSigmoid, + torch.nn.RReLU, torch.nn.Softshrink, torch.nn.Softplus, torch.nn.Sigmoid, + torch.nn.Tanh]: + input = torch.randn(input_shape, device=device, requires_grad=True) + m = module() + output = m(input) + verify_scalars(input, output) - reorder_ary = np.array([ - [0, 0, 1, 0], - [0, 1, 0, 0], - [1, 0, 0, 0], - [0, 0, 0, 1], - ], dtype=np.float64) + def test_nn_scalars_reductions(self, device): + # One off tests to ensure scalars from nn.yaml are properly applied + def verify_reduction_scalars(input, reduction, output): + if reduction != 'none' or input.dim() == 0: + self.assertEqual((), output.shape) + else: + self.assertNotEqual((), output.shape) + output.sum().backward() + self.assertEqual(input.shape, input.grad.shape) - transform_ary = np.dot(np.dot(np.dot(np.dot( - intrans_ary, - inscale_ary), - np.linalg.inv(scipyRotation_ary)), - outscale_ary), - outtrans_ary) - grid_ary = np.dot(np.dot(np.dot(reorder_ary, np.linalg.inv(scipyRotation_ary)), outscale_ary), outtrans_ary) + for input_shape in [(5, 6), ()]: + for reduction in ['none', 'mean', 'sum']: + for module in [torch.nn.BCELoss, torch.nn.L1Loss, torch.nn.MSELoss, + torch.nn.SmoothL1Loss, torch.nn.SoftMarginLoss]: + input = torch.randn(input_shape, device=device, requires_grad=True) + target = torch.empty(input_shape, device=device).random_(2) + sigmoid = nn.Sigmoid() - transform_tensor = torch.from_numpy((torchRotation_ary)).to(device, torch.float32) - transform_tensor = transform_tensor[:3].unsqueeze(0) + input = torch.randn(input_shape, device=device, requires_grad=True) + m = module(reduction=reduction) + output = m(sigmoid(input), target) + verify_reduction_scalars(input, reduction, output) - return transform_tensor, transform_ary, grid_ary -# end TestNN.test_affine_* helpers +instantiate_device_type_tests(TestNNDeviceType, globals()) if __name__ == '__main__': run_tests() diff --git a/test/test_optim.py b/test/test_optim.py index f66e53d343cc7..1bbd88c152cf8 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -14,8 +14,7 @@ from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, \ ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, _LRScheduler, \ CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR -from common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \ - skipIfRocm +from common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -285,7 +284,6 @@ def test_sgd_sparse(self): [lambda opt: StepLR(opt, gamma=0.99999, step_size=300)] ) - @skipIfRocm def test_adam(self): self._test_basic_cases( lambda weight, bias: optim.Adam([weight, bias], lr=1e-3) @@ -401,7 +399,6 @@ def test_adagrad_sparse(self): lambda opt: ReduceLROnPlateau(opt, threshold=1e-4)] ) - @skipIfRocm def test_adamax(self): self._test_basic_cases( lambda weight, bias: optim.Adamax([weight, bias], lr=1e-1) @@ -426,7 +423,6 @@ def test_rmsprop(self): with self.assertRaisesRegex(ValueError, "Invalid momentum value: -1.0"): optim.RMSprop(None, lr=1e-2, momentum=-1.0) - @skipIfRocm def test_asgd(self): self._test_basic_cases( lambda weight, bias: optim.ASGD([weight, bias], lr=1e-3, t0=100) @@ -451,7 +447,6 @@ def test_rprop(self): with self.assertRaisesRegex(ValueError, "Invalid eta values: 1.0, 0.5"): optim.Rprop(None, lr=1e-2, etas=(1.0, 0.5)) - @skipIfRocm def test_lbfgs(self): self._test_basic_cases( lambda weight, bias: optim.LBFGS([weight, bias]), @@ -542,6 +537,32 @@ def setUp(self): [{'params': self.net.conv1.parameters()}, {'params': self.net.conv2.parameters(), 'lr': 0.5}], lr=0.05) + def test_no_cyclic_references(self): + import gc + param = Variable(torch.Tensor(10), requires_grad=True) + optim = SGD([param], lr=0.5) + scheduler = LambdaLR(optim, lambda epoch: 1.0) + del scheduler + + # Prior to Python 3.7, local variables in a function will be referred by the current frame. + import sys + if sys.version_info < (3, 7): + import inspect + referrers = gc.get_referrers(optim) + self.assertTrue( + len(referrers) == 1 and referrers[0] is inspect.currentframe(), + "Optimizer should contain no cyclic references (except current frame)") + del referrers + else: + self.assertTrue( + len(gc.get_referrers(optim)) == 0, + "Optimizer should contain no cyclic references") + + gc.collect() + del optim + self.assertEqual( + gc.collect(), 0, "Optimizer should be garbage-collected on __del__") + def test_old_pattern_warning(self): epochs = 35 with warnings.catch_warnings(record=True) as ws: @@ -1288,5 +1309,28 @@ def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iteratio msg='Momentum is wrong in batch_num {}: expected {}, got {}'.format( batch_num, momentum_target[batch_num], param_group['momentum']), delta=1e-5) + def test_cosine_then_cyclic(self): + # https://github.com/pytorch/pytorch/issues/21965 + + max_lr = 0.3 + base_lr = 0.1 + optim_lr = 0.5 + + model = torch.nn.Linear(2, 1) + optimizer = torch.optim.SGD(model.parameters(), lr=optim_lr) + lr_scheduler_1 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=0.1) + lr_scheduler_2 = torch.optim.lr_scheduler.CyclicLR( + optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=1, step_size_down=3 + ) + + for i in range(40): + if i <= lr_scheduler_1.T_max: + lr_scheduler_1.step() + else: + lr_scheduler_2.step() + last_lr = optimizer.param_groups[0]["lr"] + + self.assertLessEqual(last_lr, max_lr) + if __name__ == '__main__': run_tests() diff --git a/test/test_qat.py b/test/test_qat.py index 1a0547971cb34..7acaf27dc86c5 100644 --- a/test/test_qat.py +++ b/test/test_qat.py @@ -7,7 +7,7 @@ from torch.nn import Conv2d, BatchNorm2d, ReLU from torch.nn._intrinsic.qat import ConvBn2d, ConvBnReLU2d from torch.quantization.QConfig import default_qat_qconfig -from torch.utils.mkldnn import disable_mkldnn_conv +import torch.backends.mkldnn from common_utils import TestCase, run_tests from hypothesis import given from hypothesis import strategies as st @@ -59,7 +59,11 @@ def test_conv_bn_relu( momentum, freeze_bn ): - with disable_mkldnn_conv(): + # **** WARNING: This is used to temporarily disable MKL-DNN convolution due + # to a bug: https://github.com/pytorch/pytorch/issues/23825 + # Once this bug is fixed, this context manager as well as its callsites + # should be removed! + with torch.backends.mkldnn.flags(enabled=False): input_channels = input_channels_per_group * groups output_channels = output_channels_per_group * groups dilation_h = dilation_w = dilation diff --git a/test/test_quantization.py b/test/test_quantization.py index aea17d85faefc..bad08b00eb848 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -1,4 +1,5 @@ import unittest +import math import torch import torch.nn as nn import torch.nn.quantized as nnq @@ -6,12 +7,12 @@ import torch.nn._intrinsic.quantized as nniq import torch.nn._intrinsic.qat as nniqat from torch.quantization import \ - QConfig_dynamic, default_weight_observer, \ + QConfig_dynamic, default_weight_observer, dump_tensor,\ quantize, prepare, convert, prepare_qat, quantize_qat, fuse_modules, \ - quantize_dynamic, default_qconfig, default_qat_qconfig, \ - default_dynamic_qconfig, MinMaxObserver, QuantWrapper + quantize_dynamic, default_qconfig, default_debug_qconfig, default_qat_qconfig, \ + default_dynamic_qconfig, HistogramObserver, MinMaxObserver, PerChannelMinMaxObserver, TensorObserver, QuantWrapper -from common_utils import run_tests, tempfile +from common_utils import run_tests from common_quantization import QuantizationTestCase, SingleLayerLinearModel, \ SkipQuantModel, QuantStubModel, \ ModelForFusion, ManualLinearQATModel, ManualConvLinearQATModel, \ @@ -25,6 +26,7 @@ from hypothesis import given from hypothesis import strategies as st +from hypothesis_utils import no_deadline import io import copy @@ -302,6 +304,7 @@ def test_single_layer(self): def checkQuantized(model): self.checkDynamicQuantizedLinear(model.fc1) + self.checkScriptable(model, self.calib_data, check_save_load=True) checkQuantized(model) @@ -325,6 +328,7 @@ def test_two_layers(self): def checkQuantized(model): self.assertEqual(type(model.fc1), torch.nn.Linear) self.checkDynamicQuantizedLinear(model.fc2) + self.checkScriptable(model, self.calib_data, check_save_load=True) checkQuantized(model) @@ -350,6 +354,7 @@ def checkQuantized(model): self.checkDynamicQuantizedLinear(model.fc3) self.checkDynamicQuantizedLinear(model.sub2.fc1) self.checkLinear(model.sub2.fc2) + self.checkScriptable(model, self.calib_data, check_save_load=True) checkQuantized(model) @@ -376,6 +381,7 @@ def checkQuantized(model): self.checkDynamicQuantizedLinear(model.sub2.fc1) self.checkDynamicQuantizedLinear(model.sub2.fc2) self.checkDynamicQuantizedLinear(model.fc3) + self.checkScriptable(model, self.calib_data, check_save_load=True) checkQuantized(model) @@ -406,6 +412,7 @@ def checkQuantized(model): self.checkDynamicQuantizedLinear(model.sub2.fc1) self.checkDynamicQuantizedLinear(model.sub2.fc2) self.checkDynamicQuantizedLinear(model.fc3) + self.checkScriptable(model, self.calib_data, check_save_load=True) checkQuantized(model) @@ -434,6 +441,7 @@ def checkQuantized(model): self.checkLinear(model.sub2.fc1) self.checkDynamicQuantizedLinear(model.sub2.fc2) test_only_eval_fn(model, self.calib_data) + self.checkScriptable(model, self.calib_data, check_save_load=True) checkQuantized(model) @@ -484,12 +492,20 @@ def test_quantized_rnn(self): torch.nn.LSTM: torch.nn.quantized.dynamic.LSTM, } model_int8 = quantize_dynamic( - model, qconfig_dynamic_dict, default_dynamic_module_mapping + model=model, qconfig_dict=qconfig_dynamic_dict, mapping=default_dynamic_module_mapping, + dtype=torch.qint8 + ) + model_fp16 = quantize_dynamic( + model=model, qconfig_dict=qconfig_dynamic_dict, mapping=default_dynamic_module_mapping, + dtype=torch.float16 ) cell_int8 = model_int8.lstm + cell_fp16 = model_fp16.lstm assert type(cell_int8) == torch.nn.quantized.dynamic.LSTM, \ 'torch.nn.LSTM should be converted to torch.nn.quantized.dynamic.LSTM after quantize_dynamic' + assert type(cell_fp16) == torch.nn.quantized.dynamic.LSTM, \ + 'torch.nn.LSTM should be converted to torch.nn.quantized.dynamic.LSTM after quantize_dynamic' niter = 10 x = torch.tensor([[100, -155], @@ -513,7 +529,40 @@ def test_quantized_rnn(self): torch.testing.assert_allclose(output_int8, ref_out) self.assertEqual(output_int8, ref_out) - for out, ref in zip(final_hiddens_int8, ref_hid): + for out_val, ref_val in zip(final_hiddens_int8, ref_hid): + torch.testing.assert_allclose(out_val, ref_val) + + class ScriptWrapper(torch.nn.Module): + def __init__(self, cell): + super(ScriptWrapper, self).__init__() + self.cell = cell + + def forward(self, x, hiddens): + # type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + return self.cell(x, hiddens) + + # TODO: TorchScript overloads don't work without this wrapper + cell_script = torch.jit.script(ScriptWrapper(cell_int8)) + out_script, hid_script = cell_script(x, hiddens) + self.assertEqual(len(out_script), len(ref_out)) + for out_val, ref_val in zip(out_script, ref_out): + torch.testing.assert_allclose(out_val, ref_val) + + # Test save/load + b = io.BytesIO() + torch.jit.save(cell_script, b) + b.seek(0) + loaded = torch.jit.load(b) + out_loaded, hid_loaded = loaded(x, hiddens) + for loaded_val, ref_val in zip(out_loaded, ref_out): + torch.testing.assert_allclose(loaded_val, ref_val) + + # Compare fp16 quantized to unquantized + output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens) + + torch.testing.assert_allclose(output_fp16, ref_out) + self.assertEqual(output_fp16, ref_out) + for out, ref in zip(final_hiddens_fp16, ref_hid): torch.testing.assert_allclose(out, ref) @unittest.skipIf( @@ -589,10 +638,10 @@ def setUp(self): def test_scriptability_serialization(self): # test serialization of quantized functional modules - with tempfile.TemporaryFile() as f: - torch.save(self.qmodel_under_test, f) - f.seek(0) - loaded = torch.load(f) + b = io.BytesIO() + torch.save(self.qmodel_under_test, b) + b.seek(0) + loaded = torch.load(b) self.assertEqual(self.qmodel_under_test.myadd.zero_point, loaded.myadd.zero_point) state_dict = self.qmodel_under_test.state_dict() self.assertTrue('myadd.zero_point' in state_dict.keys(), @@ -711,9 +760,13 @@ def checkQuantized(model): class ObserverTest(QuantizationTestCase): @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), - qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric))) - def test_minmax_observer(self, qdtype, qscheme): - myobs = MinMaxObserver(dtype=qdtype, qscheme=qscheme) + qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)), + reduce_range=st.booleans()) + def test_minmax_observer(self, qdtype, qscheme, reduce_range): + # reduce_range cannot be true for symmetric quantization with uint8 + if qdtype == torch.quint8 and qscheme == torch.per_tensor_symmetric: + reduce_range = False + myobs = MinMaxObserver(dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range) x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0]) y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0]) result = myobs(x) @@ -722,15 +775,82 @@ def test_minmax_observer(self, qdtype, qscheme): self.assertEqual(myobs.min_val, 1.0) self.assertEqual(myobs.max_val, 8.0) qparams = myobs.calculate_qparams() - if qscheme == torch.per_tensor_symmetric: - ref_scale = 0.062745 - ref_zero_point = 0 if qdtype is torch.qint8 else 128 + if reduce_range: + if qscheme == torch.per_tensor_symmetric: + ref_scale = 0.062745 * 255 / 127 + ref_zero_point = 0 if qdtype is torch.qint8 else 128 + else: + ref_scale = 0.0313725 * 255 / 127 + ref_zero_point = -64 if qdtype is torch.qint8 else 0 else: - ref_scale = 0.0313725 - ref_zero_point = -128 if qdtype is torch.qint8 else 0 + if qscheme == torch.per_tensor_symmetric: + ref_scale = 0.062745 + ref_zero_point = 0 if qdtype is torch.qint8 else 128 + else: + ref_scale = 0.0313725 + ref_zero_point = -128 if qdtype is torch.qint8 else 0 self.assertEqual(qparams[1].item(), ref_zero_point) self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5) + @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), + qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric)), + ch_axis=st.sampled_from((0, 1, 2, 3)), reduce_range=st.booleans()) + def test_per_channel_minmax_observer(self, qdtype, qscheme, ch_axis, reduce_range): + # reduce_range cannot be true for symmetric quantization with uint8 + if qdtype == torch.quint8 and qscheme == torch.per_channel_symmetric: + reduce_range = False + myobs = PerChannelMinMaxObserver(reduce_range=reduce_range, ch_axis=ch_axis, dtype=qdtype, qscheme=qscheme) + x = torch.tensor( + [ + [[[1.0, 2.0], [2.0, 2.5]], [[3.0, 4.0], [4.5, 6.0]]], + [[[-4.0, -3.0], [5.0, 5.0]], [[6.0, 3.0], [7.0, 8.0]]], + ] + ) + result = myobs(x) + self.assertEqual(result, x) + qparams = myobs.calculate_qparams() + ref_min_vals = [[1.0, -4.0], [-4.0, 3.0], [-4.0, 2.0], [-4.0, -3.0]] + ref_max_vals = [[6.0, 8.0], [5.0, 8.0], [6.0, 8.0], [7.0, 8.0]] + per_channel_symmetric_ref_scales = [ + [0.04705882, 0.06274509], + [0.03921569, 0.0627451], + [0.04705882, 0.0627451], + [0.05490196, 0.0627451], + ] + per_channel_affine_ref_scales = [ + [0.02352941, 0.04705882], + [0.03529412, 0.03137255], + [0.03921569, 0.03137255], + [0.04313726, 0.04313726], + ] + per_channel_affine_qint8_zp = [ + [-128, -43], + [-15, -128], + [-26, -128], + [-35, -58], + ] + per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0], [93, 70]] + + self.assertEqual(myobs.min_vals, ref_min_vals[ch_axis]) + self.assertEqual(myobs.max_vals, ref_max_vals[ch_axis]) + if qscheme == torch.per_channel_symmetric: + ref_scales = per_channel_symmetric_ref_scales[ch_axis] + ref_zero_points = [0, 0] if qdtype is torch.qint8 else [128, 128] + else: + ref_scales = per_channel_affine_ref_scales[ch_axis] + ref_zero_points = ( + per_channel_affine_qint8_zp[ch_axis] + if qdtype is torch.qint8 + else per_channel_affine_quint8_zp[ch_axis] + ) + + if reduce_range: + ref_scales = [s * 255 / 127 for s in ref_scales] + ref_zero_points = [math.floor(z / 2) for z in ref_zero_points] + + self.assertTrue(torch.allclose(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype))) + self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype))) + def test_observer_scriptable(self): obs = torch.quantization.default_observer()() scripted = torch.jit.script(obs) @@ -747,5 +867,75 @@ def test_observer_scriptable(self): loaded = torch.jit.load(buf) self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams()) +@unittest.skipIf(not torch.fbgemm_is_cpu_supported(), + 'Quantization requires FBGEMM. FBGEMM does not play' + ' well with UBSAN at the moment, so we skip the test if' + ' we are in a UBSAN environment.') +class QuantizationDebugTest(QuantizationTestCase): + def test_tensor_observer(self): + model = SingleLayerLinearModel() + model.qconfig = default_debug_qconfig + prepare(model) + # run the evaluation and dump all tensors + test_only_eval_fn(model, self.calib_data) + test_only_eval_fn(model, self.calib_data) + tensor_dict = {} + dump_tensor(model, tensor_dict) + + # we can torch,save() and torch_load() in bento for further analysis + self.assertTrue('fc1.module.activation' in tensor_dict.keys(), + 'activation is not recorded in the dict') + self.assertEqual(len(tensor_dict['fc1.module.activation']), 2 * len(self.calib_data)) + + @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), + qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric))) + def test_tensor_observer_scriptable(self, qdtype, qscheme): + obs = TensorObserver(dtype=qdtype, qscheme=qscheme) + scripted = torch.jit.script(obs) + + x = torch.rand(3, 4) + obs(x) + scripted(x) + self.assertTrue(torch.equal(obs.get_tensor_value()[0], scripted.get_tensor_value()[0])) + buf = io.BytesIO() + torch.jit.save(scripted, buf) + buf.seek(0) + loaded = torch.jit.load(buf) + self.assertTrue(torch.equal(obs.get_tensor_value()[0], loaded.get_tensor_value()[0])) + + @no_deadline + @given(qdtype=st.sampled_from((torch.qint8, torch.quint8)), + qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)), + reduce_range=st.booleans()) + def test_histogram_observer(self, qdtype, qscheme, reduce_range): + myobs = HistogramObserver(bins=3, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range) + x = torch.tensor([2.0, 3.0, 4.0, 5.0]) + y = torch.tensor([5.0, 6.0, 7.0, 8.0]) + myobs(x) + myobs(y) + self.assertEqual(myobs.min_val, 2.0) + self.assertEqual(myobs.max_val, 8.0) + self.assertEqual(myobs.histogram, [2., 3., 3.]) + + qparams = myobs.calculate_qparams() + + if reduce_range: + if qscheme == torch.per_tensor_symmetric: + ref_scale = 0.0470588 * 255 / 127 + ref_zero_point = 0 if qdtype is torch.qint8 else 128 + else: + ref_scale = 0.0235294 * 255 / 127 + ref_zero_point = -64 if qdtype is torch.qint8 else 0 + else: + if qscheme == torch.per_tensor_symmetric: + ref_scale = 0.0470588 + ref_zero_point = 0 if qdtype is torch.qint8 else 128 + else: + ref_scale = 0.0235294 + ref_zero_point = -128 if qdtype is torch.qint8 else 0 + + self.assertEqual(qparams[1].item(), ref_zero_point) + self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5) + if __name__ == '__main__': run_tests() diff --git a/test/test_quantized.py b/test/test_quantized.py index f080fd5a4e0c5..fc34f352e906c 100644 --- a/test/test_quantized.py +++ b/test/test_quantized.py @@ -11,8 +11,10 @@ import hypothesis_utils as hu from hypothesis_utils import no_deadline -from common_utils import TEST_WITH_UBSAN, TestCase, run_tests, IS_WINDOWS, IS_PPC -from common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams +from common_utils import TEST_WITH_UBSAN, TestCase, run_tests, IS_WINDOWS, IS_PPC, \ + TEST_WITH_QNNPACK +from common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \ + enable_mobile_quantized_engine # Make sure we won't have overflows from vpmaddubsw instruction used in FBGEMM. # On the current Intel x86 architecture, we need to utilize vpmaddubsw instruction @@ -137,6 +139,9 @@ def test_qrelu6(self, X): self.assertEqual(qY, qY_hat, message="{} relu failed".format(name)) """Tests the correctness of the scalar addition.""" + @unittest.skip("temporarily disable until failures are fixed. " + + "See https://github.com/pytorch/pytorch/issues/26279") + @no_deadline @given(A=hu.tensor(shapes=hu.array_shapes(1, 4, 1, 5), elements=st.floats(-1e6, 1e6, allow_nan=False), qparams=hu.qparams()), @@ -195,7 +200,6 @@ def test_qadd_relu_same_qparams(self): torch.qint32 : np.int32 } qC = _quantize(C, scale, zero_point, dtype=np_dtype[dtype]) - # print('C', qC) qC_hat = add(qA, qB, scale=scale, zero_point=zero_point) np.testing.assert_equal(qC, qC_hat.int_repr(), "Quantized addition failed.") @@ -432,6 +436,63 @@ def test_max_pool2d(self, X, kernel, stride, dilation, padding): self.assertEqual(a_ref, a_hat.dequantize(), message="ops.quantized.max_pool2d results are off") + """Tests max pool operation on NHWC quantized tensors.""" + @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4, + min_side=1, max_side=10), + qparams=hu.qparams()), + kernel=st.sampled_from((3, 5, 7)), + stride=st.sampled_from((None, 1, 2)), + dilation=st.integers(1, 2), + padding=st.integers(0, 2)) + def test_max_pool2d_nhwc(self, X, kernel, stride, dilation, padding): + X, (scale, zero_point, torch_type) = X + # Ensure we hit the vectorized paths + # 176 = 128 + 32 + 16 + # 128 hits the interleaved path + # 32 hits the non-interleaved path + # 16 hits the scalar path + if X.shape[1] < 176: + X = np.repeat(X, 176 / X.shape[1], 1) + # Check constraints + assume(kernel // 2 >= padding) # Kernel cannot be overhanging! + iH, iW = X.shape[-2:] + oH = pool_output_shape(iH, kernel, padding, stride, dilation) + assume(oH > 0) + oW = pool_output_shape(iW, kernel, padding, stride, dilation) + assume(oW > 0) + + X_nchw = np.ascontiguousarray(X.transpose([0, 2, 3, 1])) + a = torch.from_numpy(X_nchw).permute([0, 3, 1, 2]) + a_pool = torch.nn.functional.max_pool2d(a, kernel_size=kernel, + stride=stride, + padding=padding, dilation=dilation) + a_ref = torch.quantize_linear(a_pool, scale=scale, + zero_point=zero_point, dtype=torch_type) + a_ref = a_ref.dequantize() + qa = torch.quantize_linear(torch.from_numpy(X_nchw), scale=scale, zero_point=zero_point, + dtype=torch_type).permute([0, 3, 1, 2]) + self.assertTrue(qa.stride() != sorted(qa.stride())) + + ops_under_test = { + "torch": torch.max_pool2d, + "nn.functional": torch.nn.functional.max_pool2d, + "nn.quantized.functional": torch.nn.quantized.functional.max_pool2d + } + + for name, op in ops_under_test.items(): + a_hat = op(qa, kernel_size=kernel, stride=stride, padding=padding, + dilation=dilation) + self.assertTrue(a_hat.stride() != sorted(a_hat.stride())) + self.assertEqual(a_ref, a_hat.dequantize(), + message="{} results are off".format(name)) + # Test the ops.quantized separately, because None is not treated. + a_hat = torch.ops.quantized.max_pool2d( + qa, kernel_size=_pair(kernel), + stride=_pair(kernel if stride is None else stride), + padding=_pair(padding), dilation=_pair(dilation)) + self.assertEqual(a_ref, a_hat.dequantize(), + message="ops.quantized.max_pool2d results are off") + @no_deadline @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4, min_side=1, max_side=10), @@ -619,11 +680,11 @@ class TestDynamicQuantizedLinear(TestCase): use_channelwise=st.booleans()) def test_qlinear(self, batch_size, input_channels, output_channels, use_bias, use_relu, use_multi_dim_input, use_channelwise): - qlinear_prepack = torch.ops.quantized.fbgemm_linear_prepack + qlinear_prepack = torch.ops.quantized.linear_prepack if use_relu: - qlinear_dynamic = torch.ops.quantized.fbgemm_linear_relu_dynamic + qlinear_dynamic = torch.ops.quantized.linear_relu_dynamic else: - qlinear_dynamic = torch.ops.quantized.fbgemm_linear_dynamic + qlinear_dynamic = torch.ops.quantized.linear_dynamic if use_multi_dim_input: batch_size *= 3 # Test the multi-dim input tensor @@ -704,9 +765,9 @@ def test_qlinear(self, batch_size, input_channels, output_channels, X_q = torch.quantize_linear(X_fp32, scale=X_scale, zero_point=X_zp, dtype=torch.quint8) # Weight prepacking operator for dynamic quantized Linear - W_prepack = qlinear_prepack(W_q) + W_prepack = qlinear_prepack(W_q, b_fp32) # Dynamic quantized Linear operator with prepacked weight - Y_fp32 = qlinear_dynamic(X_q.dequantize(), W_prepack, b_fp32) + Y_fp32 = qlinear_dynamic(X_q.dequantize(), W_prepack) # Y_fp32 = qlinear_dynamic(X_fp32, W_prepack, b_fp32) Y_fp32_ref = F.linear(X_q.dequantize(), W_q.dequantize(), b_fp32) @@ -718,7 +779,7 @@ def test_qlinear(self, batch_size, input_channels, output_channels, Y_fp32_ref[Y_fp32_ref < 0.0] = 0.0 self.assertEqual(Y_fp32, Y_fp32_ref, - message="torch.ops.quantized.fbgemm_linear_dynamic results are off") + message="torch.ops.quantized.linear_dynamic (fbgemm) results are off") @unittest.skipIf( not torch.fbgemm_is_cpu_supported(), @@ -736,11 +797,11 @@ class TestQuantizedLinear(unittest.TestCase): use_channelwise=st.booleans()) def test_qlinear(self, batch_size, input_channels, output_channels, use_bias, use_relu, use_multi_dim_input, use_channelwise): - qlinear_prepack = torch.ops.quantized.fbgemm_linear_prepack + qlinear_prepack = torch.ops.quantized.linear_prepack if use_relu: - qlinear = torch.ops.quantized.fbgemm_linear_relu + qlinear = torch.ops.quantized.linear_relu else: - qlinear = torch.ops.quantized.fbgemm_linear + qlinear = torch.ops.quantized.linear if use_multi_dim_input: batch_size *= 3 # Test the multi-dim input tensor @@ -815,13 +876,14 @@ def test_qlinear(self, batch_size, input_channels, output_channels, use_bias, Y_zp = 5 # Weight prepacking operator for quantized Linear - W_prepack = qlinear_prepack(W_q) + float_bias = b if use_bias else None + W_prepack = qlinear_prepack(W_q, float_bias) if use_multi_dim_input: X_q = X_q.view(3, int(batch_size / 3), input_channels) # Quantized Linear operator with prepacked weight - Y_q = qlinear(X_q, W_prepack, b_q, Y_scale, Y_zp) + Y_q = qlinear(X_q, W_prepack, Y_scale, Y_zp) if not use_channelwise: # Test the per-tensor quantization only @@ -851,21 +913,20 @@ def test_qlinear(self, batch_size, input_channels, output_channels, use_bias, np.testing.assert_equal( Y_q_ref2.int_repr().numpy(), Y_q.int_repr().numpy()) - """Tests the correctness of the quantized::fbgemm_linear_unpack op.""" + """Tests the correctness of the quantized::linear_unpack (fbgemm) op.""" @given(W=hu.tensor(shapes=hu.array_shapes(2, 2,), qparams=hu.qparams(dtypes=torch.qint8)), use_channelwise=st.booleans()) def test_qlinear_unpack(self, W, use_channelwise): W, (W_scale, W_zp, torch_type) = W - if use_channelwise: output_channels = W.shape[0] W_scales = torch.rand(output_channels).to(torch.double) W_zps = torch.round(torch.rand(output_channels) * 100 - 50).to(torch.int64) - qlinear_prepack = torch.ops.quantized.fbgemm_linear_prepack - qlinear_unpack = torch.ops.quantized.fbgemm_linear_unpack + qlinear_prepack = torch.ops.quantized.linear_prepack + qlinear_unpack = torch.ops.quantized.linear_unpack W = torch.from_numpy(W) @@ -879,7 +940,7 @@ def test_qlinear_unpack(self, W, use_channelwise): # Weight prepacking operator for quantized Linear W_prepack = qlinear_prepack(W_q) # Weight unpack operator for quantized Linear (Used for serialization) - W_q_origin = qlinear_unpack(W_prepack) + W_q_origin = qlinear_unpack(W_prepack)[0] # Assert equal np.testing.assert_equal(W_q.int_repr(), W_q_origin.int_repr().numpy()) @@ -916,7 +977,7 @@ class TestQuantizedConv(unittest.TestCase): stride_w=st.integers(1, 2), pad_h=st.integers(0, 2), pad_w=st.integers(0, 2), - dilation=st.integers(1, 1), + dilation=st.integers(1, 2), X_scale=st.floats(0.2, 1.6), X_zero_point=st.integers(0, 4), W_scale=st.lists(st.floats(0.2, 1.6), min_size=1, max_size=2), @@ -951,11 +1012,10 @@ def test_qconv( use_relu, use_channelwise ): - - qconv = torch.ops.quantized.fbgemm_conv2d + qconv = torch.ops.quantized.conv2d if use_relu: - qconv = torch.ops.quantized.fbgemm_conv2d_relu - qconv_prepack = torch.ops.quantized.fbgemm_conv_prepack + qconv = torch.ops.quantized.conv2d_relu + qconv_prepack = torch.ops.quantized.conv_prepack # C input_channels = input_channels_per_group * groups @@ -964,6 +1024,10 @@ def test_qconv( dilation_h = dilation_w = dilation + # Padded input size should be at least as big as dilated kernel + assume(height + 2 * pad_h >= dilation_h * (kernel_h - 1) + 1) + assume(width + 2 * pad_w >= dilation_w * (kernel_w - 1) + 1) + W_scale = W_scale * output_channels W_zero_point = W_zero_point * output_channels # Resize W_scale and W_zero_points arrays equal to output_channels @@ -1032,34 +1096,22 @@ def test_qconv( # quantize reference results for comparision result_ref_q = torch.quantize_linear(result_ref, scale=Y_scale, zero_point=Y_zero_point, dtype=torch.quint8) - # reformat X_init and W_init in the required format by qconv operator - # NCHW -> NHWC - X_NHWC = X.permute([0, 2, 3, 1]).contiguous() - # K(C/G)RS -> KRS(C/G) - W_KRSC = W.permute([0, 2, 3, 1]).contiguous() - - X_q = torch.quantize_linear(X_NHWC, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8) + X_q = torch.quantize_linear(X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8) if use_channelwise: - W_q = torch.quantize_linear_per_channel(W_KRSC, + W_q = torch.quantize_linear_per_channel(W, W_scales_tensor.to(dtype=torch.double), W_zero_points_tensor.to(dtype=torch.long), [0], dtype=torch.qint8) - b_q = torch.quantize_linear_per_channel(b, - X_scale * W_scales_tensor.to(dtype=torch.double), - torch.zeros(output_channels, dtype=torch.long), - [0], - dtype=torch.qint32) if use_bias else None else: - W_q = torch.quantize_linear(W_KRSC, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8) - b_q = torch.quantize_linear(b, scale=X_scale * W_scale[0], zero_point=0, dtype=torch.qint32) if use_bias else None + W_q = torch.quantize_linear(W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8) - W_prepack = qconv_prepack(W_q, stride, pad, dilation, groups) + bias_float = b if use_bias else None + W_prepack = qconv_prepack(W_q, bias_float, stride, pad, dilation, groups) Y_q = qconv( X_q, W_prepack, - b_q, stride, pad, dilation, @@ -1068,9 +1120,6 @@ def test_qconv( Y_zero_point, ) - # Back to NCHW format - Y_q = Y_q.permute([0, 3, 1, 2]).contiguous() - # Make sure the results match # assert_array_almost_equal compares using the following formula: # abs(desired-actual) < 1.5 * 10**(-decimal) @@ -1086,7 +1135,7 @@ def test_qconv( # assuming the rounding mode is round-to-nearest, ties-to-even. np.testing.assert_array_almost_equal(result_ref_q.int_repr().numpy(), Y_q.int_repr().numpy(), decimal=0) - """Tests the correctness of the quantized::fbgemm_qconv_unpack op.""" + """Tests the correctness of the quantized::qconv_unpack (fbgemm) op.""" @given(X=hu.tensor_conv2d(min_batch=1, max_batch=3, min_in_channels=1, max_in_channels=7, min_out_channels=1, max_out_channels=7, @@ -1116,29 +1165,28 @@ def test_qconv_unpack(self, X, strideH, strideW, padH, padW, channelwise): filters_scale = torch.tensor([filters_scale] * output_channels).to(torch.double) filters_zero_point = torch.tensor([filters_zero_point] * output_channels).to(torch.long) - qconv_prepack = torch.ops.quantized.fbgemm_conv_prepack - qconv_unpack = torch.ops.quantized.fbgemm_conv_unpack + qconv_prepack = torch.ops.quantized.conv_prepack + qconv_unpack = torch.ops.quantized.conv_unpack - # Orig tensor is assumed to be in K(C/G)RS format W = torch.from_numpy(filters).to(torch.float) - # K(C/G)RS -> KRS(C/G) - W_KRSC = W.permute([0, 2, 3, 1]).contiguous() if channelwise: - W_q = torch.quantize_linear_per_channel(W_KRSC, + W_q = torch.quantize_linear_per_channel(W, scales=filters_scale, zero_points=filters_zero_point, axis=[0], dtype=filters_qtype) else: - W_q = torch.quantize_linear(W_KRSC, scale=filters_scale, zero_point=filters_zero_point, dtype=filters_qtype) + W_q = torch.quantize_linear(W, scale=filters_scale, zero_point=filters_zero_point, dtype=filters_qtype) # Pack weights using weight packing operator strides = [strideH, strideW] paddings = [padH, padW] dilations = [1, 1] - W_packed = qconv_prepack(W_q, strides, paddings, dilations, groups) + bias = torch.from_numpy(bias).to(torch.float) + W_packed = qconv_prepack(W_q, bias, strides, paddings, dilations, groups) # Unpack weights weight unpacking operator (Used for serialization) - W_unpacked = qconv_unpack(W_packed) + W_unpacked = qconv_unpack(W_packed)[0] + bias = qconv_unpack(W_packed)[1] # Assert equal np.testing.assert_equal(W_q.int_repr().numpy(), W_unpacked.int_repr().numpy()) @@ -1151,6 +1199,7 @@ def test_qconv_unpack(self, X, strideH, strideW, padH, padW, channelwise): np.testing.assert_equal(np.float32(W_q.q_scale()), np.float32(W_unpacked.q_scale())) np.testing.assert_equal(W_q.q_zero_point(), W_unpacked.q_zero_point()) +@unittest.skipIf(not TEST_WITH_QNNPACK, "This Pytorch Build has not been built with QNNPACK") @unittest.skipIf(IS_WINDOWS, "QNNPACK has not been built for Windows") @unittest.skipIf(IS_PPC, "QNNPACK is not currently supported on ppc64le") @unittest.skipIf(TEST_WITH_UBSAN, @@ -1176,76 +1225,288 @@ def test_qnnpack_relu(self, X): qY = torch.quantize_linear(Y, scale=scale, zero_point=zero_point, dtype=torch_type) self.assertEqual(qY, qY_hat) - """Tests the correctness of the quantized::qnnpack_linear op.""" - @given(output_channels=st.sampled_from([2, 4, 5, 8, 16, 32]), - X=hu.tensor(shapes=hu.array_shapes(2, 3, 8, 15), - qparams=hu.qparams(dtypes=torch.quint8))) - def test_qnnpack_linear(self, output_channels, X): - X, (X_scale, X_zp, torch_type) = X - qmin = torch.iinfo(torch_type).min - qmax = torch.iinfo(torch_type).max + @given(batch_size=st.integers(1, 4), + input_channels=st.integers(16, 32), + output_channels=st.integers(4, 8), + use_relu=st.booleans()) + def test_qlinear_qnnpack(self, batch_size, input_channels, output_channels, use_relu): - input_channels = X.shape[X.ndim - 1] + with enable_mobile_quantized_engine(): + qlinear_prepack = torch.ops.quantized.linear_prepack + if use_relu: + qlinear = torch.ops.quantized.linear_relu + else: + qlinear = torch.ops.quantized.linear + + X_scale = 1.5 + X_zp = 5 + X_value_min = 0 + X_value_max = 225 + X_q0 = np.round( + np.random.rand(batch_size, input_channels) * + (X_value_max - X_value_min) + + X_value_min + ).astype(np.uint8) + + W_scales = np.random.rand(output_channels) + W_zp = 2 + W_value_min = -128 + W_value_max = 127 + W_q0 = np.round( + np.random.rand(output_channels, input_channels) + * (W_value_max - W_value_min) + + W_value_min + ).astype(np.uint8) + + b_value_min = -10 + b_value_max = 10 + b_q0 = np.round( + np.random.rand(output_channels) * + (b_value_max - b_value_min) + b_value_min + ).astype(np.int32) + + X = torch.from_numpy(_dequantize( + X_q0, X_scale, X_zp)).to(dtype=torch.float) + X_q = torch.quantize_linear( + X, scale=X_scale, zero_point=X_zp, dtype=torch.quint8) - input_rows = 1 + W = torch.from_numpy(_dequantize( + W_q0, W_scales[0], W_zp)).to(dtype=torch.float) + W_q = torch.quantize_linear(W, scale=W_scales[0], zero_point=( + W_zp), dtype=torch.quint8) + b = torch.from_numpy(_dequantize( + b_q0, X_scale * (W_scales[0].item()), 0)).to(dtype=torch.float) + b_q = torch.quantize_linear( + b, scale=X_scale * (W_scales[0].item()), zero_point=0, dtype=torch.qint32) - for x in range(X.ndim - 1): - input_rows *= X.shape[x] + # Compare X_scale * W_scale * input_channels * X_value_max * W_value_max with + # Y_scale * 255 (max for uint8). + Y_scale = 125.1234 + Y_zp = 5 - qnnpack_linear = torch.ops.quantized.qnnpack_linear + # Weight prepacking operator for quantized Linear + W_prepack = qlinear_prepack(W_q, b_q) - X_q0 = np.round(X * (qmin - qmax) + qmin).astype(np.uint8) + # Quantized Linear operator with prepacked weight + Y_q = qlinear(X_q, W_prepack, Y_scale, Y_zp) - W_scale = 0.4 - W_zp = 0 - W_value_min = 0 - W_value_max = 255 - W_q0 = np.round( - np.random.rand(output_channels, input_channels) - * (W_value_max - W_value_min) - + W_value_min - ).astype(np.uint8) + # Reference quantized Linear operator + Y_q_ref = qlinear_ref(X_q0, X_scale, X_zp, W_q0, + W_scales[0], W_zp, b_q0, Y_scale, Y_zp) + if use_relu: + Y_q_ref[Y_q_ref < Y_zp] = Y_zp + # Assert equal + np.testing.assert_array_almost_equal(Y_q_ref, Y_q.int_repr().numpy(), decimal=4) + + # Test both per-tensor and per-channel quantization + # Reference quantized result from PyTorch Linear operator + W_fp32 = W_q.dequantize().to(dtype=torch.float) + X_fp32 = X_q.dequantize().to(dtype=torch.float) + b_fp32 = b_q.dequantize().to(dtype=torch.float) + Y_fp32_ref = F.linear(X_fp32, W_fp32, b_fp32) + if use_relu: + Y_fp32_ref[Y_fp32_ref < 0.0] = 0.0 + Y_q_ref2 = torch.quantize_linear( + Y_fp32_ref, Y_scale, Y_zp, torch.quint8) + # Assert equal + np.testing.assert_equal( + Y_q_ref2.int_repr().numpy(), Y_q.int_repr().numpy()) - b_value_min = -10 - b_value_max = 10 - b_q0 = np.round( - np.random.rand(output_channels) * (b_value_max - b_value_min) + b_value_min - ).astype(np.int32) + """Tests the correctness of the quantized::linear_unpack (qnnpack) op.""" + @given(W=hu.tensor(shapes=hu.array_shapes(2, 2,), + qparams=hu.qparams(dtypes=torch.quint8))) + def test_qlinear_unpack(self, W): + W, (W_scale, W_zp, torch_type) = W - X_scale = 10 - X_zp = 0 - X = torch.from_numpy(_dequantize(X_q0, X_scale, X_zp)).to(dtype=torch.float) - W = torch.from_numpy(_dequantize(W_q0, W_scale, W_zp)).to(dtype=torch.float) - b = torch.from_numpy(_dequantize(b_q0, X_scale * W_scale, 0)).to(dtype=torch.float) + with enable_mobile_quantized_engine(): + qlinear_prepack = torch.ops.quantized.linear_prepack + qlinear_unpack = torch.ops.quantized.linear_unpack - X_q = torch.quantize_linear(X, scale=X_scale, zero_point=X_zp, dtype=torch.quint8) - W_q = torch.quantize_linear(W, scale=W_scale, zero_point=W_zp, dtype=torch.quint8) - b_q = torch.quantize_linear(b, scale=X_scale * W_scale, zero_point=0, dtype=torch.qint32) + W = torch.from_numpy(W) + W_q = torch.quantize_linear(W, scale=W_scale, zero_point=W_zp, + dtype=torch_type) - Y_scale = 5.4 # This makes sure that the max output value does not exceed 255. - Y_zp = 0 + # Weight prepacking operator for quantized Linear + W_prepack = qlinear_prepack(W_q) + # Weight unpack operator for quantized Linear (Used for serialization) + W_q_origin = qlinear_unpack(W_prepack)[0] - # Reference quantized Linear operator - Y_q_ref = qlinear_ref(X_q0, X_scale, X_zp, W_q0, W_scale, W_zp, b_q0, Y_scale, Y_zp) - Y_q_ref_float = _dequantize(Y_q_ref, Y_scale, Y_zp) + # Assert equal + np.testing.assert_equal(W_q.int_repr(), W_q_origin.int_repr().numpy()) - # Quantized linear operator - Y_q = qnnpack_linear(X_q, W_q, b_q, Y_scale, Y_zp) + np.testing.assert_equal(np.float32( + W_q.q_scale()), np.float32(W_q_origin.q_scale())) + np.testing.assert_equal( + W_q.q_zero_point(), W_q_origin.q_zero_point()) - # Assert equal - np.testing.assert_array_almost_equal(Y_q_ref_float, Y_q.dequantize().numpy(), decimal=4) + @given(batch_size=st.integers(1, 3), + input_channels_per_group=st.sampled_from([8, 16, 32]), + height=st.integers(10, 16), + width=st.integers(7, 14), + output_channels_per_group=st.sampled_from([8, 16, 32]), + groups=st.integers(1, 3), + kernel_h=st.integers(1, 7), + kernel_w=st.integers(1, 7), + stride_h=st.integers(1, 2), + stride_w=st.integers(1, 2), + pad_h=st.integers(0, 2), + pad_w=st.integers(0, 2), + dilation_h=st.integers(1, 1), + X_scale=st.floats(1.2, 1.6), + X_zp=st.integers(0, 4), + W_scale=st.floats(0.2, 1.6), + W_zp=st.integers(2, 5), + Y_scale=st.floats(4.2, 5.6), + Y_zp=st.integers(0, 4), + use_relu=st.booleans()) + def test_qconv_qnnpack( + self, + batch_size, + input_channels_per_group, + height, + width, + output_channels_per_group, + groups, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + X_scale, + X_zp, + W_scale, + W_zp, + Y_scale, + Y_zp, + use_relu): + with enable_mobile_quantized_engine(): + # C + input_channels = input_channels_per_group * groups + # K + output_channels = output_channels_per_group * groups + + stride = [stride_h, stride_w] + padding = [pad_h, pad_w] + kernel = [kernel_h, kernel_w] + dilation = [dilation_h, dilation_h] + + W_value_min = 0 + W_value_max = 10 + W_init = torch.from_numpy( + np.random.randint( + W_value_min, + W_value_max, + (output_channels, int(input_channels / groups), kernel_h, kernel_w)), + ) + b_init = torch.from_numpy(np.random.randint(0, 10, (output_channels,))) + + X_value_min = 0 + X_value_max = 10 + X_init = torch.from_numpy(np.random.randint( + X_value_min, X_value_max, (batch_size, input_channels, height, width))) + + # Existing floating point conv operator + conv_op = torch.nn.Conv2d( + input_channels, + output_channels, + (kernel_h, kernel_w), + (stride_h, stride_w), + (pad_h, pad_w), + (dilation_h, dilation_h), + groups, + ) + + X = X_scale * (X_init - X_zp).to(dtype=torch.float) + + W = W_scale * (W_init - W_zp).to(dtype=torch.float) + + b = X_scale * W_scale * (b_init - 0).to(dtype=torch.float) + + # assign weights + conv_op.weight = torch.nn.Parameter(W, requires_grad=False) + conv_op.bias = torch.nn.Parameter(b, requires_grad=False) + + result_ref = conv_op(X) + + X_q = torch.quantize_linear(X, scale=X_scale, zero_point=X_zp, dtype=torch.quint8) + W_q = torch.quantize_linear(W, scale=W_scale, zero_point=W_zp, dtype=torch.quint8) + b_q = torch.quantize_linear(b, scale=X_scale * W_scale, zero_point=0, dtype=torch.qint32) + + W_pack = torch.ops.quantized.conv_prepack(W_q, b_q, stride, padding, dilation, groups) + qconv = torch.ops.quantized.conv2d + if use_relu: + qconv = torch.ops.quantized.conv2d_relu + + Y_q = qconv( + X_q, + W_pack, + stride, + padding, + dilation, + groups, + Y_scale, + Y_zp + ) - # Reference quantized result from PyTorch Linear operator + if use_relu: + relu = torch.nn.ReLU() + result_ref = relu(result_ref) - W_fp32 = W_q.dequantize().to(dtype=torch.float) - X_fp32 = X_q.dequantize().to(dtype=torch.float) - b_fp32 = b_q.dequantize().to(dtype=torch.float) - Y_fp32_ref = F.linear(X_fp32, W_fp32, b_fp32) - Y_fp32_ref = Y_fp32_ref.view(-1, output_channels) - Y_q_ref2 = torch.quantize_linear(Y_fp32_ref, Y_scale, Y_zp, torch.quint8) + result_ref_q = torch.quantize_linear(result_ref, scale=Y_scale, zero_point=Y_zp, dtype=torch.quint8) - # Assert equal - np.testing.assert_array_almost_equal(Y_q_ref2.dequantize().numpy(), Y_q.dequantize().numpy(), decimal=4) + np.testing.assert_array_almost_equal(result_ref_q.int_repr().numpy(), Y_q.int_repr().numpy(), decimal=0) + + """Tests the correctness of the quantized::qconv_unpack (qnnpack) op.""" + @given(X=hu.tensor_conv2d(min_batch=1, max_batch=3, + min_in_channels=1, max_in_channels=7, + min_out_channels=1, max_out_channels=7, + H_range=(6, 12), W_range=(6, 12), + kH_range=(3, 5), kW_range=(3, 5), + max_groups=4, + qparams=[hu.qparams(dtypes=torch.quint8, + zero_point_min=0, + zero_point_max=0), + hu.qparams(dtypes=torch.quint8, + zero_point_min=0, + zero_point_max=0), + hu.qparams(dtypes=torch.qint32, + zero_point_min=0, + zero_point_max=0)]), + strideH=st.integers(1, 3), strideW=st.integers(1, 3), + padH=st.integers(1, 2), padW=st.integers(1, 2)) + def test_qconv_unpack(self, X, strideH, strideW, padH, padW): + with enable_mobile_quantized_engine(): + (inputs, filters, bias, groups) = X + inputs, (inputs_scale, inputs_zero_point, inputs_qtype) = inputs + filters, (filters_scale, filters_zero_point, filters_qtype) = filters + bias, (bias_scale, bias_zero_point, bias_qtype) = bias + + qconv_prepack = torch.ops.quantized.conv_prepack + qconv_unpack = torch.ops.quantized.conv_unpack + + # Orig tensor is assumed to be in K(C/G)RS format + W = torch.from_numpy(filters).to(torch.float) + + W_q = torch.quantize_linear(W, scale=filters_scale, zero_point=filters_zero_point, dtype=filters_qtype) + + # Pack weights using weight packing operator + strides = [strideH, strideW] + paddings = [padH, padW] + dilations = [1, 1] + bias = torch.from_numpy(bias).to(torch.float) + b_q = torch.quantize_linear(bias, scale=bias_scale, zero_point=bias_zero_point, dtype=bias_qtype) + W_packed = qconv_prepack(W_q, b_q, strides, paddings, dilations, groups) + # Unpack weights weight unpacking operator (Used for serialization) + W_unpacked = qconv_unpack(W_packed)[0] + b_q = qconv_unpack(W_packed)[1] + + # Assert equal + np.testing.assert_equal(W_q.int_repr().numpy(), W_unpacked.int_repr().numpy()) + + np.testing.assert_equal(np.float32(W_q.q_scale()), np.float32(W_unpacked.q_scale())) + np.testing.assert_equal(W_q.q_zero_point(), W_unpacked.q_zero_point()) """Tests the correctness of the quantized::qnnpack_add op.""" @given(A=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5), @@ -1322,6 +1583,7 @@ def test_qnnpack_maxpool2d(self, A, kernel, stride, padding): d = (dilation, dilation) p = (padding, padding) + # TODO(supriyar): unify qnnpack op apis with generic ones and follow logical NCHW q_max_pool = torch.ops.quantized.qnnpack_maxpool2d a = scale * (X - zero_point).to(dtype=torch.float) diff --git a/test/test_quantized_models.py b/test/test_quantized_models.py new file mode 100644 index 0000000000000..08d08c4c9a50a --- /dev/null +++ b/test/test_quantized_models.py @@ -0,0 +1,29 @@ +import torch +import torch.jit +from common_utils import run_tests +from common_quantization import QuantizationTestCase, ModelMultipleOps + +class ModelNumerics(QuantizationTestCase): + def test_float_quant_compare(self): + torch.manual_seed(42) + myModel = ModelMultipleOps().to(torch.float32) + myModel.eval() + calib_data = torch.rand(1024, 3, 15, 15, dtype=torch.float32) + eval_data = torch.rand(1, 3, 15, 15, dtype=torch.float32) + out_ref = myModel(eval_data) + qModel = torch.quantization.QuantWrapper(myModel) + qModel.eval() + qModel.qconfig = torch.quantization.default_qconfig + torch.quantization.fuse_modules(qModel.module, [['conv1', 'bn1', 'relu1']]) + torch.quantization.prepare(qModel) + qModel(calib_data) + torch.quantization.convert(qModel) + out_q = qModel(eval_data) + SQNRdB = 20 * torch.log10(torch.norm(out_ref) / torch.norm(out_ref - out_q)) + # Quantized model output should be close to floating point model output numerically + # Setting target SQNR to be 30 dB so that relative error is 1e-3 below the desired + # output + self.assertGreater(SQNRdB, 30, msg='Quantized model numerics diverge from float, expect SQNR > 30 dB') + +if __name__ == "__main__": + run_tests() diff --git a/test/test_quantized_nn_mods.py b/test/test_quantized_nn_mods.py index 76d6f65e4abb1..b7e1a51f1fc72 100644 --- a/test/test_quantized_nn_mods.py +++ b/test/test_quantized_nn_mods.py @@ -6,13 +6,14 @@ from torch.nn.quantized.modules import Conv2d from torch.nn._intrinsic.quantized import ConvReLU2d import torch.quantization -from common_utils import run_tests, tempfile +from common_utils import run_tests from common_quantization import QuantizationTestCase, prepare_dynamic from common_quantized import _calculate_dynamic_qparams from hypothesis import given from hypothesis import strategies as st from hypothesis_utils import no_deadline import unittest +import io ''' Note that tests in this file are just API test, to make sure we wrapped the @@ -55,7 +56,6 @@ def test_conv_api(self, use_bias): dilation = (1, 1) X = torch.randn(N, iC, H, W, dtype=torch.float32) - X = X.permute([0, 2, 3, 1]).contiguous() qX = torch.quantize_linear(X, scale=scale, zero_point=128, dtype=torch.quint8) w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32) @@ -63,23 +63,22 @@ def test_conv_api(self, use_bias): qw = torch.quantize_linear(w, scale=scale, zero_point=0, dtype=torch.qint8) b = torch.randn(oC, dtype=torch.float32) if use_bias else None - q_bias = torch.quantize_linear(b, scale=1.0 / 1024, zero_point=0, dtype=torch.qint32) if use_bias else None - q_filters_ref = torch.ops.quantized.fbgemm_conv_prepack(qw.permute([0, 2, 3, 1]), - stride, - i_padding, - dilation, - g) + q_filters_ref = torch.ops.quantized.conv_prepack(qw, + b, + stride, + i_padding, + dilation, + g) - requantized_bias = torch.quantize_linear(q_bias.dequantize(), scale * scale, 0 , torch.qint32) if use_bias else None - ref_result = torch.ops.quantized.fbgemm_conv2d(qX.permute([0, 2, 3, 1]), q_filters_ref, - requantized_bias, stride, - i_padding, dilation, - g, scale, zero_point).permute([0, 3, 1, 2]) + ref_result = torch.ops.quantized.conv2d(qX, q_filters_ref, + stride, + i_padding, dilation, + g, scale, zero_point) q_result = torch.nn.quantized.functional.conv2d(qX, qw, - bias=q_bias, scale=scale, + bias=b, scale=scale, zero_point=zero_point, stride=stride, padding=i_padding, dilation=dilation, groups=g, @@ -112,18 +111,17 @@ def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_d qlinear = nnqd.Linear(in_features, out_features) # Run module with default-initialized parameters. # This tests that the constructor is correct. + qlinear.set_weight_bias(W_q, B) qlinear(X) - qlinear.set_weight(W_q) # Simple round-trip test to ensure weight()/set_weight() API self.assertEqual(qlinear.weight(), W_q) - W_pack = qlinear._packed_weight - qlinear.bias = B if use_bias else None + W_pack = qlinear._packed_params Z_dq = qlinear(X) # Check if the module implementation matches calling the # ops directly - Z_ref = torch.ops.quantized.fbgemm_linear_dynamic(X, W_pack, B) + Z_ref = torch.ops.quantized.linear_dynamic(X, W_pack) self.assertEqual(Z_ref, Z_dq) # Test serialization of dynamic quantized Linear Module using state_dict @@ -131,36 +129,36 @@ def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_d self.assertEqual(model_dict['weight'], W_q) if use_bias: self.assertEqual(model_dict['bias'], B) - with tempfile.TemporaryFile() as f: - torch.save(model_dict, f) - f.seek(0) - loaded_dict = torch.load(f) + b = io.BytesIO() + torch.save(model_dict, b) + b.seek(0) + loaded_dict = torch.load(b) for key in model_dict: self.assertEqual(model_dict[key], loaded_dict[key]) loaded_qlinear = nnqd.Linear(in_features, out_features) loaded_qlinear.load_state_dict(loaded_dict) - linear_unpack = torch.ops.quantized.fbgemm_linear_unpack - self.assertEqual(linear_unpack(qlinear._packed_weight), - linear_unpack(loaded_qlinear._packed_weight)) + linear_unpack = torch.ops.quantized.linear_unpack + self.assertEqual(linear_unpack(qlinear._packed_params), + linear_unpack(loaded_qlinear._packed_params)) if use_bias: - self.assertEqual(qlinear.bias, loaded_qlinear.bias) + self.assertEqual(qlinear.bias(), loaded_qlinear.bias()) self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) - self.assertTrue(hasattr(qlinear, '_packed_weight')) - self.assertTrue(hasattr(loaded_qlinear, '_packed_weight')) - self.assertTrue(hasattr(qlinear, 'weight')) - self.assertTrue(hasattr(loaded_qlinear, 'weight')) + self.assertTrue(hasattr(qlinear, '_packed_params')) + self.assertTrue(hasattr(loaded_qlinear, '_packed_params')) + self.assertTrue(hasattr(qlinear, '_weight_bias')) + self.assertTrue(hasattr(loaded_qlinear, '_weight_bias')) - self.assertEqual(qlinear.weight(), loaded_qlinear.weight()) - self.assertEqual(qlinear.weight(), torch.ops.quantized.fbgemm_linear_unpack(qlinear._packed_weight)) + self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias()) + self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params)) Z_dq2 = qlinear(X) self.assertEqual(Z_dq, Z_dq2) # test serialization of module directly - with tempfile.TemporaryFile() as f: - torch.save(qlinear, f) - f.seek(0) - loaded = torch.load(f) + b = io.BytesIO() + torch.save(qlinear, b) + b.seek(0) + loaded = torch.load(b) # This check is disabled pending an issue in PyTorch serialization: # https://github.com/pytorch/pytorch/issues/24045 # self.assertEqual(qlinear.weight(), loaded.weight()) @@ -223,7 +221,6 @@ def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_f X = torch.rand(batch_size, in_features).float() X_q = torch.quantize_linear(X, 0.2, 10, torch.quint8) B = torch.rand(out_features).float() if use_bias else None - B_q = torch.quantize_linear(B, W_q.q_scale() * X_q.q_scale(), 0, torch.qint32) if use_bias else None scale = 0.5 zero_point = 3 if use_fused: @@ -235,11 +232,10 @@ def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_f # This tests that the constructor is correct. qlinear(X_q) - qlinear.set_weight(W_q) + qlinear.set_weight_bias(W_q, B) # Simple round-trip test to ensure weight()/set_weight() API self.assertEqual(qlinear.weight(), W_q) - W_pack = qlinear._packed_weight - qlinear.bias = B_q if use_bias else None + W_pack = qlinear._packed_params qlinear.scale = float(scale) qlinear.zero_point = int(zero_point) @@ -247,9 +243,9 @@ def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_f # Check if the module implementation matches calling the # ops directly if use_fused: - Z_ref = torch.ops.quantized.fbgemm_linear_relu(X_q, W_pack, B_q, scale, zero_point) + Z_ref = torch.ops.quantized.linear_relu(X_q, W_pack, scale, zero_point) else: - Z_ref = torch.ops.quantized.fbgemm_linear(X_q, W_pack, B_q, scale, zero_point) + Z_ref = torch.ops.quantized.linear(X_q, W_pack, scale, zero_point) self.assertEqual(Z_ref, Z_q) # Test serialization of quantized Linear Module using state_dict @@ -257,11 +253,11 @@ def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_f model_dict = qlinear.state_dict() self.assertEqual(model_dict['weight'], W_q) if use_bias: - self.assertEqual(model_dict['bias'], B_q) - with tempfile.TemporaryFile() as f: - torch.save(model_dict, f) - f.seek(0) - loaded_dict = torch.load(f) + self.assertEqual(model_dict['bias'], B) + b = io.BytesIO() + torch.save(model_dict, b) + b.seek(0) + loaded_dict = torch.load(b) for key in model_dict: self.assertEqual(model_dict[key], loaded_dict[key]) if use_fused: @@ -270,32 +266,31 @@ def test_linear_api(self, batch_size, in_features, out_features, use_bias, use_f loaded_qlinear = nnq.Linear(in_features, out_features) loaded_qlinear.load_state_dict(loaded_dict) - linear_unpack = torch.ops.quantized.fbgemm_linear_unpack - self.assertEqual(linear_unpack(qlinear._packed_weight), - linear_unpack(loaded_qlinear._packed_weight)) + linear_unpack = torch.ops.quantized.linear_unpack + self.assertEqual(linear_unpack(qlinear._packed_params), + linear_unpack(loaded_qlinear._packed_params)) if use_bias: - self.assertEqual(qlinear.bias, loaded_qlinear.bias) + self.assertEqual(qlinear.bias(), loaded_qlinear.bias()) self.assertEqual(qlinear.scale, loaded_qlinear.scale) self.assertEqual(qlinear.zero_point, loaded_qlinear.zero_point) self.assertTrue(dir(qlinear) == dir(loaded_qlinear)) - self.assertTrue(hasattr(qlinear, '_packed_weight')) - self.assertTrue(hasattr(loaded_qlinear, '_packed_weight')) - self.assertTrue(hasattr(qlinear, 'weight')) - self.assertTrue(hasattr(loaded_qlinear, 'weight')) - self.assertEqual(qlinear.weight(), loaded_qlinear.weight()) - self.assertEqual(qlinear.weight(), torch.ops.quantized.fbgemm_linear_unpack(qlinear._packed_weight)) + self.assertTrue(hasattr(qlinear, '_packed_params')) + self.assertTrue(hasattr(loaded_qlinear, '_packed_params')) + self.assertTrue(hasattr(qlinear, '_weight_bias')) + self.assertTrue(hasattr(loaded_qlinear, '_weight_bias')) + self.assertEqual(qlinear._weight_bias(), loaded_qlinear._weight_bias()) + self.assertEqual(qlinear._weight_bias(), torch.ops.quantized.linear_unpack(qlinear._packed_params)) Z_q2 = loaded_qlinear(X_q) self.assertEqual(Z_q, Z_q2) # test serialization of module directly - with tempfile.TemporaryFile() as f: - torch.save(qlinear, f) - f.seek(0) - loaded = torch.load(f) + b = io.BytesIO() + torch.save(qlinear, b) + b.seek(0) + loaded = torch.load(b) # This check is disabled pending an issue in PyTorch serialization: # https://github.com/pytorch/pytorch/issues/24045 # self.assertEqual(qlinear.weight(), loaded.weight()) - self.assertEqual(qlinear.bias, loaded.bias) self.assertEqual(qlinear.scale, loaded.scale) self.assertEqual(qlinear.zero_point, loaded.zero_point) @@ -352,7 +347,6 @@ def test_conv_api(self, use_bias, use_fused): scale, zero_point = 1.0 / 255, 128 X = torch.randn(N, iC, H, W, dtype=torch.float32) - X = X.permute([0, 2, 3, 1]).contiguous() qX = torch.quantize_linear(X, scale=scale, zero_point=128, dtype=torch.quint8) w = torch.randn(oC, iC // g, kH, kW, dtype=torch.float32) @@ -360,7 +354,6 @@ def test_conv_api(self, use_bias, use_fused): qw = torch.quantize_linear(w, scale=scale, zero_point=0, dtype=torch.qint8) b = torch.randn(oC, dtype=torch.float32) if use_bias else None - qb = torch.quantize_linear(b, scale=1.0 / 1024, zero_point=0, dtype=torch.qint32) if use_bias else None if use_fused: conv_under_test = ConvReLU2d(in_channels=iC, @@ -384,27 +377,26 @@ def test_conv_api(self, use_bias, use_fused): padding_mode='zeros') # Run module with default-initialized parameters. # This tests that the constructor is correct. + conv_under_test.set_weight_bias(qw, b) conv_under_test(qX) - conv_under_test.set_weight(qw) - conv_under_test.bias = qb conv_under_test.scale = scale conv_under_test.zero_point = zero_point # Test members - self.assertTrue(hasattr(conv_under_test, '_packed_weight')) + self.assertTrue(hasattr(conv_under_test, '_packed_params')) self.assertTrue(hasattr(conv_under_test, 'scale')) self.assertTrue(hasattr(conv_under_test, 'zero_point')) # Test properties self.assertEqual(qw, conv_under_test.weight()) - self.assertEqual(qb, conv_under_test.bias) + self.assertEqual(b, conv_under_test.bias()) self.assertEqual(scale, conv_under_test.scale) self.assertEqual(zero_point, conv_under_test.zero_point) # Test forward result_under_test = conv_under_test(qX) - result_reference = qF.conv2d(qX, qw, bias=qb, + result_reference = qF.conv2d(qX, qw, bias=b, scale=scale, zero_point=zero_point, stride=1, padding=0, dilation=1, groups=g, dtype=torch.quint8 @@ -428,11 +420,11 @@ def test_conv_api(self, use_bias, use_fused): model_dict = conv_under_test.state_dict() self.assertEqual(model_dict['weight'], qw) if use_bias: - self.assertEqual(model_dict['bias'], qb) - with tempfile.NamedTemporaryFile() as f: - torch.save(model_dict, f) - f.seek(0) - loaded_dict = torch.load(f) + self.assertEqual(model_dict['bias'], b) + b = io.BytesIO() + torch.save(model_dict, b) + b.seek(0) + loaded_dict = torch.load(b) for key in model_dict: self.assertEqual(loaded_dict[key], model_dict[key]) if use_fused: @@ -456,27 +448,27 @@ def test_conv_api(self, use_bias, use_fused): bias=use_bias, padding_mode='zeros') loaded_conv_under_test.load_state_dict(loaded_dict) - self.assertEqual(loaded_conv_under_test.weight(), conv_under_test.weight()) + self.assertEqual(loaded_conv_under_test._weight_bias(), conv_under_test._weight_bias()) if use_bias: - self.assertEqual(loaded_conv_under_test.bias, conv_under_test.bias) + self.assertEqual(loaded_conv_under_test.bias(), conv_under_test.bias()) self.assertEqual(loaded_conv_under_test.scale, conv_under_test.scale) self.assertEqual(loaded_conv_under_test.zero_point, conv_under_test.zero_point) self.assertTrue(dir(loaded_conv_under_test) == dir(conv_under_test)) - self.assertTrue(hasattr(conv_under_test, '_packed_weight')) - self.assertTrue(hasattr(loaded_conv_under_test, '_packed_weight')) - self.assertTrue(hasattr(conv_under_test, 'weight')) - self.assertTrue(hasattr(loaded_conv_under_test, 'weight')) - self.assertEqual(loaded_conv_under_test.weight(), conv_under_test.weight()) + self.assertTrue(hasattr(conv_under_test, '_packed_params')) + self.assertTrue(hasattr(loaded_conv_under_test, '_packed_params')) + self.assertTrue(hasattr(conv_under_test, '_weight_bias')) + self.assertTrue(hasattr(loaded_conv_under_test, '_weight_bias')) + self.assertEqual(loaded_conv_under_test._weight_bias(), conv_under_test._weight_bias()) self.assertEqual(loaded_conv_under_test.weight(), qw) loaded_result = loaded_conv_under_test(qX) self.assertEqual(loaded_result, result_reference) - with tempfile.NamedTemporaryFile() as f: - torch.save(conv_under_test, f) - f.seek(0) - loaded_conv = torch.load(f) + b = io.BytesIO() + torch.save(conv_under_test, b) + b.seek(0) + loaded_conv = torch.load(b) - self.assertEqual(conv_under_test.bias, loaded_conv.bias) + self.assertEqual(conv_under_test.bias(), loaded_conv.bias()) self.assertEqual(conv_under_test.scale, loaded_conv.scale) self.assertEqual(conv_under_test.zero_point, loaded_conv.zero_point) @@ -501,10 +493,8 @@ def test_conv_api(self, use_bias, use_fused): # Smoke test to make sure the module actually runs quantized_float_conv(qX) - # Check that bias is quantized based on output scale if use_bias: - qbias = torch.quantize_linear(float_conv.bias, quantized_float_conv[0].scale / 2**16, 0, torch.qint32) - self.assertEqual(quantized_float_conv[0].bias.dequantize(), qbias.dequantize()) + self.assertEqual(quantized_float_conv[0].bias(), float_conv.bias) # Smoke test extra_repr str(quantized_float_conv) diff --git a/test/test_quantized_tensor.py b/test/test_quantized_tensor.py index 055c122f9d0ea..81c1e86224097 100644 --- a/test/test_quantized_tensor.py +++ b/test/test_quantized_tensor.py @@ -62,8 +62,8 @@ def test_qtensor(self): "scale=1.0, zero_point=2)") def test_qtensor_quant_dequant(self): - r = torch.rand(3, 2, dtype=torch.float) * 2 - 4 - scale = 2 + r = torch.rand(3, 2, dtype=torch.float) * 4 - 2 + scale = 0.02 zero_point = 2 qr = torch.quantize_linear(r, scale, zero_point, torch.quint8) rqr = qr.dequantize() @@ -77,6 +77,7 @@ def test_per_channel_qtensor_creation(self): q = torch._empty_per_channel_affine_quantized_like(scales, zero_points, [numel], [ch_axis], dtype=torch.quint8) self.assertEqual(scales, q.q_per_channel_scales()) self.assertEqual(zero_points, q.q_per_channel_zero_points()) + self.assertEqual([ch_axis], q.q_per_channel_axis()) # create Tensor from uint8_t Tensor, scales and zero_points int_tensor = torch.randint(0, 100, size=(numel,), dtype=torch.uint8) @@ -84,6 +85,7 @@ def test_per_channel_qtensor_creation(self): self.assertEqual(int_tensor, q.int_repr()) self.assertEqual(scales, q.q_per_channel_scales()) self.assertEqual(zero_points, q.q_per_channel_zero_points()) + self.assertEqual([ch_axis], q.q_per_channel_axis()) def test_qtensor_creation(self): scale = 0.5 @@ -113,8 +115,8 @@ def test_qtensor_creation(self): torch.empty_like(q, dtype=torch.qint8) def test_qtensor_dtypes(self): - r = torch.rand(3, 2, dtype=torch.float) * 2 - 4 - scale = 2 + r = torch.rand(3, 2, dtype=torch.float) * 4 - 2 + scale = 0.2 zero_point = 2 qr = torch.quantize_linear(r, scale, zero_point, torch.qint8) rqr = qr.dequantize() @@ -135,8 +137,8 @@ def test_qtensor_dequantize_linear(self): self.assertEqual(qt, qt2.dequantize()) def test_qtensor_per_channel_affine(self): - r = torch.rand(3, 2, dtype=torch.float) * 2 - 4 - scales = torch.tensor([2.0, 3.0], dtype=torch.double) + r = torch.rand(3, 2, dtype=torch.float) * 4 - 2 + scales = torch.tensor([0.2, 0.03], dtype=torch.double) zero_points = torch.tensor([5, 10], dtype=torch.long) axis = [1] @@ -153,28 +155,58 @@ def quantize_c(data, scales, zero_points): self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / np.min(scales.numpy()))) def test_qtensor_permute(self): - r = torch.rand(100, 30, dtype=torch.float) * 2 - 4 - scale = 2 - zero_point = 2 + r = torch.rand(10, 30, 2, 2, dtype=torch.float) * 4 - 2 + scale = 0.02 + zero_point = 1 qr = torch.quantize_linear(r, scale, zero_point, torch.qint8) qr = qr.transpose(0, 1) rqr = qr.dequantize() # compare transpose + dequantized result with orignal transposed result - self.assertTrue(np.allclose(r.numpy().T, rqr.numpy(), atol=2 / scale)) + self.assertTrue(np.allclose(r.numpy().transpose([1, 0, 2, 3]), rqr.numpy(), atol=2 / scale)) qr = torch.quantize_linear(r, scale, zero_point, torch.qint8) - qr1 = qr.permute([1, 0]) + qr1 = qr.permute([1, 0, 2, 3]) qr2 = qr.transpose(0, 1) # compare int representation after transformations - self.assertTrue(torch.equal(qr1.int_repr(), qr2.int_repr())) - self.assertTrue(qr1.q_scale() == qr2.q_scale()) - self.assertTrue(qr1.q_zero_point() == qr2.q_zero_point()) + self.assertEqual(qr1.int_repr(), qr2.int_repr()) + self.assertEqual(qr1.q_scale(), qr2.q_scale()) + self.assertEqual(qr1.q_zero_point(), qr2.q_zero_point()) # compare dequantized result - self.assertTrue(np.array_equal(qr1.dequantize().numpy(), qr2.dequantize().numpy())) + self.assertEqual(qr1.dequantize(), qr2.dequantize()) # compare permuted + dequantized result with original transposed result - self.assertTrue(np.allclose(qr2.dequantize().numpy(), r.numpy().T, atol=2 / scale)) + self.assertTrue(np.allclose(qr2.dequantize().numpy(), r.numpy().transpose([1, 0, 2, 3]), atol=2 / scale)) # make permuted result contiguous - self.assertTrue(torch.equal(qr2.contiguous().int_repr(), qr2.int_repr())) + self.assertEqual(qr2.contiguous().int_repr(), qr2.int_repr()) + + # change memory format + qlast = qr.contiguous(memory_format=torch.channels_last) + self.assertEqual(qr.stride(), list(reversed(sorted(qr.stride())))) + self.assertNotEqual(qlast.stride(), list(reversed(sorted(qlast.stride())))) + self.assertEqual(qr.int_repr(), qlast.int_repr()) + self.assertEqual(qr.q_scale(), qlast.q_scale()) + self.assertEqual(qr.q_zero_point(), qlast.q_zero_point()) + self.assertEqual(qlast.dequantize(), qr.dequantize()) + + def test_qtensor_per_channel_permute(self): + r = torch.rand(20, 10, 2, 2, dtype=torch.float) * 4 - 2 + scales = torch.rand(10) * 0.02 + 0.01 + zero_points = torch.round(torch.rand(10) * 2 - 1).to(torch.long) + qr = torch.quantize_linear_per_channel(r, scales, zero_points, [1], torch.qint8) + + # we can't reorder the axis + with self.assertRaises(RuntimeError): + qr.transpose(0, 1) + + # but we can change memory format + qlast = qr.contiguous(memory_format=torch.channels_last) + self.assertEqual(qr.stride(), list(reversed(sorted(qr.stride())))) + self.assertNotEqual(qlast.stride(), list(reversed(sorted(qlast.stride())))) + self.assertEqual(qr.int_repr(), qlast.int_repr()) + self.assertEqual(scales, qlast.q_per_channel_scales()) + self.assertEqual(zero_points, qlast.q_per_channel_zero_points()) + self.assertEqual((1,), qlast.q_per_channel_axis()) + self.assertEqual(qlast.dequantize(), qr.dequantize()) + def test_qtensor_load_save(self): scale = 2.0 @@ -282,7 +314,7 @@ def test_qscheme_pickle(self): buf.seek(0) f2 = torch.load(buf) - self.assertTrue(f2.qscheme == torch.per_tensor_symmetric) + self.assertEqual(f2.qscheme, torch.per_tensor_symmetric) if __name__ == "__main__": run_tests() diff --git a/test/test_quantizer.py b/test/test_quantizer.py index 2fcbf4a319ac0..bca792ee51e25 100644 --- a/test/test_quantizer.py +++ b/test/test_quantizer.py @@ -41,6 +41,7 @@ def __init__(self): " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" " with instruction set support avx2 or newer.", ) +@unittest.skip("temoprarily disable the test") class QuantizerTestCase(TestCase): @_tmp_donotuse_dont_inline_everything def test_default(self): @@ -82,8 +83,6 @@ def forward(self, x): def get_forward(m): return m._c._get_method('forward') # TODO: test jit.script as well - torch._C._jit_pass_constant_propagation(get_forward(script_module).graph) - ScriptedObserver = torch.jit.script(Observer()) ScriptedWeightObserver = torch.jit.script(WeightObserver()) qconfig_dict = { diff --git a/test/test_rpc.py b/test/test_rpc.py index d7c87ac486d4f..4bbaca83d2712 100644 --- a/test/test_rpc.py +++ b/test/test_rpc.py @@ -7,9 +7,17 @@ import torch import torch.distributed as dist +if not dist.is_available(): + print("c10d not available, skipping tests") + sys.exit(0) + +from torch.distributed.rpc import RpcBackend from common_distributed import MultiProcessTestCase from common_utils import load_tests, run_tests +from os import getenv +BACKEND = getenv('RPC_BACKEND', RpcBackend.PROCESS_GROUP) +RPC_INIT_URL = getenv('RPC_INIT_URL', '') # it is used to test python user defined function over rpc def my_function(a, b, c): @@ -63,11 +71,6 @@ def my_static_method(f): load_tests = load_tests -if not dist.is_available(): - print("c10d not available, skipping tests") - sys.exit(0) - - def _wrap_with_rpc(func): ''' We use this decorator for setting up and tearing down state since @@ -79,7 +82,10 @@ def wrapper(self): store = dist.FileStore(self.file.name, self.world_size) dist.init_process_group(backend='gloo', rank=self.rank, world_size=self.world_size, store=store) - dist.init_model_parallel('worker%d' % self.rank) + dist.init_model_parallel(self_name='worker%d' % self.rank, + backend=BACKEND, + self_rank=self.rank, + init_method=RPC_INIT_URL) func(self) dist.join_rpc() @@ -123,33 +129,39 @@ def test_self_add(self): ): dist.rpc(self_worker_name, torch.add, args=(torch.ones(2, 2), 1)) - def test_duplicated_names(self): + def test_reinit(self): store = dist.FileStore(self.file.name, self.world_size) dist.init_process_group(backend="gloo", rank=self.rank, world_size=self.world_size, store=store) with self.assertRaisesRegex(RuntimeError, "is not unique"): - dist.init_model_parallel("duplicated_name") + dist.init_model_parallel(self_name="duplicate_name", + backend=BACKEND, + self_rank=self.rank, + init_method=RPC_INIT_URL) dist.join_rpc() + @unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/25912") def test_invalid_names(self): store = dist.FileStore(self.file.name, self.world_size) dist.init_process_group(backend="gloo", rank=self.rank, world_size=self.world_size, store=store) with self.assertRaisesRegex(RuntimeError, "Worker name must match"): - dist.init_model_parallel("abc*") + dist.init_model_parallel(self_name="abc*") with self.assertRaisesRegex(RuntimeError, "Worker name must match"): - dist.init_model_parallel(" ") + dist.init_model_parallel(self_name=" ") with self.assertRaisesRegex(RuntimeError, "must be non-empty"): - dist.init_model_parallel("") + dist.init_model_parallel(self_name="") # If the number in the message does not match, it is likely that the # value of MAX_NAME_LEN in RPC WorkerId has changed. with self.assertRaisesRegex(RuntimeError, "shorter than 128"): - dist.init_model_parallel("".join(["a" for _ in range(500)])) - + dist.init_model_parallel(self_name="".join(["a" for _ in range(500)]), + backend=BACKEND, + self_rank=self.rank, + init_method=RPC_INIT_URL) dist.join_rpc() @_wrap_with_rpc @@ -385,5 +397,33 @@ def test_stress_light_rpc(self): def test_stress_heavy_rpc(self): self._stress_test_rpc(heavy_rpc, repeat=20, args=(torch.ones(100, 100),)) -if __name__ == "__main__": + @_wrap_with_rpc + def test_builtin_remote_ret(self): + n = self.rank + 1 + dst_rank = n % self.world_size + rref = dist.remote('worker{}'.format(dst_rank), torch.add, + args=(torch.ones(n, n), torch.ones(n, n))) + self.assertEqual(rref.to_here(), torch.ones(n, n) * 2) + + @_wrap_with_rpc + def test_multi_builtin_remote_ret(self): + m = 10 + n = self.rank + 1 + dst_rank = n % self.world_size + rrefs = [] + expected = [] + for i in range(m): + n = n + i + rrefs.append(dist.remote( + 'worker{}'.format(dst_rank), + torch.add, + args=(torch.ones(n, n), torch.ones(n, n)) + )) + expected.append(torch.ones(n, n) * 2) + + for i in range(m): + self.assertEqual(rrefs[i].to_here(), expected[i]) + + +if __name__ == '__main__': run_tests() diff --git a/test/test_sparse.py b/test/test_sparse.py index 1243103e6e498..f7795c68046a9 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -2107,21 +2107,21 @@ def test_cuda_sparse_cpu_dense_add(self): sparse_y = torch.cuda.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(), torch.randn(4, 4, 4).cuda(), [3, 4, 4]) - with self.assertRaisesRegex(RuntimeError, "add: expected 'other' to be a CPU tensor\\, but got a CUDA tensor"): + with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"): x + sparse_y x = torch.zeros(3, 4, 4, 0) sparse_y = torch.cuda.sparse.FloatTensor(torch.zeros(1, 4).long().cuda(), torch.randn(4, 4, 4, 0).cuda(), [3, 4, 4, 0]) - with self.assertRaisesRegex(RuntimeError, "add: expected 'other' to be a CPU tensor\\, but got a CUDA tensor"): + with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"): x + sparse_y x = torch.zeros(0, 4, 4, 0) sparse_y = torch.cuda.sparse.FloatTensor(torch.LongTensor(1, 0).cuda(), torch.randn(0, 4, 4, 0).cuda(), [0, 4, 4, 0]) - with self.assertRaisesRegex(RuntimeError, "add: expected 'other' to be a CPU tensor\\, but got a CUDA tensor"): + with self.assertRaisesRegex(RuntimeError, "add: expected 'self' to be a CUDA tensor, but got a CPU tensor"): x + sparse_y diff --git a/test/test_tensorboard.py b/test/test_tensorboard.py index 7c9f1830196be..f34589de6a217 100644 --- a/test/test_tensorboard.py +++ b/test/test_tensorboard.py @@ -9,6 +9,7 @@ import shutil import sys import unittest +import uuid TEST_TENSORBOARD = True try: @@ -44,8 +45,6 @@ import torch from common_utils import TestCase, run_tests, TEST_WITH_ASAN -from google.protobuf import text_format -from PIL import Image def tensor_N(shape, dtype=float): numel = np.prod(shape) @@ -54,11 +53,22 @@ def tensor_N(shape, dtype=float): class BaseTestCase(TestCase): """ Base class used for all TensorBoard tests """ + def setUp(self): + if not TEST_TENSORBOARD: + return self.skipTest("Skip the test since TensorBoard is not installed") + self.temp_dirs = [] + + def createSummaryWriter(self): + temp_dir = str(uuid.uuid4()) + self.temp_dirs.append(temp_dir) + return SummaryWriter(temp_dir) + def tearDown(self): super(BaseTestCase, self).tearDown() - if os.path.exists('runs'): - # Remove directory created by SummaryWriter - shutil.rmtree('runs') + # Remove directories created by SummaryWriter + for temp_dir in self.temp_dirs: + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) if TEST_TENSORBOARD: @@ -66,603 +76,615 @@ def tearDown(self): from torch.utils.tensorboard._utils import _prepare_video, convert_to_HWC from torch.utils.tensorboard._convert_np import make_np from torch.utils.tensorboard import _caffe2_graph as c2_graph - - class TestTensorBoardPyTorchNumpy(BaseTestCase): - def test_pytorch_np(self): - tensors = [torch.rand(3, 10, 10), torch.rand(1), torch.rand(1, 2, 3, 4, 5)] - for tensor in tensors: - # regular tensor - self.assertIsInstance(make_np(tensor), np.ndarray) - - # CUDA tensor - if torch.cuda.device_count() > 0: - self.assertIsInstance(make_np(tensor.cuda()), np.ndarray) - - # regular variable - self.assertIsInstance(make_np(torch.autograd.Variable(tensor)), np.ndarray) - - # CUDA variable - if torch.cuda.device_count() > 0: - self.assertIsInstance(make_np(torch.autograd.Variable(tensor).cuda()), np.ndarray) - - # python primitive type - self.assertIsInstance(make_np(0), np.ndarray) - self.assertIsInstance(make_np(0.1), np.ndarray) - - def test_pytorch_autograd_np(self): - x = torch.autograd.Variable(torch.Tensor(1)) - self.assertIsInstance(make_np(x), np.ndarray) - - def test_pytorch_write(self): - with SummaryWriter() as w: - w.add_scalar('scalar', torch.autograd.Variable(torch.rand(1)), 0) - - def test_pytorch_histogram(self): - with SummaryWriter() as w: - w.add_histogram('float histogram', torch.rand((50,))) - w.add_histogram('int histogram', torch.randint(0, 100, (50,))) - - def test_pytorch_histogram_raw(self): - with SummaryWriter() as w: - num = 50 - floats = make_np(torch.rand((num,))) - bins = [0.0, 0.25, 0.5, 0.75, 1.0] - counts, limits = np.histogram(floats, bins) - sum_sq = floats.dot(floats).item() - w.add_histogram_raw('float histogram raw', - min=floats.min().item(), - max=floats.max().item(), - num=num, - sum=floats.sum().item(), - sum_squares=sum_sq, - bucket_limits=limits[1:].tolist(), - bucket_counts=counts.tolist()) - - ints = make_np(torch.randint(0, 100, (num,))) - bins = [0, 25, 50, 75, 100] - counts, limits = np.histogram(ints, bins) - sum_sq = ints.dot(ints).item() - w.add_histogram_raw('int histogram raw', - min=ints.min().item(), - max=ints.max().item(), - num=num, - sum=ints.sum().item(), - sum_squares=sum_sq, - bucket_limits=limits[1:].tolist(), - bucket_counts=counts.tolist()) - - ints = torch.tensor(range(0, 100)).float() - nbins = 100 - counts = torch.histc(ints, bins=nbins, min=0, max=99) - limits = torch.tensor(range(nbins)) - sum_sq = ints.dot(ints).item() - w.add_histogram_raw('int histogram raw', - min=ints.min().item(), - max=ints.max().item(), - num=num, - sum=ints.sum().item(), - sum_squares=sum_sq, - bucket_limits=limits.tolist(), - bucket_counts=counts.tolist()) - - class TestTensorBoardUtils(BaseTestCase): - def test_to_HWC(self): - test_image = np.random.randint(0, 256, size=(3, 32, 32), dtype=np.uint8) - converted = convert_to_HWC(test_image, 'chw') - self.assertEqual(converted.shape, (32, 32, 3)) - test_image = np.random.randint(0, 256, size=(16, 3, 32, 32), dtype=np.uint8) - converted = convert_to_HWC(test_image, 'nchw') - self.assertEqual(converted.shape, (64, 256, 3)) - test_image = np.random.randint(0, 256, size=(32, 32), dtype=np.uint8) - converted = convert_to_HWC(test_image, 'hw') - self.assertEqual(converted.shape, (32, 32, 3)) - - def test_prepare_video(self): - # At each timeframe, the sum over all other - # dimensions of the video should be the same. - shapes = [(16, 30, 3, 28, 28), - (36, 30, 3, 28, 28), - (19, 29, 3, 23, 19), - (3, 3, 3, 3, 3)] - for s in shapes: - V_input = np.random.random(s) - V_after = _prepare_video(np.copy(V_input)) - total_frame = s[1] - V_input = np.swapaxes(V_input, 0, 1) - for f in range(total_frame): - x = np.reshape(V_input[f], newshape=(-1)) - y = np.reshape(V_after[f], newshape=(-1)) - np.testing.assert_array_almost_equal(np.sum(x), np.sum(y)) - - def test_numpy_vid_uint8(self): - V_input = np.random.randint(0, 256, (16, 30, 3, 28, 28)).astype(np.uint8) - V_after = _prepare_video(np.copy(V_input)) * 255 - total_frame = V_input.shape[1] + from google.protobuf import text_format + from PIL import Image + +class TestTensorBoardPyTorchNumpy(BaseTestCase): + def test_pytorch_np(self): + tensors = [torch.rand(3, 10, 10), torch.rand(1), torch.rand(1, 2, 3, 4, 5)] + for tensor in tensors: + # regular tensor + self.assertIsInstance(make_np(tensor), np.ndarray) + + # CUDA tensor + if torch.cuda.device_count() > 0: + self.assertIsInstance(make_np(tensor.cuda()), np.ndarray) + + # regular variable + self.assertIsInstance(make_np(torch.autograd.Variable(tensor)), np.ndarray) + + # CUDA variable + if torch.cuda.device_count() > 0: + self.assertIsInstance(make_np(torch.autograd.Variable(tensor).cuda()), np.ndarray) + + # python primitive type + self.assertIsInstance(make_np(0), np.ndarray) + self.assertIsInstance(make_np(0.1), np.ndarray) + + def test_pytorch_autograd_np(self): + x = torch.autograd.Variable(torch.Tensor(1)) + self.assertIsInstance(make_np(x), np.ndarray) + + def test_pytorch_write(self): + with self.createSummaryWriter() as w: + w.add_scalar('scalar', torch.autograd.Variable(torch.rand(1)), 0) + + def test_pytorch_histogram(self): + with self.createSummaryWriter() as w: + w.add_histogram('float histogram', torch.rand((50,))) + w.add_histogram('int histogram', torch.randint(0, 100, (50,))) + + def test_pytorch_histogram_raw(self): + with self.createSummaryWriter() as w: + num = 50 + floats = make_np(torch.rand((num,))) + bins = [0.0, 0.25, 0.5, 0.75, 1.0] + counts, limits = np.histogram(floats, bins) + sum_sq = floats.dot(floats).item() + w.add_histogram_raw('float histogram raw', + min=floats.min().item(), + max=floats.max().item(), + num=num, + sum=floats.sum().item(), + sum_squares=sum_sq, + bucket_limits=limits[1:].tolist(), + bucket_counts=counts.tolist()) + + ints = make_np(torch.randint(0, 100, (num,))) + bins = [0, 25, 50, 75, 100] + counts, limits = np.histogram(ints, bins) + sum_sq = ints.dot(ints).item() + w.add_histogram_raw('int histogram raw', + min=ints.min().item(), + max=ints.max().item(), + num=num, + sum=ints.sum().item(), + sum_squares=sum_sq, + bucket_limits=limits[1:].tolist(), + bucket_counts=counts.tolist()) + + ints = torch.tensor(range(0, 100)).float() + nbins = 100 + counts = torch.histc(ints, bins=nbins, min=0, max=99) + limits = torch.tensor(range(nbins)) + sum_sq = ints.dot(ints).item() + w.add_histogram_raw('int histogram raw', + min=ints.min().item(), + max=ints.max().item(), + num=num, + sum=ints.sum().item(), + sum_squares=sum_sq, + bucket_limits=limits.tolist(), + bucket_counts=counts.tolist()) + +class TestTensorBoardUtils(BaseTestCase): + def test_to_HWC(self): + test_image = np.random.randint(0, 256, size=(3, 32, 32), dtype=np.uint8) + converted = convert_to_HWC(test_image, 'chw') + self.assertEqual(converted.shape, (32, 32, 3)) + test_image = np.random.randint(0, 256, size=(16, 3, 32, 32), dtype=np.uint8) + converted = convert_to_HWC(test_image, 'nchw') + self.assertEqual(converted.shape, (64, 256, 3)) + test_image = np.random.randint(0, 256, size=(32, 32), dtype=np.uint8) + converted = convert_to_HWC(test_image, 'hw') + self.assertEqual(converted.shape, (32, 32, 3)) + + def test_prepare_video(self): + # At each timeframe, the sum over all other + # dimensions of the video should be the same. + shapes = [ + (16, 30, 3, 28, 28), + (36, 30, 3, 28, 28), + (19, 29, 3, 23, 19), + (3, 3, 3, 3, 3) + ] + for s in shapes: + V_input = np.random.random(s) + V_after = _prepare_video(np.copy(V_input)) + total_frame = s[1] V_input = np.swapaxes(V_input, 0, 1) for f in range(total_frame): x = np.reshape(V_input[f], newshape=(-1)) y = np.reshape(V_after[f], newshape=(-1)) np.testing.assert_array_almost_equal(np.sum(x), np.sum(y)) - freqs = [262, 294, 330, 349, 392, 440, 440, 440, 440, 440, 440] - - true_positive_counts = [75, 64, 21, 5, 0] - false_positive_counts = [150, 105, 18, 0, 0] - true_negative_counts = [0, 45, 132, 150, 150] - false_negative_counts = [0, 11, 54, 70, 75] - precision = [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0] - recall = [1.0, 0.8533334, 0.28, 0.0666667, 0.0] - - class TestTensorBoardWriter(BaseTestCase): - def test_writer(self): - with SummaryWriter() as writer: - sample_rate = 44100 - - n_iter = 0 - writer.add_scalar('data/scalar_systemtime', 0.1, n_iter) - writer.add_scalar('data/scalar_customtime', 0.2, n_iter, walltime=n_iter) - writer.add_scalars('data/scalar_group', {"xsinx": n_iter * np.sin(n_iter), - "xcosx": n_iter * np.cos(n_iter), - "arctanx": np.arctan(n_iter)}, n_iter) - x = np.zeros((32, 3, 64, 64)) # output from network - writer.add_images('Image', x, n_iter) # Tensor - writer.add_image_with_boxes('imagebox', - np.zeros((3, 64, 64)), - np.array([[10, 10, 40, 40], [40, 40, 60, 60]]), - n_iter) - x = np.zeros(sample_rate * 2) - - writer.add_audio('myAudio', x, n_iter) - writer.add_video('myVideo', np.random.rand(16, 48, 1, 28, 28).astype(np.float32), n_iter) - writer.add_text('Text', 'text logged at step:' + str(n_iter), n_iter) - writer.add_text('markdown Text', '''a|b\n-|-\nc|d''', n_iter) - writer.add_histogram('hist', np.random.rand(100, 100), n_iter) - writer.add_pr_curve('xoxo', np.random.randint(2, size=100), np.random.rand( - 100), n_iter) # needs tensorboard 0.4RC or later - writer.add_pr_curve_raw('prcurve with raw data', true_positive_counts, - false_positive_counts, - true_negative_counts, - false_negative_counts, - precision, - recall, n_iter) - - v = np.array([[[1, 1, 1], [-1, -1, 1], [1, -1, -1], [-1, 1, -1]]], dtype=float) - c = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 0, 255]]], dtype=int) - f = np.array([[[0, 2, 3], [0, 3, 1], [0, 1, 2], [1, 3, 2]]], dtype=int) - writer.add_mesh('my_mesh', vertices=v, colors=c, faces=f) - - class TestTensorBoardSummaryWriter(BaseTestCase): - def test_summary_writer_ctx(self): - # after using a SummaryWriter as a ctx it should be closed - with SummaryWriter(filename_suffix='.test') as writer: - writer.add_scalar('test', 1) - self.assertIs(writer.file_writer, None) - - def test_summary_writer_close(self): - # Opening and closing SummaryWriter a lot should not run into - # OSError: [Errno 24] Too many open files - passed = True - try: - writer = SummaryWriter() - writer.close() - except OSError: - passed = False - - self.assertTrue(passed) - - def test_pathlib(self): - import sys - if sys.version_info.major == 2: - import pathlib2 as pathlib - else: - import pathlib - p = pathlib.Path('./pathlibtest') - with SummaryWriter(p) as writer: - writer.add_scalar('test', 1) - import shutil - shutil.rmtree(str(p)) - - class TestTensorBoardEmbedding(BaseTestCase): - def test_embedding(self): - w = SummaryWriter() - all_features = torch.Tensor([[1, 2, 3], [5, 4, 1], [3, 7, 7]]) - all_labels = torch.Tensor([33, 44, 55]) - all_images = torch.zeros(3, 3, 5, 5) - - w.add_embedding(all_features, - metadata=all_labels, - label_img=all_images, - global_step=2) - - dataset_label = ['test'] * 2 + ['train'] * 2 - all_labels = list(zip(all_labels, dataset_label)) - w.add_embedding(all_features, - metadata=all_labels, - label_img=all_images, - metadata_header=['digit', 'dataset'], - global_step=2) - # assert... - - def test_embedding_64(self): - w = SummaryWriter() - all_features = torch.Tensor([[1, 2, 3], [5, 4, 1], [3, 7, 7]]) - all_labels = torch.Tensor([33, 44, 55]) - all_images = torch.zeros((3, 3, 5, 5), dtype=torch.float64) - - w.add_embedding(all_features, - metadata=all_labels, - label_img=all_images, - global_step=2) - - dataset_label = ['test'] * 2 + ['train'] * 2 - all_labels = list(zip(all_labels, dataset_label)) - w.add_embedding(all_features, - metadata=all_labels, - label_img=all_images, - metadata_header=['digit', 'dataset'], - global_step=2) - - class TestTensorBoardSummary(BaseTestCase): - def test_uint8_image(self): - ''' - Tests that uint8 image (pixel values in [0, 255]) is not changed - ''' - test_image = np.random.randint(0, 256, size=(3, 32, 32), dtype=np.uint8) - scale_factor = summary._calc_scale_factor(test_image) - self.assertEqual(scale_factor, 1, 'Values are already in [0, 255], scale factor should be 1') - - def test_float32_image(self): - ''' - Tests that float32 image (pixel values in [0, 1]) are scaled correctly - to [0, 255] - ''' - test_image = np.random.rand(3, 32, 32).astype(np.float32) - scale_factor = summary._calc_scale_factor(test_image) - self.assertEqual(scale_factor, 255, 'Values are in [0, 1], scale factor should be 255') - - def test_list_input(self): - with self.assertRaises(Exception) as e_info: - summary.histogram('dummy', [1, 3, 4, 5, 6], 'tensorflow') - - def test_empty_input(self): - with self.assertRaises(Exception) as e_info: - summary.histogram('dummy', np.ndarray(0), 'tensorflow') - - def test_image_with_boxes(self): - self.assertTrue(compare_image_proto(summary.image_boxes('dummy', - tensor_N(shape=(3, 32, 32)), - np.array([[10, 10, 40, 40]])), - self)) - - def test_image_with_one_channel(self): - self.assertTrue(compare_image_proto(summary.image('dummy', - tensor_N(shape=(1, 8, 8)), - dataformats='CHW'), - self)) # noqa E127 - - def test_image_with_one_channel_batched(self): - self.assertTrue(compare_image_proto(summary.image('dummy', - tensor_N(shape=(2, 1, 8, 8)), - dataformats='NCHW'), - self)) # noqa E127 - - def test_image_with_3_channel_batched(self): - self.assertTrue(compare_image_proto(summary.image('dummy', - tensor_N(shape=(2, 3, 8, 8)), - dataformats='NCHW'), - self)) # noqa E127 - - def test_image_without_channel(self): - self.assertTrue(compare_image_proto(summary.image('dummy', - tensor_N(shape=(8, 8)), - dataformats='HW'), - self)) # noqa E127 - - def test_video(self): - try: - import moviepy # noqa F401 - except ImportError: - return - self.assertTrue(compare_proto(summary.video('dummy', tensor_N(shape=(4, 3, 1, 8, 8))), self)) - summary.video('dummy', np.random.rand(16, 48, 1, 28, 28)) - summary.video('dummy', np.random.rand(20, 7, 1, 8, 8)) - - def test_audio(self): - self.assertTrue(compare_proto(summary.audio('dummy', tensor_N(shape=(42,))), self)) - - def test_text(self): - self.assertTrue(compare_proto(summary.text('dummy', 'text 123'), self)) - - def test_histogram_auto(self): - self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='auto', max_bins=5), self)) - - def test_histogram_fd(self): - self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='fd', max_bins=5), self)) - - def test_histogram_doane(self): - self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='doane', max_bins=5), self)) - - def test_custom_scalars(self): - layout = {'Taiwan': {'twse': ['Multiline', ['twse/0050', 'twse/2330']]}, - 'USA': {'dow': ['Margin', ['dow/aaa', 'dow/bbb', 'dow/ccc']], - 'nasdaq': ['Margin', ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']]}} - summary.custom_scalars(layout) # only smoke test. Because protobuf in python2/3 serialize dictionary differently. - - def test_hparams_smoke(self): - hp = {'lr': 0.1, 'bsize': 4} - mt = {'accuracy': 0.1, 'loss': 10} - summary.hparams(hp, mt) # only smoke test. Because protobuf in python2/3 serialize dictionary differently. - - hp = {'use_magic': True, 'init_string': "42"} - mt = {'accuracy': 0.1, 'loss': 10} - summary.hparams(hp, mt) - - mt = {'accuracy': torch.zeros(1), 'loss': torch.zeros(1)} - summary.hparams(hp, mt) - - def test_hparams_wrong_parameter(self): - with self.assertRaises(TypeError): - summary.hparams([], {}) - with self.assertRaises(TypeError): - summary.hparams({}, []) - with self.assertRaises(ValueError): - res = summary.hparams({'pytorch': [1, 2]}, {'accuracy': 2.0}) - # metric data is used in writer.py so the code path is different, which leads to different exception type. - with self.assertRaises(NotImplementedError): - with SummaryWriter() as writer: - writer.add_hparams({'pytorch': 1.0}, {'accuracy': [1, 2]}) - - def test_mesh(self): + def test_numpy_vid_uint8(self): + V_input = np.random.randint(0, 256, (16, 30, 3, 28, 28)).astype(np.uint8) + V_after = _prepare_video(np.copy(V_input)) * 255 + total_frame = V_input.shape[1] + V_input = np.swapaxes(V_input, 0, 1) + for f in range(total_frame): + x = np.reshape(V_input[f], newshape=(-1)) + y = np.reshape(V_after[f], newshape=(-1)) + np.testing.assert_array_almost_equal(np.sum(x), np.sum(y)) + +freqs = [262, 294, 330, 349, 392, 440, 440, 440, 440, 440, 440] + +true_positive_counts = [75, 64, 21, 5, 0] +false_positive_counts = [150, 105, 18, 0, 0] +true_negative_counts = [0, 45, 132, 150, 150] +false_negative_counts = [0, 11, 54, 70, 75] +precision = [0.3333333, 0.3786982, 0.5384616, 1.0, 0.0] +recall = [1.0, 0.8533334, 0.28, 0.0666667, 0.0] + +class TestTensorBoardWriter(BaseTestCase): + def test_writer(self): + with self.createSummaryWriter() as writer: + sample_rate = 44100 + + n_iter = 0 + writer.add_scalar('data/scalar_systemtime', 0.1, n_iter) + writer.add_scalar('data/scalar_customtime', 0.2, n_iter, walltime=n_iter) + writer.add_scalars('data/scalar_group', { + "xsinx": n_iter * np.sin(n_iter), + "xcosx": n_iter * np.cos(n_iter), + "arctanx": np.arctan(n_iter) + }, n_iter) + x = np.zeros((32, 3, 64, 64)) # output from network + writer.add_images('Image', x, n_iter) # Tensor + writer.add_image_with_boxes('imagebox', + np.zeros((3, 64, 64)), + np.array([[10, 10, 40, 40], [40, 40, 60, 60]]), + n_iter) + x = np.zeros(sample_rate * 2) + + writer.add_audio('myAudio', x, n_iter) + writer.add_video('myVideo', np.random.rand(16, 48, 1, 28, 28).astype(np.float32), n_iter) + writer.add_text('Text', 'text logged at step:' + str(n_iter), n_iter) + writer.add_text('markdown Text', '''a|b\n-|-\nc|d''', n_iter) + writer.add_histogram('hist', np.random.rand(100, 100), n_iter) + writer.add_pr_curve('xoxo', np.random.randint(2, size=100), np.random.rand( + 100), n_iter) # needs tensorboard 0.4RC or later + writer.add_pr_curve_raw('prcurve with raw data', true_positive_counts, + false_positive_counts, + true_negative_counts, + false_negative_counts, + precision, + recall, n_iter) + v = np.array([[[1, 1, 1], [-1, -1, 1], [1, -1, -1], [-1, 1, -1]]], dtype=float) c = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 0, 255]]], dtype=int) f = np.array([[[0, 2, 3], [0, 3, 1], [0, 1, 2], [1, 3, 2]]], dtype=int) - mesh = summary.mesh('my_mesh', vertices=v, colors=c, faces=f, config_dict=None) - self.assertTrue(compare_proto(mesh, self)) - - def remove_whitespace(string): - return string.replace(' ', '').replace('\t', '').replace('\n', '') - - def read_expected_content(function_ptr): - module_id = function_ptr.__class__.__module__ - test_dir = os.path.dirname(sys.modules[module_id].__file__) - functionName = function_ptr.id().split('.')[-1] - expected_file = os.path.join(test_dir, - "expect", - 'TestTensorBoard.' + functionName + ".expect") - assert os.path.exists(expected_file) - with open(expected_file, "r") as f: - return f.read() - - def compare_image_proto(actual_proto, function_ptr): - expected_str = read_expected_content(function_ptr) - expected_proto = Summary() - text_format.Parse(expected_str, expected_proto) - - [actual, expected] = [actual_proto.value[0], expected_proto.value[0]] - actual_img = Image.open(io.BytesIO(actual.image.encoded_image_string)) - expected_img = Image.open(io.BytesIO(expected.image.encoded_image_string)) - - return ( - actual.tag == expected.tag and - actual.image.height == expected.image.height and - actual.image.width == expected.image.width and - actual.image.colorspace == expected.image.colorspace and - actual_img == expected_img - ) - - def compare_proto(str_to_compare, function_ptr): - expected = read_expected_content(function_ptr) - str_to_compare = str(str_to_compare) - return remove_whitespace(str_to_compare) == remove_whitespace(expected) - - def write_proto(str_to_compare, function_ptr): - module_id = function_ptr.__class__.__module__ - test_dir = os.path.dirname(sys.modules[module_id].__file__) - functionName = function_ptr.id().split('.')[-1] - expected_file = os.path.join(test_dir, - "expect", - 'TestTensorBoard.' + functionName + ".expect") - with open(expected_file, 'w') as f: - f.write(str(str_to_compare)) - - class TestTensorBoardPytorchGraph(BaseTestCase): - def test_pytorch_graph(self): - dummy_input = (torch.zeros(1, 3),) - - class myLinear(torch.nn.Module): - def __init__(self): - super(myLinear, self).__init__() - self.l = torch.nn.Linear(3, 5) - - def forward(self, x): - return self.l(x) - - with SummaryWriter(comment='LinearModel') as w: - w.add_graph(myLinear(), dummy_input) - - def test_mlp_graph(self): - dummy_input = (torch.zeros(2, 1, 28, 28),) - - # This MLP class with the above input is expected - # to fail JIT optimizations as seen at - # https://github.com/pytorch/pytorch/issues/18903 - # - # However, it should not raise an error during - # the add_graph call and still continue. - class myMLP(torch.nn.Module): - def __init__(self): - super(myMLP, self).__init__() - self.input_len = 1 * 28 * 28 - self.fc1 = torch.nn.Linear(self.input_len, 1200) - self.fc2 = torch.nn.Linear(1200, 1200) - self.fc3 = torch.nn.Linear(1200, 10) - - def forward(self, x, update_batch_stats=True): - h = torch.nn.functional.relu( - self.fc1(x.view(-1, self.input_len))) - h = self.fc2(h) - h = torch.nn.functional.relu(h) - h = self.fc3(h) - return h - - with SummaryWriter(comment='MLPModel') as w: - w.add_graph(myMLP(), dummy_input) - - def test_wrong_input_size(self): - with self.assertRaises(RuntimeError) as e_info: - dummy_input = torch.rand(1, 9) - model = torch.nn.Linear(3, 5) - with SummaryWriter(comment='expect_error') as w: - w.add_graph(model, dummy_input) # error - - @skipIfNoTorchVision - def test_torchvision_smoke(self): - model_input_shapes = { - 'alexnet': (2, 3, 224, 224), - 'resnet34': (2, 3, 224, 224), - 'resnet152': (2, 3, 224, 224), - 'densenet121': (2, 3, 224, 224), - 'vgg16': (2, 3, 224, 224), - 'vgg19': (2, 3, 224, 224), - 'vgg16_bn': (2, 3, 224, 224), - 'vgg19_bn': (2, 3, 224, 224), - 'mobilenet_v2': (2, 3, 224, 224), + writer.add_mesh('my_mesh', vertices=v, colors=c, faces=f) + +class TestTensorBoardSummaryWriter(BaseTestCase): + def test_summary_writer_ctx(self): + # after using a SummaryWriter as a ctx it should be closed + with self.createSummaryWriter() as writer: + writer.add_scalar('test', 1) + self.assertIs(writer.file_writer, None) + + def test_summary_writer_close(self): + # Opening and closing SummaryWriter a lot should not run into + # OSError: [Errno 24] Too many open files + passed = True + try: + writer = self.createSummaryWriter() + writer.close() + except OSError: + passed = False + + self.assertTrue(passed) + + def test_pathlib(self): + import sys + if sys.version_info.major == 2: + import pathlib2 as pathlib + else: + import pathlib + p = pathlib.Path('./pathlibtest' + str(uuid.uuid4())) + with SummaryWriter(p) as writer: + writer.add_scalar('test', 1) + import shutil + shutil.rmtree(str(p)) + +class TestTensorBoardEmbedding(BaseTestCase): + def test_embedding(self): + w = self.createSummaryWriter() + all_features = torch.Tensor([[1, 2, 3], [5, 4, 1], [3, 7, 7]]) + all_labels = torch.Tensor([33, 44, 55]) + all_images = torch.zeros(3, 3, 5, 5) + + w.add_embedding(all_features, + metadata=all_labels, + label_img=all_images, + global_step=2) + + dataset_label = ['test'] * 2 + ['train'] * 2 + all_labels = list(zip(all_labels, dataset_label)) + w.add_embedding(all_features, + metadata=all_labels, + label_img=all_images, + metadata_header=['digit', 'dataset'], + global_step=2) + # assert... + + def test_embedding_64(self): + w = self.createSummaryWriter() + all_features = torch.Tensor([[1, 2, 3], [5, 4, 1], [3, 7, 7]]) + all_labels = torch.Tensor([33, 44, 55]) + all_images = torch.zeros((3, 3, 5, 5), dtype=torch.float64) + + w.add_embedding(all_features, + metadata=all_labels, + label_img=all_images, + global_step=2) + + dataset_label = ['test'] * 2 + ['train'] * 2 + all_labels = list(zip(all_labels, dataset_label)) + w.add_embedding(all_features, + metadata=all_labels, + label_img=all_images, + metadata_header=['digit', 'dataset'], + global_step=2) + +class TestTensorBoardSummary(BaseTestCase): + def test_uint8_image(self): + ''' + Tests that uint8 image (pixel values in [0, 255]) is not changed + ''' + test_image = np.random.randint(0, 256, size=(3, 32, 32), dtype=np.uint8) + scale_factor = summary._calc_scale_factor(test_image) + self.assertEqual(scale_factor, 1, 'Values are already in [0, 255], scale factor should be 1') + + def test_float32_image(self): + ''' + Tests that float32 image (pixel values in [0, 1]) are scaled correctly + to [0, 255] + ''' + test_image = np.random.rand(3, 32, 32).astype(np.float32) + scale_factor = summary._calc_scale_factor(test_image) + self.assertEqual(scale_factor, 255, 'Values are in [0, 1], scale factor should be 255') + + def test_list_input(self): + with self.assertRaises(Exception) as e_info: + summary.histogram('dummy', [1, 3, 4, 5, 6], 'tensorflow') + + def test_empty_input(self): + with self.assertRaises(Exception) as e_info: + summary.histogram('dummy', np.ndarray(0), 'tensorflow') + + def test_image_with_boxes(self): + self.assertTrue(compare_image_proto(summary.image_boxes('dummy', + tensor_N(shape=(3, 32, 32)), + np.array([[10, 10, 40, 40]])), + self)) + + def test_image_with_one_channel(self): + self.assertTrue(compare_image_proto(summary.image('dummy', + tensor_N(shape=(1, 8, 8)), + dataformats='CHW'), + self)) # noqa E127 + + def test_image_with_one_channel_batched(self): + self.assertTrue(compare_image_proto(summary.image('dummy', + tensor_N(shape=(2, 1, 8, 8)), + dataformats='NCHW'), + self)) # noqa E127 + + def test_image_with_3_channel_batched(self): + self.assertTrue(compare_image_proto(summary.image('dummy', + tensor_N(shape=(2, 3, 8, 8)), + dataformats='NCHW'), + self)) # noqa E127 + + def test_image_without_channel(self): + self.assertTrue(compare_image_proto(summary.image('dummy', + tensor_N(shape=(8, 8)), + dataformats='HW'), + self)) # noqa E127 + + def test_video(self): + try: + import moviepy # noqa F401 + except ImportError: + return + self.assertTrue(compare_proto(summary.video('dummy', tensor_N(shape=(4, 3, 1, 8, 8))), self)) + summary.video('dummy', np.random.rand(16, 48, 1, 28, 28)) + summary.video('dummy', np.random.rand(20, 7, 1, 8, 8)) + + def test_audio(self): + self.assertTrue(compare_proto(summary.audio('dummy', tensor_N(shape=(42,))), self)) + + def test_text(self): + self.assertTrue(compare_proto(summary.text('dummy', 'text 123'), self)) + + def test_histogram_auto(self): + self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='auto', max_bins=5), self)) + + def test_histogram_fd(self): + self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='fd', max_bins=5), self)) + + def test_histogram_doane(self): + self.assertTrue(compare_proto(summary.histogram('dummy', tensor_N(shape=(1024,)), bins='doane', max_bins=5), self)) + + def test_custom_scalars(self): + layout = { + 'Taiwan': { + 'twse': ['Multiline', ['twse/0050', 'twse/2330']] + }, + 'USA': { + 'dow': ['Margin', ['dow/aaa', 'dow/bbb', 'dow/ccc']], + 'nasdaq': ['Margin', ['nasdaq/aaa', 'nasdaq/bbb', 'nasdaq/ccc']] } - for model_name, input_shape in model_input_shapes.items(): - with SummaryWriter(comment=model_name) as w: - model = getattr(torchvision.models, model_name)() - w.add_graph(model, torch.zeros(input_shape)) - - class TestTensorBoardFigure(BaseTestCase): - @skipIfNoMatplotlib - def test_figure(self): - writer = SummaryWriter() - - figure, axes = plt.figure(), plt.gca() - circle1 = plt.Circle((0.2, 0.5), 0.2, color='r') - circle2 = plt.Circle((0.8, 0.5), 0.2, color='g') - axes.add_patch(circle1) - axes.add_patch(circle2) - plt.axis('scaled') + } + summary.custom_scalars(layout) # only smoke test. Because protobuf in python2/3 serialize dictionary differently. + + def test_hparams_smoke(self): + hp = {'lr': 0.1, 'bsize': 4} + mt = {'accuracy': 0.1, 'loss': 10} + summary.hparams(hp, mt) # only smoke test. Because protobuf in python2/3 serialize dictionary differently. + + hp = {'use_magic': True, 'init_string': "42"} + mt = {'accuracy': 0.1, 'loss': 10} + summary.hparams(hp, mt) + + mt = {'accuracy': torch.zeros(1), 'loss': torch.zeros(1)} + summary.hparams(hp, mt) + + def test_hparams_wrong_parameter(self): + with self.assertRaises(TypeError): + summary.hparams([], {}) + with self.assertRaises(TypeError): + summary.hparams({}, []) + with self.assertRaises(ValueError): + res = summary.hparams({'pytorch': [1, 2]}, {'accuracy': 2.0}) + # metric data is used in writer.py so the code path is different, which leads to different exception type. + with self.assertRaises(NotImplementedError): + with self.createSummaryWriter() as writer: + writer.add_hparams({'pytorch': 1.0}, {'accuracy': [1, 2]}) + + def test_mesh(self): + v = np.array([[[1, 1, 1], [-1, -1, 1], [1, -1, -1], [-1, 1, -1]]], dtype=float) + c = np.array([[[255, 0, 0], [0, 255, 0], [0, 0, 255], [255, 0, 255]]], dtype=int) + f = np.array([[[0, 2, 3], [0, 3, 1], [0, 1, 2], [1, 3, 2]]], dtype=int) + mesh = summary.mesh('my_mesh', vertices=v, colors=c, faces=f, config_dict=None) + self.assertTrue(compare_proto(mesh, self)) + +def remove_whitespace(string): + return string.replace(' ', '').replace('\t', '').replace('\n', '') + +def read_expected_content(function_ptr): + module_id = function_ptr.__class__.__module__ + test_dir = os.path.dirname(sys.modules[module_id].__file__) + functionName = function_ptr.id().split('.')[-1] + expected_file = os.path.join(test_dir, + "expect", + 'TestTensorBoard.' + functionName + ".expect") + assert os.path.exists(expected_file) + with open(expected_file, "r") as f: + return f.read() + +def compare_image_proto(actual_proto, function_ptr): + expected_str = read_expected_content(function_ptr) + expected_proto = Summary() + text_format.Parse(expected_str, expected_proto) + + [actual, expected] = [actual_proto.value[0], expected_proto.value[0]] + actual_img = Image.open(io.BytesIO(actual.image.encoded_image_string)) + expected_img = Image.open(io.BytesIO(expected.image.encoded_image_string)) + + return ( + actual.tag == expected.tag and + actual.image.height == expected.image.height and + actual.image.width == expected.image.width and + actual.image.colorspace == expected.image.colorspace and + actual_img == expected_img + ) + +def compare_proto(str_to_compare, function_ptr): + expected = read_expected_content(function_ptr) + str_to_compare = str(str_to_compare) + return remove_whitespace(str_to_compare) == remove_whitespace(expected) + +def write_proto(str_to_compare, function_ptr): + module_id = function_ptr.__class__.__module__ + test_dir = os.path.dirname(sys.modules[module_id].__file__) + functionName = function_ptr.id().split('.')[-1] + expected_file = os.path.join(test_dir, + "expect", + 'TestTensorBoard.' + functionName + ".expect") + with open(expected_file, 'w') as f: + f.write(str(str_to_compare)) + +class TestTensorBoardPytorchGraph(BaseTestCase): + def test_pytorch_graph(self): + dummy_input = (torch.zeros(1, 3),) + + class myLinear(torch.nn.Module): + def __init__(self): + super(myLinear, self).__init__() + self.l = torch.nn.Linear(3, 5) + + def forward(self, x): + return self.l(x) + + with self.createSummaryWriter() as w: + w.add_graph(myLinear(), dummy_input) + + def test_mlp_graph(self): + dummy_input = (torch.zeros(2, 1, 28, 28),) + + # This MLP class with the above input is expected + # to fail JIT optimizations as seen at + # https://github.com/pytorch/pytorch/issues/18903 + # + # However, it should not raise an error during + # the add_graph call and still continue. + class myMLP(torch.nn.Module): + def __init__(self): + super(myMLP, self).__init__() + self.input_len = 1 * 28 * 28 + self.fc1 = torch.nn.Linear(self.input_len, 1200) + self.fc2 = torch.nn.Linear(1200, 1200) + self.fc3 = torch.nn.Linear(1200, 10) + + def forward(self, x, update_batch_stats=True): + h = torch.nn.functional.relu( + self.fc1(x.view(-1, self.input_len))) + h = self.fc2(h) + h = torch.nn.functional.relu(h) + h = self.fc3(h) + return h + + with self.createSummaryWriter() as w: + w.add_graph(myMLP(), dummy_input) + + def test_wrong_input_size(self): + with self.assertRaises(RuntimeError) as e_info: + dummy_input = torch.rand(1, 9) + model = torch.nn.Linear(3, 5) + with self.createSummaryWriter() as w: + w.add_graph(model, dummy_input) # error + + @skipIfNoTorchVision + def test_torchvision_smoke(self): + model_input_shapes = { + 'alexnet': (2, 3, 224, 224), + 'resnet34': (2, 3, 224, 224), + 'resnet152': (2, 3, 224, 224), + 'densenet121': (2, 3, 224, 224), + 'vgg16': (2, 3, 224, 224), + 'vgg19': (2, 3, 224, 224), + 'vgg16_bn': (2, 3, 224, 224), + 'vgg19_bn': (2, 3, 224, 224), + 'mobilenet_v2': (2, 3, 224, 224), + } + for model_name, input_shape in model_input_shapes.items(): + with self.createSummaryWriter() as w: + model = getattr(torchvision.models, model_name)() + w.add_graph(model, torch.zeros(input_shape)) + +class TestTensorBoardFigure(BaseTestCase): + @skipIfNoMatplotlib + def test_figure(self): + writer = self.createSummaryWriter() + + figure, axes = plt.figure(), plt.gca() + circle1 = plt.Circle((0.2, 0.5), 0.2, color='r') + circle2 = plt.Circle((0.8, 0.5), 0.2, color='g') + axes.add_patch(circle1) + axes.add_patch(circle2) + plt.axis('scaled') + plt.tight_layout() + + writer.add_figure("add_figure/figure", figure, 0, close=False) + self.assertTrue(plt.fignum_exists(figure.number)) + + writer.add_figure("add_figure/figure", figure, 1) + self.assertFalse(plt.fignum_exists(figure.number)) + + writer.close() + + @skipIfNoMatplotlib + def test_figure_list(self): + writer = self.createSummaryWriter() + + figures = [] + for i in range(5): + figure = plt.figure() + plt.plot([i * 1, i * 2, i * 3], label="Plot " + str(i)) + plt.xlabel("X") + plt.xlabel("Y") + plt.legend() plt.tight_layout() - - writer.add_figure("add_figure/figure", figure, 0, close=False) - self.assertTrue(plt.fignum_exists(figure.number)) - - writer.add_figure("add_figure/figure", figure, 1) - self.assertFalse(plt.fignum_exists(figure.number)) - - writer.close() - - @skipIfNoMatplotlib - def test_figure_list(self): - writer = SummaryWriter() - - figures = [] - for i in range(5): - figure = plt.figure() - plt.plot([i * 1, i * 2, i * 3], label="Plot " + str(i)) - plt.xlabel("X") - plt.xlabel("Y") - plt.legend() - plt.tight_layout() - figures.append(figure) - - writer.add_figure("add_figure/figure_list", figures, 0, close=False) - self.assertTrue(all([plt.fignum_exists(figure.number) is True for figure in figures])) # noqa F812 - - writer.add_figure("add_figure/figure_list", figures, 1) - self.assertTrue(all([plt.fignum_exists(figure.number) is False for figure in figures])) # noqa F812 - - writer.close() - - class TestTensorBoardNumpy(BaseTestCase): - def test_scalar(self): - res = make_np(1.1) - self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) - res = make_np(1 << 64 - 1) # uint64_max - self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) - res = make_np(np.float16(1.00000087)) - self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) - res = make_np(np.float128(1.00008 + 9)) - self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) - res = make_np(np.int64(100000000000)) - self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) - - @skipIfNoCaffe2 - def test_caffe2_np(self): - workspace.FeedBlob("testBlob", tensor_N(shape=(1, 3, 64, 64))) - self.assertIsInstance(make_np('testBlob'), np.ndarray) - - @skipIfNoCaffe2 - def test_caffe2_np_expect_fail(self): - with self.assertRaises(RuntimeError): - res = make_np('This_blob_does_not_exist') - - def test_pytorch_np_expect_fail(self): - with self.assertRaises(NotImplementedError): - res = make_np({'pytorch': 1.0}) - - @skipIfNoCaffe2 - @unittest.skipIf(TEST_WITH_ASAN, "Caffe2 failure with ASAN") - def test_caffe2_simple_model(self): - model = ModelHelper(name="mnist") - # how come those inputs don't break the forward pass =.=a - workspace.FeedBlob("data", np.random.randn(1, 3, 64, 64).astype(np.float32)) - workspace.FeedBlob("label", np.random.randn(1, 1000).astype(np.int)) - - with core.NameScope("conv1"): - conv1 = brew.conv(model, "data", 'conv1', dim_in=1, dim_out=20, kernel=5) - # Image size: 24 x 24 -> 12 x 12 - pool1 = brew.max_pool(model, conv1, 'pool1', kernel=2, stride=2) - # Image size: 12 x 12 -> 8 x 8 - conv2 = brew.conv(model, pool1, 'conv2', dim_in=20, dim_out=100, kernel=5) - # Image size: 8 x 8 -> 4 x 4 - pool2 = brew.max_pool(model, conv2, 'pool2', kernel=2, stride=2) - with core.NameScope("classifier"): - # 50 * 4 * 4 stands for dim_out from previous layer multiplied by the image size - fc3 = brew.fc(model, pool2, 'fc3', dim_in=100 * 4 * 4, dim_out=500) - relu = brew.relu(model, fc3, fc3) - pred = brew.fc(model, relu, 'pred', 500, 10) - softmax = brew.softmax(model, pred, 'softmax') - xent = model.LabelCrossEntropy([softmax, "label"], 'xent') - # compute the expected loss - loss = model.AveragedLoss(xent, "loss") - model.net.RunAllOnMKL() - model.param_init_net.RunAllOnMKL() - model.AddGradientOperators([loss], skip=1) - blob_name_tracker = {} - graph = c2_graph.model_to_graph_def( - model, - blob_name_tracker=blob_name_tracker, - shapes={}, - show_simplified=False, - ) - compare_proto(graph, self) - - @skipIfNoCaffe2 - def test_caffe2_simple_cnnmodel(self): - model = cnn.CNNModelHelper("NCHW", name="overfeat") - workspace.FeedBlob("data", np.random.randn(1, 3, 64, 64).astype(np.float32)) - workspace.FeedBlob("label", np.random.randn(1, 1000).astype(np.int)) - with core.NameScope("conv1"): - conv1 = model.Conv("data", "conv1", 3, 96, 11, stride=4) - relu1 = model.Relu(conv1, conv1) - pool1 = model.MaxPool(relu1, "pool1", kernel=2, stride=2) - with core.NameScope("classifier"): - fc = model.FC(pool1, "fc", 4096, 1000) - pred = model.Softmax(fc, "pred") - xent = model.LabelCrossEntropy([pred, "label"], "xent") - loss = model.AveragedLoss(xent, "loss") - - blob_name_tracker = {} - graph = c2_graph.model_to_graph_def( - model, - blob_name_tracker=blob_name_tracker, - shapes={}, - show_simplified=False, - ) - compare_proto(graph, self) + figures.append(figure) + + writer.add_figure("add_figure/figure_list", figures, 0, close=False) + self.assertTrue(all([plt.fignum_exists(figure.number) is True for figure in figures])) # noqa F812 + + writer.add_figure("add_figure/figure_list", figures, 1) + self.assertTrue(all([plt.fignum_exists(figure.number) is False for figure in figures])) # noqa F812 + + writer.close() + +class TestTensorBoardNumpy(BaseTestCase): + def test_scalar(self): + res = make_np(1.1) + self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) + res = make_np(1 << 64 - 1) # uint64_max + self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) + res = make_np(np.float16(1.00000087)) + self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) + res = make_np(np.float128(1.00008 + 9)) + self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) + res = make_np(np.int64(100000000000)) + self.assertIsInstance(res, np.ndarray) and self.assertEqual(res.shape, (1,)) + + @skipIfNoCaffe2 + def test_caffe2_np(self): + workspace.FeedBlob("testBlob", tensor_N(shape=(1, 3, 64, 64))) + self.assertIsInstance(make_np('testBlob'), np.ndarray) + + @skipIfNoCaffe2 + def test_caffe2_np_expect_fail(self): + with self.assertRaises(RuntimeError): + res = make_np('This_blob_does_not_exist') + + def test_pytorch_np_expect_fail(self): + with self.assertRaises(NotImplementedError): + res = make_np({'pytorch': 1.0}) + + @skipIfNoCaffe2 + @unittest.skipIf(TEST_WITH_ASAN, "Caffe2 failure with ASAN") + def test_caffe2_simple_model(self): + model = ModelHelper(name="mnist") + # how come those inputs don't break the forward pass =.=a + workspace.FeedBlob("data", np.random.randn(1, 3, 64, 64).astype(np.float32)) + workspace.FeedBlob("label", np.random.randn(1, 1000).astype(np.int)) + + with core.NameScope("conv1"): + conv1 = brew.conv(model, "data", 'conv1', dim_in=1, dim_out=20, kernel=5) + # Image size: 24 x 24 -> 12 x 12 + pool1 = brew.max_pool(model, conv1, 'pool1', kernel=2, stride=2) + # Image size: 12 x 12 -> 8 x 8 + conv2 = brew.conv(model, pool1, 'conv2', dim_in=20, dim_out=100, kernel=5) + # Image size: 8 x 8 -> 4 x 4 + pool2 = brew.max_pool(model, conv2, 'pool2', kernel=2, stride=2) + with core.NameScope("classifier"): + # 50 * 4 * 4 stands for dim_out from previous layer multiplied by the image size + fc3 = brew.fc(model, pool2, 'fc3', dim_in=100 * 4 * 4, dim_out=500) + relu = brew.relu(model, fc3, fc3) + pred = brew.fc(model, relu, 'pred', 500, 10) + softmax = brew.softmax(model, pred, 'softmax') + xent = model.LabelCrossEntropy([softmax, "label"], 'xent') + # compute the expected loss + loss = model.AveragedLoss(xent, "loss") + model.net.RunAllOnMKL() + model.param_init_net.RunAllOnMKL() + model.AddGradientOperators([loss], skip=1) + blob_name_tracker = {} + graph = c2_graph.model_to_graph_def( + model, + blob_name_tracker=blob_name_tracker, + shapes={}, + show_simplified=False, + ) + compare_proto(graph, self) + + @skipIfNoCaffe2 + def test_caffe2_simple_cnnmodel(self): + model = cnn.CNNModelHelper("NCHW", name="overfeat") + workspace.FeedBlob("data", np.random.randn(1, 3, 64, 64).astype(np.float32)) + workspace.FeedBlob("label", np.random.randn(1, 1000).astype(np.int)) + with core.NameScope("conv1"): + conv1 = model.Conv("data", "conv1", 3, 96, 11, stride=4) + relu1 = model.Relu(conv1, conv1) + pool1 = model.MaxPool(relu1, "pool1", kernel=2, stride=2) + with core.NameScope("classifier"): + fc = model.FC(pool1, "fc", 4096, 1000) + pred = model.Softmax(fc, "pred") + xent = model.LabelCrossEntropy([pred, "label"], "xent") + loss = model.AveragedLoss(xent, "loss") + + blob_name_tracker = {} + graph = c2_graph.model_to_graph_def( + model, + blob_name_tracker=blob_name_tracker, + shapes={}, + show_simplified=False, + ) + compare_proto(graph, self) if __name__ == '__main__': run_tests() diff --git a/test/test_torch.py b/test/test_torch.py index 08ec2fe88f26e..dbe3d8898c319 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -29,9 +29,13 @@ _compare_trilu_indices from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \ TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \ - IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, skipIfRocm, do_test_dtypes, do_test_empty_full, \ - IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, torchtest, TEST_WITH_ROCM + IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, do_test_dtypes, do_test_empty_full, \ + IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, \ + skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf from multiprocessing.reduction import ForkingPickler +from common_device_type import instantiate_device_type_tests, \ + skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfRocm, onlyCUDA, onlyCPU, \ + dtypes, dtypesIfCUDA # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -101,7 +105,7 @@ def __exit__(self, *args): # This is intentionally prefixed by an underscore. Otherwise pytest will try to # run its methods as test cases. -class _TestTorchMixin(torchtest): +class _TestTorchMixin(object): def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True): float_types = [torch.double, torch.float] @@ -239,10 +243,13 @@ def test_namespace(ns, *skips): 'sparse_resize_', 'sparse_resize_and_clear_', 'align_to', # BUILD_NAMEDTENSOR only - 'view_names', # BUILD_NAMEDTENSOR only + 'align_as', # BUILD_NAMEDTENSOR only + 'renamed', # BUILD_NAMEDTENSOR only 'names_', # BUILD_NAMEDTENSOR only 'has_names', # BUILD_NAMEDTENSOR only 'rename', # BUILD_NAMEDTENSOR only + 'refine_names', # BUILD_NAMEDTENSOR only + 'unflatten', # BUILD_NAMEDTENSOR only ) test_namespace(torch.nn) test_namespace(torch.nn.functional, 'assert_int_or_pair', 'feature_alpha_dropout') @@ -401,84 +408,6 @@ def test_addmm(self): res2[i, j] += m1[i, k] * m2[k, j] self.assertEqual(res1, res2, prec) - def test_logical_any(self): - for device in torch.testing.get_all_device_types(): - x = torch.zeros([2, 3, 400], dtype=torch.uint8, device=device) - - self.assertEqual( - torch.tensor(0, dtype=torch.uint8, device=device), - x.any()) - - self.assertEqual( - torch.zeros([1, 3, 400], dtype=torch.uint8, device=device), - x.any(0, keepdim=True)) - - self.assertEqual( - torch.zeros([2, 1, 400], dtype=torch.uint8, device=device), - x.any(1, keepdim=True)) - - self.assertEqual( - torch.zeros([2, 3, 1], dtype=torch.uint8, device=device), - x.any(2, keepdim=True)) - - # set the last element to 0 - x[-1][-1][-1] = 1 - - self.assertEqual( - torch.tensor(1, dtype=torch.uint8, device=device), - x.any()) - - y = torch.zeros([1, 3, 400], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 1 - self.assertEqual(y, x.any(0, keepdim=True)) - - y = torch.zeros([2, 1, 400], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 1 - self.assertEqual(y, x.any(1, keepdim=True)) - - y = torch.zeros([2, 3, 1], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 1 - self.assertEqual(y, x.any(2, keepdim=True)) - - def test_logical_all(self): - for device in torch.testing.get_all_device_types(): - x = torch.ones([2, 3, 400], dtype=torch.uint8, device=device) - - self.assertEqual( - torch.tensor(1, dtype=torch.uint8, device=device), - x.all()) - - self.assertEqual( - torch.ones([1, 3, 400], dtype=torch.uint8, device=device), - x.all(0, keepdim=True)) - - self.assertEqual( - torch.ones([2, 1, 400], dtype=torch.uint8, device=device), - x.all(1, keepdim=True)) - - self.assertEqual( - torch.ones([2, 3, 1], dtype=torch.uint8, device=device), - x.all(2, keepdim=True)) - - # set the last element to 0 - x[-1][-1][-1] = 0 - - self.assertEqual( - torch.tensor(0, dtype=torch.uint8, device=device), - x.all()) - - y = torch.ones([1, 3, 400], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 0 - self.assertEqual(y, x.all(0, keepdim=True)) - - y = torch.ones([2, 1, 400], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 0 - self.assertEqual(y, x.all(1, keepdim=True)) - - y = torch.ones([2, 3, 1], dtype=torch.uint8, device=device) - y[-1][-1][-1] = 0 - self.assertEqual(y, x.all(2, keepdim=True)) - def test_allclose(self): x = torch.tensor([1.0, 2.0, 3.0]) y = torch.tensor([1.01, 2.01, 3.01]) @@ -700,6 +629,9 @@ def test_polygamma(self): lambda x: polygamma(n, x).item(), self._digamma_input(test_poles=False)) + with self.assertRaisesRegex(RuntimeError, r'polygamma\(n, x\) does not support negative n\.'): + torch.polygamma(-1, torch.tensor([1.0, 2.0])) + def test_asin(self): self._test_math(torch.asin, lambda x: math.asin(x) if abs(x) <= 1 else nan) @@ -776,18 +708,6 @@ def test_erf(self): def test_erfc(self): self._test_math_by_name('erfc') - def test_erfinv(self): - def checkType(tensor): - inputValues = torch.randn(4, 4, out=tensor()).clamp(-2., 2.) - self.assertEqual(tensor(inputValues).erf().erfinv(), tensor(inputValues)) - # test inf - self.assertTrue(torch.equal(tensor([-1, 1]).erfinv(), tensor([-inf, inf]))) - # test nan - self.assertEqual(tensor([-2, 2]).erfinv(), tensor([nan, nan])) - - checkType(torch.FloatTensor) - checkType(torch.DoubleTensor) - def test_exp(self): def exp(x): try: @@ -819,12 +739,6 @@ def test_floor(self): def test_ceil(self): self._test_math_by_name('ceil') - @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') - def test_ceil_out_cpu_cuda(self): - a = torch.randn(1) - b = torch.randn(1, device="cuda") - self.assertRaises(RuntimeError, lambda: torch.ceil(a, out=b)) - def test_rsqrt(self): def rsqrt(x): if x == 0: @@ -864,23 +778,6 @@ def test_has_storage(self): self.assertIsNotNone(torch.Tensor([0, 0, 0]).nonzero().storage()) self.assertIsNotNone(torch.Tensor().new().storage()) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_has_storage_numpy(self): - for dtype in [np.float32, np.float64, np.int64, - np.int32, np.int16, np.uint8]: - arr = np.array([1], dtype=dtype) - self.assertIsNotNone(torch.FloatTensor(arr).storage()) - self.assertIsNotNone(torch.DoubleTensor(arr).storage()) - self.assertIsNotNone(torch.IntTensor(arr).storage()) - self.assertIsNotNone(torch.LongTensor(arr).storage()) - self.assertIsNotNone(torch.ByteTensor(arr).storage()) - if torch.cuda.is_available(): - self.assertIsNotNone(torch.cuda.FloatTensor(arr).storage()) - self.assertIsNotNone(torch.cuda.DoubleTensor(arr).storage()) - self.assertIsNotNone(torch.cuda.IntTensor(arr).storage()) - self.assertIsNotNone(torch.cuda.LongTensor(arr).storage()) - self.assertIsNotNone(torch.cuda.ByteTensor(arr).storage()) - def _testSelection(self, torchfn, mathfn): # contiguous m1 = torch.randn(100, 100) @@ -936,18 +833,6 @@ def _testSelection(self, torchfn, mathfn): def test_max(self): self._testSelection(torch.max, max) - def test_log_normal(self): - for device in torch.testing.get_all_device_types(): - a = torch.tensor([10], dtype=torch.float, device=device).log_normal_() - self.assertEqual(a.dtype, torch.float) - self.assertEqual(a.size(), torch.Size([1])) - - def test_geometric(self): - for device in torch.testing.get_all_device_types(): - a = torch.tensor([10], dtype=torch.float, device=device).geometric_(0.5) - self.assertEqual(a.dtype, torch.float) - self.assertEqual(a.size(), torch.Size([1])) - @staticmethod def _test_max_with_inf(self, dtypes=(torch.float, torch.double), device='cpu'): for dtype in dtypes: @@ -971,11841 +856,11493 @@ def _test_min_with_inf(self, dtypes=(torch.float, torch.double), device='cpu'): def test_min_with_inf(self): self._test_min_with_inf(self) - @staticmethod - def _test_norm(self, device): - # full reduction - x = torch.randn(25, device=device) - xn = x.cpu().numpy() - for p in [0, 1, 2, 3, 4, inf, -inf]: - res = x.norm(p).item() - expected = np.linalg.norm(xn, p) - self.assertEqual(res, expected, "full reduction failed for {}-norm".format(p)) - - # one dimension - x = torch.randn(25, 25, device=device) - xn = x.cpu().numpy() - for p in [0, 1, 2, 3, 4, inf, -inf]: - res = x.norm(p, 1).cpu().numpy() - expected = np.linalg.norm(xn, p, 1) - self.assertEqual(res.shape, expected.shape) - self.assertTrue(np.allclose(res, expected), "dim reduction failed for {}-norm".format(p)) - - # matrix norm - for p in ['fro', 'nuc']: - res = x.norm(p).cpu().numpy() - expected = np.linalg.norm(xn, p) - self.assertEqual(res.shape, expected.shape) - self.assertTrue(np.allclose(res, expected), "dim reduction failed for {}-norm".format(p)) + def test_dim_reduction_uint8_overflow(self): + example = [[-1, 2, 1], [5, 3, 6]] + x = torch.tensor(example, dtype=torch.uint8) + self.assertEqual(x.sum(dtype=torch.uint8).item(), 16) + self.assertEqual(x.sum(0, dtype=torch.uint8), torch.FloatTensor([4, 5, 7])) + self.assertEqual(x.sum(1, dtype=torch.uint8), torch.FloatTensor([2, 14])) + y = torch.tensor(example, dtype=torch.uint8) + torch.sum(x, 0, out=y) + self.assertEqual(x.sum(0, dtype=torch.uint8), y) - # larger tensor sanity check - self.assertEqual(2 * torch.norm(torch.ones(10000)), torch.norm(torch.ones(40000))) + @unittest.skipIf(not TEST_SCIPY, "Scipy not found") + def test_logsumexp(self): + from scipy.special import logsumexp + a = torch.randn(5, 4) + a[0, 0] = inf + a[1, :] = -inf + actual = a.logsumexp(1) + expected = logsumexp(a.numpy(), 1) + self.assertEqual(expected.shape, actual.shape) + self.assertTrue(np.allclose(expected, actual.numpy())) + # check that out is actually inplace + b = torch.zeros(5, 2) + c = b[:, 0] + torch.logsumexp(a, 1, out=c) + self.assertTrue(np.allclose(expected, b[:, 0].numpy())) @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - @skipIfNoLapack - def test_norm(self): - self._test_norm(self, device='cpu') + def test_cpu_parallel(self): + # To use parallel branches we'll need to compare on tensors + # that are relatively large. Even if this is run on a single + # core machine these tests will still give you signal on + # the correctness - @staticmethod - def _test_nuclear_norm_axes(self, device='cpu'): - def check_single_nuclear_norm(x, axes): - if x.is_cuda and randrange(100) < 95: - return # too many cpu <==> gpu copies + def _run_test(size): + for dim in range(len(size) + 1): + nv = np.round(np.random.rand(*size)) # 0s and 1s + tv = torch.from_numpy(nv) + # Parallelisim is only used if numel is + # larger than grainsize defined in Parallel.h + self.assertTrue(tv.numel() > 32768) + if dim == len(size): + nvs = nv.sum() + tvs = tv.sum() + else: + nvs = nv.sum(dim) + tvs = tv.sum(dim) + diff = np.abs(nvs - tvs.numpy()).sum() + self.assertEqual(diff, 0) - a = np.array(x.cpu(), copy=False) - expected = np.linalg.norm(a, "nuc", axis=axes) + _run_test([2, 3, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 3]) + _run_test([4, 4, 4, 4, 4, 4, 4, 4, 4, 4]) + _run_test([1, 32 * 8 * 32 * 8]) + _run_test([1, 32770]) - ans = torch.norm(x, "nuc", dim=axes) - self.assertTrue(ans.is_contiguous()) - self.assertEqual(ans.shape, expected.shape) - self.assertTrue(np.allclose(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True)) + def _testCSelection(self, torchfn, mathfn): + # Two tensors + size = (100, 100) + a = torch.rand(*size) + b = torch.rand(*size) + c = torchfn(a, b) + expected_c = torch.zeros(*size) + expected_c.map2_(a, b, lambda _, a, b: mathfn(a, b)) + self.assertEqual(expected_c, c, 0) - out = torch.zeros(expected.shape, dtype=x.dtype, device=x.device) - ans = torch.norm(x, "nuc", dim=axes, out=out) - self.assertIs(ans, out) - self.assertTrue(ans.is_contiguous()) - self.assertEqual(ans.shape, expected.shape) - self.assertTrue(np.allclose(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True)) + def test_max_elementwise(self): + self._testCSelection(torch.max, max) - for n in range(1, 3): - for m in range(1, 3): - for axes in permutations([0, 1], 2): - # 2d, inner dimensions C - x = torch.randn(n, m, device=device) - check_single_nuclear_norm(x, axes) + def test_min_elementwise(self): + self._testCSelection(torch.min, min) - # 2d, inner dimensions Fortran - x = torch.randn(m, n, device=device).transpose(-1, -2) - check_single_nuclear_norm(x, axes) + def test_all_any(self): + def test(size): + x = torch.ones(*size).byte() + self.assertTrue(x.all()) + self.assertTrue(x.any()) - # 2d, inner dimensions non-contiguous - x = torch.randn(n, 2 * m, device=device)[:, ::2] - check_single_nuclear_norm(x, axes) + x[3] = 0 + self.assertFalse(x.all()) + self.assertTrue(x.any()) - # 2d, all dimensions non-contiguous - x = torch.randn(7 * n, 2 * m, device=device)[::7, ::2] - check_single_nuclear_norm(x, axes) + x.zero_() + self.assertFalse(x.all()) + self.assertFalse(x.any()) - for o in range(1, 3): - for axes in permutations([0, 1, 2], 2): - # 3d, inner dimensions C - x = torch.randn(o, n, m, device=device) - check_single_nuclear_norm(x, axes) + x.fill_(2) + self.assertTrue(x.all()) + self.assertTrue(x.any()) - # 3d, inner dimensions Fortran - x = torch.randn(o, m, n, device=device).transpose(-1, -2) - check_single_nuclear_norm(x, axes) + x = torch.ones(*size).bool() + self.assertTrue(x.all()) + self.assertTrue(x.any()) - # 3d, inner dimensions non-contiguous - x = torch.randn(o, n, 2 * m, device=device)[:, :, ::2] - check_single_nuclear_norm(x, axes) + x[3] = False + self.assertFalse(x.all()) + self.assertTrue(x.any()) - # 3d, all dimensions non-contiguous - x = torch.randn(7 * o, 5 * n, 2 * m, device=device)[::7, ::5, ::2] - check_single_nuclear_norm(x, axes) + test((10,)) + test((5, 5)) - for r in range(1, 3): - for axes in permutations([0, 1, 2, 3], 2): - # 4d, inner dimensions C - x = torch.randn(r, o, n, m, device=device) - check_single_nuclear_norm(x, axes) + def test_where_bool_tensor(self): + for d in torch.testing.get_all_device_types(): + a = torch.tensor([True, False], device=d) + res = torch.where(a > 0) + self.assertEqual(1, len(res)) - # 4d, inner dimensions Fortran - x = torch.randn(r, o, n, m, device=device).transpose(-1, -2) - check_single_nuclear_norm(x, axes) + def test_all_any_with_dim(self): + def test(x): + r1 = x.prod(dim=0, keepdim=False).byte() + r2 = x.all(dim=0, keepdim=False) + self.assertEqual(r1.shape, r2.shape) + self.assertTrue((r1 == r2).all()) - # 4d, inner dimensions non-contiguous - x = torch.randn(r, o, n, 2 * m, device=device)[:, :, :, ::2] - check_single_nuclear_norm(x, axes) + r3 = x.sum(dim=1, keepdim=True).clamp(0, 1).byte() + r4 = x.any(dim=1, keepdim=True) + self.assertEqual(r3.shape, r4.shape) + self.assertTrue((r3 == r4).all()) - # 4d, all dimensions non-contiguous - x = torch.randn(7 * r, 5 * o, 11 * n, 2 * m, device=device)[::7, ::5, ::11, ::2] - check_single_nuclear_norm(x, axes) + test(torch.ByteTensor([[0, 0, 0], + [0, 0, 1], + [0, 1, 1], + [1, 1, 1]])) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_nuclear_norm_axes_small_brute_force(self): - self._test_nuclear_norm_axes(self) + def test_mv(self): + def _test_mv(m1, v1): + res1 = torch.mv(m1, v1) + res2 = res1.clone().zero_() + for i, j in iter_indices(m1): + res2[i] += m1[i][j] * v1[j] - @staticmethod - def _test_nuclear_norm_exceptions(self, device='cpu'): - for lst in [], [1], [1, 2]: - for axes in (), (0,), (0, 1): - x = torch.tensor(lst, dtype=torch.double, device=device) - self.assertRaises(RuntimeError, torch.norm, x, "nuc", axes) + self.assertEqual(res1, res2) - x = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.double, device=device) - self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0)) - self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 2)) + _test_mv(torch.randn(100, 100, dtype=torch.float32), torch.randn(100, dtype=torch.float32)) + _test_mv(torch.randn(100, 100, dtype=torch.float64), torch.randn(100, dtype=torch.float64)) + _test_mv(torch.randint(0, 100, (100, 100), dtype=torch.int32), torch.randint(0, 100, (100, ), dtype=torch.int32)) + _test_mv(torch.randint(0, 100, (100, 100), dtype=torch.int64), torch.randint(0, 100, (100, ), dtype=torch.int64)) + _test_mv(torch.randn(100, 100, dtype=torch.float32).bfloat16(), torch.randn(100, dtype=torch.float32).bfloat16()) - def test_nuclear_norm_exceptions(self): - self._test_nuclear_norm_exceptions(self) + def test_numpy_args(self): + x1 = torch.randn(10) + x2 = torch.randn(10) + res1 = torch.add(input=x1, other=x2) + res2 = torch.add(x1=x1, x2=x2) + self.assertEqual(res1, res2) - @staticmethod - def _test_dist(self, device): - def run_test(x, y): - for p in [0, 1, 2, 3, 4, inf, -inf]: - dist_xy = torch.dist(x, y, p) - dist_xy_norm = torch.norm(x - y, p) - self.assertEqual(dist_xy, dist_xy_norm) - - run_test(torch.randn(5, device=device), torch.randn(5, device=device)) + x1 = torch.randn(10, 10, 10) + res1 = x1.sum(dim=(0, 2), keepdim=True) + res2 = x1.sum(axis=(0, 2), keepdims=True) + self.assertEqual(res1, res2) - x = torch.zeros(3, device=device) - y = torch.zeros(3, device=device) - y[1] = 1. - run_test(x, y) + def test_sub(self): + for dtype in torch.testing.get_all_dtypes(): + m1 = torch.tensor([2.34, 4.44], dtype=dtype) + m2 = torch.tensor([1.23, 2.33], dtype=dtype) - def test_dist(self): - self._test_dist(self, device='cpu') + if (dtype == torch.half or dtype == torch.bool): + self.assertRaises(RuntimeError, lambda: m1 - m2) + elif (dtype == torch.bfloat16): + # bfloat16 has a lower precision so we have to have a separate check for it + self.assertEqual(m1 - m2, torch.tensor([1.11, 2.11], dtype=dtype), 0.01) + else: + self.assertEqual(m1 - m2, torch.tensor([1.11, 2.11], dtype=dtype)) - def test_dim_reduction_uint8_overflow(self): - example = [[-1, 2, 1], [5, 3, 6]] - x = torch.tensor(example, dtype=torch.uint8) - self.assertEqual(x.sum(dtype=torch.uint8).item(), 16) - self.assertEqual(x.sum(0, dtype=torch.uint8), torch.FloatTensor([4, 5, 7])) - self.assertEqual(x.sum(1, dtype=torch.uint8), torch.FloatTensor([2, 14])) - y = torch.tensor(example, dtype=torch.uint8) - torch.sum(x, 0, out=y) - self.assertEqual(x.sum(0, dtype=torch.uint8), y) + def test_csub(self): + # with a tensor + a = torch.randn(100, 90) + b = a.clone().normal_() - @staticmethod - def _test_dim_reduction(self, cast): - example = [[-1, 2, 1], [5, 3, 6]] + res_add = torch.add(a, -1, b) + res_csub = a.clone() + res_csub.sub_(b) + self.assertEqual(res_add, res_csub) - types = [torch.double, - torch.float, - torch.int64, - torch.int32, - torch.int16] + # with a scalar + a = torch.randn(100, 100) - # This won't test for 256bit instructions, since we usually - # only work on 1 cacheline (1024bit) at a time and these - # examples aren't big enough to trigger that. - for dtype in types: - x = cast(torch.tensor(example, dtype=dtype)) - self.assertEqual(x.sum().item(), 16) - self.assertEqual(x.sum(0), torch.FloatTensor([4, 5, 7])) - self.assertEqual(x.sum(1), torch.FloatTensor([2, 14])) - y = cast(torch.tensor(example, dtype=dtype)) - torch.sum(x, 0, out=y) - self.assertEqual(x.sum(0), y) + scalar = 123.5 + res_add = torch.add(a, -scalar) + res_csub = a.clone() + res_csub.sub_(scalar) + self.assertEqual(res_add, res_csub) - # Mean not supported for Int types - for dtype in types[:2]: - x = cast(torch.tensor(example, dtype=dtype)) - self.assertEqual(x.mean().item(), 16.0 / 6) - self.assertEqual(x.mean(0), torch.FloatTensor([2.0, 2.5, 7.0 / 2])) - self.assertEqual(x.mean(1), torch.FloatTensor([2.0 / 3, 14.0 / 3])) - self.assertEqual(x.mean(), x.mean((0, 1))) + def test_threshold(self): + for dtype in torch.testing.get_all_math_dtypes('cpu'): + if dtype != torch.uint8 and dtype != torch.float16: + # 100 is wide enough to use AVX2 instructions for all types + x = torch.randn(100).sign().to(dtype=dtype) + y = torch.threshold(x, 0, 0) + self.assertTrue(y.le(0).any()) - for dtype in types: - x = cast(torch.tensor(example, dtype=dtype)) - self.assertEqual(x.prod().item(), -180) - self.assertEqual(x.prod(0), torch.FloatTensor([-5, 6, 6])) - self.assertEqual(x.prod(1), torch.FloatTensor([-2, 90])) + def test_reciprocal(self): + for dtype in [torch.float, torch.double]: + a = torch.randn(100, 89, dtype=dtype) + res_div = 1 / a + res_reciprocal = a.clone() + res_reciprocal.reciprocal_() + self.assertEqual(res_reciprocal, res_div) - for dtype in types: - x = cast(torch.tensor(example, dtype=dtype)) - self.assertEqual(x.max().item(), 6) - self.assertEqual(x.max(0), (torch.FloatTensor([5, 3, 6]), torch.FloatTensor([1, 1, 1]))) - self.assertEqual(x.max(1), (torch.FloatTensor([2, 6]), torch.FloatTensor([1, 2]))) + def test_div(self): + m1 = torch.randn(10, 10) + res1 = m1.clone() + res1[:, 3].div_(2) + res2 = m1.clone() + for i in range(m1.size(0)): + res2[i, 3] = res2[i, 3] / 2 + self.assertEqual(res1, res2) - for dtype in types: - x = cast(torch.tensor(example, dtype=dtype)) - self.assertEqual(x.min().item(), -1) - self.assertEqual(x.min(0), (torch.FloatTensor([-1, 2, 1]), torch.FloatTensor([0, 0, 0]))) - self.assertEqual(x.min(1), (torch.FloatTensor([-1, 3]), torch.FloatTensor([0, 1]))) + a1 = torch.tensor([4.2, 6.2], dtype=torch.bfloat16) + a2 = torch.tensor([2., 2.], dtype=torch.bfloat16) + self.assertEqual(a1 / a2, torch.tensor([2.1, 3.1], dtype=torch.bfloat16), 0.01) + self.assertEqual(a1.div(a2), a1 / a2) - for dtype in types: - x = cast(torch.tensor(example, dtype=dtype)) - self.assertEqual(x.argmax().item(), 5) - self.assertEqual(x.argmax(dim=None).item(), 5) - self.assertEqual(x.argmax(dim=0), torch.FloatTensor([1, 1, 1])) - self.assertEqual(x.argmax(dim=1), torch.FloatTensor([1, 2])) - self.assertEqual(x.argmax(dim=0, keepdim=True), torch.FloatTensor([[1, 1, 1]])) - # test that non-contiguous tensors work - self.assertEqual(x[:, :2].argmax().item(), 2) + def test_floordiv(self): + for dtype in torch.testing.get_all_math_dtypes('cpu'): + if dtype is torch.float16: + continue + x = torch.randn(100).mul(10).to(dtype) + y = x // 3 + self.assertEqual(y.dtype, x.dtype) + z = torch.tensor([math.trunc(v.item() / 3.) for v in x], dtype=y.dtype) + self.assertEqual(y, z) - for dtype in types: - x = cast(torch.tensor(example, dtype=dtype)) - self.assertEqual(x.argmin().item(), 0) - self.assertEqual(x.argmin(dim=None).item(), 0) - self.assertEqual(x.argmin(dim=0), torch.FloatTensor([0, 0, 0])) - self.assertEqual(x.argmin(dim=1), torch.FloatTensor([0, 1])) - self.assertEqual(x.argmin(dim=1, keepdim=True), torch.FloatTensor([[0], [1]])) - # test that non-contiguous tensors work - self.assertEqual(x[:, :2].argmin().item(), 0) + def test_rdiv(self): + for dtype in torch.testing.get_all_math_dtypes('cpu'): + if dtype is torch.float16: + continue + x = torch.rand(100).add(1).mul(4).to(dtype) + y = 30 / x + if dtype.is_floating_point: + z = torch.tensor([30 / v.item() for v in x], dtype=dtype) + else: + z = torch.tensor([math.trunc(30. / v.item()) for v in x], dtype=dtype) + self.assertEqual(y, z) - dim_red_fns = [ - "mean", "median", "mode", "norm", "prod", - "std", "sum", "var", "max", "min"] + def test_fmod(self): + m1 = torch.Tensor(10, 10).uniform_(-10., 10.) + res1 = m1.clone() + q = 2.1 + res1[:, 3].fmod_(q) + res2 = m1.clone() + for i in range(m1.size(1)): + res2[i, 3] = math.fmod(res2[i, 3], q) + self.assertEqual(res1, res2) - def normfn_attr(t, dim, keepdim=False, out=None): - attr = torch.norm - return attr(t, 2, dim, keepdim, out=out) + def test_remainder(self): + # Check the Floating point case, both tensor and scalar overloads + for use_item in [True, False]: + m1 = torch.Tensor(10, 10).uniform_(-10., 10.) + res1 = m1.clone() + res2 = m1.clone() + qs = torch.arange(-5.1, 4.1) + # Check the case where the divisor is a simple float + for col_idx, q in enumerate(qs): + # Reference + for i in range(m1.size(0)): + res2[i, col_idx] = res2[i, col_idx] % q + # To test + res1[:, col_idx].remainder_(q if not use_item else q.item()) + self.assertEqual(res1, res2) + # Check the case where the divisor is a tensor + res1 = m1.clone() + res1.remainder_(qs.unsqueeze(0).expand_as(res1)) + self.assertEqual(res1, res2) - for fn_name in dim_red_fns: - fn_attr = getattr(torch, fn_name) if fn_name != "norm" else normfn_attr + # Check the LongTensor case, both tensor and scalar overloads + for use_item in [True, False]: + long_m1 = torch.LongTensor(10, 10).random_(-10, 10) + long_res1 = long_m1.clone() + long_res2 = long_m1.clone() + long_qs = torch.arange(-5, 5) + long_qs[5] = 5 # Can't handle the divisor=0 case + for col_idx, long_q in enumerate(long_qs): + # Reference + for i in range(long_m1.size(0)): + long_res2[i, col_idx] = long_res2[i, col_idx] % long_q + # To test + long_res1[:, col_idx].remainder_(long_q if not use_item else long_q.item()) + self.assertEqual(long_res1, long_res2) + # Divisor is a tensor case + long_res1 = long_m1.clone() + long_res1.remainder_(long_qs.unsqueeze(0).expand_as(long_res1)) - def fn(x, dim, keepdim=False, out=None): - ans = fn_attr(x, dim, keepdim=keepdim, out=out) - return ans if not istuple(ans) else ans[0] + def test_mm(self): + def _test_mm(n, m, p, dtype, genf): + # helper function + def matrixmultiply(mat1, mat2): + n = mat1.size(0) + m = mat1.size(1) + p = mat2.size(1) + res = torch.zeros(n, p, dtype=dtype) + for i, j in iter_indices(res): + res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m)) + return res - def fn_tuple(x, dim, keepdim=False, out=None): - return fn_attr(x, dim, keepdim=keepdim, out=out) + # contiguous case + mat1 = genf(n, m) + mat2 = genf(m, p) + res = torch.mm(mat1, mat2) - def test_multidim(x, dim): - self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, keepdim=True)) - self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension()) - self.assertEqual(x.ndimension(), fn(x, dim, keepdim=True).ndimension()) + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) - # general case - x = cast(torch.randn(3, 4, 5)) - dim = random.randint(0, 2) - test_multidim(x, dim) + # non contiguous case 1 + mat1 = genf(n, m) + mat2 = genf(p, m).t() + res = torch.mm(mat1, mat2) - # check 1-d behavior - x = cast(torch.randn(1)) - dim = 0 - self.assertEqual(fn(x, dim).shape, ()) - self.assertEqual(fn(x, dim, keepdim=True).shape, (1,)) + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) - # check reducing of a singleton dimension - dims = [3, 4, 5] - singleton_dim = random.randint(0, 2) - dims[singleton_dim] = 1 - x = cast(torch.randn(dims)) - test_multidim(x, singleton_dim) + # non contiguous case 2 + mat1 = genf(m, n).t() + mat2 = genf(m, p) + res = torch.mm(mat1, mat2) - # check reducing with output kwargs - if fn_name in ['median', 'mode', 'max', 'min']: - y = cast(torch.randn(5, 3)) - values = cast(torch.randn(5, 3)) - indices = cast(torch.zeros(5, 3).long() - 1) - fn_tuple(y, 1, keepdim=False, out=(values[:, 1], indices[:, 1])) - values_expected, indices_expected = fn_tuple(y, 1, keepdim=False) - self.assertEqual(values[:, 1], values_expected, - '{} values with out= kwarg'.format(fn_name)) - self.assertEqual(indices[:, 1], indices_expected, - '{} indices with out= kwarg'.format(fn_name)) - continue + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) - x = cast(torch.randn(5, 3)) - y = cast(torch.randn(5, 3)) - fn(y, 1, keepdim=False, out=x[:, 1]) - expected = fn(y, 1, keepdim=False) - self.assertEqual(x[:, 1], expected, '{} with out= kwarg'.format(fn_name)) + # non contiguous case 3 + mat1 = genf(m, n).t() + mat2 = genf(p, m).t() + res = torch.mm(mat1, mat2) - def test_dim_reduction(self): - self._test_dim_reduction(self, lambda t: t) + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) - def test_reduction_empty(self): - fns_to_test = [ - # name, function, identity - ('max', torch.max, None), - ('kthvalue', lambda *args, **kwargs: torch.kthvalue(*args, k=1, **kwargs), None), - ('argmax', torch.argmax, None), - ('min', torch.min, None), - ('argmin', torch.argmin, None), - ('mode', torch.mode, None), - ('median', torch.median, None), + # test with zero stride + mat1 = genf(n, m) + mat2 = genf(m, 1).expand(m, p) + res = torch.mm(mat1, mat2) - ('prod', torch.prod, 1), - ('sum', torch.sum, 0), - ('norm', torch.norm, 0), - ('mean', torch.mean, nan), - ('var', torch.var, nan), - ('std', torch.std, nan), - ('logsumexp', torch.logsumexp, -inf), - ] + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) - shape = (2, 0, 4) - for device in torch.testing.get_all_device_types(): - x = torch.randn(shape, device=device) - - for fn in [torch.max, torch.min]: - ident_err = 'operation does not have an identity' - self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x)) - - for item in fns_to_test: - name, fn, identity = item - if identity is None: - ident_err = 'does not have an identity' - self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=2)) - self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=2, keepdim=True)) - self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=1)) - self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=1, keepdim=True)) - else: - self.assertEqual(torch.empty((2, 0), device=device), fn(x, dim=2)) - self.assertEqual(torch.empty((2, 0, 1), device=device), fn(x, dim=2, keepdim=True)) - # assertEqual doesn't work with inf, -inf, nan and two tensors. - check = (torch.testing.assert_allclose if math.isnan(identity) or math.isinf(identity) else - self.assertEqual) - check(torch.full((2, 4), identity, device=device), fn(x, dim=1)) - check(torch.full((2, 1, 4), identity, device=device), fn(x, dim=1, keepdim=True)) - try: - check(torch.full((), identity, device=device), fn(x)) - except TypeError as err: - # ignore if there is no allreduce. - self.assertTrue('dim' in str(err)) - - # any - xb = x.to(torch.uint8) - yb = x.to(torch.uint8) - self.assertEqual((2, 0), xb.any(2).shape) - self.assertEqual((2, 0, 1), xb.any(2, keepdim=True).shape) - self.assertEqual(torch.zeros((2, 4), device=device), xb.any(1)) - self.assertEqual(torch.zeros((2, 1, 4), device=device), xb.any(1, keepdim=True)) - self.assertEqual(torch.zeros((), device=device), xb.any()) - - # all - self.assertEqual((2, 0), xb.all(2).shape) - self.assertEqual((2, 0, 1), xb.all(2, keepdim=True).shape) - self.assertEqual(torch.ones((2, 4), device=device), xb.all(1)) - self.assertEqual(torch.ones((2, 1, 4), device=device), xb.all(1, keepdim=True)) - self.assertEqual(torch.ones((), device=device), xb.all()) - - def test_pairwise_distance_empty(self): - for device in torch.testing.get_all_device_types(): - shape = (2, 0) - x = torch.randn(shape, device=device) - y = torch.randn(shape, device=device) + # explicitly exercise the _out variant in torch.mm(). + # contiguous case + mat1 = genf(n, m) + mat2 = genf(m, p) + res = genf(n, p) + torch.mm(mat1, mat2, out=res) - self.assertEqual(torch.zeros(2, device=device), torch.pairwise_distance(x, y)) - self.assertEqual(torch.zeros((2, 1), device=device), torch.pairwise_distance(x, y, keepdim=True)) + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) - shape = (0, 2) - x = torch.randn(shape, device=device) - y = torch.randn(shape, device=device) - self.assertEqual(torch.zeros(0, device=device), torch.pairwise_distance(x, y)) - self.assertEqual(torch.zeros((0, 1), device=device), torch.pairwise_distance(x, y, keepdim=True)) + # explicitly exercise the _out variant in torch.mm(). + # non contiguous case 3 + mat1 = genf(m, n).t() + mat2 = genf(p, m).t() + res = genf(n, p) + torch.mm(mat1, mat2, out=res) - def test_pdist_empty(self): - for device in torch.testing.get_all_device_types(): - shape = (0, 2) - x = torch.randn(shape, device=device) - self.assertEqual(torch.empty(0, device=device), torch.pdist(x)) + res2 = matrixmultiply(mat1, mat2) + self.assertEqual(res, res2) - shape = (1, 2) - x = torch.randn(shape, device=device) - self.assertEqual(torch.empty(0, device=device), torch.pdist(x)) + for (n, m, p) in [(20, 10, 5), (15, 5, 10), (5, 18, 10)]: + _test_mm(n, m, p, torch.float32, lambda x, y: torch.randn(x, y, dtype=torch.float32)) + _test_mm(n, m, p, torch.float64, lambda x, y: torch.randn(x, y, dtype=torch.float64)) + _test_mm(n, m, p, torch.int32, lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int32)) + _test_mm(n, m, p, torch.int64, lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int64)) + _test_mm(n, m, p, torch.bfloat16, lambda x, y: torch.randn(x, y, dtype=torch.float32).bfloat16()) - shape = (3, 0) - x = torch.randn(shape, device=device) - self.assertEqual(torch.zeros(3, device=device), torch.pdist(x)) + @staticmethod + def _test_lu_solve(self, cast, pivot=True): + from common_utils import lu_solve_test_helper + for k, n in zip([2, 3, 5], [3, 5, 7]): + b, A, LU_data, LU_pivots = lu_solve_test_helper(self, (n,), (n, k), cast, pivot) + x = torch.lu_solve(b, LU_data, LU_pivots) + self.assertLessEqual(b.dist(A.mm(x)), 1e-12) - def test_pdist_norm(self): - def test_pdist_single(shape, device, p, dtype, trans): - x = torch.randn(shape, dtype=dtype, device=device) - if trans: - x.transpose_(-2, -1) - actual = torch.pdist(x, p=p) - expected = brute_pdist(x, p=p) - self.assertEqual(expected.shape, actual.shape) - self.assertTrue(torch.allclose(expected, actual)) + @skipIfNoLapack + def test_lu_solve(self): + self._test_lu_solve(self, lambda t: t) - for device in torch.testing.get_all_device_types(): - for shape in [(4, 5), (3, 2), (2, 1)]: - for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: - for trans in [False, True]: - for dtype in [torch.float32, torch.float64]: - test_pdist_single(shape, device, p, dtype, trans) - - # do a simplified comparison with big inputs, see: - # https://github.com/pytorch/pytorch/issues/15511 - for dtype in [torch.float32, torch.float64]: - test_pdist_single((1000, 2), device, 2, dtype, False) - - def test_cdist_empty(self): - for device in torch.testing.get_all_device_types(): - x = torch.randn((0, 5), device=device) - y = torch.randn((4, 5), device=device) - self.assertEqual(torch.empty(0, 4, device=device), torch.cdist(x, y)) + @staticmethod + def _test_lu_solve_batched(self, cast, pivot=True): + from common_utils import lu_solve_test_helper - x = torch.randn((2, 5), device=device) - y = torch.randn((0, 5), device=device) - self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y)) + def lu_solve_batch_test_helper(A_dims, b_dims, cast, pivot): + b, A, LU_data, LU_pivots = lu_solve_test_helper(self, A_dims, b_dims, cast, pivot) + x_exp_list = [] + for i in range(b_dims[0]): + x_exp_list.append(torch.lu_solve(b[i], LU_data[i], LU_pivots[i])) + x_exp = torch.stack(x_exp_list) # Stacked output + x_act = torch.lu_solve(b, LU_data, LU_pivots) # Actual output + self.assertEqual(x_exp, x_act) # Equality check + self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 1e-12) # Correctness check - x = torch.randn((2, 0), device=device) - y = torch.randn((3, 0), device=device) - self.assertEqual(torch.zeros(2, 3, device=device), torch.cdist(x, y)) + for batchsize in [1, 3, 4]: + lu_solve_batch_test_helper((5, batchsize), (batchsize, 5, 10), cast, pivot) - x = torch.randn((2, 0), device=device) - y = torch.randn((0, 0), device=device) - self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y)) + # tensors with 0 elements + b = cast(torch.randn(3, 0, 3)) + A = cast(torch.randn(3, 0, 0)) + LU_data, LU_pivots = torch.lu(A) + self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots)) - def test_cdist_norm(self): - for device in torch.testing.get_all_device_types(): - for r1 in [3, 4, 5, 6]: - for m in [2, 3, 4, 10]: - for r2 in [4, 6, 7, 8]: - for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: - x = torch.randn(r1, m, device=device) - y = torch.randn(r2, m, device=device) - actual = torch.cdist(x, y, p=p) - expected = brute_cdist(x, y, p=p) - self.assertTrue(torch.allclose(expected, actual)) - - def test_cdist_norm_batch(self): - for device in torch.testing.get_all_device_types(): - for r1 in [3, 4, 5, 6]: - for m in [2, 3, 4, 10]: - for r2 in [4, 6, 7, 8]: - for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: - x = torch.randn(2, 3, 6, r1, m, device=device) - y = torch.randn(2, 3, 6, r2, m, device=device) - actual = torch.cdist(x, y, p=p) - expected = brute_cdist(x, y, p=p) - self.assertTrue(torch.allclose(expected, actual)) - - def test_cdist_large(self): - for device in torch.testing.get_all_device_types(): - x = torch.randn(1000, 10, device=device) - y = torch.randn(1000, 10, device=device) - actual = torch.cdist(x, y, p=2) - expected = brute_cdist(x, y, p=2) - self.assertTrue(torch.allclose(expected, actual)) + @skipIfNoLapack + def test_lu_solve_batched(self): + self._test_lu_solve_batched(self, lambda t: t) - def test_cdist_large_batch(self): - for device in torch.testing.get_all_device_types(): - x = torch.randn(4, 3, 1000, 10, device=device) - y = torch.randn(4, 3, 1000, 10, device=device) - actual = torch.cdist(x, y, p=2) - expected = brute_cdist(x, y, p=2) - self.assertTrue(torch.allclose(expected, actual)) + @staticmethod + def _test_lu_unpack(self, cast, pivot=True): + def run_test(shape, cast): + a = cast(torch.randn(*shape)) + a_lu, p = torch.lu(a, pivot=pivot) + p_ref, l_ref, u_ref = torch.lu_unpack(a_lu, p) + self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a) - def test_cdist_non_contiguous(self): - for device in torch.testing.get_all_device_types(): - x = torch.randn(5, 7, device=device).transpose(-1, -2) - y = torch.randn(5, 3, device=device).transpose(-1, -2) - actual = torch.cdist(x, y, p=2) - expected = brute_cdist(x, y, p=2) - self.assertFalse(x.is_contiguous()) - self.assertFalse(y.is_contiguous()) - self.assertTrue(torch.allclose(expected, actual)) + run_test((3, 3), cast) + run_test((5, 3, 3), cast) + run_test((7, 3, 5, 5), cast) + run_test((7, 5, 3, 3, 3), cast) - x = torch.randn(7, 5, device=device) - y = torch.randn(5, 3, device=device).t() - actual = torch.cdist(x, y, p=2) - expected = brute_cdist(x, y, p=2) - self.assertTrue(x.is_contiguous()) - self.assertFalse(y.is_contiguous()) - self.assertTrue(torch.allclose(expected, actual)) + @skipIfNoLapack + def test_lu_unpack(self): + self._test_lu_unpack(self, lambda t: t) - x = torch.randn(5, 7, device=device).t() - y = torch.randn(3, 5, device=device) - actual = torch.cdist(x, y, p=2) - expected = brute_cdist(x, y, p=2) - self.assertFalse(x.is_contiguous()) - self.assertTrue(y.is_contiguous()) - self.assertTrue(torch.allclose(expected, actual)) + def test_bmm(self): + num_batches = 10 + M, N, O = 23, 8, 12 + b1 = torch.randn(num_batches, M, N) + b2 = torch.randn(num_batches, N, O) + res = torch.bmm(b1, b2) + for i in range(num_batches): + r = torch.mm(b1[i], b2[i]) + self.assertEqual(r, res[i]) + if torch.cuda.is_available(): + # check that mixed arguments are rejected + self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cuda())) + self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cuda(), b2)) - def test_cdist_non_contiguous_batch(self): - for device in torch.testing.get_all_device_types(): - x = torch.randn(4, 3, 2, 5, 7, device=device).transpose(-1, -2) - y = torch.randn(4, 3, 2, 5, 3, device=device).transpose(-1, -2) - actual = torch.cdist(x, y, p=2) - expected = brute_cdist(x, y, p=2) - self.assertFalse(x.is_contiguous()) - self.assertFalse(y.is_contiguous()) - self.assertTrue(torch.allclose(expected, actual)) + def test_addbmm(self): + # num_batches = 10 + # M, N, O = 12, 8, 5 + num_batches = 2 + M, N, O = 2, 3, 4 + b1 = torch.randn(num_batches, M, N) + b2 = torch.randn(num_batches, N, O) + res = torch.bmm(b1, b2) + res2 = torch.Tensor().resize_as_(res[0]).zero_() - x = torch.randn(7, 2, 7, 5, device=device) - y = torch.randn(7, 2, 5, 3, device=device).transpose(-1, -2) - actual = torch.cdist(x, y, p=2) - expected = brute_cdist(x, y, p=2) - self.assertTrue(x.is_contiguous()) - self.assertFalse(y.is_contiguous()) - self.assertTrue(torch.allclose(expected, actual)) + res2.addbmm_(b1, b2) + self.assertEqual(res2, res.sum(0, False)) - x = torch.randn(4, 5, 7, device=device).transpose(-1, -2) - y = torch.randn(4, 3, 5, device=device) - actual = torch.cdist(x, y, p=2) - expected = brute_cdist(x, y, p=2) - self.assertFalse(x.is_contiguous()) - self.assertTrue(y.is_contiguous()) - self.assertTrue(torch.allclose(expected, actual)) + res2.addbmm_(1, b1, b2) + self.assertEqual(res2, res.sum(0, False) * 2) - @unittest.skipIf(not TEST_SCIPY, "Scipy not found") - def test_logsumexp(self): - from scipy.special import logsumexp - a = torch.randn(5, 4) - a[0, 0] = inf - a[1, :] = -inf - actual = a.logsumexp(1) - expected = logsumexp(a.numpy(), 1) - self.assertEqual(expected.shape, actual.shape) - self.assertTrue(np.allclose(expected, actual.numpy())) - # check that out is actually inplace - b = torch.zeros(5, 2) - c = b[:, 0] - torch.logsumexp(a, 1, out=c) - self.assertTrue(np.allclose(expected, b[:, 0].numpy())) + res2.addbmm_(1., .5, b1, b2) + self.assertEqual(res2, res.sum(0, False) * 2.5) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_cpu_parallel(self): - # To use parallel branches we'll need to compare on tensors - # that are relatively large. Even if this is run on a single - # core machine these tests will still give you signal on - # the correctness + res3 = torch.addbmm(1, res2, 0, b1, b2) + self.assertEqual(res3, res2) - def _run_test(size): - for dim in range(len(size) + 1): - nv = np.round(np.random.rand(*size)) # 0s and 1s - tv = torch.from_numpy(nv) - # Parallelisim is only used if numel is - # larger than grainsize defined in Parallel.h - self.assertTrue(tv.numel() > 32768) - if dim == len(size): - nvs = nv.sum() - tvs = tv.sum() - else: - nvs = nv.sum(dim) - tvs = tv.sum(dim) - diff = np.abs(nvs - tvs.numpy()).sum() - self.assertEqual(diff, 0) + res4 = torch.addbmm(1, res2, .5, b1, b2) + self.assertEqual(res4, res.sum(0, False) * 3) - _run_test([2, 3, 3, 3, 3, 2, 2, 3, 2, 3, 2, 3, 3]) - _run_test([4, 4, 4, 4, 4, 4, 4, 4, 4, 4]) - _run_test([1, 32 * 8 * 32 * 8]) - _run_test([1, 32770]) + res5 = torch.addbmm(0, res2, 1, b1, b2) + self.assertEqual(res5, res.sum(0, False)) - def _testCSelection(self, torchfn, mathfn): - # Two tensors - size = (100, 100) - a = torch.rand(*size) - b = torch.rand(*size) - c = torchfn(a, b) - expected_c = torch.zeros(*size) - expected_c.map2_(a, b, lambda _, a, b: mathfn(a, b)) - self.assertEqual(expected_c, c, 0) + res6 = torch.addbmm(.1, res2, .5, b1, b2) + self.assertEqual(res6, res2 * .1 + (res.sum(0) * .5)) - def test_max_elementwise(self): - self._testCSelection(torch.max, max) + def test_baddbmm(self): + num_batches = 10 + M, N, O = 12, 8, 5 + b1 = torch.randn(num_batches, M, N) + b2 = torch.randn(num_batches, N, O) + res = torch.bmm(b1, b2) + res2 = torch.Tensor().resize_as_(res).zero_() - def test_min_elementwise(self): - self._testCSelection(torch.min, min) + res2.baddbmm_(b1, b2) + self.assertEqual(res2, res) - @staticmethod - def _test_lerp(self, cast): - start_end_shapes = [(), (5,), (5, 5), (5, 5, 5)] - for shapes in product(start_end_shapes, start_end_shapes): - start = cast(torch.randn(shapes[0])) - end = cast(torch.randn(shapes[1])) + res2.baddbmm_(1, b1, b2) + self.assertEqual(res2, res * 2) - # Tensor weights - for weight in [cast(torch.randn(shapes[0])), random.random()]: - actual = torch.lerp(start, end, weight) - actual_method = start.lerp(end, weight) - self.assertEqual(actual, actual_method) - actual_out = cast(torch.Tensor()) - torch.lerp(start, end, weight, out=actual_out) - self.assertEqual(actual, actual_out) - expected = start + weight * (end - start) - self.assertEqual(expected, actual) + res2.baddbmm_(1, .5, b1, b2) + self.assertEqual(res2, res * 2.5) - def test_lerp(self): - self._test_lerp(self, lambda t: t) + res3 = torch.baddbmm(1, res2, 0, b1, b2) + self.assertEqual(res3, res2) - def test_all_any(self): - def test(size): - x = torch.ones(*size).byte() - self.assertTrue(x.all()) - self.assertTrue(x.any()) + res4 = torch.baddbmm(1, res2, .5, b1, b2) + self.assertEqual(res4, res * 3) - x[3] = 0 - self.assertFalse(x.all()) - self.assertTrue(x.any()) - - x.zero_() - self.assertFalse(x.all()) - self.assertFalse(x.any()) - - x.fill_(2) - self.assertTrue(x.all()) - self.assertTrue(x.any()) - - x = torch.ones(*size).bool() - self.assertTrue(x.all()) - self.assertTrue(x.any()) - - x[3] = False - self.assertFalse(x.all()) - self.assertTrue(x.any()) - - test((10,)) - test((5, 5)) - - def test_all_any_empty(self): - x = torch.ByteTensor() - self.assertTrue(x.all()) - self.assertFalse(x.any()) - - x = torch.BoolTensor() - self.assertTrue(x.all()) - self.assertFalse(x.any()) - - def test_all_any_with_dim(self): - def test(x): - r1 = x.prod(dim=0, keepdim=False).byte() - r2 = x.all(dim=0, keepdim=False) - self.assertEqual(r1.shape, r2.shape) - self.assertTrue((r1 == r2).all()) - - r3 = x.sum(dim=1, keepdim=True).clamp(0, 1).byte() - r4 = x.any(dim=1, keepdim=True) - self.assertEqual(r3.shape, r4.shape) - self.assertTrue((r3 == r4).all()) - - test(torch.ByteTensor([[0, 0, 0], - [0, 0, 1], - [0, 1, 1], - [1, 1, 1]])) + res5 = torch.baddbmm(0, res2, 1, b1, b2) + self.assertEqual(res5, res) - @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') - def test_all_any_empty_cuda(self): - x = torch.cuda.ByteTensor() - self.assertTrue(x.all()) - self.assertFalse(x.any()) + res6 = torch.baddbmm(.1, res2, .5, b1, b2) + self.assertEqual(res6, res2 * .1 + res * .5) - x = torch.cuda.BoolTensor() - self.assertTrue(x.all()) - self.assertFalse(x.any()) + def test_pow(self): + # [res] torch.pow([res,] x) - def test_mv(self): - def _test_mv(m1, v1): - res1 = torch.mv(m1, v1) + # pow has dedicated implementation for different exponents + for exponent in [-2, -1, -0.5, 0.5, 1, 2, 3, 4]: + # base - tensor, exponent - number + # contiguous + m1 = torch.rand(100, 100) + 0.5 + res1 = torch.pow(m1[4], exponent) res2 = res1.clone().zero_() - for i, j in iter_indices(m1): - res2[i] += m1[i][j] * v1[j] - + for i in range(res2.size(0)): + res2[i] = math.pow(m1[4][i], exponent) self.assertEqual(res1, res2) - _test_mv(torch.randn(100, 100, dtype=torch.float32), torch.randn(100, dtype=torch.float32)) - _test_mv(torch.randn(100, 100, dtype=torch.float64), torch.randn(100, dtype=torch.float64)) - _test_mv(torch.randint(0, 100, (100, 100), dtype=torch.int32), torch.randint(0, 100, (100, ), dtype=torch.int32)) - _test_mv(torch.randint(0, 100, (100, 100), dtype=torch.int64), torch.randint(0, 100, (100, ), dtype=torch.int64)) - _test_mv(torch.randn(100, 100, dtype=torch.float32).bfloat16(), torch.randn(100, dtype=torch.float32).bfloat16()) + # non-contiguous + m1 = torch.rand(100, 100) + 0.5 + res1 = torch.pow(m1[:, 4], exponent) + res2 = res1.clone().zero_() + for i in range(res2.size(0)): + res2[i] = math.pow(m1[i, 4], exponent) + self.assertEqual(res1, res2) - def test_numpy_args(self): - x1 = torch.randn(10) - x2 = torch.randn(10) - res1 = torch.add(input=x1, other=x2) - res2 = torch.add(x1=x1, x2=x2) + # base - number, exponent - tensor + # contiguous + m1 = torch.randn(100, 100) + res1 = torch.pow(3, m1[4]) + res2 = res1.clone().zero_() + for i in range(res2.size(0)): + res2[i] = math.pow(3, m1[4, i]) self.assertEqual(res1, res2) - x1 = torch.randn(10, 10, 10) - res1 = x1.sum(dim=(0, 2), keepdim=True) - res2 = x1.sum(axis=(0, 2), keepdims=True) + # non-contiguous + m1 = torch.randn(100, 100) + res1 = torch.pow(3, m1[:, 4]) + res2 = res1.clone().zero_() + for i in range(res2.size(0)): + res2[i] = math.pow(3, m1[i][4]) self.assertEqual(res1, res2) - def test_addcdiv(self): - def _test_addcdiv(a, alpha, b, c): - actual = torch.addcdiv(a, alpha, b, c) - expected = a + (alpha * b) / c - self.assertTrue(torch.allclose(expected, actual, equal_nan=True)) - - def non_zero_rand(size, dtype, device): - if dtype.is_floating_point: - a = torch.rand(size=size, dtype=dtype, device=device) - elif dtype == torch.uint8: - a = torch.randint(1, 5, size=size, dtype=dtype, device=device) - else: - a = torch.randint(-5, 5, size=size, dtype=dtype, device=device) - return a + (a == 0).type(dtype) + # resize behavior for exp == 1 + m1 = torch.randn(2, 2) + out = torch.randn([0]) + torch.pow(m1, 1, out=out) + self.assertEqual(out, m1) - for device in torch.testing.get_all_device_types(): - for dtype in torch.testing.get_all_math_dtypes(device): - _test_addcdiv( - non_zero_rand((2, 2), dtype=dtype, device=device), - 0.5, - non_zero_rand((2, 2), dtype=dtype, device=device), - non_zero_rand((2, 2), dtype=dtype, device=device)) - - def test_add(self): - for device in torch.testing.get_all_device_types(): - # [res] torch.add([res,] tensor1, tensor2) - m1 = torch.randn(100, 100, device=device) - v1 = torch.randn(100, device=device) + def _test_cop(self, torchfn, mathfn): + def reference_implementation(res2): + for i, j in iter_indices(sm1): + idx1d = i * sm1.size(0) + j + res2[i, j] = mathfn(sm1[i, j], sm2[idx1d]) + return res2 - # contiguous - res1 = torch.add(m1[4], v1) - res2 = res1.clone().zero_() - for i in range(m1.size(1)): - res2[i] = m1[4, i] + v1[i] - self.assertEqual(res1, res2) + # contiguous + m1 = torch.randn(10, 10, 10) + m2 = torch.randn(10, 10 * 10) + sm1 = m1[4] + sm2 = m2[4] - m1 = torch.randn(100, 100, device=device) - v1 = torch.randn(100, device=device) + res1 = torchfn(sm1, sm2.view(10, 10)) + res2 = reference_implementation(res1.clone()) + self.assertEqual(res1, res2) - # non-contiguous - res1 = torch.add(m1[:, 4], v1) - res2 = res1.clone().zero_() - for i in range(m1.size(0)): - res2[i] = m1[i, 4] + v1[i] - self.assertEqual(res1, res2) + # non-contiguous + m1 = torch.randn(10, 10, 10) + m2 = torch.randn(10 * 10, 10 * 10) + sm1 = m1[:, 4] + sm2 = m2[:, 4] + # view as sm1.size() + sm2.set_(sm2.storage(), sm2.storage_offset(), sm1.size(), (sm2.stride()[0] * 10, sm2.stride()[0])) + res1 = torchfn(sm1, sm2) + # reference_implementation assumes 1-d sm2 + sm2.set_(sm2.storage(), sm2.storage_offset(), m2[:, 4].size(), m2[:, 4].stride()) + res2 = reference_implementation(res1.clone()) + self.assertEqual(res1, res2) - # [res] torch.add([res,] tensor, value) - m1 = torch.randn(10, 10, device=device) + def test_cdiv(self): + self._test_cop(torch.div, lambda x, y: x / y) - # contiguous - res1 = m1.clone() - res1[3].add_(2) - res2 = m1.clone() - for i in range(m1.size(1)): - res2[3, i] = res2[3, i] + 2 - self.assertEqual(res1, res2) + def test_cfmod(self): + self._test_cop(torch.fmod, math.fmod) - # non-contiguous - m1 = torch.randn(10, 10, device=device) - res1 = m1.clone() - res1[:, 3].add_(2) - res2 = m1.clone() - for i in range(m1.size(0)): - res2[i, 3] = res2[i, 3] + 2 - self.assertEqual(res1, res2) + def test_cremainder(self): + self._test_cop(torch.remainder, lambda x, y: x % y) - # inter-type - m1 = torch.randn(10, 10, device=device) - self.assertEqual(m1 + 3, m1 + torch.tensor(3)) - self.assertEqual(3 + m1, torch.tensor(3) + m1) - one = torch.tensor(1, dtype=torch.uint8, device=device) - self.assertEqual(torch.add(one, 1), 2) - self.assertEqual(torch.add(one, 1).dtype, torch.uint8) - - # contiguous + non-contiguous - m1 = torch.randn(10, 10, device=device) - m2 = torch.randn(10, 10, device=device).t() - res = m1 + m2 - self.assertTrue(res.is_contiguous()) - self.assertEqual(res, m1 + m2.contiguous()) - - # 1d + empty - m1 = torch.tensor([1.0], dtype=torch.float, device=device) - m2 = torch.tensor([], dtype=torch.float, device=device) - self.assertEqual(m1 + m2, []) - - # bool - m1 = torch.tensor([True, False, False, True, False, False], dtype=torch.bool, device=device) - m2 = torch.tensor([True, True, False, False, False, True], dtype=torch.bool, device=device) - expected = torch.tensor([True, True, False, True, False, True], dtype=torch.bool, device=device) - self.assertEqual(m1 + m2, expected) - - # fused multiply add - a = torch.zeros(2, 3, dtype=torch.bool, device=device) - res = torch.add(a, a, alpha=0) - expected = torch.zeros(2, 3, device=device).bool() - self.assertEqual(res, expected) - - # bfloat16 - m1 = torch.tensor([1., 2.], dtype=torch.bfloat16) - m2 = torch.tensor([3., 4.], dtype=torch.bfloat16) - self.assertEqual(m1 + m2, torch.tensor([4., 6.], dtype=torch.bfloat16)) - - def test_bool_sub(self): - for device in torch.testing.get_all_device_types(): - m1 = torch.tensor([True, False, False, True, False, False], dtype=torch.bool, device=device) - m2 = torch.tensor([True, True, False, False, False, True], dtype=torch.bool, device=device) - self.assertRaisesRegex(RuntimeError, - r"Subtraction, the `\-` operator, with two bool tensors is not supported. " - r"Use the `\^` or `logical_xor\(\)` operator instead.", - lambda: m1 - m2) - self.assertRaisesRegex(RuntimeError, - r"Subtraction, the `\-` operator, with a bool tensor is not supported. " - r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.", - lambda: 1 - m1) - self.assertRaisesRegex(RuntimeError, - r"Subtraction, the `\-` operator, with a bool tensor is not supported. " - r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.", - lambda: m2 - 1) + def test_cmul(self): + self._test_cop(torch.mul, lambda x, y: x * y) - def test_sub(self): - for dtype in torch.testing.get_all_dtypes(): - m1 = torch.tensor([2.34, 4.44], dtype=dtype) - m2 = torch.tensor([1.23, 2.33], dtype=dtype) + def test_cpow(self): + self._test_cop(torch.pow, lambda x, y: nan if x < 0 else math.pow(x, y)) - if (dtype == torch.half or dtype == torch.bool): - self.assertRaises(RuntimeError, lambda: m1 - m2) - elif (dtype == torch.bfloat16): - # bfloat16 has a lower precision so we have to have a separate check for it - self.assertEqual(m1 - m2, torch.tensor([1.11, 2.11], dtype=dtype), 0.01) - else: - self.assertEqual(m1 - m2, torch.tensor([1.11, 2.11], dtype=dtype)) + @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') + def test_einsum(self): + # test cases taken from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f + x = torch.randn(5) + y = torch.randn(7) + A = torch.randn(3, 5) + B = torch.randn(2, 5) + C = torch.randn(2, 3, 5) + D = torch.randn(2, 5, 7) + E = torch.randn(7, 9) + F = torch.randn(2, 3, 5, 7) + G = torch.randn(7, 11, 13) + H = torch.randn(4, 4) + I = torch.randn(3, 4, 4) + l = torch.randn(5, 10) + r = torch.randn(5, 20) + w = torch.randn(30, 10, 20) + test_list = [ + # -- Vector + ("i->", x), # sum + ("i,i->", x, x), # dot + ("i,i->i", x, x), # vector element-wise mul + ("i,j->ij", x, y), # outer + # -- Matrix + ("ij->ji", A), # transpose + ("ij->j", A), # row sum + ("ij->i", A), # col sum + ("ij,ij->ij", A, A), # matrix element-wise mul + ("ij,j->i", A, x), # matrix vector multiplication + ("ij,kj->ik", A, B), # matmul + ("ij,ab->ijab", A, E), # matrix outer product + # -- Tensor + ("aij,ajk->aik", C, D), # batch matmul + ("ijk,jk->i", C, A), # tensor matrix contraction + ("aij,jk->aik", D, E), # tensor matrix contraction + ("abcd,dfg->abcfg", F, G), # tensor tensor contraction + ("ijk,jk->ik", C, A), # tensor matrix contraction with double indices + ("ijk,jk->ij", C, A), # tensor matrix contraction with double indices + ("ijk,ik->j", C, B), # non contiguous + ("ijk,ik->jk", C, B), # non contiguous with double indices + # -- Diagonal + ("ii", H), # trace + ("ii->i", H), # diagonal + # -- Ellipsis + ("i...->...", H), + ("ki,...k->i...", A.t(), B), + ("k...,jk", A.t(), B), + ("...ii->...i", I), # batch diagonal + # -- Other + ("bn,anm,bm->ba", l, w, r), # as torch.bilinear + ("... ii->...i ", I), # batch diagonal with spaces + ] + for test in test_list: + actual = torch.einsum(test[0], test[1:]) + expected = np.einsum(test[0], *[t.numpy() for t in test[1:]]) + self.assertEqual(expected.shape, actual.shape, test[0]) + self.assertTrue(np.allclose(expected, actual.numpy()), test[0]) + # test vararg + actual2 = torch.einsum(test[0], *test[1:]) + self.assertEqual(expected.shape, actual2.shape, test[0]) + self.assertTrue(np.allclose(expected, actual2.numpy()), test[0]) - def test_csub(self): - # with a tensor - a = torch.randn(100, 90) - b = a.clone().normal_() + def do_einsum(*args): + return torch.einsum(test[0], args) + # FIXME: following test cases fail gradcheck + if test[0] not in {"i,i->", "i,i->i", "ij,ij->ij"}: + gradcheck_inps = tuple(t.detach().requires_grad_() for t in test[1:]) + self.assertTrue(torch.autograd.gradcheck(do_einsum, gradcheck_inps)) + self.assertTrue(A._version == 0) # check that we do not use inplace ops - res_add = torch.add(a, -1, b) - res_csub = a.clone() - res_csub.sub_(b) - self.assertEqual(res_add, res_csub) + def test_sum_all(self): + def check_sum_all(tensor): + pylist = tensor.reshape(-1).tolist() + self.assertEqual(tensor.sum(), sum(pylist)) - # with a scalar - a = torch.randn(100, 100) + check_sum_all(torch.tensor([1, 2, 3, 4, 5])) + check_sum_all(torch.randn(200000)) + check_sum_all(torch.randn(2000, 2)[:, 0]) + check_sum_all(torch.tensor([True, False, True], dtype=torch.bool)) - scalar = 123.5 - res_add = torch.add(a, -scalar) - res_csub = a.clone() - res_csub.sub_(scalar) - self.assertEqual(res_add, res_csub) + def _assert_matches_numpy(self, t, n): + self.assertEqual(n.shape, t.shape) + if t.dtype == torch.float: + self.assertTrue(np.allclose(n, t.numpy(), rtol=1e-03, atol=1e-05, + equal_nan=True)) + else: + self.assertTrue(np.allclose(n, t.numpy(), equal_nan=True)) - @staticmethod - def _test_neg(self, cast): - float_types = [torch.DoubleTensor, torch.FloatTensor, torch.LongTensor] - int_types = [torch.IntTensor, torch.ShortTensor, torch.ByteTensor, - torch.CharTensor] - - for t in float_types + int_types: - if t in float_types: - a = cast(torch.randn(100, 90).type(t)) - else: - a = cast(torch.randint(-128, 128, (100, 90), dtype=t.dtype)) - zeros = cast(torch.Tensor().type(t)).resize_as_(a).zero_() + def _test_dim_ops(self, pytorch_op, numpy_op, + use_floating=True, use_integral=True): + def do_one(tensors_dict, dim): + for category, tensors in tensors_dict.items(): + if category == "slice": + dim = 0 + for tensor in tensors: + # we have no control over NumPy warnings... + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + expected = numpy_op(tensor.numpy(), dim) + actual = pytorch_op(tensor, dim) + self._assert_matches_numpy(actual, expected) + if torch.cuda.is_available(): + self._assert_matches_numpy(pytorch_op(tensor.cuda(), + dim).cpu(), + expected) + do_one(self._make_tensors((5, 400000), use_floating=use_floating, + use_integral=use_integral), 1) + do_one(self._make_tensors((3, 5, 7), use_floating=use_floating, + use_integral=use_integral), 0) + do_one(self._make_tensors((3, 5, 7), use_floating=use_floating, + use_integral=use_integral), 1) + do_one(self._make_tensors((3, 5, 7), use_floating=use_floating, + use_integral=use_integral), 2) + do_one(self._make_tensors((100000, ), use_floating=use_floating, + use_integral=use_integral), -1) + do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, + use_integral=use_integral), 0) + do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, + use_integral=use_integral), 1) + do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, + use_integral=use_integral), 2) + do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, + use_integral=use_integral), (1, 2)) + do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, + use_integral=use_integral), (1, -1)) + do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, + use_integral=use_integral), (0, 2)) + do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, + use_integral=use_integral), (0, 2, 1)) - if t == torch.ByteTensor: - res_add = torch.add(zeros, a, alpha=255) - else: - res_add = torch.add(zeros, a, alpha=-1) - res_neg = a.clone() - res_neg.neg_() - self.assertEqual(res_neg, res_add) + @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') + def test_sum_dim(self): + self._test_dim_ops( + lambda t, d: t.sum(d), + lambda n, d: n.sum(d)) - # test out of place as well - res_neg_out_place = a.clone().neg() - self.assertEqual(res_neg_out_place, res_add) + @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') + def test_mean_dim(self): + self._test_dim_ops( + lambda t, d: t.mean(d), + lambda n, d: n.mean(d), + use_integral=False) - # test via __neg__ operator - res_neg_op = -a.clone() - self.assertEqual(res_neg_op, res_add) + @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') + def test_std_dim(self): + for unbiased in [False, True]: + self._test_dim_ops( + lambda t, d: t.std(d, unbiased=unbiased), + lambda n, d: n.std(d, ddof=1 if unbiased else 0), + use_integral=False) - # bool - self.assertRaisesRegex( - RuntimeError, - r"Negation, the `\-` operator, on a bool tensor is not supported. " - r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.", - lambda: - cast(torch.tensor([False, True]))) + @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') + def test_var_dim(self): + for unbiased in [False, True]: + self._test_dim_ops( + lambda t, d: t.var(d, unbiased=unbiased), + lambda n, d: n.var(d, ddof=1 if unbiased else 0), + use_integral=False) - def test_neg(self): - self._test_neg(self, lambda t: t) + @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') + @unittest.skipIf(not TEST_SCIPY, 'Scipy not found') + def test_logsumexp_dim(self): + from scipy.special import logsumexp + self._test_dim_ops( + lambda t, d: t.logsumexp(d), + lambda n, d: logsumexp(n, d), + use_integral=False) - @staticmethod - def _test_bitwise_not(self, device): - res = 0xffff - torch.arange(127, dtype=torch.int8, device=device) - for dtype in (torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): - if dtype == torch.bool: - a = torch.tensor([True, False], device=device) - expected_res = torch.tensor([False, True], device=device) - else: - a = torch.arange(127, dtype=dtype, device=device) - expected_res = res.type(dtype) - # new tensor - self.assertEqual(expected_res, a.bitwise_not()) - # out - b = torch.empty(0, dtype=dtype, device=device) - torch.bitwise_not(a, out=b) - self.assertEqual(expected_res, b) - # in-place - a.bitwise_not_() - self.assertEqual(expected_res, a) + def test_sum_out(self): + x = torch.rand(100, 100) + res1 = torch.sum(x, 1) + res2 = torch.Tensor() + torch.sum(x, 1, out=res2) + self.assertEqual(res1, res2) + x = torch.rand(100, 100, 100) + res1 = x.sum(2).sum(1) + res2 = torch.Tensor() + torch.sum(x, (2, 1), out=res2) + self.assertEqual(res1, res2) - # test exceptions - for dtype in(torch.half, torch.float, torch.double): - a = torch.zeros(10, dtype=dtype, device=device) - # new tensor - with self.assertRaises(RuntimeError): - a.bitwise_not() - # out - b = torch.empty(0, dtype=dtype, device=device) - with self.assertRaises(RuntimeError): - torch.bitwise_not(a, out=b) - # in-place - with self.assertRaises(RuntimeError): - a.bitwise_not_() + # TODO: these tests only check if it's possible to pass a return value + # it'd be good to expand them + def test_prod(self): + x = torch.rand(100, 100) + res1 = torch.prod(x, 1) + res2 = torch.Tensor() + torch.prod(x, 1, out=res2) + self.assertEqual(res1, res2) - def test_bitwise_not(self): - self._test_bitwise_not(self, 'cpu') + def _test_reduce_integer_upcast(self, fn, has_out=True): + shape = (3, 4, 5) + reduced_shape = fn(torch.ones(shape)).shape - @staticmethod - def _test_logical_not(self, device): - for dtype in torch.testing.get_all_dtypes(): - a = torch.tensor([10, 1, 0], dtype=dtype, device=device) - if dtype == torch.bfloat16: - self.assertRaises(RuntimeError, lambda: a.logical_not()) - continue - expected_res = torch.tensor([0, 0, 1], dtype=dtype, device=device) - # new tensor - self.assertEqual(expected_res.bool(), a.logical_not()) - # out - for out_dtype in torch.testing.get_all_dtypes(): - b = torch.empty(0, dtype=out_dtype, device=device) - if out_dtype == torch.bfloat16: - self.assertRaises(RuntimeError, lambda: torch.logical_not(a, out=b)) - continue - torch.logical_not(a, out=b) - self.assertEqual(expected_res.bool(), b.bool()) - # in-place - a.logical_not_() - self.assertEqual(expected_res, a) + def _test_out(dtype, other_dtype): + out = torch.ones(reduced_shape, dtype=dtype) + result = fn(x, out=out) + self.assertIs(out.dtype, result.dtype) + self.assertEqual(fn(x.type(dtype)), result) + result = fn(x, out=out, dtype=dtype) + self.assertIs(out.dtype, result.dtype) + self.assertEqual(fn(x.type(dtype)), result) + # 'out' is favored over dtype, check error + self.assertRaises(RuntimeError, lambda: fn(x, out=out, dtype=other_dtype)) - def test_logical_not(self): - self._test_logical_not(self, 'cpu') + for dtype in [dtype for dtype in torch.testing.get_all_math_dtypes('cpu') if dtype != torch.float16]: + x = torch.ones(shape, dtype=dtype) + expected_dtype = dtype if dtype.is_floating_point else torch.int64 + self.assertIs(expected_dtype, fn(x).dtype) + self.assertEqual(fn(x.type(expected_dtype)), fn(x)) - @staticmethod - def _test_logical_xor(self, device): - for dtype in (torch.bool,): # Will add more dtypes in the future - expected_res = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) - a = torch.tensor([10, 0, 1, 0], dtype=dtype, device=device) - b = torch.tensor([1, 0, 0, 10], dtype=dtype, device=device) - # new tensor - self.assertEqual(expected_res, a.logical_xor(b)) - # out - c = torch.empty(0, dtype=dtype, device=device) - torch.logical_xor(a, b, out=c) - self.assertEqual(expected_res, c) - # out is not bool - c = torch.empty(0, dtype=torch.uint8, device=device) - with self.assertRaisesRegex(RuntimeError, - r"The output tensor of logical_xor must be a bool tensor\."): - torch.logical_xor(a, b, out=c) - # in-place - a.logical_xor_(b) - self.assertEqual(expected_res, a) + if dtype.is_floating_point: + other_dtype = torch.float32 if dtype == torch.float64 else torch.float64 + else: + other_dtype = torch.int32 if dtype != torch.int32 else torch.int16 + self.assertIs(other_dtype, fn(x, dtype=other_dtype).dtype) + self.assertEqual(fn(x.type(other_dtype)), fn(x, dtype=other_dtype)) - def test_logical_xor(self): - self._test_logical_xor(self, 'cpu') + # test mixed int/float + mixed_dtype = torch.int32 if dtype.is_floating_point else torch.float32 + self.assertIs(mixed_dtype, fn(x, dtype=mixed_dtype).dtype) + self.assertEqual(fn(x.type(mixed_dtype)), fn(x, dtype=mixed_dtype)) - def test_threshold(self): - for dtype in torch.testing.get_all_math_dtypes('cpu'): - if dtype != torch.uint8 and dtype != torch.float16: - # 100 is wide enough to use AVX2 instructions for all types - x = torch.randn(100).sign().to(dtype=dtype) - y = torch.threshold(x, 0, 0) - self.assertTrue(y.le(0).any()) + if has_out: + _test_out(dtype, other_dtype) + _test_out(dtype, mixed_dtype) - def test_reciprocal(self): - for dtype in [torch.float, torch.double]: - a = torch.randn(100, 89, dtype=dtype) - res_div = 1 / a - res_reciprocal = a.clone() - res_reciprocal.reciprocal_() - self.assertEqual(res_reciprocal, res_div) + def test_sum_integer_upcast(self): + self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, **kwargs), False) + self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, 0, **kwargs)) - def test_mul(self): - for device in torch.testing.get_all_device_types(): - m1 = torch.randn(10, 10, device=device) - res1 = m1.clone() - res1[:, 3].mul_(2) - res2 = m1.clone() - for i in range(res1.size(0)): - res2[i, 3] = res2[i, 3] * 2 - self.assertEqual(res1, res2) + def test_prod_integer_upcast(self): + self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, **kwargs), False) + self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, 0, **kwargs)) - a1 = torch.tensor([True, False, False, True], dtype=torch.bool, device=device) - a2 = torch.tensor([True, False, True, False], dtype=torch.bool, device=device) - self.assertEqual(a1 * a2, torch.tensor([True, False, False, False], dtype=torch.bool, device=device)) + def test_cumsum_integer_upcast(self): + self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumsum(x, 0, **kwargs)) - if device == 'cpu': - a1 = torch.tensor([0.1, 0.1], dtype=torch.bfloat16, device=device) - a2 = torch.tensor([1.1, 0.1], dtype=torch.bfloat16, device=device) - self.assertEqual(a1 * a2, torch.tensor([0.11, 0.01], dtype=torch.bfloat16, device=device), 0.01) - self.assertEqual(a1.mul(a2), a1 * a2) + def test_cumprod_integer_upcast(self): + self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumprod(x, 0, **kwargs)) - def test_div(self): - m1 = torch.randn(10, 10) - res1 = m1.clone() - res1[:, 3].div_(2) - res2 = m1.clone() - for i in range(m1.size(0)): - res2[i, 3] = res2[i, 3] / 2 + def test_cross(self): + x = torch.rand(100, 3, 100) + y = torch.rand(100, 3, 100) + res1 = torch.cross(x, y) + res2 = torch.Tensor() + torch.cross(x, y, out=res2) self.assertEqual(res1, res2) - a1 = torch.tensor([4.2, 6.2], dtype=torch.bfloat16) - a2 = torch.tensor([2., 2.], dtype=torch.bfloat16) - self.assertEqual(a1 / a2, torch.tensor([2.1, 3.1], dtype=torch.bfloat16), 0.01) - self.assertEqual(a1.div(a2), a1 / a2) - - def test_floordiv(self): - for dtype in torch.testing.get_all_math_dtypes('cpu'): - if dtype is torch.float16: - continue - x = torch.randn(100).mul(10).to(dtype) - y = x // 3 - self.assertEqual(y.dtype, x.dtype) - z = torch.tensor([math.trunc(v.item() / 3.) for v in x], dtype=y.dtype) - self.assertEqual(y, z) + def test_cross_with_and_without_dim(self): + x = torch.rand(100, 3) + y = torch.rand(100, 3) + res1 = torch.cross(x, y, dim=1) + res2 = torch.cross(x, y, dim=-1) + res3 = torch.cross(x, y) + self.assertEqual(res1, res2) + self.assertEqual(res1, res3) - def test_rdiv(self): - for dtype in torch.testing.get_all_math_dtypes('cpu'): - if dtype is torch.float16: - continue - x = torch.rand(100).add(1).mul(4).to(dtype) - y = 30 / x - if dtype.is_floating_point: - z = torch.tensor([30 / v.item() for v in x], dtype=dtype) - else: - z = torch.tensor([math.trunc(30. / v.item()) for v in x], dtype=dtype) - self.assertEqual(y, z) + def test_cross_validation(self): + self.assertRaisesRegex( + RuntimeError, "inconsistent tensors dimensions", + lambda: torch.cross(torch.rand(100, 3), torch.rand(100, 3, 10))) + self.assertRaisesRegex( + RuntimeError, "inconsistent tensors sizes", + lambda: torch.cross(torch.rand(5, 3), torch.rand(3, 5))) + self.assertRaisesRegex( + RuntimeError, "no dimension of size 3 in input", + lambda: torch.cross(torch.rand(5, 4), torch.rand(5, 4))) + self.assertRaisesRegex( + RuntimeError, "dimension 0 does not have size 3", + lambda: torch.cross(torch.rand(5, 4, 3), torch.rand(5, 4, 3), dim=0)) + self.assertRaisesRegex( + RuntimeError, "dimension -1 does not have size 3", + lambda: torch.cross(torch.rand(5, 3, 4), torch.rand(5, 3, 4), dim=-1)) + self.assertRaisesRegex( + IndexError, "Dimension out of range", + lambda: torch.cross(torch.rand(5, 3, 4), torch.rand(5, 3, 4), dim=-5)) - def test_fmod(self): - m1 = torch.Tensor(10, 10).uniform_(-10., 10.) - res1 = m1.clone() - q = 2.1 - res1[:, 3].fmod_(q) - res2 = m1.clone() - for i in range(m1.size(1)): - res2[i, 3] = math.fmod(res2[i, 3], q) + def test_zeros(self): + res1 = torch.zeros(100, 100) + res2 = torch.Tensor() + torch.zeros(100, 100, out=res2) self.assertEqual(res1, res2) - def test_remainder(self): - # Check the Floating point case, both tensor and scalar overloads - for use_item in [True, False]: - m1 = torch.Tensor(10, 10).uniform_(-10., 10.) - res1 = m1.clone() - res2 = m1.clone() - qs = torch.arange(-5.1, 4.1) - # Check the case where the divisor is a simple float - for col_idx, q in enumerate(qs): - # Reference - for i in range(m1.size(0)): - res2[i, col_idx] = res2[i, col_idx] % q - # To test - res1[:, col_idx].remainder_(q if not use_item else q.item()) - self.assertEqual(res1, res2) - # Check the case where the divisor is a tensor - res1 = m1.clone() - res1.remainder_(qs.unsqueeze(0).expand_as(res1)) - self.assertEqual(res1, res2) + boolTensor = torch.zeros(2, 2, dtype=torch.bool) + expected = torch.tensor([[False, False], [False, False]], dtype=torch.bool) + self.assertEqual(boolTensor, expected) - # Check the LongTensor case, both tensor and scalar overloads - for use_item in [True, False]: - long_m1 = torch.LongTensor(10, 10).random_(-10, 10) - long_res1 = long_m1.clone() - long_res2 = long_m1.clone() - long_qs = torch.arange(-5, 5) - long_qs[5] = 5 # Can't handle the divisor=0 case - for col_idx, long_q in enumerate(long_qs): - # Reference - for i in range(long_m1.size(0)): - long_res2[i, col_idx] = long_res2[i, col_idx] % long_q - # To test - long_res1[:, col_idx].remainder_(long_q if not use_item else long_q.item()) - self.assertEqual(long_res1, long_res2) - # Divisor is a tensor case - long_res1 = long_m1.clone() - long_res1.remainder_(long_qs.unsqueeze(0).expand_as(long_res1)) + halfTensor = torch.zeros(1, 1, dtype=torch.half) + expected = torch.tensor([[0.]], dtype=torch.float16) + self.assertEqual(halfTensor, expected) - @staticmethod - def _test_remainder_overflow(self, dtype, device): - # Check Integer Overflows - x = torch.tensor(23500, dtype=dtype, device=device) - q = 392486996410368 - self.assertEqual(x % q, x) - self.assertEqual(-x % q, q - x) - self.assertEqual(x % -q, x - q) - self.assertEqual(-x % -q, -x) + bfloat16Tensor = torch.zeros(1, 1, dtype=torch.bfloat16) + expected = torch.tensor([[0.]], dtype=torch.bfloat16) + self.assertEqual(bfloat16Tensor, expected) - def test_remainder_overflow(self): - self._test_remainder_overflow(self, dtype=torch.int64, device='cpu') + def test_zeros_out(self): + shape = (3, 4) + out = torch.zeros(shape) + torch.zeros(shape, out=out) - def test_mm(self): - def _test_mm(n, m, p, dtype, genf): - # helper function - def matrixmultiply(mat1, mat2): - n = mat1.size(0) - m = mat1.size(1) - p = mat2.size(1) - res = torch.zeros(n, p, dtype=dtype) - for i, j in iter_indices(res): - res[i, j] = sum(mat1[i, k] * mat2[k, j] for k in range(m)) - return res + # change the dtype, layout, device + self.assertRaises(RuntimeError, lambda: torch.zeros(shape, dtype=torch.int64, out=out)) + self.assertRaises(RuntimeError, lambda: torch.zeros(shape, layout=torch.sparse_coo, out=out)) + if torch.cuda.is_available(): + self.assertRaises(RuntimeError, lambda: torch.zeros(shape, device='cuda', out=out)) - # contiguous case - mat1 = genf(n, m) - mat2 = genf(m, p) - res = torch.mm(mat1, mat2) + # leave them the same + self.assertEqual(torch.zeros(shape), torch.zeros(shape, dtype=out.dtype, out=out)) + self.assertEqual(torch.zeros(shape), torch.zeros(shape, layout=torch.strided, out=out)) + self.assertEqual(torch.zeros(shape), torch.zeros(shape, device='cpu', out=out)) - res2 = matrixmultiply(mat1, mat2) - self.assertEqual(res, res2) + def test_ones(self): + res1 = torch.ones(100, 100) + res2 = torch.Tensor() + torch.ones(100, 100, out=res2) + self.assertEqual(res1, res2) - # non contiguous case 1 - mat1 = genf(n, m) - mat2 = genf(p, m).t() - res = torch.mm(mat1, mat2) + # test boolean tensor + res1 = torch.ones(1, 2, dtype=torch.bool) + expected = torch.tensor([[True, True]], dtype=torch.bool) + self.assertEqual(res1, expected) - res2 = matrixmultiply(mat1, mat2) - self.assertEqual(res, res2) + def test_ones_like(self): + expected = torch.ones(100, 100) - # non contiguous case 2 - mat1 = genf(m, n).t() - mat2 = genf(m, p) - res = torch.mm(mat1, mat2) + res1 = torch.ones_like(expected) + self.assertEqual(res1, expected) - res2 = matrixmultiply(mat1, mat2) - self.assertEqual(res, res2) + # test boolean tensor + expected = torch.tensor([True, True], dtype=torch.bool) + res1 = torch.ones_like(expected) + self.assertEqual(res1, expected) - # non contiguous case 3 - mat1 = genf(m, n).t() - mat2 = genf(p, m).t() - res = torch.mm(mat1, mat2) + def test_dtypes(self): + all_dtypes = torch.testing.get_all_dtypes() + do_test_dtypes(self, all_dtypes, torch.strided, torch.device('cpu')) + if torch.cuda.is_available(): + all_dtypes.remove(torch.bfloat16) # Remove once _th_zero_ is enabled on cuda for bfloat16 + do_test_dtypes(self, all_dtypes, torch.strided, torch.device('cuda:0')) - res2 = matrixmultiply(mat1, mat2) - self.assertEqual(res, res2) + def test_copy_dtypes(self): + all_dtypes = torch.testing.get_all_dtypes() + for dtype in all_dtypes: + copied_dtype = copy.deepcopy(dtype) + self.assertIs(dtype, copied_dtype) - # test with zero stride - mat1 = genf(n, m) - mat2 = genf(m, 1).expand(m, p) - res = torch.mm(mat1, mat2) + def test_copy_transpose(self): + x = torch.arange(100 * 100, dtype=torch.float).reshape(100, 100).t() + y = torch.empty(100, 100, dtype=torch.float) + y.copy_(x) + self.assertEqual(y[:, 0], range(100)) + self.assertEqual(y[:, 40], range(4000, 4100)) - res2 = matrixmultiply(mat1, mat2) - self.assertEqual(res, res2) + y = torch.empty(100, 100, dtype=torch.double) + y.copy_(x) + self.assertEqual(y[:, 0], range(100)) + self.assertEqual(y[:, 40], range(4000, 4100)) - # explicitly exercise the _out variant in torch.mm(). - # contiguous case - mat1 = genf(n, m) - mat2 = genf(m, p) - res = genf(n, p) - torch.mm(mat1, mat2, out=res) + def test_device(self): + cpu = torch.device('cpu') + self.assertEqual('cpu', str(cpu)) + self.assertEqual('cpu', cpu.type) + self.assertEqual(None, cpu.index) - res2 = matrixmultiply(mat1, mat2) - self.assertEqual(res, res2) + cpu0 = torch.device('cpu:0') + self.assertEqual('cpu:0', str(cpu0)) + self.assertEqual('cpu', cpu0.type) + self.assertEqual(0, cpu0.index) - # explicitly exercise the _out variant in torch.mm(). - # non contiguous case 3 - mat1 = genf(m, n).t() - mat2 = genf(p, m).t() - res = genf(n, p) - torch.mm(mat1, mat2, out=res) + cpu0 = torch.device('cpu', 0) + self.assertEqual('cpu:0', str(cpu0)) + self.assertEqual('cpu', cpu0.type) + self.assertEqual(0, cpu0.index) - res2 = matrixmultiply(mat1, mat2) - self.assertEqual(res, res2) + cuda = torch.device('cuda') + self.assertEqual('cuda', str(cuda)) + self.assertEqual('cuda', cuda.type) + self.assertEqual(None, cuda.index) - for (n, m, p) in [(20, 10, 5), (15, 5, 10), (5, 18, 10)]: - _test_mm(n, m, p, torch.float32, lambda x, y: torch.randn(x, y, dtype=torch.float32)) - _test_mm(n, m, p, torch.float64, lambda x, y: torch.randn(x, y, dtype=torch.float64)) - _test_mm(n, m, p, torch.int32, lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int32)) - _test_mm(n, m, p, torch.int64, lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int64)) - _test_mm(n, m, p, torch.bfloat16, lambda x, y: torch.randn(x, y, dtype=torch.float32).bfloat16()) + cuda1 = torch.device('cuda:1') + self.assertEqual('cuda:1', str(cuda1)) + self.assertEqual('cuda', cuda1.type) + self.assertEqual(1, cuda1.index) - @staticmethod - def _test_lu(self, cast, pivot=True): - from common_utils import random_fullrank_matrix_distinct_singular_value as fullrank + cuda1 = torch.device('cuda', 1) + self.assertEqual('cuda:1', str(cuda1)) + self.assertEqual('cuda', cuda1.type) + self.assertEqual(1, cuda1.index) - def run_test(matrix_size, batches, cast): - a = cast(fullrank(matrix_size, *batches)) - a_LU_info, pivots_info, info_ = a.lu(pivot=pivot, get_infos=True) - self.assertEqual(a_LU_info.size(), torch.Size(batches + (matrix_size, matrix_size))) - self.assertEqual(pivots_info.size(), torch.Size(batches + (matrix_size,))) - self.assertEqual(info_.size(), torch.Size(batches)) - self.assertEqual(info_.abs().sum(), 0) - a_LU, pivots = a.lu(pivot=pivot) - self.assertEqual(a_LU, a_LU_info) - self.assertEqual(pivots_info, pivots) - if a.is_cuda: - a_LU_info_nopiv, nopiv, info_nopiv = a.lu(pivot=False, get_infos=True) - self.assertEqual(nopiv, cast(torch.arange(1, 1 + a.size(-1), dtype=torch.int32).expand(a.shape[:-1]))) - self.assertEqual(info_, info_nopiv) - P, L, U = torch.lu_unpack(a_LU, pivots) - self.assertEqual(P.matmul(L.matmul(U)), a) - - for ms, batch in product([3, 5, 7], [(), (2,), (3,), (3, 5)]): - run_test(ms, batch, cast) - - # Info should be positive for rank deficient matrices - a = cast(torch.ones(5, 3, 3)) - self.assertGreater(a.lu(pivot=pivot, get_infos=True)[2][0], 0) - - # Error checking, no pivoting variant on CPU - with self.assertRaisesRegex(RuntimeError, - 'lu without pivoting is not implemented on the CPU'): - torch.lu(torch.empty(1, 2, 2), pivot=False) + self.assertRaises(RuntimeError, lambda: torch.device('cpu:-1')) + self.assertRaises(RuntimeError, lambda: torch.device('cpu:1')) + self.assertRaises(RuntimeError, lambda: torch.device('cpu', -1)) + self.assertRaises(RuntimeError, lambda: torch.device('cpu', 1)) + self.assertRaises(RuntimeError, lambda: torch.device('cuda:-1')) + self.assertRaises(RuntimeError, lambda: torch.device('cuda', -1)) + self.assertRaises(RuntimeError, lambda: torch.device(-1)) - @skipIfNoLapack - def test_lu(self): - self._test_lu(self, lambda t: t, pivot=True) + self.assertRaises(RuntimeError, lambda: torch.device('other')) + self.assertRaises(RuntimeError, lambda: torch.device('other:0')) - @staticmethod - def _test_lu_solve(self, cast, pivot=True): - from common_utils import lu_solve_test_helper - for k, n in zip([2, 3, 5], [3, 5, 7]): - b, A, LU_data, LU_pivots = lu_solve_test_helper(self, (n,), (n, k), cast, pivot) - x = torch.lu_solve(b, LU_data, LU_pivots) - b_ = torch.matmul(A, x) - self.assertEqual(b_, b) + device_set = {'cpu', 'cpu:0', 'cuda', 'cuda:0', 'cuda:1', 'cuda:10', 'cuda:100'} + device_hash_set = set() + for device in list(device_set): + device_hash_set.add(hash(torch.device(device))) + self.assertEqual(len(device_set), len(device_hash_set)) - @skipIfNoLapack - def test_lu_solve(self): - self._test_lu_solve(self, lambda t: t) + def test_tensor_device(self): + def assertEqual(device_str, fn): + self.assertEqual(torch.device(device_str), fn().device) + self.assertEqual(device_str, str(fn().device)) - @staticmethod - def _test_lu_solve_batched(self, cast, pivot=True): - from common_utils import lu_solve_test_helper + assertEqual('cpu', lambda: torch.tensor(5)) + assertEqual('cpu', lambda: torch.ones((2, 3), dtype=torch.float32, device='cpu')) + # NOTE: 'cpu' is the canonical representation of 'cpu:0', but 'cuda:X' is the canonical + # representation of cuda devices. + assertEqual('cpu', lambda: torch.ones((2, 3), dtype=torch.float32, device='cpu:0')) + assertEqual('cpu', lambda: torch.tensor(torch.ones((2, 3), dtype=torch.float32), device='cpu:0')) + if TEST_NUMPY: + assertEqual('cpu', lambda: torch.tensor(np.random.randn(2, 3), device='cpu')) - def lu_solve_batch_test_helper(A_dims, b_dims, cast, pivot): - b, A, LU_data, LU_pivots = lu_solve_test_helper(self, A_dims, b_dims, cast, pivot) - x_exp_list = [] - for i in range(b_dims[0]): - x_exp_list.append(torch.lu_solve(b[i], LU_data[i], LU_pivots[i])) - x_exp = torch.stack(x_exp_list) # Stacked output - x_act = torch.lu_solve(b, LU_data, LU_pivots) # Actual output - self.assertEqual(x_exp, x_act) # Equality check - self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 1e-12) # Correctness check + if torch.cuda.is_available(): + assertEqual('cuda:0', lambda: torch.tensor(5).cuda(0)) + assertEqual('cuda:0', lambda: torch.tensor(5).cuda('cuda:0')) + self.assertRaises(RuntimeError, lambda: torch.tensor(5).cuda('cpu')) + self.assertRaises(RuntimeError, lambda: torch.tensor(5).cuda('cpu:0')) + assertEqual('cuda:0', lambda: torch.tensor(5, dtype=torch.int64, device=0)) + assertEqual('cuda:0', lambda: torch.tensor(5, dtype=torch.int64, device='cuda:0')) + assertEqual('cuda:' + str(torch.cuda.current_device()), + lambda: torch.tensor(5, dtype=torch.int64, device='cuda')) + assertEqual('cuda:0', lambda: torch.tensor(torch.ones((2, 3), dtype=torch.float32), device='cuda:0')) + if TEST_NUMPY: + assertEqual('cuda:0', lambda: torch.tensor(np.random.randn(2, 3), device='cuda:0')) - for batchsize in [1, 3, 4]: - lu_solve_batch_test_helper((5, batchsize), (batchsize, 5, 10), cast, pivot) + if torch.cuda.device_count() > 1: + assertEqual('cuda:1', lambda: torch.tensor(5).cuda(1)) + assertEqual('cuda:1', lambda: torch.tensor(5).cuda('cuda:1')) + assertEqual('cuda:1', lambda: torch.tensor(5, dtype=torch.int64, device=1)) + assertEqual('cuda:1', lambda: torch.tensor(5, dtype=torch.int64, device='cuda:1')) + assertEqual('cuda:1', lambda: torch.tensor(torch.ones((2, 3), dtype=torch.float32), device='cuda:1')) + if TEST_NUMPY: + assertEqual('cuda:1', lambda: torch.tensor(np.random.randn(2, 3), device='cuda:1')) - # tensors with 0 elements - b = cast(torch.randn(3, 0, 3)) - A = cast(torch.randn(3, 0, 0)) - LU_data, LU_pivots = torch.lu(A) - self.assertEqual(torch.empty_like(b), b.lu_solve(LU_data, LU_pivots)) + def test_to(self): + def test_copy_behavior(t, non_blocking=False): + self.assertIs(t, t.to(t, non_blocking=non_blocking)) + self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking)) + self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking)) + self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=True)) + self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=True)) + self.assertIsNot(t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)) - @skipIfNoLapack - def test_lu_solve_batched(self): - self._test_lu_solve_batched(self, lambda t: t) + devices = [t.device] + if t.device.type == 'cuda': + if t.device.index == -1: + devices.append('cuda:{}'.format(torch.cuda.current_device())) + elif t.device.index == torch.cuda.current_device(): + devices.append('cuda') + for device in devices: + self.assertIs(t, t.to(device, non_blocking=non_blocking)) + self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking)) + self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=True)) + self.assertIsNot(t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True)) - @staticmethod - def _test_lu_solve_batched_non_contiguous(self, cast): - from numpy.linalg import solve - from common_utils import random_fullrank_matrix_distinct_singular_value + a = torch.tensor(5) + test_copy_behavior(a) + self.assertEqual(a.device, a.to('cpu').device) + self.assertEqual(a.device, a.to('cpu', dtype=torch.float32).device) + self.assertIs(torch.float32, a.to('cpu', dtype=torch.float32).dtype) + self.assertEqual(a.device, a.to(torch.float32).device) + self.assertIs(torch.float32, a.to(dtype=torch.float32).dtype) + self.assertEqual(a.data_ptr(), a.to('cpu').data_ptr()) + self.assertEqual(a.data_ptr(), a.to(dtype=a.dtype, device=a.device, copy=False).data_ptr()) + self.assertEqual(a.data_ptr(), a.to('cpu', copy=False).data_ptr()) + self.assertNotEqual(a.data_ptr(), a.to('cpu', copy=True).data_ptr()) - A = random_fullrank_matrix_distinct_singular_value(2, 2) - b = torch.randn(2, 2, 2) - x_exp = torch.as_tensor(solve(A.permute(0, 2, 1).numpy(), b.permute(2, 1, 0).numpy())) - A = cast(A).permute(0, 2, 1) - b = cast(b).permute(2, 1, 0) - assert not A.is_contiguous() and not b.is_contiguous(), "contiguous inputs" - LU_data, LU_pivots = torch.lu(A) - x = torch.lu_solve(b, LU_data, LU_pivots) - self.assertEqual(x, cast(x_exp)) + if torch.cuda.is_available(): + for non_blocking in [True, False]: + for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']: + b = torch.tensor(5., device=cuda) + test_copy_behavior(b, non_blocking) + self.assertEqual(b.device, b.to(cuda, non_blocking=non_blocking).device) + self.assertEqual(a.device, b.to('cpu', non_blocking=non_blocking).device) + self.assertEqual(b.device, a.to(cuda, non_blocking=non_blocking).device) + self.assertIs(torch.int32, b.to('cpu', dtype=torch.int32, non_blocking=non_blocking).dtype) + self.assertEqual(a.device, b.to('cpu', dtype=torch.int32, non_blocking=non_blocking).device) + self.assertIs(torch.int32, b.to(dtype=torch.int32).dtype) + self.assertEqual(b.device, b.to(dtype=torch.int32).device) - @skipIfNoLapack - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - def test_lu_solve_batched_non_contiguous(self): - self._test_lu_solve_batched_non_contiguous(self, lambda t: t) + def test_to_with_tensor(self): + a = torch.tensor(5) + self.assertEqual(a.device, a.to(a).device) - @staticmethod - def _test_lu_solve_batched_many_batches(self, cast): - from common_utils import lu_solve_test_helper + if torch.cuda.is_available(): + for non_blocking in [True, False]: + for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']: + b = torch.tensor(5., device=cuda) + self.assertEqual(b.device, b.to(b, non_blocking=non_blocking).device) + self.assertEqual(a.device, b.to(a, non_blocking=non_blocking).device) + self.assertEqual(b.device, a.to(b, non_blocking=non_blocking).device) - def run_test(A_dims, b_dims, cast): - b, A, LU_data, LU_pivots = lu_solve_test_helper(self, A_dims, b_dims, cast, True) - x = torch.lu_solve(b, LU_data, LU_pivots) - b_ = torch.matmul(A, x) - self.assertEqual(b_, b.expand_as(b_)) + def test_empty_full(self): + do_test_empty_full(self, torch.testing.get_all_math_dtypes('cpu'), torch.strided, torch.device('cpu')) + if torch.cuda.device_count() > 0: + do_test_empty_full(self, torch.testing.get_all_math_dtypes('cpu'), torch.strided, None) + do_test_empty_full(self, torch.testing.get_all_math_dtypes('cpu'), torch.strided, torch.device('cuda:0')) - run_test((5, 65536), (65536, 5, 10), cast) - run_test((5, 262144), (262144, 5, 10), cast) + def test_dtype_out_match(self): + d = torch.autograd.Variable(torch.DoubleTensor(2, 3)) + self.assertRaises(RuntimeError, lambda: torch.zeros((2, 3), out=d, dtype=torch.float32)) - @skipIfNoLapack - @slowTest - def test_lu_solve_batched_many_batches(self): - self._test_lu_solve_batched_many_batches(self, lambda t: t) + def test_constructor_dtypes(self): + default_type = torch.Tensor().type() + self.assertIs(torch.Tensor().dtype, torch.get_default_dtype()) - @staticmethod - def _test_lu_solve_batched_broadcasting(self, cast, pivot=True): - from numpy.linalg import solve - from common_utils import random_fullrank_matrix_distinct_singular_value + self.assertIs(torch.uint8, torch.ByteTensor.dtype) + self.assertIs(torch.float32, torch.FloatTensor.dtype) + self.assertIs(torch.float64, torch.DoubleTensor.dtype) - def run_test(A_dims, b_dims, cast, pivot): - A_matrix_size = A_dims[-1] - A_batch_dims = A_dims[:-2] - A = random_fullrank_matrix_distinct_singular_value(A_matrix_size, *A_batch_dims) - b = torch.randn(*b_dims) - x_exp = torch.as_tensor(solve(A.numpy(), b.numpy())) - A, b = cast(A), cast(b) - LU_data, LU_pivots = torch.lu(A, pivot=pivot) - x = torch.lu_solve(b, LU_data, LU_pivots) - self.assertEqual(x, cast(x_exp)) + torch.set_default_tensor_type('torch.FloatTensor') + self.assertIs(torch.float32, torch.get_default_dtype()) + self.assertIs(torch.FloatStorage, torch.Storage) - # test against numpy.linalg.solve - run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), cast, pivot) # no broadcasting - run_test((2, 1, 3, 4, 4), (4, 6), cast, pivot) # broadcasting b - run_test((4, 4), (2, 1, 3, 4, 2), cast, pivot) # broadcasting A - run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), cast, pivot) # broadcasting A & b + torch.set_default_dtype(torch.float64) + self.assertIs(torch.float64, torch.get_default_dtype()) + self.assertIs(torch.DoubleStorage, torch.Storage) - @skipIfNoLapack - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - def test_lu_solve_batched_broadcasting(self): - self._test_lu_solve_batched_broadcasting(self, lambda t: t) + torch.set_default_tensor_type(torch.FloatTensor) + self.assertIs(torch.float32, torch.get_default_dtype()) + self.assertIs(torch.FloatStorage, torch.Storage) - @staticmethod - def _test_lu_unpack(self, cast, pivot=True): - def run_test(shape, cast): - a = cast(torch.randn(*shape)) - a_lu, p = torch.lu(a, pivot=pivot) - p_ref, l_ref, u_ref = torch.lu_unpack(a_lu, p) - self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a) + if torch.cuda.is_available(): + torch.set_default_tensor_type(torch.cuda.FloatTensor) + self.assertIs(torch.float32, torch.get_default_dtype()) + self.assertIs(torch.float32, torch.cuda.FloatTensor.dtype) + self.assertIs(torch.cuda.FloatStorage, torch.Storage) - run_test((3, 3), cast) - run_test((5, 3, 3), cast) - run_test((7, 3, 5, 5), cast) - run_test((7, 5, 3, 3, 3), cast) + torch.set_default_dtype(torch.float64) + self.assertIs(torch.float64, torch.get_default_dtype()) + self.assertIs(torch.cuda.DoubleStorage, torch.Storage) - @skipIfNoLapack - def test_lu_unpack(self): - self._test_lu_unpack(self, lambda t: t) + # don't support integral or sparse default types. + self.assertRaises(TypeError, lambda: torch.set_default_tensor_type('torch.IntTensor')) + self.assertRaises(TypeError, lambda: torch.set_default_dtype(torch.int64)) - def test_bmm(self): - num_batches = 10 - M, N, O = 23, 8, 12 - b1 = torch.randn(num_batches, M, N) - b2 = torch.randn(num_batches, N, O) - res = torch.bmm(b1, b2) - for i in range(num_batches): - r = torch.mm(b1[i], b2[i]) - self.assertEqual(r, res[i]) - if torch.cuda.is_available(): - # check that mixed arguments are rejected - self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cuda())) - self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cuda(), b2)) + # don't allow passing dtype to set_default_tensor_type + self.assertRaises(TypeError, lambda: torch.set_default_tensor_type(torch.float32)) - def test_addbmm(self): - # num_batches = 10 - # M, N, O = 12, 8, 5 - num_batches = 2 - M, N, O = 2, 3, 4 - b1 = torch.randn(num_batches, M, N) - b2 = torch.randn(num_batches, N, O) - res = torch.bmm(b1, b2) - res2 = torch.Tensor().resize_as_(res[0]).zero_() + torch.set_default_tensor_type(default_type) - res2.addbmm_(b1, b2) - self.assertEqual(res2, res.sum(0, False)) - - res2.addbmm_(1, b1, b2) - self.assertEqual(res2, res.sum(0, False) * 2) - - res2.addbmm_(1., .5, b1, b2) - self.assertEqual(res2, res.sum(0, False) * 2.5) + def test_constructor_device_legacy(self): + self.assertRaises(RuntimeError, lambda: torch.FloatTensor(device='cuda')) + self.assertRaises(RuntimeError, lambda: torch.FloatTensor(torch.Size([2, 3, 4]), device='cuda')) + self.assertRaises(RuntimeError, lambda: torch.FloatTensor((2.0, 3.0), device='cuda')) - res3 = torch.addbmm(1, res2, 0, b1, b2) - self.assertEqual(res3, res2) + self.assertRaises(RuntimeError, lambda: torch.Tensor(device='cuda')) + self.assertRaises(RuntimeError, lambda: torch.Tensor(torch.Size([2, 3, 4]), device='cuda')) + self.assertRaises(RuntimeError, lambda: torch.Tensor((2.0, 3.0), device='cuda')) - res4 = torch.addbmm(1, res2, .5, b1, b2) - self.assertEqual(res4, res.sum(0, False) * 3) + x = torch.randn((3,), device='cpu') + self.assertRaises(RuntimeError, lambda: x.new(device='cuda')) + self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cuda')) + self.assertRaises(RuntimeError, lambda: x.new((2.0, 3.0), device='cuda')) - res5 = torch.addbmm(0, res2, 1, b1, b2) - self.assertEqual(res5, res.sum(0, False)) + if torch.cuda.is_available(): + self.assertRaises(RuntimeError, lambda: torch.cuda.FloatTensor(device='cpu')) + self.assertRaises(RuntimeError, lambda: torch.cuda.FloatTensor(torch.Size([2, 3, 4]), device='cpu')) + self.assertRaises(RuntimeError, lambda: torch.cuda.FloatTensor((2.0, 3.0), device='cpu')) - res6 = torch.addbmm(.1, res2, .5, b1, b2) - self.assertEqual(res6, res2 * .1 + (res.sum(0) * .5)) + default_type = torch.Tensor().type() + torch.set_default_tensor_type(torch.cuda.FloatTensor) + self.assertRaises(RuntimeError, lambda: torch.Tensor(device='cpu')) + self.assertRaises(RuntimeError, lambda: torch.Tensor(torch.Size([2, 3, 4]), device='cpu')) + self.assertRaises(RuntimeError, lambda: torch.Tensor((2.0, 3.0), device='cpu')) + torch.set_default_tensor_type(torch.cuda.FloatTensor) + torch.set_default_tensor_type(default_type) - def test_baddbmm(self): - num_batches = 10 - M, N, O = 12, 8, 5 - b1 = torch.randn(num_batches, M, N) - b2 = torch.randn(num_batches, N, O) - res = torch.bmm(b1, b2) - res2 = torch.Tensor().resize_as_(res).zero_() + x = torch.randn((3,), device='cuda') + self.assertRaises(RuntimeError, lambda: x.new(device='cpu')) + self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cpu')) + self.assertRaises(RuntimeError, lambda: x.new((2.0, 3.0), device='cpu')) - res2.baddbmm_(b1, b2) - self.assertEqual(res2, res) + def test_type(self): + x = torch.randn(3, 3).double() + self.assertEqual(x.type('torch.FloatTensor').dtype, torch.float32) + self.assertEqual(x.type(torch.FloatTensor).dtype, torch.float32) + self.assertEqual(x.int().type(torch.Tensor).dtype, torch.get_default_dtype()) + self.assertEqual(x.type(torch.int32).dtype, torch.int32) - res2.baddbmm_(1, b1, b2) - self.assertEqual(res2, res * 2) + def test_tensor_factory(self): + expected = torch.Tensor([1, 1]) + # test data + res1 = torch.tensor([1, 1]) + self.assertEqual(res1, expected) - res2.baddbmm_(1, .5, b1, b2) - self.assertEqual(res2, res * 2.5) + res1 = torch.tensor([1, 1], dtype=torch.int) + self.assertEqual(res1, expected) + self.assertIs(torch.int, res1.dtype) - res3 = torch.baddbmm(1, res2, 0, b1, b2) - self.assertEqual(res3, res2) + # test copy + res2 = torch.tensor(expected) + self.assertEqual(res2, expected) + res2[1] = 2 + self.assertEqual(expected, torch.ones_like(expected)) - res4 = torch.baddbmm(1, res2, .5, b1, b2) - self.assertEqual(res4, res * 3) + res2 = torch.tensor(expected, dtype=torch.int) + self.assertEqual(res1, expected) + self.assertIs(torch.int, res1.dtype) - res5 = torch.baddbmm(0, res2, 1, b1, b2) - self.assertEqual(res5, res) + # test copy with numpy + if TEST_NUMPY: + for dtype in [np.float64, np.int64, np.int8, np.uint8]: + a = np.array([5.]).astype(dtype) + res1 = torch.tensor(a) + self.assertEqual(5., res1[0].item()) + a[0] = 7. + self.assertEqual(5., res1[0].item()) - res6 = torch.baddbmm(.1, res2, .5, b1, b2) - self.assertEqual(res6, res2 * .1 + res * .5) + # test boolean tensor + a = torch.tensor([True, True, False, True, True], dtype=torch.bool) + b = torch.tensor([-1, -1.1, 0, 1, 1.1], dtype=torch.bool) + self.assertEqual(a, b) - @staticmethod - def _test_clamp(self, device='cpu'): - m1 = torch.rand(100, device=device).mul(5).add(-2.5) # uniform in [-2.5, 2.5] - # just in case we're extremely lucky. - min_val = -1 - max_val = 1 - m1[1] = min_val - m1[2] = max_val + def test_tensor_factory_copy_var(self): - res1 = m1.clone() - res1.clamp_(min_val, max_val) - res2 = m1.clone() - for i in iter_indices(res2): - res2[i] = max(min_val, min(max_val, res2[i])) - self.assertEqual(res1, res2) + def check_copy(copy, is_leaf, requires_grad, data_ptr=None): + if data_ptr is None: + data_ptr = copy.data_ptr + self.assertEqual(copy.data, source.data) + self.assertTrue(copy.is_leaf == is_leaf) + self.assertTrue(copy.requires_grad == requires_grad) + self.assertTrue(copy.data_ptr == data_ptr) - out = m1.clone() - torch.clamp(m1, min=min_val, max=max_val, out=out) - self.assertEqual(out, res1) + source = torch.randn(5, 5, dtype=torch.double, requires_grad=True) + # test torch.tensor() + check_copy(torch.tensor(source), True, False) + check_copy(torch.tensor(source, requires_grad=False), True, False) + check_copy(torch.tensor(source, requires_grad=True), True, True) - res1 = torch.clamp(m1, min=min_val) - res2 = m1.clone() - for i in iter_indices(res2): - res2[i] = max(min_val, res2[i]) - self.assertEqual(res1, res2) + # test tensor.new_tensor() + copy = torch.randn(1) + check_copy(copy.new_tensor(source), True, False) + check_copy(copy.new_tensor(source, requires_grad=False), True, False) + check_copy(copy.new_tensor(source, requires_grad=True), True, True) - torch.clamp(m1, min=min_val, out=out) - self.assertEqual(out, res1) + # test torch.as_tensor() + check_copy(torch.as_tensor(source), source.is_leaf, source.requires_grad, source.data_ptr) # not copy + check_copy(torch.as_tensor(source, dtype=torch.float), False, True) # copy and keep the graph - res1 = torch.clamp(m1, max=max_val) - res2 = m1.clone() - for i in iter_indices(res2): - res2[i] = min(max_val, res2[i]) - self.assertEqual(res1, res2) + def test_tensor_factory_type_inference(self): + def test_inference(default_dtype): + saved_dtype = torch.get_default_dtype() + torch.set_default_dtype(default_dtype) + self.assertIs(default_dtype, torch.tensor(()).dtype) + self.assertIs(default_dtype, torch.tensor(5.).dtype) + self.assertIs(torch.int64, torch.tensor(5).dtype) + self.assertIs(torch.bool, torch.tensor(True).dtype) + self.assertIs(torch.int32, torch.tensor(5, dtype=torch.int32).dtype) + self.assertIs(default_dtype, torch.tensor(((7, 5), (9, 5.))).dtype) + self.assertIs(default_dtype, torch.tensor(((5., 5), (3, 5))).dtype) + self.assertIs(torch.int64, torch.tensor(((5, 3), (3, 5))).dtype) - torch.clamp(m1, max=max_val, out=out) - self.assertEqual(out, res1) + if TEST_NUMPY: + self.assertIs(torch.float64, torch.tensor(np.array(())).dtype) + self.assertIs(torch.float64, torch.tensor(np.array(5.)).dtype) + if np.array(5).dtype == np.int64: # np long, which can be 4 bytes (e.g. on windows) + self.assertIs(torch.int64, torch.tensor(np.array(5)).dtype) + else: + self.assertIs(torch.int32, torch.tensor(np.array(5)).dtype) + self.assertIs(torch.uint8, torch.tensor(np.array(3, dtype=np.uint8)).dtype) + self.assertIs(default_dtype, torch.tensor(((7, np.array(5)), (np.array(9), 5.))).dtype) + self.assertIs(torch.float64, torch.tensor(((7, 5), (9, np.array(5.)))).dtype) + self.assertIs(torch.int64, torch.tensor(((5, np.array(3)), (np.array(3), 5))).dtype) + torch.set_default_dtype(saved_dtype) - # if the tensor contains nan case - test_tens = torch.tensor([nan], device=device) + test_inference(torch.float64) + test_inference(torch.float32) - res1 = test_tens.clone() - res1.clamp_(min_val, max_val) - res2 = test_tens.clone() - for i in iter_indices(res2): - res2[i] = max(min(res2[i], max_val), min_val) - self.assertEqual(torch.isnan(res1), torch.isnan(res2)) + def test_new_tensor(self): + expected = torch.autograd.Variable(torch.ByteTensor([1, 1])) + # test data + res1 = expected.new_tensor([1, 1]) + self.assertEqual(res1, expected) + res1 = expected.new_tensor([1, 1], dtype=torch.int) + self.assertEqual(res1, expected) + self.assertIs(torch.int, res1.dtype) - out = test_tens.clone() - torch.clamp(test_tens, min=min_val, max=max_val, out=out) - self.assertEqual(torch.isnan(out), torch.isnan(res1)) + # test copy + res2 = expected.new_tensor(expected) + self.assertEqual(res2, expected) + res2[1] = 2 + self.assertEqual(expected, torch.ones_like(expected)) + res2 = expected.new_tensor(expected, dtype=torch.int) + self.assertEqual(res2, expected) + self.assertIs(torch.int, res2.dtype) - res1 = torch.clamp(test_tens, min=min_val) - res2 = test_tens.clone() - for i in iter_indices(res2): - res2[i] = max(res2[i], min_val) - self.assertEqual(torch.isnan(res1), torch.isnan(res2)) + # test copy with numpy + if TEST_NUMPY: + a = np.array([5.]) + res1 = torch.tensor(a) + res1 = res1.new_tensor(a) + self.assertEqual(5., res1[0].item()) + a[0] = 7. + self.assertEqual(5., res1[0].item()) - torch.clamp(test_tens, min=min_val, out=out) - self.assertEqual(torch.isnan(out), torch.isnan(res1)) + if torch.cuda.device_count() >= 2: + expected = expected.cuda(1) + res1 = expected.new_tensor([1, 1]) + self.assertEqual(res1.get_device(), expected.get_device()) + res1 = expected.new_tensor([1, 1], dtype=torch.int) + self.assertIs(torch.int, res1.dtype) + self.assertEqual(res1.get_device(), expected.get_device()) - res1 = torch.clamp(test_tens, max=max_val) - res2 = test_tens.clone() - for i in iter_indices(res2): - res2[i] = min(res2[i], max_val) - self.assertEqual(torch.isnan(res1), torch.isnan(res2)) + res2 = expected.new_tensor(expected) + self.assertEqual(res2.get_device(), expected.get_device()) + res2 = expected.new_tensor(expected, dtype=torch.int) + self.assertIs(torch.int, res1.dtype) + self.assertEqual(res2.get_device(), expected.get_device()) + res2 = expected.new_tensor(expected, dtype=torch.int, device=0) + self.assertIs(torch.int, res1.dtype) + self.assertEqual(res2.get_device(), 0) - torch.clamp(test_tens, max=max_val, out=out) - self.assertEqual(torch.isnan(out), torch.isnan(res1)) + res1 = expected.new_tensor(1) + self.assertEqual(res1.get_device(), expected.get_device()) + res1 = expected.new_tensor(1, dtype=torch.int) + self.assertIs(torch.int, res1.dtype) + self.assertEqual(res1.get_device(), expected.get_device()) - error_msg = 'At least one of \'min\' or \'max\' must not be None' - with self.assertRaisesRegex(RuntimeError, error_msg): - m1.clamp() - with self.assertRaisesRegex(RuntimeError, error_msg): - m1.clamp_() + def test_as_tensor(self): + # from python data + x = [[0, 1], [2, 3]] + self.assertEqual(torch.tensor(x), torch.as_tensor(x)) + self.assertEqual(torch.tensor(x, dtype=torch.float32), torch.as_tensor(x, dtype=torch.float32)) - def test_clamp(self): - self._test_clamp(self) + # python data with heterogeneous types + z = [0, 'torch'] + with self.assertRaisesRegex(TypeError, "invalid data type"): + torch.tensor(z) + torch.as_tensor(z) - def test_pow(self): - # [res] torch.pow([res,] x) - - # pow has dedicated implementation for different exponents - for exponent in [-2, -1, -0.5, 0.5, 1, 2, 3, 4]: - # base - tensor, exponent - number - # contiguous - m1 = torch.rand(100, 100) + 0.5 - res1 = torch.pow(m1[4], exponent) - res2 = res1.clone().zero_() - for i in range(res2.size(0)): - res2[i] = math.pow(m1[4][i], exponent) - self.assertEqual(res1, res2) - - # non-contiguous - m1 = torch.rand(100, 100) + 0.5 - res1 = torch.pow(m1[:, 4], exponent) - res2 = res1.clone().zero_() - for i in range(res2.size(0)): - res2[i] = math.pow(m1[i, 4], exponent) - self.assertEqual(res1, res2) - - # base - number, exponent - tensor - # contiguous - m1 = torch.randn(100, 100) - res1 = torch.pow(3, m1[4]) - res2 = res1.clone().zero_() - for i in range(res2.size(0)): - res2[i] = math.pow(3, m1[4, i]) - self.assertEqual(res1, res2) + # python data with self-referential lists + z = [0] + z += [z] + with self.assertRaisesRegex(TypeError, "self-referential lists are incompatible"): + torch.tensor(z) + torch.as_tensor(z) - # non-contiguous - m1 = torch.randn(100, 100) - res1 = torch.pow(3, m1[:, 4]) - res2 = res1.clone().zero_() - for i in range(res2.size(0)): - res2[i] = math.pow(3, m1[i][4]) - self.assertEqual(res1, res2) + z = [[1, 2], z] + with self.assertRaisesRegex(TypeError, "self-referential lists are incompatible"): + torch.tensor(z) + torch.as_tensor(z) - @staticmethod - def _test_rpow(self, cast): - m = cast(torch.randn(10, 10)) - self.assertEqual(torch.pow(2, m), 2**m) + # from tensor (doesn't copy unless type is different) + y = torch.tensor(x) + self.assertIs(y, torch.as_tensor(y)) + self.assertIsNot(y, torch.as_tensor(y, dtype=torch.float32)) + if torch.cuda.is_available(): + self.assertIsNot(y, torch.as_tensor(y, device='cuda')) + y_cuda = y.to('cuda') + self.assertIs(y_cuda, torch.as_tensor(y_cuda)) + self.assertIs(y_cuda, torch.as_tensor(y_cuda, device='cuda')) - # test with scalar - m = cast(torch.randn(1).squeeze()) - assert m.dim() == 0, "m is intentionally a scalar" - self.assertEqual(torch.pow(2, m), 2**m) + if TEST_NUMPY: + # doesn't copy + for dtype in [np.float64, np.int64, np.int8, np.uint8]: + n = np.random.rand(5, 6).astype(dtype) + n_astensor = torch.as_tensor(n) + self.assertEqual(torch.tensor(n), n_astensor) + n_astensor[0][0] = 25.7 + self.assertEqual(torch.tensor(n), n_astensor) - def test_rpow(self): - self._test_rpow(self, lambda x: x) + # changing dtype causes copy + n = np.random.rand(5, 6).astype(np.float32) + n_astensor = torch.as_tensor(n, dtype=torch.float64) + self.assertEqual(torch.tensor(n, dtype=torch.float64), n_astensor) + n_astensor[0][1] = 250.8 + self.assertNotEqual(torch.tensor(n, dtype=torch.float64), n_astensor) - def _test_pow(self, base, exponent, np_exponent=None): - if np_exponent is None: - np_exponent = exponent + # changing device causes copy + if torch.cuda.is_available(): + n = np.random.randn(5, 6) + n_astensor = torch.as_tensor(n, device='cuda') + self.assertEqual(torch.tensor(n, device='cuda'), n_astensor) + n_astensor[0][2] = 250.9 + self.assertNotEqual(torch.tensor(n, device='cuda'), n_astensor) - def to_np(value): - if isinstance(value, torch.Tensor): - return value.cpu().numpy() - return value + def test_diag(self): + x = torch.rand(100, 100) + res1 = torch.diag(x) + res2 = torch.Tensor() + torch.diag(x, out=res2) + self.assertEqual(res1, res2) - try: - expected = torch.from_numpy( - np.power(to_np(base), to_np(np_exponent))) - except ValueError as e: - err_msg = "Integers to negative integer powers are not allowed." - self.assertEqual(str(e), err_msg) - out = torch.empty_like(base) - test_cases = [ - lambda: base.pow(exponent), - lambda: base.pow_(exponent), - lambda: torch.pow(base, exponent), - lambda: torch.pow(base, exponent, out=out) - ] - for test_case in test_cases: - self.assertRaisesRegex(RuntimeError, err_msg, test_case) - else: - if isinstance(base, torch.Tensor): - actual = base.pow(exponent) - self.assertEqual(actual, expected, allow_inf=True) + @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') + def test_diagonal_multidim(self): + x = torch.randn(10, 11, 12, 13) + xn = x.numpy() + for args in [(2, 2, 3), + (2,), + (-2, 1, 2), + (0, -2, -1)]: + result = torch.diagonal(x, *args) + expected = xn.diagonal(*args) + self.assertEqual(expected.shape, result.shape) + self.assertTrue(np.allclose(expected, result.numpy())) + # test non-continguous + xp = x.permute(1, 2, 3, 0) + result = torch.diagonal(xp, 0, -2, -1) + expected = xp.numpy().diagonal(0, -2, -1) + self.assertEqual(expected.shape, result.shape) + self.assertTrue(np.allclose(expected, result.numpy())) - actual = base.clone() - actual2 = actual.pow_(exponent) - self.assertEqual(actual, expected, allow_inf=True) - self.assertEqual(actual2, expected, allow_inf=True) + @staticmethod + def _test_diag_embed(self, dtype, device): + x = torch.arange(3 * 4, dtype=dtype, device=device).view(3, 4) + result = torch.diag_embed(x) + expected = torch.stack([torch.diag(r) for r in x], 0) + self.assertEqual(result, expected) - actual = torch.pow(base, exponent) - self.assertEqual(actual, expected, allow_inf=True) + result = torch.diag_embed(x, offset=1, dim1=0, dim2=2) + expected = torch.stack([torch.diag(r, 1) for r in x], 1) + self.assertEqual(result, expected) - actual2 = torch.pow(base, exponent, out=actual) - self.assertEqual(actual, expected, allow_inf=True) - self.assertEqual(actual2, expected, allow_inf=True) + def test_diag_embed(self): + self._test_diag_embed(self, dtype=torch.float32, device='cpu') - @torchtest.for_all_device_types() - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_int_pow(self, device): + def test_renorm(self): + m1 = torch.randn(10, 5) + res1 = torch.Tensor() - def _test_integral_pow(dt, range, dev): - tensor = torch.tensor((3, 3), dtype=dt, device=dev).random_(*range) - exps = [0, 1, 2, 4, - torch.tensor((3, 3), dtype=dt, device=dev).random_(0, 5)] - for exp in exps: - self._test_pow(tensor, exp) + def renorm(matrix, value, dim, max_norm): + m1 = matrix.transpose(dim, 0).contiguous() + # collapse non-dim dimensions. + m2 = m1.clone().resize_(m1.size(0), int(math.floor(m1.nelement() / m1.size(0)))) + norms = m2.norm(value, 1, True) + # clip + new_norms = norms.clone() + new_norms[torch.gt(norms, max_norm)] = max_norm + new_norms.div_(norms.add_(1e-7)) + # renormalize + m1.mul_(new_norms.expand_as(m1)) + return m1.transpose(dim, 0) - _test_integral_pow(torch.int8, (-3, 4), device) - _test_integral_pow(torch.uint8, (0, 4), device) - _test_integral_pow(torch.int16, (-5, 5), device) - _test_integral_pow(torch.int64, (-10, 10), device) - _test_integral_pow(torch.int32, (-10, 10), device) + # note that the axis fed to torch.renorm is different (2~=1) + maxnorm = m1.norm(2, 1).mean() + m2 = renorm(m1, 2, 1, maxnorm) + m1.renorm_(2, 1, maxnorm) + self.assertEqual(m1, m2, 1e-5) + self.assertEqual(m1.norm(2, 0), m2.norm(2, 0), 1e-5) - @torchtest.for_all_device_types() - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_int_tensor_pow_neg_ints(self, device): - ints = [torch.iinfo(torch.int32).min, - -3, -2, -1, 0, 1, 2, 3, - torch.iinfo(torch.int32).max] - neg_ints = [torch.iinfo(torch.int32).min, -3, -2, -1] - tensor = torch.tensor(ints, dtype=torch.int32, device=device) - for pow in neg_ints: - self._test_pow(tensor, pow) + m1 = torch.randn(3, 4, 5) + m2 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4) + maxnorm = m2.norm(2, 0).mean() + m2 = renorm(m2, 2, 1, maxnorm) + m1.renorm_(2, 1, maxnorm) + m3 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4) + self.assertEqual(m3, m2) + self.assertEqual(m3.norm(2, 0), m2.norm(2, 0)) - @torchtest.for_all_device_types() - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_long_tensor_pow_floats(self, device): - ints = [0, 1, 23, 4567] - floats = [0.0, 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0] - tensor = torch.tensor(ints, dtype=torch.int64, device=device) - for pow in floats: - if device == 'cuda' and not TEST_WITH_ROCM: - # Current pow CUDA implementation casts exponent - # to tensor dtype, but numpy does not, that's why: - # pow CUDA 4 ^ 0.5 = 1 - # numpy pow 4 ^ 0.5 = 2 - # This line must be deleted as soon as - # pow CUDA implementation is fixed. - self._test_pow(tensor, pow, np_exponent=int(pow)) + @staticmethod + def _test_multinomial(self, type): + def make_prob_dist(shape, is_contiguous): + if is_contiguous: + return type(*shape).uniform_() + elif len(shape) == 1: + return type(*(shape + [5])).uniform_()[:, 2] else: - # pow CPU implementation is already fixed and - # does not cast exponent to tensor dtype, - # that why it is compatible with numpy. - self._test_pow(tensor, pow) - - @torchtest.for_all_device_types() - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_float_scalar_pow_float_tensor(self, device): - floats = [2.0, -3 / 2, -1.0, -1 / 2, -1 / 3, 0.0, - 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0] - tensor = torch.tensor(floats, dtype=torch.float32, device=device) - for base in floats: - self._test_pow(base, tensor) + # num dim = 2 + new_shape = [2, shape[1], 7, 1, shape[0], 1, 10] + prob_dist = type(*new_shape).uniform_() + prob_dist = prob_dist.transpose(1, 4) + prob_dist = prob_dist[1, :, 5, 0, :, 0, 4] + assert not prob_dist.is_contiguous() # sanity check + return prob_dist - @torchtest.for_all_device_types() - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_tensor_pow_tensor(self, dev): - def rotate(l, n): - return l[-n:] + l[:-n] + for is_contiguous in (True, False): + # with replacement + n_row = 3 + for n_col in range(4, 5 + 1): + prob_dist = make_prob_dist([n_row, n_col], is_contiguous) + # indices that shouldn't be sampled (<0 means none) + zero_prob_indices = torch.LongTensor(n_row).random_(-2, n_col).tolist() + for i, j in enumerate(zero_prob_indices): + if j >= 0: + prob_dist[i, j] = 0 + n_sample = n_col * 3 + sample_indices = torch.multinomial(prob_dist, n_sample, True) + self.assertEqual(prob_dist.dim(), 2) + self.assertEqual(sample_indices.size(1), n_sample) + for i in range(n_row): + zero_prob_idx = zero_prob_indices[i] + if zero_prob_idx < 0: + continue + for j in range(n_sample): + self.assertNotEqual(sample_indices[i, j], zero_prob_idx, + "sampled an index with zero probability") - def test_tensor_pow_tensor(values, torch_type, numpy_type): - vals_tensor = torch.tensor(values, dtype=torch_type, device=dev) - for i in range(len(values)): - pows = rotate(values, i) - pows_tensor = torch.tensor(pows, dtype=torch_type, device=dev) - self._test_pow(vals_tensor, pows_tensor) + # without replacement + n_row = 3 + for n_col in range(2, 10 + 1, 2): + prob_dist = make_prob_dist([n_row, n_col], is_contiguous) + # indices that shouldn't be sampled (<0 means none) + zero_prob_indices = torch.LongTensor(n_row).random_(-1, n_col).tolist() + for i, j in enumerate(zero_prob_indices): + if j >= 0: + prob_dist[i, j] = 0 + n_sample = max(1, n_col - 2) + sample_indices = torch.multinomial(prob_dist, n_sample, False) + self.assertEqual(prob_dist.dim(), 2) + self.assertEqual(sample_indices.size(1), n_sample) + for i in range(n_row): + row_samples = {} + zero_prob_idx = zero_prob_indices[i] + for j in range(n_sample): + sample_idx = sample_indices[i, j] + if zero_prob_idx >= 0: + self.assertNotEqual(sample_idx, zero_prob_idx, + "sampled an index with zero probability") + self.assertNotIn(sample_idx, row_samples, "sampled an index twice") + row_samples[sample_idx] = True - ints = [0, 1, 2, 3] - test_tensor_pow_tensor(ints, torch.int32, np.int32) - test_tensor_pow_tensor(ints, torch.int64, np.int64) + # vector + n_col = 4 + prob_dist = make_prob_dist([n_col], is_contiguous).fill_(1) + zero_prob_idx = 1 # index that shouldn't be sampled + prob_dist[zero_prob_idx] = 0 + n_sample = 20 + sample_indices = torch.multinomial(prob_dist, n_sample, True) + for sample_index in sample_indices: + self.assertNotEqual(sample_index, zero_prob_idx, "sampled an index with zero probability") + s_dim = sample_indices.dim() + self.assertEqual(sample_indices.dim(), 1, "wrong number of dimensions") + self.assertEqual(prob_dist.dim(), 1, "wrong number of prob_dist dimensions") + self.assertEqual(sample_indices.size(0), n_sample, "wrong number of samples") - floats = [-3.0, -2.0, -1.0, -1 / 2, -1 / 3, - 0.0, - 1 / 3, 1 / 2, 1.0, 2.0, 3.0] - test_tensor_pow_tensor(floats, torch.float32, np.float32) - test_tensor_pow_tensor(floats, torch.float64, np.float64) + def test_multinomial(self): + self._test_multinomial(self, torch.FloatTensor) - def _test_cop(self, torchfn, mathfn): - def reference_implementation(res2): - for i, j in iter_indices(sm1): - idx1d = i * sm1.size(0) + j - res2[i, j] = mathfn(sm1[i, j], sm2[idx1d]) - return res2 + def _spawn_method(self, method, arg): + try: + mp.set_start_method('spawn') + except RuntimeError: + pass + with mp.Pool(1) as pool: + self.assertTrue(pool.map(method, [arg])) - # contiguous - m1 = torch.randn(10, 10, 10) - m2 = torch.randn(10, 10 * 10) - sm1 = m1[4] - sm2 = m2[4] + @staticmethod + def _test_multinomial_invalid_probs(probs): + try: + # n_sample = 1 is a special case, test n_sample=2 which is more general + torch.multinomial(probs.to('cpu'), 2) + return False # Should not be reached + except RuntimeError as e: + return 'invalid multinomial distribution' in str(e) - res1 = torchfn(sm1, sm2.view(10, 10)) - res2 = reference_implementation(res1.clone()) - self.assertEqual(res1, res2) + @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \ + don't support multiprocessing with spawn start method") + @unittest.skipIf(IS_WINDOWS, 'FIXME: CUDA OOM error on Windows') + @unittest.skipIf(not PY3, + "spawn start method is not supported in Python 2, \ + but we need it for for testing failure case for CPU RNG on Windows") + def test_multinomial_invalid_probs(self): + test_method = _TestTorchMixin._test_multinomial_invalid_probs + self._spawn_method(test_method, torch.Tensor([1, -1, 1])) + self._spawn_method(test_method, torch.Tensor([1, inf, 1])) + self._spawn_method(test_method, torch.Tensor([1, -inf, 1])) + self._spawn_method(test_method, torch.Tensor([1, 1, nan])) + self._spawn_method(test_method, torch.Tensor([0, 1, 0])) - # non-contiguous - m1 = torch.randn(10, 10, 10) - m2 = torch.randn(10 * 10, 10 * 10) - sm1 = m1[:, 4] - sm2 = m2[:, 4] - # view as sm1.size() - sm2.set_(sm2.storage(), sm2.storage_offset(), sm1.size(), (sm2.stride()[0] * 10, sm2.stride()[0])) - res1 = torchfn(sm1, sm2) - # reference_implementation assumes 1-d sm2 - sm2.set_(sm2.storage(), sm2.storage_offset(), m2[:, 4].size(), m2[:, 4].stride()) - res2 = reference_implementation(res1.clone()) - self.assertEqual(res1, res2) + @suppress_warnings + def test_range(self): + res1 = torch.range(0, 1) + res2 = torch.Tensor() + torch.range(0, 1, out=res2) + self.assertEqual(res1, res2, 0) - def test_cdiv(self): - self._test_cop(torch.div, lambda x, y: x / y) + # Check range for non-contiguous tensors. + x = torch.zeros(2, 3) + torch.range(0, 3, out=x.narrow(1, 1, 2)) + res2 = torch.Tensor(((0, 0, 1), (0, 2, 3))) + self.assertEqual(x, res2, 1e-16) - def test_cfmod(self): - self._test_cop(torch.fmod, math.fmod) + # Check negative + res1 = torch.Tensor((1, 0)) + res2 = torch.Tensor() + torch.range(1, 0, -1, out=res2) + self.assertEqual(res1, res2, 0) - def test_cremainder(self): - self._test_cop(torch.remainder, lambda x, y: x % y) + # Equal bounds + res1 = torch.ones(1) + res2 = torch.Tensor() + torch.range(1, 1, -1, out=res2) + self.assertEqual(res1, res2, 0) + torch.range(1, 1, 1, out=res2) + self.assertEqual(res1, res2, 0) - def test_cmul(self): - self._test_cop(torch.mul, lambda x, y: x * y) + # FloatTensor + res1 = torch.range(0.6, 0.9, 0.1, out=torch.FloatTensor()) + self.assertEqual(res1.size(0), 4) + res1 = torch.range(1, 10, 0.3, out=torch.FloatTensor()) + self.assertEqual(res1.size(0), 31) - def test_cpow(self): - self._test_cop(torch.pow, lambda x, y: nan if x < 0 else math.pow(x, y)) + # DoubleTensor + res1 = torch.range(0.6, 0.9, 0.1, out=torch.DoubleTensor()) + self.assertEqual(res1.size(0), 4) + res1 = torch.range(1, 10, 0.3, out=torch.DoubleTensor()) + self.assertEqual(res1.size(0), 31) - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_einsum(self): - # test cases taken from https://gist.github.com/rockt/15ee013889d65342088e9260a377dc8f - x = torch.randn(5) - y = torch.randn(7) - A = torch.randn(3, 5) - B = torch.randn(2, 5) - C = torch.randn(2, 3, 5) - D = torch.randn(2, 5, 7) - E = torch.randn(7, 9) - F = torch.randn(2, 3, 5, 7) - G = torch.randn(7, 11, 13) - H = torch.randn(4, 4) - I = torch.randn(3, 4, 4) - l = torch.randn(5, 10) - r = torch.randn(5, 20) - w = torch.randn(30, 10, 20) - test_list = [ - # -- Vector - ("i->", x), # sum - ("i,i->", x, x), # dot - ("i,i->i", x, x), # vector element-wise mul - ("i,j->ij", x, y), # outer - # -- Matrix - ("ij->ji", A), # transpose - ("ij->j", A), # row sum - ("ij->i", A), # col sum - ("ij,ij->ij", A, A), # matrix element-wise mul - ("ij,j->i", A, x), # matrix vector multiplication - ("ij,kj->ik", A, B), # matmul - ("ij,ab->ijab", A, E), # matrix outer product - # -- Tensor - ("aij,ajk->aik", C, D), # batch matmul - ("ijk,jk->i", C, A), # tensor matrix contraction - ("aij,jk->aik", D, E), # tensor matrix contraction - ("abcd,dfg->abcfg", F, G), # tensor tensor contraction - ("ijk,jk->ik", C, A), # tensor matrix contraction with double indices - ("ijk,jk->ij", C, A), # tensor matrix contraction with double indices - ("ijk,ik->j", C, B), # non contiguous - ("ijk,ik->jk", C, B), # non contiguous with double indices - # -- Diagonal - ("ii", H), # trace - ("ii->i", H), # diagonal - # -- Ellipsis - ("i...->...", H), - ("ki,...k->i...", A.t(), B), - ("k...,jk", A.t(), B), - ("...ii->...i", I), # batch diagonal - # -- Other - ("bn,anm,bm->ba", l, w, r), # as torch.bilinear - ("... ii->...i ", I), # batch diagonal with spaces - ] - for test in test_list: - actual = torch.einsum(test[0], test[1:]) - expected = np.einsum(test[0], *[t.numpy() for t in test[1:]]) - self.assertEqual(expected.shape, actual.shape, test[0]) - self.assertTrue(np.allclose(expected, actual.numpy()), test[0]) - # test vararg - actual2 = torch.einsum(test[0], *test[1:]) - self.assertEqual(expected.shape, actual2.shape, test[0]) - self.assertTrue(np.allclose(expected, actual2.numpy()), test[0]) + def test_range_warning(self): + with warnings.catch_warnings(record=True) as w: + torch.range(0, 10) + self.assertEqual(len(w), 1) - def do_einsum(*args): - return torch.einsum(test[0], args) - # FIXME: following test cases fail gradcheck - if test[0] not in {"i,i->", "i,i->i", "ij,ij->ij"}: - gradcheck_inps = tuple(t.detach().requires_grad_() for t in test[1:]) - self.assertTrue(torch.autograd.gradcheck(do_einsum, gradcheck_inps)) - self.assertTrue(A._version == 0) # check that we do not use inplace ops + def test_arange(self): + res1 = torch.arange(0, 1) + res2 = torch.Tensor() + torch.arange(0, 1, out=res2) + self.assertEqual(res1, res2, 0) - def test_sum_all(self): - def check_sum_all(tensor): - pylist = tensor.reshape(-1).tolist() - self.assertEqual(tensor.sum(), sum(pylist)) + # Check arange with only one argument + res1 = torch.arange(10) + res2 = torch.arange(0, 10) + self.assertEqual(res1, res2, 0) - check_sum_all(torch.tensor([1, 2, 3, 4, 5])) - check_sum_all(torch.randn(200000)) - check_sum_all(torch.randn(2000, 2)[:, 0]) - check_sum_all(torch.tensor([True, False, True], dtype=torch.bool)) + # Check arange for non-contiguous tensors. + x = torch.zeros(2, 3) + torch.arange(0, 4, out=x.narrow(1, 1, 2)) + res2 = torch.Tensor(((0, 0, 1), (0, 2, 3))) + self.assertEqual(x, res2, 1e-16) - def _assert_matches_numpy(self, t, n): - self.assertEqual(n.shape, t.shape) - if t.dtype == torch.float: - self.assertTrue(np.allclose(n, t.numpy(), rtol=1e-03, atol=1e-05, - equal_nan=True)) - else: - self.assertTrue(np.allclose(n, t.numpy(), equal_nan=True)) + # Check negative + res1 = torch.Tensor((1, 0)) + res2 = torch.Tensor() + torch.arange(1, -1, -1, out=res2) + self.assertEqual(res1, res2, 0) - def _test_dim_ops(self, pytorch_op, numpy_op, - use_floating=True, use_integral=True): - def do_one(tensors_dict, dim): - for category, tensors in tensors_dict.items(): - if category == "slice": - dim = 0 - for tensor in tensors: - # we have no control over NumPy warnings... - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - expected = numpy_op(tensor.numpy(), dim) - actual = pytorch_op(tensor, dim) - self._assert_matches_numpy(actual, expected) - if torch.cuda.is_available(): - self._assert_matches_numpy(pytorch_op(tensor.cuda(), - dim).cpu(), - expected) - do_one(self._make_tensors((5, 400000), use_floating=use_floating, - use_integral=use_integral), 1) - do_one(self._make_tensors((3, 5, 7), use_floating=use_floating, - use_integral=use_integral), 0) - do_one(self._make_tensors((3, 5, 7), use_floating=use_floating, - use_integral=use_integral), 1) - do_one(self._make_tensors((3, 5, 7), use_floating=use_floating, - use_integral=use_integral), 2) - do_one(self._make_tensors((100000, ), use_floating=use_floating, - use_integral=use_integral), -1) - do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, - use_integral=use_integral), 0) - do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, - use_integral=use_integral), 1) - do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, - use_integral=use_integral), 2) - do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, - use_integral=use_integral), (1, 2)) - do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, - use_integral=use_integral), (1, -1)) - do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, - use_integral=use_integral), (0, 2)) - do_one(self._make_tensors((50, 50, 50), use_floating=use_floating, - use_integral=use_integral), (0, 2, 1)) + # Equal bounds + res1 = torch.ones(1) + res2 = torch.Tensor() + torch.arange(1, 0, -1, out=res2) + self.assertEqual(res1, res2, 0) + torch.arange(1, 2, 1, out=res2) + self.assertEqual(res1, res2, 0) - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_sum_dim(self): - self._test_dim_ops( - lambda t, d: t.sum(d), - lambda n, d: n.sum(d)) + # FloatTensor + res1 = torch.arange(0.6, 0.89, 0.1, out=torch.FloatTensor()) + self.assertEqual(res1, [0.6, 0.7, 0.8]) + res1 = torch.arange(1, 10, 0.3, out=torch.FloatTensor()) + self.assertEqual(res1.size(0), 30) + self.assertEqual(res1[0], 1) + self.assertEqual(res1[29], 9.7) - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_mean_dim(self): - self._test_dim_ops( - lambda t, d: t.mean(d), - lambda n, d: n.mean(d), - use_integral=False) + # DoubleTensor + res1 = torch.arange(0.6, 0.89, 0.1, out=torch.DoubleTensor()) + self.assertEqual(res1, [0.6, 0.7, 0.8]) + res1 = torch.arange(1, 10, 0.3, out=torch.DoubleTensor()) + self.assertEqual(res1.size(0), 30) + self.assertEqual(res1[0], 1) + self.assertEqual(res1[29], 9.7) - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_std_dim(self): - for unbiased in [False, True]: - self._test_dim_ops( - lambda t, d: t.std(d, unbiased=unbiased), - lambda n, d: n.std(d, ddof=1 if unbiased else 0), - use_integral=False) + # Check that it's exclusive + r = torch.arange(0, 5) + self.assertEqual(r.min(), 0) + self.assertEqual(r.max(), 4) + self.assertEqual(r.numel(), 5) - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_var_dim(self): - for unbiased in [False, True]: - self._test_dim_ops( - lambda t, d: t.var(d, unbiased=unbiased), - lambda n, d: n.var(d, ddof=1 if unbiased else 0), - use_integral=False) + r = torch.arange(0, 5, 2) + self.assertEqual(r.min(), 0) + self.assertEqual(r.max(), 4) + self.assertEqual(r.numel(), 3) - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - @unittest.skipIf(not TEST_SCIPY, 'Scipy not found') - def test_logsumexp_dim(self): - from scipy.special import logsumexp - self._test_dim_ops( - lambda t, d: t.logsumexp(d), - lambda n, d: logsumexp(n, d), - use_integral=False) + r1 = torch.arange(0, 5 + 1e-6) + r2 = torch.arange(0, 5) + r3 = torch.arange(0, 5 - 1e-6) + self.assertEqual(r1[:-1], r2, 0) + self.assertEqual(r2, r3, 0) - def test_sum_out(self): - x = torch.rand(100, 100) - res1 = torch.sum(x, 1) - res2 = torch.Tensor() - torch.sum(x, 1, out=res2) - self.assertEqual(res1, res2) - x = torch.rand(100, 100, 100) - res1 = x.sum(2).sum(1) - res2 = torch.Tensor() - torch.sum(x, (2, 1), out=res2) - self.assertEqual(res1, res2) + r1 = torch.arange(10, -1 + 1e-6, -1) + r2 = torch.arange(10, -1, -1) + r3 = torch.arange(10, -1 - 1e-6, -1) + self.assertEqual(r1, r2, 0) + self.assertEqual(r2, r3[:-1], 0) - # TODO: these tests only check if it's possible to pass a return value - # it'd be good to expand them - def test_prod(self): - x = torch.rand(100, 100) - res1 = torch.prod(x, 1) - res2 = torch.Tensor() - torch.prod(x, 1, out=res2) - self.assertEqual(res1, res2) + x = torch.empty(1).expand(10) + self.assertRaises(RuntimeError, lambda: torch.arange(10, out=x)) + msg = "unsupported range" + self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('inf'))) + self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('inf'))) - def test_cumsum(self): - for d in torch.testing.get_all_device_types(): - x = torch.rand(100, 100, device=d) - res1 = torch.cumsum(x, 1) - res2 = torch.Tensor().to(d) - torch.cumsum(x, 1, out=res2) - self.assertEqual(res1, res2) + for device in torch.testing.get_all_device_types(): + self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(-5, float('nan'), device=device)) + # check with step size + self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('-inf'), -1, device=device)) + self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('inf'), device=device)) + self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('-inf'), 10, device=device)) + self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('nan'), 10, device=device)) + self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('inf'), device=device)) + self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('nan'), device=device)) - a = torch.tensor([[True, False, True], - [False, False, False], - [True, True, True]], device=d) - b = a.byte() - aRes = torch.cumsum(a, 0) - bRes = torch.cumsum(b, 0) - self.assertEqual(aRes, bRes) - self.assertEqual(aRes, torch.tensor([[1, 0, 1], - [1, 0, 1], - [2, 1, 2]])) - - aRes = torch.cumsum(a, 1) - bRes = torch.cumsum(b, 1) - self.assertEqual(aRes, bRes) - self.assertEqual(aRes, torch.tensor([[1, 1, 2], - [0, 0, 0], - [1, 2, 3]])) - - def test_cumprod(self): - for d in torch.testing.get_all_device_types(): - x = torch.rand(100, 100, device=d) - res1 = torch.cumprod(x, 1) - res2 = torch.Tensor().to(d) - torch.cumprod(x, 1, out=res2) - self.assertEqual(res1, res2) + self.assertRaisesRegex( + RuntimeError, "overflow", + lambda: torch.arange(1.175494351e-38, 3.402823466e+38, device=device)) - a = torch.tensor([[True, False, True], - [False, False, False], - [True, True, True]], dtype=torch.bool, device=d) - b = a.byte() - aRes = torch.cumprod(a, 0) - bRes = torch.cumprod(b, 0) - self.assertEqual(aRes, bRes) - self.assertEqual(aRes, torch.tensor([[1, 0, 1], - [0, 0, 0], - [0, 0, 0]])) - - aRes = torch.cumprod(a, 1) - bRes = torch.cumprod(b, 1) - self.assertEqual(aRes, bRes) - self.assertEqual(aRes, torch.tensor([[1, 0, 0], - [0, 0, 0], - [1, 1, 1]])) + # check that it holds a consistent output shape on precision-cornered step sizes + d = torch.arange(-4.0, 4.0, 0.01, dtype=torch.float32, device=device) + self.assertEqual(d.shape[0], 800) - def _test_reduce_integer_upcast(self, fn, has_out=True): - shape = (3, 4, 5) - reduced_shape = fn(torch.ones(shape)).shape + def test_arange_inference(self): + saved_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.float32) + # end only + self.assertIs(torch.float32, torch.arange(1.).dtype) + self.assertIs(torch.float32, torch.arange(torch.tensor(1.)).dtype) + self.assertIs(torch.float32, torch.arange(torch.tensor(1., dtype=torch.float64)).dtype) - def _test_out(dtype, other_dtype): - out = torch.ones(reduced_shape, dtype=dtype) - result = fn(x, out=out) - self.assertIs(out.dtype, result.dtype) - self.assertEqual(fn(x.type(dtype)), result) - result = fn(x, out=out, dtype=dtype) - self.assertIs(out.dtype, result.dtype) - self.assertEqual(fn(x.type(dtype)), result) - # 'out' is favored over dtype, check error - self.assertRaises(RuntimeError, lambda: fn(x, out=out, dtype=other_dtype)) + self.assertIs(torch.int64, torch.arange(1).dtype) + self.assertIs(torch.int64, torch.arange(torch.tensor(1)).dtype) + self.assertIs(torch.int64, torch.arange(torch.tensor(1, dtype=torch.int16)).dtype) - for dtype in [dtype for dtype in torch.testing.get_all_math_dtypes('cpu') if dtype != torch.float16]: - x = torch.ones(shape, dtype=dtype) - expected_dtype = dtype if dtype.is_floating_point else torch.int64 - self.assertIs(expected_dtype, fn(x).dtype) - self.assertEqual(fn(x.type(expected_dtype)), fn(x)) + # start, end, [step] + self.assertIs(torch.float32, torch.arange(1., 3).dtype) + self.assertIs(torch.float32, torch.arange(torch.tensor(1., dtype=torch.float64), 3).dtype) + self.assertIs(torch.float32, torch.arange(1, 3.).dtype) + self.assertIs(torch.float32, torch.arange(torch.tensor(1, dtype=torch.int16), torch.tensor(3.)).dtype) + self.assertIs(torch.float32, torch.arange(1, 3, 1.).dtype) + self.assertIs(torch.float32, + torch.arange(torch.tensor(1), + torch.tensor(3, dtype=torch.int16), + torch.tensor(1., dtype=torch.float64)).dtype) - if dtype.is_floating_point: - other_dtype = torch.float32 if dtype == torch.float64 else torch.float64 - else: - other_dtype = torch.int32 if dtype != torch.int32 else torch.int16 - self.assertIs(other_dtype, fn(x, dtype=other_dtype).dtype) - self.assertEqual(fn(x.type(other_dtype)), fn(x, dtype=other_dtype)) + self.assertIs(torch.int64, torch.arange(1, 3).dtype) + self.assertIs(torch.int64, torch.arange(torch.tensor(1), 3).dtype) + self.assertIs(torch.int64, torch.arange(torch.tensor(1), torch.tensor(3, dtype=torch.int16)).dtype) + self.assertIs(torch.int64, torch.arange(1, 3, 1).dtype) + self.assertIs(torch.int64, + torch.arange(torch.tensor(1), + torch.tensor(3), + torch.tensor(1, dtype=torch.int16)).dtype) + torch.set_default_dtype(saved_dtype) - # test mixed int/float - mixed_dtype = torch.int32 if dtype.is_floating_point else torch.float32 - self.assertIs(mixed_dtype, fn(x, dtype=mixed_dtype).dtype) - self.assertEqual(fn(x.type(mixed_dtype)), fn(x, dtype=mixed_dtype)) + def test_randint_inference(self): + size = (2, 1) + for args in [(3,), (1, 3)]: # (low,) and (low, high) + self.assertIs(torch.int64, torch.randint(*args, size=size).dtype) + self.assertIs(torch.int64, torch.randint(*args, size=size, layout=torch.strided).dtype) + self.assertIs(torch.int64, torch.randint(*args, size=size, generator=torch.default_generator).dtype) + self.assertIs(torch.float32, torch.randint(*args, size=size, dtype=torch.float32).dtype) + out = torch.empty(size, dtype=torch.float32) + self.assertIs(torch.float32, torch.randint(*args, size=size, out=out).dtype) + self.assertIs(torch.float32, torch.randint(*args, size=size, out=out, dtype=torch.float32).dtype) + out = torch.empty(size, dtype=torch.int64) + self.assertIs(torch.int64, torch.randint(*args, size=size, out=out).dtype) + self.assertIs(torch.int64, torch.randint(*args, size=size, out=out, dtype=torch.int64).dtype) - if has_out: - _test_out(dtype, other_dtype) - _test_out(dtype, mixed_dtype) + def test_broadcast_empty(self): + # empty + empty + self.assertRaises(RuntimeError, lambda: torch.randn(5, 0) + torch.randn(0, 5)) + self.assertEqual(torch.randn(5, 0), torch.randn(0) + torch.randn(5, 0)) + self.assertEqual(torch.randn(5, 0, 0), torch.randn(0) + torch.randn(5, 0, 1)) - def test_sum_integer_upcast(self): - self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, **kwargs), False) - self._test_reduce_integer_upcast(lambda x, **kwargs: torch.sum(x, 0, **kwargs)) + # scalar + empty + self.assertEqual(torch.randn(5, 0, 6), torch.randn(()) + torch.randn(5, 0, 6)) - def test_prod_integer_upcast(self): - self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, **kwargs), False) - self._test_reduce_integer_upcast(lambda x, **kwargs: torch.prod(x, 0, **kwargs)) + # non-empty, empty + self.assertEqual(torch.randn(0), torch.randn(0) + torch.randn(1)) + self.assertEqual(torch.randn(0, 7, 0, 6, 5, 0, 7), + torch.randn(0, 7, 0, 6, 5, 0, 1) + torch.randn(1, 1, 5, 1, 7)) + self.assertRaises(RuntimeError, lambda: torch.randn(7, 0) + torch.randn(2, 1)) - def test_cumsum_integer_upcast(self): - self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumsum(x, 0, **kwargs)) + def test_broadcast_tensors(self): + x0 = torch.randn(2, 1, 3) + x1 = torch.randn(3) + x2 = torch.randn(3, 1) + expected_size = (2, 3, 3) - def test_cumprod_integer_upcast(self): - self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumprod(x, 0, **kwargs)) - - def test_cross(self): - x = torch.rand(100, 3, 100) - y = torch.rand(100, 3, 100) - res1 = torch.cross(x, y) - res2 = torch.Tensor() - torch.cross(x, y, out=res2) - self.assertEqual(res1, res2) - - def test_cross_with_and_without_dim(self): - x = torch.rand(100, 3) - y = torch.rand(100, 3) - res1 = torch.cross(x, y, dim=1) - res2 = torch.cross(x, y, dim=-1) - res3 = torch.cross(x, y) - self.assertEqual(res1, res2) - self.assertEqual(res1, res3) + y0, y1, y2 = torch.broadcast_tensors(x0, x1, x2) + self.assertTrue(y0.size() == expected_size) + self.assertTrue(y1.size() == expected_size) + self.assertTrue(y2.size() == expected_size) - def test_cross_validation(self): - self.assertRaisesRegex( - RuntimeError, "inconsistent tensors dimensions", - lambda: torch.cross(torch.rand(100, 3), torch.rand(100, 3, 10))) - self.assertRaisesRegex( - RuntimeError, "inconsistent tensors sizes", - lambda: torch.cross(torch.rand(5, 3), torch.rand(3, 5))) - self.assertRaisesRegex( - RuntimeError, "no dimension of size 3 in input", - lambda: torch.cross(torch.rand(5, 4), torch.rand(5, 4))) - self.assertRaisesRegex( - RuntimeError, "dimension 0 does not have size 3", - lambda: torch.cross(torch.rand(5, 4, 3), torch.rand(5, 4, 3), dim=0)) - self.assertRaisesRegex( - RuntimeError, "dimension -1 does not have size 3", - lambda: torch.cross(torch.rand(5, 3, 4), torch.rand(5, 3, 4), dim=-1)) - self.assertRaisesRegex( - IndexError, "Dimension out of range", - lambda: torch.cross(torch.rand(5, 3, 4), torch.rand(5, 3, 4), dim=-5)) + def test_scalars_as_floats(self): + "zero-dim variables that don't require grad should bind to scalar arguments" + x = torch.tensor(2.) + y = torch.tensor(3.) + # 3 + (3 * 3) * 2 + self.assertEqual(y.addcmul(y, y, value=x), 21) - def test_zeros(self): - res1 = torch.zeros(100, 100) - res2 = torch.Tensor() - torch.zeros(100, 100, out=res2) - self.assertEqual(res1, res2) + x = torch.tensor(2., requires_grad=True) + self.assertRaises(Exception, lambda: y.addcmul(y, y, value=x)) - boolTensor = torch.zeros(2, 2, dtype=torch.bool) - expected = torch.tensor([[False, False], [False, False]], dtype=torch.bool) - self.assertEqual(boolTensor, expected) + def test_copy_broadcast(self): + torch.zeros(5, 6).copy_(torch.zeros(6)) + self.assertRaises(RuntimeError, lambda: torch.zeros(5, 6).copy_(torch.zeros(30))) - halfTensor = torch.zeros(1, 1, dtype=torch.half) - expected = torch.tensor([[0.]], dtype=torch.float16) - self.assertEqual(halfTensor, expected) + def test_copy_many_to_one(self): + # Testing in-place copy where it attempt to write from many memory + # storage to a single storage would cause RuntimeError to be thrown + self.assertRaises(RuntimeError, lambda: torch.zeros(1, 6).expand(5, 6).copy_(torch.zeros(5, 6))) - bfloat16Tensor = torch.zeros(1, 1, dtype=torch.bfloat16) - expected = torch.tensor([[0.]], dtype=torch.bfloat16) - self.assertEqual(bfloat16Tensor, expected) + def test_random(self): + # This test is flaky with p<=(2/(ub-lb))^200=6e-36 + t = torch.FloatTensor(200) + lb = 1 + ub = 4 - def test_std_mean(self): - for device in torch.testing.get_all_device_types(): - x = torch.rand(100, 50, 20, device=device) - for dim in range(x.dim()): - for unbiased in [False, True]: - for keepdim in [False, True]: - std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) - std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim) - mean2 = x.mean(dim=dim, keepdim=keepdim) - self.assertEqual(std1, std2) - self.assertEqual(mean1, mean2) + t.fill_(-1) + t.random_(lb, ub) + self.assertEqual(t.min(), lb) + self.assertEqual(t.max(), ub - 1) - def test_std_mean_all_dims(self): - for device in torch.testing.get_all_device_types(): - x = torch.rand(100, 50, 20, device=device) - for unbiased in [False, True]: - std1, mean1 = torch.std_mean(x, unbiased=unbiased) - std2 = x.std(unbiased=unbiased) - mean2 = x.mean() - self.assertEqual(std1, std2) - self.assertEqual(mean1, mean2) + t.fill_(-1) + t.random_(ub) + self.assertEqual(t.min(), 0) + self.assertEqual(t.max(), ub - 1) - def test_var_mean(self): - for device in torch.testing.get_all_device_types(): - x = torch.rand(100, 300, 50, device=device) - for dim in range(x.dim()): - for unbiased in [False, True]: - for keepdim in [False, True]: - var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) - var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim) - mean2 = x.mean(dim=dim, keepdim=keepdim) - self.assertEqual(var1, var2) - self.assertEqual(mean1, mean2) + def test_not_equal(self): + ones = torch.ones(10, dtype=torch.int) + self.assertRaisesRegex(AssertionError, "0 not greater than or equal to", + lambda: self.assertNotEqual(ones, ones)) - def test_var_mean_all_dims(self): - for device in torch.testing.get_all_device_types(): - x = torch.rand(100, 50, 20, device=device) - for unbiased in [False, True]: - var1, mean1 = torch.var_mean(x, unbiased=unbiased) - var2 = x.var(unbiased=unbiased) - mean2 = x.mean() - self.assertEqual(var1, var2) - self.assertEqual(mean1, mean2) + def assertIsOrdered(self, order, x, mxx, ixx, task): + SIZE = 4 + if order == 'descending': + def check_order(a, b): + # `a != a` because we put NaNs + # at the end of ascending sorted lists, + # and the beginning of descending ones. + return a != a or a >= b + elif order == 'ascending': + def check_order(a, b): + # see above + return b != b or a <= b + else: + error('unknown order "{}", must be "ascending" or "descending"'.format(order)) - def test_std_mean_some_dims(self): - sizes = (4, 6, 7, 5, 3) - dims = len(sizes) - for device in torch.testing.get_all_device_types(): - x = torch.rand(sizes, device=device) - for num_of_dims in range(2, dims): - dim_list = list(combinations(list(range(dims)), r=num_of_dims)) - for dim in dim_list: - for unbiased in [False, True]: - for keepdim in [False, True]: - std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) - std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim) - mean2 = x.mean(dim=dim, keepdim=keepdim) - self.assertEqual(std1, std2) - self.assertEqual(mean1, mean2) - - @torchtest.for_all_device_types() - def test_var_mean_some_dims(self, device): - sizes = (4, 6, 7, 5, 3) - dims = len(sizes) + are_ordered = True + for j, k in product(range(SIZE), range(1, SIZE)): + self.assertTrue(check_order(mxx[j][k - 1], mxx[j][k]), + 'torch.sort ({}) values unordered for {}'.format(order, task)) - x = torch.rand(sizes, device=device) - for num_of_dims in range(2, dims): - dim_list = list(combinations(list(range(dims)), r=num_of_dims)) - for dim in dim_list: - for unbiased in [False, True]: - for keepdim in [False, True]: - var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) - var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim) - mean2 = x.mean(dim=dim, keepdim=keepdim) - self.assertEqual(var1, var2) - self.assertEqual(mean1, mean2) + seen = set() + indicesCorrect = True + size = x.size(x.dim() - 1) + for k in range(size): + seen.clear() + for j in range(size): + self.assertEqual(x[k][ixx[k][j]], mxx[k][j], + 'torch.sort ({}) indices wrong for {}'.format(order, task)) + seen.add(ixx[k][j]) + self.assertEqual(len(seen), size) - @torchtest.for_all_device_types() - def test_zeros_like(self, device): - expected = torch.zeros((100, 100,), device=device) + def test_sort(self): + SIZE = 4 + x = torch.rand(SIZE, SIZE) + res1val, res1ind = torch.sort(x) - res1 = torch.zeros_like(expected) - self.assertEqual(res1, expected) + # Test use of result tensor + res2val = torch.Tensor() + res2ind = torch.LongTensor() + torch.sort(x, out=(res2val, res2ind)) + self.assertEqual(res1val, res2val, 0) + self.assertEqual(res1ind, res2ind, 0) + self.assertEqual(torch.argsort(x), res1ind) + self.assertEqual(x.argsort(), res1ind) - @unittest.skipIf(torch.cuda.device_count() < 2, 'only one GPU detected') - def test_zeros_like_multiple_device(self): - expected = torch.zeros(100, 100).cuda() - x = torch.cuda.FloatTensor(100, 100, device=1) - output = torch.zeros_like(x) - self.assertEqual(output, expected) + # Test sorting of random numbers + self.assertIsOrdered('ascending', x, res2val, res2ind, 'random') - def test_zeros_out(self): - shape = (3, 4) - out = torch.zeros(shape) - torch.zeros(shape, out=out) + # Test simple sort + self.assertEqual( + torch.sort(torch.Tensor((50, 40, 30, 20, 10)))[0], + torch.Tensor((10, 20, 30, 40, 50)), + 0 + ) - # change the dtype, layout, device - self.assertRaises(RuntimeError, lambda: torch.zeros(shape, dtype=torch.int64, out=out)) - self.assertRaises(RuntimeError, lambda: torch.zeros(shape, layout=torch.sparse_coo, out=out)) - if torch.cuda.is_available(): - self.assertRaises(RuntimeError, lambda: torch.zeros(shape, device='cuda', out=out)) + # Test that we still have proper sorting with duplicate keys + x = torch.floor(torch.rand(SIZE, SIZE) * 10) + torch.sort(x, out=(res2val, res2ind)) + self.assertIsOrdered('ascending', x, res2val, res2ind, 'random with duplicate keys') - # leave them the same - self.assertEqual(torch.zeros(shape), torch.zeros(shape, dtype=out.dtype, out=out)) - self.assertEqual(torch.zeros(shape), torch.zeros(shape, layout=torch.strided, out=out)) - self.assertEqual(torch.zeros(shape), torch.zeros(shape, device='cpu', out=out)) + # DESCENDING SORT + x = torch.rand(SIZE, SIZE) + res1val, res1ind = torch.sort(x, x.dim() - 1, True) - def test_histc(self): - for device in torch.testing.get_all_device_types(): - # negative nbins throws - with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'): - torch.histc(torch.tensor([1], dtype=torch.float, device=device), bins=-1) - - # without nbins - actual = torch.histc( - torch.tensor([2, 5], dtype=torch.float, device=device)) - expected = torch.zeros(100, dtype=torch.float, device=device) - expected.data[0] = 1 - expected.data[99] = 1 - self.assertEqual(expected, actual) - # tensor with the same element - actual = torch.histc(torch.ones(5, dtype=torch.float, device=device), bins=5) - self.assertEqual( - torch.tensor([0, 0, 5, 0, 0], dtype=torch.float, device=device), - actual) - # no element falls between [min, max] - actual = torch.histc( - torch.ones(5, dtype=torch.float, device=device), bins=5, min=2, max=3) - self.assertEqual( - torch.tensor([0, 0, 0, 0, 0], dtype=torch.float, device=device), - actual) - # element falls below min + integral bin size and - actual = torch.histc( - torch.tensor([2, 4, 2, 2, 5, 4], dtype=torch.float, device=device), - bins=5, min=1, max=5) - self.assertEqual( - torch.tensor([0, 3, 0, 2, 1], dtype=torch.float, device=device), - actual) - # non-integral bin size - actual = torch.histc( - torch.tensor([1, 2, 1], dtype=torch.float, device=device), - bins=4, min=0, max=3) - self.assertEqual( - torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device), - actual) - # double input - actual = torch.histc( - torch.tensor([1, 2, 1], dtype=torch.double, device=device), bins=4, min=0, max=3) - self.assertEqual( - torch.tensor([0, 2, 1, 0], dtype=torch.double, device=device), - actual) - self.assertEqual(actual.dtype, torch.double) - # mixed input - actual = torch.histc( - torch.tensor([1., 2, 1], dtype=torch.float, device=device), - bins=4, min=0, max=3) - self.assertEqual( - torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device), - actual) - self.assertEqual(actual.dtype, torch.float) - # scalar input and 1 bin -- should return a 1-dimensional tensor, not a scalar. - actual = torch.histc( - torch.tensor(0, dtype=torch.float, device=device), - bins=1, min=0, max=3) - self.assertEqual( - torch.tensor([1], dtype=torch.float, device=device), - actual) + # Test use of result tensor + res2val = torch.Tensor() + res2ind = torch.LongTensor() + torch.sort(x, x.dim() - 1, True, out=(res2val, res2ind)) + self.assertEqual(res1val, res2val, 0) + self.assertEqual(res1ind, res2ind, 0) + self.assertEqual(torch.argsort(x, x.dim() - 1, True), res1ind) + self.assertEqual(x.argsort(x.dim() - 1, True), res1ind) - # test against numpy.histogram() - def test_against_np(tensor, bins=100, min=0, max=0): - if min == 0 and max == 0: - min = tensor.min().item() - max = tensor.max().item() - nparr = tensor.cpu().numpy() - actual = torch.histc(tensor, bins=bins, min=min, max=max) - expected = torch.from_numpy(np.histogram(nparr, bins=bins, range=(min, max))[0]) - self.assertEqual(actual.cpu(), expected) + # Test sorting of random numbers + self.assertIsOrdered('descending', x, res2val, res2ind, 'random') - if TEST_NUMPY: - test_against_np(torch.tensor([1., 2, 1], device=device)) - test_against_np(torch.randn(5000, device=device)) + # Test simple sort task + self.assertEqual( + torch.sort(torch.Tensor((10, 20, 30, 40, 50)), 0, True)[0], + torch.Tensor((50, 40, 30, 20, 10)), + 0 + ) - # Test bins arg - test_against_np(torch.randn(301, device=device), bins=10) + # Test that we still have proper sorting with duplicate keys + self.assertIsOrdered('descending', x, res2val, res2ind, 'random with duplicate keys') - # Test truncated range - test_against_np(torch.randn(201, device=device), min=0.1, max=1) + # Test sorting with NaNs + x = torch.rand(SIZE, SIZE) + x[1][2] = float('NaN') + x[3][0] = float('NaN') + torch.sort(x, out=(res2val, res2ind)) + self.assertIsOrdered('ascending', x, res2val, res2ind, + 'random with NaNs') + torch.sort(x, out=(res2val, res2ind), descending=True) + self.assertIsOrdered('descending', x, res2val, res2ind, + 'random with NaNs') - noncontig = torch.randn(100, 3, device=device)[:, 2] - test_against_np(noncontig) + def test_topk(self): + def topKViaSort(t, k, dim, dir): + sorted, indices = t.sort(dim, dir) + return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k) - multidim = torch.randn(3, 5, 7, 2, device=device) - test_against_np(multidim) + def compareTensors(t, res1, ind1, res2, ind2, dim): + # Values should be exactly equivalent + self.assertEqual(res1, res2, 0) - expanded = torch.randn(1, 5, 1, 2, device=device).expand(3, 5, 7, 2) - test_against_np(expanded) - - def test_ones(self): - res1 = torch.ones(100, 100) - res2 = torch.Tensor() - torch.ones(100, 100, out=res2) - self.assertEqual(res1, res2) + # Indices might differ based on the implementation, since there is + # no guarantee of the relative order of selection + if not ind1.eq(ind2).all(): + # To verify that the indices represent equivalent elements, + # gather from the input using the topk indices and compare against + # the sort indices + vals = t.gather(dim, ind2) + self.assertEqual(res1, vals, 0) - # test boolean tensor - res1 = torch.ones(1, 2, dtype=torch.bool) - expected = torch.tensor([[True, True]], dtype=torch.bool) - self.assertEqual(res1, expected) + def compare(t, k, dim, dir): + topKVal, topKInd = t.topk(k, dim, dir, True) + sortKVal, sortKInd = topKViaSort(t, k, dim, dir) + compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim) - def test_ones_like(self): - expected = torch.ones(100, 100) + t = torch.rand(random.randint(1, SIZE), + random.randint(1, SIZE), + random.randint(1, SIZE)) - res1 = torch.ones_like(expected) - self.assertEqual(res1, expected) + for _kTries in range(3): + for _dimTries in range(3): + for transpose in (True, False): + for dir in (True, False): + testTensor = t + if transpose: + dim1 = random.randrange(t.ndimension()) + dim2 = dim1 + while dim1 == dim2: + dim2 = random.randrange(t.ndimension()) - # test boolean tensor - expected = torch.tensor([True, True], dtype=torch.bool) - res1 = torch.ones_like(expected) - self.assertEqual(res1, expected) + testTensor = t.transpose(dim1, dim2) - @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') - def test_ones_like_cuda(self): - expected = torch.ones(100, 100).cuda() + dim = random.randrange(testTensor.ndimension()) + k = random.randint(1, testTensor.size(dim)) + compare(testTensor, k, dim, dir) - res1 = torch.ones_like(expected) - self.assertEqual(res1, expected) + def test_topk_arguments(self): + q = torch.randn(10, 2, 10) + # Make sure True isn't mistakenly taken as the 2nd dimension (interpreted as 1) + self.assertRaises(TypeError, lambda: q.topk(4, True)) - @unittest.skipIf(torch.cuda.device_count() < 2, 'only one GPU detected') - def test_ones_like_multiple_device(self): - expected = torch.ones(100, 100).cuda() - x = torch.cuda.FloatTensor(100, 100, device=1) - output = torch.ones_like(x) - self.assertEqual(output, expected) + def test_median(self): + for size in (155, 156): + x = torch.rand(size, size) + x0 = x.clone() - def test_dtypes(self): - all_dtypes = torch.testing.get_all_dtypes() - do_test_dtypes(self, all_dtypes, torch.strided, torch.device('cpu')) - if torch.cuda.is_available(): - all_dtypes.remove(torch.bfloat16) # Remove once _th_zero_ is enabled on cuda for bfloat16 - do_test_dtypes(self, all_dtypes, torch.strided, torch.device('cuda:0')) + nelem = x.nelement() + res1val = torch.median(x) + res2val, _ = torch.sort(x.view(nelem)) + ind = int(math.floor((nelem + 1) / 2) - 1) - def test_copy_dtypes(self): - all_dtypes = torch.testing.get_all_dtypes() - for dtype in all_dtypes: - copied_dtype = copy.deepcopy(dtype) - self.assertIs(dtype, copied_dtype) + self.assertEqual(res2val[ind], res1val, 0) - def test_copy_transpose(self): - x = torch.arange(100 * 100, dtype=torch.float).reshape(100, 100).t() - y = torch.empty(100, 100, dtype=torch.float) - y.copy_(x) - self.assertEqual(y[:, 0], range(100)) - self.assertEqual(y[:, 40], range(4000, 4100)) + res1val, res1ind = torch.median(x, dim=1, keepdim=False) + res2val, res2ind = torch.sort(x) + ind = int(math.floor((size + 1) / 2) - 1) - y = torch.empty(100, 100, dtype=torch.double) - y.copy_(x) - self.assertEqual(y[:, 0], range(100)) - self.assertEqual(y[:, 40], range(4000, 4100)) + self.assertEqual(res2val.select(1, ind), res1val, 0) + self.assertEqual(res2val.select(1, ind), res1val, 0) - def test_device(self): - cpu = torch.device('cpu') - self.assertEqual('cpu', str(cpu)) - self.assertEqual('cpu', cpu.type) - self.assertEqual(None, cpu.index) + # Test use of result tensor + res2val = torch.Tensor() + res2ind = torch.LongTensor() + torch.median(x, dim=-1, keepdim=False, out=(res2val, res2ind)) + self.assertEqual(res2val, res1val, 0) + self.assertEqual(res2ind, res1ind, 0) - cpu0 = torch.device('cpu:0') - self.assertEqual('cpu:0', str(cpu0)) - self.assertEqual('cpu', cpu0.type) - self.assertEqual(0, cpu0.index) + # Test non-default dim + res1val, res1ind = torch.median(x, 0, keepdim=False) + res2val, res2ind = torch.sort(x, 0) + self.assertEqual(res1val, res2val[ind], 0) + self.assertEqual(res1ind, res2ind[ind], 0) - cpu0 = torch.device('cpu', 0) - self.assertEqual('cpu:0', str(cpu0)) - self.assertEqual('cpu', cpu0.type) - self.assertEqual(0, cpu0.index) + # input unchanged + self.assertEqual(x, x0, 0) - cuda = torch.device('cuda') - self.assertEqual('cuda', str(cuda)) - self.assertEqual('cuda', cuda.type) - self.assertEqual(None, cuda.index) + def test_mode(self): + x = torch.arange(1., SIZE * SIZE + 1).clone().resize_(SIZE, SIZE) + x[:2] = 1 + x[:, :2] = 1 + x0 = x.clone() - cuda1 = torch.device('cuda:1') - self.assertEqual('cuda:1', str(cuda1)) - self.assertEqual('cuda', cuda1.type) - self.assertEqual(1, cuda1.index) + # Pre-calculated results. + res1val = torch.Tensor(SIZE).fill_(1) + # The indices are the position of the last appearance of the mode element. + res1ind = torch.LongTensor(SIZE).fill_(1) + res1ind[0] = SIZE - 1 + res1ind[1] = SIZE - 1 - cuda1 = torch.device('cuda', 1) - self.assertEqual('cuda:1', str(cuda1)) - self.assertEqual('cuda', cuda1.type) - self.assertEqual(1, cuda1.index) + res2val, res2ind = torch.mode(x, keepdim=False) + self.assertEqual(res1val, res2val, 0) + self.assertEqual(res1ind, res2ind, 0) - self.assertRaises(RuntimeError, lambda: torch.device('cpu:-1')) - self.assertRaises(RuntimeError, lambda: torch.device('cpu:1')) - self.assertRaises(RuntimeError, lambda: torch.device('cpu', -1)) - self.assertRaises(RuntimeError, lambda: torch.device('cpu', 1)) - self.assertRaises(RuntimeError, lambda: torch.device('cuda:-1')) - self.assertRaises(RuntimeError, lambda: torch.device('cuda', -1)) - self.assertRaises(RuntimeError, lambda: torch.device(-1)) + # Test use of result tensor + res2val = torch.Tensor() + res2ind = torch.LongTensor() + torch.mode(x, keepdim=False, out=(res2val, res2ind)) + self.assertEqual(res1val, res2val, 0) + self.assertEqual(res1ind, res2ind, 0) - self.assertRaises(RuntimeError, lambda: torch.device('other')) - self.assertRaises(RuntimeError, lambda: torch.device('other:0')) + # Test non-default dim + res2val, res2ind = torch.mode(x, 0, False) + self.assertEqual(res1val, res2val, 0) + self.assertEqual(res1ind, res2ind, 0) - device_set = {'cpu', 'cpu:0', 'cuda', 'cuda:0', 'cuda:1', 'cuda:10', 'cuda:100'} - device_hash_set = set() - for device in list(device_set): - device_hash_set.add(hash(torch.device(device))) - self.assertEqual(len(device_set), len(device_hash_set)) + # input unchanged + self.assertEqual(x, x0, 0) - def test_tensor_device(self): - def assertEqual(device_str, fn): - self.assertEqual(torch.device(device_str), fn().device) - self.assertEqual(device_str, str(fn().device)) + def test_trilu_indices(self): + for test_args in tri_tests_args: + _compare_trilu_indices(self, *test_args) + run_additional_tri_tests(self, 'cpu') - assertEqual('cpu', lambda: torch.tensor(5)) - assertEqual('cpu', lambda: torch.ones((2, 3), dtype=torch.float32, device='cpu')) - # NOTE: 'cpu' is the canonical representation of 'cpu:0', but 'cuda:X' is the canonical - # representation of cuda devices. - assertEqual('cpu', lambda: torch.ones((2, 3), dtype=torch.float32, device='cpu:0')) - assertEqual('cpu', lambda: torch.tensor(torch.ones((2, 3), dtype=torch.float32), device='cpu:0')) - if TEST_NUMPY: - assertEqual('cpu', lambda: torch.tensor(np.random.randn(2, 3), device='cpu')) + # test default options + x = torch.ones( + 3, 3, dtype=torch.long, device='cpu', layout=torch.strided) + self.assertEqual( + x.tril(0).nonzero().transpose(0, 1), torch.tril_indices(3, 3)) + self.assertEqual( + x.triu(0).nonzero().transpose(0, 1), torch.triu_indices(3, 3)) - if torch.cuda.is_available(): - assertEqual('cuda:0', lambda: torch.tensor(5).cuda(0)) - assertEqual('cuda:0', lambda: torch.tensor(5).cuda('cuda:0')) - self.assertRaises(RuntimeError, lambda: torch.tensor(5).cuda('cpu')) - self.assertRaises(RuntimeError, lambda: torch.tensor(5).cuda('cpu:0')) - assertEqual('cuda:0', lambda: torch.tensor(5, dtype=torch.int64, device=0)) - assertEqual('cuda:0', lambda: torch.tensor(5, dtype=torch.int64, device='cuda:0')) - assertEqual('cuda:' + str(torch.cuda.current_device()), - lambda: torch.tensor(5, dtype=torch.int64, device='cuda')) - assertEqual('cuda:0', lambda: torch.tensor(torch.ones((2, 3), dtype=torch.float32), device='cuda:0')) - if TEST_NUMPY: - assertEqual('cuda:0', lambda: torch.tensor(np.random.randn(2, 3), device='cuda:0')) + # test stride 0 cases + x = torch.ones( + 3, 1, 3, 3, dtype=torch.long, device='cpu', layout=torch.strided) + output = x.triu(2).expand(3, 3, 3, 3) + b = x.clone().expand(3, 3, 3, 3) + self.assertEqual(b.triu(2), output) + self.assertRaises(RuntimeError, lambda: b.triu_(2)) - if torch.cuda.device_count() > 1: - assertEqual('cuda:1', lambda: torch.tensor(5).cuda(1)) - assertEqual('cuda:1', lambda: torch.tensor(5).cuda('cuda:1')) - assertEqual('cuda:1', lambda: torch.tensor(5, dtype=torch.int64, device=1)) - assertEqual('cuda:1', lambda: torch.tensor(5, dtype=torch.int64, device='cuda:1')) - assertEqual('cuda:1', lambda: torch.tensor(torch.ones((2, 3), dtype=torch.float32), device='cuda:1')) - if TEST_NUMPY: - assertEqual('cuda:1', lambda: torch.tensor(np.random.randn(2, 3), device='cuda:1')) + def test_cat(self): + SIZE = 10 + for dtype in (torch.half, torch.double, torch.int): + for dim in range(-3, 3): + pos_dim = dim if dim >= 0 else 3 + dim + x = torch.randint(low=-100, high=100, size=(13, SIZE, SIZE)).to(dtype).transpose(0, pos_dim) + y = torch.randint(low=-100, high=100, size=(17, SIZE, SIZE)).to(dtype).transpose(0, pos_dim) + z = torch.randint(low=-100, high=100, size=(19, SIZE, SIZE)).to(dtype).transpose(0, pos_dim) - @unittest.skipIf(torch.cuda.device_count() < 2, 'fewer than 2 GPUs detected') - def test_device_guard(self): - # verify that all operators with `device_guard: False` behave properly with multiple devices. - # TODO: if we had operator introspection we could figure out this set of operators automatically... - current_device = torch.cuda.current_device() - device = torch.device('cuda:1') if current_device == 0 else torch.device('cuda:0') - x = torch.randn((1, 2, 3), device=device) - y = torch.zeros((1, 3, 2), device=device) - scalar = torch.tensor(5, device=device) + res1 = torch.cat((x, y, z), dim) + self.assertEqual(res1.narrow(pos_dim, 0, 13), x, 0) + self.assertEqual(res1.narrow(pos_dim, 13, 17), y, 0) + self.assertEqual(res1.narrow(pos_dim, 30, 19), z, 0) - # property ops - torch.cudnn_is_acceptable(x) - x.is_distributed() - x.is_floating_point() - x.is_complex() - x.is_same_size(y) - x.is_signed() - x.size(0) - x.stride(0) - x.numel() - x.is_set_to(y) - x.data_ptr() - scalar.is_nonzero() + x = torch.randint(low=-100, high=100, size=(20, SIZE, SIZE)).to(dtype) + self.assertEqual(torch.cat(torch.split(x, 7)), x) + self.assertEqual(torch.cat(torch.chunk(x, 7)), x) - # sparse property ops - y[0][1] = 5 - y_sparse = y.to_sparse() - y_sparse.sparse_dim() - y_sparse._dimI() - y_sparse.dense_dim() - y_sparse._dimV() - y_sparse._nnz() - y_sparse.is_coalesced() - y_sparse._indices() - y_sparse._values() - y_sparse.indices() - y_sparse.values() - - # in-place ops - def inplace(): - return torch.randn((1, 2, 3), device=device) - inplace().as_strided_(y.size(), y.stride()) - inplace().resize_(y.size()) - inplace().squeeze_() - inplace().squeeze_(0) - inplace().unsqueeze_(2) - inplace().transpose_(1, 2) - inplace().squeeze_().t_() - inplace().set_(x.storage()) - inplace().set_(x.storage(), x.storage_offset(), x.size(), x.stride()) - inplace().set_(x) - inplace().set_() - y_sparse._coalesced_(True) + y = torch.randint(low=-100, high=100, size=(1, SIZE, SIZE)).to(dtype) + z = torch.cat([x, y]) + self.assertEqual(z.size(), (21, SIZE, SIZE)) - # shape modification - x.as_strided(y.size(), y.stride()) - x.expand((5, 2, 3)) - x.expand_as(x) - x.sum_to_size((1,)) - torch.broadcast_tensors(x , x) - x.reshape((1, 3, 2)) - x.reshape_as(y) - x.squeeze() - x.squeeze(0) - x.squeeze().t() - x.transpose(1, 2) - x.unsqueeze(2) - x.view((1, 3, 2)) - x.view_as(y) + self.assertRaises(RuntimeError, lambda: torch.cat([])) + self.assertRaisesRegex(TypeError, 'got None', lambda: torch.cat([x, None])) - # chunk, split, etc. - x.chunk(2, dim=1) - x.split(1, dim=2) - x.split_with_sizes([1, 2], dim=2) - x.unfold(dimension=2, size=1, step=1) + def test_cat_bad_input_sizes(self): + x = torch.randn(2, 1) + y = torch.randn(2, 1, 1) + z = torch.randn(2, 1, 1) + self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z])) - x.narrow(1, 1, 1) - x.select(1, 1) - torch.isnan(x) + x = torch.randn(2, 1, 2) + y = torch.randn(2, 1, 1) + z = torch.randn(2, 2, 1) + self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z], dim=1)) - torch.empty((1, 3, 2), out=y) - torch.empty_like(x) - torch.empty_like(x, dtype=torch.int64) + def test_cat_scalars(self): + x = torch.tensor(0) + y = torch.tensor(1) + with self.assertRaisesRegex(RuntimeError, 'zero-dimensional.*cannot be concatenated'): + torch.cat([x, y]) - # to - x.to(x) - x.to(y) - x.to(x, copy=True) + @slowTest + def test_cat_big(self): + SIZE1 = 6500 + SIZE2 = 4500 + concat_list = [] + concat_list.append(torch.ones((SIZE1, 1024 * 512), dtype=torch.uint8)) + concat_list.append(torch.ones((SIZE2, 1024 * 512), dtype=torch.uint8)) + result = torch.cat(concat_list) + self.assertEqual(result.size(0), SIZE1 + SIZE2) - def test_to(self): - def test_copy_behavior(t, non_blocking=False): - self.assertIs(t, t.to(t, non_blocking=non_blocking)) - self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking)) - self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking)) - self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=True)) - self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=True)) - self.assertIsNot(t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)) + def test_narrow(self): + x = torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) + self.assertEqual(x.narrow(0, 0, 1), torch.Tensor([[0, 1, 2]])) + self.assertEqual(x.narrow(0, 0, 2), torch.Tensor([[0, 1, 2], [3, 4, 5]])) + self.assertEqual(x.narrow(0, 1, 1), torch.Tensor([[3, 4, 5]])) + self.assertEqual(x.narrow(0, -1, 1), torch.Tensor([[6, 7, 8]])) + self.assertEqual(x.narrow(0, -2, 2), torch.Tensor([[3, 4, 5], [6, 7, 8]])) + self.assertEqual(x.narrow(0, -3, 3), torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])) + self.assertEqual(x.narrow(-1, -1, 1), torch.Tensor([[2], [5], [8]])) + self.assertEqual(x.narrow(-2, -1, 1), torch.Tensor([[6, 7, 8]])) - devices = [t.device] - if t.device.type == 'cuda': - if t.device.index == -1: - devices.append('cuda:{}'.format(torch.cuda.current_device())) - elif t.device.index == torch.cuda.current_device(): - devices.append('cuda') - for device in devices: - self.assertIs(t, t.to(device, non_blocking=non_blocking)) - self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking)) - self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=True)) - self.assertIsNot(t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True)) + def test_stack(self): + for dtype in (torch.half, torch.double, torch.int): + x = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) + y = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) + z = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) + for dim in range(4): + res = torch.stack((x, y, z), dim) + res_neg = torch.stack((x, y, z), dim - 4) + expected_size = x.size()[:dim] + (3,) + x.size()[dim:] + self.assertEqual(res, res_neg) + self.assertEqual(res.size(), expected_size) + self.assertEqual(res.select(dim, 0), x, 0) + self.assertEqual(res.select(dim, 1), y, 0) + self.assertEqual(res.select(dim, 2), z, 0) - a = torch.tensor(5) - test_copy_behavior(a) - self.assertEqual(a.device, a.to('cpu').device) - self.assertEqual(a.device, a.to('cpu', dtype=torch.float32).device) - self.assertIs(torch.float32, a.to('cpu', dtype=torch.float32).dtype) - self.assertEqual(a.device, a.to(torch.float32).device) - self.assertIs(torch.float32, a.to(dtype=torch.float32).dtype) - self.assertEqual(a.data_ptr(), a.to('cpu').data_ptr()) - self.assertEqual(a.data_ptr(), a.to(dtype=a.dtype, device=a.device, copy=False).data_ptr()) - self.assertEqual(a.data_ptr(), a.to('cpu', copy=False).data_ptr()) - self.assertNotEqual(a.data_ptr(), a.to('cpu', copy=True).data_ptr()) + def test_stack_out(self): + for dtype in (torch.half, torch.double, torch.int): + x = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) + y = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) + z = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) + for dim in range(4): + expected_size = x.size()[:dim] + (3,) + x.size()[dim:] + res_out = x.new(expected_size) + res_neg_out = x.new(expected_size) + res_out_dp = res_out.data_ptr() + res_out_neg_dp = res_neg_out.data_ptr() + torch.stack((x, y, z), dim, out=res_out) + torch.stack((x, y, z), dim - 4, out=res_neg_out) + self.assertEqual(res_out, res_neg_out) + self.assertEqual(res_out.size(), expected_size) + self.assertEqual(res_out_dp, res_out.data_ptr()) + self.assertEqual(res_out_neg_dp, res_neg_out.data_ptr()) + self.assertEqual(res_out.select(dim, 0), x, 0) + self.assertEqual(res_out.select(dim, 1), y, 0) + self.assertEqual(res_out.select(dim, 2), z, 0) - if torch.cuda.is_available(): - for non_blocking in [True, False]: - for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']: - b = torch.tensor(5., device=cuda) - test_copy_behavior(b, non_blocking) - self.assertEqual(b.device, b.to(cuda, non_blocking=non_blocking).device) - self.assertEqual(a.device, b.to('cpu', non_blocking=non_blocking).device) - self.assertEqual(b.device, a.to(cuda, non_blocking=non_blocking).device) - self.assertIs(torch.int32, b.to('cpu', dtype=torch.int32, non_blocking=non_blocking).dtype) - self.assertEqual(a.device, b.to('cpu', dtype=torch.int32, non_blocking=non_blocking).device) - self.assertIs(torch.int32, b.to(dtype=torch.int32).dtype) - self.assertEqual(b.device, b.to(dtype=torch.int32).device) + def test_unbind(self): + x = torch.rand(2, 3, 4, 5) + for dim in range(4): + res = torch.unbind(x, dim) + res2 = x.unbind(dim) + self.assertEqual(x.size(dim), len(res)) + self.assertEqual(x.size(dim), len(res2)) + for i in range(dim): + self.assertEqual(x.select(dim, i), res[i]) + self.assertEqual(x.select(dim, i), res2[i]) - def test_to_with_tensor(self): - a = torch.tensor(5) - self.assertEqual(a.device, a.to(a).device) + def test_logspace(self): + _from = random.random() + to = _from + random.random() + res1 = torch.logspace(_from, to, 137) + res2 = torch.Tensor() + torch.logspace(_from, to, 137, out=res2) + self.assertEqual(res1, res2, 0) + self.assertRaises(RuntimeError, lambda: torch.logspace(0, 1, -1)) + self.assertEqual(torch.logspace(0, 1, 1), torch.ones(1), 0) - if torch.cuda.is_available(): - for non_blocking in [True, False]: - for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']: - b = torch.tensor(5., device=cuda) - self.assertEqual(b.device, b.to(b, non_blocking=non_blocking).device) - self.assertEqual(a.device, b.to(a, non_blocking=non_blocking).device) - self.assertEqual(b.device, a.to(b, non_blocking=non_blocking).device) + # Check non-default base=2 + self.assertEqual(torch.logspace(1, 1, 1, 2), torch.ones(1) * 2) + self.assertEqual(torch.logspace(0, 2, 3, 2), torch.Tensor((1, 2, 4))) - def test_empty_full(self): - do_test_empty_full(self, torch.testing.get_all_math_dtypes('cpu'), torch.strided, torch.device('cpu')) - if torch.cuda.device_count() > 0: - do_test_empty_full(self, torch.testing.get_all_math_dtypes('cpu'), torch.strided, None) - do_test_empty_full(self, torch.testing.get_all_math_dtypes('cpu'), torch.strided, torch.device('cuda:0')) + # Check logspace_ for generating with start > end. + self.assertEqual(torch.logspace(1, 0, 2), torch.Tensor((10, 1)), 0) - def test_dtype_out_match(self): - d = torch.autograd.Variable(torch.DoubleTensor(2, 3)) - self.assertRaises(RuntimeError, lambda: torch.zeros((2, 3), out=d, dtype=torch.float32)) + # Check logspace_ for non-contiguous tensors. + x = torch.zeros(2, 3) + y = torch.logspace(0, 3, 4, out=x.narrow(1, 1, 2)) + self.assertEqual(x, torch.Tensor(((0, 1, 10), (0, 100, 1000))), 0) - def test_constructor_dtypes(self): - default_type = torch.Tensor().type() - self.assertIs(torch.Tensor().dtype, torch.get_default_dtype()) + def test_rand(self): + torch.manual_seed(123456) + res1 = torch.rand(SIZE, SIZE) + res2 = torch.Tensor() + torch.manual_seed(123456) + torch.rand(SIZE, SIZE, out=res2) + self.assertEqual(res1, res2) - self.assertIs(torch.uint8, torch.ByteTensor.dtype) - self.assertIs(torch.float32, torch.FloatTensor.dtype) - self.assertIs(torch.float64, torch.DoubleTensor.dtype) + def test_randint(self): + torch.manual_seed(123456) + res1 = torch.randint(0, 6, (SIZE, SIZE)) + res2 = torch.Tensor() + torch.manual_seed(123456) + torch.randint(0, 6, (SIZE, SIZE), out=res2) + torch.manual_seed(123456) + res3 = torch.randint(6, (SIZE, SIZE)) + res4 = torch.Tensor() + torch.manual_seed(123456) + torch.randint(6, (SIZE, SIZE), out=res4) + self.assertEqual(res1, res2) + self.assertEqual(res1, res3) + self.assertEqual(res1, res4) + self.assertEqual(res2, res3) + self.assertEqual(res2, res4) + self.assertEqual(res3, res4) + res1 = res1.view(-1) + high = (res1 < 6).type(torch.LongTensor) + low = (res1 >= 0).type(torch.LongTensor) + tensorSize = res1.size()[0] + assert(tensorSize == high.sum()) + assert(tensorSize == low.sum()) - torch.set_default_tensor_type('torch.FloatTensor') - self.assertIs(torch.float32, torch.get_default_dtype()) - self.assertIs(torch.FloatStorage, torch.Storage) + def test_randn(self): + torch.manual_seed(123456) + res1 = torch.randn(SIZE, SIZE) + res2 = torch.Tensor() + torch.manual_seed(123456) + torch.randn(SIZE, SIZE, out=res2) + self.assertEqual(res1, res2) - torch.set_default_dtype(torch.float64) - self.assertIs(torch.float64, torch.get_default_dtype()) - self.assertIs(torch.DoubleStorage, torch.Storage) + def test_slice(self): + empty = torch.empty(0, 4) + x = torch.arange(0., 16).view(4, 4) + self.assertEqual(x[:], x) + self.assertEqual(x[:4], x) + # start and stop are clamped to the size of dim + self.assertEqual(x[:5], x) + # if start >= stop then the result is empty + self.assertEqual(x[2:1], empty) + self.assertEqual(x[2:2], empty) + # out of bounds is also empty + self.assertEqual(x[10:12], empty) + # additional correctness checks + self.assertEqual(x[:1].data.tolist(), [[0, 1, 2, 3]]) + self.assertEqual(x[:-3].data.tolist(), [[0, 1, 2, 3]]) + self.assertEqual(x[:, -2:3].data.tolist(), [[2], [6], [10], [14]]) + self.assertEqual(x[0:-1:2].data.tolist(), [[0, 1, 2, 3], [8, 9, 10, 11]]) - torch.set_default_tensor_type(torch.FloatTensor) - self.assertIs(torch.float32, torch.get_default_dtype()) - self.assertIs(torch.FloatStorage, torch.Storage) + @skipIfNoLapack + def test_ormqr(self): + mat1 = torch.randn(7, 7) + mat2 = torch.randn(7, 7) + q, r = torch.qr(mat1) + m, tau = torch.geqrf(mat1) + out_holder = torch.empty_like(mat1) - if torch.cuda.is_available(): - torch.set_default_tensor_type(torch.cuda.FloatTensor) - self.assertIs(torch.float32, torch.get_default_dtype()) - self.assertIs(torch.float32, torch.cuda.FloatTensor.dtype) - self.assertIs(torch.cuda.FloatStorage, torch.Storage) + res1 = torch.mm(q, mat2) + res2 = torch.ormqr(m, tau, mat2, left=True, transpose=False) + torch.ormqr(m, tau, mat2, out=out_holder) + self.assertEqual(res1, res2) + self.assertEqual(res2, out_holder) - torch.set_default_dtype(torch.float64) - self.assertIs(torch.float64, torch.get_default_dtype()) - self.assertIs(torch.cuda.DoubleStorage, torch.Storage) + res1 = torch.mm(mat2, q) + res2 = torch.ormqr(m, tau, mat2, left=False, transpose=False) + torch.ormqr(m, tau, mat2, left=False, transpose=False, out=out_holder) + self.assertEqual(res1, res2) + self.assertEqual(res2, out_holder) - # don't support integral or sparse default types. - self.assertRaises(TypeError, lambda: torch.set_default_tensor_type('torch.IntTensor')) - self.assertRaises(TypeError, lambda: torch.set_default_dtype(torch.int64)) + res1 = torch.mm(q.t(), mat2) + res2 = torch.ormqr(m, tau, mat2, left=True, transpose=True) + torch.ormqr(m, tau, mat2, left=True, transpose=True, out=out_holder) + self.assertEqual(res1, res2) + self.assertEqual(res2, out_holder) - # don't allow passing dtype to set_default_tensor_type - self.assertRaises(TypeError, lambda: torch.set_default_tensor_type(torch.float32)) + res1 = torch.mm(mat2, q.t()) + res2 = torch.ormqr(m, tau, mat2, left=False, transpose=True) + torch.ormqr(m, tau, mat2, left=False, transpose=True, out=out_holder) + self.assertEqual(res1, res2) + self.assertEqual(res2, out_holder) - torch.set_default_tensor_type(default_type) + @staticmethod + def _test_triangular_solve_batched(self, cast): + from common_utils import triangular_solve_test_helper - def test_constructor_device_legacy(self): - self.assertRaises(RuntimeError, lambda: torch.FloatTensor(device='cuda')) - self.assertRaises(RuntimeError, lambda: torch.FloatTensor(torch.Size([2, 3, 4]), device='cuda')) - self.assertRaises(RuntimeError, lambda: torch.FloatTensor((2.0, 3.0), device='cuda')) + def triangular_solve_batch_helper(A_dims, b_dims, cast, upper, unitriangular, transpose): + b, A = triangular_solve_test_helper(A_dims, b_dims, cast, upper, unitriangular) + x_exp_list = [] + for i in range(b_dims[0]): + x_exp_list.append(torch.triangular_solve(b[i], A[i], upper=upper, + unitriangular=unitriangular, transpose=transpose)[0]) + x_exp = torch.stack(x_exp_list) # Stacked output + x_act = torch.triangular_solve(b, A, upper=upper, + unitriangular=unitriangular, transpose=transpose)[0] # Actual output + self.assertEqual(x_act, x_exp) # Equality check + if transpose: + self.assertLessEqual(b.dist(torch.matmul(A.transpose(-2, -1), x_act)), 3e-12) # Correctness check + else: + self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 3e-12) # Correctness check - self.assertRaises(RuntimeError, lambda: torch.Tensor(device='cuda')) - self.assertRaises(RuntimeError, lambda: torch.Tensor(torch.Size([2, 3, 4]), device='cuda')) - self.assertRaises(RuntimeError, lambda: torch.Tensor((2.0, 3.0), device='cuda')) + for (upper, unitriangular, transpose), batchsize in product(product([True, False], repeat=3), [1, 3, 4]): + triangular_solve_batch_helper((batchsize, 5, 5), (batchsize, 5, 10), cast, + upper, unitriangular, transpose) - x = torch.randn((3,), device='cpu') - self.assertRaises(RuntimeError, lambda: x.new(device='cuda')) - self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cuda')) - self.assertRaises(RuntimeError, lambda: x.new((2.0, 3.0), device='cuda')) + @skipIfNoLapack + def test_triangular_solve_batched(self): + self._test_triangular_solve_batched(self, lambda t: t) - if torch.cuda.is_available(): - self.assertRaises(RuntimeError, lambda: torch.cuda.FloatTensor(device='cpu')) - self.assertRaises(RuntimeError, lambda: torch.cuda.FloatTensor(torch.Size([2, 3, 4]), device='cpu')) - self.assertRaises(RuntimeError, lambda: torch.cuda.FloatTensor((2.0, 3.0), device='cpu')) + @skipIfNoLapack + def test_eig(self): + a = torch.Tensor(((1.96, 0.00, 0.00, 0.00, 0.00), + (-6.49, 3.80, 0.00, 0.00, 0.00), + (-0.47, -6.39, 4.17, 0.00, 0.00), + (-7.20, 1.50, -1.51, 5.70, 0.00), + (-0.65, -6.34, 2.67, 1.80, -7.10))).t().contiguous() + e = torch.eig(a)[0] + ee, vv = torch.eig(a, True) + te = torch.Tensor() + tv = torch.Tensor() + eee, vvv = torch.eig(a, True, out=(te, tv)) + self.assertEqual(e, ee, 1e-12) + self.assertEqual(ee, eee, 1e-12) + self.assertEqual(ee, te, 1e-12) + self.assertEqual(vv, vvv, 1e-12) + self.assertEqual(vv, tv, 1e-12) - default_type = torch.Tensor().type() - torch.set_default_tensor_type(torch.cuda.FloatTensor) - self.assertRaises(RuntimeError, lambda: torch.Tensor(device='cpu')) - self.assertRaises(RuntimeError, lambda: torch.Tensor(torch.Size([2, 3, 4]), device='cpu')) - self.assertRaises(RuntimeError, lambda: torch.Tensor((2.0, 3.0), device='cpu')) - torch.set_default_tensor_type(torch.cuda.FloatTensor) - torch.set_default_tensor_type(default_type) + # test reuse + X = torch.randn(4, 4) + X = torch.mm(X.t(), X) + e, v = torch.zeros(4, 2), torch.zeros(4, 4) + torch.eig(X, True, out=(e, v)) + Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t()) + self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong') + self.assertFalse(v.is_contiguous(), 'V is contiguous') - x = torch.randn((3,), device='cuda') - self.assertRaises(RuntimeError, lambda: x.new(device='cpu')) - self.assertRaises(RuntimeError, lambda: x.new(torch.Size([2, 3, 4]), device='cpu')) - self.assertRaises(RuntimeError, lambda: x.new((2.0, 3.0), device='cpu')) + torch.eig(X, True, out=(e, v)) + Xhat = torch.mm(v, torch.mm(e.select(1, 0).diag(), v.t())) + self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong') + self.assertFalse(v.is_contiguous(), 'V is contiguous') - def test_type(self): - x = torch.randn(3, 3).double() - self.assertEqual(x.type('torch.FloatTensor').dtype, torch.float32) - self.assertEqual(x.type(torch.FloatTensor).dtype, torch.float32) - self.assertEqual(x.int().type(torch.Tensor).dtype, torch.get_default_dtype()) - self.assertEqual(x.type(torch.int32).dtype, torch.int32) + # test non-contiguous + X = torch.randn(4, 4) + X = torch.mm(X.t(), X) + e = torch.zeros(4, 2, 2)[:, 1] + v = torch.zeros(4, 2, 4)[:, 1] + self.assertFalse(v.is_contiguous(), 'V is contiguous') + self.assertFalse(e.is_contiguous(), 'E is contiguous') + torch.eig(X, True, out=(e, v)) + Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t()) + self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong') - def test_tensor_factory(self): - expected = torch.Tensor([1, 1]) - # test data - res1 = torch.tensor([1, 1]) - self.assertEqual(res1, expected) + @staticmethod + def _test_fft_ifft_rfft_irfft(self, device='cpu'): + def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x): + x = prepro_fn(torch.randn(*sizes, device=device)) + for normalized in (True, False): + res = x.fft(signal_ndim, normalized=normalized) + rec = res.ifft(signal_ndim, normalized=normalized) + self.assertEqual(x, rec, 1e-8, 'fft and ifft') + res = x.ifft(signal_ndim, normalized=normalized) + rec = res.fft(signal_ndim, normalized=normalized) + self.assertEqual(x, rec, 1e-8, 'ifft and fft') - res1 = torch.tensor([1, 1], dtype=torch.int) - self.assertEqual(res1, expected) - self.assertIs(torch.int, res1.dtype) + def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x): + x = prepro_fn(torch.randn(*sizes, device=device)) + signal_numel = 1 + signal_sizes = x.size()[-signal_ndim:] + for normalized, onesided in product((True, False), repeat=2): + res = x.rfft(signal_ndim, normalized=normalized, onesided=onesided) + if not onesided: # check Hermitian symmetry + def test_one_sample(res, test_num=10): + idxs_per_dim = [torch.LongTensor(test_num).random_(s).tolist() for s in signal_sizes] + for idx in zip(*idxs_per_dim): + reflected_idx = tuple((s - i) % s for i, s in zip(idx, res.size())) + idx_val = res.__getitem__(idx) + reflected_val = res.__getitem__(reflected_idx) + self.assertEqual(idx_val[0], reflected_val[0], 'rfft hermitian symmetry on real part') + self.assertEqual(idx_val[1], -reflected_val[1], 'rfft hermitian symmetry on imaginary part') + if len(sizes) == signal_ndim: + test_one_sample(res) + else: + output_non_batch_shape = res.size()[-(signal_ndim + 1):] + flatten_batch_res = res.view(-1, *output_non_batch_shape) + nb = flatten_batch_res.size(0) + test_idxs = torch.LongTensor(min(nb, 4)).random_(nb) + for test_idx in test_idxs.tolist(): + test_one_sample(flatten_batch_res[test_idx]) + # compare with C2C + xc = torch.stack([x, torch.zeros_like(x)], -1) + xc_res = xc.fft(signal_ndim, normalized=normalized) + self.assertEqual(res, xc_res) + test_input_signal_sizes = [signal_sizes] + rec = res.irfft(signal_ndim, normalized=normalized, + onesided=onesided, signal_sizes=signal_sizes) + self.assertEqual(x, rec, 1e-8, 'rfft and irfft') + if not onesided: # check that we can use C2C ifft + rec = res.ifft(signal_ndim, normalized=normalized) + self.assertEqual(x, rec.select(-1, 0), 1e-8, 'twosided rfft and ifft real') + self.assertEqual(rec.select(-1, 1).data.abs().mean(), 0, 1e-8, 'twosided rfft and ifft imaginary') - # test copy - res2 = torch.tensor(expected) - self.assertEqual(res2, expected) - res2[1] = 2 - self.assertEqual(expected, torch.ones_like(expected)) + # contiguous case + _test_real((100,), 1) + _test_real((10, 1, 10, 100), 1) + _test_real((100, 100), 2) + _test_real((2, 2, 5, 80, 60), 2) + _test_real((50, 40, 70), 3) + _test_real((30, 1, 50, 25, 20), 3) - res2 = torch.tensor(expected, dtype=torch.int) - self.assertEqual(res1, expected) - self.assertIs(torch.int, res1.dtype) + _test_complex((100, 2), 1) + _test_complex((100, 100, 2), 1) + _test_complex((100, 100, 2), 2) + _test_complex((1, 20, 80, 60, 2), 2) + _test_complex((50, 40, 70, 2), 3) + _test_complex((6, 5, 50, 25, 20, 2), 3) - # test copy with numpy - if TEST_NUMPY: - for dtype in [np.float64, np.int64, np.int8, np.uint8]: - a = np.array([5.]).astype(dtype) - res1 = torch.tensor(a) - self.assertEqual(5., res1[0].item()) - a[0] = 7. - self.assertEqual(5., res1[0].item()) + # non-contiguous case + _test_real((165,), 1, lambda x: x.narrow(0, 25, 100)) # input is not aligned to complex type + _test_real((100, 100, 3), 1, lambda x: x[:, :, 0]) + _test_real((100, 100), 2, lambda x: x.t()) + _test_real((20, 100, 10, 10), 2, lambda x: x.view(20, 100, 100)[:, :60]) + _test_real((65, 80, 115), 3, lambda x: x[10:60, 13:53, 10:80]) + _test_real((30, 20, 50, 25), 3, lambda x: x.transpose(1, 2).transpose(2, 3)) - # test boolean tensor - a = torch.tensor([True, True, False, True, True], dtype=torch.bool) - b = torch.tensor([-1, -1.1, 0, 1, 1.1], dtype=torch.bool) - self.assertEqual(a, b) + _test_complex((2, 100), 1, lambda x: x.t()) + _test_complex((100, 2), 1, lambda x: x.expand(100, 100, 2)) + _test_complex((300, 200, 3), 2, lambda x: x[:100, :100, 1:]) # input is not aligned to complex type + _test_complex((20, 90, 110, 2), 2, lambda x: x[:, 5:85].narrow(2, 5, 100)) + _test_complex((40, 60, 3, 80, 2), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:]) + _test_complex((30, 55, 50, 22, 2), 3, lambda x: x[:, 3:53, 15:40, 1:21]) - def test_tensor_factory_copy_var(self): + # non-contiguous with strides not representable as aligned with complex type + _test_complex((50,), 1, lambda x: x.as_strided([5, 5, 2], [3, 2, 1])) + _test_complex((50,), 1, lambda x: x.as_strided([5, 5, 2], [4, 2, 2])) + _test_complex((50,), 1, lambda x: x.as_strided([5, 5, 2], [4, 3, 1])) + _test_complex((50,), 2, lambda x: x.as_strided([5, 5, 2], [3, 3, 1])) + _test_complex((50,), 2, lambda x: x.as_strided([5, 5, 2], [4, 2, 2])) + _test_complex((50,), 2, lambda x: x.as_strided([5, 5, 2], [4, 3, 1])) - def check_copy(copy, is_leaf, requires_grad, data_ptr=None): - if data_ptr is None: - data_ptr = copy.data_ptr - self.assertEqual(copy.data, source.data) - self.assertTrue(copy.is_leaf == is_leaf) - self.assertTrue(copy.requires_grad == requires_grad) - self.assertTrue(copy.data_ptr == data_ptr) + @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support") + def test_fft_ifft_rfft_irfft(self): + self._test_fft_ifft_rfft_irfft(self) - source = torch.randn(5, 5, dtype=torch.double, requires_grad=True) - # test torch.tensor() - check_copy(torch.tensor(source), True, False) - check_copy(torch.tensor(source, requires_grad=False), True, False) - check_copy(torch.tensor(source, requires_grad=True), True, True) + @unittest.skip("Not implemented yet") + def test_conv2(self): + x = torch.rand(math.floor(torch.uniform(50, 100)), math.floor(torch.uniform(50, 100))) + k = torch.rand(math.floor(torch.uniform(10, 20)), math.floor(torch.uniform(10, 20))) + imvc = torch.conv2(x, k) + imvc2 = torch.conv2(x, k, 'V') + imfc = torch.conv2(x, k, 'F') - # test tensor.new_tensor() - copy = torch.randn(1) - check_copy(copy.new_tensor(source), True, False) - check_copy(copy.new_tensor(source, requires_grad=False), True, False) - check_copy(copy.new_tensor(source, requires_grad=True), True, True) + ki = k.clone() + ks = k.storage() + kis = ki.storage() + for i in range(ks.size() - 1, 0, -1): + kis[ks.size() - i + 1] = ks[i] + # for i=ks.size(), 1, -1 do kis[ks.size()-i+1]=ks[i] end + imvx = torch.xcorr2(x, ki) + imvx2 = torch.xcorr2(x, ki, 'V') + imfx = torch.xcorr2(x, ki, 'F') - # test torch.as_tensor() - check_copy(torch.as_tensor(source), source.is_leaf, source.requires_grad, source.data_ptr) # not copy - check_copy(torch.as_tensor(source, dtype=torch.float), False, True) # copy and keep the graph + self.assertEqual(imvc, imvc2, 0, 'torch.conv2') + self.assertEqual(imvc, imvx, 0, 'torch.conv2') + self.assertEqual(imvc, imvx2, 0, 'torch.conv2') + self.assertEqual(imfc, imfx, 0, 'torch.conv2') + self.assertLessEqual(math.abs(x.dot(x) - torch.xcorr2(x, x)[0][0]), 1e-10, 'torch.conv2') - def test_tensor_factory_type_inference(self): - def test_inference(default_dtype): - saved_dtype = torch.get_default_dtype() - torch.set_default_dtype(default_dtype) - self.assertIs(default_dtype, torch.tensor(()).dtype) - self.assertIs(default_dtype, torch.tensor(5.).dtype) - self.assertIs(torch.int64, torch.tensor(5).dtype) - self.assertIs(torch.bool, torch.tensor(True).dtype) - self.assertIs(torch.int32, torch.tensor(5, dtype=torch.int32).dtype) - self.assertIs(default_dtype, torch.tensor(((7, 5), (9, 5.))).dtype) - self.assertIs(default_dtype, torch.tensor(((5., 5), (3, 5))).dtype) - self.assertIs(torch.int64, torch.tensor(((5, 3), (3, 5))).dtype) + xx = torch.Tensor(2, x.size(1), x.size(2)) + xx[1].copy_(x) + xx[2].copy_(x) + kk = torch.Tensor(2, k.size(1), k.size(2)) + kk[1].copy_(k) + kk[2].copy_(k) - if TEST_NUMPY: - self.assertIs(torch.float64, torch.tensor(np.array(())).dtype) - self.assertIs(torch.float64, torch.tensor(np.array(5.)).dtype) - if np.array(5).dtype == np.int64: # np long, which can be 4 bytes (e.g. on windows) - self.assertIs(torch.int64, torch.tensor(np.array(5)).dtype) - else: - self.assertIs(torch.int32, torch.tensor(np.array(5)).dtype) - self.assertIs(torch.uint8, torch.tensor(np.array(3, dtype=np.uint8)).dtype) - self.assertIs(default_dtype, torch.tensor(((7, np.array(5)), (np.array(9), 5.))).dtype) - self.assertIs(torch.float64, torch.tensor(((7, 5), (9, np.array(5.)))).dtype) - self.assertIs(torch.int64, torch.tensor(((5, np.array(3)), (np.array(3), 5))).dtype) - torch.set_default_dtype(saved_dtype) + immvc = torch.conv2(xx, kk) + immvc2 = torch.conv2(xx, kk, 'V') + immfc = torch.conv2(xx, kk, 'F') - test_inference(torch.float64) - test_inference(torch.float32) + self.assertEqual(immvc[0], immvc[1], 0, 'torch.conv2') + self.assertEqual(immvc[0], imvc, 0, 'torch.conv2') + self.assertEqual(immvc2[0], imvc2, 0, 'torch.conv2') + self.assertEqual(immfc[0], immfc[1], 0, 'torch.conv2') + self.assertEqual(immfc[0], imfc, 0, 'torch.conv2') - @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') - def test_tensor_factory_cuda_type_inference(self): - saved_type = torch.Tensor().type() - torch.set_default_tensor_type(torch.cuda.DoubleTensor) - torch.set_default_dtype(torch.float32) - self.assertIs(torch.float32, torch.tensor(0.).dtype) - self.assertEqual(torch.device('cuda:0'), torch.tensor(0.).device) - torch.set_default_dtype(torch.float64) - self.assertIs(torch.float64, torch.tensor(0.).dtype) - self.assertEqual(torch.device('cuda:0'), torch.tensor(0.).device) - torch.set_default_tensor_type(saved_type) + @unittest.skip("Not implemented yet") + def test_conv3(self): + x = torch.rand(math.floor(torch.uniform(20, 40)), + math.floor(torch.uniform(20, 40)), + math.floor(torch.uniform(20, 40))) + k = torch.rand(math.floor(torch.uniform(5, 10)), + math.floor(torch.uniform(5, 10)), + math.floor(torch.uniform(5, 10))) + imvc = torch.conv3(x, k) + imvc2 = torch.conv3(x, k, 'V') + imfc = torch.conv3(x, k, 'F') - @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') - def test_tensor_factory_cuda_type(self): - saved_type = torch.Tensor().type() - torch.set_default_tensor_type(torch.cuda.FloatTensor) - x = torch.zeros((5, 5)) - self.assertIs(torch.float32, x.dtype) - self.assertTrue(x.is_cuda) - torch.set_default_tensor_type(torch.cuda.DoubleTensor) - x = torch.zeros((5, 5)) - self.assertIs(torch.float64, x.dtype) - self.assertTrue(x.is_cuda) - torch.set_default_tensor_type(saved_type) + ki = k.clone() + ks = k.storage() + kis = ki.storage() + for i in range(ks.size() - 1, 0, -1): + kis[ks.size() - i + 1] = ks[i] + imvx = torch.xcorr3(x, ki) + imvx2 = torch.xcorr3(x, ki, 'V') + imfx = torch.xcorr3(x, ki, 'F') - def test_bool_tensor_comparison_ops(self): - for device in torch.testing.get_all_device_types(): - a = torch.tensor([True, False, True, False, True, False], dtype=torch.bool, device=device) - b = torch.tensor([True, False, True, True, True, True], dtype=torch.bool, device=device) - self.assertEqual(a == b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)) - self.assertEqual(a != b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)) - self.assertEqual(a < b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)) - self.assertEqual(a > b, torch.tensor([0, 0, 0, 0, 0, 0], dtype=torch.bool, device=device)) - self.assertEqual(a >= b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)) - self.assertEqual(a <= b, torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.bool, device=device)) - self.assertEqual(a > False, torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device)) - self.assertEqual(a == torch.tensor(True, dtype=torch.bool, device=device), - torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device)) - self.assertEqual(a == torch.tensor(0, dtype=torch.bool, device=device), - torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool, device=device)) - self.assertFalse(a.equal(b)) - - def test_bool_tensor_value_change(self): - for device in torch.testing.get_all_device_types(): - x = torch.tensor([True, False], dtype=torch.bool, device=device) - x[0] = False - x[1] = True - self.assertEqual(x, torch.tensor([False, True], dtype=torch.bool, device=device)) + self.assertEqual(imvc, imvc2, 0, 'torch.conv3') + self.assertEqual(imvc, imvx, 0, 'torch.conv3') + self.assertEqual(imvc, imvx2, 0, 'torch.conv3') + self.assertEqual(imfc, imfx, 0, 'torch.conv3') + self.assertLessEqual(math.abs(x.dot(x) - torch.xcorr3(x, x)[0][0][0]), 4e-10, 'torch.conv3') - @torchtest.for_all_device_types() - def test_unfold_all_devices_and_dtypes(self, device): - for dt in torch.testing.get_all_dtypes(): - if dt == torch.bfloat16: - self.assertRaises(RuntimeError, lambda: torch.randint(5, (0, 1, 3, 0), dtype=dt, device=device)) - continue + xx = torch.Tensor(2, x.size(1), x.size(2), x.size(3)) + xx[1].copy_(x) + xx[2].copy_(x) + kk = torch.Tensor(2, k.size(1), k.size(2), k.size(3)) + kk[1].copy_(k) + kk[2].copy_(k) - if dt == torch.half and device == 'cpu': - # fix once random is implemented for Half on CPU - self.assertRaises(RuntimeError, lambda: torch.randint(5, (0, 1, 3, 0), dtype=dt, device=device)) - else: - x = torch.randint(5, (0, 1, 3, 0), dtype=dt, device=device) - self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape) + immvc = torch.conv3(xx, kk) + immvc2 = torch.conv3(xx, kk, 'V') + immfc = torch.conv3(xx, kk, 'F') - def test_copy_all_dtypes_and_devices(self): - from copy import copy - for device in torch.testing.get_all_device_types(): - for dt in torch.testing.get_all_dtypes(): - x = torch.tensor([1, 2, 3, 4], dtype=dt, device=device) - x_clone = x.clone() - if (device == 'cuda' and dt == torch.bfloat16): - self.assertRaises(RuntimeError, lambda: copy(x)) - continue - y = copy(x) - y.fill_(1) - # copy is a shallow copy, only copies the tensor view, - # not the data - self.assertEqual(x, y) + self.assertEqual(immvc[0], immvc[1], 0, 'torch.conv3') + self.assertEqual(immvc[0], imvc, 0, 'torch.conv3') + self.assertEqual(immvc2[0], imvc2, 0, 'torch.conv3') + self.assertEqual(immfc[0], immfc[1], 0, 'torch.conv3') + self.assertEqual(immfc[0], imfc, 0, 'torch.conv3') - def test_resize_all_dtypes_and_devices(self): - shape = (2, 2) - for device in torch.testing.get_all_device_types(): - for dt in torch.testing.get_all_dtypes(): - x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) - x.resize_(shape) - self.assertEqual(shape, x.shape) + @unittest.skip("Not implemented yet") + def _test_conv_corr_eq(self, fn, fn_2_to_3): + ix = math.floor(random.randint(20, 40)) + iy = math.floor(random.randint(20, 40)) + iz = math.floor(random.randint(20, 40)) + kx = math.floor(random.randint(5, 10)) + ky = math.floor(random.randint(5, 10)) + kz = math.floor(random.randint(5, 10)) - def test_resize_as_all_dtypes_and_devices(self): - for device in torch.testing.get_all_device_types(): - for dt in torch.testing.get_all_dtypes(): - x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) - y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device) - x.resize_as_(y) - self.assertEqual(y.shape, x.shape) + x = torch.rand(ix, iy, iz) + k = torch.rand(kx, ky, kz) - def test_view_all_dtypes_and_devices(self): - for device in torch.testing.get_all_device_types(): - for dt in torch.testing.get_all_dtypes(): - x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) - if (device == 'cuda' and dt == torch.bfloat16): - self.assertRaises(RuntimeError, lambda: x.view(6)) - continue - self.assertEqual(x.view(6).shape, [6]) + o3 = fn(x, k) + o32 = torch.zeros(o3.size()) + fn_2_to_3(x, k, o3, o32) + self.assertEqual(o3, o32) - def test_fill_all_dtypes_and_devices(self): - for device in torch.testing.get_all_device_types(): - for dt in torch.testing.get_all_dtypes(): - x = torch.tensor((1, 1), dtype=dt, device=device) - if (device == 'cuda' and dt == torch.bfloat16): - self.assertRaises(RuntimeError, lambda: x.fill_(1)) - continue - x.fill_(1) + @unittest.skip("Not implemented yet") + def test_xcorr3_xcorr2_eq(self): + def reference(x, k, o3, o32): + for i in range(o3.size(1)): + for j in range(k.size(1)): + o32[i].add(torch.xcorr2(x[i + j - 1], k[j])) + self._test_conv_corr_eq(torch.xcorr3, reference) - self.assertEqual(x, torch.tensor([1, 1], dtype=dt, device=device)) - self.assertEqual(dt, x.dtype) + @unittest.skip("Not implemented yet") + def test_xcorr3_xcorr2_eq_full(self): + def reference(x, k, o3, o32): + for i in range(x.size(1)): + for j in range(k.size(1)): + o32[i].add(torch.xcorr2(x[i], k[k.size(1) - j + 1], 'F')) + self._test_conv_corr_eq(lambda x, k: torch.xcorr3(x, k, 'F'), reference) - def test_clone_all_dtypes_and_devices(self): - for device in torch.testing.get_all_device_types(): - for dt in torch.testing.get_all_dtypes(): - x = torch.tensor((1, 1), dtype=dt, device=device) - y = x.clone() - if (device == 'cuda' and dt == torch.bfloat16): - # `x - y` is used inside of the assertEqual - self.assertRaises(RuntimeError, lambda: x - y) - continue - self.assertEqual(x, y) + @unittest.skip("Not implemented yet") + def test_conv3_conv2_eq_valid(self): + def reference(x, k, o3, o32): + for i in range(o3.size(1)): + for j in range(k.size(1)): + o32[i].add(torch.conv2(x[i + j - 1], k[k.size(1) - j + 1])) + self._test_conv_corr_eq(torch.conv3, reference) - def test_cat_all_dtypes_and_devices(self): - for device in torch.testing.get_all_device_types(): - for dt in torch.testing.get_all_dtypes(): - x = torch.tensor([[1, 2], [3, 4]], dtype=dt, device=device) - if (device == 'cuda' and dt == torch.bfloat16): - self.assertRaises(RuntimeError, lambda: torch.cat((x, x), 0)) - continue + @unittest.skip("Not implemented yet") + def test_fconv3_fconv2_eq(self): + def reference(x, k, o3, o32): + for i in range(o3.size(1)): + for j in range(k.size(1)): + o32[i + j - 1].add(torch.conv2(x[i], k[j], 'F')) + self._test_conv_corr_eq(lambda x, k: torch.conv3(x, k, 'F'), reference) - expected1 = torch.tensor([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=dt, device=device) - self.assertEqual(torch.cat((x, x), 0), expected1) + def test_isfinite(self): + x = torch.Tensor([1, inf, 2, -inf, nan, -10]) + self.assertEqual(torch.isfinite(x), torch.BoolTensor([True, False, True, False, False, True])) - expected2 = torch.tensor([[1, 2, 1, 2], [3, 4, 3, 4]], dtype=dt, device=device) - self.assertEqual(torch.cat((x, x), 1), expected2) + def test_isfinite_int(self): + x = torch.tensor([1, 2, 3]) + self.assertEqual(torch.isfinite(x), torch.BoolTensor([True, True, True])) - def test_tensor_factories_empty(self): - # ensure we can create empty tensors from each factory function - shapes = [(5, 0, 1), (0,), (0, 0, 1, 0, 2, 0, 0)] + def test_isfinite_type(self): + with self.assertRaises(TypeError): + torch.isfinite(1) # Parameter must be a tensor - for device in torch.testing.get_all_device_types(): - for shape in shapes: - for dt in torch.testing.get_all_dtypes(): + def test_isinf_type(self): + with self.assertRaises(TypeError): + torch.isinf(1) # Parameter must be a tensor - if (device == 'cuda' and dt == torch.bfloat16): - self.assertRaises(RuntimeError, lambda: torch.zeros(shape, device=device, dtype=dt).shape) - self.assertRaises(RuntimeError, lambda: torch.zeros_like(torch.zeros(shape, device=device, dtype=dt)).shape) - self.assertRaises(RuntimeError, lambda: torch.full(shape, 3, device=device, dtype=dt).shape) - self.assertRaises(RuntimeError, lambda: torch.full_like(torch.zeros(shape, device=device, dtype=dt), 3)) - self.assertRaises(RuntimeError, lambda: torch.ones(shape, device=device, dtype=dt).shape) - self.assertRaises(RuntimeError, lambda: torch.ones_like(torch.zeros(shape, device=device, dtype=dt)).shape) - self.assertRaises(RuntimeError, lambda: torch.empty_like(torch.zeros(shape, device=device, dtype=dt)).shape) - else: - self.assertEqual(shape, torch.zeros(shape, device=device, dtype=dt).shape) - self.assertEqual(shape, torch.zeros_like(torch.zeros(shape, device=device, dtype=dt)).shape) - self.assertEqual(shape, torch.full(shape, 3, device=device, dtype=dt).shape) - self.assertEqual(shape, torch.full_like(torch.zeros(shape, device=device, dtype=dt), 3).shape) - self.assertEqual(shape, torch.ones(shape, device=device, dtype=dt).shape) - self.assertEqual(shape, torch.ones_like(torch.zeros(shape, device=device, dtype=dt)).shape) - self.assertEqual(shape, torch.empty(shape, device=device, dtype=dt).shape) - self.assertEqual(shape, torch.empty_like(torch.zeros(shape, device=device, dtype=dt)).shape) - self.assertEqual(shape, torch.empty_strided(shape, (0,) * len(shape), device=device, dtype=dt).shape) - - if dt == torch.half and device == "cpu": - # update once random is implemented for half on CPU - self.assertRaises(RuntimeError, lambda: torch.randint(6, shape, device=device, dtype=dt).shape) - else: - if dt == torch.bfloat16: - self.assertRaises(RuntimeError, lambda: torch.randint(6, shape, device=device, dtype=dt)) - continue # Remove once random is supported for bfloat16 on cuda - self.assertEqual(shape, torch.randint(6, shape, device=device, dtype=dt).shape) - self.assertEqual(shape, torch.randint_like(torch.zeros(shape, device=device, dtype=dt), 6).shape) - - if dt != torch.double and dt != torch.float and dt != torch.half: - self.assertRaises(RuntimeError, lambda: torch.rand(shape, device=device, dtype=dt).shape) - - if dt == torch.double or dt == torch.float: - self.assertEqual(shape, torch.randn(shape, device=device, dtype=dt).shape) - self.assertEqual(shape, torch.randn_like(torch.zeros(shape, device=device, dtype=dt)).shape) - - self.assertEqual((0,), torch.arange(0, device=device).shape) - self.assertEqual((0, 0), torch.eye(0, device=device).shape) - self.assertEqual((0, 0), torch.eye(0, 0, device=device).shape) - self.assertEqual((5, 0), torch.eye(5, 0, device=device).shape) - self.assertEqual((0, 5), torch.eye(0, 5, device=device).shape) - self.assertEqual((0,), torch.linspace(1, 1, 0, device=device).shape) - self.assertEqual((0,), torch.logspace(1, 1, 0, device=device).shape) - self.assertEqual((0,), torch.randperm(0, device=device).shape) - self.assertEqual((0,), torch.bartlett_window(0, device=device).shape) - self.assertEqual((0,), torch.bartlett_window(0, periodic=False, device=device).shape) - self.assertEqual((0,), torch.hamming_window(0, device=device).shape) - self.assertEqual((0,), torch.hann_window(0, device=device).shape) - self.assertEqual((1, 1, 0), torch.tensor([[[]]], device=device).shape) - self.assertEqual((1, 1, 0), torch.as_tensor([[[]]], device=device).shape) + def test_isnan(self): + x = torch.Tensor([1, nan, 2]) + self.assertEqual(torch.isnan(x), torch.ByteTensor([0, 1, 0])) - def test_new_tensor(self): - expected = torch.autograd.Variable(torch.ByteTensor([1, 1])) - # test data - res1 = expected.new_tensor([1, 1]) - self.assertEqual(res1, expected) - res1 = expected.new_tensor([1, 1], dtype=torch.int) - self.assertEqual(res1, expected) - self.assertIs(torch.int, res1.dtype) - - # test copy - res2 = expected.new_tensor(expected) - self.assertEqual(res2, expected) - res2[1] = 2 - self.assertEqual(expected, torch.ones_like(expected)) - res2 = expected.new_tensor(expected, dtype=torch.int) - self.assertEqual(res2, expected) - self.assertIs(torch.int, res2.dtype) + def test_RNGState(self): + state = torch.get_rng_state() + stateCloned = state.clone() + before = torch.rand(1000) - # test copy with numpy - if TEST_NUMPY: - a = np.array([5.]) - res1 = torch.tensor(a) - res1 = res1.new_tensor(a) - self.assertEqual(5., res1[0].item()) - a[0] = 7. - self.assertEqual(5., res1[0].item()) + self.assertEqual(state.ne(stateCloned).long().sum(), 0, 0) - if torch.cuda.device_count() >= 2: - expected = expected.cuda(1) - res1 = expected.new_tensor([1, 1]) - self.assertEqual(res1.get_device(), expected.get_device()) - res1 = expected.new_tensor([1, 1], dtype=torch.int) - self.assertIs(torch.int, res1.dtype) - self.assertEqual(res1.get_device(), expected.get_device()) + torch.set_rng_state(state) + after = torch.rand(1000) + self.assertEqual(before, after, 0) - res2 = expected.new_tensor(expected) - self.assertEqual(res2.get_device(), expected.get_device()) - res2 = expected.new_tensor(expected, dtype=torch.int) - self.assertIs(torch.int, res1.dtype) - self.assertEqual(res2.get_device(), expected.get_device()) - res2 = expected.new_tensor(expected, dtype=torch.int, device=0) - self.assertIs(torch.int, res1.dtype) - self.assertEqual(res2.get_device(), 0) + def test_RNGStateAliasing(self): + # Fork the random number stream at this point + gen = torch.Generator() + gen.set_state(torch.get_rng_state()) + self.assertEqual(gen.get_state(), torch.get_rng_state()) - res1 = expected.new_tensor(1) - self.assertEqual(res1.get_device(), expected.get_device()) - res1 = expected.new_tensor(1, dtype=torch.int) - self.assertIs(torch.int, res1.dtype) - self.assertEqual(res1.get_device(), expected.get_device()) + target_value = torch.rand(1000) + # Dramatically alter the internal state of the main generator + _ = torch.rand(100000) + forked_value = torch.rand(1000, generator=gen) + self.assertEqual(target_value, forked_value, 0, "RNG has not forked correctly.") - def test_as_tensor(self): - # from python data - x = [[0, 1], [2, 3]] - self.assertEqual(torch.tensor(x), torch.as_tensor(x)) - self.assertEqual(torch.tensor(x, dtype=torch.float32), torch.as_tensor(x, dtype=torch.float32)) + def test_RNG_after_pickle(self): + torch.random.manual_seed(100) + before = torch.rand(10) - # python data with heterogeneous types - z = [0, 'torch'] - with self.assertRaisesRegex(TypeError, "invalid data type"): - torch.tensor(z) - torch.as_tensor(z) + torch.random.manual_seed(100) + buf = io.BytesIO() + tensor = torch.Tensor([1, 2, 3]) + ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(tensor) + after = torch.rand(10) - # python data with self-referential lists - z = [0] - z += [z] - with self.assertRaisesRegex(TypeError, "self-referential lists are incompatible"): - torch.tensor(z) - torch.as_tensor(z) + self.assertEqual(before, after, 0) - z = [[1, 2], z] - with self.assertRaisesRegex(TypeError, "self-referential lists are incompatible"): - torch.tensor(z) - torch.as_tensor(z) + def test_boxMullerState(self): + torch.manual_seed(123) + odd_number = 101 + seeded = torch.randn(odd_number) + state = torch.get_rng_state() + midstream = torch.randn(odd_number) + torch.set_rng_state(state) + repeat_midstream = torch.randn(odd_number) + torch.manual_seed(123) + reseeded = torch.randn(odd_number) + self.assertEqual(midstream, repeat_midstream, 0, + 'get_rng_state/set_rng_state not generating same sequence of normally distributed numbers') + self.assertEqual(seeded, reseeded, 0, + 'repeated calls to manual_seed not generating same sequence of normally distributed numbers') - # from tensor (doesn't copy unless type is different) - y = torch.tensor(x) - self.assertIs(y, torch.as_tensor(y)) - self.assertIsNot(y, torch.as_tensor(y, dtype=torch.float32)) - if torch.cuda.is_available(): - self.assertIsNot(y, torch.as_tensor(y, device='cuda')) - y_cuda = y.to('cuda') - self.assertIs(y_cuda, torch.as_tensor(y_cuda)) - self.assertIs(y_cuda, torch.as_tensor(y_cuda, device='cuda')) + def test_manual_seed(self): + rng_state = torch.get_rng_state() + torch.manual_seed(2) + x = torch.randn(100) + self.assertEqual(torch.initial_seed(), 2) + torch.manual_seed(2) + y = torch.randn(100) + self.assertEqual(x, y) + torch.set_rng_state(rng_state) - if TEST_NUMPY: - # doesn't copy - for dtype in [np.float64, np.int64, np.int8, np.uint8]: - n = np.random.rand(5, 6).astype(dtype) - n_astensor = torch.as_tensor(n) - self.assertEqual(torch.tensor(n), n_astensor) - n_astensor[0][0] = 25.7 - self.assertEqual(torch.tensor(n), n_astensor) + def test_numel(self): + b = torch.ByteTensor(3, 100, 100) + self.assertEqual(b.nelement(), 3 * 100 * 100) + self.assertEqual(b.numel(), 3 * 100 * 100) - # changing dtype causes copy - n = np.random.rand(5, 6).astype(np.float32) - n_astensor = torch.as_tensor(n, dtype=torch.float64) - self.assertEqual(torch.tensor(n, dtype=torch.float64), n_astensor) - n_astensor[0][1] = 250.8 - self.assertNotEqual(torch.tensor(n, dtype=torch.float64), n_astensor) + def _consecutive(self, size, start=1): + sequence = torch.ones(int(torch.Tensor(size).prod(0))).cumsum(0) + sequence.add_(start - 1) + return sequence.resize_(*size) - # changing device causes copy - if torch.cuda.is_available(): - n = np.random.randn(5, 6) - n_astensor = torch.as_tensor(n, device='cuda') - self.assertEqual(torch.tensor(n, device='cuda'), n_astensor) - n_astensor[0][2] = 250.9 - self.assertNotEqual(torch.tensor(n, device='cuda'), n_astensor) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_empty_storage_view(self): + # we should be able to "modify" slices of a 0-element + # array without an error being raised due to + # trying to resize its storage + t = torch.from_numpy(np.empty((0, 4))) + t[:, 1::2] *= 1 - def test_diag(self): - x = torch.rand(100, 100) - res1 = torch.diag(x) - res2 = torch.Tensor() - torch.diag(x, out=res2) - self.assertEqual(res1, res2) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_newaxis_numpy_comparison(self): + def run_test(tensor, *idx): + npt = tensor.numpy() + self.assertEqual(tensor[idx], npt[idx]) - @staticmethod - def _test_diagonal(self, dtype, device): - x = torch.randn((100, 100), dtype=dtype, device=device) - result = torch.diagonal(x) - expected = torch.diag(x) - self.assertEqual(result, expected) + # 1D Tensor Tests + x = torch.arange(0, 10) + cases = [ + [None], + [None, None], + [Ellipsis, None], + [None, Ellipsis], + [2, None], + [None, 2], + [Ellipsis, None, 2], + [Ellipsis, 2, None], + [2, Ellipsis, None], + [2, None, Ellipsis], + [None, 2, Ellipsis], + [None, Ellipsis, 2], + ] - x = torch.randn((100, 100), dtype=dtype, device=device) - result = torch.diagonal(x, 17) - expected = torch.diag(x, 17) - self.assertEqual(result, expected) + for case in cases: + run_test(x, *case) - def test_diagonal(self): - self._test_diagonal(self, dtype=torch.float32, device='cpu') + # 2D Tensor Tests + x = torch.arange(0, 12).view(3, 4) + cases = [ + [None], + [None, None], + [None, None, None], + [Ellipsis, None], + [Ellipsis, None, None], + [None, Ellipsis], + [None, Ellipsis, None], + [None, None, Ellipsis], + [2, None], + [2, None, Ellipsis], + [2, Ellipsis, None], + [None, 2, Ellipsis], + [Ellipsis, 2, None], + [Ellipsis, None, 2], + [None, Ellipsis, 2], + [1, 2, None], + [1, 2, Ellipsis, None], + [1, Ellipsis, 2, None], + [Ellipsis, 1, None, 2], + [Ellipsis, 1, 2, None], + [1, None, 2, Ellipsis], + [None, 1, Ellipsis, 2], + [None, 1, 2, Ellipsis], + ] - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_diagonal_multidim(self): - x = torch.randn(10, 11, 12, 13) - xn = x.numpy() - for args in [(2, 2, 3), - (2,), - (-2, 1, 2), - (0, -2, -1)]: - result = torch.diagonal(x, *args) - expected = xn.diagonal(*args) - self.assertEqual(expected.shape, result.shape) - self.assertTrue(np.allclose(expected, result.numpy())) - # test non-continguous - xp = x.permute(1, 2, 3, 0) - result = torch.diagonal(xp, 0, -2, -1) - expected = xp.numpy().diagonal(0, -2, -1) - self.assertEqual(expected.shape, result.shape) - self.assertTrue(np.allclose(expected, result.numpy())) + for case in cases: + run_test(x, *case) - @staticmethod - def _test_diag_embed(self, dtype, device): - x = torch.arange(3 * 4, dtype=dtype, device=device).view(3, 4) - result = torch.diag_embed(x) - expected = torch.stack([torch.diag(r) for r in x], 0) - self.assertEqual(result, expected) + def test_newindex(self): + reference = self._consecutive((3, 3, 3)) + # This relies on __index__() being correct - but we have separate tests for that - result = torch.diag_embed(x, offset=1, dim1=0, dim2=2) - expected = torch.stack([torch.diag(r, 1) for r in x], 1) - self.assertEqual(result, expected) + def checkPartialAssign(index): + reference = torch.zeros(3, 3, 3) + reference[index] = self._consecutive((3, 3, 3))[index] + self.assertEqual(reference[index], self._consecutive((3, 3, 3))[index], 0) + reference[index] = 0 + self.assertEqual(reference, torch.zeros(3, 3, 3), 0) - def test_diag_embed(self): - self._test_diag_embed(self, dtype=torch.float32, device='cpu') + checkPartialAssign(0) + checkPartialAssign(1) + checkPartialAssign(2) + checkPartialAssign((0, 1)) + checkPartialAssign((1, 2)) + checkPartialAssign((0, 2)) + checkPartialAssign(torch.LongTensor((0, 2))) - @staticmethod - def _test_diagflat(self, dtype, device): - # Basic sanity test - x = torch.randn((100,), dtype=dtype, device=device) - result = torch.diagflat(x) - expected = torch.diag(x) - self.assertEqual(result, expected) + with self.assertRaises(IndexError): + reference[1, 1, 1, 1] = 1 + with self.assertRaises(IndexError): + reference[1, 1, 1, (1, 1)] = 1 + with self.assertRaises(IndexError): + reference[3, 3, 3, 3, 3, 3, 3, 3] = 1 + with self.assertRaises(IndexError): + reference[0.0] = 1 + with self.assertRaises(TypeError): + reference[0.0:2.0] = 1 + with self.assertRaises(IndexError): + reference[0.0, 0.0:2.0] = 1 + with self.assertRaises(IndexError): + reference[0.0, :, 0.0:2.0] = 1 + with self.assertRaises(IndexError): + reference[0.0, ..., 0.0:2.0] = 1 + with self.assertRaises(IndexError): + reference[0.0, :, 0.0] = 1 - # Test offset - x = torch.randn((100,), dtype=dtype, device=device) - result = torch.diagflat(x, 17) - expected = torch.diag(x, 17) - self.assertEqual(result, expected) + def test_index_add(self): + num_copy, num_dest = 3, 3 + dest = torch.randn(num_dest, 4, 5) + src = torch.randn(num_copy, 4, 5) + idx = torch.randperm(num_dest).narrow(0, 0, num_copy) + dest2 = dest.clone() + dest.index_add_(0, idx, src) + for i in range(idx.size(0)): + dest2[idx[i]] += src[i] + self.assertEqual(dest, dest2) - # Test where input has more than one dimension - x = torch.randn((2, 3, 4), dtype=dtype, device=device) - result = torch.diagflat(x) - expected = torch.diag(x.contiguous().view(-1)) - self.assertEqual(result, expected) + dest = torch.randn(num_dest) + src = torch.randn(num_copy) + idx = torch.randperm(num_dest).narrow(0, 0, num_copy) + dest2 = dest.clone() + dest.index_add_(0, idx, src) + for i in range(idx.size(0)): + dest2[idx[i]] = dest2[idx[i]] + src[i] + self.assertEqual(dest, dest2) - # Noncontig input - x = torch.randn((2, 3, 4), dtype=dtype, device=device).transpose(2, 0) - self.assertFalse(x.is_contiguous()) - result = torch.diagflat(x) - expected = torch.diag(x.contiguous().view(-1)) - self.assertEqual(result, expected) + def test_t(self): + # Test 0D tensors + x = torch.randn(()) + self.assertEqual(x, x.t()) + x = x.to_sparse() + self.assertEqual(x, x.t()) - def test_diagflat(self): - self._test_diagflat(self, dtype=torch.float32, device='cpu') + # Test 1D tensors + x = torch.arange(4) + self.assertEqual(x, x.t()) + x = x.to_sparse() + self.assertEqual(x, x.t()) - def test_eye(self): - for dtype, device in product(torch.testing.get_all_dtypes(), torch.testing.get_all_device_types()): - if dtype == torch.bfloat16: - continue + # Test 2D tensors + x = torch.rand((2, 2)) + self.assertEqual(x.t(), x.transpose(0, 1)) + x = x.to_sparse() + self.assertEqual(x.t(), x.transpose(0, 1)) - for n, m in product([3, 5, 7], repeat=2): - # Construct identity using diagonal and fill - res1 = torch.eye(n, m, device=device, dtype=dtype) - naive_eye = torch.zeros(n, m, dtype=dtype, device=device) - naive_eye.diagonal(dim1=-2, dim2=-1).fill_(1) - self.assertEqual(naive_eye, res1) + # Test 3D tensor + x = torch.rand((2, 2, 2)) + with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'): + x.t() + x = x.to_sparse() + with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'): + x.t() - # Check eye_out outputs - res2 = torch.empty(0, device=device, dtype=dtype) - torch.eye(n, m, out=res2) - self.assertEqual(res1, res2) + def test_take(self): + def check(src, idx): + expected = src.contiguous().view(-1).index_select( + 0, idx.contiguous().view(-1)).view_as(idx) + actual = src.take(idx) + self.assertEqual(actual.size(), idx.size()) + self.assertEqual(expected, actual) - def test_renorm(self): - m1 = torch.randn(10, 5) - res1 = torch.Tensor() + src = torch.randn(2, 3, 5) + idx = torch.LongTensor([[0, 2], [3, 4]]) + check(src, idx) + check(src.transpose(1, 2), idx) + check(src.bool(), idx) - def renorm(matrix, value, dim, max_norm): - m1 = matrix.transpose(dim, 0).contiguous() - # collapse non-dim dimensions. - m2 = m1.clone().resize_(m1.size(0), int(math.floor(m1.nelement() / m1.size(0)))) - norms = m2.norm(value, 1, True) - # clip - new_norms = norms.clone() - new_norms[torch.gt(norms, max_norm)] = max_norm - new_norms.div_(norms.add_(1e-7)) - # renormalize - m1.mul_(new_norms.expand_as(m1)) - return m1.transpose(dim, 0) + def test_put_(self): + def check(dst, idx, value): + expected = dst.clone().view(-1).index_copy_( + 0, idx.contiguous().view(-1), value.contiguous().view(-1)) + expected = expected.view_as(dst) + dst.put_(idx, value) + self.assertEqual(expected, dst) - # note that the axis fed to torch.renorm is different (2~=1) - maxnorm = m1.norm(2, 1).mean() - m2 = renorm(m1, 2, 1, maxnorm) - m1.renorm_(2, 1, maxnorm) - self.assertEqual(m1, m2, 1e-5) - self.assertEqual(m1.norm(2, 0), m2.norm(2, 0), 1e-5) + dst = torch.randn(2, 3, 5) + idx = torch.LongTensor([[0, 2], [3, 4]]) + values = torch.randn(2, 2) + check(dst, idx, values) + check(dst.transpose(1, 2), idx, values) - m1 = torch.randn(3, 4, 5) - m2 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4) - maxnorm = m2.norm(2, 0).mean() - m2 = renorm(m2, 2, 1, maxnorm) - m1.renorm_(2, 1, maxnorm) - m3 = m1.transpose(1, 2).contiguous().clone().resize_(15, 4) - self.assertEqual(m3, m2) - self.assertEqual(m3.norm(2, 0), m2.norm(2, 0)) + values = torch.tensor([[False, False], [False, False]]) + check(dst.bool(), idx, values) + + def test_put_accumulate(self): + dst = torch.ones(2, 2) + idx = torch.LongTensor([[0, 1], [0, 1]]) + src = torch.Tensor([1, 2, 3, 4]) + dst.put_(idx, src, accumulate=True) + self.assertEqual(dst.tolist(), [[5, 7], [1, 1]]) + # Fill idx with valid indices. @staticmethod - def _test_renorm_ps(self, device): - # full reduction - x = torch.randn(5, 5) - xn = x.numpy() - for p in [1, 2, 3, 4, inf]: - res = x.renorm(p, 1, 1) - expected = x / x.norm(p, 0, keepdim=True).clamp(min=1) - self.assertEqual(res.numpy(), expected.numpy(), "renorm failed for {}-norm".format(p)) + def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o): + for i in range(1 if dim == 0 else m): + for j in range(1 if dim == 1 else n): + for k in range(1 if dim == 2 else o): + ii = [i, j, k] + ii[dim] = slice(0, idx.size(dim) + 1) + idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row] - def test_renorm_ps(self): - self._test_renorm_ps(self, device='cpu') + def test_flatten(self): + # Test that flatten returns 1-dim tensor when given a 0-dim tensor + zero_dim_tensor = torch.tensor(123) + flat0 = zero_dim_tensor.flatten() + one_dim_tensor = torch.tensor([123]) + flat1 = zero_dim_tensor.flatten() - @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') - def test_renorm_ps_cuda(self): - self._test_renorm_ps(self, device='cuda') + self.assertEqual(zero_dim_tensor.shape, torch.Size([])) + self.assertEqual(flat0.shape, torch.Size([1])) + self.assertEqual(one_dim_tensor.shape, torch.Size([1])) + self.assertEqual(flat1.shape, torch.Size([1])) + self.assertEqual(flat0, one_dim_tensor) + self.assertEqual(flat0, flat1) + self.assertEqual(flat0.shape, flat1.shape) - @staticmethod - def _test_multinomial(self, type): - def make_prob_dist(shape, is_contiguous): - if is_contiguous: - return type(*shape).uniform_() - elif len(shape) == 1: - return type(*(shape + [5])).uniform_()[:, 2] - else: - # num dim = 2 - new_shape = [2, shape[1], 7, 1, shape[0], 1, 10] - prob_dist = type(*new_shape).uniform_() - prob_dist = prob_dist.transpose(1, 4) - prob_dist = prob_dist[1, :, 5, 0, :, 0, 4] - assert not prob_dist.is_contiguous() # sanity check - return prob_dist + # Test both float tensor and quantized tensor + tensors = [torch.randn(5, 5, 5, 5), + torch._empty_affine_quantized([5, 5, 5, 5], + scale=2, + zero_point=3, + dtype=torch.quint8)] + for src in tensors: + flat = src.flatten(0, -1) + self.assertEqual(flat.shape, torch.Size([625])) + self.assertEqual(src.view(-1), flat.view(-1)) - for is_contiguous in (True, False): - # with replacement - n_row = 3 - for n_col in range(4, 5 + 1): - prob_dist = make_prob_dist([n_row, n_col], is_contiguous) - # indices that shouldn't be sampled (<0 means none) - zero_prob_indices = torch.LongTensor(n_row).random_(-2, n_col).tolist() - for i, j in enumerate(zero_prob_indices): - if j >= 0: - prob_dist[i, j] = 0 - n_sample = n_col * 3 - sample_indices = torch.multinomial(prob_dist, n_sample, True) - self.assertEqual(prob_dist.dim(), 2) - self.assertEqual(sample_indices.size(1), n_sample) - for i in range(n_row): - zero_prob_idx = zero_prob_indices[i] - if zero_prob_idx < 0: - continue - for j in range(n_sample): - self.assertNotEqual(sample_indices[i, j], zero_prob_idx, - "sampled an index with zero probability") + flat = src.flatten(0, 2) + self.assertEqual(flat.shape, torch.Size([125, 5])) + self.assertEqual(src.view(-1), flat.view(-1)) - # without replacement - n_row = 3 - for n_col in range(2, 10 + 1, 2): - prob_dist = make_prob_dist([n_row, n_col], is_contiguous) - # indices that shouldn't be sampled (<0 means none) - zero_prob_indices = torch.LongTensor(n_row).random_(-1, n_col).tolist() - for i, j in enumerate(zero_prob_indices): - if j >= 0: - prob_dist[i, j] = 0 - n_sample = max(1, n_col - 2) - sample_indices = torch.multinomial(prob_dist, n_sample, False) - self.assertEqual(prob_dist.dim(), 2) - self.assertEqual(sample_indices.size(1), n_sample) - for i in range(n_row): - row_samples = {} - zero_prob_idx = zero_prob_indices[i] - for j in range(n_sample): - sample_idx = sample_indices[i, j] - if zero_prob_idx >= 0: - self.assertNotEqual(sample_idx, zero_prob_idx, - "sampled an index with zero probability") - self.assertNotIn(sample_idx, row_samples, "sampled an index twice") - row_samples[sample_idx] = True + flat = src.flatten(0, 1) + self.assertEqual(flat.shape, torch.Size([25, 5, 5])) + self.assertEqual(src.view(-1), flat.view(-1)) - # vector - n_col = 4 - prob_dist = make_prob_dist([n_col], is_contiguous).fill_(1) - zero_prob_idx = 1 # index that shouldn't be sampled - prob_dist[zero_prob_idx] = 0 - n_sample = 20 - sample_indices = torch.multinomial(prob_dist, n_sample, True) - for sample_index in sample_indices: - self.assertNotEqual(sample_index, zero_prob_idx, "sampled an index with zero probability") - s_dim = sample_indices.dim() - self.assertEqual(sample_indices.dim(), 1, "wrong number of dimensions") - self.assertEqual(prob_dist.dim(), 1, "wrong number of prob_dist dimensions") - self.assertEqual(sample_indices.size(0), n_sample, "wrong number of samples") + flat = src.flatten(1, 2) + self.assertEqual(flat.shape, torch.Size([5, 25, 5])) + self.assertEqual(src.view(-1), flat.view(-1)) - def test_multinomial(self): - self._test_multinomial(self, torch.FloatTensor) + flat = src.flatten(2, 3) + self.assertEqual(flat.shape, torch.Size([5, 5, 25])) + self.assertEqual(src.view(-1), flat.view(-1)) - @staticmethod - def _test_multinomial_alias(self, cast): - # Get probs vector to use in setup - def get_probs(length, is_contiguous): - probs = torch.softmax(torch.randn(length), 0) - if not is_contiguous: - probs = torch.softmax(torch.randn(length, 2), 0)[:, 1] - assert not (is_contiguous ^ probs.is_contiguous()), "contiguity requirement not met" - return cast(probs) + flat = src.flatten(-2, -1) + self.assertEqual(flat.shape, torch.Size([5, 5, 25])) + self.assertEqual(src.view(-1), flat.view(-1)) - for is_contiguous in [True, False]: - probs = get_probs(4, is_contiguous) - alias_table, prob_table = torch._multinomial_alias_setup(probs) - for n_samples in [-1, 1, 10]: - if n_samples > 0: - samples = torch._multinomial_alias_draw(prob_table, alias_table, n_samples) - self.assertEqual(prob_table.size(), torch.Size([4]), "size mismatch: probability table") - self.assertEqual(alias_table.size(), torch.Size([4]), "size mismatch: alias table") - self.assertEqual(samples.size(), torch.Size([n_samples]), "wrong number of samples") - else: - with self.assertRaisesRegex(RuntimeError, "cannot sample <= 0 samples"): - torch._multinomial_alias_draw(prob_table, alias_table, n_samples) + flat = src.flatten(2, 2) + self.assertEqual(flat, src) - with self.assertRaisesRegex(RuntimeError, "expected 1-D"): - probs = probs.view(2, 2) - torch._multinomial_alias_setup(probs) + # out of bounds index + with self.assertRaisesRegex(IndexError, 'Dimension out of range'): + src.flatten(5, 10) - with self.assertRaisesRegex(RuntimeError, "expected 1-D"): - a_t, p_t = torch._multinomial_alias_setup(probs) - torch._multinomial_alias_draw(p_t.view(2, 2), a_t.view(2, 2)) + # invalid start and end + with self.assertRaisesRegex(RuntimeError, 'start_dim cannot come after end_dim'): + src.flatten(2, 0) - MAX_SAMPLES = 200000 - for probs in [get_probs(4, True), - cast(torch.tensor([0.8, 0.2])), - cast(torch.tensor([0.7, 0.2, 0.1]))]: - # Check how different the alias distribution and the original distribution are - alias_dist = torch.zeros_like(probs) - alias_table, prob_table = torch._multinomial_alias_setup(probs) - alias_samples = torch._multinomial_alias_draw(prob_table, alias_table, MAX_SAMPLES) - alias_dist = torch.unique(alias_samples, return_counts=True)[1].to(dtype=probs.dtype) / MAX_SAMPLES - self.assertTrue(torch.allclose(alias_dist, probs, rtol=0.02, atol=0.0), - "Actual: {}\nExpected: {}".format(alias_dist, probs)) + @staticmethod + def _test_gather(self, cast, test_bounds=True): + m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20) + elems_per_row = random.randint(1, 10) + dim = random.randrange(3) - for probs in [cast(torch.tensor([0.2501, 0.25, 0.2499, 0.25])), - cast(torch.tensor([0.8, 0.199, 0.001])), - cast(torch.tensor([0.25001, 0.25, 0.24999, 0.25])), - cast(torch.tensor([0.33, 0.34, 0.33])), - cast(torch.tensor([0.8, 0.1999, 0.0001]))]: - # Check the difference between the original probabilities and the reconstructed - # probabilities from the alias and probability tables output by _multinomial_alias_setup - alias_table, prob_table = torch._multinomial_alias_setup(probs) - actual = torch.zeros_like(probs) - for i, vals in enumerate(zip(alias_table, prob_table)): - idx, p = vals - actual[i] += p - actual[idx] += 1. - p - actual = actual / len(probs) - self.assertEqual(actual, probs, 1e-6) + src = torch.randn(m, n, o) + idx_size = [m, n, o] + idx_size[dim] = elems_per_row + idx = torch.LongTensor().resize_(*idx_size) + _TestTorchMixin._fill_indices(self, idx, dim, src.size(dim), elems_per_row, m, n, o) - # Some special cases - test_cases = [cast(torch.tensor([1.0, 0.0, 0.0])), cast(torch.tensor([0.0, 1.0]))] - for probs in test_cases: - alias_table, prob_table = torch._multinomial_alias_setup(probs) - alias_samples = torch._multinomial_alias_draw(prob_table, alias_table, MAX_SAMPLES) - self.assertEqual(alias_samples.unique(), probs.nonzero().squeeze(-1)) + src = cast(src) + idx = cast(idx) - def test_multinomial_alias(self): - self._test_multinomial_alias(self, lambda t: t) + actual = torch.gather(src, dim, idx) + expected = cast(torch.Tensor().resize_(*idx_size)) + for i in range(idx_size[0]): + for j in range(idx_size[1]): + for k in range(idx_size[2]): + ii = [i, j, k] + ii[dim] = idx[i, j, k] + expected[i, j, k] = src[tuple(ii)] + self.assertEqual(actual, expected, 0) - def _spawn_method(self, method, arg): - try: - mp.set_start_method('spawn') - except RuntimeError: - pass - with mp.Pool(1) as pool: - self.assertTrue(pool.map(method, [arg])) + if test_bounds: + idx[0][0][0] = 23 + self.assertRaises(RuntimeError, lambda: torch.gather(src, dim, idx)) - def test_addcmul(self): - def rand_tensor(size, dtype, device): - if dtype.is_floating_point: - return torch.rand(size=size, dtype=dtype, device=device) - if dtype == torch.uint8: - return torch.randint(1, 5, size=size, dtype=dtype, device=device) - else: - return torch.randint(-5, 5, size=size, dtype=dtype, device=device) - for device in torch.testing.get_all_device_types(): - for dtype in torch.testing.get_all_math_dtypes(device): - a = rand_tensor((2, 2), dtype=dtype, device=device) - b = rand_tensor((2, 2), dtype=dtype, device=device) - c = rand_tensor((2, 2), dtype=dtype, device=device) - if dtype.is_floating_point: - alpha = 0.1 - else: - alpha = 3 - actual = torch.addcmul(a, alpha, b, c) - expected = a + alpha * b * c - self.assertTrue(torch.allclose(expected, actual)) + src = cast(torch.randn(3, 4, 5)) + expected, idx = src.max(2, True) + expected = cast(expected) + idx = cast(idx) + actual = torch.gather(src, 2, idx) + self.assertEqual(actual, expected, 0) - @staticmethod - def _test_multinomial_invalid_probs(probs): - try: - # n_sample = 1 is a special case, test n_sample=2 which is more general - torch.multinomial(probs.to('cpu'), 2) - return False # Should not be reached - except RuntimeError as e: - return 'invalid multinomial distribution' in str(e) + # Bool test case + t = torch.tensor([[False, True], [True, True]]) + self.assertEqual(torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])), torch.tensor([[False, False], [True, True]])) - @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \ - don't support multiprocessing with spawn start method") - @unittest.skipIf(IS_WINDOWS, 'FIXME: CUDA OOM error on Windows') - @unittest.skipIf(not PY3, - "spawn start method is not supported in Python 2, \ - but we need it for for testing failure case for CPU RNG on Windows") - def test_multinomial_invalid_probs(self): - test_method = _TestTorchMixin._test_multinomial_invalid_probs - self._spawn_method(test_method, torch.Tensor([1, -1, 1])) - self._spawn_method(test_method, torch.Tensor([1, inf, 1])) - self._spawn_method(test_method, torch.Tensor([1, -inf, 1])) - self._spawn_method(test_method, torch.Tensor([1, 1, nan])) - self._spawn_method(test_method, torch.Tensor([0, 1, 0])) + def test_gather(self): + self._test_gather(self, lambda t: t) - @suppress_warnings - def test_range(self): - res1 = torch.range(0, 1) - res2 = torch.Tensor() - torch.range(0, 1, out=res2) - self.assertEqual(res1, res2, 0) + @staticmethod + def _test_scatter_base(self, cast, method, is_scalar=False, test_bounds=True): + m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20) + elems_per_row = random.randint(1, 10) + dim = random.randrange(3) - # Check range for non-contiguous tensors. - x = torch.zeros(2, 3) - torch.range(0, 3, out=x.narrow(1, 1, 2)) - res2 = torch.Tensor(((0, 0, 1), (0, 2, 3))) - self.assertEqual(x, res2, 1e-16) + idx_size = [m, n, o] + idx_size[dim] = elems_per_row + idx = cast(torch.LongTensor().resize_(*idx_size)) + _TestTorchMixin._fill_indices(self, idx, dim, ([m, n, o])[dim], elems_per_row, m, n, o) - # Check negative - res1 = torch.Tensor((1, 0)) - res2 = torch.Tensor() - torch.range(1, 0, -1, out=res2) - self.assertEqual(res1, res2, 0) + if is_scalar: + src = random.random() + else: + src = cast(torch.Tensor(*idx_size).normal_()) - # Equal bounds - res1 = torch.ones(1) - res2 = torch.Tensor() - torch.range(1, 1, -1, out=res2) - self.assertEqual(res1, res2, 0) - torch.range(1, 1, 1, out=res2) - self.assertEqual(res1, res2, 0) + base = cast(torch.randn(m, n, o)) + actual = getattr(base.clone(), method)(dim, idx, src) + expected = base.clone() + for i in range(idx_size[0]): + for j in range(idx_size[1]): + for k in range(idx_size[2]): + ii = [i, j, k] + ii[dim] = idx[i, j, k] + if method == 'scatter_' and not is_scalar: + expected[tuple(ii)] = src[i, j, k] + elif method == 'scatter_add_': + expected[tuple(ii)] += src[i, j, k] + else: + expected[tuple(ii)] = src + self.assertEqual(actual, expected, 0) - # FloatTensor - res1 = torch.range(0.6, 0.9, 0.1, out=torch.FloatTensor()) - self.assertEqual(res1.size(0), 4) - res1 = torch.range(1, 10, 0.3, out=torch.FloatTensor()) - self.assertEqual(res1.size(0), 31) + if test_bounds: + idx[0][0][0] = 34 + with self.assertRaises(RuntimeError): + getattr(base.clone(), method)(dim, idx, src) - # DoubleTensor - res1 = torch.range(0.6, 0.9, 0.1, out=torch.DoubleTensor()) - self.assertEqual(res1.size(0), 4) - res1 = torch.range(1, 10, 0.3, out=torch.DoubleTensor()) - self.assertEqual(res1.size(0), 31) + # test for empty index, should be a no-op + idx = cast(torch.LongTensor()) + actual = getattr(base.clone(), method)(dim, idx, src) + self.assertEqual(actual, base, 0) - def test_range_warning(self): - with warnings.catch_warnings(record=True) as w: - torch.range(0, 10) - self.assertEqual(len(w), 1) + def test_scatter(self): + self._test_scatter_base(self, lambda t: t, 'scatter_') - def test_arange(self): - res1 = torch.arange(0, 1) - res2 = torch.Tensor() - torch.arange(0, 1, out=res2) - self.assertEqual(res1, res2, 0) + def test_scatterAdd(self): + self._test_scatter_base(self, lambda t: t, 'scatter_add_') - # Check arange with only one argument - res1 = torch.arange(10) - res2 = torch.arange(0, 10) - self.assertEqual(res1, res2, 0) + def test_scatterFill(self): + self._test_scatter_base(self, lambda t: t, 'scatter_', True) - # Check arange for non-contiguous tensors. - x = torch.zeros(2, 3) - torch.arange(0, 4, out=x.narrow(1, 1, 2)) - res2 = torch.Tensor(((0, 0, 1), (0, 2, 3))) - self.assertEqual(x, res2, 1e-16) + def test_masked_scatter(self): + with warnings.catch_warnings(record=True) as w: + for maskType in [torch.uint8, torch.bool]: + for dt in torch.testing.get_all_dtypes(): + num_copy, num_dest = 3, 10 + dest = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dt) + dest2 = dest.clone() + src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dt) + mask = torch.tensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0), dtype=maskType) - # Check negative - res1 = torch.Tensor((1, 0)) - res2 = torch.Tensor() - torch.arange(1, -1, -1, out=res2) - self.assertEqual(res1, res2, 0) + if dt == torch.bool: + # torch.bool is a special case and is being tested + # in a separate test + continue - # Equal bounds - res1 = torch.ones(1) - res2 = torch.Tensor() - torch.arange(1, 0, -1, out=res2) - self.assertEqual(res1, res2, 0) - torch.arange(1, 2, 1, out=res2) - self.assertEqual(res1, res2, 0) + if dt == torch.half: + self.assertRaises(RuntimeError, lambda: dest.masked_scatter_(mask, src)) + continue - # FloatTensor - res1 = torch.arange(0.6, 0.89, 0.1, out=torch.FloatTensor()) - self.assertEqual(res1, [0.6, 0.7, 0.8]) - res1 = torch.arange(1, 10, 0.3, out=torch.FloatTensor()) - self.assertEqual(res1.size(0), 30) - self.assertEqual(res1[0], 1) - self.assertEqual(res1[29], 9.7) + dest.masked_scatter_(mask, src) + j = 0 + for i in range(num_dest): + if mask[i]: + dest2[i] = src[j] + j += 1 + self.assertEqual(dest, dest2, 0) - # DoubleTensor - res1 = torch.arange(0.6, 0.89, 0.1, out=torch.DoubleTensor()) - self.assertEqual(res1, [0.6, 0.7, 0.8]) - res1 = torch.arange(1, 10, 0.3, out=torch.DoubleTensor()) - self.assertEqual(res1.size(0), 30) - self.assertEqual(res1[0], 1) - self.assertEqual(res1[29], 9.7) + # make source bigger than number of 1s in mask + src = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=dt) + dest.masked_scatter_(mask, src) - # Check that it's exclusive - r = torch.arange(0, 5) - self.assertEqual(r.min(), 0) - self.assertEqual(r.max(), 4) - self.assertEqual(r.numel(), 5) + # make src smaller. this should fail + src = torch.randn(num_copy - 1) + with self.assertRaises(RuntimeError): + dest.masked_scatter_(mask, src) + self.assertEqual(len(w), 25) - r = torch.arange(0, 5, 2) - self.assertEqual(r.min(), 0) - self.assertEqual(r.max(), 4) - self.assertEqual(r.numel(), 3) + warn = 'masked_scatter_ received a mask with dtype torch.uint8,' + for wi in w: + self.assertEqual(str(wi.message)[0:55], str(warn)) - r1 = torch.arange(0, 5 + 1e-6) - r2 = torch.arange(0, 5) - r3 = torch.arange(0, 5 - 1e-6) - self.assertEqual(r1[:-1], r2, 0) - self.assertEqual(r2, r3, 0) + def test_masked_fill(self): + with warnings.catch_warnings(record=True) as w: + for dt in torch.testing.get_all_dtypes(): + for dtype in [torch.uint8, torch.bool]: + num_dest = 10 + dst = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dt) + mask = torch.rand(num_dest).mul(2).floor().to(dtype) + val = random.random() + dst2 = dst.clone() - r1 = torch.arange(10, -1 + 1e-6, -1) - r2 = torch.arange(10, -1, -1) - r3 = torch.arange(10, -1 - 1e-6, -1) - self.assertEqual(r1, r2, 0) - self.assertEqual(r2, r3[:-1], 0) + if dt == torch.half: + self.assertRaises(RuntimeError, lambda: dst.masked_fill_(mask, val)) + continue - x = torch.empty(1).expand(10) - self.assertRaises(RuntimeError, lambda: torch.arange(10, out=x)) - msg = "unsupported range" - self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('inf'))) - self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('inf'))) + dst.masked_fill_(mask, val) + for i in range(num_dest): + if mask[i]: + dst2[i] = val + self.assertEqual(dst, dst2, 0) - for device in torch.testing.get_all_device_types(): - self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(-5, float('nan'), device=device)) - # check with step size - self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('-inf'), -1, device=device)) - self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('inf'), device=device)) - self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('-inf'), 10, device=device)) - self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('nan'), 10, device=device)) - self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('inf'), device=device)) - self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('nan'), device=device)) + # test non-contiguous case + dst = torch.randn(num_dest, num_dest, num_dest).permute((2, 0, 1)) + dst2 = dst.clone() + dst.masked_fill_((dst > 0).to(dtype), val) + dst2.masked_fill_((dst2 > 0).to(dtype), val) + self.assertEqual(dst, dst2, 0) + self.assertEqual(len(w), 28) - self.assertRaisesRegex( - RuntimeError, "overflow", - lambda: torch.arange(1.175494351e-38, 3.402823466e+38, device=device)) + warn = 'masked_fill_ received a mask with dtype torch.uint8,' + for wi in w: + self.assertEqual(str(wi.message)[0:52], str(warn)) - # check that it holds a consistent output shape on precision-cornered step sizes - d = torch.arange(-4.0, 4.0, 0.01, dtype=torch.float32, device=device) - self.assertEqual(d.shape[0], 800) + def test_abs(self): + def _test_abs(tensors_dict): + for _category, tensors in tensors_dict.items(): + for data in tensors: + _test_abs_single(data) - def test_arange_inference(self): - saved_dtype = torch.get_default_dtype() - torch.set_default_dtype(torch.float32) - # end only - self.assertIs(torch.float32, torch.arange(1.).dtype) - self.assertIs(torch.float32, torch.arange(torch.tensor(1.)).dtype) - self.assertIs(torch.float32, torch.arange(torch.tensor(1., dtype=torch.float64)).dtype) + def _test_abs_single(data): + switch = torch.rand(data.size()).mul(2).floor().mul(2).add(-1).type(data.dtype) + res = torch.mul(data, switch) + self.assertTensorsSlowEqual(res.abs(), data, 1e-16) - self.assertIs(torch.int64, torch.arange(1).dtype) - self.assertIs(torch.int64, torch.arange(torch.tensor(1)).dtype) - self.assertIs(torch.int64, torch.arange(torch.tensor(1, dtype=torch.int16)).dtype) + shapes = [(3, 4), (3, 5, 7), (2, 2, 5, 8, 2, 3), (1000,), (10, 10, 10)] - # start, end, [step] - self.assertIs(torch.float32, torch.arange(1., 3).dtype) - self.assertIs(torch.float32, torch.arange(torch.tensor(1., dtype=torch.float64), 3).dtype) - self.assertIs(torch.float32, torch.arange(1, 3.).dtype) - self.assertIs(torch.float32, torch.arange(torch.tensor(1, dtype=torch.int16), torch.tensor(3.)).dtype) - self.assertIs(torch.float32, torch.arange(1, 3, 1.).dtype) - self.assertIs(torch.float32, - torch.arange(torch.tensor(1), - torch.tensor(3, dtype=torch.int16), - torch.tensor(1., dtype=torch.float64)).dtype) + for shape in shapes: + # Test all except char/byte + _test_abs(self._make_tensors(shape, val_range=(0, 1000))) - self.assertIs(torch.int64, torch.arange(1, 3).dtype) - self.assertIs(torch.int64, torch.arange(torch.tensor(1), 3).dtype) - self.assertIs(torch.int64, torch.arange(torch.tensor(1), torch.tensor(3, dtype=torch.int16)).dtype) - self.assertIs(torch.int64, torch.arange(1, 3, 1).dtype) - self.assertIs(torch.int64, - torch.arange(torch.tensor(1), - torch.tensor(3), - torch.tensor(1, dtype=torch.int16)).dtype) - torch.set_default_dtype(saved_dtype) + # Test char + _test_abs_single(torch.CharTensor(*shape).random_(0, 100)) - def test_randint_inference(self): - size = (2, 1) - for args in [(3,), (1, 3)]: # (low,) and (low, high) - self.assertIs(torch.int64, torch.randint(*args, size=size).dtype) - self.assertIs(torch.int64, torch.randint(*args, size=size, layout=torch.strided).dtype) - self.assertIs(torch.int64, torch.randint(*args, size=size, generator=torch.default_generator).dtype) - self.assertIs(torch.float32, torch.randint(*args, size=size, dtype=torch.float32).dtype) - out = torch.empty(size, dtype=torch.float32) - self.assertIs(torch.float32, torch.randint(*args, size=size, out=out).dtype) - self.assertIs(torch.float32, torch.randint(*args, size=size, out=out, dtype=torch.float32).dtype) - out = torch.empty(size, dtype=torch.int64) - self.assertIs(torch.int64, torch.randint(*args, size=size, out=out).dtype) - self.assertIs(torch.int64, torch.randint(*args, size=size, out=out, dtype=torch.int64).dtype) + # Test byte + byte_tensor = torch.ByteTensor(*shape).random_(0, 100) + self.assertTensorsSlowEqual(byte_tensor, byte_tensor.abs(), 1e-16) - @staticmethod - def _select_broadcastable_dims(dims_full=None): - # select full dimensionality - if dims_full is None: - dims_full = [] - ndims = random.randint(1, 4) - dims_full = [random.randint(1, 8) for _ in range(ndims)] - else: - ndims = len(dims_full) + # Checking that the right abs function is called for LongTensor + bignumber = 2 ^ 31 + 1 + res = torch.LongTensor((-bignumber,)) + self.assertGreater(res.abs()[0], 0) - # select actual dimensions for ops: - # larger: full ndims, individual sizes may be reduced - # smaller: possibly reduced ndims, sizes may be reduced - smaller_ndims = random.randint(1, ndims) - dims_small = [] - dims_large = [] - for i in range(ndims - 1, -1, -1): - j = random.randint(1, 3) - if j == 1: # no reduced singleton dimension - ds = dims_full[i] - dl = dims_full[i] - elif j == 2: # larger may have reduced singleton dimension - ds = dims_full[i] - dl = 1 if len(dims_small) < smaller_ndims else dims_full[i] - elif j == 3: # smaller may have reduced singleton dimension - ds = 1 - dl = dims_full[i] - dims_large = [dl] + dims_large - if len(dims_small) < smaller_ndims: - dims_small = [ds] + dims_small - return (dims_small, dims_large, dims_full) + # One of + rec = torch.randn(2, 2, 3, 7, 6, 2).type(torch.float64).clamp(0, 1) + val1 = rec.select(-1, -1).data[0][0][0].sum() + val2 = rec.select(-1, -1).data.abs()[0][0][0].sum() + self.assertEqual(val1, val2, 1e-8, 'absolute value') - @staticmethod - def _test_broadcast(self, cast): + # Both abs(0.0) and abs(-0.0) should result in 0.0 + for dtype in (torch.float, torch.double): + abs_zeros = torch.tensor([0.0, -0.0], dtype=dtype).abs().tolist() + for num in abs_zeros: + self.assertGreater(math.copysign(1.0, num), 0.0) - # all functions - fns = { - "dist", "atan2", "pow", "lerp", "add", - "sub", "mul", "div", "fmod", "remainder", - "eq", "ge", "gt", "le", "lt", "max", "min", "ne", - "addcdiv", "addcmul", "masked_scatter", "masked_select", "masked_fill", - "map", "map2", "copy" - } - # functions with three tensor arguments - fns_3_args = {"addcdiv", "addcmul", "map2"} + def test_hardshrink(self): + data_original = torch.tensor([1, 0.5, 0.3, 0.6]).view(2, 2) + float_types = [ + 'torch.DoubleTensor', + 'torch.FloatTensor' + ] + for t in float_types: + data = data_original.type(t) + self.assertEqual(torch.tensor([1, 0.5, 0, 0.6]).view(2, 2), data.hardshrink(0.3)) + self.assertEqual(torch.tensor([1, 0, 0, 0.6]).view(2, 2), data.hardshrink(0.5)) - for fn in fns: - (dims_small, dims_large, dims_full) = self._select_broadcastable_dims() - full1d = cast(torch.randn(*dims_full).flatten().float()) - small = cast(torch.randn(*dims_small).float()) - large = cast(torch.randn(*dims_large).float()) - small_expanded = small.expand(*dims_full) - large_expanded = large.expand(*dims_full) - small2 = None - small2_expanded = None - if fn in fns_3_args: - # create another smaller tensor - (dims_small2, _, _) = self._select_broadcastable_dims(dims_full) - small2 = cast(torch.randn(*dims_small2).float()) - small2_expanded = small2.expand(*dims_full) + # test default lambd=0.5 + self.assertEqual(data.hardshrink(), data.hardshrink(0.5)) - if small.is_cuda and fn in ['map', 'map2']: - # map and map2 are not implementd on CUDA tensors - continue + # test non-contiguous case + self.assertEqual(torch.tensor([1, 0, 0.5, 0.6]).view(2, 2), data.t().hardshrink(0.3)) - if hasattr(large_expanded, fn): - # run through tensor versions of functions - # and verify fully expanded inputs give same results - expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded} + def test_hardshrink_edge_cases(self): + def h(t, values, l_expected): + for l, expected in l_expected.items(): + values_tensor = torch.tensor([float(v) for v in values]).type(t) + expected_tensor = torch.tensor([float(v) for v in expected]).type(t) + self.assertEqual(expected_tensor == values_tensor.hardshrink(l), + torch.ones_like(values_tensor)) - def tensorfn(myfn, t1, t2): - if fn == "lerp": - return myfn(t1, 0.5) - elif fn == "masked_select": - return myfn(t1 < 0) - elif fn == "masked_scatter": - return myfn(t1 < 0.5, full1d) - elif fn == "masked_fill": - return myfn(t1 < 0.5, 1.0) - elif fn in fns_3_args: - return myfn(1, t1, t2) - else: - return myfn(t1) + def test_helper(t, min, max): + h(t, [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], + {0.0: [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], + min: [0.0, 0.0, 0.0, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], + 0.1: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, max, -max, inf, -inf], + 1.0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, max, -max, inf, -inf], + max: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, inf, -inf], + inf: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}) - # test various orders - for first, second, third in [(large, small, small2), (small, large, small2), - (small2, small, large), (small2, large, small)]: - if first is None: - break # ignore last iter when small2 is None - method_expanded = getattr(expanded[first], fn) - method = getattr(first, fn) - r1 = tensorfn(method_expanded, expanded[second], expanded[third]) - r2 = tensorfn(method, second, third) - self.assertEqual(r1, r2) + test_helper(torch.DoubleTensor, + torch.finfo(torch.double).tiny, torch.finfo(torch.double).max) + test_helper(torch.FloatTensor, + torch.finfo(torch.float).tiny, torch.finfo(torch.float).max) - # now for torch. versions of functions - if hasattr(torch, fn): - fntorch = getattr(torch, fn) - expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded} + def test_unbiased(self): + tensor = torch.randn(100) + self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True)) + self.assertEqual(tensor.var(), tensor.var(unbiased=True)) + self.assertEqual(tensor.var(unbiased=False), tensor.var(0, unbiased=False)) - def torchfn(t1, t2, t3): - if fn == "lerp": - return fntorch(t1, t2, 0.5) - elif fn == "masked_select": - return fntorch(t1, t2 < 0) - elif fn == "masked_scatter": - return fntorch(t1, t2 < 0.5, full1d) - elif fn == "masked_fill": - return fntorch(t1, t2 < 0.5, 1.0) - elif fn in fns_3_args: - return fntorch(t1, 1.0, t2, t3) - else: - return fntorch(t1, t2) + tensor = torch.FloatTensor([1.0, 2.0]) + self.assertEqual(tensor.var(unbiased=True), 0.5) + self.assertEqual(tensor.var(unbiased=False), 0.25) - # test various orders - for first, second, third in [(large, small, small2), (small, large, small2), - (small2, small, large), (small2, large, small)]: - if first is None: - break # ignore last iter when small2 is None - r1 = torchfn(expanded[first], expanded[second], expanded[third]) - r2 = torchfn(first, second, third) - self.assertEqual(r1, r2) + tensor = torch.FloatTensor([1.0, 2.0, 3.0]) + self.assertEqual(tensor.var(unbiased=True), 1.0) + self.assertEqual(tensor.var(unbiased=False), 2.0 / 3.0) - # now for in place functions - # in-place tensor is not broadcastable; test only guaranteed - # to work by broadcasting other argument(s) - if not hasattr(large_expanded, fn + "_"): - continue + tensor = torch.randn(100) + self.assertEqual(tensor.std(0), tensor.std(0, unbiased=True)) + self.assertEqual(tensor.std(), tensor.std(unbiased=True)) + self.assertEqual(tensor.std(unbiased=False), tensor.std(0, unbiased=False)) - # need to clone largeExpanded so we can reuse, since functions are in-place - large_expanded_clone = large_expanded.clone() + def test_structseq_repr(self): + a = torch.arange(250).reshape(5, 5, 10) + expected = """ + torch.return_types.max( + values=tensor([[ 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], + [ 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], + [140, 141, 142, 143, 144, 145, 146, 147, 148, 149], + [190, 191, 192, 193, 194, 195, 196, 197, 198, 199], + [240, 241, 242, 243, 244, 245, 246, 247, 248, 249]]), + indices=tensor([[4, 4, 4, 4, 4, 4, 4, 4, 4, 4], + [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], + [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], + [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], + [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]]))""" + self.assertEqual(repr(a.max(1)), textwrap.dedent(expected).strip()) - def tensorfn_inplace(t0, t1, t2=None): - t0_fn = getattr(t0, fn + "_") - if fn == "lerp": - return t0_fn(t1, 0.5) - elif fn == "masked_scatter": - return t0_fn(t1 < 0.5, full1d) - elif fn == "masked_fill": - return t0_fn(t1 < 0.5, 1.0) - elif fn == "map": - return t0_fn(t1, lambda x, y: x + y) - elif fn == "map2": - return t0_fn(t1, t2, lambda x, y, z: x + y + z) - elif fn in fns_3_args: - return t0_fn(1.0, t1, t2) - else: - return t0_fn(t1) - # in-place pointwise operations don't actually work if the in-place - # tensor is 0-strided (numpy has the same issue) - if (0 not in large_expanded.stride() and 0 not in large_expanded_clone.stride()): - r1 = tensorfn_inplace(large_expanded, small_expanded, small2_expanded) - r2 = tensorfn_inplace(large_expanded_clone, small, small2) - self.assertEqual(r1, r2) + def test_var_stability(self): + tensor = torch.FloatTensor([2281.5, 2281.25]) + self.assertEqual(tensor.var(dim=0), 0.03125) + self.assertEqual(tensor.var(), 0.03125) - def broadcastable(t0, t1, t2=None): - try: - t1.expand_as(t0) - if t2 is not None: - t2.expand_as(t0) - except RuntimeError: - return False - return True + def test_view_empty(self): + x = torch.randn(0, 6) + self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape) - def _test_in_place_broadcastable(t0, t1, t2=None): - if not broadcastable(t0, t1, t2): - same_size = t0.numel() == t1.numel() and (t0.numel() == t2.numel() if t2 is not None else True) - if not same_size: - self.assertRaises(RuntimeError, lambda: tensorfn_inplace(t0, t1, t2)) - else: - tensorfn_inplace(t0, t1, t2) + def test_reshape(self): + x = torch.randn(3, 3) + self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr()) + self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr()) + self.assertEqual(torch.reshape(x, (9,)), x.reshape(9)) + self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1)) - if fn not in fns_3_args: - _test_in_place_broadcastable(small, large_expanded) - _test_in_place_broadcastable(small, large) - else: - _test_in_place_broadcastable(small2, small_expanded, large_expanded) - _test_in_place_broadcastable(small2, small, large) + y = torch.randn(4, 4, 4)[:, 0, :] + self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr()) + self.assertEqual(y.contiguous().view(-1), y.reshape(-1)) + self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr()) - def test_broadcast(self): - self._test_broadcast(self, lambda t: t) + s = torch.randn(()) + self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr()) + self.assertEqual(s.reshape(-1).shape, (1,)) + self.assertRaises(RuntimeError, lambda: s.reshape(2)) - def test_broadcast_empty(self): - # empty + empty - self.assertRaises(RuntimeError, lambda: torch.randn(5, 0) + torch.randn(0, 5)) - self.assertEqual(torch.randn(5, 0), torch.randn(0) + torch.randn(5, 0)) - self.assertEqual(torch.randn(5, 0, 0), torch.randn(0) + torch.randn(5, 0, 1)) + empty = torch.tensor([]) + self.assertEqual(empty, empty.reshape(-1)) + self.assertEqual(empty, empty.reshape([0])) + # TODO: fix these once we have multi-dimensional empty tensors + self.assertEqual(empty.reshape([0, 1]).shape, (0, 1)) + self.assertEqual(empty.reshape([1, -1]).shape, (1, 0)) + self.assertRaises(RuntimeError, lambda: empty.reshape(1)) - # scalar + empty - self.assertEqual(torch.randn(5, 0, 6), torch.randn(()) + torch.randn(5, 0, 6)) + x = torch.randn(3, 3) + self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr()) + self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr()) + self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10))) - # non-empty, empty - self.assertEqual(torch.randn(0), torch.randn(0) + torch.randn(1)) - self.assertEqual(torch.randn(0, 7, 0, 6, 5, 0, 7), - torch.randn(0, 7, 0, 6, 5, 0, 1) + torch.randn(1, 1, 5, 1, 7)) - self.assertRaises(RuntimeError, lambda: torch.randn(7, 0) + torch.randn(2, 1)) + def test_empty_reshape(self): + x = torch.randn(0, 6) + self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape) + # should be viewable -- i.e. data_ptr is the same. + self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr()) - def test_broadcast_tensors(self): - x0 = torch.randn(2, 1, 3) - x1 = torch.randn(3) - x2 = torch.randn(3, 1) - expected_size = (2, 3, 3) + # match NumPy semantics -- don't infer the size of dimension with a degree of freedom + self.assertRaises(RuntimeError, lambda: x.reshape(0, -1)) - y0, y1, y2 = torch.broadcast_tensors(x0, x1, x2) - self.assertTrue(y0.size() == expected_size) - self.assertTrue(y1.size() == expected_size) - self.assertTrue(y2.size() == expected_size) + def check_single_matmul(self, x, y, shape): + a = np.array(x, copy=False) + b = np.array(y, copy=False) + expected = np.matmul(a, b) - @staticmethod - def _test_contiguous(self, cast): - x = cast(torch.randn(1, 16, 5, 5)) - self.assertTrue(x.is_contiguous()) - stride = list(x.stride()) - stride[0] = 20 - # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1 - x.set_(x.storage(), 0, x.size(), stride) - self.assertTrue(x.is_contiguous()) + ans = torch.matmul(x, y) + self.assertTrue(ans.is_contiguous()) + self.assertTrue(np.array_equal(ans, expected)) - def test_contiguous(self): - return self._test_contiguous(self, lambda t: t) + out = torch.zeros(*shape, dtype=torch.int64) + ans = torch.matmul(x, y, out=out) + self.assertIs(ans, out) + self.assertTrue(ans.is_contiguous()) + self.assertTrue(np.array_equal(ans, expected)) - def test_empty_tensor_props(self): - sizes = [(0,), (0, 3), (5, 0), (5, 0, 3, 0, 2), (0, 3, 0, 2), (0, 5, 0, 2, 0)] - for size in sizes: - for device in torch.testing.get_all_device_types(): - x = torch.empty(tuple(size), device=device) - self.assertEqual(size, x.shape) - self.assertTrue(x.is_contiguous()) - size_ones_instead_of_zeros = (x if x != 0 else 1 for x in size) - y = torch.empty(tuple(size_ones_instead_of_zeros), device=device) - self.assertEqual(x.stride(), y.stride()) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_matmul_small_brute_force_1d_Nd(self): + # Issue #20452: range(0, 10) does not work. + n = 1 + for m in range(1, 8): + for p in range(1, 8): + for o in range(1, 5): + # 1d, 3d, inner dimensions C + x = torch.arange(m) + y = torch.arange(o * m * p).reshape(o, m, p) + self.check_single_matmul(x, y, (o, n, p)) - def test_scalars_as_floats(self): - "zero-dim variables that don't require grad should bind to scalar arguments" - x = torch.tensor(2.) - y = torch.tensor(3.) - # 3 + (3 * 3) * 2 - self.assertEqual(y.addcmul(y, y, value=x), 21) + # 1d, 3d, inner dimensions Fortran + x = torch.arange(m) + y = torch.arange(o * p * m).reshape(o, p, m).transpose(-1, -2) + self.check_single_matmul(x, y, (o, n, p)) - x = torch.tensor(2., requires_grad=True) - self.assertRaises(Exception, lambda: y.addcmul(y, y, value=x)) + # 1d, 3d, inner dimensions non-contiguous + x = torch.arange(2 * m)[::2] + y = torch.arange(o * m * 2 * p).reshape(o, m, 2 * p)[:, :, ::2] + self.check_single_matmul(x, y, (o, n, p)) - @staticmethod - def _test_broadcast_fused_matmul(self, cast): - fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"] + for r in range(1, 5): + # 1d, 4d, inner dimensions C + x = torch.arange(m) + y = torch.arange(r * o * m * p).reshape(r, o, m, p) + self.check_single_matmul(x, y, (r, o, n, p)) - for fn in fns: - batch_dim = random.randint(1, 8) - n_dim = random.randint(1, 8) - m_dim = random.randint(1, 8) - p_dim = random.randint(1, 8) + # 1d, 4d, inner dimensions Fortran + x = torch.arange(m) + y = torch.arange(r * o * p * m).reshape(r, o, p, m).transpose(-1, -2) + self.check_single_matmul(x, y, (r, o, n, p)) - def dims_full_for_fn(): - if fn == "baddbmm": - return ([batch_dim, n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim]) - elif fn == "addbmm": - return ([n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim]) - elif fn == "addmm": - return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim]) - elif fn == "addmv": - return ([n_dim], [n_dim, m_dim], [m_dim]) - elif fn == "addr": - return ([n_dim, m_dim], [n_dim], [m_dim]) - else: - raise AssertionError("unknown function") + # 1d, 4d, inner dimensions non-contiguous + x = torch.arange(2 * m)[::2] + y = torch.arange(r * o * m * 2 * p).reshape(r, o, m, 2 * p)[:, :, :, ::2] + self.check_single_matmul(x, y, (r, o, n, p)) - (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn() - (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_matmul_small_brute_force_2d_Nd(self): + # Issue #20452: range(0, 10) does not work. + for n in range(1, 5): + for m in range(1, 5): + for p in range(1, 5): + for o in range(1, 3): + # 2d, 3d, inner dimensions C + x = torch.arange(n * m).reshape(n, m) + y = torch.arange(o * m * p).reshape(o, m, p) + self.check_single_matmul(x, y, (o, n, p)) - t0_small = cast(torch.randn(*t0_dims_small).float()) - t1 = cast(torch.randn(*t1_dims).float()) - t2 = cast(torch.randn(*t2_dims).float()) + # 2d, 3d, inner dimensions Fortran + x = torch.arange(m * n).reshape(m, n).transpose(-1, -2) + y = torch.arange(o * p * m).reshape(o, p, m).transpose(-1, -2) + self.check_single_matmul(x, y, (o, n, p)) - t0_full = cast(t0_small.expand(*t0_dims_full)) + # 2d, 3d, inner dimensions non-contiguous + x = torch.arange(n * 2 * m).reshape(n, 2 * m)[:, ::2] + y = torch.arange(o * m * 2 * p).reshape(o, m, 2 * p)[:, :, ::2] + self.check_single_matmul(x, y, (o, n, p)) - fntorch = getattr(torch, fn) - r0 = fntorch(t0_small, t1, t2) - r1 = fntorch(t0_full, t1, t2) - self.assertEqual(r0, r1) + for r in range(1, 2): + # 2d, 4d, inner dimensions C + x = torch.arange(n * m).reshape(n, m) + y = torch.arange(r * o * m * p).reshape(r, o, m, p) + self.check_single_matmul(x, y, (r, o, n, p)) - def test_broadcast_fused_matmul(self): - self._test_broadcast_fused_matmul(self, lambda t: t) + # 2d, 4d, inner dimensions Fortran + x = torch.arange(m * n).reshape(m, n).transpose(-1, -2) + y = torch.arange(r * o * p * m).reshape(r, o, p, m).transpose(-1, -2) + self.check_single_matmul(x, y, (r, o, n, p)) - @staticmethod - def _test_broadcast_batched_matmul(self, cast): - n_dim = random.randint(1, 8) - m_dim = random.randint(1, 8) - p_dim = random.randint(1, 8) - full_batch_dims = [random.randint(1, 3) for i in range(random.randint(1, 3))] - (batch_dims_small, _, _) = self._select_broadcastable_dims(full_batch_dims) + # 2d, 4d, inner dimensions non-contiguous + x = torch.arange(n * 2 * m).reshape(n, 2 * m)[:, ::2] + y = torch.arange(r * o * m * 2 * p).reshape(r, o, m, 2 * p)[:, :, :, ::2] + self.check_single_matmul(x, y, (r, o, n, p)) - def verify_batched_matmul(full_lhs, one_dimensional): - if not one_dimensional: - lhs_dims = [n_dim, m_dim] - rhs_dims = [m_dim, p_dim] - result_dims = [n_dim, p_dim] - else: - lhs_dims = [n_dim, m_dim] if full_lhs else [m_dim] - rhs_dims = [m_dim, p_dim] if not full_lhs else [m_dim] - result_dims = [n_dim] if full_lhs else [p_dim] + def test_expand(self): + tensor = torch.rand(1, 8, 1) + tensor2 = torch.rand(5) + template = torch.rand(4, 8, 5) + target = template.size() + self.assertEqual(tensor.expand_as(template).size(), target) + self.assertEqual(tensor.expand(4, 8, 5).size(), target) + self.assertEqual(tensor.expand(target).size(), target) + self.assertEqual(tensor2.expand_as(template).size(), target) + self.assertEqual(tensor2.expand(4, 8, 5).size(), target) + self.assertEqual(tensor2.expand(target).size(), target) - lhs_mat_dims = lhs_dims if len(lhs_dims) != 1 else [1, m_dim] - rhs_mat_dims = rhs_dims if len(rhs_dims) != 1 else [m_dim, 1] - full_mat_dims = lhs_mat_dims if full_lhs else rhs_mat_dims - dim0_dims = rhs_dims if full_lhs else lhs_dims - small_dims = batch_dims_small + (rhs_mat_dims if full_lhs else lhs_mat_dims) + # test double expand + self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1)) - small = cast(torch.randn(*(small_dims)).float()) - dim0 = cast(torch.randn(*(dim0_dims)).float()) - full = cast(torch.randn(*(full_batch_dims + full_mat_dims)).float()) - if not one_dimensional: - (lhsTensors, rhsTensors) = ((full,), (small, dim0)) if full_lhs else ((small, dim0), (full,)) - else: - (lhsTensors, rhsTensors) = ((full,), (dim0,)) if full_lhs else ((dim0,), (full,)) + # test non-contiguous + noncontig = torch.randn(5, 2, 1, 3)[:, 0] + self.assertFalse(noncontig.is_contiguous()) + self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1)) - def maybe_squeeze_result(l, r, result): - if len(lhs_dims) == 1 and l.dim() != 1: - return result.squeeze(-2) - elif len(rhs_dims) == 1 and r.dim() != 1: - return result.squeeze(-1) - else: - return result + # make sure it's compatible with unsqueeze + expanded = tensor2.expand(1, 1, 5) + unsqueezed = tensor2.unsqueeze(0).unsqueeze(1) + self.assertEqual(expanded, unsqueezed) + self.assertEqual(expanded.stride(), unsqueezed.stride()) - for lhs in lhsTensors: - lhs_expanded = lhs.expand(*(torch.Size(full_batch_dims) + torch.Size(lhs_mat_dims))) - lhs_expanded_matmul_fn = lhs_expanded.matmul - for rhs in rhsTensors: - rhs_expanded = ((rhs if len(rhs_dims) != 1 else rhs.unsqueeze(-1)). - expand(*(torch.Size(full_batch_dims) + torch.Size(rhs_mat_dims)))) - truth = maybe_squeeze_result(lhs_expanded, rhs_expanded, lhs_expanded_matmul_fn(rhs_expanded)) - for l in (lhs, lhs_expanded): - for r in (rhs, rhs_expanded): - l_matmul_fn = l.matmul - result = maybe_squeeze_result(l, r, l_matmul_fn(r)) - self.assertEqual(truth, result) - # test torch.matmul function as well - torch_result = maybe_squeeze_result(l, r, torch.matmul(l, r)) - self.assertEqual(truth, torch_result) - # test torch.matmul with out - out = torch.zeros_like(torch_result) - torch.matmul(l, r, out=out) - self.assertEqual(truth, maybe_squeeze_result(l, r, out)) + # test -1 as target size + self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5)) + self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1)) - # compare to bmm - bmm_result = (torch.bmm(lhs_expanded.contiguous().view(-1, *lhs_mat_dims), - rhs_expanded.contiguous().view(-1, *rhs_mat_dims))) - self.assertEqual(truth.view(-1, *result_dims), bmm_result.view(-1, *result_dims)) + # test expanding empty to empty + self.assertEqual(torch.zeros(0).expand((0,)), torch.zeros(0)) - for indices in product((True, False), repeat=2): - verify_batched_matmul(*indices) + def test_repeat(self): + initial_shape = (8, 4) + tensor = torch.rand(*initial_shape) - def test_broadcast_batched_matmul(self): - self._test_broadcast_batched_matmul(self, lambda t: t) + size = (3, 1, 1) + torchSize = torch.Size(size) + target = [3, 8, 4] + self.assertEqual(tensor.repeat(*size).size(), target, 'Error in repeat') + self.assertEqual(tensor.repeat(torchSize).size(), target, + 'Error in repeat using LongStorage') + result = tensor.repeat(*size) + self.assertEqual(result.size(), target, 'Error in repeat using result') + result = tensor.repeat(torchSize) + self.assertEqual(result.size(), target, 'Error in repeat using result and LongStorage') + self.assertEqual(result.mean(0).view(8, 4), tensor, 'Error in repeat (not equal)') - def test_copy_broadcast(self): - torch.zeros(5, 6).copy_(torch.zeros(6)) - self.assertRaises(RuntimeError, lambda: torch.zeros(5, 6).copy_(torch.zeros(30))) + zeroDimTarget = torch.Size([24, 0]) + self.assertEqual(tensor.repeat((3, 0)).size(), zeroDimTarget, "Error when calling with 0 repeats") - def test_copy_many_to_one(self): - # Testing in-place copy where it attempt to write from many memory - # storage to a single storage would cause RuntimeError to be thrown - self.assertRaises(RuntimeError, lambda: torch.zeros(1, 6).expand(5, 6).copy_(torch.zeros(5, 6))) + def test_repeat_interleave(self): + x = torch.tensor([0, 1, 2, 3]) + expected = torch.tensor([1, 2, 2, 3, 3, 3]) + self.assertEqual(torch.repeat_interleave(x), expected) - @staticmethod - def _test_randperm(self, device): - if device == 'cpu': - rng_device = None - else: - rng_device = [0] + with self.assertRaises(RuntimeError): + torch.repeat_interleave(torch.arange(4).reshape(2, 2)) - # Test core functionality. On CUDA, for small n, randperm is offloaded to CPU instead. For large n, randperm is - # executed on GPU. - for n in (100, 50000, 100000): - # Ensure both integer and floating-point numbers are tested. Half follows an execution path that is - # different from others on CUDA. - for dtype in (torch.long, torch.half, torch.float): - if n > 2049 and dtype == torch.half: # Large n for torch.half will raise an exception, do not test here. - continue - with torch.random.fork_rng(devices=rng_device): - res1 = torch.randperm(n, dtype=dtype, device=device) - res2 = torch.empty(0, dtype=dtype, device=device) - torch.randperm(n, out=res2, dtype=dtype, device=device) - self.assertEqual(res1, res2, 0) + with self.assertRaises(RuntimeError): + torch.repeat_interleave(torch.arange(4.0)) - # Default type is long - for n in (100, 10000): - self.assertEqual(torch.randperm(n, device=device).dtype, torch.long) + with self.assertRaises(RuntimeError): + torch.repeat_interleave(torch.tensor([1, 2, -1, 3, 4])) - # randperm of 0 elements is an empty tensor - res1 = torch.randperm(0) - res2 = torch.tensor(5, dtype=dtype, device=device) - torch.randperm(0, out=res2) - self.assertEqual(res1.numel(), 0) - self.assertEqual(res2.numel(), 0) + y = torch.tensor([[1, 2], [3, 4]]) - # Test exceptions when n is too large for a floating point type - for dtype, small_n, large_n in ((torch.half, 2**11 + 1, 2**11 + 2), - (torch.float, 2**24 + 1, 2**24 + 2), - (torch.double, 2**25, # 2**53 + 1 is too large to run - 2**53 + 2)): - res = torch.empty(0, dtype=dtype, device=device) - torch.randperm(small_n, out=res) # No exception expected - self.assertRaises(RuntimeError, lambda: torch.randperm(large_n, out=res, device=device)) + y1_v1 = torch.repeat_interleave(y, 2) + y1_v2 = torch.repeat_interleave(y, torch.tensor(2)) + y1_v3 = torch.repeat_interleave(y, torch.tensor([2])) + y1_expect = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4]) + self.assertEqual(y1_v1, y1_expect) + self.assertEqual(y1_v2, y1_expect) + self.assertEqual(y1_v3, y1_expect) - # Test non-contiguous tensors - for n in (4, 5, 6, 10, 20): - non_contiguous_tensor = torch.zeros((2, 3), dtype=torch.long, device=device).t() - self.assertFalse(non_contiguous_tensor.is_contiguous()) - with torch.random.fork_rng(devices=rng_device): - res = torch.randperm(n, dtype=torch.long, device=device) - torch.randperm(n, out=non_contiguous_tensor) - self.assertEqual(non_contiguous_tensor, res) + y2 = torch.repeat_interleave(y, 3, dim=1) + y2_expect = torch.tensor([[1, 1, 1, 2, 2, 2], + [3, 3, 3, 4, 4, 4]]) + self.assertEqual(y2, y2_expect) - def test_randperm(self): - self._test_randperm(self, 'cpu') + y3 = torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) + y3_expect = torch.tensor([[1, 2], + [3, 4], + [3, 4]]) + self.assertEqual(y3, y3_expect) - def test_random(self): - # This test is flaky with p<=(2/(ub-lb))^200=6e-36 - t = torch.FloatTensor(200) - lb = 1 - ub = 4 + with self.assertRaises(RuntimeError): + torch.repeat_interleave(y, torch.tensor([1, 2, 3]), dim=0) - t.fill_(-1) - t.random_(lb, ub) - self.assertEqual(t.min(), lb) - self.assertEqual(t.max(), ub - 1) + with self.assertRaises(RuntimeError): + torch.repeat_interleave(y, torch.arange(9).reshape(3, 3), dim=0) - t.fill_(-1) - t.random_(ub) - self.assertEqual(t.min(), 0) - self.assertEqual(t.max(), ub - 1) + # test zero sized dimension + x = torch.zeros((5, 0)) + y = torch.repeat_interleave(x, repeats=3, dim=1) + self.assertEqual(y, x.new_zeros(5, 0)) - def test_not_equal(self): - ones = torch.ones(10, dtype=torch.int) - self.assertRaisesRegex(AssertionError, "0 not greater than or equal to", - lambda: self.assertNotEqual(ones, ones)) + x = torch.tensor([], dtype=torch.int64) + y = torch.repeat_interleave(x, x) + self.assertEqual(y, x) - @staticmethod - def _test_random_neg_values(self, use_cuda=False): - signed_types = ['torch.DoubleTensor', 'torch.FloatTensor', 'torch.LongTensor', - 'torch.IntTensor', 'torch.ShortTensor'] - for tname in signed_types: - res = torch.rand(SIZE, SIZE).type(tname) - if use_cuda: - res = res.cuda() - res.random_(-10, -1) - self.assertLessEqual(res.max().item(), 9) - self.assertGreaterEqual(res.min().item(), -10) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_repeat_tile(self): - def test_random_neg_values(self): - self._test_random_neg_values(self) + initial_shape = (8, 4) - def assertIsOrdered(self, order, x, mxx, ixx, task): - SIZE = 4 - if order == 'descending': - def check_order(a, b): - # `a != a` because we put NaNs - # at the end of ascending sorted lists, - # and the beginning of descending ones. - return a != a or a >= b - elif order == 'ascending': - def check_order(a, b): - # see above - return b != b or a <= b - else: - error('unknown order "{}", must be "ascending" or "descending"'.format(order)) + repeats = ((3, 1, 1), + (3, 3, 3), + (1, 2, 1), + (2, 2, 2, 2)) - are_ordered = True - for j, k in product(range(SIZE), range(1, SIZE)): - self.assertTrue(check_order(mxx[j][k - 1], mxx[j][k]), - 'torch.sort ({}) values unordered for {}'.format(order, task)) + def _generate_noncontiguous_input(): - seen = set() - indicesCorrect = True - size = x.size(x.dim() - 1) - for k in range(size): - seen.clear() - for j in range(size): - self.assertEqual(x[k][ixx[k][j]], mxx[k][j], - 'torch.sort ({}) indices wrong for {}'.format(order, task)) - seen.add(ixx[k][j]) - self.assertEqual(len(seen), size) + out = np.broadcast_to(np.random.random((1, 4)), + initial_shape) - def test_sort(self): - SIZE = 4 - x = torch.rand(SIZE, SIZE) - res1val, res1ind = torch.sort(x) + assert not (out.flags.c_contiguous or out.flags.f_contiguous) - # Test use of result tensor - res2val = torch.Tensor() - res2ind = torch.LongTensor() - torch.sort(x, out=(res2val, res2ind)) - self.assertEqual(res1val, res2val, 0) - self.assertEqual(res1ind, res2ind, 0) - self.assertEqual(torch.argsort(x), res1ind) - self.assertEqual(x.argsort(), res1ind) + return out - # Test sorting of random numbers - self.assertIsOrdered('ascending', x, res2val, res2ind, 'random') + for repeat in repeats: + for tensor in (torch.from_numpy(np.random.random(initial_shape)), + torch.from_numpy(_generate_noncontiguous_input()),): - # Test simple sort - self.assertEqual( - torch.sort(torch.Tensor((50, 40, 30, 20, 10)))[0], - torch.Tensor((10, 20, 30, 40, 50)), - 0 - ) + self.assertEqual(tensor.repeat(*repeat).numpy(), + np.tile(tensor.numpy(), repeat)) - # Test that we still have proper sorting with duplicate keys - x = torch.floor(torch.rand(SIZE, SIZE) * 10) - torch.sort(x, out=(res2val, res2ind)) - self.assertIsOrdered('ascending', x, res2val, res2ind, 'random with duplicate keys') + def test_is_same_size(self): + t1 = torch.Tensor(3, 4, 9, 10) + t2 = torch.Tensor(3, 4) + t3 = torch.Tensor(1, 9, 3, 3) + t4 = torch.Tensor(3, 4, 9, 10) - # DESCENDING SORT - x = torch.rand(SIZE, SIZE) - res1val, res1ind = torch.sort(x, x.dim() - 1, True) + self.assertFalse(t1.is_same_size(t2)) + self.assertFalse(t1.is_same_size(t3)) + self.assertTrue(t1.is_same_size(t4)) - # Test use of result tensor - res2val = torch.Tensor() - res2ind = torch.LongTensor() - torch.sort(x, x.dim() - 1, True, out=(res2val, res2ind)) - self.assertEqual(res1val, res2val, 0) - self.assertEqual(res1ind, res2ind, 0) - self.assertEqual(torch.argsort(x, x.dim() - 1, True), res1ind) - self.assertEqual(x.argsort(x.dim() - 1, True), res1ind) + def test_is_set_to(self): + t1 = torch.Tensor(3, 4, 9, 10) + t2 = torch.Tensor(3, 4, 9, 10) + t3 = torch.Tensor().set_(t1) + t4 = t3.clone().resize_(12, 90) + self.assertFalse(t1.is_set_to(t2)) + self.assertTrue(t1.is_set_to(t3)) + self.assertTrue(t3.is_set_to(t1), "is_set_to should be symmetric") + self.assertFalse(t1.is_set_to(t4)) + self.assertFalse(torch.Tensor().is_set_to(torch.Tensor()), + "Tensors with no storages should not appear to be set " + "to each other") - # Test sorting of random numbers - self.assertIsOrdered('descending', x, res2val, res2ind, 'random') + t1 = torch.tensor([True, True], dtype=torch.bool) + t2 = torch.tensor([0], dtype=torch.bool).set_(t1) + self.assertTrue(t1.is_set_to(t2)) - # Test simple sort task - self.assertEqual( - torch.sort(torch.Tensor((10, 20, 30, 40, 50)), 0, True)[0], - torch.Tensor((50, 40, 30, 20, 10)), - 0 - ) + def test_tensor_set(self): + t1 = torch.Tensor() + t2 = torch.Tensor(3, 4, 9, 10).uniform_() + t1.set_(t2) + self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) + size = torch.Size([9, 3, 4, 10]) + t1.set_(t2.storage(), 0, size) + self.assertEqual(t1.size(), size) + t1.set_(t2.storage(), 0, tuple(size)) + self.assertEqual(t1.size(), size) + self.assertEqual(t1.stride(), (120, 40, 10, 1)) + stride = (10, 360, 90, 1) + t1.set_(t2.storage(), 0, size, stride) + self.assertEqual(t1.stride(), stride) + t1.set_(t2.storage(), 0, size=size, stride=stride) + self.assertEqual(t1.size(), size) + self.assertEqual(t1.stride(), stride) - # Test that we still have proper sorting with duplicate keys - self.assertIsOrdered('descending', x, res2val, res2ind, 'random with duplicate keys') + # test argument names + t1 = torch.Tensor() + # 1. case when source is tensor + t1.set_(source=t2) + self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) + # 2. case when source is storage + t1.set_(source=t2.storage()) + self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) + # 3. case when source is storage, and other args also specified + t1.set_(source=t2.storage(), storage_offset=0, size=size, stride=stride) + self.assertEqual(t1.size(), size) + self.assertEqual(t1.stride(), stride) - # Test sorting with NaNs - x = torch.rand(SIZE, SIZE) - x[1][2] = float('NaN') - x[3][0] = float('NaN') - torch.sort(x, out=(res2val, res2ind)) - self.assertIsOrdered('ascending', x, res2val, res2ind, - 'random with NaNs') - torch.sort(x, out=(res2val, res2ind), descending=True) - self.assertIsOrdered('descending', x, res2val, res2ind, - 'random with NaNs') + t1 = torch.tensor([True, True], dtype=torch.bool) + t2 = torch.tensor([False, False], dtype=torch.bool) + t1.set_(t2) + self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) - @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') - def test_tensordot(self): - for d in torch.testing.get_all_device_types(): - a = torch.arange(60., device=d).reshape(3, 4, 5) - b = torch.arange(24., device=d).reshape(4, 3, 2) - c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu() - cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), - axes=([1, 0], [0, 1]))) - self.assertEqual(c, cn) - a = torch.randn(2, 3, 4, 5, device=d) - b = torch.randn(4, 5, 6, 7, device=d) - c = torch.tensordot(a, b, dims=2).cpu() - cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), - axes=2)) - self.assertEqual(c, cn) - c = torch.tensordot(a, b).cpu() - cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy())) - self.assertEqual(c, cn) + def test_tensor_set_errors(self): + f_cpu = torch.randn((2, 3), dtype=torch.float32) + d_cpu = torch.randn((2, 3), dtype=torch.float64) - def test_topk(self): - def topKViaSort(t, k, dim, dir): - sorted, indices = t.sort(dim, dir) - return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k) + # change dtype + self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu.storage())) + self.assertRaises(RuntimeError, + lambda: f_cpu.set_(d_cpu.storage(), 0, d_cpu.size(), d_cpu.stride())) + self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu)) - def compareTensors(t, res1, ind1, res2, ind2, dim): - # Values should be exactly equivalent - self.assertEqual(res1, res2, 0) + # change device + if torch.cuda.is_available(): + f_cuda = torch.randn((2, 3), dtype=torch.float32, device='cuda') - # Indices might differ based on the implementation, since there is - # no guarantee of the relative order of selection - if not ind1.eq(ind2).all(): - # To verify that the indices represent equivalent elements, - # gather from the input using the topk indices and compare against - # the sort indices - vals = t.gather(dim, ind2) - self.assertEqual(res1, vals, 0) + # cpu -> cuda + self.assertRaises(RuntimeError, lambda: f_cpu.set_(f_cuda.storage())) + self.assertRaises(RuntimeError, + lambda: f_cpu.set_(f_cuda.storage(), 0, f_cuda.size(), f_cuda.stride())) + self.assertRaises(RuntimeError, lambda: f_cpu.set_(f_cuda)) - def compare(t, k, dim, dir): - topKVal, topKInd = t.topk(k, dim, dir, True) - sortKVal, sortKInd = topKViaSort(t, k, dim, dir) - compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim) + # cuda -> cpu + self.assertRaises(RuntimeError, lambda: f_cuda.set_(f_cpu.storage())) + self.assertRaises(RuntimeError, + lambda: f_cuda.set_(f_cpu.storage(), 0, f_cpu.size(), f_cpu.stride())) + self.assertRaises(RuntimeError, lambda: f_cuda.set_(f_cpu)) - t = torch.rand(random.randint(1, SIZE), - random.randint(1, SIZE), - random.randint(1, SIZE)) + def test_equal(self): + # Contiguous, 1D + t1 = torch.Tensor((3, 4, 9, 10)) + t2 = t1.contiguous() + t3 = torch.Tensor((1, 9, 3, 10)) + t4 = torch.Tensor((3, 4, 9)) + t5 = torch.Tensor() + self.assertTrue(t1.equal(t2)) + self.assertFalse(t1.equal(t3)) + self.assertFalse(t1.equal(t4)) + self.assertFalse(t1.equal(t5)) + self.assertTrue(torch.equal(t1, t2)) + self.assertFalse(torch.equal(t1, t3)) + self.assertFalse(torch.equal(t1, t4)) + self.assertFalse(torch.equal(t1, t5)) - for _kTries in range(3): - for _dimTries in range(3): - for transpose in (True, False): - for dir in (True, False): - testTensor = t - if transpose: - dim1 = random.randrange(t.ndimension()) - dim2 = dim1 - while dim1 == dim2: - dim2 = random.randrange(t.ndimension()) - - testTensor = t.transpose(dim1, dim2) + # Non contiguous, 2D + s = torch.Tensor(((1, 2, 3, 4), (5, 6, 7, 8))) + s1 = s[:, 1:3] + s2 = s1.clone() + s3 = torch.Tensor(((2, 3), (6, 7))) + s4 = torch.Tensor(((0, 0), (0, 0))) - dim = random.randrange(testTensor.ndimension()) - k = random.randint(1, testTensor.size(dim)) - compare(testTensor, k, dim, dir) + self.assertFalse(s1.is_contiguous()) + self.assertTrue(s1.equal(s2)) + self.assertTrue(s1.equal(s3)) + self.assertFalse(s1.equal(s4)) + self.assertTrue(torch.equal(s1, s2)) + self.assertTrue(torch.equal(s1, s3)) + self.assertFalse(torch.equal(s1, s4)) - def test_topk_arguments(self): - q = torch.randn(10, 2, 10) - # Make sure True isn't mistakenly taken as the 2nd dimension (interpreted as 1) - self.assertRaises(TypeError, lambda: q.topk(4, True)) + def test_element_size(self): + byte = torch.ByteStorage().element_size() + char = torch.CharStorage().element_size() + short = torch.ShortStorage().element_size() + int = torch.IntStorage().element_size() + long = torch.LongStorage().element_size() + float = torch.FloatStorage().element_size() + double = torch.DoubleStorage().element_size() + bool = torch.BoolStorage().element_size() + bfloat16 = torch.BFloat16Storage().element_size() - @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') - def test_topk_noncontiguous_gpu(self): - t = torch.randn(20, device="cuda")[::2] - top1, idx1 = t.topk(5) - top2, idx2 = t.contiguous().topk(5) - self.assertEqual(top1, top2) - self.assertEqual(idx1, idx2) + self.assertEqual(byte, torch.ByteTensor().element_size()) + self.assertEqual(char, torch.CharTensor().element_size()) + self.assertEqual(short, torch.ShortTensor().element_size()) + self.assertEqual(int, torch.IntTensor().element_size()) + self.assertEqual(long, torch.LongTensor().element_size()) + self.assertEqual(float, torch.FloatTensor().element_size()) + self.assertEqual(double, torch.DoubleTensor().element_size()) + self.assertEqual(bool, torch.BoolTensor().element_size()) - @staticmethod - def _test_kthvalue(self, device='cpu'): - SIZE = 50 - x = torch.rand(SIZE, SIZE, SIZE, device=device) - x0 = x.clone() + self.assertGreater(byte, 0) + self.assertGreater(char, 0) + self.assertGreater(short, 0) + self.assertGreater(int, 0) + self.assertGreater(long, 0) + self.assertGreater(float, 0) + self.assertGreater(double, 0) + self.assertGreater(bool, 0) + self.assertGreater(bfloat16, 0) - k = random.randint(1, SIZE) - res1val, res1ind = torch.kthvalue(x, k, keepdim=False) - res2val, res2ind = torch.sort(x) + # These tests are portable, not necessarily strict for your system. + self.assertEqual(byte, 1) + self.assertEqual(char, 1) + self.assertEqual(bool, 1) + self.assertGreaterEqual(short, 2) + self.assertGreaterEqual(int, 2) + self.assertGreaterEqual(int, short) + self.assertGreaterEqual(long, 4) + self.assertGreaterEqual(long, int) + self.assertGreaterEqual(double, float) - self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0) - self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0) - # test use of result tensors - k = random.randint(1, SIZE) - res1val = torch.tensor([], device=device) - res1ind = torch.tensor([], dtype=torch.long, device=device) - torch.kthvalue(x, k, keepdim=False, out=(res1val, res1ind)) - res2val, res2ind = torch.sort(x) - self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0) - self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0) + def test_split(self): + tensor = torch.rand(7, 4) + split_size = 3 + dim = 0 + target_sizes = ([3, 4], [3, 4], [1, 4]) + splits = tensor.split(split_size, dim) + start = 0 + for target_size, split in zip(target_sizes, splits): + self.assertEqual(split.size(), target_size) + self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0) + start = start + target_size[dim] - # test non-default dim - k = random.randint(1, SIZE) - res1val, res1ind = torch.kthvalue(x, k, 0, keepdim=False) - res2val, res2ind = torch.sort(x, 0) - self.assertEqual(res1val, res2val[k - 1], 0) - self.assertEqual(res1ind, res2ind[k - 1], 0) + # Variable sections split + tensor = torch.randn(20, 10) + dim = 0 + split_sizes = [5, 5, 10] + target_sizes = ([[5, 10], [5, 10], [10, 10]]) + splits = tensor.split(split_sizes, dim) + start = 0 + for target_size, split in zip(target_sizes, splits): + self.assertEqual(split.size(), target_size) + self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0) + start = start + target_size[dim] - # non-contiguous - y = x.narrow(1, 0, 1) - y0 = y.contiguous() - k = random.randint(1, SIZE) - res1val, res1ind = torch.kthvalue(y, k) - res2val, res2ind = torch.kthvalue(y0, k) - self.assertEqual(res1val, res2val, 0) - self.assertEqual(res1ind, res2ind, 0) + split_sizes = [2, 2, 6] + target_sizes = ([20, 2], [20, 2], [20, 6]) + dim = 1 + splits = tensor.split(split_sizes, dim) + start = 0 + for target_size, split in zip(target_sizes, splits): + self.assertEqual(split.size(), target_size) + self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0) + start = start + target_size[dim] - # check that the input wasn't modified - self.assertEqual(x, x0, 0) + def test_chunk(self): + tensor = torch.rand(4, 7) + num_chunks = 3 + dim = 1 + target_sizes = ([4, 3], [4, 3], [4, 1]) + splits = tensor.chunk(num_chunks, dim) + start = 0 + for target_size, split in zip(target_sizes, splits): + self.assertEqual(split.size(), target_size) + self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0) + start = start + target_size[dim] - # simple test case (with repetitions) - y = torch.tensor((3., 5, 4, 1, 1, 5), device=device) - self.assertEqual(torch.kthvalue(y, 3)[0], 3, 0) - self.assertEqual(torch.kthvalue(y, 2)[0], 1, 0) + # Invalid chunk sizes + error_regex = 'chunk expects.*greater than 0' + with self.assertRaisesRegex(RuntimeError, error_regex): + tensor.chunk(0) + with self.assertRaisesRegex(RuntimeError, error_regex): + tensor.chunk(-2) - # simple test case (with NaN) - SIZE = 50 - x = torch.rand(SIZE, SIZE, SIZE, device=device) - x[torch.arange(SIZE), :, torch.randint(50, (50,))] = nan - ks = [random.randint(1, SIZE), 1, SIZE, SIZE - 1] - res2val, res2ind = torch.sort(x) - for k in ks: - res1val, res1ind = torch.kthvalue(x, k, keepdim=False) - self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0) - self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0) + def test_tolist(self): + list0D = [] + tensor0D = torch.Tensor(list0D) + self.assertEqual(tensor0D.tolist(), list0D) - def test_kthvalue(self): - self._test_kthvalue(self) + table1D = [1, 2, 3] + tensor1D = torch.Tensor(table1D) + storage = torch.Storage(table1D) + self.assertEqual(tensor1D.tolist(), table1D) + self.assertEqual(storage.tolist(), table1D) + self.assertEqual(tensor1D.tolist(), table1D) + self.assertEqual(storage.tolist(), table1D) - def test_median(self): - for size in (155, 156): - x = torch.rand(size, size) - x0 = x.clone() + table2D = [[1, 2], [3, 4]] + tensor2D = torch.Tensor(table2D) + self.assertEqual(tensor2D.tolist(), table2D) - nelem = x.nelement() - res1val = torch.median(x) - res2val, _ = torch.sort(x.view(nelem)) - ind = int(math.floor((nelem + 1) / 2) - 1) + tensor3D = torch.Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + tensorNonContig = tensor3D.select(1, 1) + self.assertFalse(tensorNonContig.is_contiguous()) + self.assertEqual(tensorNonContig.tolist(), [[3, 4], [7, 8]]) - self.assertEqual(res2val[ind], res1val, 0) + def test_permute(self): + orig = [1, 2, 3, 4, 5, 6, 7] + perm = torch.randperm(7).tolist() + x = torch.Tensor(*orig).fill_(0) + new = list(map(lambda x: x - 1, x.permute(*perm).size())) + self.assertEqual(perm, new) + self.assertEqual(x.size(), orig) - res1val, res1ind = torch.median(x, dim=1, keepdim=False) - res2val, res2ind = torch.sort(x) - ind = int(math.floor((size + 1) / 2) - 1) + def test_reversed(self): + val = torch.arange(0, 10) + self.assertEqual(reversed(val), torch.arange(9, -1, -1)) - self.assertEqual(res2val.select(1, ind), res1val, 0) - self.assertEqual(res2val.select(1, ind), res1val, 0) + val = torch.arange(1, 10).view(3, 3) + self.assertEqual(reversed(val), torch.tensor([[7, 8, 9], [4, 5, 6], [1, 2, 3]])) - # Test use of result tensor - res2val = torch.Tensor() - res2ind = torch.LongTensor() - torch.median(x, dim=-1, keepdim=False, out=(res2val, res2ind)) - self.assertEqual(res2val, res1val, 0) - self.assertEqual(res2ind, res1ind, 0) + val = torch.tensor(42) + self.assertEqual(reversed(val), torch.tensor(42)) - # Test non-default dim - res1val, res1ind = torch.median(x, 0, keepdim=False) - res2val, res2ind = torch.sort(x, 0) - self.assertEqual(res1val, res2val[ind], 0) - self.assertEqual(res1ind, res2ind[ind], 0) + def test_contains(self): + x = torch.arange(0, 10) + self.assertEqual(4 in x, True) + self.assertEqual(12 in x, False) - # input unchanged - self.assertEqual(x, x0, 0) + x = torch.arange(1, 10).view(3, 3) + val = torch.arange(1, 4) + self.assertEqual(val in x, True) + val += 10 + self.assertEqual(val in x, False) - def test_mode(self): - x = torch.arange(1., SIZE * SIZE + 1).clone().resize_(SIZE, SIZE) - x[:2] = 1 - x[:, :2] = 1 - x0 = x.clone() + self.assertRaisesRegex( + RuntimeError, + "Tensor.__contains__ only supports Tensor or scalar, but you passed in a {}.".format(type("foo")), + lambda: "foo" in x) + self.assertRaisesRegex( + RuntimeError, + "Tensor.__contains__ only supports Tensor or scalar, but you passed in a {}.".format(type([1, 2])), + lambda: [1, 2] in x) - # Pre-calculated results. - res1val = torch.Tensor(SIZE).fill_(1) - # The indices are the position of the last appearance of the mode element. - res1ind = torch.LongTensor(SIZE).fill_(1) - res1ind[0] = SIZE - 1 - res1ind[1] = SIZE - 1 + def test_storage(self): + v = torch.randn(3, 5) + self.assertEqual(v.storage()[0], v.data[0][0]) + self.assertEqual(v.storage()[14], v.data[2][4]) - res2val, res2ind = torch.mode(x, keepdim=False) - self.assertEqual(res1val, res2val, 0) - self.assertEqual(res1ind, res2ind, 0) + def test_deepcopy(self): + from copy import deepcopy + a = torch.randn(5, 5) + b = torch.randn(5, 5) + c = a.view(25) + q = [a, [a.storage(), b.storage()], b, c] + w = deepcopy(q) + self.assertEqual(w[0], q[0], 0) + self.assertEqual(w[1][0], q[1][0], 0) + self.assertEqual(w[1][1], q[1][1], 0) + self.assertEqual(w[1], q[1], 0) + self.assertEqual(w[2], q[2], 0) - # Test use of result tensor - res2val = torch.Tensor() - res2ind = torch.LongTensor() - torch.mode(x, keepdim=False, out=(res2val, res2ind)) - self.assertEqual(res1val, res2val, 0) - self.assertEqual(res1ind, res2ind, 0) + # Check that deepcopy preserves sharing + w[0].add_(1) + for i in range(a.numel()): + self.assertEqual(w[1][0][i], q[1][0][i] + 1) + self.assertEqual(w[3], c + 1) + w[2].sub_(1) + for i in range(a.numel()): + self.assertEqual(w[1][1][i], q[1][1][i] - 1) - # Test non-default dim - res2val, res2ind = torch.mode(x, 0, False) - self.assertEqual(res1val, res2val, 0) - self.assertEqual(res1ind, res2ind, 0) + def test_deepcopy_scalar(self): + from copy import deepcopy + a = torch.tensor(5) + self.assertEqual(a.size(), deepcopy(a).size()) + self.assertEqual(a, deepcopy(a)) - # input unchanged - self.assertEqual(x, x0, 0) + def test_deepcopy_parameter(self): + from copy import deepcopy + l = torch.nn.Linear(10, 1) + s = l.state_dict(keep_vars=True) + self.assertEqual(torch.nn.Parameter, type(s['weight'])) + self.assertEqual(torch.nn.Parameter, type(s['bias'])) - def test_trilu_indices(self): - for test_args in tri_tests_args: - _compare_trilu_indices(self, *test_args) - run_additional_tri_tests(self, 'cpu') + s2 = deepcopy(s) + self.assertEqual(torch.nn.Parameter, type(s2['weight'])) + self.assertEqual(torch.nn.Parameter, type(s2['bias'])) - # test default options - x = torch.ones( - 3, 3, dtype=torch.long, device='cpu', layout=torch.strided) - self.assertEqual( - x.tril(0).nonzero().transpose(0, 1), torch.tril_indices(3, 3)) - self.assertEqual( - x.triu(0).nonzero().transpose(0, 1), torch.triu_indices(3, 3)) + def test_pickle(self): + if sys.version_info[0] == 2: + import cPickle as pickle + else: + import pickle + a = torch.randn(5, 5) + serialized = pickle.dumps(a) + b = pickle.loads(serialized) + self.assertEqual(a, b) - # test stride 0 cases - x = torch.ones( - 3, 1, 3, 3, dtype=torch.long, device='cpu', layout=torch.strided) - output = x.triu(2).expand(3, 3, 3, 3) - b = x.clone().expand(3, 3, 3, 3) - self.assertEqual(b.triu(2), output) - self.assertRaises(RuntimeError, lambda: b.triu_(2)) + def test_pickle_parameter(self): + if sys.version_info[0] == 2: + import cPickle as pickle + else: + import pickle + a = torch.nn.Parameter(torch.randn(5, 5)) + serialized = pickle.dumps(a) + b = pickle.loads(serialized) + self.assertTrue(isinstance(b, torch.nn.Parameter)) + self.assertEqual(a.requires_grad, b.requires_grad) + self.assertEqual(a, b) - @staticmethod - def _test_triu_tril(self, cast): - def gen_mask(shape, diagonal, cast, upper): - mask = torch.zeros(*shape[-2:]).byte() - for i in range(shape[-2]): - for j in range(shape[-1]): - cond = j - i < diagonal if upper else j - i > diagonal - if cond: - mask[i, j] = 1 - return cast(mask.expand(*shape)) + def test_pickle_parameter_no_requires_grad(self): + if sys.version_info[0] == 2: + import cPickle as pickle + else: + import pickle + a = torch.nn.Parameter(torch.randn(5, 5), requires_grad=False) + serialized = pickle.dumps(a) + b = pickle.loads(serialized) + self.assertTrue(isinstance(b, torch.nn.Parameter)) + self.assertEqual(a.requires_grad, b.requires_grad) + self.assertEqual(a, b) - torch_functions = {True: torch.triu, False: torch.tril} - if TEST_NUMPY: - numpy_functions = {True: np.triu, False: np.tril} + def test_pickle_dtype(self): + t = torch.float32 + serialized = pickle.dumps(t) + b = pickle.loads(serialized) + self.assertTrue(isinstance(b, torch.dtype)) + self.assertEqual(id(b), id(t)) - # TODO: remove this when bool and half are supported for torch.where - def bool_half_compat_where(pred, true_tensor, false_tensor, dtype): - if dtype == torch.bool or dtype == torch.half: - return torch.where(pred.byte(), true_tensor.byte(), false_tensor.byte()).to(dtype=dtype) - else: - return torch.where(pred, true_tensor, false_tensor) + def test_pickle_size(self): + a = torch.rand(10).size() + serialized = pickle.dumps(a) + b = pickle.loads(serialized) + self.assertTrue(isinstance(b, torch.Size)) + self.assertEqual(a, b) - def run_test(shape, cast, diagonal, dtype): - x_cpu = torch.empty(*shape, dtype=dtype).fill_(2) - x = cast(x_cpu) + def test_norm_fastpaths(self): + x = torch.randn(3, 5) - for upper in [True, False]: - # normal test with mask - torch_tri_func = torch_functions[upper] - res1 = torch_tri_func(x, diagonal=diagonal) - res2 = cast(torch.empty(0, dtype=dtype)) - torch_tri_func(x, diagonal=diagonal, out=res2) - exp_mask = gen_mask(shape, diagonal, cast, upper) - expected = bool_half_compat_where(exp_mask, torch.tensor(0).type_as(x), x, dtype) - self.assertEqual(res1, res2, 0) - self.assertEqual(expected, res1, 0) + # slow path + result = torch.norm(x, 4.5, 1) + expected = torch.pow(x.abs().pow(4.5).sum(1), 1.0 / 4.5) + self.assertEqual(result, expected) - # non-contiguous and expanded tensors test - if 0 not in shape: - for s in range(-len(shape), -1): - # non-contiguous tensors - x_nc = x.clone().transpose(s, s + 1) - exp_mask = gen_mask(x_nc.size(), diagonal, cast, upper) - if 1 not in shape: - assert not x_nc.is_contiguous(), "x is intentionally non-contiguous" - exp_nc = bool_half_compat_where(exp_mask, torch.tensor(0).type_as(x), x_nc, dtype) - self.assertEqual(torch_tri_func(x_nc, diagonal), exp_nc, 0) - x_nc_is_contiguous = x_nc.is_contiguous() - if upper: - self.assertEqual(x_nc.triu_(diagonal), exp_nc, 0) - else: - self.assertEqual(x_nc.tril_(diagonal), exp_nc, 0) + # fast 0-norm + result = torch.norm(x, 0, 1) + expected = (x != 0).type_as(x).sum(1) + self.assertEqual(result, expected) - self.assertTrue(x_nc.is_contiguous() == x_nc_is_contiguous, - "contiguity of x_nc should not be changed") + # fast 1-norm + result = torch.norm(x, 1, 1) + expected = x.abs().sum(1) + self.assertEqual(result, expected) - # expanded tensors - expanded_size = (x.size(0),) + x.size() - x_expanded = x.clone().expand(*expanded_size) - if x.size(0) != 1: - assert 0 in x_expanded.stride(), "x intentionally has 0 in its stride" - output = torch_tri_func(x_expanded, diagonal) - self.assertEqual(output, expected.expand(expanded_size), 0) - if x.size(0) != 1: - self.assertTrue(0 in x_expanded.stride(), - "geometry of x_expanded should be the same") - if upper: - self.assertEqual(output, x_expanded.triu_(diagonal), 0) - else: - self.assertEqual(output, x_expanded.tril_(diagonal), 0) + # fast 2-norm + result = torch.norm(x, 2, 1) + expected = torch.sqrt(x.pow(2).sum(1)) + self.assertEqual(result, expected) - if not TEST_NUMPY: - continue + # fast 3-norm + result = torch.norm(x, 3, 1) + expected = torch.pow(x.pow(3).abs().sum(1), 1.0 / 3.0) + self.assertEqual(result, expected) - # numpy test - numpy_tri_func = numpy_functions[upper] - self.assertEqual(numpy_tri_func(x_cpu.numpy(), diagonal), res1.cpu().numpy()) + @staticmethod + def _test_bernoulli(self, t_dtype, p_dtype, device): + for trivial_p in ([0, 1], [1, 0, 1, 1, 0, 1]): + x = torch.tensor(trivial_p, dtype=p_dtype, device=device) + self.assertEqual(x.bernoulli().tolist(), trivial_p) - diagonals = [-2, -1, 0, 1, 2] - shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices - (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices - (3, 7), (5, 3, 7), (7, 5, 3, 7), # thin matrices - (3, 0), (0, 3, 3), (3, 3, 0, 0), # no numel matrices - (3, 1), (5, 3, 1), (7, 5, 3, 1), # very fat matrices - (1, 3), (5, 1, 3), (7, 5, 1, 3), # very thin matrices - (1, 3, 3, 3), (3, 1, 3, 3, 3)] # unsqueezed batch dimensions - dtypes = [dtype for dtype in torch.testing.get_all_dtypes() if dtype != torch.bfloat16] - for s, d, dtype in product(shapes, diagonals, dtypes): - run_test(s, cast, d, dtype) + def isBinary(t): + return torch.ne(t, 0).mul_(torch.ne(t, 1)).sum().item() == 0 - def test_triu_tril(self): - self._test_triu_tril(self, lambda t: t) + p = torch.rand(5, 5, dtype=p_dtype, device=device) + self.assertTrue(isBinary(p.bernoulli())) - def test_cat(self): - SIZE = 10 - for dtype in (torch.half, torch.double, torch.int): - for dim in range(-3, 3): - pos_dim = dim if dim >= 0 else 3 + dim - x = torch.randint(low=-100, high=100, size=(13, SIZE, SIZE)).to(dtype).transpose(0, pos_dim) - y = torch.randint(low=-100, high=100, size=(17, SIZE, SIZE)).to(dtype).transpose(0, pos_dim) - z = torch.randint(low=-100, high=100, size=(19, SIZE, SIZE)).to(dtype).transpose(0, pos_dim) + p = torch.rand(5, dtype=p_dtype, device=device).expand(5, 5) + self.assertTrue(isBinary(p.bernoulli())) - res1 = torch.cat((x, y, z), dim) - self.assertEqual(res1.narrow(pos_dim, 0, 13), x, 0) - self.assertEqual(res1.narrow(pos_dim, 13, 17), y, 0) - self.assertEqual(res1.narrow(pos_dim, 30, 19), z, 0) + p = torch.rand(5, 5, dtype=p_dtype, device=device) + torch.bernoulli(torch.rand_like(p), out=p) + self.assertTrue(isBinary(p)) - x = torch.randint(low=-100, high=100, size=(20, SIZE, SIZE)).to(dtype) - self.assertEqual(torch.cat(torch.split(x, 7)), x) - self.assertEqual(torch.cat(torch.chunk(x, 7)), x) + p = torch.rand(5, dtype=p_dtype, device=device).expand(5, 5) + torch.bernoulli(torch.rand_like(p), out=p) + self.assertTrue(isBinary(p)) - y = torch.randint(low=-100, high=100, size=(1, SIZE, SIZE)).to(dtype) - z = torch.cat([x, y]) - self.assertEqual(z.size(), (21, SIZE, SIZE)) + t = torch.empty(10, 10, dtype=t_dtype, device=device) - self.assertRaises(RuntimeError, lambda: torch.cat([])) - self.assertRaisesRegex(TypeError, 'got None', lambda: torch.cat([x, None])) + t.fill_(2) + t.bernoulli_(0.5) + self.assertTrue(isBinary(t)) - def test_cat_bad_input_sizes(self): - x = torch.randn(2, 1) - y = torch.randn(2, 1, 1) - z = torch.randn(2, 1, 1) - self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z])) + p = torch.rand(10, dtype=p_dtype, device=device).expand(10, 10) + t.fill_(2) + t.bernoulli_(p) + self.assertTrue(isBinary(t)) - x = torch.randn(2, 1, 2) - y = torch.randn(2, 1, 1) - z = torch.randn(2, 2, 1) - self.assertRaises(RuntimeError, lambda: torch.cat([x, y, z], dim=1)) + t.fill_(2) + torch.bernoulli(torch.rand_like(t, dtype=p_dtype), out=t) + self.assertTrue(isBinary(t)) - def test_cat_scalars(self): - x = torch.tensor(0) - y = torch.tensor(1) - with self.assertRaisesRegex(RuntimeError, 'zero-dimensional.*cannot be concatenated'): - torch.cat([x, y]) - - @staticmethod - def _test_cat_empty_legacy(self, use_cuda=False): - # FIXME: this is legacy behavior and should be removed - # when we support empty tensors with arbitrary sizes - dtype = torch.float32 - device = 'cuda' if use_cuda else 'cpu' + t.fill_(2) + t.bernoulli_(torch.rand_like(t, dtype=p_dtype)) + self.assertTrue(isBinary(t)) - x = torch.randn((4, 3, 32, 32), dtype=dtype, device=device) - empty = torch.randn((0,), dtype=dtype, device=device) + def test_bernoulli(self): + self._test_bernoulli(self, torch.float32, torch.float64, 'cpu') + # test that it works with integral tensors + self._test_bernoulli(self, torch.uint8, torch.float64, 'cpu') + # test that it works with bool tensors + self._test_bernoulli(self, torch.bool, torch.float32, 'cpu') - res1 = torch.cat([x, empty], dim=1) - res2 = torch.cat([empty, x], dim=1) - self.assertEqual(res1, res2) + def test_generator_cpu(self): + # test default generators are equal + self.assertEqual(torch.default_generator, torch.default_generator) - conv = torch.nn.Conv2d(3, 3, kernel_size=1).float() - if use_cuda: - conv = conv.cuda() - res1 = torch.cat([conv(x), empty], dim=1) - res2 = torch.cat([empty, conv(x)], dim=1) - self.assertEqual(res1, res2) + # tests Generator API + # manual_seed, seed, initial_seed, get_state, set_state + g1 = torch.Generator() + g2 = torch.Generator() + g1.manual_seed(12345) + g2.manual_seed(12345) + self.assertEqual(g1.initial_seed(), g2.initial_seed()) - res1 = torch.cat([empty, empty], dim=1) - self.assertEqual(res1, empty) + g1.seed() + g2.seed() + self.assertNotEqual(g1.initial_seed(), g2.initial_seed()) - with self.assertRaisesRegex(RuntimeError, - 'expected a non-empty list of Tensors'): - torch.cat([], dim=1) + g1 = torch.Generator() + g2_state = g2.get_state() + g2_randn = torch.randn(1, generator=g2) + g1.set_state(g2_state) + g1_randn = torch.randn(1, generator=g1) + self.assertEqual(g1_randn, g2_randn) - def test_cat_empty_legacy(self): - self._test_cat_empty_legacy(self) + default_state = torch.default_generator.get_state() + q = torch.Tensor(100) + g1_normal = q.normal_() + g2 = torch.Generator() + g2.set_state(default_state) + g2_normal = q.normal_(generator=g2) + self.assertEqual(g1_normal, g2_normal) - @staticmethod - def _test_cat_empty(self, use_cuda=False): - dtype = torch.float32 - device = 'cuda' if use_cuda else 'cpu' + def test_sobolengine_unscrambled_lowdim(self): + engine_1d = torch.quasirandom.SobolEngine(1) + expected_1d = torch.tensor([0.5, 0.75, 0.25, 0.375, 0.875, 0.625, 0.125, 0.1875, 0.6875, 0.9375]) + actual_1d = engine_1d.draw(10) + self.assertEqual(actual_1d.view(-1), expected_1d) + self.assertEqual(actual_1d.size(), torch.Size([10, 1])) - x = torch.randn((4, 3, 32, 32), dtype=dtype, device=device) - empty = torch.randn((4, 0, 32, 32), dtype=dtype, device=device) + # Test out kwarg + engine_1d.reset() + actual_1d_out = torch.Tensor().float() + engine_1d.draw(10, out=actual_1d_out) + self.assertEqual(actual_1d.view(-1), expected_1d) - res1 = torch.cat([x, empty], dim=1) - res2 = torch.cat([empty, x], dim=1) - self.assertEqual(res1, res2) + engine_3d = torch.quasirandom.SobolEngine(3) + expected_3d = torch.tensor([0.5, 0.75, 0.25, 0.625, 0.125, 0.375, 0.875, 0.3125, 0.8125, 0.5625]) + actual_3d = engine_3d.draw(10) + self.assertEqual(actual_3d[:, 2], expected_3d) + self.assertEqual(actual_3d[:, 0], expected_1d) + self.assertEqual(actual_3d.size(), torch.Size([10, 3])) - conv = torch.nn.Conv2d(3, 3, kernel_size=1).float() - if use_cuda: - conv = conv.cuda() - res1 = torch.cat([conv(x), empty], dim=1) - res2 = torch.cat([empty, conv(x)], dim=1) - self.assertEqual(res1, res2) + engine_3d = torch.quasirandom.SobolEngine(3) + draws = torch.cat([engine_3d.draw() for _ in range(0, 10)]) + self.assertEqual(draws, actual_3d) - res1 = torch.cat([empty, empty], dim=1) - self.assertEqual(res1, empty) + engine_3d = torch.quasirandom.SobolEngine(3).fast_forward(5) + draws = engine_3d.draw(5) + self.assertEqual(draws, actual_3d[5:]) + engine_3d.reset() + self.assertEqual(engine_3d.draw(3), actual_3d[:3]) + engine_3d.fast_forward(2) + self.assertEqual(engine_3d.draw(5), actual_3d[5:]) - # check non-legacy-behavior (sizes don't match) - empty = torch.randn((4, 0, 31, 32), dtype=dtype, device=device) - self.assertRaises(RuntimeError, lambda: torch.cat([x, empty], dim=1)) - self.assertRaises(RuntimeError, lambda: torch.cat([empty, x], dim=1)) + def test_sobolengine_unscrambled_highdim(self): + from collections import Counter + engine = torch.quasirandom.SobolEngine(1111) + count1 = dict(Counter(engine.draw().view(-1).tolist())) + count2 = dict(Counter(engine.draw().view(-1).tolist())) + count3 = dict(Counter(engine.draw().view(-1).tolist())) + self.assertTrue(count1 == {0.5: 1111}) + self.assertTrue(count2 == {0.25: 580, 0.75: 531}) + self.assertTrue(count3 == {0.25: 531, 0.75: 580}) - # check non-legacy-behavior (dimensions don't match) - empty = torch.randn((4, 0), dtype=dtype, device=device) - self.assertRaises(RuntimeError, lambda: torch.cat([x, empty], dim=1)) - self.assertRaises(RuntimeError, lambda: torch.cat([empty, x], dim=1)) + engine = torch.quasirandom.SobolEngine(1111) + draws = engine.draw(1000) + self.assertTrue(torch.all(draws <= 1)) + self.assertTrue(torch.all(draws >= 0)) - def test_cat_empty(self): - self._test_cat_empty(self) + def test_sobolengine_scrambled_lowdim(self): + engine_1d = torch.quasirandom.SobolEngine(1, scramble=True, seed=1729) + expected_1d = [0.16478512, 0.43221009, 0.84261382, 0.99750268, 0.27460563, + 0.01084163, 0.73373985, 0.65039611, 0.12329865, 0.35587373] + actual_1d = engine_1d.draw(10) + self.assertEqual(actual_1d.flatten(), torch.tensor(expected_1d)) + self.assertEqual(actual_1d.size(), torch.Size([10, 1])) + # make sure random seed if chosen if none is provided + engine_1d_a = torch.quasirandom.SobolEngine(1, scramble=True) + engine_1d_b = torch.quasirandom.SobolEngine(1, scramble=True) + self.assertNotEqual(engine_1d_a.draw(2), engine_1d_b.draw(2)) - @slowTest - def test_cat_big(self): - SIZE1 = 6500 - SIZE2 = 4500 - concat_list = [] - concat_list.append(torch.ones((SIZE1, 1024 * 512), dtype=torch.uint8)) - concat_list.append(torch.ones((SIZE2, 1024 * 512), dtype=torch.uint8)) - result = torch.cat(concat_list) - self.assertEqual(result.size(0), SIZE1 + SIZE2) + engine_3d = torch.quasirandom.SobolEngine(3, scramble=True, seed=1729) + expected_3d = [0.32642800, 0.17881306, 0.68837059, 0.46492538, 0.91789097, + 0.58075899, 0.03642474, 0.68229187, 0.20051685, 0.30083340] + actual_3d = engine_3d.draw(10) + self.assertEqual(actual_3d[:, 2], torch.tensor(expected_3d)) + self.assertEqual(actual_3d.size(), torch.Size([10, 3])) - def test_narrow(self): - x = torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) - self.assertEqual(x.narrow(0, 0, 1), torch.Tensor([[0, 1, 2]])) - self.assertEqual(x.narrow(0, 0, 2), torch.Tensor([[0, 1, 2], [3, 4, 5]])) - self.assertEqual(x.narrow(0, 1, 1), torch.Tensor([[3, 4, 5]])) - self.assertEqual(x.narrow(0, -1, 1), torch.Tensor([[6, 7, 8]])) - self.assertEqual(x.narrow(0, -2, 2), torch.Tensor([[3, 4, 5], [6, 7, 8]])) - self.assertEqual(x.narrow(0, -3, 3), torch.Tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])) - self.assertEqual(x.narrow(-1, -1, 1), torch.Tensor([[2], [5], [8]])) - self.assertEqual(x.narrow(-2, -1, 1), torch.Tensor([[6, 7, 8]])) + engine_3d = torch.quasirandom.SobolEngine(3, scramble=True, seed=1729) + draws = torch.cat([engine_3d.draw() for _ in range(0, 10)]) + self.assertEqual(draws, actual_3d) - def test_narrow_empty(self): - for device in torch.testing.get_all_device_types(): - x = torch.randn(2, 3, 4, device=device) - for d in range(x.dim()): - y = x.narrow(d, x.size(d), 0) - sz = list(x.size()) - sz[d] = 0 - self.assertEqual(sz, y.size()) + engine_3d = torch.quasirandom.SobolEngine(3, scramble=True, seed=1729) + engine_3d.fast_forward(5) + draws = engine_3d.draw(5) + self.assertEqual(draws, actual_3d[5:]) + engine_3d.reset() + self.assertEqual(engine_3d.draw(3), actual_3d[:3]) + engine_3d.fast_forward(2) + self.assertEqual(engine_3d.draw(5), actual_3d[5:]) - def test_stack(self): - for dtype in (torch.half, torch.double, torch.int): - x = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) - y = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) - z = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) - for dim in range(4): - res = torch.stack((x, y, z), dim) - res_neg = torch.stack((x, y, z), dim - 4) - expected_size = x.size()[:dim] + (3,) + x.size()[dim:] - self.assertEqual(res, res_neg) - self.assertEqual(res.size(), expected_size) - self.assertEqual(res.select(dim, 0), x, 0) - self.assertEqual(res.select(dim, 1), y, 0) - self.assertEqual(res.select(dim, 2), z, 0) + def test_sobolengine_scrambled_highdim(self): + engine = torch.quasirandom.SobolEngine(1111, scramble=True) + draws = engine.draw(1000) + self.assertTrue(torch.all(draws <= 1)) + self.assertTrue(torch.all(draws >= 0)) - def test_stack_out(self): - for dtype in (torch.half, torch.double, torch.int): - x = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) - y = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) - z = torch.randint(low=-100, high=100, size=(2, 3, 4)).to(dtype) - for dim in range(4): - expected_size = x.size()[:dim] + (3,) + x.size()[dim:] - res_out = x.new(expected_size) - res_neg_out = x.new(expected_size) - res_out_dp = res_out.data_ptr() - res_out_neg_dp = res_neg_out.data_ptr() - torch.stack((x, y, z), dim, out=res_out) - torch.stack((x, y, z), dim - 4, out=res_neg_out) - self.assertEqual(res_out, res_neg_out) - self.assertEqual(res_out.size(), expected_size) - self.assertEqual(res_out_dp, res_out.data_ptr()) - self.assertEqual(res_out_neg_dp, res_neg_out.data_ptr()) - self.assertEqual(res_out.select(dim, 0), x, 0) - self.assertEqual(res_out.select(dim, 1), y, 0) - self.assertEqual(res_out.select(dim, 2), z, 0) + def test_parsing_int64(self): + # accepts integer arguments + x = torch.cumsum(torch.ones(5, 5), 0) + self.assertEqual(x, torch.cumsum(torch.ones(5, 5), torch.tensor(0))) + # doesn't accept floating point variables + self.assertRaises(TypeError, lambda: torch.cumsum(torch.ones(5, 5), torch.tensor(0.))) - def test_unbind(self): - x = torch.rand(2, 3, 4, 5) - for dim in range(4): - res = torch.unbind(x, dim) - res2 = x.unbind(dim) - self.assertEqual(x.size(dim), len(res)) - self.assertEqual(x.size(dim), len(res2)) - for i in range(dim): - self.assertEqual(x.select(dim, i), res[i]) - self.assertEqual(x.select(dim, i), res2[i]) + def test_parsing_double(self): + # accepts floating point and integer arguments + x = torch.randn(2, 3) + torch.isclose(x, x, 1, 1) + self.assertTrue(torch.isclose(x, x, 1, 1).all()) + self.assertTrue(torch.isclose(x, x, 1.5, 1.).all()) + # accepts floating point and integer tensors + self.assertTrue(torch.isclose(x, x, torch.tensor(1), torch.tensor(1)).all()) + self.assertTrue(torch.isclose(x, x, torch.tensor(1.5), torch.tensor(1.)).all()) + # doesn't accept variables with requires_grad + self.assertRaises(TypeError, + lambda: torch.isclose(x, x, torch.tensor(1.5), torch.tensor(1., requires_grad=True)).all()) - def test_linspace(self): - for device in torch.testing.get_all_device_types(): - _from = random.random() - to = _from + random.random() - res1 = torch.linspace(_from, to, 137, device=device) - res2 = torch.tensor((), device=device) - torch.linspace(_from, to, 137, out=res2) - self.assertEqual(res1, res2, 0) - self.assertRaises(RuntimeError, lambda: torch.linspace(0, 1, -1, device=device)) - self.assertEqual(torch.linspace(0, 1, 1, device=device), torch.zeros(1, device=device), 0) + def test_parsing_intlist(self): + # parse with integer variables + self.assertEqual(torch.Size([3, 4]), torch.ones((torch.tensor(3), torch.tensor(4))).shape) + self.assertEqual(torch.Size([3, 4]), torch.ones(torch.tensor(3), torch.tensor(4)).shape) + # parse with numpy integers + if TEST_NUMPY: + self.assertEqual(torch.Size([3, 4]), torch.ones((np.array(3), np.int64(4))).shape) + self.assertEqual(torch.Size([3, 4]), torch.ones(np.array(3), np.int64(4)).shape) + self.assertEqual(torch.Size([3, 4]), torch.ones((np.int64(3), np.array(4))).shape) + self.assertEqual(torch.Size([3, 4]), torch.ones(np.int64(3), np.array(4)).shape) - # Check linspace for generating with start > end. - self.assertEqual(torch.linspace(2, 0, 3, device=device), torch.tensor((2, 1, 0), device=device), 0) + # fail parse with float variables + self.assertRaises(TypeError, lambda: torch.ones((torch.tensor(3.), torch.tensor(4)))) + # fail parse with numpy floats + if TEST_NUMPY: + self.assertRaises(TypeError, lambda: torch.ones((np.float(3.), torch.tensor(4)))) + self.assertRaises(TypeError, lambda: torch.ones((np.array(3.), torch.tensor(4)))) - # Check linspace for non-contiguous tensors. - x = torch.zeros(2, 3, device=device) - y = torch.linspace(0, 3, 4, out=x.narrow(1, 1, 2)) - self.assertEqual(x, torch.tensor(((0, 0, 1), (0, 2, 3)), device=device), 0) + # fail parse with > 1 element variables + self.assertRaises(TypeError, lambda: torch.ones(torch.tensor(3, 3))) + self.assertRaises(TypeError, lambda: torch.ones((torch.tensor(3, 3)))) + if TEST_NUMPY: + self.assertRaises(TypeError, lambda: torch.ones(np.array(3, 3))) + self.assertRaises(TypeError, lambda: torch.ones((np.array(3, 3)))) - def test_logspace(self): - _from = random.random() - to = _from + random.random() - res1 = torch.logspace(_from, to, 137) - res2 = torch.Tensor() - torch.logspace(_from, to, 137, out=res2) - self.assertEqual(res1, res2, 0) - self.assertRaises(RuntimeError, lambda: torch.logspace(0, 1, -1)) - self.assertEqual(torch.logspace(0, 1, 1), torch.ones(1), 0) + # fail parse with additional positional args after intlist arg + self.assertRaisesRegex(TypeError, + "received an invalid combination of arguments", + lambda: torch.LongTensor((6, 0), 1, 1, 0)) + self.assertRaisesRegex(TypeError, + "missing 1 required positional arguments", + lambda: torch.tensor().new_zeros((5, 5), 0)) - # Check non-default base=2 - self.assertEqual(torch.logspace(1, 1, 1, 2), torch.ones(1) * 2) - self.assertEqual(torch.logspace(0, 2, 3, 2), torch.Tensor((1, 2, 4))) + def _test_serialization_data(self): + a = [torch.randn(5, 5).float() for i in range(2)] + b = [a[i % 2] for i in range(4)] # 0-3 + b += [a[0].storage()] # 4 + b += [a[0].reshape(-1)[1:4].storage()] # 5 + b += [torch.arange(1, 11).int()] # 6 + t1 = torch.FloatTensor().set_(a[0].reshape(-1)[1:4].clone().storage(), 0, (3,), (1,)) + t2 = torch.FloatTensor().set_(a[0].reshape(-1)[1:4].clone().storage(), 0, (3,), (1,)) + b += [(t1.storage(), t1.storage(), t2.storage())] # 7 + b += [a[0].reshape(-1)[0:2].storage()] # 8 + return b - # Check logspace_ for generating with start > end. - self.assertEqual(torch.logspace(1, 0, 2), torch.Tensor((10, 1)), 0) + def _test_serialization_assert(self, b, c): + self.assertEqual(b, c, 0) + self.assertTrue(isinstance(c[0], torch.FloatTensor)) + self.assertTrue(isinstance(c[1], torch.FloatTensor)) + self.assertTrue(isinstance(c[2], torch.FloatTensor)) + self.assertTrue(isinstance(c[3], torch.FloatTensor)) + self.assertTrue(isinstance(c[4], torch.FloatStorage)) + c[0].fill_(10) + self.assertEqual(c[0], c[2], 0) + self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0) + c[1].fill_(20) + self.assertEqual(c[1], c[3], 0) + # I have to do it in this roundabout fashion, because there's no + # way to slice storages + for i in range(4): + self.assertEqual(c[4][i + 1], c[5][i]) - # Check logspace_ for non-contiguous tensors. - x = torch.zeros(2, 3) - y = torch.logspace(0, 3, 4, out=x.narrow(1, 1, 2)) - self.assertEqual(x, torch.Tensor(((0, 1, 10), (0, 100, 1000))), 0) + # check that serializing the same storage view object unpickles + # it as one object not two (and vice versa) + views = c[7] + self.assertEqual(views[0]._cdata, views[1]._cdata) + self.assertEqual(views[0], views[2]) + self.assertNotEqual(views[0]._cdata, views[2]._cdata) - def test_rand(self): - torch.manual_seed(123456) - res1 = torch.rand(SIZE, SIZE) - res2 = torch.Tensor() - torch.manual_seed(123456) - torch.rand(SIZE, SIZE, out=res2) - self.assertEqual(res1, res2) + rootview = c[8] + self.assertEqual(rootview.data_ptr(), c[0].data_ptr()) - def test_randint(self): - torch.manual_seed(123456) - res1 = torch.randint(0, 6, (SIZE, SIZE)) - res2 = torch.Tensor() - torch.manual_seed(123456) - torch.randint(0, 6, (SIZE, SIZE), out=res2) - torch.manual_seed(123456) - res3 = torch.randint(6, (SIZE, SIZE)) - res4 = torch.Tensor() - torch.manual_seed(123456) - torch.randint(6, (SIZE, SIZE), out=res4) - self.assertEqual(res1, res2) - self.assertEqual(res1, res3) - self.assertEqual(res1, res4) - self.assertEqual(res2, res3) - self.assertEqual(res2, res4) - self.assertEqual(res3, res4) - res1 = res1.view(-1) - high = (res1 < 6).type(torch.LongTensor) - low = (res1 >= 0).type(torch.LongTensor) - tensorSize = res1.size()[0] - assert(tensorSize == high.sum()) - assert(tensorSize == low.sum()) + def test_serialization(self): + # Test serialization with a real file + b = self._test_serialization_data() + for use_name in (False, True): + # Passing filename to torch.save(...) will cause the file to be opened twice, + # which is not supported on Windows + if sys.platform == "win32" and use_name: + continue + with tempfile.NamedTemporaryFile() as f: + handle = f if not use_name else f.name + torch.save(b, handle) + f.seek(0) + c = torch.load(handle) + self._test_serialization_assert(b, c) + # test non-ascii encoding of bytes arrays/strings + # The following bytes are produced by serializing + # [b'\xc5\xbc\xc4\x85\xc4\x85\xc3\xb3\xc5\xbc\xc4\x85\xc5\xbc', torch.zeros(1, dtype=torch.float), 2] + # in Python 2.7.12 and PyTorch 0.4.1, where the first element contains + # bytes of some utf-8 characters (i.e., `utf8_str.encode('utf-8')`). + serialized = ( + b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.' + b'\x80\x02}q\x01(U\x10protocol_versionq\x02M\xe9\x03U\n' + b'type_sizesq\x03}q\x04(U\x03intq\x05K\x04U\x05shortq\x06K\x02U' + b'\x04longq\x07K\x04uU\rlittle_endianq\x08\x88u.\x80\x02]q' + b'\x01(U\x0e\xc5\xbc\xc4\x85\xc4\x85\xc3\xb3\xc5\xbc\xc4\x85' + b'\xc5\xbcq\x02ctorch._utils\n_rebuild_tensor_v2\nq\x03((U' + b'\x07storageq\x04ctorch\nFloatStorage\nq\x05U\x0845640624q' + b'\x06U\x03cpuq\x07\x8a\x01\x01NtQK\x00K\x01\x85K\x01\x85' + b'\x89NtRq\x08K\x02e.\x80\x02]q\x01U\x0845640624q\x02a.\x01\x00' + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' + ) + buf = io.BytesIO(serialized) + utf8_bytes = b'\xc5\xbc\xc4\x85\xc4\x85\xc3\xb3\xc5\xbc\xc4\x85\xc5\xbc' + utf8_str = utf8_bytes.decode('utf-8') + if PY3: + with self.assertRaisesRegex(UnicodeDecodeError, "'ascii' codec can't decode byte"): + loaded = torch.load(buf) + buf.seek(0) + loaded_utf8 = torch.load(buf, encoding='utf-8') + self.assertEqual(loaded_utf8, [utf8_str, torch.zeros(1, dtype=torch.float), 2]) + buf.seek(0) + loaded_bytes = torch.load(buf, encoding='bytes') + else: + loaded_bytes = torch.load(buf) + self.assertEqual(loaded_bytes, [utf8_bytes, torch.zeros(1, dtype=torch.float), 2]) - def test_randn(self): - torch.manual_seed(123456) - res1 = torch.randn(SIZE, SIZE) - res2 = torch.Tensor() - torch.manual_seed(123456) - torch.randn(SIZE, SIZE, out=res2) - self.assertEqual(res1, res2) + def test_serialization_filelike(self): + # Test serialization (load and save) with a filelike object + b = self._test_serialization_data() + with BytesIOContext() as f: + torch.save(b, f) + f.seek(0) + c = torch.load(f) + self._test_serialization_assert(b, c) - def test_slice(self): - empty = torch.empty(0, 4) - x = torch.arange(0., 16).view(4, 4) - self.assertEqual(x[:], x) - self.assertEqual(x[:4], x) - # start and stop are clamped to the size of dim - self.assertEqual(x[:5], x) - # if start >= stop then the result is empty - self.assertEqual(x[2:1], empty) - self.assertEqual(x[2:2], empty) - # out of bounds is also empty - self.assertEqual(x[10:12], empty) - # additional correctness checks - self.assertEqual(x[:1].data.tolist(), [[0, 1, 2, 3]]) - self.assertEqual(x[:-3].data.tolist(), [[0, 1, 2, 3]]) - self.assertEqual(x[:, -2:3].data.tolist(), [[2], [6], [10], [14]]) - self.assertEqual(x[0:-1:2].data.tolist(), [[0, 1, 2, 3], [8, 9, 10, 11]]) + @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows") + def test_serialization_fake_zip(self): + data = [ + ord('P'), + ord('K'), + 5, + 6 + ] + for i in range(0, 100): + data.append(0) + t = torch.tensor(data, dtype=torch.uint8) - def test_is_signed(self): - self.assertEqual(torch.IntTensor(5).is_signed(), True) - self.assertEqual(torch.ByteTensor(5).is_signed(), False) - self.assertEqual(torch.CharTensor(5).is_signed(), True) - self.assertEqual(torch.FloatTensor(5).is_signed(), True) - self.assertEqual(torch.HalfTensor(10).is_signed(), True) - - @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') - def test_is_signed_cuda(self): - self.assertEqual(torch.cuda.IntTensor(5).is_signed(), True) - self.assertEqual(torch.cuda.ByteTensor(5).is_signed(), False) - self.assertEqual(torch.cuda.CharTensor(5).is_signed(), True) - self.assertEqual(torch.cuda.FloatTensor(5).is_signed(), True) - self.assertEqual(torch.cuda.HalfTensor(10).is_signed(), True) + with tempfile.NamedTemporaryFile() as f: + torch.save(t, f.name) - @staticmethod - def _test_solve(self, cast): - a = cast(torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23), - (-6.05, -3.30, 5.36, -4.44, 1.08), - (-0.45, 2.58, -2.70, 0.27, 9.04), - (8.32, 2.71, 4.35, -7.17, 2.14), - (-9.67, -5.14, -7.26, 6.08, -6.87)))).t() - b = cast(torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03), - (-1.56, 4.00, -8.67, 1.75, 2.86), - (9.81, -4.09, -4.57, -8.61, 8.99)))).t() - - res1 = torch.solve(b, a)[0] - self.assertLessEqual(b.dist(torch.mm(a, res1)), 1e-12) - - ta = cast(torch.Tensor()) - tb = cast(torch.Tensor()) - res2 = torch.solve(b, a, out=(tb, ta))[0] - res3 = torch.solve(b, a, out=(b, a))[0] - self.assertEqual(res1, tb) - self.assertEqual(res1, b) - self.assertEqual(res1, res2) - self.assertEqual(res1, res3) + # If this check is False for all Python versions (i.e. the fix + # has been backported), this test and torch.serialization._is_zipfile + # can be deleted + self.assertTrue(zipfile.is_zipfile(f)) + self.assertFalse(torch.serialization._is_zipfile(f)) + self.assertEqual(torch.load(f.name), t) - # test reuse - res1 = torch.solve(b, a)[0] - ta = cast(torch.Tensor()) - tb = cast(torch.Tensor()) - torch.solve(b, a, out=(tb, ta))[0] - self.assertEqual(res1, tb) - torch.solve(b, a, out=(tb, ta))[0] - self.assertEqual(res1, tb) + def test_serialization_gzip(self): + # Test serialization with gzip file + b = self._test_serialization_data() + f1 = tempfile.NamedTemporaryFile(delete=False) + f2 = tempfile.NamedTemporaryFile(delete=False) + torch.save(b, f1) + with open(f1.name, 'rb') as f_in, gzip.open(f2.name, 'wb') as f_out: + shutil.copyfileobj(f_in, f_out) - @skipIfNoLapack - def test_solve(self): - self._test_solve(self, lambda t: t) + with gzip.open(f2.name, 'rb') as f: + c = torch.load(f) + self._test_serialization_assert(b, c) - @staticmethod - def _test_solve_batched(self, cast): - from common_utils import random_fullrank_matrix_distinct_singular_value - # test against solve: one batch - A = cast(random_fullrank_matrix_distinct_singular_value(5, 1)) - b = cast(torch.randn(1, 5, 10)) - x_exp, LU_exp = torch.solve(b.squeeze(0), A.squeeze(0)) - x, LU = torch.solve(b, A) - self.assertEqual(x, x_exp.unsqueeze(0)) - self.assertEqual(LU, LU_exp.unsqueeze(0)) - - # test against solve in a loop: four batches - A = cast(random_fullrank_matrix_distinct_singular_value(5, 4)) - b = cast(torch.randn(4, 5, 10)) - - x_exp_list = [] - LU_exp_list = [] - for i in range(4): - x_exp, LU_exp = torch.solve(b[i], A[i]) - x_exp_list.append(x_exp) - LU_exp_list.append(LU_exp) - x_exp = torch.stack(x_exp_list) - LU_exp = torch.stack(LU_exp_list) + def test_serialization_offset(self): + a = torch.randn(5, 5) + b = torch.randn(2, 2) + m = torch.nn.Conv2d(1, 1, (1, 3)) + i, j = 41, 43 + with tempfile.NamedTemporaryFile() as f: + pickle.dump(i, f) + torch.save(a, f) + pickle.dump(j, f) + torch.save(b, f) + torch.save(m, f) + f.seek(0) + i_loaded = pickle.load(f) + a_loaded = torch.load(f) + j_loaded = pickle.load(f) + b_loaded = torch.load(f) + m_loaded = torch.load(f) + self.assertTrue(torch.equal(a, a_loaded)) + self.assertTrue(torch.equal(b, b_loaded)) + self.assertTrue(m.kernel_size == m_loaded.kernel_size) + self.assertEqual(i, i_loaded) + self.assertEqual(j, j_loaded) - x, LU = torch.solve(b, A) - self.assertEqual(x, x_exp) - self.assertEqual(LU, LU_exp) + def test_serialization_offset_filelike(self): + a = torch.randn(5, 5) + b = torch.randn(2, 3) + i, j = 41, 43 + with BytesIOContext() as f: + pickle.dump(i, f) + torch.save(a, f) + pickle.dump(j, f) + torch.save(b, f) + f.seek(0) + i_loaded = pickle.load(f) + a_loaded = torch.load(f) + j_loaded = pickle.load(f) + b_loaded = torch.load(f) + self.assertTrue(torch.equal(a, a_loaded)) + self.assertTrue(torch.equal(b, b_loaded)) + self.assertEqual(i, i_loaded) + self.assertEqual(j, j_loaded) - # basic correctness test - A = cast(random_fullrank_matrix_distinct_singular_value(5, 3)) - b = cast(torch.randn(3, 5, 10)) - x, LU = torch.solve(b, A) - self.assertEqual(torch.matmul(A, x), b) + def test_serialization_offset_gzip(self): + a = torch.randn(5, 5) + i = 41 + f1 = tempfile.NamedTemporaryFile(delete=False) + f2 = tempfile.NamedTemporaryFile(delete=False) + with open(f1.name, 'wb') as f: + pickle.dump(i, f) + torch.save(a, f) + with open(f1.name, 'rb') as f_in, gzip.open(f2.name, 'wb') as f_out: + shutil.copyfileobj(f_in, f_out) - # Test non-contiguous inputs. - if not TEST_NUMPY: - return - from numpy.linalg import solve - A = cast(random_fullrank_matrix_distinct_singular_value(2, 2)).permute(1, 0, 2) - b = cast(torch.randn(2, 2, 2)).permute(2, 1, 0) - x, _ = torch.solve(b, A) - x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())) - self.assertEqual(x.data, cast(x_exp)) + with gzip.open(f2.name, 'rb') as f: + j = pickle.load(f) + b = torch.load(f) + self.assertTrue(torch.equal(a, b)) + self.assertEqual(i, j) - @skipIfNoLapack - def test_solve_batched(self): - self._test_solve_batched(self, lambda t: t) + def test_half_tensor(self): + x = torch.randn(5, 5).float() + y = torch.randn(5, 5).float() + xh, yh = x.half(), y.half() - @staticmethod - def _test_solve_batched_many_batches(self, cast): - from common_utils import random_fullrank_matrix_distinct_singular_value + self.assertEqual(x.half().float(), x, 1e-3) - A = cast(random_fullrank_matrix_distinct_singular_value(5, 256, 256)) - b = cast(torch.randn(5, 1)) - x, _ = torch.solve(b, A) - self.assertEqual(torch.matmul(A, x), b.expand(A.shape[:-2] + (5, 1))) + z = torch.Tensor(5, 5) + self.assertEqual(z.copy_(xh), x, 1e-3) - A = cast(random_fullrank_matrix_distinct_singular_value(3)) - b = cast(torch.randn(512, 512, 3, 1)) - x, _ = torch.solve(b, A) - self.assertEqual(torch.matmul(A, x), b) + with tempfile.NamedTemporaryFile() as f: + torch.save(xh, f) + f.seek(0) + xh2 = torch.load(f) + self.assertEqual(xh.float(), xh2.float()) - @slowTest - @skipIfNoLapack - def test_solve_batched_many_batches(self): - self._test_solve_batched_many_batches(self, lambda t: t.cuda()) + def test_serialize_device(self): + device_str = ['cpu', 'cpu:0', 'cuda', 'cuda:0'] + device_obj = [torch.device(d) for d in device_str] + for device in device_obj: + device_copied = copy.deepcopy(device) + self.assertEqual(device, device_copied) - @staticmethod - def _test_solve_batched_dims(self, cast): - if not TEST_NUMPY: - return + def test_serialization_backwards_compat(self): + a = [torch.arange(1 + i, 26 + i).view(5, 5).float() for i in range(2)] + b = [a[i % 2] for i in range(4)] + b += [a[0].storage()] + b += [a[0].reshape(-1)[1:4].clone().storage()] + path = download_file('https://download.pytorch.org/test_data/legacy_serialized.pt') + c = torch.load(path) + self.assertEqual(b, c, 0) + self.assertTrue(isinstance(c[0], torch.FloatTensor)) + self.assertTrue(isinstance(c[1], torch.FloatTensor)) + self.assertTrue(isinstance(c[2], torch.FloatTensor)) + self.assertTrue(isinstance(c[3], torch.FloatTensor)) + self.assertTrue(isinstance(c[4], torch.FloatStorage)) + c[0].fill_(10) + self.assertEqual(c[0], c[2], 0) + self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0) + c[1].fill_(20) + self.assertEqual(c[1], c[3], 0) - from numpy.linalg import solve - from common_utils import random_fullrank_matrix_distinct_singular_value - # test against numpy.linalg.solve - A = cast(random_fullrank_matrix_distinct_singular_value(4, 2, 1, 3)) - b = cast(torch.randn(2, 1, 3, 4, 6)) - x, _ = torch.solve(b, A) - x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())) - self.assertEqual(x.data, cast(x_exp)) - - # test column major format - A = cast(random_fullrank_matrix_distinct_singular_value(4, 2, 1, 3)).transpose(-2, -1) - b = cast(torch.randn(2, 1, 3, 6, 4)).transpose(-2, -1) - assert not A.is_contiguous() - assert not b.is_contiguous() - x, _ = torch.solve(b, A) - x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())) - self.assertEqual(x.data, cast(x_exp)) + # test some old tensor serialization mechanism + class OldTensorBase(object): + def __init__(self, new_tensor): + self.new_tensor = new_tensor - # broadcasting b - A = cast(random_fullrank_matrix_distinct_singular_value(4, 2, 1, 3)) - b = cast(torch.randn(4, 6)) - x, _ = torch.solve(b, A) - x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())) - self.assertEqual(x.data, cast(x_exp)) + def __getstate__(self): + return (self.new_tensor.storage(), + self.new_tensor.storage_offset(), + tuple(self.new_tensor.size()), + self.new_tensor.stride()) - # broadcasting A - A = cast(random_fullrank_matrix_distinct_singular_value(4)) - b = cast(torch.randn(2, 1, 3, 4, 2)) - x, _ = torch.solve(b, A) - x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())) - self.assertEqual(x.data, cast(x_exp)) + class OldTensorV1(OldTensorBase): + def __reduce__(self): + return (torch.Tensor, (), self.__getstate__()) - # broadcasting both A & b - A = cast(random_fullrank_matrix_distinct_singular_value(4, 1, 3, 1)) - b = cast(torch.randn(2, 1, 3, 4, 5)) - x, _ = torch.solve(b, A) - x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())) - self.assertEqual(x.data, cast(x_exp)) + class OldTensorV2(OldTensorBase): + def __reduce__(self): + return (_rebuild_tensor, self.__getstate__()) - @skipIfNoLapack - def test_solve_batched_dims(self): - self._test_solve_batched_dims(self, lambda t: t) + x = torch.randn(30).as_strided([2, 3], [9, 3], 2) + for old_cls in [OldTensorV1, OldTensorV2]: + with tempfile.NamedTemporaryFile() as f: + old_x = old_cls(x) + torch.save(old_x, f) + f.seek(0) + load_x = torch.load(f) + self.assertEqual(x.storage(), load_x.storage()) + self.assertEqual(x.storage_offset(), load_x.storage_offset()) + self.assertEqual(x.size(), load_x.size()) + self.assertEqual(x.stride(), load_x.stride()) - def test_solve_methods_arg_device(self): - if not torch.cuda.is_available(): - return + # unique_key is necessary because on Python 2.7, if a warning passed to + # the warning module is the same, it is not raised again. + def _test_serialization_container(self, unique_key, filecontext_lambda): + tmpmodule_name = 'tmpmodule{}'.format(unique_key) - for b_device, A_device in product(['cpu', 'cuda'], repeat=2): - if b_device == A_device: - continue + def import_module(name, filename): + if sys.version_info >= (3, 5): + import importlib.util + spec = importlib.util.spec_from_file_location(name, filename) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + else: + import imp + module = imp.load_source(name, filename) + sys.modules[module.__name__] = module + return module - b = torch.randn(3, 1, device=b_device) - A = torch.randn(3, 3, device=A_device) - err_str = "Expected b and A to be on the same device" - with self.assertRaisesRegex(RuntimeError, err_str): - torch.solve(b, A) + with filecontext_lambda() as checkpoint: + fname = get_file_path_2(os.path.dirname(__file__), 'data', 'network1.py') + module = import_module(tmpmodule_name, fname) + torch.save(module.Net(), checkpoint) - with self.assertRaisesRegex(RuntimeError, err_str): - torch.cholesky_solve(b, A) + # First check that the checkpoint can be loaded without warnings + checkpoint.seek(0) + with warnings.catch_warnings(record=True) as w: + loaded = torch.load(checkpoint) + self.assertTrue(isinstance(loaded, module.Net)) + if can_retrieve_source: + self.assertEquals(len(w), 0) - with self.assertRaisesRegex(RuntimeError, err_str): - torch.triangular_solve(b, A) + # Replace the module with different source + fname = get_file_path_2(os.path.dirname(__file__), 'data', 'network2.py') + module = import_module(tmpmodule_name, fname) + checkpoint.seek(0) + with warnings.catch_warnings(record=True) as w: + loaded = torch.load(checkpoint) + self.assertTrue(isinstance(loaded, module.Net)) + if can_retrieve_source: + self.assertEquals(len(w), 1) + self.assertTrue(w[0].category, 'SourceChangeWarning') - # b and A have to be modified to match accepted inputs sizes for lu_solve - b = b.unsqueeze(0) - A = A.unsqueeze(0) - with self.assertRaisesRegex(RuntimeError, err_str): - torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=A_device).int()) + def test_serialization_container(self): + self._test_serialization_container('file', tempfile.NamedTemporaryFile) - # This checks if a suitable error message is thrown - # when LU output and pivots are on the same device - with self.assertRaisesRegex(RuntimeError, - "Expected LU_pivots and LU_data to be on the same device"): - torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=b_device).int()) + def test_serialization_container_filelike(self): + self._test_serialization_container('filelike', BytesIOContext) - @staticmethod - def _test_qr(self, cast): - def run_test(tensor_dims, some): - A = cast(torch.randn(*tensor_dims)) - Q, R = torch.qr(A, some=some) + def test_serialization_map_location(self): + test_file_path = download_file('https://download.pytorch.org/test_data/gpu_tensors.pt') - # Check0: Q[-2:] = (m, n_columns), R[-2:] = (n_columns, n) - m, n = tensor_dims[-2:] - n_columns = m if (not some) and m > n else min(m, n) - self.assertEqual(Q.size(-2), m) - self.assertEqual(R.size(-1), n) - self.assertEqual(Q.size(-1), n_columns) + def map_location(storage, loc): + return storage - # Check1: A = QR - self.assertEqual(A, torch.matmul(Q, R)) + def load_bytes(): + with open(test_file_path, 'rb') as f: + return io.BytesIO(f.read()) - # Check2: A = QR (with out) - Q_out, R_out = cast(torch.Tensor()), cast(torch.Tensor()) - torch.qr(A, some=some, out=(Q_out, R_out)) - self.assertEqual(A, torch.matmul(Q_out, R_out)) + fileobject_lambdas = [lambda: test_file_path, load_bytes] + cpu_map_locations = [ + map_location, + {'cuda:0': 'cpu'}, + 'cpu', + torch.device('cpu'), + ] + gpu_0_map_locations = [ + {'cuda:0': 'cuda:0'}, + 'cuda', + 'cuda:0', + torch.device('cuda'), + torch.device('cuda', 0) + ] + gpu_last_map_locations = [ + 'cuda:{}'.format(torch.cuda.device_count() - 1), + ] - # Check3: Q == Q_out, R == R_out - self.assertEqual(Q, Q_out) - self.assertEqual(R, R_out) + def check_map_locations(map_locations, tensor_class, intended_device): + for fileobject_lambda in fileobject_lambdas: + for map_location in map_locations: + tensor = torch.load(fileobject_lambda(), map_location=map_location) - # Check4: Q^{T}Q = I, triu(R) = R - self.assertEqual(torch.matmul(Q.transpose(-2, -1), Q), - cast(torch.eye(n_columns).expand(Q.shape[:-2] + (n_columns, n_columns)))) - self.assertEqual(R.triu(), R) + self.assertEqual(tensor.device, intended_device) + self.assertIsInstance(tensor, tensor_class) + self.assertEqual(tensor, tensor_class([[1.0, 2.0], [3.0, 4.0]])) - tensor_dims_list = [(3, 5), (5, 5), (5, 3), # Single matrix - (7, 3, 5), (7, 5, 5), (7, 5, 3), # 3-dim Tensors - (7, 5, 3, 5), (7, 5, 5, 5), (7, 5, 5, 3)] # 4-dim Tensors - for tensor_dims, some in product(tensor_dims_list, [True, False]): - run_test(tensor_dims, some) + check_map_locations(cpu_map_locations, torch.FloatTensor, torch.device('cpu')) + if torch.cuda.is_available(): + check_map_locations(gpu_0_map_locations, torch.cuda.FloatTensor, torch.device('cuda', 0)) + check_map_locations( + gpu_last_map_locations, + torch.cuda.FloatTensor, + torch.device('cuda', torch.cuda.device_count() - 1) + ) - @skipIfNoLapack - def test_qr(self): - self._test_qr(self, lambda t: t) + @unittest.skipIf(torch.cuda.is_available(), "Testing torch.load on CPU-only machine") + @unittest.skipIf(not PY3, "Test tensors were serialized using python 3") + def test_load_nonexistent_device(self): + # Setup: create a serialized file object with a 'cuda:0' restore location + # The following was generated by saving a torch.randn(2, device='cuda') tensor. + serialized = (b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9' + b'\x03.\x80\x02}q\x00(X\x10\x00\x00\x00protocol_versionq' + b'\x01M\xe9\x03X\r\x00\x00\x00little_endianq\x02\x88X\n' + b'\x00\x00\x00type_sizesq\x03}q\x04(X\x05\x00\x00\x00shortq' + b'\x05K\x02X\x03\x00\x00\x00intq\x06K\x04X\x04\x00\x00\x00' + b'longq\x07K\x04uu.\x80\x02ctorch._utils\n_rebuild_tensor_v2' + b'\nq\x00((X\x07\x00\x00\x00storageq\x01ctorch\nFloatStorage' + b'\nq\x02X\x0e\x00\x00\x0094919395964320q\x03X\x06\x00\x00' + b'\x00cuda:0q\x04K\x02Ntq\x05QK\x00K\x02\x85q\x06K\x01\x85q' + b'\x07\x89Ntq\x08Rq\t.\x80\x02]q\x00X\x0e\x00\x00\x00' + b'94919395964320q\x01a.\x02\x00\x00\x00\x00\x00\x00\x00\xbb' + b'\x1f\x82\xbe\xea\x81\xd1>') - @skipIfNoLapack - def test_ormqr(self): - mat1 = torch.randn(7, 7) - mat2 = torch.randn(7, 7) - q, r = torch.qr(mat1) - m, tau = torch.geqrf(mat1) - out_holder = torch.empty_like(mat1) + buf = io.BytesIO(serialized) - res1 = torch.mm(q, mat2) - res2 = torch.ormqr(m, tau, mat2, left=True, transpose=False) - torch.ormqr(m, tau, mat2, out=out_holder) - self.assertEqual(res1, res2) - self.assertEqual(res2, out_holder) + error_msg = r'Attempting to deserialize object on a CUDA device' + with self.assertRaisesRegex(RuntimeError, error_msg): + _ = torch.load(buf) - res1 = torch.mm(mat2, q) - res2 = torch.ormqr(m, tau, mat2, left=False, transpose=False) - torch.ormqr(m, tau, mat2, left=False, transpose=False, out=out_holder) - self.assertEqual(res1, res2) - self.assertEqual(res2, out_holder) + def test_serialization_filelike_api_requirements(self): + filemock = FilelikeMock(b'', has_readinto=False) + tensor = torch.randn(3, 5) + torch.save(tensor, filemock) + expected_superset = {'write', 'flush'} + self.assertTrue(expected_superset.issuperset(filemock.calls)) - res1 = torch.mm(q.t(), mat2) - res2 = torch.ormqr(m, tau, mat2, left=True, transpose=True) - torch.ormqr(m, tau, mat2, left=True, transpose=True, out=out_holder) - self.assertEqual(res1, res2) - self.assertEqual(res2, out_holder) + # Reset between save and load + filemock.seek(0) + filemock.calls.clear() - res1 = torch.mm(mat2, q.t()) - res2 = torch.ormqr(m, tau, mat2, left=False, transpose=True) - torch.ormqr(m, tau, mat2, left=False, transpose=True, out=out_holder) - self.assertEqual(res1, res2) - self.assertEqual(res2, out_holder) + _ = torch.load(filemock) + expected_superset = {'read', 'readline', 'seek', 'tell'} + self.assertTrue(expected_superset.issuperset(filemock.calls)) - @staticmethod - def _test_geqrf(self, cast): - a = cast(torch.randn(5, 5)) - b, c = torch.geqrf(a) - b_placeholder, c_placeholder = torch.empty_like(b), torch.empty_like(c) - torch.geqrf(a, out=(b_placeholder, c_placeholder)) - self.assertEqual(b, b_placeholder) - self.assertEqual(c, c_placeholder) + def _test_serialization_filelike(self, tensor, mock, desc): + f = mock(b'') + torch.save(tensor, f) + f.seek(0) + data = mock(f.read()) - @skipIfNoLapack - def test_geqrf(self): - self._test_geqrf(self, lambda t: t) + msg = 'filelike serialization with {}' - @staticmethod - def _test_triangular_solve(self, cast): - a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23), - (-6.05, -3.30, 5.36, -4.44, 1.08), - (-0.45, 2.58, -2.70, 0.27, 9.04), - (8.32, 2.71, 4.35, -7.17, 2.14), - (-9.67, -5.14, -7.26, 6.08, -6.87))).t() - b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03), - (-1.56, 4.00, -8.67, 1.75, 2.86), - (9.81, -4.09, -4.57, -8.61, 8.99))).t() - - a = cast(a) - b = cast(b) - - U = torch.triu(a) - L = torch.tril(a) - - # solve Ux = b - x = torch.triangular_solve(b, U)[0] - self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12) - x = torch.triangular_solve(b, U, True, False, False)[0] - self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12) - - # solve Lx = b - x = torch.triangular_solve(b, L, False)[0] - self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12) - x = torch.triangular_solve(b, L, False, False, False)[0] - self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12) - - # solve U'x = b - x = torch.triangular_solve(b, U, True, True)[0] - self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12) - x = torch.triangular_solve(b, U, True, True, False)[0] - self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12) - - # solve U'x = b by manual transposition - y = torch.triangular_solve(b, U.t(), False, False)[0] - self.assertLessEqual(x.dist(y), 1e-12) - - # solve L'x = b - x = torch.triangular_solve(b, L, False, True)[0] - self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12) - x = torch.triangular_solve(b, L, False, True, False)[0] - self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12) - - # solve L'x = b by manual transposition - y = torch.triangular_solve(b, L.t(), True, False)[0] - self.assertLessEqual(x.dist(y), 1e-12) + b = torch.load(data) + self.assertTrue(torch.equal(tensor, b), msg.format(desc)) - # test reuse - res1 = torch.triangular_solve(b, a)[0] - ta = cast(torch.Tensor()) - tb = cast(torch.Tensor()) - torch.triangular_solve(b, a, out=(tb, ta)) - self.assertEqual(res1, tb, 0) - tb.zero_() - torch.triangular_solve(b, a, out=(tb, ta)) - self.assertEqual(res1, tb, 0) + def test_serialization_filelike_missing_attrs(self): + # Test edge cases where filelike objects are missing attributes. + # The Python io docs suggests that these attributes should really exist + # and throw io.UnsupportedOperation, but that isn't always the case. + mocks = [ + ('no readinto', lambda x: FilelikeMock(x)), + ('has readinto', lambda x: FilelikeMock(x, has_readinto=True)), + ('no fileno', lambda x: FilelikeMock(x, has_fileno=False)), + ] - @skipIfNoLapack - def test_triangular_solve(self): - self._test_triangular_solve(self, lambda t: t) + to_serialize = torch.randn(3, 10) + for desc, mock in mocks: + self._test_serialization_filelike(to_serialize, mock, desc) - @staticmethod - def _test_triangular_solve_batched(self, cast): - def triangular_solve_test_helper(A_dims, b_dims, cast, upper, unitriangular): - A = cast(torch.randn(*A_dims)) - A = A.triu() if upper else A.tril() - if unitriangular: - A.diagonal(dim1=-2, dim2=-1).fill_(1.) - b = cast(torch.randn(*b_dims)) - return A, b + def test_serialization_filelike_stress(self): + a = torch.randn(11 * (2 ** 9) + 1, 5 * (2 ** 9)) - for upper, transpose, unitriangular in product([True, False], repeat=3): - # test against triangular_solve: one batch with all possible arguments - A, b = triangular_solve_test_helper((1, 5, 5), (1, 5, 10), cast, upper, unitriangular) - x_exp = torch.triangular_solve(b.squeeze(0), A.squeeze(0), - upper=upper, transpose=transpose, unitriangular=unitriangular)[0] - x = torch.triangular_solve(b, A, - upper=upper, transpose=transpose, unitriangular=unitriangular)[0] - self.assertEqual(x, x_exp.unsqueeze(0)) - - # test against triangular_solve in a loop: four batches with all possible arguments - A, b = triangular_solve_test_helper((4, 5, 5), (4, 5, 10), cast, upper, unitriangular) - x_exp_list = [] - for i in range(4): - x_exp = torch.triangular_solve(b[i], A[i], - upper=upper, transpose=transpose, unitriangular=unitriangular)[0] - x_exp_list.append(x_exp) - x_exp = torch.stack(x_exp_list) + # This one should call python read multiple times + self._test_serialization_filelike(a, lambda x: FilelikeMock(x, has_readinto=False), + 'read() stress test') + self._test_serialization_filelike(a, lambda x: FilelikeMock(x, has_readinto=True), + 'readinto() stress test') - x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0] - self.assertEqual(x, x_exp) + def test_serialization_filelike_uses_readinto(self): + # For maximum effiency, when reading a file-like object, + # ensure the C API calls readinto instead of read. + a = torch.randn(5, 4) - # basic correctness test - A, b = triangular_solve_test_helper((3, 5, 5), (3, 5, 10), cast, upper, unitriangular) - x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0] - if transpose: - self.assertLessEqual(b.dist(torch.matmul(A.transpose(-1, -2), x)), 2e-12) - else: - self.assertLessEqual(b.dist(torch.matmul(A, x)), 2e-12) + f = io.BytesIO() + torch.save(a, f) + f.seek(0) + data = FilelikeMock(f.read(), has_readinto=True) - @skipIfNoLapack - def test_triangular_solve_batched(self): - self._test_triangular_solve_batched(self, lambda t: t) + b = torch.load(data) + self.assertTrue(data.was_called('readinto')) - @staticmethod - def _test_triangular_solve_batched_many_batches(self, cast): - def triangular_solve_test_helper(A_dims, b_dims, cast, upper, unitriangular): - A = cast(torch.randn(*A_dims)) - A = A.triu() if upper else A.tril() - if unitriangular: - A.diagonal(dim1=-2, dim2=-1).fill_(1.) - b = cast(torch.randn(*b_dims)) - return A, b + def test_serialization_storage_slice(self): + # Generated using: + # + # t = torch.zeros(2); + # s1 = t.storage()[:1] + # s2 = t.storage()[1:] + # torch.save((s1, s2), 'foo.ser') + # + # with PyTorch 0.3.1 + serialized = (b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03' + b'.\x80\x02}q\x00(X\n\x00\x00\x00type_sizesq\x01}q\x02(X\x03' + b'\x00\x00\x00intq\x03K\x04X\x05\x00\x00\x00shortq\x04K\x02X' + b'\x04\x00\x00\x00longq\x05K\x04uX\x10\x00\x00\x00protocol_versionq' + b'\x06M\xe9\x03X\r\x00\x00\x00little_endianq\x07\x88u.\x80\x02' + b'(X\x07\x00\x00\x00storageq\x00ctorch\nFloatStorage\nq\x01X\x0e' + b'\x00\x00\x0094279043900432q\x02X\x03\x00\x00\x00cpuq\x03K\x02' + b'X\x0e\x00\x00\x0094279029750368q\x04K\x00K\x01\x87q\x05tq\x06' + b'Q(h\x00h\x01X\x0e\x00\x00\x0094279043900432q\x07h\x03K\x02X' + b'\x0e\x00\x00\x0094279029750432q\x08K\x01K\x01\x87q\ttq\nQ' + b'\x86q\x0b.\x80\x02]q\x00X\x0e\x00\x00\x0094279043900432q' + b'\x01a.\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' + b'\x00\x00\x00\x00') - for upper, transpose, unitriangular in product([True, False], repeat=3): - A, b = triangular_solve_test_helper((256, 256, 5, 5), (5, 1), cast, upper, unitriangular) - x, _ = torch.triangular_solve(b, A, - upper=upper, transpose=transpose, unitriangular=unitriangular) - if transpose: - A = A.transpose(-2, -1) - self.assertEqual(torch.matmul(A, x), b.expand(A.shape[:-2] + (5, 1))) + buf = io.BytesIO(serialized) + (s1, s2) = torch.load(buf) + self.assertEqual(s1[0], 0) + self.assertEqual(s2[0], 0) + self.assertEqual(s1.data_ptr() + 4, s2.data_ptr()) - A, b = triangular_solve_test_helper((3, 3), (512, 512, 3, 1), cast, upper, unitriangular) - x, _ = torch.triangular_solve(b, A, - upper=upper, transpose=transpose, unitriangular=unitriangular) - if transpose: - A = A.transpose(-2, -1) - self.assertEqual(torch.matmul(A, x), b) + def test_load_error_msg(self): + expected_err_msg = (".*You can only torch.load from a file that is seekable. " + + "Please pre-load the data into a buffer like io.BytesIO and " + + "try to load from it instead.") - @slowTest - @skipIfNoLapack - def test_triangular_solve_batched_many_batches(self): - self._test_triangular_solve_batched_many_batches(self, lambda t: t) + resource = FilelikeMock(data=b"data") + delattr(resource, "tell") + delattr(resource, "seek") + self.assertRaisesRegex(AttributeError, expected_err_msg, lambda: torch.load(resource)) - @staticmethod - def _test_triangular_solve_batched_dims(self, cast): - if not TEST_SCIPY: - return + def test_from_buffer(self): + a = bytearray([1, 2, 3, 4]) + self.assertEqual(torch.ByteStorage.from_buffer(a).tolist(), [1, 2, 3, 4]) + shorts = torch.ShortStorage.from_buffer(a, 'big') + self.assertEqual(shorts.size(), 2) + self.assertEqual(shorts.tolist(), [258, 772]) + ints = torch.IntStorage.from_buffer(a, 'little') + self.assertEqual(ints.size(), 1) + self.assertEqual(ints[0], 67305985) + f = bytearray([0x40, 0x10, 0x00, 0x00]) + floats = torch.FloatStorage.from_buffer(f, 'big') + self.assertEqual(floats.size(), 1) + self.assertEqual(floats[0], 2.25) - from scipy.linalg import solve_triangular as tri_solve + f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40]) + bools = torch.BoolStorage.from_buffer(f, 'big') + self.assertEqual(bools.size(), 8) + self.assertEqual(bools.tolist(), [False, True, True, True, True, True, True, True]) + self.assertEqual(bools.type(), 'torch.BoolStorage') - def scipy_tri_solve_batched(A, B, upper, trans, diag): - batch_dims_A, batch_dims_B = A.shape[:-2], B.shape[:-2] - single_dim_A, single_dim_B = A.shape[-2:], B.shape[-2:] - expand_dims = tuple(torch._C._infer_size(torch.Size(batch_dims_A), - torch.Size(batch_dims_B))) - expand_A = np.broadcast_to(A, expand_dims + single_dim_A) - expand_B = np.broadcast_to(B, expand_dims + single_dim_B) - flat_A = expand_A.reshape((-1,) + single_dim_A) - flat_B = expand_B.reshape((-1,) + single_dim_B) - flat_X = np.vstack([tri_solve(a, b, lower=(not upper), trans=int(trans), unit_diagonal=diag) - for a, b in zip(flat_A, flat_B)]) - return flat_X.reshape(expand_B.shape) + f = bytearray(b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9') + bools = torch.BoolStorage.from_buffer(f, 'big') + self.assertEqual(bools.size(), 19) - def run_test(A_dims, b_dims, cast, upper, transpose, unitriangular): - A = torch.randn(*A_dims) - A = A.triu() if upper else A.tril() - if unitriangular: - A.diagonal(dim1=-2, dim2=-1).fill_(1.) - b = torch.randn(*b_dims) - x_exp = torch.Tensor(scipy_tri_solve_batched(A.numpy(), b.numpy(), - upper, transpose, unitriangular)) - A, b = cast(A), cast(b) - x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0] + f = bytearray(b'\0x4A') + bools = torch.BoolStorage.from_buffer(f, 'big') + self.assertEqual(bools.size(), 4) + self.assertEqual(bools.tolist(), [False, True, True, True]) - self.assertEqual(x, cast(x_exp)) + def test_storage_casts(self): + storage = torch.IntStorage([-1, 0, 1, 2, 3, 4]) + self.assertEqual(storage.size(), 6) + self.assertEqual(storage.tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertEqual(storage.type(), 'torch.IntStorage') + self.assertIs(storage.dtype, torch.int32) - for upper, transpose, unitriangular in product([True, False], repeat=3): - # test against scipy.linalg.solve_triangular - run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), cast, upper, transpose, unitriangular) # no broadcasting - run_test((2, 1, 3, 4, 4), (4, 6), cast, upper, transpose, unitriangular) # broadcasting b - run_test((4, 4), (2, 1, 3, 4, 2), cast, upper, transpose, unitriangular) # broadcasting A - run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), cast, upper, transpose, unitriangular) # broadcasting A & b + floatStorage = storage.float() + self.assertEqual(floatStorage.size(), 6) + self.assertEqual(floatStorage.tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertEqual(floatStorage.type(), 'torch.FloatStorage') + self.assertEqual(floatStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertIs(floatStorage.dtype, torch.float32) - @skipIfNoLapack - def test_triangular_solve_batched_dims(self): - self._test_triangular_solve_batched_dims(self, lambda t: t) + halfStorage = storage.half() + self.assertEqual(halfStorage.size(), 6) + self.assertEqual(halfStorage.tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertEqual(halfStorage.type(), 'torch.HalfStorage') + self.assertEqual(halfStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertIs(halfStorage.dtype, torch.float16) - @staticmethod - def _test_lstsq(self, device): - def cast_fn(tensor): - return tensor.to(device=device) + bfloat16Storage = storage.bfloat16() + self.assertEqual(bfloat16Storage.size(), 6) + self.assertEqual(bfloat16Storage.tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertEqual(bfloat16Storage.type(), 'torch.BFloat16Storage') + self.assertEqual(bfloat16Storage.int().tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertIs(bfloat16Storage.dtype, torch.bfloat16) - def _test_underdetermined(a, b, expectedNorm): - # underdetermined systems are not supported on the GPU - if 'cuda' in device: - return + longStorage = storage.long() + self.assertEqual(longStorage.size(), 6) + self.assertEqual(longStorage.tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertEqual(longStorage.type(), 'torch.LongStorage') + self.assertEqual(longStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertIs(longStorage.dtype, torch.int64) - m = a.size()[0] - n = a.size()[1] - assert(m <= n) + shortStorage = storage.short() + self.assertEqual(shortStorage.size(), 6) + self.assertEqual(shortStorage.tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertEqual(shortStorage.type(), 'torch.ShortStorage') + self.assertEqual(shortStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertIs(shortStorage.dtype, torch.int16) - a_copy = a.clone() - b_copy = b.clone() - res1 = torch.lstsq(b, a)[0] - self.assertEqual(a, a_copy, 0) - self.assertEqual(b, b_copy, 0) - self.assertEqual((torch.mm(a, res1) - b).norm(), expectedNorm, 1e-8) + doubleStorage = storage.double() + self.assertEqual(doubleStorage.size(), 6) + self.assertEqual(doubleStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0]) + self.assertEqual(doubleStorage.type(), 'torch.DoubleStorage') + self.assertEqual(doubleStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertIs(doubleStorage.dtype, torch.float64) - ta = cast_fn(torch.Tensor()) - tb = cast_fn(torch.Tensor()) - res2 = torch.lstsq(b, a, out=(tb, ta))[0] - self.assertEqual(a, a_copy, 0) - self.assertEqual(b, b_copy, 0) - self.assertEqual((torch.mm(a, res1) - b).norm(), expectedNorm, 1e-8) + charStorage = storage.char() + self.assertEqual(charStorage.size(), 6) + self.assertEqual(charStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0]) + self.assertEqual(charStorage.type(), 'torch.CharStorage') + self.assertEqual(charStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertIs(charStorage.dtype, torch.int8) - res3 = torch.lstsq(b, a, out=(b, a))[0] - self.assertEqual((torch.mm(a_copy, b) - b_copy).norm(), expectedNorm, 1e-8) - self.assertEqual(res1, tb, 0) - self.assertEqual(res1, b, 0) - self.assertEqual(res1, res2, 0) - self.assertEqual(res1, res3, 0) - - def _test_overdetermined(a, b, expectedNorm): - m = a.size()[0] - n = a.size()[1] - assert(m > n) - - def check_norm(a, b, expected_norm, gels_result): - # Checks |ax - b| and the residual info from the result + byteStorage = storage.byte() + self.assertEqual(byteStorage.size(), 6) + self.assertEqual(byteStorage.tolist(), [255, 0, 1, 2, 3, 4]) + self.assertEqual(byteStorage.type(), 'torch.ByteStorage') + self.assertEqual(byteStorage.int().tolist(), [255, 0, 1, 2, 3, 4]) + self.assertIs(byteStorage.dtype, torch.uint8) - # The first n rows is the least square solution. - # Rows n to m-1 contain residual information. - x = gels_result[:n] - resid_info = gels_result[n:] + boolStorage = storage.bool() + self.assertEqual(boolStorage.size(), 6) + self.assertEqual(boolStorage.tolist(), [True, False, True, True, True, True]) + self.assertEqual(boolStorage.type(), 'torch.BoolStorage') + self.assertEqual(boolStorage.int().tolist(), [1, 0, 1, 1, 1, 1]) + self.assertIs(boolStorage.dtype, torch.bool) - resid_norm = (torch.mm(a, x) - b).norm() - self.assertEqual(resid_norm, expectedNorm, 1e-8) - self.assertEqual(resid_info.norm(), resid_norm, 1e-8) + @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows") + def test_from_file(self): + size = 10000 + with tempfile.NamedTemporaryFile() as f: + s1 = torch.FloatStorage.from_file(f.name, True, size) + t1 = torch.FloatTensor(s1).copy_(torch.randn(size)) - a_copy = a.clone() - b_copy = b.clone() - res1 = torch.lstsq(b, a)[0] - self.assertEqual(a, a_copy, 0) - self.assertEqual(b, b_copy, 0) - check_norm(a, b, expectedNorm, res1) + # check mapping + s2 = torch.FloatStorage.from_file(f.name, True, size) + t2 = torch.FloatTensor(s2) + self.assertEqual(t1, t2, 0) - ta = cast_fn(torch.Tensor()) - tb = cast_fn(torch.Tensor()) - res2 = torch.lstsq(b, a, out=(tb, ta))[0] - self.assertEqual(a, a_copy, 0) - self.assertEqual(b, b_copy, 0) - check_norm(a, b, expectedNorm, res2) + # check changes to t1 from t2 + rnum = random.uniform(-1, 1) + t1.fill_(rnum) + self.assertEqual(t1, t2, 0) - res3 = torch.lstsq(b, a, out=(b, a))[0] - check_norm(a_copy, b_copy, expectedNorm, res3) + # check changes to t2 from t1 + rnum = random.uniform(-1, 1) + t2.fill_(rnum) + self.assertEqual(t1, t2, 0) - self.assertEqual(res1, tb, 0) - self.assertEqual(res1, b, 0) - self.assertEqual(res1, res2, 0) - self.assertEqual(res1, res3, 0) + @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows") + def test_torch_from_file(self): + size = 10000 + with tempfile.NamedTemporaryFile() as f: + s1 = torch.from_file(f.name, True, size, dtype=torch.float) + t1 = torch.FloatTensor(s1).copy_(torch.randn(size)) - # basic test - expectedNorm = 0 - a = cast_fn(torch.Tensor(((1.44, -9.96, -7.55, 8.34), - (-7.84, -0.28, 3.24, 8.09), - (-4.39, -3.24, 6.27, 5.28), - (4.53, 3.83, -6.64, 2.06)))).t() - b = cast_fn(torch.Tensor(((8.58, 8.26, 8.48, -5.28), - (9.35, -4.43, -0.70, -0.26)))).t() - _test_underdetermined(a, b, expectedNorm) + # check mapping + s2 = torch.from_file(f.name, True, size, dtype=torch.float) + t2 = torch.FloatTensor(s2) + self.assertEqual(t1, t2, 0) - # test overdetermined - expectedNorm = 17.390200628863 - a = cast_fn(torch.Tensor(((1.44, -9.96, -7.55, 8.34, 7.08, -5.45), - (-7.84, -0.28, 3.24, 8.09, 2.52, -5.70), - (-4.39, -3.24, 6.27, 5.28, 0.74, -1.19), - (4.53, 3.83, -6.64, 2.06, -2.47, 4.70)))).t() - b = cast_fn(torch.Tensor(((8.58, 8.26, 8.48, -5.28, 5.72, 8.93), - (9.35, -4.43, -0.70, -0.26, -7.36, -2.52)))).t() - _test_overdetermined(a, b, expectedNorm) + # check changes to t1 from t2 + rnum = random.uniform(-1, 1) + t1.fill_(rnum) + self.assertEqual(t1, t2, 0) - # test underdetermined - expectedNorm = 0 - a = cast_fn(torch.Tensor(((1.44, -9.96, -7.55), - (-7.84, -0.28, 3.24), - (-4.39, -3.24, 6.27), - (4.53, 3.83, -6.64)))).t() - b = cast_fn(torch.Tensor(((8.58, 8.26, 8.48), - (9.35, -4.43, -0.70)))).t() - _test_underdetermined(a, b, expectedNorm) + # check changes to t2 from t1 + rnum = random.uniform(-1, 1) + t2.fill_(rnum) + self.assertEqual(t1, t2, 0) - # test reuse - expectedNorm = 0 - a = cast_fn(torch.Tensor(((1.44, -9.96, -7.55, 8.34), - (-7.84, -0.28, 3.24, 8.09), - (-4.39, -3.24, 6.27, 5.28), - (4.53, 3.83, -6.64, 2.06)))).t() - b = cast_fn(torch.Tensor(((8.58, 8.26, 8.48, -5.28), - (9.35, -4.43, -0.70, -0.26)))).t() - ta = cast_fn(torch.Tensor()) - tb = cast_fn(torch.Tensor()) - torch.lstsq(b, a, out=(tb, ta)) - self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8) - torch.lstsq(b, a, out=(tb, ta)) - self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8) - torch.lstsq(b, a, out=(tb, ta)) - self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8) + def test_print(self): + default_type = torch.Tensor().type() + for t in torch._tensor_classes: + if t == torch.HalfTensor: + continue # HalfTensor does not support fill + if t.is_sparse: + continue + if t.is_cuda and not torch.cuda.is_available(): + continue + if t == torch.cuda.BFloat16Tensor: + self.assertRaises(RuntimeError, lambda: t(100, 100).fill_(1)) + continue + obj = t(100, 100).fill_(1) + obj.__repr__() + str(obj) + # test half tensor + obj = torch.rand(100, 100, device='cpu').half() + obj.__repr__() + str(obj) + for t in torch._storage_classes: + if t == torch.BFloat16Storage: + continue # Fix once fill is enabled for bfloat16 + if t.is_cuda and not torch.cuda.is_available(): + continue + if t == torch.BoolStorage or t == torch.cuda.BoolStorage: + obj = t(100).fill_(True) + else: + obj = t(100).fill_(1) + obj.__repr__() + str(obj) - @skipIfNoLapack - def test_lstsq(self): - self._test_lstsq(self, 'cpu') + # test big integer + x = torch.tensor(2341234123412341) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor(2341234123412341)''') - @skipIfNoLapack - def test_eig(self): - a = torch.Tensor(((1.96, 0.00, 0.00, 0.00, 0.00), - (-6.49, 3.80, 0.00, 0.00, 0.00), - (-0.47, -6.39, 4.17, 0.00, 0.00), - (-7.20, 1.50, -1.51, 5.70, 0.00), - (-0.65, -6.34, 2.67, 1.80, -7.10))).t().contiguous() - e = torch.eig(a)[0] - ee, vv = torch.eig(a, True) - te = torch.Tensor() - tv = torch.Tensor() - eee, vvv = torch.eig(a, True, out=(te, tv)) - self.assertEqual(e, ee, 1e-12) - self.assertEqual(ee, eee, 1e-12) - self.assertEqual(ee, te, 1e-12) - self.assertEqual(vv, vvv, 1e-12) - self.assertEqual(vv, tv, 1e-12) + # test scientific notation + x = torch.tensor([1e28, 1e-28]) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([1.0000e+28, 1.0000e-28])''') - # test reuse - X = torch.randn(4, 4) - X = torch.mm(X.t(), X) - e, v = torch.zeros(4, 2), torch.zeros(4, 4) - torch.eig(X, True, out=(e, v)) - Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t()) - self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong') - self.assertFalse(v.is_contiguous(), 'V is contiguous') + # test scientific notation using set_printoptions + x = torch.tensor([1e2, 1e-2]) + torch.set_printoptions(sci_mode=True) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([1.0000e+02, 1.0000e-02])''') + torch.set_printoptions(sci_mode=False) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([ 100.0000, 0.0100])''') + torch.set_printoptions(sci_mode=None) # reset to the default value - torch.eig(X, True, out=(e, v)) - Xhat = torch.mm(v, torch.mm(e.select(1, 0).diag(), v.t())) - self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong') - self.assertFalse(v.is_contiguous(), 'V is contiguous') + # test no leading space if all elements positive + x = torch.tensor([1, 2]) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([1, 2])''') - # test non-contiguous - X = torch.randn(4, 4) - X = torch.mm(X.t(), X) - e = torch.zeros(4, 2, 2)[:, 1] - v = torch.zeros(4, 2, 4)[:, 1] - self.assertFalse(v.is_contiguous(), 'V is contiguous') - self.assertFalse(e.is_contiguous(), 'E is contiguous') - torch.eig(X, True, out=(e, v)) - Xhat = torch.mm(torch.mm(v, torch.diag(e.select(1, 0))), v.t()) - self.assertEqual(X, Xhat, 1e-8, 'VeV\' wrong') + # test for leading space if there are negative elements + x = torch.tensor([1, -2]) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([ 1, -2])''') - @staticmethod - def _test_symeig(self, conv_fn): - from common_utils import random_symmetric_matrix + # test inf and nan + x = torch.tensor([4, inf, 1.5, -inf, 0, nan, 1]) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([4.0000, inf, 1.5000, -inf, 0.0000, nan, 1.0000])''') - def run_test(dims, eigenvectors, upper): - x = conv_fn(random_symmetric_matrix(*dims)) - oute = conv_fn(torch.empty(dims[1:] + dims[:1])) - outv = conv_fn(torch.empty(dims[1:] + dims[:1] * 2)) - torch.symeig(x, eigenvectors=eigenvectors, upper=upper, out=(oute, outv)) + # test dtype + torch.set_default_dtype(torch.float) + x = torch.tensor([1e-324, 1e-323, 1e-322, 1e307, 1e308, 1e309], dtype=torch.float64) + self.assertEqual(x.__repr__(), str(x)) + expected_str = '''\ +tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308, + inf], dtype=torch.float64)''' + self.assertExpectedInline(str(x), expected_str) - if eigenvectors: - x_recon = torch.matmul(torch.matmul(outv, torch.diag_embed(oute)), outv.transpose(-2, -1)) - self.assertEqual(x, x_recon, 1e-8, 'Incorrect reconstruction using V @ diag(e) @ V.T') - else: - eigvals, _ = torch.symeig(x, eigenvectors=True, upper=upper) - self.assertEqual(eigvals, oute, 'Eigenvalues mismatch') - self.assertEqual(torch.zeros_like(outv), outv, 'Eigenvector matrix not zero') + # test changing default dtype + torch.set_default_dtype(torch.float64) + self.assertEqual(x.__repr__(), str(x)) + expected_str = '''\ +tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308, + inf])''' + self.assertExpectedInline(str(x), expected_str) - rese, resv = x.symeig(eigenvectors=eigenvectors, upper=upper) - self.assertEqual(rese, oute, "outputs of symeig and symeig with out don't match") - self.assertEqual(resv, outv, "outputs of symeig and symeig with out don't match") + # test summary + x = torch.zeros(10000) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([0., 0., 0., ..., 0., 0., 0.])''') - # test non-contiguous - x = conv_fn(random_symmetric_matrix(*dims)) - n_dim = len(dims) + 1 - # Reverse the batch dimensions and the matrix dimensions and then concat them - x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) - assert not x.is_contiguous(), "x is intentionally non-contiguous" - rese, resv = torch.symeig(x, eigenvectors=eigenvectors, upper=upper) - if eigenvectors: - x_recon = torch.matmul(torch.matmul(resv, torch.diag_embed(rese)), resv.transpose(-2, -1)) - self.assertEqual(x, x_recon, 1e-8, 'Incorrect reconstruction using V @ diag(e) @ V.T') - else: - eigvals, _ = torch.symeig(x, eigenvectors=True, upper=upper) - self.assertEqual(eigvals, rese, 'Eigenvalues mismatch') - self.assertEqual(torch.zeros_like(resv), resv, 'Eigenvector matrix not zero') + # test internal summary function + x = torch.rand(1, 20, 5, 30) + summary = torch._tensor_str.get_summarized_data(x) + self.assertEqual(summary.shape, (1, 6, 5, 6)) + first_and_last = [0, 1, 2, -3, -2, -1] + self.assertEqual(summary, x[:, first_and_last][..., first_and_last]) - batch_dims_set = [(), (3,), (3, 5), (5, 3, 5)] - for batch_dims, eigenvectors, upper in product(batch_dims_set, (True, False), (True, False)): - run_test((5,) + batch_dims, eigenvectors, upper) + # test device + if torch.cuda.is_available(): + x = torch.tensor([123], device='cuda:0') + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([123], device='cuda:0')''') - @skipIfNoLapack - def test_symeig(self): - self._test_symeig(self, lambda x: x) + # test changing default to cuda + torch.set_default_tensor_type(torch.cuda.FloatTensor) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([123])''') - @staticmethod - def _test_svd(self, conv_fn): - def run_test(dims, some, compute_uv): - x = conv_fn(torch.randn(*dims)) - outu, outs, outv = conv_fn(torch.Tensor()), conv_fn(torch.Tensor()), conv_fn(torch.Tensor()) - torch.svd(x, some=some, compute_uv=compute_uv, out=(outu, outs, outv)) + # test printing a tensor on a different gpu than current one. + if torch.cuda.device_count() >= 2: + with torch.cuda.device(1): + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([123], device='cuda:0')''') - if compute_uv: - if some: - x_recon = torch.matmul(outu, torch.matmul(outs.diag_embed(), outv.transpose(-2, -1))) - self.assertEqual(x, x_recon, 1e-8, 'Incorrect reconstruction using U @ diag(S) @ V.T') - else: - narrow_u = outu[..., :min(*dims[-2:])] - narrow_v = outv[..., :min(*dims[-2:])] - x_recon = torch.matmul(narrow_u, torch.matmul(outs.diag_embed(), narrow_v.transpose(-2, -1))) - self.assertEqual(x, x_recon, 1e-8, 'Incorrect reconstruction using U @ diag(S) @ V.T') - else: - _, singvals, _ = torch.svd(x, compute_uv=True) - self.assertEqual(singvals, outs, 'Singular values mismatch') - self.assertEqual(outu, torch.zeros_like(outu), 'U not zero') - self.assertEqual(outv, torch.zeros_like(outv), 'V not zero') + # test printing cpu tensor when default device is cuda + y = torch.tensor([123], device='cpu') + self.assertEqual(y.__repr__(), str(y)) + self.assertExpectedInline(str(y), '''tensor([123], device='cpu')''') + torch.set_default_tensor_type(default_type) - resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) - self.assertEqual(resu, outu, 'outputs of svd and svd with out differ') - self.assertEqual(ress, outs, 'outputs of svd and svd with out differ') - self.assertEqual(resv, outv, 'outputs of svd and svd with out differ') + # test integral floats and requires_grad + x = torch.tensor([123.], requires_grad=True) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([123.], requires_grad=True)''') - # test non-contiguous - x = conv_fn(torch.randn(*dims)) - n_dim = len(dims) - # Reverse the batch dimensions and the matrix dimensions and then concat them - x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) - assert not x.is_contiguous(), "x is intentionally non-contiguous" - resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) - if compute_uv: - if some: - x_recon = torch.matmul(resu, torch.matmul(ress.diag_embed(), resv.transpose(-2, -1))) - self.assertEqual(x, x_recon, 1e-8, 'Incorrect reconstruction using U @ diag(S) @ V.T') - else: - narrow_u = resu[..., :min(*dims[-2:])] - narrow_v = resv[..., :min(*dims[-2:])] - x_recon = torch.matmul(narrow_u, torch.matmul(ress.diag_embed(), narrow_v.transpose(-2, -1))) - self.assertEqual(x, x_recon, 1e-8, 'Incorrect reconstruction using U @ diag(S) @ V.T') - else: - _, singvals, _ = torch.svd(x, compute_uv=True) - self.assertEqual(singvals, ress, 'Singular values mismatch') - self.assertEqual(resu, torch.zeros_like(resu), 'U not zero') - self.assertEqual(resv, torch.zeros_like(resv), 'V not zero') + # test non-contiguous print + # sliced tensor should have > PRINT_OPTS.threshold elements + x = torch.ones(100, 2, 2, 10) + y = x.as_strided(size=(100, 2, 10), stride=(2 * 2 * 10, 2 * 10, 1)) + self.assertEqual(str(y), y.__repr__()) + expected_str = '''\ +tensor([[[1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.]], - shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices - (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices - (3, 7), (5, 3, 7), (7, 5, 3, 7)] # thin matrices - for dims, some, compute_uv in product(shapes, [True, False], [True, False]): - run_test(dims, some, compute_uv) + [[1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.]], - @skipIfNoLapack - def test_svd(self): - self._test_svd(self, lambda t: t) + [[1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.]], - @staticmethod - def _test_svd_no_singularvectors(self, cast): - for size in [(5, 5), (5, 20), (20, 5)]: - a = cast(torch.randn(*size)) - u, s_expect, v = torch.svd(a) - u, s_actual, v = torch.svd(a, compute_uv=False) - self.assertEqual(s_expect, s_actual, "Singular values don't match") + ..., - @skipIfNoLapack - def test_svd_no_singularvectors(self): - self._test_svd_no_singularvectors(self, lambda t: t) + [[1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.]], - @staticmethod - def _test_matrix_rank(self, conv_fn): - a = conv_fn(torch.eye(10)) - self.assertEqual(torch.matrix_rank(a).item(), 10) - self.assertEqual(torch.matrix_rank(a, True).item(), 10) + [[1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.]], - a[5, 5] = 0 - self.assertEqual(torch.matrix_rank(a).item(), 9) - self.assertEqual(torch.matrix_rank(a, True).item(), 9) + [[1., 1., 1., ..., 1., 1., 1.], + [1., 1., 1., ..., 1., 1., 1.]]])\ +''' - a = conv_fn(torch.randn(24, 42)) - self.assertEqual(torch.matrix_rank(a), torch.matrix_rank(a.t())) - aaT = torch.mm(a, a.t()) - self.assertEqual(torch.matrix_rank(aaT), torch.matrix_rank(aaT, True)) - aTa = torch.mm(a.t(), a) - self.assertEqual(torch.matrix_rank(aTa), torch.matrix_rank(aTa, True)) + self.assertExpectedInline(str(y), expected_str) - if TEST_NUMPY: - from numpy.linalg import matrix_rank - a = conv_fn(torch.randn(35, 75)) - self.assertEqual(torch.matrix_rank(a).item(), matrix_rank(a.cpu().numpy())) - self.assertEqual(torch.matrix_rank(a, 0.01).item(), matrix_rank(a.cpu().numpy(), 0.01)) + # test print 0-dim tensor: there's no 0-dim in Numpy, we match arrayprint style + x = torch.tensor(0.00002) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor(2.0000e-05)''') - aaT = torch.mm(a, a.t()) - self.assertEqual(torch.matrix_rank(aaT).item(), matrix_rank(aaT.cpu().numpy())) - self.assertEqual(torch.matrix_rank(aaT, 0.01).item(), matrix_rank(aaT.cpu().numpy(), 0.01)) + # test print boolean tensor + x = torch.tensor([True]) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([True])''') - if np.lib.NumpyVersion(np.__version__) >= '1.14.0': - self.assertEqual(torch.matrix_rank(aaT, True).item(), matrix_rank(aaT.cpu().numpy(), True)) - self.assertEqual(torch.matrix_rank(aaT, 0.01, True).item(), - matrix_rank(aaT.cpu().numpy(), 0.01, True)) + x = torch.tensor(True) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor(True)''') - @skipIfNoLapack - def test_matrix_rank(self): - self._test_matrix_rank(self, lambda x: x) + # [Numpy] test print float in sci_mode when min < 0.0001. + x = torch.tensor([0.00002]) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([2.0000e-05])''') - @staticmethod - def _test_signal_window_functions(self, device='cpu'): - if not TEST_SCIPY: - raise unittest.SkipTest('Scipy not found') + # [Numpy] test print float in sci_mode when max > 1e8. + # TODO: Pytorch uses fixed precision to print, while Numpy uses dragon4_scientific + # to do automatic trimming and padding. + x = torch.tensor([123456789.]) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([1.2346e+08])''') - def test(name): - torch_method = getattr(torch, name + '_window') - for size in [1, 2, 5, 10, 50, 100, 1024, 2048]: - for periodic in [True, False]: - res = torch_method(size, periodic=periodic, device=device) - ref = torch.from_numpy(signal.get_window(name, size, fftbins=periodic)) - self.assertEqual(res, ref) - with self.assertRaisesRegex(RuntimeError, r'not implemented for sparse types'): - torch_method(3, layout=torch.sparse_coo) - with self.assertRaisesRegex(RuntimeError, r'floating point'): - torch_method(3, dtype=torch.long) - self.assertTrue(torch_method(3, requires_grad=True).requires_grad) - self.assertFalse(torch_method(3).requires_grad) + # [Numpy] test print float in sci_mode when max / min > 1000. + x = torch.tensor([0.01, 11]) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([1.0000e-02, 1.1000e+01])''') - for window in ['hann', 'hamming', 'bartlett', 'blackman']: - test(window) + # [Numpy] test print int max / min > 1000, no sci_mode + x = torch.tensor([1, 1010]) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([ 1, 1010])''') - def test_signal_window_functions(self): - self._test_signal_window_functions(self) + # [Numpy] test print int > 1e8, no sci_mode + x = torch.tensor([1000000000]) # 1e9 + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([1000000000])''') - @staticmethod - def _test_inverse(self, conv_fn): - from common_utils import random_fullrank_matrix_distinct_singular_value + # [Numpy] test printing float in int_mode + x = torch.tensor([1., 1000.]) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([ 1., 1000.])''') - # no batches: 2-D tensors - matrix = conv_fn(random_fullrank_matrix_distinct_singular_value(5)) - matrix_inverse = torch.inverse(matrix) - identity = conv_fn(torch.eye(5)) - self.assertEqual(identity, torch.mm(matrix, matrix_inverse), 1e-8, 'inverse value') - self.assertEqual(identity, torch.mm(matrix_inverse, matrix), 1e-8, 'inverse value') + # [Numpy] test printing float in int_mode in sci format when max / min > 1000. + x = torch.tensor([1., 1010.]) + self.assertEqual(x.__repr__(), str(x)) + self.assertExpectedInline(str(x), '''tensor([1.0000e+00, 1.0100e+03])''') - matrix_inverse_out = conv_fn(torch.empty(5, 5)) - torch.inverse(matrix, out=matrix_inverse_out) - self.assertEqual(matrix_inverse_out, matrix_inverse, 0, 'inverse value in-place') - # second call, now that matrix_inverse_out is transposed - torch.inverse(matrix, out=matrix_inverse_out) - self.assertEqual(matrix_inverse_out, matrix_inverse, 0, 'inverse value in-place') + def test_sizeof(self): + sizeof_empty = torch.randn(0).storage().__sizeof__() + sizeof_10 = torch.randn(10).storage().__sizeof__() + sizeof_100 = torch.randn(100).storage().__sizeof__() + self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10) + self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0) - # one batch - matrix = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 1)) - matrix_inverse = torch.inverse(matrix) - expected_inv = matrix.squeeze(0).inverse() - self.assertEqual(matrix_inverse, expected_inv.unsqueeze(0)) + sizeof_empty = torch.randn(0).type(torch.ByteTensor).storage().__sizeof__() + sizeof_10 = torch.randn(10).type(torch.ByteTensor).storage().__sizeof__() + sizeof_100 = torch.randn(100).type(torch.ByteTensor).storage().__sizeof__() + self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10) + self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0) - # four batches - matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 4)) - expected_inv_list = [] - for i in range(0, 4): - expected_inv_list.append(torch.inverse(matrices[i])) - expected_inv = torch.stack(expected_inv_list) - matrices_inverse = torch.inverse(matrices) - self.assertEqual(matrices_inverse, expected_inv) + def test_unsqueeze(self): + x = torch.randn(2, 3, 4) + y = x.unsqueeze(1) + self.assertEqual(y, x.view(2, 1, 3, 4)) + y = x.clone().unsqueeze_(2) + self.assertEqual(y, x.view(2, 3, 1, 4)) - # six batches (2 x 3) - matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 2, 3)) - expected_inv_list = [] - for mat in matrices.view(-1, 5, 5): - expected_inv_list.append(torch.inverse(mat)) - expected_inv = torch.stack(expected_inv_list).view(2, 3, 5, 5) - matrices_inverse = torch.inverse(matrices) - self.assertEqual(matrices_inverse, expected_inv) + x = x[:, 1] + self.assertFalse(x.is_contiguous()) + y = x.unsqueeze(1) + self.assertEqual(y, x.contiguous().view(2, 1, 4)) + y = x.clone().unsqueeze_(2) + self.assertEqual(y, x.contiguous().view(2, 4, 1)) - # incorrect input test - with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): - torch.inverse(torch.randn(2, 3, 4, 3)) - - # correctness test - matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 3)) - matrices_inverse = torch.inverse(matrices) - self.assertEqual(torch.matmul(matrices, matrices_inverse), identity.expand_as(matrices)) - self.assertEqual(torch.matmul(matrices_inverse, matrices), identity.expand_as(matrices)) - - # torch.inverse with out and batches - matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 3)) - matrices_inverse = conv_fn(torch.empty(3, 5, 5)) - torch.inverse(matrices, out=matrices_inverse) - self.assertEqual(torch.inverse(matrices), matrices_inverse) - - # non-contiguous inputs - if not TEST_NUMPY: - return - - from numpy.linalg import inv - matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(3, 2)).permute(0, 2, 1) - assert not matrices.is_contiguous() - matrices_inverse = torch.inverse(matrices) - expected_inv = torch.as_tensor(inv(matrices.cpu().numpy())) - self.assertEqual(matrices_inverse, conv_fn(expected_inv)) - - @staticmethod - def _test_inverse_slow(self, conv_fn): - from common_utils import random_fullrank_matrix_distinct_singular_value - - matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(5, 256, 256)) - matrices_inverse = torch.inverse(matrices) - self.assertEqual(torch.matmul(matrices_inverse, matrices), - conv_fn(torch.eye(5)).expand_as(matrices)) + def test_iter(self): + x = torch.randn(5, 5) + for i, sub in enumerate(x): + self.assertEqual(sub, x[i]) - matrices = conv_fn(random_fullrank_matrix_distinct_singular_value(3, 512, 512)) - matrices_inverse = torch.inverse(matrices) - self.assertEqual(torch.matmul(matrices, matrices_inverse), - conv_fn(torch.eye(3)).expand_as(matrices)) + x = torch.Tensor() + self.assertEqual(list(x), []) - @skipIfNoLapack - def test_inverse(self): - self._test_inverse(self, lambda t: t) + def test_accreal_type(self): + x = torch.ones(2, 3, 4) + self.assertIsInstance(x.double().sum().item(), float) + self.assertIsInstance(x.float().sum().item(), float) + self.assertIsInstance(x.long().sum().item(), int) + self.assertIsInstance(x.int().sum().item(), int) + self.assertIsInstance(x.short().sum().item(), int) + self.assertIsInstance(x.char().sum().item(), int) + self.assertIsInstance(x.byte().sum().item(), int) - @slowTest - @skipIfNoLapack - def test_inverse_many_batches(self): - self._test_inverse_slow(self, lambda t: t) + def test_assertEqual(self): + x = torch.FloatTensor([0]) + self.assertEqual(x, 0) + xv = torch.autograd.Variable(x) + self.assertEqual(xv, 0) + self.assertEqual(x, xv) + self.assertEqual(xv, x) - @staticmethod - def _test_pinverse(self, conv_fn): - def run_test(M): - # Testing against definition for pseudo-inverses - MPI = torch.pinverse(M) - self.assertEqual(M, M.mm(MPI).mm(M), 1e-8, 'pseudo-inverse condition 1') - self.assertEqual(MPI, MPI.mm(M).mm(MPI), 1e-8, 'pseudo-inverse condition 2') - self.assertEqual(M.mm(MPI), (M.mm(MPI)).t(), 1e-8, 'pseudo-inverse condition 3') - self.assertEqual(MPI.mm(M), (MPI.mm(M)).t(), 1e-8, 'pseudo-inverse condition 4') + def test_new(self): + x = torch.autograd.Variable(torch.Tensor()) + y = torch.autograd.Variable(torch.randn(4, 4)) + z = torch.autograd.Variable(torch.IntTensor([1, 2, 3])) + self.assertEqual(x.new().shape, [0]) + self.assertEqual(x.new(), x) + self.assertEqual(x.new(1, 2).shape, [1, 2]) + self.assertEqual(x.new(torch.Size([3, 4])).shape, [3, 4]) + self.assertEqual(x.new([3, 4]).shape, [2]) + self.assertEqual(x.new([3, 4]).tolist(), [3, 4]) + self.assertEqual(x.new((3, 4)).tolist(), [3, 4]) + if TEST_NUMPY: + self.assertEqual(x.new([np.int32(3), np.float64(4)]).tolist(), [3, 4]) + self.assertEqual(x.new(np.array((3, 4))).tolist(), [3, 4]) + self.assertEqual(x.new([z[2], z[0] + 3]).tolist(), [3, 4]) + self.assertEqual(x.new(size=(3, 4)).shape, [3, 4]) + self.assertEqual(x.new(()).shape, [0]) + self.assertEqual(x.new(y.storage()).data_ptr(), y.data_ptr()) + self.assertEqual(x.new(y).data_ptr(), y.data_ptr()) + self.assertIsNot(x.new(y), y) - # Square matrix - M = conv_fn(torch.randn(5, 5)) - run_test(M) + self.assertRaises(TypeError, lambda: x.new(z)) + # TypeError would be better + self.assertRaises(RuntimeError, lambda: x.new(z.storage())) - # Rectangular matrix - M = conv_fn(torch.randn(3, 4)) - run_test(M) + def test_empty_like(self): + x = torch.autograd.Variable(torch.Tensor()) + y = torch.autograd.Variable(torch.randn(4, 4)) + z = torch.autograd.Variable(torch.IntTensor([1, 2, 3])) + for a in (x, y, z): + self.assertEqual(torch.empty_like(a).shape, a.shape) + self.assertEqual(torch.empty_like(a).type(), a.type()) - # Test inverse and pseudo-inverse for invertible matrix - M = torch.randn(5, 5) - M = conv_fn(M.mm(M.t())) - self.assertEqual(conv_fn(torch.eye(5)), M.pinverse().mm(M), 1e-7, 'pseudo-inverse for invertible matrix') + def test_pin_memory(self): + x = torch.randn(3, 5) + self.assertFalse(x.is_pinned()) + if not torch.cuda.is_available(): + self.assertRaises(RuntimeError, lambda: x.pin_memory()) + else: + pinned = x.pin_memory() + self.assertTrue(pinned.is_pinned()) + self.assertEqual(pinned, x) + self.assertNotEqual(pinned.data_ptr(), x.data_ptr()) + # test that pin_memory on already pinned tensor has no effect + self.assertIs(pinned, pinned.pin_memory()) + self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr()) - @skipIfNoLapack - def test_pinverse(self): - self._test_pinverse(self, conv_fn=lambda x: x) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_numpy_unresizable(self): + x = np.zeros((2, 2)) + y = torch.from_numpy(x) + with self.assertRaises(ValueError): + x.resize((5, 5)) - @staticmethod - def _test_matrix_power(self, conv_fn): - def run_test(M, sign=1): - if sign == -1: - M = M.inverse() - MP2 = torch.matrix_power(M, 2) - self.assertEqual(MP2, torch.matmul(M, M)) + z = torch.randn(5, 5) + w = z.numpy() + with self.assertRaises(RuntimeError): + z.resize_(10, 10) + with self.assertRaises(ValueError): + w.resize((10, 10)) - MP3 = torch.matrix_power(M, 3) - self.assertEqual(MP3, torch.matmul(MP2, M)) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_to_numpy(self): + def get_castable_tensor(shape, tp): + dtype = tp.dtype + if dtype.is_floating_point: + dtype_info = torch.finfo(dtype) + # can't directly use min and max, because for double, max - min + # is greater than double range and sampling always gives inf. + low = max(dtype_info.min, -1e10) + high = min(dtype_info.max, 1e10) + t = torch.empty(shape, dtype=torch.float64).uniform_(low, high) + else: + # can't directly use min and max, because for int64_t, max - min + # is greater than int64_t range and triggers UB. + dtype_info = torch.iinfo(dtype) + low = max(dtype_info.min, int(-1e10)) + high = min(dtype_info.max, int(1e10)) + dtype_info = torch.iinfo(dtype) + t = torch.empty(shape, dtype=torch.int64).random_(low, high) + return t.to(dtype) - MP4 = torch.matrix_power(M, 4) - self.assertEqual(MP4, torch.matmul(MP2, MP2)) + types = [ + torch.ByteTensor, + torch.CharTensor, + torch.ShortTensor, + torch.IntTensor, + torch.HalfTensor, + torch.FloatTensor, + torch.DoubleTensor, + torch.LongTensor, + ] + for tp in types: + # 1D + sz = 10 + x = get_castable_tensor(sz, tp) + y = x.numpy() + for i in range(sz): + self.assertEqual(x[i], y[i]) - MP6 = torch.matrix_power(M, 6) - self.assertEqual(MP6, torch.matmul(MP3, MP3)) + # 1D > 0 storage offset + xm = get_castable_tensor(sz * 2, tp) + x = xm.narrow(0, sz - 1, sz) + self.assertTrue(x.storage_offset() > 0) + y = x.numpy() + for i in range(sz): + self.assertEqual(x[i], y[i]) - MP0 = torch.matrix_power(M, 0) - self.assertEqual(MP0, torch.eye(M.size(-2)).expand_as(M)) + def check2d(x, y): + for i in range(sz1): + for j in range(sz2): + self.assertEqual(x[i][j], y[i][j]) - # Single matrix - M = conv_fn(torch.randn(5, 5)) - run_test(M) + # empty + x = torch.Tensor().type(tp) + y = x.numpy() + self.assertEqual(y.size, 0) - # Batch matrices - M = conv_fn(torch.randn(3, 3, 3)) - run_test(M) + # contiguous 2D + sz1 = 3 + sz2 = 5 + x = get_castable_tensor((sz1, sz2), tp) + y = x.numpy() + check2d(x, y) + self.assertTrue(y.flags['C_CONTIGUOUS']) - # Many batch matrices - M = conv_fn(torch.randn(2, 3, 3, 3)) - run_test(M) + # with storage offset + xm = get_castable_tensor((sz1 * 2, sz2), tp) + x = xm.narrow(0, sz1 - 1, sz1) + y = x.numpy() + self.assertTrue(x.storage_offset() > 0) + check2d(x, y) + self.assertTrue(y.flags['C_CONTIGUOUS']) - # This is for negative powers - from common_utils import random_fullrank_matrix_distinct_singular_value - M = conv_fn(random_fullrank_matrix_distinct_singular_value(5)) - run_test(M, sign=-1) + # non-contiguous 2D + x = get_castable_tensor((sz2, sz1), tp).t() + y = x.numpy() + check2d(x, y) + self.assertFalse(y.flags['C_CONTIGUOUS']) - M = conv_fn(random_fullrank_matrix_distinct_singular_value(3, 3)) - run_test(M, sign=-1) + # with storage offset + xm = get_castable_tensor((sz2 * 2, sz1), tp) + x = xm.narrow(0, sz2 - 1, sz2).t() + y = x.numpy() + self.assertTrue(x.storage_offset() > 0) + check2d(x, y) - M = conv_fn(random_fullrank_matrix_distinct_singular_value(3, 2, 3)) - run_test(M, sign=-1) + # non-contiguous 2D with holes + xm = get_castable_tensor((sz2 * 2, sz1 * 2), tp) + x = xm.narrow(0, sz2 - 1, sz2).narrow(1, sz1 - 1, sz1).t() + y = x.numpy() + self.assertTrue(x.storage_offset() > 0) + check2d(x, y) - @skipIfNoLapack - def test_matrix_power(self): - self._test_matrix_power(self, conv_fn=lambda x: x) + if tp != torch.HalfTensor: + # check writeable + x = get_castable_tensor((3, 4), tp) + y = x.numpy() + self.assertTrue(y.flags.writeable) + y[0][1] = 3 + self.assertTrue(x[0][1] == 3) + y = x.t().numpy() + self.assertTrue(y.flags.writeable) + y[0][1] = 3 + self.assertTrue(x[0][1] == 3) - @staticmethod - def _test_chain_matmul(self, cast): - def product(matrices): - for mat in matrices[1:]: - matrices[0] = matrices[0].mm(mat) - return matrices[0] + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_to_numpy_bool(self): + x = torch.tensor([True, False], dtype=torch.bool) + self.assertEqual(x.dtype, torch.bool) - def run_test(p, cast): - matrices = [] - for (pi, pi_1) in zip(p[:-1], p[1:]): - matrices.append(cast(torch.randn(pi, pi_1))) - self.assertEqual(torch.chain_matmul(*matrices), product(matrices)) + y = x.numpy() + self.assertEqual(y.dtype, np.bool) + for i in range(len(x)): + self.assertEqual(x[i], y[i]) - run_test([10, 20, 30, 5], cast) - run_test([15, 5, 10, 20, 25], cast) + x = torch.tensor([True], dtype=torch.bool) + self.assertEqual(x.dtype, torch.bool) - def test_chain_matmul(self): - self._test_chain_matmul(self, cast=lambda x: x) + y = x.numpy() + self.assertEqual(y.dtype, np.bool) + self.assertEqual(x[0], y[0]) - @staticmethod - def _test_det_logdet_slogdet(self, device): - def reference_slogdet(M): - if TEST_NUMPY: - sdet, logabsdet = np.linalg.slogdet(M.detach().cpu().numpy()) - return M.new_tensor(sdet), M.new_tensor(logabsdet) - else: - # naive row reduction - M = M.clone() - l = M.size(0) - multiplier = 1 - for i in range(l): - if M[i, 0].item() != 0: - if i != 0: - M[0], M[i] = M[i], M[0] - multiplier = -1 - break - else: - return 0 - for i in range(1, l): - row = M[i] - for j in range(i): - row -= row[j] / M[j, j] * M[j] - M[i] = row - sdet = M.diag().sign().prod() - logabsdet = M.diag().abs_().log_().sum().add_(math.log(multiplier)) - return sdet, logabsdet - - def test_single_det(M, target, desc): - target_sdet, target_logabsdet = target + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_from_numpy(self): + dtypes = [ + np.double, + np.float, + np.float16, + np.int64, + np.int32, + np.int16, + np.int8, + np.uint8, + np.longlong, + np.bool, + ] + for dtype in dtypes: + array = np.array([1, 2, 3, 4], dtype=dtype) + tensor_from_array = torch.from_numpy(array) + # TODO: change to tensor equality check once HalfTensor + # implements `==` + for i in range(len(array)): + self.assertEqual(tensor_from_array[i], array[i]) + # This is a special test case for Windows + # https://github.com/pytorch/pytorch/issues/22615 + array2 = array % 2 + tensor_from_array2 = torch.from_numpy(array2) + for i in range(len(array2)): + self.assertEqual(tensor_from_array2[i], array2[i]) - det = M.det() - logdet = M.logdet() - sdet, logabsdet = M.slogdet() + # Test unsupported type + array = np.array([1, 2, 3, 4], dtype=np.complex) + with self.assertRaises(TypeError): + tensor_from_array = torch.from_numpy(array) - # Test det - self.assertEqual(det, target_sdet * target_logabsdet.exp(), 1e-7, '{} (det)'.format(desc)) + # check storage offset + x = np.linspace(1, 125, 125) + x.shape = (5, 5, 5) + x = x[1] + expected = torch.arange(1, 126).view(5, 5, 5)[1] + self.assertEqual(torch.from_numpy(x), expected) - # Test slogdet - # Compare the overall value rather than individual parts because of - # precision issues when det is near zero. - self.assertEqual(sdet * logabsdet.exp(), target_sdet * target_logabsdet.exp(), 1e-7, '{} (slogdet)'.format(desc)) + # check noncontiguous + x = np.linspace(1, 25, 25) + x.shape = (5, 5) + expected = torch.arange(1, 26).view(5, 5).t() + self.assertEqual(torch.from_numpy(x.T), expected) - # Test logdet - # Compare logdet against our own pytorch slogdet because they should - # be consistent, while it may behave slightly differently with other - # slogdet implementations when det is near zero due to precision - # issues. - if sdet.item() < 0: - self.assertTrue(logdet.item() != logdet.item(), '{} (logdet negative case)'.format(desc)) - else: - self.assertEqual(logdet.exp(), target_logabsdet.exp(), 1e-7, '{} (logdet non-negative case)'.format(desc)) + # check noncontiguous with holes + x = np.linspace(1, 125, 125) + x.shape = (5, 5, 5) + x = x[:, 1] + expected = torch.arange(1, 126).view(5, 5, 5)[:, 1] + self.assertEqual(torch.from_numpy(x), expected) - eye = torch.eye(5, device=device) - test_single_det(eye, (torch.ones((), device=device), torch.zeros((), device=device)), 'identity') + # check zero dimensional + x = np.zeros((0, 2)) + self.assertEqual(torch.from_numpy(x).shape, (0, 2)) + x = np.zeros((2, 0)) + self.assertEqual(torch.from_numpy(x).shape, (2, 0)) - def test(M): - assert M.size(0) >= 5, 'this helper fn assumes M to be at least 5x5' - M = M.to(device) + # check ill-sized strides raise exception + x = np.array([3., 5., 8.]) + x.strides = (3,) + self.assertRaises(ValueError, lambda: torch.from_numpy(x)) - ref_M_sdet, ref_M_logabsdet = reference_slogdet(M) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_ctor_with_numpy_scalar_ctor(self): + dtypes = [ + np.double, + np.float, + np.float16, + np.int64, + np.int32, + np.int16, + np.uint8, + np.bool, + ] + for dtype in dtypes: + self.assertEqual(dtype(42), torch.tensor(dtype(42)).item()) - test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'basic') - if ref_M_logabsdet.exp().item() >= 1e-6: # skip singular - M_inv = M.inverse() - test_single_det(M_inv, reference_slogdet(M_inv), 'inverse') + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_numpy_index(self): + i = np.int32([0, 1, 2]) + x = torch.randn(5, 5) + for idx in i: + self.assertFalse(isinstance(idx, int)) + self.assertEqual(x[idx], x[int(idx)]) - test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'transpose') + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_numpy_array_interface(self): + types = [ + torch.DoubleTensor, + torch.FloatTensor, + torch.HalfTensor, + torch.LongTensor, + torch.IntTensor, + torch.ShortTensor, + torch.ByteTensor, + ] + dtypes = [ + np.float64, + np.float32, + np.float16, + np.int64, + np.int32, + np.int16, + np.uint8, + ] + for tp, dtype in zip(types, dtypes): + if np.dtype(dtype).kind == 'u': + x = torch.Tensor([1, 2, 3, 4]).type(tp) + array = np.array([1, 2, 3, 4], dtype=dtype) + else: + x = torch.Tensor([1, -2, 3, -4]).type(tp) + array = np.array([1, -2, 3, -4], dtype=dtype) - for x in [0, 2, 4]: - for scale in [-2, -0.1, 0, 10]: - if scale > 0: - target = ref_M_sdet, ref_M_logabsdet + math.log(scale) - elif scale == 0: - target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) - else: - target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-scale) + # Test __array__ w/o dtype argument + asarray = np.asarray(x) + self.assertIsInstance(asarray, np.ndarray) + self.assertEqual(asarray.dtype, dtype) + for i in range(len(x)): + self.assertEqual(asarray[i], x[i]) - # dim 0 - M_clone = M.clone() - M_clone[:, x] *= scale - test_single_det(M_clone, target, 'scale a row') - # dim 1 - M_clone = M.clone() - M_clone[x, :] *= scale - test_single_det(M_clone, target, 'scale a column') + # Test __array_wrap__, same dtype + abs_x = np.abs(x) + abs_array = np.abs(array) + self.assertIsInstance(abs_x, tp) + for i in range(len(x)): + self.assertEqual(abs_x[i], abs_array[i]) - for x1, x2 in [(0, 3), (4, 1), (3, 2)]: - assert x1 != x2, 'x1 and x2 needs to be different for this test' - target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) - # dim 0 - M_clone = M.clone() - M_clone[:, x2] = M_clone[:, x1] - test_single_det(M_clone, target, 'two rows are same') - # dim 1 - M_clone = M.clone() - M_clone[x2, :] = M_clone[x1, :] - test_single_det(M_clone, target, 'two columns are same') + # Test __array__ with dtype argument + for dtype in dtypes: + x = torch.IntTensor([1, -2, 3, -4]) + asarray = np.asarray(x, dtype=dtype) + self.assertEqual(asarray.dtype, dtype) + if np.dtype(dtype).kind == 'u': + wrapped_x = np.array([1, -2, 3, -4], dtype=dtype) + for i in range(len(x)): + self.assertEqual(asarray[i], wrapped_x[i]) + else: + for i in range(len(x)): + self.assertEqual(asarray[i], x[i]) - for scale1, scale2 in [(0.3, -1), (0, 2), (10, 0.1)]: - det_scale = scale1 * scale2 * -1 - if det_scale > 0: - target = ref_M_sdet, ref_M_logabsdet + math.log(det_scale) - elif det_scale == 0: - target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) - else: - target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-det_scale) + # Test some math functions with float types + float_types = [torch.DoubleTensor, torch.FloatTensor] + float_dtypes = [np.float64, np.float32] + for tp, dtype in zip(float_types, float_dtypes): + x = torch.Tensor([1, 2, 3, 4]).type(tp) + array = np.array([1, 2, 3, 4], dtype=dtype) + for func in ['sin', 'sqrt', 'ceil']: + ufunc = getattr(np, func) + res_x = ufunc(x) + res_array = ufunc(array) + self.assertIsInstance(res_x, tp) + for i in range(len(x)): + self.assertEqual(res_x[i], res_array[i]) - # dim 0 - M_clone = M.clone() - t = M_clone[:, x1] * scale1 - M_clone[:, x1] += M_clone[:, x2] * scale2 - M_clone[:, x2] = t - test_single_det(M_clone, target, 'exchanging rows') - # dim 1 - M_clone = M.clone() - t = M_clone[x1, :] * scale1 - M_clone[x1, :] += M_clone[x2, :] * scale2 - M_clone[x2, :] = t - test_single_det(M_clone, target, 'exchanging columns') + # Test functions with boolean return value + for tp, dtype in zip(types, dtypes): + x = torch.Tensor([1, 2, 3, 4]).type(tp) + array = np.array([1, 2, 3, 4], dtype=dtype) + geq2_x = np.greater_equal(x, 2) + geq2_array = np.greater_equal(array, 2).astype('uint8') + self.assertIsInstance(geq2_x, torch.ByteTensor) + for i in range(len(x)): + self.assertEqual(geq2_x[i], geq2_array[i]) - def get_random_mat_scale(n): - # For matrices with values i.i.d. with 0 mean, unit variance, and - # subexponential tail, we have: - # E[log det(A^2)] \approx log((n-1)!) - # - # Notice: - # log Var[det(A)] = log E[det(A^2)] >= E[log det(A^2)] - # - # So: - # stddev[det(A)] >= sqrt( (n-1)! ) - # - # We use this as an intuitive guideline to scale random generated - # matrices so our closeness tests can work more robustly: - # scale by sqrt( (n-1)! )^(-1/n) = ( (n-1)! )^(-1/(2n)) - # - # source: https://arxiv.org/pdf/1112.0752.pdf - - # TODO: technically we need subexponential distn for this to hold, - # but we mostly use gaussian entries below. Consider switching - # to Chi-sq if this turns out not stable enough, since Chi-sq - # is easy enough to sample from. - return math.factorial(n - 1) ** (-1.0 / (2 * n)) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_multiplication_numpy_scalar(self): + for np_dtype in [np.float32, np.float64, np.int32, np.int64, np.int16, np.uint8]: + for t_dtype in [torch.float, torch.double]: + np_sc = np_dtype(2.0) + t = torch.ones(2, requires_grad=True, dtype=t_dtype) + r1 = t * np_sc + self.assertIsInstance(r1, torch.Tensor) + self.assertTrue(r1.dtype == t_dtype) + self.assertTrue(r1.requires_grad) + r2 = np_sc * t + self.assertIsInstance(r2, torch.Tensor) + self.assertTrue(r2.dtype == t_dtype) + self.assertTrue(r2.requires_grad) - for n in [5, 10, 25]: - scale = get_random_mat_scale(n) - test(torch.randn(n, n, device=device) * scale) - r = torch.randn(n, n, device=device) * scale - # symmetric psd - test(r.mm(r.t())) - # symmetric pd - r = torch.randn(n, n, device=device) * scale - test(r.mm(r.t()) + torch.eye(n, device=device) * 1e-6) - # symmetric - r = torch.randn(n, n, device=device) * scale - for i in range(n): - for j in range(i): - r[i, j] = r[j, i] - test(r) - # non-contiguous - test((torch.randn(n, n, n + 1, device=device) * scale)[:, 2, 1:]) - # det = 0 - r = torch.randn(n, n, device=device) * scale - u, s, v = r.svd() - if reference_slogdet(u)[0] < 0: - u = -u - if reference_slogdet(v)[0] < 0: - v = -v - s[0] *= -1 - s[-1] = 0 - test(u.mm(s.diag()).mm(v)) + def test_error_msg_type_translation(self): + with self.assertRaisesRegex( + RuntimeError, + # message includes both Double and Long + '(?=.*Double)(?=.*Long)'): - # Small values to test numerical stability. Note that we don't scale - # this matrix. - r = torch.randn(512, 512, device=device) - u, s, v = r.svd() - s.fill_(1. / (100 * s.numel())) - test(u.mm(s.diag()).mm(v)) + # Calls model with a DoubleTensor input but LongTensor weights + input = torch.autograd.Variable(torch.randn(1, 1, 1, 6).double()) + weight = torch.zeros(1, 1, 1, 3).long() + model = torch.nn.Conv2d(1, 1, (1, 3), stride=1, padding=0, bias=False) + model.weight.data = weight + out = model(input) - @skipIfNoLapack - @skipIfRocm - def test_det_logdet_slogdet(self): - self._test_det_logdet_slogdet(self, 'cpu') + def test_tensor_from_sequence(self): + class MockSequence(object): + def __init__(self, lst): + self.lst = lst - @staticmethod - def _test_det_logdet_slogdet_batched(self, device): - from common_utils import (random_symmetric_matrix, random_symmetric_psd_matrix, - random_symmetric_pd_matrix, random_square_matrix_of_rank) + def __len__(self): + return len(self.lst) - # mat_chars denotes matrix characteristics - # possible values are: sym, sym_psd, sym_pd, sing, non_sym - def run_test(matsize, batchdims, mat_chars): - num_matrices = reduce(lambda x, y: x * y, batchdims, 1) - list_of_matrices = [] + def __getitem__(self, item): + raise TypeError - for idx in range(num_matrices): - mat_type = idx % len(mat_chars) - if mat_chars[mat_type] == 'sym': - list_of_matrices.append(random_symmetric_matrix(matsize).to(device=device)) - elif mat_chars[mat_type] == 'sym_psd': - list_of_matrices.append(random_symmetric_psd_matrix(matsize).to(device=device)) - elif mat_chars[mat_type] == 'sym_pd': - list_of_matrices.append(random_symmetric_pd_matrix(matsize).to(device=device)) - elif mat_chars[mat_type] == 'sing': - list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize // 2).to(device=device)) - elif mat_chars[mat_type] == 'non_sing': - list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize).to(device=device)) - full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize)) - # Scaling adapted from `get_random_mat_scale` in _test_det_logdet_slogdet - full_tensor *= (math.factorial(matsize - 1) ** (-1.0 / (2 * matsize))) + class GoodMockSequence(MockSequence): + def __getitem__(self, item): + return self.lst[item] - for fn in [torch.det, torch.logdet, torch.slogdet]: - expected_value = [] - actual_value = fn(full_tensor) - for full_idx in product(*map(lambda x: list(range(x)), batchdims)): - expected_value.append(fn(full_tensor[full_idx])) + bad_mock_seq = MockSequence([1.0, 2.0, 3.0]) + good_mock_seq = GoodMockSequence([1.0, 2.0, 3.0]) + with self.assertRaisesRegex(ValueError, 'could not determine the shape'): + torch.Tensor(bad_mock_seq) + self.assertEqual(torch.Tensor([1.0, 2.0, 3.0]), torch.Tensor(good_mock_seq)) - if fn == torch.slogdet: - sign_value = torch.stack([tup[0] for tup in expected_value], dim=0).reshape(batchdims) - expected_value = torch.stack([tup[1] for tup in expected_value], dim=0).reshape(batchdims) - self.assertEqual(sign_value, actual_value[0], allow_inf=True) - self.assertEqual(expected_value, actual_value[1], allow_inf=True) - else: - expected_value = torch.stack(expected_value, dim=0).reshape(batchdims) - self.assertEqual(actual_value, expected_value, allow_inf=True) + def test_comparison_ops(self): + x = torch.randn(5, 5) + y = torch.randn(5, 5) - for matsize, batchdims in product([3, 5], [(3,), (5, 3)]): - run_test(matsize, batchdims, mat_chars=['sym_pd']) - run_test(matsize, batchdims, mat_chars=['sing']) - run_test(matsize, batchdims, mat_chars=['non_sing']) - run_test(matsize, batchdims, mat_chars=['sym', 'sym_pd', 'sym_psd']) - run_test(matsize, batchdims, mat_chars=['sing', 'non_sing']) + eq = x == y + for idx in iter_indices(x): + self.assertEqual(x[idx] == y[idx], eq[idx] == 1) - @skipIfNoLapack - def test_det_logdet_slogdet_batched(self): - self._test_det_logdet_slogdet_batched(self, 'cpu') + ne = x != y + for idx in iter_indices(x): + self.assertEqual(x[idx] != y[idx], ne[idx] == 1) - @staticmethod - def _test_fft_ifft_rfft_irfft(self, device='cpu'): - def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x): - x = prepro_fn(torch.randn(*sizes, device=device)) - for normalized in (True, False): - res = x.fft(signal_ndim, normalized=normalized) - rec = res.ifft(signal_ndim, normalized=normalized) - self.assertEqual(x, rec, 1e-8, 'fft and ifft') - res = x.ifft(signal_ndim, normalized=normalized) - rec = res.fft(signal_ndim, normalized=normalized) - self.assertEqual(x, rec, 1e-8, 'ifft and fft') + lt = x < y + for idx in iter_indices(x): + self.assertEqual(x[idx] < y[idx], lt[idx] == 1) - def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x): - x = prepro_fn(torch.randn(*sizes, device=device)) - signal_numel = 1 - signal_sizes = x.size()[-signal_ndim:] - for normalized, onesided in product((True, False), repeat=2): - res = x.rfft(signal_ndim, normalized=normalized, onesided=onesided) - if not onesided: # check Hermitian symmetry - def test_one_sample(res, test_num=10): - idxs_per_dim = [torch.LongTensor(test_num).random_(s).tolist() for s in signal_sizes] - for idx in zip(*idxs_per_dim): - reflected_idx = tuple((s - i) % s for i, s in zip(idx, res.size())) - idx_val = res.__getitem__(idx) - reflected_val = res.__getitem__(reflected_idx) - self.assertEqual(idx_val[0], reflected_val[0], 'rfft hermitian symmetry on real part') - self.assertEqual(idx_val[1], -reflected_val[1], 'rfft hermitian symmetry on imaginary part') - if len(sizes) == signal_ndim: - test_one_sample(res) - else: - output_non_batch_shape = res.size()[-(signal_ndim + 1):] - flatten_batch_res = res.view(-1, *output_non_batch_shape) - nb = flatten_batch_res.size(0) - test_idxs = torch.LongTensor(min(nb, 4)).random_(nb) - for test_idx in test_idxs.tolist(): - test_one_sample(flatten_batch_res[test_idx]) - # compare with C2C - xc = torch.stack([x, torch.zeros_like(x)], -1) - xc_res = xc.fft(signal_ndim, normalized=normalized) - self.assertEqual(res, xc_res) - test_input_signal_sizes = [signal_sizes] - rec = res.irfft(signal_ndim, normalized=normalized, - onesided=onesided, signal_sizes=signal_sizes) - self.assertEqual(x, rec, 1e-8, 'rfft and irfft') - if not onesided: # check that we can use C2C ifft - rec = res.ifft(signal_ndim, normalized=normalized) - self.assertEqual(x, rec.select(-1, 0), 1e-8, 'twosided rfft and ifft real') - self.assertEqual(rec.select(-1, 1).data.abs().mean(), 0, 1e-8, 'twosided rfft and ifft imaginary') + le = x <= y + for idx in iter_indices(x): + self.assertEqual(x[idx] <= y[idx], le[idx] == 1) - # contiguous case - _test_real((100,), 1) - _test_real((10, 1, 10, 100), 1) - _test_real((100, 100), 2) - _test_real((2, 2, 5, 80, 60), 2) - _test_real((50, 40, 70), 3) - _test_real((30, 1, 50, 25, 20), 3) + gt = x > y + for idx in iter_indices(x): + self.assertEqual(x[idx] > y[idx], gt[idx] == 1) - _test_complex((100, 2), 1) - _test_complex((100, 100, 2), 1) - _test_complex((100, 100, 2), 2) - _test_complex((1, 20, 80, 60, 2), 2) - _test_complex((50, 40, 70, 2), 3) - _test_complex((6, 5, 50, 25, 20, 2), 3) + ge = x >= y + for idx in iter_indices(x): + self.assertEqual(x[idx] >= y[idx], ge[idx] == 1) - # non-contiguous case - _test_real((165,), 1, lambda x: x.narrow(0, 25, 100)) # input is not aligned to complex type - _test_real((100, 100, 3), 1, lambda x: x[:, :, 0]) - _test_real((100, 100), 2, lambda x: x.t()) - _test_real((20, 100, 10, 10), 2, lambda x: x.view(20, 100, 100)[:, :60]) - _test_real((65, 80, 115), 3, lambda x: x[10:60, 13:53, 10:80]) - _test_real((30, 20, 50, 25), 3, lambda x: x.transpose(1, 2).transpose(2, 3)) + def test_bitwise_ops(self): + x = torch.randn(5, 5).gt(0) + y = torch.randn(5, 5).gt(0) - _test_complex((2, 100), 1, lambda x: x.t()) - _test_complex((100, 2), 1, lambda x: x.expand(100, 100, 2)) - _test_complex((300, 200, 3), 2, lambda x: x[:100, :100, 1:]) # input is not aligned to complex type - _test_complex((20, 90, 110, 2), 2, lambda x: x[:, 5:85].narrow(2, 5, 100)) - _test_complex((40, 60, 3, 80, 2), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:]) - _test_complex((30, 55, 50, 22, 2), 3, lambda x: x[:, 3:53, 15:40, 1:21]) + and_result = x & y + for idx in iter_indices(x): + if and_result[idx]: + self.assertTrue(x[idx] and y[idx]) + else: + self.assertFalse(x[idx] and y[idx]) - # non-contiguous with strides not representable as aligned with complex type - _test_complex((50,), 1, lambda x: x.as_strided([5, 5, 2], [3, 2, 1])) - _test_complex((50,), 1, lambda x: x.as_strided([5, 5, 2], [4, 2, 2])) - _test_complex((50,), 1, lambda x: x.as_strided([5, 5, 2], [4, 3, 1])) - _test_complex((50,), 2, lambda x: x.as_strided([5, 5, 2], [3, 3, 1])) - _test_complex((50,), 2, lambda x: x.as_strided([5, 5, 2], [4, 2, 2])) - _test_complex((50,), 2, lambda x: x.as_strided([5, 5, 2], [4, 3, 1])) + or_result = x | y + for idx in iter_indices(x): + if or_result[idx]: + self.assertTrue(x[idx] or y[idx]) + else: + self.assertFalse(x[idx] or y[idx]) - @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support") - def test_fft_ifft_rfft_irfft(self): - self._test_fft_ifft_rfft_irfft(self) - - @staticmethod - def _test_stft(self, device='cpu'): - if not TEST_LIBROSA: - raise unittest.SkipTest('librosa not found') - - def librosa_stft(x, n_fft, hop_length, win_length, window, center): - if window is None: - window = np.ones(n_fft if win_length is None else win_length) + xor_result = x ^ y + for idx in iter_indices(x): + if xor_result[idx]: + self.assertTrue(x[idx] ^ y[idx]) else: - window = window.cpu().numpy() - input_1d = x.dim() == 1 - if input_1d: - x = x.view(1, -1) - result = [] - for xi in x: - ri = librosa.stft(xi.cpu().numpy(), n_fft, hop_length, win_length, window, center=center) - result.append(torch.from_numpy(np.stack([ri.real, ri.imag], -1))) - result = torch.stack(result, 0) - if input_1d: - result = result[0] - return result + self.assertFalse(x[idx] ^ y[idx]) - def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None, - center=True, expected_error=None): - x = torch.randn(*sizes, device=device) - if win_sizes is not None: - window = torch.randn(*win_sizes, device=device) - else: - window = None - if expected_error is None: - result = x.stft(n_fft, hop_length, win_length, window, center=center) - ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center) - self.assertEqual(result, ref_result, 7e-6, 'stft comparison against librosa') - else: - self.assertRaises(expected_error, - lambda: x.stft(n_fft, hop_length, win_length, window, center=center)) + x_clone = x.clone() + x_clone &= y + self.assertEqual(x_clone, and_result) - for center in [True, False]: - _test((10,), 7, center=center) - _test((10, 4000), 1024, center=center) + x_clone = x.clone() + x_clone |= y + self.assertEqual(x_clone, or_result) - _test((10,), 7, 2, center=center) - _test((10, 4000), 1024, 512, center=center) + x_clone = x.clone() + x_clone ^= y + self.assertEqual(x_clone, xor_result) - _test((10,), 7, 2, win_sizes=(7,), center=center) - _test((10, 4000), 1024, 512, win_sizes=(1024,), center=center) + def test_op_invert(self): + res = 0xffff - torch.arange(127, dtype=torch.int8) + for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): + a = torch.arange(127, dtype=dtype) + self.assertEqual(res.to(dtype), ~a) - # spectral oversample - _test((10,), 7, 2, win_length=5, center=center) - _test((10, 4000), 1024, 512, win_length=100, center=center) + self.assertEqual(torch.tensor([True, False]), + ~torch.tensor([False, True])) - _test((10, 4, 2), 1, 1, expected_error=RuntimeError) - _test((10,), 11, 1, center=False, expected_error=RuntimeError) - _test((10,), -1, 1, expected_error=RuntimeError) - _test((10,), 3, win_length=5, expected_error=RuntimeError) - _test((10,), 5, 4, win_sizes=(11,), expected_error=RuntimeError) - _test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError) + # test exceptions + for dtype in(torch.half, torch.float, torch.double): + a = torch.zeros(10, dtype=dtype) + with self.assertRaises(TypeError): + b = ~a - # passes on ROCm w/ python 2.7, fails w/ python 3.6 - @skipIfRocm - def test_stft(self): - self._test_stft(self) + def test_apply(self): + x = torch.arange(1, 6) + res = x.clone().apply_(lambda k: k + k) + self.assertEqual(res, x * 2) + self.assertRaises(TypeError, lambda: x.apply_(lambda k: "str")) - @unittest.skip("Not implemented yet") - def test_conv2(self): - x = torch.rand(math.floor(torch.uniform(50, 100)), math.floor(torch.uniform(50, 100))) - k = torch.rand(math.floor(torch.uniform(10, 20)), math.floor(torch.uniform(10, 20))) - imvc = torch.conv2(x, k) - imvc2 = torch.conv2(x, k, 'V') - imfc = torch.conv2(x, k, 'F') + def test_map(self): + x = torch.autograd.Variable(torch.randn(3, 3)) + y = torch.autograd.Variable(torch.randn(3)) + res = x.clone() + res.map_(y, lambda a, b: a + b) + self.assertEqual(res, x + y) + self.assertRaisesRegex(TypeError, "not callable", lambda: res.map_(y, "str")) - ki = k.clone() - ks = k.storage() - kis = ki.storage() - for i in range(ks.size() - 1, 0, -1): - kis[ks.size() - i + 1] = ks[i] - # for i=ks.size(), 1, -1 do kis[ks.size()-i+1]=ks[i] end - imvx = torch.xcorr2(x, ki) - imvx2 = torch.xcorr2(x, ki, 'V') - imfx = torch.xcorr2(x, ki, 'F') + def test_map2(self): + x = torch.autograd.Variable(torch.randn(3, 3)) + y = torch.autograd.Variable(torch.randn(3)) + z = torch.autograd.Variable(torch.randn(1, 3)) + res = x.clone() + res.map2_(y, z, lambda a, b, c: a + b * c) + self.assertEqual(res, x + y * z) + z.requires_grad = True + self.assertRaisesRegex( + RuntimeError, "requires grad", + lambda: res.map2_(y, z, lambda a, b, c: a + b * c)) - self.assertEqual(imvc, imvc2, 0, 'torch.conv2') - self.assertEqual(imvc, imvx, 0, 'torch.conv2') - self.assertEqual(imvc, imvx2, 0, 'torch.conv2') - self.assertEqual(imfc, imfx, 0, 'torch.conv2') - self.assertLessEqual(math.abs(x.dot(x) - torch.xcorr2(x, x)[0][0]), 1e-10, 'torch.conv2') + def test_Size(self): + x = torch.Size([1, 2, 3]) + self.assertIsInstance(x, tuple) + self.assertEqual(x[0], 1) + self.assertEqual(x[1], 2) + self.assertEqual(x[2], 3) + self.assertEqual(len(x), 3) + self.assertRaises(TypeError, lambda: torch.Size(torch.ones(3))) - xx = torch.Tensor(2, x.size(1), x.size(2)) - xx[1].copy_(x) - xx[2].copy_(x) - kk = torch.Tensor(2, k.size(1), k.size(2)) - kk[1].copy_(k) - kk[2].copy_(k) + self.assertIsInstance(x * 2, torch.Size) + self.assertIsInstance(x[:-1], torch.Size) + self.assertIsInstance(x + x, torch.Size) - immvc = torch.conv2(xx, kk) - immvc2 = torch.conv2(xx, kk, 'V') - immfc = torch.conv2(xx, kk, 'F') + def test_Size_scalar(self): + three = torch.tensor(3) + two = torch.tensor(2) + x = torch.Size([0, 1, two, three, 4]) + for i in range(1, 5): + self.assertEqual(x[i], i) - self.assertEqual(immvc[0], immvc[1], 0, 'torch.conv2') - self.assertEqual(immvc[0], imvc, 0, 'torch.conv2') - self.assertEqual(immvc2[0], imvc2, 0, 'torch.conv2') - self.assertEqual(immfc[0], immfc[1], 0, 'torch.conv2') - self.assertEqual(immfc[0], imfc, 0, 'torch.conv2') + def test_Size_iter(self): + for sizes in [iter([1, 2, 3, 4, 5]), range(1, 6)]: + x = torch.Size(sizes) + for i in range(0, 5): + self.assertEqual(x[i], i + 1) - @unittest.skip("Not implemented yet") - def test_conv3(self): - x = torch.rand(math.floor(torch.uniform(20, 40)), - math.floor(torch.uniform(20, 40)), - math.floor(torch.uniform(20, 40))) - k = torch.rand(math.floor(torch.uniform(5, 10)), - math.floor(torch.uniform(5, 10)), - math.floor(torch.uniform(5, 10))) - imvc = torch.conv3(x, k) - imvc2 = torch.conv3(x, k, 'V') - imfc = torch.conv3(x, k, 'F') + def test_t_not_2d_error(self): + self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t()) + self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t_()) - ki = k.clone() - ks = k.storage() - kis = ki.storage() - for i in range(ks.size() - 1, 0, -1): - kis[ks.size() - i + 1] = ks[i] - imvx = torch.xcorr3(x, ki) - imvx2 = torch.xcorr3(x, ki, 'V') - imfx = torch.xcorr3(x, ki, 'F') + # unit test for special case transposed copy (see ATen/native/Copy.cpp for details) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_big_transpose(self): + t = torch.rand(456, 789) + t1 = t.t().contiguous() + t2 = torch.from_numpy(t.numpy().transpose()) + self.assertEqual(t1, t2) - self.assertEqual(imvc, imvc2, 0, 'torch.conv3') - self.assertEqual(imvc, imvx, 0, 'torch.conv3') - self.assertEqual(imvc, imvx2, 0, 'torch.conv3') - self.assertEqual(imfc, imfx, 0, 'torch.conv3') - self.assertLessEqual(math.abs(x.dot(x) - torch.xcorr3(x, x)[0][0][0]), 4e-10, 'torch.conv3') + def test_inplace_division(self): + t = torch.rand(5, 5) + id_before = id(t) + t /= 2 + id_after = id(t) + self.assertEqual(id_before, id_after) - xx = torch.Tensor(2, x.size(1), x.size(2), x.size(3)) - xx[1].copy_(x) - xx[2].copy_(x) - kk = torch.Tensor(2, k.size(1), k.size(2), k.size(3)) - kk[1].copy_(k) - kk[2].copy_(k) + def test_simple_scalar_cast(self): + ok = [torch.Tensor([1.5]), torch.zeros(1, 1, 1, 1)] + ok_values = [1.5, 0] - immvc = torch.conv3(xx, kk) - immvc2 = torch.conv3(xx, kk, 'V') - immfc = torch.conv3(xx, kk, 'F') + not_ok = map(torch.Tensor, [[], [1, 2], [[1, 2], [3, 4]]]) - self.assertEqual(immvc[0], immvc[1], 0, 'torch.conv3') - self.assertEqual(immvc[0], imvc, 0, 'torch.conv3') - self.assertEqual(immvc2[0], imvc2, 0, 'torch.conv3') - self.assertEqual(immfc[0], immfc[1], 0, 'torch.conv3') - self.assertEqual(immfc[0], imfc, 0, 'torch.conv3') + for tensor, value in zip(ok, ok_values): + self.assertEqual(int(tensor), int(value)) + self.assertEqual(float(tensor), float(value)) + if sys.version_info[0] < 3: + self.assertEqual(long(tensor), long(value)) - @unittest.skip("Not implemented yet") - def _test_conv_corr_eq(self, fn, fn_2_to_3): - ix = math.floor(random.randint(20, 40)) - iy = math.floor(random.randint(20, 40)) - iz = math.floor(random.randint(20, 40)) - kx = math.floor(random.randint(5, 10)) - ky = math.floor(random.randint(5, 10)) - kz = math.floor(random.randint(5, 10)) + for tensor in not_ok: + self.assertRaises(ValueError, lambda: int(tensor)) + self.assertRaises(ValueError, lambda: float(tensor)) + if sys.version_info[0] < 3: + self.assertRaises(ValueError, lambda: long(tensor)) - x = torch.rand(ix, iy, iz) - k = torch.rand(kx, ky, kz) + def test_offset_scalar_cast(self): + x = torch.Tensor([1, 2, 3]) + y = x[2:] + self.assertEqual(int(y), 3) - o3 = fn(x, k) - o32 = torch.zeros(o3.size()) - fn_2_to_3(x, k, o3, o32) - self.assertEqual(o3, o32) + # skip this test for now as it affects all tests + @unittest.skipIf(True, "flush_denormal not supported") + def test_set_flush_denormal(self): + tiny_float = 1e-42 + tiny_double = 1e-320 + float_tensor = torch.FloatTensor([1.0, tiny_float]) + double_tensor = torch.DoubleTensor([1.0, tiny_float, tiny_double]) - @unittest.skip("Not implemented yet") - def test_xcorr3_xcorr2_eq(self): - def reference(x, k, o3, o32): - for i in range(o3.size(1)): - for j in range(k.size(1)): - o32[i].add(torch.xcorr2(x[i + j - 1], k[j])) - self._test_conv_corr_eq(torch.xcorr3, reference) + self.assertEqual(float_tensor[0], 1.0, prec=0.0) + self.assertEqual(float_tensor[1], tiny_float, prec=tiny_float / 16) + self.assertEqual(double_tensor[0], 1.0, prec=0.0) + self.assertEqual(double_tensor[1], tiny_float, prec=0.0) + self.assertEqual(double_tensor[2], tiny_double, prec=0.0) - @unittest.skip("Not implemented yet") - def test_xcorr3_xcorr2_eq_full(self): - def reference(x, k, o3, o32): - for i in range(x.size(1)): - for j in range(k.size(1)): - o32[i].add(torch.xcorr2(x[i], k[k.size(1) - j + 1], 'F')) - self._test_conv_corr_eq(lambda x, k: torch.xcorr3(x, k, 'F'), reference) + torch.set_flush_denormal(True) + self.assertEqual(float_tensor[0], 1.0, prec=0.0) + self.assertEqual(float_tensor[1], 0.0, prec=0.0) # tiny_float to zero + self.assertEqual(double_tensor[0], 1.0, prec=0.0) + # tiny_float is not converted to zero in double type + self.assertEqual(double_tensor[1], tiny_float, prec=0.0) + self.assertEqual(double_tensor[2], 0.0, prec=0.0) # tiny_double to zero + torch.set_flush_denormal(False) - @unittest.skip("Not implemented yet") - def test_conv3_conv2_eq_valid(self): - def reference(x, k, o3, o32): - for i in range(o3.size(1)): - for j in range(k.size(1)): - o32[i].add(torch.conv2(x[i + j - 1], k[k.size(1) - j + 1])) - self._test_conv_corr_eq(torch.conv3, reference) + def test_show_config(self): + # We can't usefully test the output; just make sure this doesn't crash + torch.__config__.show() - @unittest.skip("Not implemented yet") - def test_fconv3_fconv2_eq(self): - def reference(x, k, o3, o32): - for i in range(o3.size(1)): - for j in range(k.size(1)): - o32[i + j - 1].add(torch.conv2(x[i], k[j], 'F')) - self._test_conv_corr_eq(lambda x, k: torch.conv3(x, k, 'F'), reference) + def test_parallel_info(self): + torch.__config__.parallel_info() - def test_logical(self): - for device in torch.testing.get_all_device_types(): - for dt in torch.testing.get_all_dtypes(): - x = torch.tensor([1, 2, 3, 4], device=device, dtype=dt) - b = torch.tensor([2], device=device, dtype=dt) + @staticmethod + def _test_bincount(self, device): + # negative input throws + with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'): + torch.bincount(torch.tensor([1, -1], device=device)) + # n-d input, with n > 1 throws + with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'): + torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device)) + # floating input type throws + with self.assertRaisesRegex(RuntimeError, 'not implemented'): + torch.bincount(torch.tensor([1., 0.3], device=device)) + # minlength < 0 throws + with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'): + torch.bincount(torch.tensor([1, 3], device=device), + torch.tensor([.2, .2], device=device), + minlength=-1) + # input and weights dim mismatch + with self.assertRaisesRegex(RuntimeError, 'same length'): + torch.bincount(torch.tensor([1, 0], device=device), + torch.tensor([1., 0.3, 0.5], device=device)) + # 1-d input with no elements and default minlength + self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)), + torch.zeros(0, dtype=torch.long, device=device)) + # 1-d input with no elements and specified minlength + self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10), + torch.zeros(10, dtype=torch.long, device=device)) - if dt == torch.half and device == 'cpu': - self.assertRaises(RuntimeError, lambda: x.lt(2)) - continue + # test tensor method without weights + long_counts = torch.tensor( + [0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount() + self.assertEqual( + torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device), + long_counts) + # test minlength functionality + int_counts = torch.bincount( + torch.tensor([1, 1, 1, 1], device=device), minlength=5) + self.assertEqual( + torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device), + int_counts) + # test weights + byte_counts = torch.bincount( + torch.tensor([0, 1, 1, 1, 4], device=device), + torch.tensor([.1, .2, .3, .4, .5], device=device)) + self.assertEqual( + torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts) + byte_counts = torch.bincount( + torch.tensor([0, 1, 1, 1, 4], device=device), + torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device)) + self.assertEqual( + torch.tensor([1, 9, 0, 0, 5], device=device), byte_counts) + # test non-contiguous inputs and weights + inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device) + weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device) + for i in [0, 1]: + assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous" + assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous" + # inputs are non-contiguous but weights are contiguous + self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2])) + # inputs and weights are non-contiguous + self.assertEqual(inputs[:, 1].bincount(weights[:, 1]), torch.tensor([1, 9, 0, 0, 5])) + # weights are non-contiguous but inputs are contiguous + self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]), + torch.tensor([1, 9, 0, 0, 5])) - if dt == torch.bool: - # torch.bool is a special case and is being tested later - # in this test - continue + # test bincount on non-contiguous slices + all0s = torch.zeros((32, 2), dtype=torch.int64, device=device) + self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32])) - if device == 'cuda' and dt == torch.bfloat16: - self.assertRaises(RuntimeError, lambda: x > b) - self.assertRaises(RuntimeError, lambda: x < b) - self.assertRaises(RuntimeError, lambda: x == b) - self.assertRaises(RuntimeError, lambda: x != b) - self.assertRaises(RuntimeError, lambda: x >= b) - self.assertRaises(RuntimeError, lambda: x <= b) - continue + all1s = torch.ones((32, 2), dtype=torch.int64, device=device) + self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32])) - self.assertEqual(x.lt(2), torch.tensor([True, False, False, False])) - self.assertEqual(x.le(2), torch.tensor([True, True, False, False])) - self.assertEqual(x.ge(2), torch.tensor([False, True, True, True])) - self.assertEqual(x.gt(2), torch.tensor([False, False, True, True])) - self.assertEqual(x.eq(2), torch.tensor([False, True, False, False])) - self.assertEqual(x.ne(2), torch.tensor([True, False, True, True])) + # test large number of bins - global memory use + big_exp = torch.zeros(10000000, device=device) + big_exp[-1] = 50.0 + big_w = torch.tensor([.5] * 100, device=device) + big_out = torch.tensor([9999999] * 100, device=device).bincount(big_w) + self.assertEqual(big_exp, big_out) + # test large input size + big_exp = torch.zeros(2, device=device) + big_exp[1] = 1000000 + big_out = torch.ones(1000000, dtype=torch.int8, device=device).bincount() + self.assertEqual(big_exp, big_out) - self.assertEqual(x.lt(b), torch.tensor([True, False, False, False])) - self.assertEqual(x.le(b), torch.tensor([True, True, False, False])) - self.assertEqual(x.ge(b), torch.tensor([False, True, True, True])) - self.assertEqual(x.gt(b), torch.tensor([False, False, True, True])) - self.assertEqual(x.eq(b), torch.tensor([False, True, False, False])) - self.assertEqual(x.ne(b), torch.tensor([True, False, True, True])) + @slowTest + def test_slow_test(self): + # Just a smoketest to make sure our slowTest decorator works. + pass - with warnings.catch_warnings(record=True) as warningsCount: - byteRes = torch.empty_like(x, device=device).byte() - boolRes = torch.empty_like(x, device=device).bool() + def test_bincount_cpu(self): + self._test_bincount(self, device='cpu') - torch.lt(x, b, out=byteRes) - torch.lt(x, b, out=boolRes) - self.assertEqual(byteRes.bool(), boolRes) + def test_is_nonzero(self): + self.assertExpectedRaises(RuntimeError, lambda: torch.tensor([]).is_nonzero(), subname="empty") + self.assertExpectedRaises(RuntimeError, lambda: torch.tensor([0, 0]).is_nonzero(), subname="multiple") + self.assertFalse(torch.tensor(0).is_nonzero()) + self.assertTrue(torch.tensor(1).is_nonzero()) + self.assertFalse(torch.tensor([0]).is_nonzero()) + self.assertTrue(torch.tensor([1]).is_nonzero()) + self.assertFalse(torch.tensor([[0]]).is_nonzero()) + self.assertTrue(torch.tensor([[1]]).is_nonzero()) - torch.le(x, b, out=byteRes) - torch.le(x, b, out=boolRes) - self.assertEqual(byteRes.bool(), boolRes) + def test_meshgrid(self): + a = torch.tensor(1) + b = torch.tensor([1, 2, 3]) + c = torch.tensor([1, 2]) + grid_a, grid_b, grid_c = torch.meshgrid([a, b, c]) + self.assertEqual(grid_a.shape, torch.Size([1, 3, 2])) + self.assertEqual(grid_b.shape, torch.Size([1, 3, 2])) + self.assertEqual(grid_c.shape, torch.Size([1, 3, 2])) + grid_a2, grid_b2, grid_c2 = torch.meshgrid(a, b, c) + self.assertEqual(grid_a2.shape, torch.Size([1, 3, 2])) + self.assertEqual(grid_b2.shape, torch.Size([1, 3, 2])) + self.assertEqual(grid_c2.shape, torch.Size([1, 3, 2])) + expected_grid_a = torch.ones(1, 3, 2, dtype=torch.int64) + expected_grid_b = torch.tensor([[[1, 1], + [2, 2], + [3, 3]]]) + expected_grid_c = torch.tensor([[[1, 2], + [1, 2], + [1, 2]]]) + self.assertTrue(grid_a.equal(expected_grid_a)) + self.assertTrue(grid_b.equal(expected_grid_b)) + self.assertTrue(grid_c.equal(expected_grid_c)) + self.assertTrue(grid_a2.equal(expected_grid_a)) + self.assertTrue(grid_b2.equal(expected_grid_b)) + self.assertTrue(grid_c2.equal(expected_grid_c)) - torch.ge(x, b, out=byteRes) - torch.ge(x, b, out=boolRes) - self.assertEqual(byteRes.bool(), boolRes) + # NB: we must not be built with CUDA; if we are built with CUDA but no CUDA + # is available, we get a different error. + @unittest.skipIf(torch.backends.cuda.is_built() or IS_SANDCASTLE, "CUDA is built, can't test CUDA not built error") + def test_cuda_not_built(self): + msg = "Torch not compiled with CUDA enabled" + self.assertRaisesRegex(AssertionError, msg, lambda: torch.cuda.current_device()) + self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1], device="cuda")) + self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1]).cuda()) + self.assertRaisesRegex(TypeError, msg, lambda: torch.cuda.FloatTensor()) + self.assertRaisesRegex(TypeError, msg, lambda: torch.set_default_tensor_type(torch.cuda.FloatTensor)) + self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1]).to(device="cuda")) - torch.gt(x, b, out=byteRes) - torch.gt(x, b, out=boolRes) - self.assertEqual(byteRes.bool(), boolRes) + def test_cast_binary_op(self): + # Scalar + a = torch.tensor(2) + b = torch.tensor(3) + a_copy = a.clone() + b_copy = b.clone() - torch.eq(x, b, out=byteRes) - torch.eq(x, b, out=boolRes) - self.assertEqual(byteRes.bool(), boolRes) + self.assertEqual(torch.tensor(6), a.float() * b) - torch.ne(x, b, out=byteRes) - torch.ne(x, b, out=boolRes) - self.assertEqual(byteRes.bool(), boolRes) + self.assertEqual(a.type(), a_copy.type()) + self.assertEqual(a.data.type(), a_copy.data.type()) + self.assertEqual(b.type(), b_copy.type()) + self.assertEqual(b.data.type(), b_copy.type()) - self.assertEquals(len(warningsCount), 6) + def test_cartesian_prod(self): + a = torch.tensor([1]) + b = torch.tensor([1, 2, 3]) + c = torch.tensor([1, 2]) + prod = torch.cartesian_prod(a, b, c) + expected = torch.tensor(list(product([a], b, c))) + self.assertEqual(expected, prod) - # Bool Tensor - x = torch.tensor([True, False, True, False], device=device) - self.assertEqual(x.lt(True), torch.tensor([False, True, False, True])) - self.assertEqual(x.le(True), torch.tensor([True, True, True, True])) - self.assertEqual(x.ge(True), torch.tensor([True, False, True, False])) - self.assertEqual(x.gt(True), torch.tensor([False, False, False, False])) - self.assertEqual(x.eq(True), torch.tensor([True, False, True, False])) - self.assertEqual(x.ne(True), torch.tensor([False, True, False, True])) + # test 0 size input + d = torch.empty(0, dtype=b.dtype) + prod = torch.cartesian_prod(a, b, c, d) + expected = torch.empty(0, 4, dtype=b.dtype) + self.assertEqual(expected, prod) + # test single input + prod = torch.cartesian_prod(b) + self.assertEqual(b, prod) - def test_isfinite(self): - x = torch.Tensor([1, inf, 2, -inf, nan, -10]) - self.assertEqual(torch.isfinite(x), torch.BoolTensor([True, False, True, False, False, True])) + def test_combinations(self): + a = torch.tensor([1, 2, 3]) - def test_isfinite_int(self): - x = torch.tensor([1, 2, 3]) - self.assertEqual(torch.isfinite(x), torch.BoolTensor([True, True, True])) + c = torch.combinations(a, r=1) + expected = torch.tensor(list(combinations(a, r=1))) + self.assertEqual(c, expected) - def test_isfinite_type(self): - with self.assertRaises(TypeError): - torch.isfinite(1) # Parameter must be a tensor + c = torch.combinations(a, r=1, with_replacement=True) + expected = torch.tensor(list(combinations_with_replacement(a, r=1))) + self.assertEqual(c, expected) - @staticmethod - def _test_isinf(self, cast): - t1 = cast(torch.Tensor([1, inf, 2, -inf, nan])) - t2 = cast(torch.ByteTensor([1, 2, 3])) - t3 = cast(torch.CharTensor([1, 2, 3])) - t4 = cast(torch.ShortTensor([1, 2, 3])) - t5 = cast(torch.IntTensor([1, 2, 3])) - t6 = cast(torch.LongTensor([1, 2, 3])) - self.assertEqual(torch.isinf(t1), cast(torch.ByteTensor([0, 1, 0, 1, 0]))) - self.assertEqual(torch.isinf(t2), cast(torch.ByteTensor([0, 0, 0]))) - self.assertEqual(torch.isinf(t3), cast(torch.ByteTensor([0, 0, 0]))) - self.assertEqual(torch.isinf(t4), cast(torch.ByteTensor([0, 0, 0]))) - self.assertEqual(torch.isinf(t5), cast(torch.ByteTensor([0, 0, 0]))) - self.assertEqual(torch.isinf(t6), cast(torch.ByteTensor([0, 0, 0]))) - - def test_isinf(self): - self._test_isinf(self, lambda t: t) + c = torch.combinations(a) + expected = torch.tensor(list(combinations(a, r=2))) + self.assertEqual(c, expected) - def test_isinf_type(self): - with self.assertRaises(TypeError): - torch.isinf(1) # Parameter must be a tensor + c = torch.combinations(a, with_replacement=True) + expected = torch.tensor(list(combinations_with_replacement(a, r=2))) + self.assertEqual(c, expected) - def test_isnan(self): - x = torch.Tensor([1, nan, 2]) - self.assertEqual(torch.isnan(x), torch.ByteTensor([0, 1, 0])) + c = torch.combinations(a, r=3) + expected = torch.tensor(list(combinations(a, r=3))) + self.assertEqual(c, expected) - def test_RNGState(self): - state = torch.get_rng_state() - stateCloned = state.clone() - before = torch.rand(1000) + c = torch.combinations(a, r=4) + expected = torch.empty(0, 4, dtype=a.dtype) + self.assertEqual(c, expected) - self.assertEqual(state.ne(stateCloned).long().sum(), 0, 0) + c = torch.combinations(a, r=5) + expected = torch.empty(0, 5, dtype=a.dtype) + self.assertEqual(c, expected) - torch.set_rng_state(state) - after = torch.rand(1000) - self.assertEqual(before, after, 0) + # test empty imput + a = torch.empty(0) + c1 = torch.combinations(a) + c2 = torch.combinations(a, with_replacement=True) + expected = torch.empty(0, 2, dtype=a.dtype) + self.assertEqual(c1, expected) + self.assertEqual(c2, expected) - def test_RNGStateAliasing(self): - # Fork the random number stream at this point - gen = torch.Generator() - gen.set_state(torch.get_rng_state()) - self.assertEqual(gen.get_state(), torch.get_rng_state()) - - target_value = torch.rand(1000) - # Dramatically alter the internal state of the main generator - _ = torch.rand(100000) - forked_value = torch.rand(1000, generator=gen) - self.assertEqual(target_value, forked_value, 0, "RNG has not forked correctly.") - - def test_RNG_after_pickle(self): - torch.random.manual_seed(100) - before = torch.rand(10) + def test_has_internal_overlap(self): + OVERLAP_NO = 0 + OVERLAP_YES = 1 + OVERLAP_TOO_HARD = 2 - torch.random.manual_seed(100) - buf = io.BytesIO() - tensor = torch.Tensor([1, 2, 3]) - ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(tensor) - after = torch.rand(10) + # Check for contiguous tensors + a = torch.randn(3, 3) + self.assertEqual(torch._debug_has_internal_overlap(a), OVERLAP_NO) - self.assertEqual(before, after, 0) + # Checks for zero strides + b = torch.randn(1, 3) + b_expanded = b.expand(4, 3) + self.assertEqual(torch._debug_has_internal_overlap(b_expanded), OVERLAP_YES) - def test_boxMullerState(self): - torch.manual_seed(123) - odd_number = 101 - seeded = torch.randn(odd_number) - state = torch.get_rng_state() - midstream = torch.randn(odd_number) - torch.set_rng_state(state) - repeat_midstream = torch.randn(odd_number) - torch.manual_seed(123) - reseeded = torch.randn(odd_number) - self.assertEqual(midstream, repeat_midstream, 0, - 'get_rng_state/set_rng_state not generating same sequence of normally distributed numbers') - self.assertEqual(seeded, reseeded, 0, - 'repeated calls to manual_seed not generating same sequence of normally distributed numbers') + def test_allow_tensor_metadata_change(self): + def do_test(t): + with self.assertRaisesRegex( + RuntimeError, + "set_sizes_contiguous is not allowed on a Tensor created from .data or .detach()"): + t.resize_((2, 1)) + with self.assertRaisesRegex( + RuntimeError, + "set_storage is not allowed on a Tensor created from .data or .detach()"): + t.set_() + with self.assertRaisesRegex( + RuntimeError, + "set_storage_offset is not allowed on a Tensor created from .data or .detach()"): + t.set_(t.storage(), 0, t.size(), list(t.stride())) - def test_manual_seed(self): - rng_state = torch.get_rng_state() - torch.manual_seed(2) - x = torch.randn(100) - self.assertEqual(torch.initial_seed(), 2) - torch.manual_seed(2) - y = torch.randn(100) - self.assertEqual(x, y) - torch.set_rng_state(rng_state) + do_test(torch.tensor([[1, 2]]).data) + do_test(torch.tensor([[1, 2]]).detach()) - @staticmethod - def _test_cholesky(self, cast): - x = cast(torch.rand(10, 10) + 1e-1) - A = torch.mm(x, x.t()) + def test_c10_layer_norm(self): + # test that we can call c10 ops and they return a reasonable result + X = torch.rand(5, 5, dtype=torch.float) + weight = torch.rand(*X.size()[1:], dtype=torch.float) + bias = torch.rand(*X.size()[1:], dtype=torch.float) + epsilon = 1e-4 - # default Case - C = torch.cholesky(A) - B = torch.mm(C, C.t()) - self.assertEqual(A, B, 1e-14) + expected_norm = torch.nn.functional.layer_norm( + X, X.size()[1:], weight=weight, bias=bias, eps=epsilon) + actual_norm, actual_mean, actual_stdev = \ + torch.ops._caffe2.LayerNorm(torch.tensor(X), torch.tensor( + weight), torch.tensor(bias), 1, epsilon, True) + torch.testing.assert_allclose(expected_norm, actual_norm) - # test Upper Triangular - U = torch.cholesky(A, True) - B = torch.mm(U.t(), U) - self.assertEqual(A, B, 1e-14, 'cholesky (upper) did not allow rebuilding the original matrix') + def test_memory_format(self): + x = torch.randn(10, 3, 32, 32) + nhwc = x.contiguous(memory_format=torch.channels_last) + self.assertFalse(nhwc.is_contiguous()) + self.assertTrue(nhwc.is_contiguous(memory_format=torch.channels_last)) + self.assertEqual(nhwc, x) - # test Lower Triangular - L = torch.cholesky(A, False) - B = torch.mm(L, L.t()) - self.assertEqual(A, B, 1e-14, 'cholesky (lower) did not allow rebuilding the original matrix') + def test_memory_format_contiguous_returns_same_tensor_if_already_satisfies(self): + x = torch.randn(10, 32, 32, 3).permute(0, 3, 1, 2) + alias = x.contiguous(memory_format=torch.channels_last) + alias.fill_(7) + self.assertEqual(x, alias) - @skipIfNoLapack - def test_cholesky(self): - self._test_cholesky(self, lambda t: t) + def test_memory_format_empty(self): + with self.assertRaises(RuntimeError): + x = torch.empty((3, 3), memory_format=torch.channels_last) + x = torch.empty((3, 3, 3, 3), memory_format=torch.channels_last) + self.assertTrue(x.is_contiguous(memory_format=torch.channels_last)) - @staticmethod - def _test_cholesky_batched(self, cast): - from common_utils import random_symmetric_pd_matrix + def test_subclass_tensors(self): + # raise an error when trying to subclass FloatTensor + with self.assertRaisesRegex(TypeError, "type 'torch.FloatTensor' is not an acceptable base type"): + class Foo1(torch.FloatTensor): + pass - def cholesky_test_helper(n, batch_dims, cast, upper): - A = cast(random_symmetric_pd_matrix(n, *batch_dims)) - cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)]) - cholesky_exp = cholesky_exp.reshape_as(A) - self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper)) + # but allow subclassing Tensor: + class Foo2(torch.Tensor): + def foo(self): + return 5 + f = Foo2() + self.assertEqual(f.foo(), 5) - for upper, batchsize in product([True, False], [(3,), (3, 4), (2, 3, 4)]): - cholesky_test_helper(3, batchsize, cast, upper) + def test_ndim(self): + a = torch.randn(1, 2, 3) + self.assertEqual(3, a.ndim) + b = torch.randn(()) + self.assertEqual(0, b.ndim) + c = torch.randn(1, 0) + self.assertEqual(2, c.ndim) - @skipIfNoLapack - def test_cholesky_batched(self): - self._test_cholesky_batched(self, lambda t: t) + def test_T(self): + a = torch.randn(2, 3, 4) + t1 = a.T + t2 = a.permute(2, 1, 0) + self.assertEqual(t2, t1) + b = torch.randn(10) + self.assertEqual(b, b.T) + scalar = torch.tensor(5) + self.assertEqual(scalar, scalar.T) - @staticmethod - def _test_cholesky_batched_many_batches(self, cast): - from common_utils import random_symmetric_pd_matrix + def test_python_types(self): + a1 = torch.randn((1, 2), dtype=torch.float64) + a2 = torch.randn((1, 2), dtype=float) + self.assertEqual(a1.dtype, a2.dtype) - def cholesky_test_helper(n, batchsize, cast, upper): - A = cast(random_symmetric_pd_matrix(n, batchsize)) - chol_fact = torch.cholesky(A, upper=upper) - if upper: - # Correctness check - self.assertEqual(A, chol_fact.transpose(-2, -1).matmul(chol_fact)) - # Upper triangular check - self.assertEqual(chol_fact, chol_fact.triu()) - else: - # Correctness check - self.assertEqual(A, chol_fact.matmul(chol_fact.transpose(-2, -1))) - # Lower triangular check - self.assertEqual(chol_fact, chol_fact.tril()) + b1 = torch.arange(10, 20, dtype=torch.int64) + b2 = torch.arange(10, 20, dtype=int) + self.assertEqual(b1.dtype, b2.dtype) - for upper, batchsize in product([True, False], [262144, 524288]): - cholesky_test_helper(2, batchsize, cast, upper) + c1 = torch.tensor([True, False], dtype=torch.bool) + c2 = torch.tensor([True, False], dtype=bool) + self.assertEqual(c1.dtype, c2.dtype) - @skipIfNoLapack - @slowTest - def test_cholesky_batched_many_batches(self): - self._test_cholesky_batched_many_batches(self, lambda t: t) + def test_fill_diagonal(self): + a1 = torch.randn(7, 3) + a2 = a1.clone() + v = 1 + for i in range(3): + a2[i][i] = v + a1.fill_diagonal_(v) + self.assertEqual(a1, a2) - @staticmethod - def _test_cholesky_solve(self, cast): - a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23), - (-6.05, -3.30, 5.36, -4.44, 1.08), - (-0.45, 2.58, -2.70, 0.27, 9.04), - (8.32, 2.71, 4.35, -7.17, 2.14), - (-9.67, -5.14, -7.26, 6.08, -6.87))).t() - b = torch.Tensor(((4.02, 6.19, -8.22, -7.57, -3.03), - (-1.56, 4.00, -8.67, 1.75, 2.86), - (9.81, -4.09, -4.57, -8.61, 8.99))).t() - - # make sure 'a' is symmetric PSD - a = torch.mm(a, a.t()) - a, b = cast(a), cast(b) + b1 = torch.randn(7, 3) + b2 = b1.clone() + for i in range(3): + b2[i][i] = v + b2[i + 4][i] = v + b1.fill_diagonal_(v, wrap=True) + self.assertEqual(b1, b2) - # upper Triangular Test - U = torch.cholesky(a, True) - x = torch.cholesky_solve(b, U, True) - self.assertLessEqual(b.dist(torch.mm(a, x)), 1e-12) + c1 = torch.rand(3, 3, 3) + c2 = c1.clone() + for i in range(3): + c2[i][i][i] = v + c1.fill_diagonal_(v) + self.assertEqual(c1, c2) - # lower Triangular Test - L = torch.cholesky(a, False) - x = torch.cholesky_solve(b, L, False) - self.assertLessEqual(b.dist(torch.mm(a, x)), 1e-12) + # non-contiguous tensor + d1 = torch.rand(3, 3, 3)[:, 1, ...] + d2 = d1.clone() + for i in range(3): + d2[i][i] = v + d1.fill_diagonal_(v) + self.assertEqual(d1, d2) - # default arg Test - L_def = torch.cholesky(a) - x_def = torch.cholesky_solve(b, L_def) - self.assertLessEqual(b.dist(torch.mm(a, x_def)), 1e-12) + e1 = torch.rand(7, 3, 3)[:, 1, ...] + e2 = e1.clone() + for i in range(3): + e2[i][i] = v + e2[i + 4][i] = v + e1.fill_diagonal_(v, wrap=True) + self.assertEqual(e1, e2) - @skipIfNoLapack - def test_cholesky_solve(self): - self._test_cholesky_solve(self, lambda t: t) + def test_function_unwrap_message(self): + self.assertRaisesRegex(RuntimeError, ' call to _th_lt', + lambda: torch.ones(1, dtype=torch.float) < torch.ones(1, dtype=torch.double)) - @staticmethod - def _test_cholesky_solve_batched(self, cast): - from common_utils import random_symmetric_pd_matrix - def cholesky_solve_test_helper(A_dims, b_dims, cast, upper): - A = cast(random_symmetric_pd_matrix(*A_dims)) - L = torch.cholesky(A, upper) - b = cast(torch.randn(*b_dims)) - return A, L, b +# Functions to test negative dimension wrapping +METHOD = 1 +INPLACE_METHOD = 2 +FUNCTIONAL = 4 +DIM_ARG = None - for upper in [True, False]: - # test against cholesky_solve: one batch with both choices of upper - A, L, b = cholesky_solve_test_helper((5, 1), (1, 5, 10), cast, upper) - x_exp = torch.cholesky_solve(b.squeeze(0), L.squeeze(0), upper=upper) - x = torch.cholesky_solve(b, L, upper=upper) - self.assertEqual(x, x_exp.unsqueeze(0)) +def make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim=0): + def neg_dim_test(self): + if isinstance(tensor_arg, list): + assert METHOD not in types and INPLACE_METHOD not in types + x = [torch.randn(arg) for arg in tensor_arg] + ndim = len(tensor_arg[-1]) + else: + x = torch.randn(*tensor_arg) + ndim = len(tensor_arg) + ndim += extra_dim - # test against cholesky_solve in a loop: four batches with both choices of upper - A, L, b = cholesky_solve_test_helper((5, 4), (4, 5, 10), cast, upper) - x_exp_list = [] - for i in range(4): - x_exp = torch.cholesky_solve(b[i], L[i], upper=upper) - x_exp_list.append(x_exp) - x_exp = torch.stack(x_exp_list) + n_dim_to_test = sum(map(lambda e: e is DIM_ARG, arg_constr())) - x = torch.cholesky_solve(b, L, upper=upper) - self.assertEqual(x, x_exp) + for dims_val in combinations(range(ndim), n_dim_to_test): + arg = arg_constr() + arg_neg = copy.deepcopy(arg) + idx = 0 + for i, v in enumerate(arg): + if v is DIM_ARG: + arg[i] = dims_val[idx] + arg_neg[i] = dims_val[idx] - ndim + idx += 1 - # basic correctness test - A, L, b = cholesky_solve_test_helper((5, 3), (3, 5, 10), cast, upper) - x = torch.cholesky_solve(b, L, upper) - self.assertLessEqual(b.dist(torch.matmul(A, x)), 1e-12) + if METHOD in types: + a = getattr(x, name)(*arg) + b = getattr(x, name)(*arg_neg) + self.assertEqual(a, b) - # Test non-contiguous inputs. - if not TEST_NUMPY: - return - from numpy.linalg import solve - A = random_symmetric_pd_matrix(2, 2) - b = torch.randn(2, 2, 2) - x_exp = torch.Tensor(solve(A.permute(0, 2, 1).numpy(), b.permute(2, 1, 0).numpy())) - A = cast(A).permute(0, 2, 1) - b = cast(b).permute(2, 1, 0) - assert not A.is_contiguous() and not b.is_contiguous(), "contiguous inputs" - L = torch.cholesky(A, upper) - x = torch.cholesky_solve(b, L, upper=upper) - self.assertEqual(x, cast(x_exp)) + if INPLACE_METHOD in types: + a = x.clone() + getattr(a, name + '_')(*arg) + b = x.clone() + getattr(b, name + '_')(*arg_neg) + self.assertEqual(a, b) - @skipIfNoLapack - def test_cholesky_solve_batched(self): - self._test_cholesky_solve_batched(self, lambda t: t) + if FUNCTIONAL in types: + a = getattr(torch, name)(x, *arg) + b = getattr(torch, name)(x, *arg_neg) + self.assertEqual(a, b) - @staticmethod - def _test_cholesky_solve_batched_many_batches(self, cast): - from common_utils import random_symmetric_pd_matrix + return neg_dim_test - def cholesky_solve_test_helper(A_dims, b_dims, cast, upper): - A = cast(random_symmetric_pd_matrix(*A_dims)) - L = torch.cholesky(A, upper) - b = cast(torch.randn(*b_dims)) - return A, L, b - for upper in [True, False]: - A, L, b = cholesky_solve_test_helper((5, 256, 256), (5, 10), cast, upper) - x = torch.cholesky_solve(b, L, upper) - self.assertEqual(torch.matmul(A, x), b.expand(A.shape[:-2] + (5, 10))) +def idx_tensor(size, max_val): + return torch.LongTensor(*size).random_(0, max_val - 1) - A, L, b = cholesky_solve_test_helper((5,), (512, 512, 5, 10), cast, upper) - x = torch.cholesky_solve(b, L, upper) - self.assertEqual(torch.matmul(A, x), b) - @skipIfNoLapack - @slowTest - def test_cholesky_solve_batched_many_batches(self): - self._test_cholesky_solve_batched_many_batches(self, lambda t: t) +def add_neg_dim_tests(): + neg_dim_tests = [ + ('narrow', (10, 20, 30), lambda: [DIM_ARG, 0, 5], [METHOD]), + ('transpose', (10, 20, 30), lambda: [DIM_ARG, DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL]), + ('size', (10, 20, 30), lambda: [DIM_ARG], [METHOD]), + ('cat', [(2, 3, 4), (2, 3, 4)], lambda: [DIM_ARG], [FUNCTIONAL]), + ('chunk', (10, 20, 30), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]), + ('gather', (10, 20), lambda: [DIM_ARG, idx_tensor((10, 20), 10)], [METHOD, FUNCTIONAL]), + ('index_select', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10)], [METHOD, FUNCTIONAL]), + ('split', (10, 20), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]), + ('squeeze', (10, 1, 20, 1), lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL]), + ('unbind', (2, 3, 4), lambda: [DIM_ARG], [FUNCTIONAL]), + ('unsqueeze', (10, 20), lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL], 1), + ('cumprod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), + ('cumsum', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), + ('mean', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), + ('median', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), + ('mode', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), + ('norm', (10, 20), lambda: [2, DIM_ARG], [METHOD, FUNCTIONAL]), + ('prod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), + ('std', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), + ('sum', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), + ('var', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), + ('kthvalue', (10, 20), lambda: [3, DIM_ARG], [METHOD, FUNCTIONAL]), + ('max', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), + ('min', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), + ('sort', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), + ('topk', (10, 20), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]), + ('renorm', (10, 20), lambda: [2, DIM_ARG, 1], [METHOD, INPLACE_METHOD, FUNCTIONAL]), + ('index_add', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), torch.randn(10, 10)], [INPLACE_METHOD]), + ('index_copy', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), torch.randn(10, 10)], [INPLACE_METHOD]), + ('index_fill', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), 12], [INPLACE_METHOD]), + ('scatter', (10, 10), lambda: [DIM_ARG, idx_tensor((10, 10), 10), torch.randn(10, 10)], [INPLACE_METHOD]), + ('select', (10, 20), lambda: [DIM_ARG, 3], [METHOD]), + ('unfold', (10, 20), lambda: [DIM_ARG, 5, 2], [METHOD]), + ] - @staticmethod - def _test_cholesky_solve_batched_dims(self, cast): - if not TEST_NUMPY: - return + for decl in neg_dim_tests: + if len(decl) == 4: + name, tensor_arg, arg_constr, types = decl + extra_dim = 0 + elif len(decl) == 5: + name, tensor_arg, arg_constr, types, extra_dim = decl - from numpy.linalg import solve - from common_utils import random_symmetric_pd_matrix + test_name = 'test_' + name + '_neg_dim' - def run_test(A_dims, b_dims, cast, upper): - A = random_symmetric_pd_matrix(*A_dims) - b = torch.randn(*b_dims) - x_exp = torch.Tensor(solve(A.numpy(), b.numpy())) - A, b = cast(A), cast(b) - L = torch.cholesky(A, upper) - x = torch.cholesky_solve(b, L, upper=upper) - self.assertEqual(x, cast(x_exp)) + assert not hasattr(_TestTorchMixin, test_name), "Duplicated test name: " + test_name + setattr(_TestTorchMixin, test_name, make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim)) - for upper in [True, False]: - # test against numpy.linalg.solve - run_test((4, 2, 1, 3), (2, 1, 3, 4, 6), cast, upper) # no broadcasting - run_test((4, 2, 1, 3), (4, 6), cast, upper) # broadcasting b - run_test((4,), (2, 1, 3, 4, 2), cast, upper) # broadcasting A - run_test((4, 1, 3, 1), (2, 1, 3, 4, 5), cast, upper) # broadcasting A & b - @skipIfNoLapack - def test_cholesky_solve_batched_dims(self): - self._test_cholesky_solve_batched_dims(self, lambda t: t) +# Device-generic tests. Instantiated below and not run directly. +class TestTorchDeviceType(TestCase): + def check_internal_mem_overlap(self, inplace_op, num_inputs, device, + expected_failure=False): + if isinstance(inplace_op, str): + inplace_op = getattr(torch.Tensor, inplace_op) + input = torch.randn(1, device=device).expand(3, 3) + inputs = [input] + [torch.randn_like(input) + for i in range(num_inputs - 1)] + if not expected_failure: + with self.assertRaisesRegex(RuntimeError, 'single memory location'): + inplace_op(*inputs) + else: + with self.assertRaises(AssertionError): + with self.assertRaisesRegex(RuntimeError, 'single memory location'): + inplace_op(*inputs) - @staticmethod - def _test_cholesky_inverse(self, cast): - from common_utils import random_symmetric_pd_matrix - a = cast(random_symmetric_pd_matrix(5)) + def unary_check_input_output_mem_overlap(self, data, sz, op, + expected_failure=False): - # compute inverse directly - inv0 = torch.inverse(a) + def _test(op, output, input): + output_exp = torch.empty_like(output) + op(input, out=output_exp) + self.assertEqual(op(input, out=output), output_exp, op.__name__) - # default case - chol = torch.cholesky(a) - inv1 = torch.cholesky_inverse(chol, False) - self.assertLessEqual(inv0.dist(inv1), 1e-12) + # output is identical to input: + _test(op, output=data[0:sz], input=data[0:sz]) + # output and input are independent: + _test(op, output=data[0:sz], input=data[sz:2 * sz]) + # output partially overlaps with input: + if not expected_failure: + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + _test(op, data[0:sz], data[1:sz + 1]) + else: + with self.assertRaises(AssertionError): + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + _test(op, data[0:sz], data[1:sz + 1]) - # upper Triangular Test - chol = torch.cholesky(a, True) - inv1 = torch.cholesky_inverse(chol, True) - self.assertLessEqual(inv0.dist(inv1), 1e-12) + def binary_check_input_output_mem_overlap(self, op, device, + expected_failure=False): + sz = 3 + data = torch.randn(2 * sz, device=device) + other = torch.randn(sz, device=device) - # lower Triangular Test - chol = torch.cholesky(a, False) - inv1 = torch.cholesky_inverse(chol, False) - self.assertLessEqual(inv0.dist(inv1), 1e-12) + self.unary_check_input_output_mem_overlap( + data, sz, lambda input, out: op(other, input, out=out), + expected_failure=expected_failure) - @skipIfNoLapack - def test_cholesky_inverse(self): - self._test_cholesky_inverse(self, lambda t: t) + self.unary_check_input_output_mem_overlap( + data, sz, lambda input, out: op(input, other, out=out), + expected_failure=expected_failure) - def test_numel(self): - b = torch.ByteTensor(3, 100, 100) - self.assertEqual(b.nelement(), 3 * 100 * 100) - self.assertEqual(b.numel(), 3 * 100 * 100) + def ternary_check_input_output_mem_overlap(self, op, device, + expected_failure=False): + sz = 3 + data = torch.randn(2 * sz, device=device) + other1 = torch.randn(sz, device=device) + other2 = torch.randn(sz, device=device) - def _consecutive(self, size, start=1): - sequence = torch.ones(int(torch.Tensor(size).prod(0))).cumsum(0) - sequence.add_(start - 1) - return sequence.resize_(*size) + self.unary_check_input_output_mem_overlap( + data, sz, lambda input, out: op(input, other1, other2, out=out), + expected_failure=expected_failure) - @staticmethod - def _test_index(self, conv_fn): + self.unary_check_input_output_mem_overlap( + data, sz, lambda input, out: op(other1, input, other2, out=out), + expected_failure=expected_failure) - def consec(size, start=1): - sequence = torch.ones(int(torch.Tensor(size).prod(0))).cumsum(0) - sequence.add_(start - 1) - return sequence.view(*size) + self.unary_check_input_output_mem_overlap( + data, sz, lambda input, out: op(other1, other2, input, out=out), + expected_failure=expected_failure) - reference = conv_fn(consec((3, 3, 3))) + def _test_pow(self, base, exponent, np_exponent=None): + if np_exponent is None: + np_exponent = exponent - # empty tensor indexing - self.assertEqual(reference[conv_fn(torch.LongTensor())], reference.new(0, 3, 3)) + def to_np(value): + if isinstance(value, torch.Tensor): + return value.cpu().numpy() + return value - self.assertEqual(reference[0], consec((3, 3)), 0) - self.assertEqual(reference[1], consec((3, 3), 10), 0) - self.assertEqual(reference[2], consec((3, 3), 19), 0) - self.assertEqual(reference[0, 1], consec((3,), 4), 0) - self.assertEqual(reference[0:2], consec((2, 3, 3)), 0) - self.assertEqual(reference[2, 2, 2], 27, 0) - self.assertEqual(reference[:], consec((3, 3, 3)), 0) + try: + expected = torch.from_numpy( + np.power(to_np(base), to_np(np_exponent))) + except ValueError as e: + err_msg = "Integers to negative integer powers are not allowed." + self.assertEqual(str(e), err_msg) + out = torch.empty_like(base) + test_cases = [ + lambda: base.pow(exponent), + lambda: base.pow_(exponent), + lambda: torch.pow(base, exponent), + lambda: torch.pow(base, exponent, out=out) + ] + for test_case in test_cases: + self.assertRaisesRegex(RuntimeError, err_msg, test_case) + else: + if isinstance(base, torch.Tensor): + actual = base.pow(exponent) + self.assertEqual(actual, expected, allow_inf=True) - # indexing with Ellipsis - self.assertEqual(reference[..., 2], torch.Tensor([[3, 6, 9], - [12, 15, 18], - [21, 24, 27]]), 0) - self.assertEqual(reference[0, ..., 2], torch.Tensor([3, 6, 9]), 0) - self.assertEqual(reference[..., 2], reference[:, :, 2], 0) - self.assertEqual(reference[0, ..., 2], reference[0, :, 2], 0) - self.assertEqual(reference[0, 2, ...], reference[0, 2], 0) - self.assertEqual(reference[..., 2, 2, 2], 27, 0) - self.assertEqual(reference[2, ..., 2, 2], 27, 0) - self.assertEqual(reference[2, 2, ..., 2], 27, 0) - self.assertEqual(reference[2, 2, 2, ...], 27, 0) - self.assertEqual(reference[...], reference, 0) + actual = base.clone() + actual2 = actual.pow_(exponent) + self.assertEqual(actual, expected, allow_inf=True) + self.assertEqual(actual2, expected, allow_inf=True) - reference_5d = conv_fn(consec((3, 3, 3, 3, 3))) - self.assertEqual(reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0], 0) - self.assertEqual(reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0], 0) - self.assertEqual(reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1], 0) - self.assertEqual(reference_5d[...], reference_5d, 0) + actual = torch.pow(base, exponent) + self.assertEqual(actual, expected, allow_inf=True) - # LongTensor indexing - reference = conv_fn(consec((5, 5, 5))) - idx = conv_fn(torch.LongTensor([2, 4])) - self.assertEqual(reference[idx], torch.stack([reference[2], reference[4]])) - # TODO: enable one indexing is implemented like in numpy - # self.assertEqual(reference[2, idx], torch.stack([reference[2, 2], reference[2, 4]])) - # self.assertEqual(reference[3, idx, 1], torch.stack([reference[3, 2], reference[3, 4]])[:, 1]) + actual2 = torch.pow(base, exponent, out=actual) + self.assertEqual(actual, expected, allow_inf=True) + self.assertEqual(actual2, expected, allow_inf=True) - # None indexing - self.assertEqual(reference[2, None], reference[2].unsqueeze(0)) - self.assertEqual(reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0)) - self.assertEqual(reference[2:4, None], reference[2:4].unsqueeze(1)) - self.assertEqual(reference[None, 2, None, None], reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0)) - self.assertEqual(reference[None, 2:5, None, None], reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2)) + def _select_broadcastable_dims(self, dims_full=None): + # select full dimensionality + if dims_full is None: + dims_full = [] + ndims = random.randint(1, 4) + dims_full = [random.randint(1, 8) for _ in range(ndims)] + else: + ndims = len(dims_full) - # indexing 0-length slice - self.assertEqual(torch.empty(0, 5, 5), reference[slice(0)]) - self.assertEqual(torch.empty(0, 5), reference[slice(0), 2]) - self.assertEqual(torch.empty(0, 5), reference[2, slice(0)]) - self.assertEqual(torch.tensor([]), reference[2, 1:1, 2]) - - # indexing with step - reference = consec((10, 10, 10)) - self.assertEqual(reference[1:5:2], torch.stack([reference[1], reference[3]], 0)) - self.assertEqual(reference[1:6:2], torch.stack([reference[1], reference[3], reference[5]], 0)) - self.assertEqual(reference[1:9:4], torch.stack([reference[1], reference[5]], 0)) - self.assertEqual(reference[2:4, 1:5:2], torch.stack([reference[2:4, 1], reference[2:4, 3]], 1)) - self.assertEqual(reference[3, 1:6:2], torch.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0)) - self.assertEqual(reference[None, 2, 1:9:4], torch.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0)) - self.assertEqual(reference[:, 2, 1:6:2], - torch.stack([reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1)) - - lst = [list(range(i, i + 10)) for i in range(0, 100, 10)] - tensor = conv_fn(torch.DoubleTensor(lst)) - for _i in range(100): - idx1_start = random.randrange(10) - idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1) - idx1_step = random.randrange(1, 8) - idx1 = slice(idx1_start, idx1_end, idx1_step) - if random.randrange(2) == 0: - idx2_start = random.randrange(10) - idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1) - idx2_step = random.randrange(1, 8) - idx2 = slice(idx2_start, idx2_end, idx2_step) - lst_indexed = list(map(lambda l: l[idx2], lst[idx1])) - tensor_indexed = tensor[idx1, idx2] - else: - lst_indexed = lst[idx1] - tensor_indexed = tensor[idx1] - self.assertEqual(torch.DoubleTensor(lst_indexed), tensor_indexed) - - self.assertRaises(ValueError, lambda: reference[1:9:0]) - self.assertRaises(ValueError, lambda: reference[1:9:-1]) - - self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1]) - self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1]) - self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3]) - - self.assertRaises(IndexError, lambda: reference[0.0]) - self.assertRaises(TypeError, lambda: reference[0.0:2.0]) - self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0]) - self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0]) - self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0]) - self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0]) - - def delitem(): - del reference[0] + # select actual dimensions for ops: + # larger: full ndims, individual sizes may be reduced + # smaller: possibly reduced ndims, sizes may be reduced + smaller_ndims = random.randint(1, ndims) + dims_small = [] + dims_large = [] + for i in range(ndims - 1, -1, -1): + j = random.randint(1, 3) + if j == 1: # no reduced singleton dimension + ds = dims_full[i] + dl = dims_full[i] + elif j == 2: # larger may have reduced singleton dimension + ds = dims_full[i] + dl = 1 if len(dims_small) < smaller_ndims else dims_full[i] + elif j == 3: # smaller may have reduced singleton dimension + ds = 1 + dl = dims_full[i] + dims_large = [dl] + dims_large + if len(dims_small) < smaller_ndims: + dims_small = [ds] + dims_small + return (dims_small, dims_large, dims_full) - self.assertRaises(TypeError, delitem) + def test_diagonal(self, device): + x = torch.randn((100, 100), device=device) + result = torch.diagonal(x) + expected = torch.diag(x) + self.assertEqual(result, expected) - def test_index(self): - self._test_index(self, lambda x: x) + x = torch.randn((100, 100), device=device) + result = torch.diagonal(x, 17) + expected = torch.diag(x, 17) + self.assertEqual(result, expected) - @staticmethod - def _test_advancedindex(self, conv_fn): - # Tests for Integer Array Indexing, Part I - Purely integer array - # indexing + def test_neg(self, device): + int_types = [torch.int, torch.short, torch.int8, torch.uint8] + float_types = [torch.float, torch.double, torch.long] - def consec(size, start=1): - numel = reduce(lambda x, y: x * y, size, 1) - sequence = torch.ones(numel).cumsum(0) - sequence.add_(start - 1) - return sequence.view(*size) + # Tests bool tensor negation raises the correct error + self.assertRaisesRegex( + RuntimeError, + r"Negation, the `\-` operator, on a bool tensor is not supported. " + r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.", + lambda: - torch.tensor([False, True], device=device)) - # pick a random valid indexer type - def ri(indices): - choice = random.randint(0, 2) - if choice == 0: - return conv_fn(torch.LongTensor(indices)) - elif choice == 1: - return list(indices) + for dtype in float_types + int_types: + if dtype in float_types: + a = torch.randn(100, 90).type(dtype).to(device) else: - return tuple(indices) + a = torch.randint(-128, 128, (100, 90), dtype=dtype, device=device) + zeros = torch.Tensor().type(dtype).resize_as_(a).zero_().to(device) - def validate_indexing(x): - self.assertEqual(x[[0]], consec((1,))) - self.assertEqual(x[ri([0]), ], consec((1,))) - self.assertEqual(x[ri([3]), ], consec((1,), 4)) - self.assertEqual(x[[2, 3, 4]], consec((3,), 3)) - self.assertEqual(x[ri([2, 3, 4]), ], consec((3,), 3)) - self.assertEqual(x[ri([0, 2, 4]), ], torch.Tensor([1, 3, 5])) + if dtype == torch.uint8: + res_add = torch.add(zeros, a, alpha=255) + else: + res_add = torch.add(zeros, a, alpha=-1) - def validate_setting(x): - dtype = x.type() - x[[0]] = -2 - self.assertEqual(x[[0]], torch.Tensor([-2]).type(dtype)) - x[[0]] = -1 - self.assertEqual(x[ri([0]), ], torch.Tensor([-1]).type(dtype)) - x[[2, 3, 4]] = 4 - self.assertEqual(x[[2, 3, 4]], torch.Tensor([4, 4, 4]).type(dtype)) - x[ri([2, 3, 4]), ] = 3 - self.assertEqual(x[ri([2, 3, 4]), ], torch.Tensor([3, 3, 3]).type(dtype)) - x[ri([0, 2, 4]), ] = conv_fn(torch.Tensor([5, 4, 3])).type(dtype) - self.assertEqual(x[ri([0, 2, 4]), ], torch.Tensor([5, 4, 3]).type(dtype)) + res_neg = a.clone() + res_neg.neg_() + self.assertEqual(res_neg, res_add) - # First, we will test indexing to generate return values + # test out of place as well + res_neg_out_place = a.clone().neg() + self.assertEqual(res_neg_out_place, res_add) - # Case 1: Purely Integer Array Indexing - reference = conv_fn(consec((10,))) - validate_indexing(reference) - validate_indexing(reference.type(torch.half)) + # test via __neg__ operator + res_neg_op = -a.clone() + self.assertEqual(res_neg_op, res_add) - # setting values - validate_setting(reference) - validate_setting(reference.type(torch.half)) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_inverse(self, device): + from common_utils import random_fullrank_matrix_distinct_singular_value - # Tensor with stride != 1 + # no batches: 2-D tensors + matrix = random_fullrank_matrix_distinct_singular_value(5).to(device) + matrix_inverse = torch.inverse(matrix) + identity = torch.eye(5).to(device) + self.assertEqual(identity, torch.mm(matrix, matrix_inverse), 1e-8, 'inverse value') + self.assertEqual(identity, torch.mm(matrix_inverse, matrix), 1e-8, 'inverse value') - # strided is [1, 3, 5, 7] - reference = conv_fn(consec((10,))) - strided = conv_fn(torch.Tensor()) - strided.set_(reference.storage(), storage_offset=0, - size=torch.Size([4]), stride=[2]) + matrix_inverse_out = torch.empty(5, 5).to(device) + torch.inverse(matrix, out=matrix_inverse_out) + self.assertEqual(matrix_inverse_out, matrix_inverse, 0, 'inverse value in-place') + # second call, now that matrix_inverse_out is transposed + torch.inverse(matrix, out=matrix_inverse_out) + self.assertEqual(matrix_inverse_out, matrix_inverse, 0, 'inverse value in-place') - self.assertEqual(strided[[0]], torch.Tensor([1])) - self.assertEqual(strided[ri([0]), ], torch.Tensor([1])) - self.assertEqual(strided[ri([3]), ], torch.Tensor([7])) - self.assertEqual(strided[[1, 2]], torch.Tensor([3, 5])) - self.assertEqual(strided[ri([1, 2]), ], torch.Tensor([3, 5])) - self.assertEqual(strided[ri([[2, 1], [0, 3]]), ], - torch.Tensor([[5, 3], [1, 7]])) + # one batch + matrix = random_fullrank_matrix_distinct_singular_value(5, 1).to(device) + matrix_inverse = torch.inverse(matrix) + expected_inv = matrix.squeeze(0).inverse() + self.assertEqual(matrix_inverse, expected_inv.unsqueeze(0)) - # stride is [4, 8] - strided = conv_fn(torch.Tensor()) - strided.set_(reference.storage(), storage_offset=4, - size=torch.Size([2]), stride=[4]) - self.assertEqual(strided[[0]], torch.Tensor([5])) - self.assertEqual(strided[ri([0]), ], torch.Tensor([5])) - self.assertEqual(strided[ri([1]), ], torch.Tensor([9])) - self.assertEqual(strided[[0, 1]], torch.Tensor([5, 9])) - self.assertEqual(strided[ri([0, 1]), ], torch.Tensor([5, 9])) - self.assertEqual(strided[ri([[0, 1], [1, 0]]), ], - torch.Tensor([[5, 9], [9, 5]])) + # four batches + matrices = random_fullrank_matrix_distinct_singular_value(5, 4).to(device) + expected_inv_list = [] + for i in range(0, 4): + expected_inv_list.append(torch.inverse(matrices[i])) + expected_inv = torch.stack(expected_inv_list) + matrices_inverse = torch.inverse(matrices) + self.assertEqual(matrices_inverse, expected_inv) - # reference is 1 2 - # 3 4 - # 5 6 - reference = conv_fn(consec((3, 2))) - self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([1, 3, 5])) - self.assertEqual(reference[ri([0, 1, 2]), ri([1])], torch.Tensor([2, 4, 6])) - self.assertEqual(reference[ri([0]), ri([0])], consec((1,))) - self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6)) - self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]], torch.Tensor([1, 2])) - self.assertEqual(reference[[ri([0, 1, 1, 0, 2]), ri([1])]], - torch.Tensor([2, 4, 4, 2, 6])) - self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], - torch.Tensor([1, 2, 3, 3])) + # six batches (2 x 3) + matrices = random_fullrank_matrix_distinct_singular_value(5, 2, 3).to(device) + expected_inv_list = [] + for mat in matrices.view(-1, 5, 5): + expected_inv_list.append(torch.inverse(mat)) + expected_inv = torch.stack(expected_inv_list).view(2, 3, 5, 5) + matrices_inverse = torch.inverse(matrices) + self.assertEqual(matrices_inverse, expected_inv) - rows = ri([[0, 0], - [1, 2]]) - columns = [0], - self.assertEqual(reference[rows, columns], torch.Tensor([[1, 1], - [3, 5]])) + # incorrect input test + with self.assertRaisesRegex(RuntimeError, "must be batches of square matrices"): + torch.inverse(torch.randn(2, 3, 4, 3)) - rows = ri([[0, 0], - [1, 2]]) - columns = ri([1, 0]) - self.assertEqual(reference[rows, columns], torch.Tensor([[2, 1], - [4, 5]])) - rows = ri([[0, 0], - [1, 2]]) - columns = ri([[0, 1], - [1, 0]]) - self.assertEqual(reference[rows, columns], torch.Tensor([[1, 2], - [4, 5]])) + # correctness test + matrices = random_fullrank_matrix_distinct_singular_value(5, 3).to(device) + matrices_inverse = torch.inverse(matrices) + self.assertEqual(torch.matmul(matrices, matrices_inverse), identity.expand_as(matrices)) + self.assertEqual(torch.matmul(matrices_inverse, matrices), identity.expand_as(matrices)) - # setting values - reference[ri([0]), ri([1])] = -1 - self.assertEqual(reference[ri([0]), ri([1])], torch.Tensor([-1])) - reference[ri([0, 1, 2]), ri([0])] = conv_fn(torch.Tensor([-1, 2, -4])) - self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([-1, - 2, -4])) - reference[rows, columns] = conv_fn(torch.Tensor([[4, 6], [2, 3]])) - self.assertEqual(reference[rows, columns], - torch.Tensor([[4, 6], [2, 3]])) + # torch.inverse with out and batches + matrices = random_fullrank_matrix_distinct_singular_value(5, 3).to(device) + matrices_inverse = torch.empty(3, 5, 5).to(device) + torch.inverse(matrices, out=matrices_inverse) + self.assertEqual(torch.inverse(matrices), matrices_inverse) - # Verify still works with Transposed (i.e. non-contiguous) Tensors + # non-contiguous inputs + if not TEST_NUMPY: + return - reference = conv_fn(torch.Tensor([[0, 1, 2, 3], - [4, 5, 6, 7], - [8, 9, 10, 11]])).t_() + from numpy.linalg import inv + matrices = random_fullrank_matrix_distinct_singular_value(3, 2).to(device).permute(0, 2, 1) + assert not matrices.is_contiguous() + matrices_inverse = torch.inverse(matrices) + expected_inv = torch.as_tensor(inv(matrices.cpu().numpy())) + self.assertEqual(matrices_inverse, expected_inv.to(device)) - # Transposed: [[0, 4, 8], - # [1, 5, 9], - # [2, 6, 10], - # [3, 7, 11]] + def test_bitwise_not(self, device): + res = 0xffff - torch.arange(127, dtype=torch.int8, device=device) + for dtype in (torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): + if dtype == torch.bool: + a = torch.tensor([True, False], device=device) + expected_res = torch.tensor([False, True], device=device) + else: + a = torch.arange(127, dtype=dtype, device=device) + expected_res = res.type(dtype) + # new tensor + self.assertEqual(expected_res, a.bitwise_not()) + # out + b = torch.empty(0, dtype=dtype, device=device) + torch.bitwise_not(a, out=b) + self.assertEqual(expected_res, b) + # in-place + a.bitwise_not_() + self.assertEqual(expected_res, a) - self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([0, 1, - 2])) - self.assertEqual(reference[ri([0, 1, 2]), ri([1])], torch.Tensor([4, 5, - 6])) - self.assertEqual(reference[ri([0]), ri([0])], torch.Tensor([0])) - self.assertEqual(reference[ri([2]), ri([1])], torch.Tensor([6])) - self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]], torch.Tensor([0, 4])) - self.assertEqual(reference[[ri([0, 1, 1, 0, 3]), ri([1])]], - torch.Tensor([4, 5, 5, 4, 7])) - self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], - torch.Tensor([0, 4, 1, 1])) - - rows = ri([[0, 0], - [1, 2]]) - columns = [0], - self.assertEqual(reference[rows, columns], torch.Tensor([[0, 0], - [1, 2]])) + # test exceptions + for dtype in(torch.half, torch.float, torch.double): + a = torch.zeros(10, dtype=dtype, device=device) + # new tensor + with self.assertRaises(RuntimeError): + a.bitwise_not() + # out + b = torch.empty(0, dtype=dtype, device=device) + with self.assertRaises(RuntimeError): + torch.bitwise_not(a, out=b) + # in-place + with self.assertRaises(RuntimeError): + a.bitwise_not_() - rows = ri([[0, 0], - [1, 2]]) - columns = ri([1, 0]) - self.assertEqual(reference[rows, columns], torch.Tensor([[4, 0], - [5, 2]])) - rows = ri([[0, 0], - [1, 3]]) - columns = ri([[0, 1], - [1, 2]]) - self.assertEqual(reference[rows, columns], torch.Tensor([[0, 4], - [5, 11]])) + def test_logical_not(self, device): + for dtype in torch.testing.get_all_dtypes(): + a = torch.tensor([10, 1, 0], dtype=dtype, device=device) + if dtype == torch.bfloat16: + self.assertRaises(RuntimeError, lambda: a.logical_not()) + continue + expected_res = torch.tensor([0, 0, 1], dtype=dtype, device=device) + # new tensor + self.assertEqual(expected_res.bool(), a.logical_not()) + # out + for out_dtype in torch.testing.get_all_dtypes(): + b = torch.empty(0, dtype=out_dtype, device=device) + if out_dtype == torch.bfloat16: + self.assertRaises(RuntimeError, lambda: torch.logical_not(a, out=b)) + continue + torch.logical_not(a, out=b) + self.assertEqual(expected_res.bool(), b.bool()) + # in-place + a.logical_not_() + self.assertEqual(expected_res, a) - # setting values - reference[ri([0]), ri([1])] = -1 - self.assertEqual(reference[ri([0]), ri([1])], torch.Tensor([-1])) - reference[ri([0, 1, 2]), ri([0])] = conv_fn(torch.Tensor([-1, 2, -4])) - self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([-1, - 2, -4])) - reference[rows, columns] = conv_fn(torch.Tensor([[4, 6], [2, 3]])) - self.assertEqual(reference[rows, columns], - torch.Tensor([[4, 6], [2, 3]])) + def test_logical_xor(self, device): + for dtype in (torch.bool,): # Will add more dtypes in the future + expected_res = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) + a = torch.tensor([10, 0, 1, 0], dtype=dtype, device=device) + b = torch.tensor([1, 0, 0, 10], dtype=dtype, device=device) + # new tensor + self.assertEqual(expected_res, a.logical_xor(b)) + # out + c = torch.empty(0, dtype=dtype, device=device) + torch.logical_xor(a, b, out=c) + self.assertEqual(expected_res, c) + # out is not bool + c = torch.empty(0, dtype=torch.uint8, device=device) + with self.assertRaisesRegex(RuntimeError, + r"The output tensor of logical_xor must be a bool tensor\."): + torch.logical_xor(a, b, out=c) + # in-place + a.logical_xor_(b) + self.assertEqual(expected_res, a) - # stride != 1 + def test_isinf(self, device): + t1 = torch.Tensor([1, inf, 2, -inf, nan]).to(device) + t2 = torch.ByteTensor([1, 2, 3]).to(device) + t3 = torch.CharTensor([1, 2, 3]).to(device) + t4 = torch.ShortTensor([1, 2, 3]).to(device) + t5 = torch.IntTensor([1, 2, 3]).to(device) + t6 = torch.LongTensor([1, 2, 3]).to(device) + self.assertEqual(torch.isinf(t1), torch.ByteTensor([0, 1, 0, 1, 0]).to(device)) + self.assertEqual(torch.isinf(t2), torch.ByteTensor([0, 0, 0]).to(device)) + self.assertEqual(torch.isinf(t3), torch.ByteTensor([0, 0, 0]).to(device)) + self.assertEqual(torch.isinf(t4), torch.ByteTensor([0, 0, 0]).to(device)) + self.assertEqual(torch.isinf(t5), torch.ByteTensor([0, 0, 0]).to(device)) + self.assertEqual(torch.isinf(t6), torch.ByteTensor([0, 0, 0]).to(device)) + + def test_clamp(self, device): + m1 = torch.rand(100, device=device).mul(5).add(-2.5) # uniform in [-2.5, 2.5] + # just in case we're extremely lucky. + min_val = -1 + max_val = 1 + m1[1] = min_val + m1[2] = max_val - # strided is [[1 3 5 7], - # [9 11 13 15]] + res1 = m1.clone() + res1.clamp_(min_val, max_val) + res2 = m1.clone() + for i in iter_indices(res2): + res2[i] = max(min_val, min(max_val, res2[i])) + self.assertEqual(res1, res2) - reference = conv_fn(torch.arange(0., 24).view(3, 8)) - strided = conv_fn(torch.Tensor()) - strided.set_(reference.storage(), 1, size=torch.Size([2, 4]), - stride=[8, 2]) + out = m1.clone() + torch.clamp(m1, min=min_val, max=max_val, out=out) + self.assertEqual(out, res1) - self.assertEqual(strided[ri([0, 1]), ri([0])], torch.Tensor([1, 9])) - self.assertEqual(strided[ri([0, 1]), ri([1])], torch.Tensor([3, 11])) - self.assertEqual(strided[ri([0]), ri([0])], torch.Tensor([1])) - self.assertEqual(strided[ri([1]), ri([3])], torch.Tensor([15])) - self.assertEqual(strided[[ri([0, 0]), ri([0, 3])]], torch.Tensor([1, 7])) - self.assertEqual(strided[[ri([1]), ri([0, 1, 1, 0, 3])]], - torch.Tensor([9, 11, 11, 9, 15])) - self.assertEqual(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], - torch.Tensor([1, 3, 9, 9])) + res1 = torch.clamp(m1, min=min_val) + res2 = m1.clone() + for i in iter_indices(res2): + res2[i] = max(min_val, res2[i]) + self.assertEqual(res1, res2) - rows = ri([[0, 0], - [1, 1]]) - columns = [0], - self.assertEqual(strided[rows, columns], torch.Tensor([[1, 1], - [9, 9]])) + torch.clamp(m1, min=min_val, out=out) + self.assertEqual(out, res1) - rows = ri([[0, 1], - [1, 0]]) - columns = ri([1, 2]) - self.assertEqual(strided[rows, columns], torch.Tensor([[3, 13], - [11, 5]])) - rows = ri([[0, 0], - [1, 1]]) - columns = ri([[0, 1], - [1, 2]]) - self.assertEqual(strided[rows, columns], torch.Tensor([[1, 3], - [11, 13]])) + res1 = torch.clamp(m1, max=max_val) + res2 = m1.clone() + for i in iter_indices(res2): + res2[i] = min(max_val, res2[i]) + self.assertEqual(res1, res2) - # setting values + torch.clamp(m1, max=max_val, out=out) + self.assertEqual(out, res1) - # strided is [[10, 11], - # [17, 18]] + # if the tensor contains nan case + test_tens = torch.tensor([nan], device=device) - reference = conv_fn(torch.arange(0., 24).view(3, 8)) - strided = conv_fn(torch.Tensor()) - strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), - stride=[7, 1]) - self.assertEqual(strided[ri([0]), ri([1])], torch.Tensor([11])) - strided[ri([0]), ri([1])] = -1 - self.assertEqual(strided[ri([0]), ri([1])], torch.Tensor([-1])) + res1 = test_tens.clone() + res1.clamp_(min_val, max_val) + res2 = test_tens.clone() + for i in iter_indices(res2): + res2[i] = max(min(res2[i], max_val), min_val) + self.assertEqual(torch.isnan(res1), torch.isnan(res2)) - reference = conv_fn(torch.arange(0., 24).view(3, 8)) - strided = conv_fn(torch.Tensor()) - strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), - stride=[7, 1]) - self.assertEqual(strided[ri([0, 1]), ri([1, 0])], torch.Tensor([11, - 17])) - strided[ri([0, 1]), ri([1, 0])] = conv_fn(torch.Tensor([-1, 2])) - self.assertEqual(strided[ri([0, 1]), ri([1, 0])], torch.Tensor([-1, - 2])) + out = test_tens.clone() + torch.clamp(test_tens, min=min_val, max=max_val, out=out) + self.assertEqual(torch.isnan(out), torch.isnan(res1)) - reference = conv_fn(torch.arange(0., 24).view(3, 8)) - strided = conv_fn(torch.Tensor()) - strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), - stride=[7, 1]) + res1 = torch.clamp(test_tens, min=min_val) + res2 = test_tens.clone() + for i in iter_indices(res2): + res2[i] = max(res2[i], min_val) + self.assertEqual(torch.isnan(res1), torch.isnan(res2)) - rows = ri([[0], - [1]]) - columns = ri([[0, 1], - [0, 1]]) - self.assertEqual(strided[rows, columns], - torch.Tensor([[10, 11], [17, 18]])) - strided[rows, columns] = conv_fn(torch.Tensor([[4, 6], [2, 3]])) - self.assertEqual(strided[rows, columns], - torch.Tensor([[4, 6], [2, 3]])) + torch.clamp(test_tens, min=min_val, out=out) + self.assertEqual(torch.isnan(out), torch.isnan(res1)) - # Tests using less than the number of dims, and ellipsis + res1 = torch.clamp(test_tens, max=max_val) + res2 = test_tens.clone() + for i in iter_indices(res2): + res2[i] = min(res2[i], max_val) + self.assertEqual(torch.isnan(res1), torch.isnan(res2)) - # reference is 1 2 - # 3 4 - # 5 6 - reference = conv_fn(consec((3, 2))) - self.assertEqual(reference[ri([0, 2]), ], torch.Tensor([[1, 2], [5, 6]])) - self.assertEqual(reference[ri([1]), ...], torch.Tensor([[3, 4]])) - self.assertEqual(reference[..., ri([1])], torch.Tensor([[2], [4], [6]])) + torch.clamp(test_tens, max=max_val, out=out) + self.assertEqual(torch.isnan(out), torch.isnan(res1)) - # verify too many indices fails - with self.assertRaises(IndexError): - reference[ri([1]), ri([0, 2]), ri([3])] + error_msg = 'At least one of \'min\' or \'max\' must not be None' + with self.assertRaisesRegex(RuntimeError, error_msg): + m1.clamp() + with self.assertRaisesRegex(RuntimeError, error_msg): + m1.clamp_() - # test invalid index fails - reference = conv_fn(torch.empty(10)) - # can't test cuda because it is a device assert - if not reference.is_cuda: - for err_idx in (10, -11): - with self.assertRaisesRegex(IndexError, r'out of'): - reference[err_idx] - with self.assertRaisesRegex(IndexError, r'out of'): - reference[conv_fn(torch.LongTensor([err_idx]))] - with self.assertRaisesRegex(IndexError, r'out of'): - reference[[err_idx]] + def test_cat_empty_legacy(self, device): + # FIXME: this is legacy behavior and should be removed + # when we support empty tensors with arbitrary sizes + dtype = torch.float32 - if TEST_NUMPY: - # we use numpy to compare against, to verify that our advanced - # indexing semantics are the same, and also for ease of test - # writing + x = torch.randn((4, 3, 32, 32), dtype=dtype, device=device) + empty = torch.randn((0,), dtype=dtype, device=device) - def tensor_indices_to_np(tensor, indices): - # convert the Torch Tensor to a numpy array - if (tensor.is_cuda): - tensor = tensor.cpu() - npt = tensor.numpy() + res1 = torch.cat([x, empty], dim=1) + res2 = torch.cat([empty, x], dim=1) + self.assertEqual(res1, res2) - # convert indices - idxs = tuple(i.tolist() if isinstance(i, torch.LongTensor) else - i for i in indices) + conv = torch.nn.Conv2d(3, 3, kernel_size=1).float().to(device) + res1 = torch.cat([conv(x), empty], dim=1) + res2 = torch.cat([empty, conv(x)], dim=1) + self.assertEqual(res1, res2) - return npt, idxs + res1 = torch.cat([empty, empty], dim=1) + self.assertEqual(res1, empty) - def get_numpy(tensor, indices): - npt, idxs = tensor_indices_to_np(tensor, indices) + with self.assertRaisesRegex(RuntimeError, + 'non-empty list of Tensors'): + torch.cat([], dim=1) - # index and return as a Torch Tensor - return torch.Tensor(npt[idxs]) + def test_cat_empty(self, device): + dtype = torch.float32 - def set_numpy(tensor, indices, value): - if not isinstance(value, int): - if value.is_cuda: - value = value.cpu() - value = value.numpy() + x = torch.randn((4, 3, 32, 32), dtype=dtype, device=device) + empty = torch.randn((4, 0, 32, 32), dtype=dtype, device=device) - npt, idxs = tensor_indices_to_np(tensor, indices) - npt[idxs] = value - return npt + res1 = torch.cat([x, empty], dim=1) + res2 = torch.cat([empty, x], dim=1) + self.assertEqual(res1, res2) - def assert_get_eq(tensor, indexer): - self.assertEqual(tensor[indexer], - conv_fn(get_numpy(tensor, indexer))) + conv = torch.nn.Conv2d(3, 3, kernel_size=1).float().to(device) + res1 = torch.cat([conv(x), empty], dim=1) + res2 = torch.cat([empty, conv(x)], dim=1) + self.assertEqual(res1, res2) - def assert_set_eq(tensor, indexer, val): - pyt = tensor.clone() - numt = tensor.clone() - pyt[indexer] = val - numt = conv_fn(torch.Tensor(set_numpy(numt, indexer, val))) - self.assertEqual(pyt, numt) + res1 = torch.cat([empty, empty], dim=1) + self.assertEqual(res1, empty) - def assert_backward_eq(tensor, indexer): - cpu = tensor.float().clone().detach().requires_grad_(True) - outcpu = cpu[indexer] - gOcpu = torch.rand_like(outcpu) - outcpu.backward(gOcpu) - gpu = cpu.cuda().detach().requires_grad_(True) - outgpu = gpu[indexer] - outgpu.backward(gOcpu.cuda()) - self.assertEqual(cpu.grad, gpu.grad) + # check non-legacy-behavior (sizes don't match) + empty = torch.randn((4, 0, 31, 32), dtype=dtype, device=device) + self.assertRaises(RuntimeError, lambda: torch.cat([x, empty], dim=1)) + self.assertRaises(RuntimeError, lambda: torch.cat([empty, x], dim=1)) - def get_set_tensor(indexed, indexer): - set_size = indexed[indexer].size() - set_count = indexed[indexer].numel() - set_tensor = conv_fn(torch.randperm(set_count).view(set_size).double()) - return set_tensor + # check non-legacy-behavior (dimensions don't match) + empty = torch.randn((4, 0), dtype=dtype, device=device) + self.assertRaises(RuntimeError, lambda: torch.cat([x, empty], dim=1)) + self.assertRaises(RuntimeError, lambda: torch.cat([empty, x], dim=1)) - # Tensor is 0 1 2 3 4 - # 5 6 7 8 9 - # 10 11 12 13 14 - # 15 16 17 18 19 - reference = conv_fn(torch.arange(0., 20).view(4, 5)) + @slowTest + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_inverse_many_batches(self, device): + from common_utils import random_fullrank_matrix_distinct_singular_value - indices_to_test = [ - # grab the second, fourth columns - [slice(None), [1, 3]], + matrices = random_fullrank_matrix_distinct_singular_value(5, 256, 256).to(device) + matrices_inverse = torch.inverse(matrices) + self.assertEqual(torch.matmul(matrices_inverse, matrices), + torch.eye(5).to(device).expand_as(matrices)) - # first, third rows, - [[0, 2], slice(None)], + matrices = random_fullrank_matrix_distinct_singular_value(3, 512, 512).to(device) + matrices_inverse = torch.inverse(matrices) + self.assertEqual(torch.matmul(matrices, matrices_inverse), + torch.eye(3).to(device).expand_as(matrices)) - # weird shape - [slice(None), [[0, 1], - [2, 3]]], - # negatives - [[-1], [0]], - [[0, 2], [-1]], - [slice(None), [-1]], - ] + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_pinverse(self, device): + from common_utils import random_fullrank_matrix_distinct_singular_value as fullrank - # only test dupes on gets - get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]] + def run_test(M): + # Testing against definition for pseudo-inverses + MPI = torch.pinverse(M) + if M.numel() > 0: + self.assertEqual(M, M.matmul(MPI).matmul(M), 1e-8, 'pseudo-inverse condition 1') + self.assertEqual(MPI, MPI.matmul(M).matmul(MPI), 1e-8, 'pseudo-inverse condition 2') + self.assertEqual(M.matmul(MPI), (M.matmul(MPI)).transpose(-2, -1), 1e-8, 'pseudo-inverse condition 3') + self.assertEqual(MPI.matmul(M), (MPI.matmul(M)).transpose(-2, -1), 1e-8, 'pseudo-inverse condition 4') + else: + self.assertEqual(M.shape, MPI.shape[:-2] + (MPI.shape[-1], MPI.shape[-2])) + for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5), # square matrices + (3, 2), (5, 3, 2), (7, 5, 3, 2), # fat matrices + (2, 3), (5, 2, 3), (7, 5, 2, 3), # thin matrices + (0, 0), (0, 2), (2, 0), (3, 0, 0), (0, 3, 0), (0, 0, 3)]: # zero numel matrices + M = torch.randn(*sizes, device=device) + run_test(M) - for indexer in get_indices_to_test: - assert_get_eq(reference, indexer) - if torch.cuda.is_available(): - assert_backward_eq(reference, indexer) + # Test inverse and pseudo-inverse for invertible matrix + for sizes in [(5, 5), (3, 5, 5), (3, 7, 5, 5)]: + matsize = sizes[-1] + batchdims = sizes[:-2] + M = fullrank(matsize, *batchdims).to(device=device) + self.assertEqual(torch.eye(matsize, device=device).expand(sizes), M.pinverse().matmul(M), + 1e-7, 'pseudo-inverse for invertible matrix') + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_matrix_rank(self, device): + a = torch.eye(10, device=device) + self.assertEqual(torch.matrix_rank(a).item(), 10) + self.assertEqual(torch.matrix_rank(a, True).item(), 10) - for indexer in indices_to_test: - assert_set_eq(reference, indexer, 44) - assert_set_eq(reference, - indexer, - get_set_tensor(reference, indexer)) + a[5, 5] = 0 + self.assertEqual(torch.matrix_rank(a).item(), 9) + self.assertEqual(torch.matrix_rank(a, True).item(), 9) - reference = conv_fn(torch.arange(0., 160).view(4, 8, 5)) + a = torch.randn(24, 42, device=device) + self.assertEqual(torch.matrix_rank(a), torch.matrix_rank(a.t())) + aaT = torch.mm(a, a.t()) + self.assertEqual(torch.matrix_rank(aaT), torch.matrix_rank(aaT, True)) + aTa = torch.mm(a.t(), a) + self.assertEqual(torch.matrix_rank(aTa), torch.matrix_rank(aTa, True)) - indices_to_test = [ - [slice(None), slice(None), [0, 3, 4]], - [slice(None), [2, 4, 5, 7], slice(None)], - [[2, 3], slice(None), slice(None)], - [slice(None), [0, 2, 3], [1, 3, 4]], - [slice(None), [0], [1, 2, 4]], - [slice(None), [0, 1, 3], [4]], - [slice(None), [[0, 1], [1, 0]], [[2, 3]]], - [slice(None), [[0, 1], [2, 3]], [[0]]], - [slice(None), [[5, 6]], [[0, 3], [4, 4]]], - [[0, 2, 3], [1, 3, 4], slice(None)], - [[0], [1, 2, 4], slice(None)], - [[0, 1, 3], [4], slice(None)], - [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], - [[[0, 1], [1, 0]], [[2, 3]], slice(None)], - [[[0, 1], [2, 3]], [[0]], slice(None)], - [[[2, 1]], [[0, 3], [4, 4]], slice(None)], - [[[2]], [[0, 3], [4, 1]], slice(None)], - # non-contiguous indexing subspace - [[0, 2, 3], slice(None), [1, 3, 4]], + if TEST_NUMPY: + from numpy.linalg import matrix_rank + a = torch.randn(35, 75, device=device) + self.assertEqual(torch.matrix_rank(a).item(), matrix_rank(a.cpu().numpy())) + self.assertEqual(torch.matrix_rank(a, 0.01).item(), matrix_rank(a.cpu().numpy(), 0.01)) - # less dim, ellipsis - [[0, 2], ], - [[0, 2], slice(None)], - [[0, 2], Ellipsis], - [[0, 2], slice(None), Ellipsis], - [[0, 2], Ellipsis, slice(None)], - [[0, 2], [1, 3]], - [[0, 2], [1, 3], Ellipsis], - [Ellipsis, [1, 3], [2, 3]], - [Ellipsis, [2, 3, 4]], - [Ellipsis, slice(None), [2, 3, 4]], - [slice(None), Ellipsis, [2, 3, 4]], + aaT = torch.mm(a, a.t()) + self.assertEqual(torch.matrix_rank(aaT).item(), matrix_rank(aaT.cpu().numpy())) + self.assertEqual(torch.matrix_rank(aaT, 0.01).item(), matrix_rank(aaT.cpu().numpy(), 0.01)) - # ellipsis counts for nothing - [Ellipsis, slice(None), slice(None), [0, 3, 4]], - [slice(None), Ellipsis, slice(None), [0, 3, 4]], - [slice(None), slice(None), Ellipsis, [0, 3, 4]], - [slice(None), slice(None), [0, 3, 4], Ellipsis], - [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], - [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)], - [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis], - ] + if np.lib.NumpyVersion(np.__version__) >= '1.14.0': + self.assertEqual(torch.matrix_rank(aaT, True).item(), matrix_rank(aaT.cpu().numpy(), True)) + self.assertEqual(torch.matrix_rank(aaT, 0.01, True).item(), + matrix_rank(aaT.cpu().numpy(), 0.01, True)) - for indexer in indices_to_test: - assert_get_eq(reference, indexer) - assert_set_eq(reference, indexer, 212) - assert_set_eq(reference, - indexer, - get_set_tensor(reference, indexer)) - if torch.cuda.is_available(): - assert_backward_eq(reference, indexer) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_matrix_power(self, device): + def run_test(M, sign=1): + if sign == -1: + M = M.inverse() + MP2 = torch.matrix_power(M, 2) + self.assertEqual(MP2, torch.matmul(M, M)) - reference = conv_fn(torch.arange(0., 1296).view(3, 9, 8, 6)) + MP3 = torch.matrix_power(M, 3) + self.assertEqual(MP3, torch.matmul(MP2, M)) - indices_to_test = [ - [slice(None), slice(None), slice(None), [0, 3, 4]], - [slice(None), slice(None), [2, 4, 5, 7], slice(None)], - [slice(None), [2, 3], slice(None), slice(None)], - [[1, 2], slice(None), slice(None), slice(None)], - [slice(None), slice(None), [0, 2, 3], [1, 3, 4]], - [slice(None), slice(None), [0], [1, 2, 4]], - [slice(None), slice(None), [0, 1, 3], [4]], - [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]], - [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]], - [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]], - [slice(None), [0, 2, 3], [1, 3, 4], slice(None)], - [slice(None), [0], [1, 2, 4], slice(None)], - [slice(None), [0, 1, 3], [4], slice(None)], - [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)], - [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)], - [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)], - [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)], - [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)], - [[0, 1, 2], [1, 3, 4], slice(None), slice(None)], - [[0], [1, 2, 4], slice(None), slice(None)], - [[0, 1, 2], [4], slice(None), slice(None)], - [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)], - [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)], - [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)], - [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)], - [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]], - [slice(None), [2, 3, 4], [1, 3, 4], [4]], - [slice(None), [0, 1, 3], [4], [1, 3, 4]], - [slice(None), [6], [0, 2, 3], [1, 3, 4]], - [slice(None), [2, 3, 5], [3], [4]], - [slice(None), [0], [4], [1, 3, 4]], - [slice(None), [6], [0, 2, 3], [1]], - [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]], - [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)], - [[2, 0, 1], [1, 2, 3], [4], slice(None)], - [[0, 1, 2], [4], [1, 3, 4], slice(None)], - [[0], [0, 2, 3], [1, 3, 4], slice(None)], - [[0, 2, 1], [3], [4], slice(None)], - [[0], [4], [1, 3, 4], slice(None)], - [[1], [0, 2, 3], [1], slice(None)], - [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)], + MP4 = torch.matrix_power(M, 4) + self.assertEqual(MP4, torch.matmul(MP2, MP2)) - # less dim, ellipsis - [Ellipsis, [0, 3, 4]], - [Ellipsis, slice(None), [0, 3, 4]], - [Ellipsis, slice(None), slice(None), [0, 3, 4]], - [slice(None), Ellipsis, [0, 3, 4]], - [slice(None), slice(None), Ellipsis, [0, 3, 4]], - [slice(None), [0, 2, 3], [1, 3, 4]], - [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis], - [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)], - [[0], [1, 2, 4]], - [[0], [1, 2, 4], slice(None)], - [[0], [1, 2, 4], Ellipsis], - [[0], [1, 2, 4], Ellipsis, slice(None)], - [[1], ], - [[0, 2, 1], [3], [4]], - [[0, 2, 1], [3], [4], slice(None)], - [[0, 2, 1], [3], [4], Ellipsis], - [Ellipsis, [0, 2, 1], [3], [4]], - ] + MP6 = torch.matrix_power(M, 6) + self.assertEqual(MP6, torch.matmul(MP3, MP3)) - for indexer in indices_to_test: - assert_get_eq(reference, indexer) - assert_set_eq(reference, indexer, 1333) - assert_set_eq(reference, - indexer, - get_set_tensor(reference, indexer)) - indices_to_test += [ - [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]], - [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]], - ] - for indexer in indices_to_test: - assert_get_eq(reference, indexer) - assert_set_eq(reference, indexer, 1333) - if torch.cuda.is_available(): - assert_backward_eq(reference, indexer) + MP0 = torch.matrix_power(M, 0) + self.assertEqual(MP0, torch.eye(M.size(-2)).expand_as(M)) - def test_advancedindex(self): - self._test_advancedindex(self, lambda x: x) + # Single matrix + M = torch.randn(5, 5, device=device) + run_test(M) - @staticmethod - def _test_advancedindex_big(self, conv_fn): - reference = conv_fn(torch.arange(0, 123344).int()) + # Batch matrices + M = torch.randn(3, 3, 3, device=device) + run_test(M) - self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ], - torch.LongTensor([0, 123, 44488, 68807, 123343])) + # Many batch matrices + M = torch.randn(2, 3, 3, 3, device=device) + run_test(M) - def test_advancedindex_big(self): - self._test_advancedindex_big(self, lambda x: x) + # This is for negative powers + from common_utils import random_fullrank_matrix_distinct_singular_value + M = random_fullrank_matrix_distinct_singular_value(5).to(device) + run_test(M, sign=-1) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_empty_storage_view(self): - # we should be able to "modify" slices of a 0-element - # array without an error being raised due to - # trying to resize its storage - t = torch.from_numpy(np.empty((0, 4))) - t[:, 1::2] *= 1 + M = random_fullrank_matrix_distinct_singular_value(3, 3).to(device) + run_test(M, sign=-1) - def test_atan2(self): - def _test_atan2_with_size(size, device): - a = torch.rand(size=size, device=device, dtype=torch.double) - b = torch.rand(size=size, device=device, dtype=torch.double) - actual = a.atan2(b) - x = a.view(-1) - y = b.view(-1) - expected = torch.tensor([math.atan2(x[i].item(), y[i].item()) for i in range(x.numel())], - device=device, dtype=torch.double) - self.assertTrue(torch.allclose(expected, actual.view(-1), rtol=0, atol=0.02)) - for device in torch.testing.get_all_device_types(): - _test_atan2_with_size((2, 2), device) - _test_atan2_with_size((3, 3), device) - _test_atan2_with_size((5, 5), device) + M = random_fullrank_matrix_distinct_singular_value(3, 2, 3).to(device) + run_test(M, sign=-1) - def test_atan2_edgecases(self): - def _test_atan2(x, y, expected, device, dtype): - expected_tensor = torch.tensor([expected], dtype=dtype, device=device) - x_tensor = torch.tensor([x], dtype=dtype, device=device) - y_tensor = torch.tensor([y], dtype=dtype, device=device) - actual = torch.atan2(y_tensor, x_tensor) - self.assertTrue(torch.allclose(expected_tensor, actual, rtol=0, atol=0.02)) - for device in torch.testing.get_all_device_types(): - for dtype in [torch.float, torch.double]: - _test_atan2(0, 0, 0, device, dtype) - _test_atan2(0, 1, math.pi / 2, device, dtype) - _test_atan2(0, -1, math.pi / -2, device, dtype) - _test_atan2(-1, 0, math.pi, device, dtype) - _test_atan2(1, 0, 0, device, dtype) - _test_atan2(-1, -1, math.pi * -3 / 4 , device, dtype) - _test_atan2(1, 1, math.pi / 4 , device, dtype) - _test_atan2(1, -1, math.pi / -4 , device, dtype) - _test_atan2(-1, 1, math.pi * 3 / 4 , device, dtype) + def test_chain_matmul(self, device): + def product(matrices): + for mat in matrices[1:]: + matrices[0] = matrices[0].mm(mat) + return matrices[0] - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_newaxis_numpy_comparison(self): - def run_test(tensor, *idx): - npt = tensor.numpy() - self.assertEqual(tensor[idx], npt[idx]) + def run_test(p, device): + matrices = [] + for (pi, pi_1) in zip(p[:-1], p[1:]): + matrices.append(torch.randn(pi, pi_1, device=device)) + self.assertEqual(torch.chain_matmul(*matrices), product(matrices)) - # 1D Tensor Tests - x = torch.arange(0, 10) - cases = [ - [None], - [None, None], - [Ellipsis, None], - [None, Ellipsis], - [2, None], - [None, 2], - [Ellipsis, None, 2], - [Ellipsis, 2, None], - [2, Ellipsis, None], - [2, None, Ellipsis], - [None, 2, Ellipsis], - [None, Ellipsis, 2], - ] + run_test([10, 20, 30, 5], device) + run_test([15, 5, 10, 20, 25], device) - for case in cases: - run_test(x, *case) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_det_logdet_slogdet(self, device): + def reference_slogdet(M): + if TEST_NUMPY: + sdet, logabsdet = np.linalg.slogdet(M.detach().cpu().numpy()) + return M.new_tensor(sdet), M.new_tensor(logabsdet) + else: + # naive row reduction + M = M.clone() + l = M.size(0) + multiplier = 1 + for i in range(l): + if M[i, 0].item() != 0: + if i != 0: + M[0], M[i] = M[i], M[0] + multiplier = -1 + break + else: + return 0 + for i in range(1, l): + row = M[i] + for j in range(i): + row -= row[j] / M[j, j] * M[j] + M[i] = row + sdet = M.diag().sign().prod() + logabsdet = M.diag().abs_().log_().sum().add_(math.log(multiplier)) + return sdet, logabsdet - # 2D Tensor Tests - x = torch.arange(0, 12).view(3, 4) - cases = [ - [None], - [None, None], - [None, None, None], - [Ellipsis, None], - [Ellipsis, None, None], - [None, Ellipsis], - [None, Ellipsis, None], - [None, None, Ellipsis], - [2, None], - [2, None, Ellipsis], - [2, Ellipsis, None], - [None, 2, Ellipsis], - [Ellipsis, 2, None], - [Ellipsis, None, 2], - [None, Ellipsis, 2], - [1, 2, None], - [1, 2, Ellipsis, None], - [1, Ellipsis, 2, None], - [Ellipsis, 1, None, 2], - [Ellipsis, 1, 2, None], - [1, None, 2, Ellipsis], - [None, 1, Ellipsis, 2], - [None, 1, 2, Ellipsis], - ] + def test_single_det(M, target, desc): + target_sdet, target_logabsdet = target - for case in cases: - run_test(x, *case) + det = M.det() + logdet = M.logdet() + sdet, logabsdet = M.slogdet() - def test_newindex(self): - reference = self._consecutive((3, 3, 3)) - # This relies on __index__() being correct - but we have separate tests for that + # Test det + self.assertEqual(det, target_sdet * target_logabsdet.exp(), 1e-7, '{} (det)'.format(desc)) - def checkPartialAssign(index): - reference = torch.zeros(3, 3, 3) - reference[index] = self._consecutive((3, 3, 3))[index] - self.assertEqual(reference[index], self._consecutive((3, 3, 3))[index], 0) - reference[index] = 0 - self.assertEqual(reference, torch.zeros(3, 3, 3), 0) + # Test slogdet + # Compare the overall value rather than individual parts because of + # precision issues when det is near zero. + self.assertEqual(sdet * logabsdet.exp(), target_sdet * target_logabsdet.exp(), 1e-7, '{} (slogdet)'.format(desc)) - checkPartialAssign(0) - checkPartialAssign(1) - checkPartialAssign(2) - checkPartialAssign((0, 1)) - checkPartialAssign((1, 2)) - checkPartialAssign((0, 2)) - checkPartialAssign(torch.LongTensor((0, 2))) + # Test logdet + # Compare logdet against our own pytorch slogdet because they should + # be consistent, while it may behave slightly differently with other + # slogdet implementations when det is near zero due to precision + # issues. + if sdet.item() < 0: + self.assertTrue(logdet.item() != logdet.item(), '{} (logdet negative case)'.format(desc)) + else: + self.assertEqual(logdet.exp(), target_logabsdet.exp(), 1e-7, '{} (logdet non-negative case)'.format(desc)) - with self.assertRaises(IndexError): - reference[1, 1, 1, 1] = 1 - with self.assertRaises(IndexError): - reference[1, 1, 1, (1, 1)] = 1 - with self.assertRaises(IndexError): - reference[3, 3, 3, 3, 3, 3, 3, 3] = 1 - with self.assertRaises(IndexError): - reference[0.0] = 1 - with self.assertRaises(TypeError): - reference[0.0:2.0] = 1 - with self.assertRaises(IndexError): - reference[0.0, 0.0:2.0] = 1 - with self.assertRaises(IndexError): - reference[0.0, :, 0.0:2.0] = 1 - with self.assertRaises(IndexError): - reference[0.0, ..., 0.0:2.0] = 1 - with self.assertRaises(IndexError): - reference[0.0, :, 0.0] = 1 + eye = torch.eye(5, device=device) + test_single_det(eye, (torch.ones((), device=device), torch.zeros((), device=device)), 'identity') - def test_index_copy(self): - for device in torch.testing.get_all_device_types(): - num_copy, num_dest = 3, 20 - dest = torch.randn(num_dest, 4, 5, device=device) - src = torch.randn(num_copy, 4, 5, device=device) - idx = torch.randperm(num_dest, device=device).narrow(0, 0, num_copy) - dest2 = dest.clone() - dest.index_copy_(0, idx, src) - for i in range(idx.size(0)): - dest2[idx[i]] = src[i] - self.assertEqual(dest, dest2, 0) - - dest = torch.randn(num_dest, device=device) - src = torch.randn(num_copy, device=device) - idx = torch.randperm(num_dest, device=device).narrow(0, 0, num_copy) - dest2 = dest.clone() - dest.index_copy_(0, idx, src) - for i in range(idx.size(0)): - dest2[idx[i]] = src[i] - self.assertEqual(dest, dest2, 0) - - # Bool tensor - dest = torch.zeros(2, 2, dtype=torch.bool, device=device) - src = torch.tensor([[True, True], [True, True]], device=device) - index = torch.tensor([0, 1], device=device) - dest.index_copy_(0, index, src) - self.assertEqual(dest, torch.tensor([[True, True], [True, True]], device=device)) - - # Error cases - a = torch.randn(3, 5) - c = torch.zeros(3) - self.assertRaises(IndexError, lambda: a.index_copy_(dim=1, index=torch.tensor([3]), source=c)) + def test(M): + assert M.size(0) >= 5, 'this helper fn assumes M to be at least 5x5' + M = M.to(device) - def test_index_add(self): - num_copy, num_dest = 3, 3 - dest = torch.randn(num_dest, 4, 5) - src = torch.randn(num_copy, 4, 5) - idx = torch.randperm(num_dest).narrow(0, 0, num_copy) - dest2 = dest.clone() - dest.index_add_(0, idx, src) - for i in range(idx.size(0)): - dest2[idx[i]] += src[i] - self.assertEqual(dest, dest2) + ref_M_sdet, ref_M_logabsdet = reference_slogdet(M) - dest = torch.randn(num_dest) - src = torch.randn(num_copy) - idx = torch.randperm(num_dest).narrow(0, 0, num_copy) - dest2 = dest.clone() - dest.index_add_(0, idx, src) - for i in range(idx.size(0)): - dest2[idx[i]] = dest2[idx[i]] + src[i] - self.assertEqual(dest, dest2) + test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'basic') + if ref_M_logabsdet.exp().item() >= 1e-6: # skip singular + M_inv = M.inverse() + test_single_det(M_inv, reference_slogdet(M_inv), 'inverse') - def test_index_fill(self): - for device in torch.testing.get_all_device_types(): - for dt in torch.testing.get_all_dtypes(): - if dt == torch.half or dt == torch.bfloat16: - continue + test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'transpose') - x = torch.tensor([[1, 2], [4, 5]], dtype=dt, device=device) - index = torch.tensor([0], device=device) - x.index_fill_(1, index, 0) - self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dt, device=device)) + for x in [0, 2, 4]: + for scale in [-2, -0.1, 0, 10]: + if scale > 0: + target = ref_M_sdet, ref_M_logabsdet + math.log(scale) + elif scale == 0: + target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) + else: + target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-scale) - def test_index_select(self): - for device in torch.testing.get_all_device_types(): - src = torch.randn(3, 4, 5, device=device) - # Index can be duplicated. - idx = torch.tensor([2, 1, 0, 1, 2], dtype=torch.long, device=device) - dest = torch.index_select(src, 0, idx) - self.assertEqual(dest.shape, (5, 4, 5)) - for i in range(idx.size(0)): - self.assertEqual(dest[i], src[idx[i]]) - - # Check that 'out' is used correctly. - out = torch.randn(5 * 4 * 5, device=device) - dest = torch.index_select(src, 0, idx, out=out.view(5, 4, 5)) - self.assertEqual(dest.shape, (5, 4, 5)) - for i in range(idx.size(0)): - self.assertEqual(dest[i], src[idx[i]]) - out.fill_(0.123) - self.assertEqual(out, dest.view(-1)) # Must point to the same storage. - - # Bool tensor - src = torch.tensor([False, True, False, False], device=device, dtype=torch.bool) - idx = torch.tensor([1], dtype=torch.long, device=device) - dest = torch.index_select(src, 0, idx) - self.assertEqual(torch.tensor([True]), dest) + # dim 0 + M_clone = M.clone() + M_clone[:, x] *= scale + test_single_det(M_clone, target, 'scale a row') + # dim 1 + M_clone = M.clone() + M_clone[x, :] *= scale + test_single_det(M_clone, target, 'scale a column') - def test_t(self): - # Test 0D tensors - x = torch.randn(()) - self.assertEqual(x, x.t()) - x = x.to_sparse() - self.assertEqual(x, x.t()) + for x1, x2 in [(0, 3), (4, 1), (3, 2)]: + assert x1 != x2, 'x1 and x2 needs to be different for this test' + target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) + # dim 0 + M_clone = M.clone() + M_clone[:, x2] = M_clone[:, x1] + test_single_det(M_clone, target, 'two rows are same') + # dim 1 + M_clone = M.clone() + M_clone[x2, :] = M_clone[x1, :] + test_single_det(M_clone, target, 'two columns are same') - # Test 1D tensors - x = torch.arange(4) - self.assertEqual(x, x.t()) - x = x.to_sparse() - self.assertEqual(x, x.t()) + for scale1, scale2 in [(0.3, -1), (0, 2), (10, 0.1)]: + det_scale = scale1 * scale2 * -1 + if det_scale > 0: + target = ref_M_sdet, ref_M_logabsdet + math.log(det_scale) + elif det_scale == 0: + target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) + else: + target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-det_scale) - # Test 2D tensors - x = torch.rand((2, 2)) - self.assertEqual(x.t(), x.transpose(0, 1)) - x = x.to_sparse() - self.assertEqual(x.t(), x.transpose(0, 1)) + # dim 0 + M_clone = M.clone() + t = M_clone[:, x1] * scale1 + M_clone[:, x1] += M_clone[:, x2] * scale2 + M_clone[:, x2] = t + test_single_det(M_clone, target, 'exchanging rows') + # dim 1 + M_clone = M.clone() + t = M_clone[x1, :] * scale1 + M_clone[x1, :] += M_clone[x2, :] * scale2 + M_clone[x2, :] = t + test_single_det(M_clone, target, 'exchanging columns') - # Test 3D tensor - x = torch.rand((2, 2, 2)) - with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'): - x.t() - x = x.to_sparse() - with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'): - x.t() + def get_random_mat_scale(n): + # For matrices with values i.i.d. with 0 mean, unit variance, and + # subexponential tail, we have: + # E[log det(A^2)] \approx log((n-1)!) + # + # Notice: + # log Var[det(A)] = log E[det(A^2)] >= E[log det(A^2)] + # + # So: + # stddev[det(A)] >= sqrt( (n-1)! ) + # + # We use this as an intuitive guideline to scale random generated + # matrices so our closeness tests can work more robustly: + # scale by sqrt( (n-1)! )^(-1/n) = ( (n-1)! )^(-1/(2n)) + # + # source: https://arxiv.org/pdf/1112.0752.pdf - def test_take(self): - def check(src, idx): - expected = src.contiguous().view(-1).index_select( - 0, idx.contiguous().view(-1)).view_as(idx) - actual = src.take(idx) - self.assertEqual(actual.size(), idx.size()) - self.assertEqual(expected, actual) + # TODO: technically we need subexponential distn for this to hold, + # but we mostly use gaussian entries below. Consider switching + # to Chi-sq if this turns out not stable enough, since Chi-sq + # is easy enough to sample from. + return math.factorial(n - 1) ** (-1.0 / (2 * n)) - src = torch.randn(2, 3, 5) - idx = torch.LongTensor([[0, 2], [3, 4]]) - check(src, idx) - check(src.transpose(1, 2), idx) - check(src.bool(), idx) + for n in [5, 10, 25]: + scale = get_random_mat_scale(n) + test(torch.randn(n, n, device=device) * scale) + r = torch.randn(n, n, device=device) * scale + # symmetric psd + test(r.mm(r.t())) + # symmetric pd + r = torch.randn(n, n, device=device) * scale + test(r.mm(r.t()) + torch.eye(n, device=device) * 1e-6) + # symmetric + r = torch.randn(n, n, device=device) * scale + for i in range(n): + for j in range(i): + r[i, j] = r[j, i] + test(r) + # non-contiguous + test((torch.randn(n, n, n + 1, device=device) * scale)[:, 2, 1:]) + # det = 0 + r = torch.randn(n, n, device=device) * scale + u, s, v = r.svd() + if reference_slogdet(u)[0] < 0: + u = -u + if reference_slogdet(v)[0] < 0: + v = -v + s[0] *= -1 + s[-1] = 0 + test(u.mm(s.diag()).mm(v)) - def test_take_empty(self): - for device in torch.testing.get_all_device_types(): - for input_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]: - for indices_shape in [(0,), (0, 1, 2, 0)]: - input = torch.empty(input_shape, device=device) - indices = torch.empty(indices_shape, dtype=torch.int64, device=device) - self.assertEqual(indices, torch.take(input, indices)) + # Small values to test numerical stability. Note that we don't scale + # this matrix. + r = torch.randn(512, 512, device=device) + u, s, v = r.svd() + s.fill_(1. / (100 * s.numel())) + test(u.mm(s.diag()).mm(v)) - def test_put_(self): - def check(dst, idx, value): - expected = dst.clone().view(-1).index_copy_( - 0, idx.contiguous().view(-1), value.contiguous().view(-1)) - expected = expected.view_as(dst) - dst.put_(idx, value) - self.assertEqual(expected, dst) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_det_logdet_slogdet_batched(self, device): + from common_utils import (random_symmetric_matrix, random_symmetric_psd_matrix, + random_symmetric_pd_matrix, random_square_matrix_of_rank) - dst = torch.randn(2, 3, 5) - idx = torch.LongTensor([[0, 2], [3, 4]]) - values = torch.randn(2, 2) - check(dst, idx, values) - check(dst.transpose(1, 2), idx, values) + # mat_chars denotes matrix characteristics + # possible values are: sym, sym_psd, sym_pd, sing, non_sym + def run_test(matsize, batchdims, mat_chars): + num_matrices = reduce(lambda x, y: x * y, batchdims, 1) + list_of_matrices = [] - values = torch.tensor([[False, False], [False, False]]) - check(dst.bool(), idx, values) + for idx in range(num_matrices): + mat_type = idx % len(mat_chars) + if mat_chars[mat_type] == 'sym': + list_of_matrices.append(random_symmetric_matrix(matsize).to(device=device)) + elif mat_chars[mat_type] == 'sym_psd': + list_of_matrices.append(random_symmetric_psd_matrix(matsize).to(device=device)) + elif mat_chars[mat_type] == 'sym_pd': + list_of_matrices.append(random_symmetric_pd_matrix(matsize).to(device=device)) + elif mat_chars[mat_type] == 'sing': + list_of_matrices.append(torch.ones(matsize, matsize, device=device)) + elif mat_chars[mat_type] == 'non_sing': + list_of_matrices.append(random_square_matrix_of_rank(matsize, matsize).to(device=device)) + full_tensor = torch.stack(list_of_matrices, dim=0).reshape(batchdims + (matsize, matsize)) + # Scaling adapted from `get_random_mat_scale` in _test_det_logdet_slogdet + full_tensor *= (math.factorial(matsize - 1) ** (-1.0 / (2 * matsize))) - def test_put_accumulate(self): - dst = torch.ones(2, 2) - idx = torch.LongTensor([[0, 1], [0, 1]]) - src = torch.Tensor([1, 2, 3, 4]) - dst.put_(idx, src, accumulate=True) - self.assertEqual(dst.tolist(), [[5, 7], [1, 1]]) + for fn in [torch.det, torch.logdet, torch.slogdet]: + expected_value = [] + actual_value = fn(full_tensor) + for full_idx in product(*map(lambda x: list(range(x)), batchdims)): + expected_value.append(fn(full_tensor[full_idx])) - def test_put_empty(self): - for device in torch.testing.get_all_device_types(): - for dst_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]: - for indices_shape in [(0,), (0, 1, 2, 0)]: - for accumulate in [False, True]: - dst = torch.randn(dst_shape, device=device) - indices = torch.empty(indices_shape, dtype=torch.int64, device=device) - src = torch.randn(indices_shape, device=device) - self.assertEqual(dst, dst.put_(indices, src, accumulate=accumulate)) + if fn == torch.slogdet: + sign_value = torch.stack([tup[0] for tup in expected_value], dim=0).reshape(batchdims) + expected_value = torch.stack([tup[1] for tup in expected_value], dim=0).reshape(batchdims) + self.assertEqual(sign_value, actual_value[0], allow_inf=True) + self.assertEqual(expected_value, actual_value[1], allow_inf=True) + else: + expected_value = torch.stack(expected_value, dim=0).reshape(batchdims) + self.assertEqual(actual_value, expected_value, allow_inf=True) - # Fill idx with valid indices. - @staticmethod - def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o): - for i in range(1 if dim == 0 else m): - for j in range(1 if dim == 1 else n): - for k in range(1 if dim == 2 else o): - ii = [i, j, k] - ii[dim] = slice(0, idx.size(dim) + 1) - idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row] + for matsize, batchdims in product([3, 5], [(3,), (5, 3)]): + run_test(matsize, batchdims, mat_chars=['sym_pd']) + run_test(matsize, batchdims, mat_chars=['sing']) + run_test(matsize, batchdims, mat_chars=['non_sing']) + run_test(matsize, batchdims, mat_chars=['sym', 'sym_pd', 'sym_psd']) + run_test(matsize, batchdims, mat_chars=['sing', 'non_sing']) - def test_flatten(self): - # Test that flatten returns 1-dim tensor when given a 0-dim tensor - zero_dim_tensor = torch.tensor(123) - flat0 = zero_dim_tensor.flatten() - one_dim_tensor = torch.tensor([123]) - flat1 = zero_dim_tensor.flatten() + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_solve(self, device): + from common_utils import solve_test_helper + for (k, n) in zip([2, 3, 5], [3, 5, 7]): + b, A = solve_test_helper((n,), (n, k), lambda t: t.to(device)) + x = torch.solve(b, A)[0] + self.assertLessEqual(b.dist(A.mm(x)), 1e-12) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_solve_batched(self, device): + from common_utils import solve_test_helper + + def solve_batch_helper(A_dims, b_dims, device): + b, A = solve_test_helper(A_dims, b_dims, lambda t: t.to(device)) + x_exp_list = [] + for i in range(b_dims[0]): + x_exp_list.append(torch.solve(b[i], A[i])[0]) + x_exp = torch.stack(x_exp_list) # Stacked output + x_act = torch.solve(b, A)[0] # Actual output + self.assertEqual(x_exp, x_act) # Equality check + self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 1e-12) # Correctness check - self.assertEqual(zero_dim_tensor.shape, torch.Size([])) - self.assertEqual(flat0.shape, torch.Size([1])) - self.assertEqual(one_dim_tensor.shape, torch.Size([1])) - self.assertEqual(flat1.shape, torch.Size([1])) - self.assertEqual(flat0, one_dim_tensor) - self.assertEqual(flat0, flat1) - self.assertEqual(flat0.shape, flat1.shape) + for batchsize in [1, 3, 4]: + solve_batch_helper((5, batchsize), (batchsize, 5, 10), device) - # Test both float tensor and quantized tensor - tensors = [torch.randn(5, 5, 5, 5), - torch._empty_affine_quantized([5, 5, 5, 5], - scale=2, - zero_point=3, - dtype=torch.quint8)] - for src in tensors: - flat = src.flatten(0, -1) - self.assertEqual(flat.shape, torch.Size([625])) - self.assertEqual(src.view(-1), flat.view(-1)) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + def test_solve_batched_non_contiguous(self, device): + from numpy.linalg import solve + from common_utils import random_fullrank_matrix_distinct_singular_value + A = random_fullrank_matrix_distinct_singular_value(2, 2).to(device).permute(1, 0, 2) + b = torch.randn(2, 2, 2, device=device).permute(2, 1, 0) + x, _ = torch.solve(b, A) + x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())).to(device) + self.assertEqual(x.data, x_exp) - flat = src.flatten(0, 2) - self.assertEqual(flat.shape, torch.Size([125, 5])) - self.assertEqual(src.view(-1), flat.view(-1)) + @slowTest + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_solve_batched_many_batches(self, device): + from common_utils import solve_test_helper - flat = src.flatten(0, 1) - self.assertEqual(flat.shape, torch.Size([25, 5, 5])) - self.assertEqual(src.view(-1), flat.view(-1)) + b, A = solve_test_helper((5, 256, 256), (5, 1), lambda t: t.to(device)) + x, _ = torch.solve(b, A) + self.assertEqual(torch.matmul(A, x), b.expand(A.shape[:-2] + (5, 1))) - flat = src.flatten(1, 2) - self.assertEqual(flat.shape, torch.Size([5, 25, 5])) - self.assertEqual(src.view(-1), flat.view(-1)) + b, A = solve_test_helper((3,), (512, 512, 3, 1), lambda t: t.to(device)) + x, _ = torch.solve(b, A) + self.assertEqual(torch.matmul(A, x), b) - flat = src.flatten(2, 3) - self.assertEqual(flat.shape, torch.Size([5, 5, 25])) - self.assertEqual(src.view(-1), flat.view(-1)) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + def test_solve_batched_broadcasting(self, device): + from numpy.linalg import solve + from common_utils import solve_test_helper - flat = src.flatten(-2, -1) - self.assertEqual(flat.shape, torch.Size([5, 5, 25])) - self.assertEqual(src.view(-1), flat.view(-1)) + def cast(t): + return t.to(device) - flat = src.flatten(2, 2) - self.assertEqual(flat, src) + def run_test(A_dims, b_dims, cast): + A_matrix_size = A_dims[-1] + A_batch_dims = A_dims[:-2] + b, A = solve_test_helper((A_matrix_size,) + A_batch_dims, b_dims, cast) + x, _ = torch.solve(b, A) + x_exp = torch.Tensor(solve(A.cpu().numpy(), b.cpu().numpy())) + self.assertEqual(x, cast(x_exp)) - # out of bounds index - with self.assertRaisesRegex(IndexError, 'Dimension out of range'): - src.flatten(5, 10) + # test against numpy.linalg.solve + for upper in [True, False]: + run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), cast) # no broadcasting + run_test((2, 1, 3, 4, 4), (4, 6), cast) # broadcasting b + run_test((4, 4), (2, 1, 3, 4, 2), cast) # broadcasting A + run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), cast) # broadcasting A & b + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_cholesky_solve(self, device): + from common_utils import cholesky_solve_test_helper + for (k, n), upper in product(zip([2, 3, 5], [3, 5, 7]), [True, False]): + b, A, L = cholesky_solve_test_helper((n,), (n, k), lambda t: t.to(device), upper) + x = torch.cholesky_solve(b, L, upper=upper) + self.assertLessEqual(b.dist(A.mm(x)), 1e-12) - # invalid start and end - with self.assertRaisesRegex(RuntimeError, 'start_dim cannot come after end_dim'): - src.flatten(2, 0) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_cholesky_solve_batched(self, device): + from common_utils import cholesky_solve_test_helper - @staticmethod - def _test_gather(self, cast, test_bounds=True): - m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20) - elems_per_row = random.randint(1, 10) - dim = random.randrange(3) + def cholesky_solve_batch_helper(A_dims, b_dims, cast, upper): + b, A, L = cholesky_solve_test_helper(A_dims, b_dims, cast, upper) + x_exp_list = [] + for i in range(b_dims[0]): + x_exp_list.append(torch.cholesky_solve(b[i], L[i], upper=upper)) + x_exp = torch.stack(x_exp_list) # Stacked output + x_act = torch.cholesky_solve(b, L, upper=upper) # Actual output + self.assertEqual(x_act, x_exp) # Equality check + self.assertLessEqual(b.dist(torch.matmul(A, x_act)), 2e-12) # Correctness check - src = torch.randn(m, n, o) - idx_size = [m, n, o] - idx_size[dim] = elems_per_row - idx = torch.LongTensor().resize_(*idx_size) - _TestTorchMixin._fill_indices(self, idx, dim, src.size(dim), elems_per_row, m, n, o) + for upper, batchsize in product([True, False], [1, 3, 4]): + cholesky_solve_batch_helper((5, batchsize), (batchsize, 5, 10), lambda t: t.to(device), upper) - src = cast(src) - idx = cast(idx) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + def test_cholesky_solve_batched_non_contiguous(self, device): + from numpy.linalg import solve + from common_utils import random_symmetric_pd_matrix - actual = torch.gather(src, dim, idx) - expected = cast(torch.Tensor().resize_(*idx_size)) - for i in range(idx_size[0]): - for j in range(idx_size[1]): - for k in range(idx_size[2]): - ii = [i, j, k] - ii[dim] = idx[i, j, k] - expected[i, j, k] = src[tuple(ii)] - self.assertEqual(actual, expected, 0) + for upper in [True, False]: + A = random_symmetric_pd_matrix(2, 2) + b = torch.randn(2, 2, 2) + x_exp = torch.Tensor(solve(A.permute(0, 2, 1).numpy(), b.permute(2, 1, 0).numpy())).to(device) + A = A.to(device).permute(0, 2, 1) + b = b.to(device).permute(2, 1, 0) + assert not A.is_contiguous() and not b.is_contiguous(), "contiguous inputs" + L = torch.cholesky(A, upper) + x = torch.cholesky_solve(b, L, upper=upper) + self.assertEqual(x, x_exp) - if test_bounds: - idx[0][0][0] = 23 - self.assertRaises(RuntimeError, lambda: torch.gather(src, dim, idx)) + @slowTest + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_cholesky_solve_batched_many_batches(self, device): + from common_utils import cholesky_solve_test_helper - src = cast(torch.randn(3, 4, 5)) - expected, idx = src.max(2, True) - expected = cast(expected) - idx = cast(idx) - actual = torch.gather(src, 2, idx) - self.assertEqual(actual, expected, 0) - - # Bool test case - t = torch.tensor([[False, True], [True, True]]) - self.assertEqual(torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])), torch.tensor([[False, False], [True, True]])) - - def test_gather(self): - self._test_gather(self, lambda t: t) + for upper in [True, False]: + b, A, L = cholesky_solve_test_helper((5, 256, 256), (5, 10), lambda t: t.to(device), upper) + x = torch.cholesky_solve(b, L, upper) + self.assertEqual(torch.matmul(A, x), b.expand(A.shape[:-2] + (5, 10))) - @staticmethod - def _test_scatter_base(self, cast, method, is_scalar=False, test_bounds=True): - m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20) - elems_per_row = random.randint(1, 10) - dim = random.randrange(3) + b, A, L = cholesky_solve_test_helper((5,), (512, 512, 5, 10), lambda t: t.to(device), upper) + x = torch.cholesky_solve(b, L, upper) + self.assertEqual(torch.matmul(A, x), b) - idx_size = [m, n, o] - idx_size[dim] = elems_per_row - idx = cast(torch.LongTensor().resize_(*idx_size)) - _TestTorchMixin._fill_indices(self, idx, dim, ([m, n, o])[dim], elems_per_row, m, n, o) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + def test_cholesky_solve_batched_broadcasting(self, device): + from numpy.linalg import solve + from common_utils import random_symmetric_pd_matrix - if is_scalar: - src = random.random() - else: - src = cast(torch.Tensor(*idx_size).normal_()) + def cast(t): + return t.to(device) - base = cast(torch.randn(m, n, o)) - actual = getattr(base.clone(), method)(dim, idx, src) - expected = base.clone() - for i in range(idx_size[0]): - for j in range(idx_size[1]): - for k in range(idx_size[2]): - ii = [i, j, k] - ii[dim] = idx[i, j, k] - if method == 'scatter_' and not is_scalar: - expected[tuple(ii)] = src[i, j, k] - elif method == 'scatter_add_': - expected[tuple(ii)] += src[i, j, k] - else: - expected[tuple(ii)] = src - self.assertEqual(actual, expected, 0) + def run_test(A_dims, b_dims, cast, upper): + A_matrix_size = A_dims[-1] + A_batch_dims = A_dims[:-2] + A = random_symmetric_pd_matrix(A_matrix_size, *A_batch_dims) + b = torch.randn(*b_dims) + x_exp = torch.Tensor(solve(A.numpy(), b.numpy())) + A, b = cast(A), cast(b) + L = torch.cholesky(A, upper) + x = torch.cholesky_solve(b, L, upper=upper) + self.assertEqual(x, cast(x_exp)) - if test_bounds: - idx[0][0][0] = 34 - with self.assertRaises(RuntimeError): - getattr(base.clone(), method)(dim, idx, src) + # test against numpy.linalg.solve + for upper in [True, False]: + run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), cast, upper) # no broadcasting + run_test((2, 1, 3, 4, 4), (4, 6), cast, upper) # broadcasting b + run_test((4, 4), (2, 1, 3, 4, 2), cast, upper) # broadcasting A + run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), cast, upper) # broadcasting A & b + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_cholesky_inverse(self, device): + from common_utils import random_symmetric_pd_matrix + a = random_symmetric_pd_matrix(5).to(device) - # test for empty index, should be a no-op - idx = cast(torch.LongTensor()) - actual = getattr(base.clone(), method)(dim, idx, src) - self.assertEqual(actual, base, 0) + # compute inverse directly + inv0 = torch.inverse(a) - def test_scatter(self): - self._test_scatter_base(self, lambda t: t, 'scatter_') + # default case + chol = torch.cholesky(a) + inv1 = torch.cholesky_inverse(chol, False) + self.assertLessEqual(inv0.dist(inv1), 1e-12) - def test_scatterAdd(self): - self._test_scatter_base(self, lambda t: t, 'scatter_add_') + # upper Triangular Test + chol = torch.cholesky(a, True) + inv1 = torch.cholesky_inverse(chol, True) + self.assertLessEqual(inv0.dist(inv1), 1e-12) - def test_scatterFill(self): - self._test_scatter_base(self, lambda t: t, 'scatter_', True) + # lower Triangular Test + chol = torch.cholesky(a, False) + inv1 = torch.cholesky_inverse(chol, False) + self.assertLessEqual(inv0.dist(inv1), 1e-12) - def test_scatter_bool(self): - for device in torch.testing.get_all_device_types(): - x = torch.tensor([[True, True, True], [True, True, True]], device=device) - res = torch.zeros(3, 3, dtype=torch.bool, device=device) - res = res.scatter_(0, torch.tensor([[0, 1, 2], [0, 1, 2]], device=device), x) - self.assertEqual(res, torch.tensor([[True, False, False], - [False, True, False], - [False, False, True]], device=device)) - - def test_scatter_add_bool(self): - for device in torch.testing.get_all_device_types(): - x = torch.tensor([[True, True, True, True, True], [True, True, True, True, True]], device=device) - res = torch.zeros(3, 5, dtype=torch.bool, device=device) - res = res.scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]], device=device), x) - self.assertEqual(res, torch.tensor([[True, True, True, True, True], - [False, True, False, True, False], - [True, False, True, False, True]], device=device)) + @slowTest + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_cholesky_batched_many_batches(self, device): + from common_utils import random_symmetric_pd_matrix - def test_masked_scatter(self): - with warnings.catch_warnings(record=True) as w: - for maskType in [torch.uint8, torch.bool]: - for dt in torch.testing.get_all_dtypes(): - num_copy, num_dest = 3, 10 - dest = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dt) - dest2 = dest.clone() - src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dt) - mask = torch.tensor((0, 0, 0, 0, 1, 0, 1, 0, 1, 0), dtype=maskType) + def cholesky_test_helper(n, batchsize, device, upper): + A = random_symmetric_pd_matrix(n, batchsize).to(device) + chol_fact = torch.cholesky(A, upper=upper) + if upper: + # Correctness check + self.assertEqual(A, chol_fact.transpose(-2, -1).matmul(chol_fact)) + # Upper triangular check + self.assertEqual(chol_fact, chol_fact.triu()) + else: + # Correctness check + self.assertEqual(A, chol_fact.matmul(chol_fact.transpose(-2, -1))) + # Lower triangular check + self.assertEqual(chol_fact, chol_fact.tril()) - if dt == torch.bool: - # torch.bool is a special case and is being tested - # in a separate test - continue + for upper, batchsize in product([True, False], [262144, 524288]): + cholesky_test_helper(2, batchsize, device, upper) - if dt == torch.half: - self.assertRaises(RuntimeError, lambda: dest.masked_scatter_(mask, src)) - continue + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_cholesky_batched(self, device): + from common_utils import random_symmetric_pd_matrix - dest.masked_scatter_(mask, src) - j = 0 - for i in range(num_dest): - if mask[i]: - dest2[i] = src[j] - j += 1 - self.assertEqual(dest, dest2, 0) + def cholesky_test_helper(n, batch_dims, device, upper): + A = random_symmetric_pd_matrix(n, *batch_dims).to(device) + cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)]) + cholesky_exp = cholesky_exp.reshape_as(A) + self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper)) - # make source bigger than number of 1s in mask - src = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=dt) - dest.masked_scatter_(mask, src) + for upper, batchsize in product([True, False], [(3,), (3, 4), (2, 3, 4)]): + cholesky_test_helper(3, batchsize, device, upper) - # make src smaller. this should fail - src = torch.randn(num_copy - 1) - with self.assertRaises(RuntimeError): - dest.masked_scatter_(mask, src) - self.assertEqual(len(w), 25) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_cholesky(self, device): + x = torch.rand(10, 10, device=device) + 1e-1 + A = torch.mm(x, x.t()) - warn = 'masked_scatter_ received a mask with dtype torch.uint8,' - for wi in w: - self.assertEqual(str(wi.message)[0:55], str(warn)) + # default Case + C = torch.cholesky(A) + B = torch.mm(C, C.t()) + self.assertEqual(A, B, 1e-14) - def test_masked_scatter_bool_tensor(self): - for device in torch.testing.get_all_device_types(): - src = torch.tensor([True, True, True], device=device) - dst = torch.tensor([False, False, False], device=device) - mask = torch.tensor([False, True, False], device=device) + # test Upper Triangular + U = torch.cholesky(A, True) + B = torch.mm(U.t(), U) + self.assertEqual(A, B, 1e-14, 'cholesky (upper) did not allow rebuilding the original matrix') - dst.masked_scatter_(mask, src) - self.assertEqual(dst, torch.tensor([False, True, False], device=device)) + # test Lower Triangular + L = torch.cholesky(A, False) + B = torch.mm(L, L.t()) + self.assertEqual(A, B, 1e-14, 'cholesky (lower) did not allow rebuilding the original matrix') - mask = torch.tensor([True, False, True], device=device) - dst = dst.masked_scatter(mask, src) - self.assertEqual(dst, torch.tensor([True, True, True], device=device)) + def test_view(self, device): + tensor = torch.rand(15, device=device) + template = torch.rand(3, 5, device=device) + empty = torch.empty(0, device=device) + target = template.size() + self.assertEqual(tensor.view_as(template).size(), target) + self.assertEqual(tensor.view(3, 5).size(), target) + self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target) + self.assertEqual(tensor.view(-1, 5).size(), target) + self.assertEqual(tensor.view(3, -1).size(), target) + tensor_view = tensor.view(5, 3) + tensor_view.fill_(random.uniform(0, 1)) + self.assertEqual(empty.view_as(empty), empty) + self.assertEqual(empty.view(0), empty) + self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1])) + self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty) - def test_masked_select(self): - for device in torch.testing.get_all_device_types(): - for dt in torch.testing.get_all_dtypes(): - with warnings.catch_warnings(record=True) as w: - for maskType in [torch.uint8, torch.bool]: - num_src = 10 - src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dt, device=device) - mask = torch.rand(num_src, device=device).clamp(0, 1).mul(2).floor().to(maskType) - - if dt == torch.bfloat16 and device == 'cuda': - # remove once bfloat16 implemented on CUDA - self.assertRaises(RuntimeError, lambda: src.masked_select(mask)) - continue - - if dt == torch.half and device == 'cpu': - self.assertRaises(RuntimeError, lambda: src.masked_select(mask)) - continue - - dst = src.masked_select(mask) - dst2 = [] - for i in range(num_src): - if mask[i]: - dst2 += [src[i]] - self.assertEqual(dst, torch.tensor(dst2), 0) - - dst3 = torch.empty_like(src, device=device) - torch.masked_select(src, mask, out=dst3) - self.assertEqual(dst3, torch.Tensor(dst2), 0) - self.assertEqual(len(w), 1) + # test size inference with empty tensors + self.assertEqual(empty.view(-1).size(), torch.Size([0])) + self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0])) - warn = 'masked_select received a mask with dtype torch.uint8,' - self.assertEqual(str(w[0].message)[0:53], str(warn)) + with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"): + empty.view(-1, 0) - def test_masked_fill(self): - with warnings.catch_warnings(record=True) as w: - for dt in torch.testing.get_all_dtypes(): - for dtype in [torch.uint8, torch.bool]: - num_dest = 10 - dst = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dt) - mask = torch.rand(num_dest).mul(2).floor().to(dtype) - val = random.random() - dst2 = dst.clone() + with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"): + empty.view(3, 0, -1, 0) - if dt == torch.half: - self.assertRaises(RuntimeError, lambda: dst.masked_fill_(mask, val)) - continue + self.assertRaises(RuntimeError, lambda: tensor.view(15, 0)) + self.assertRaises(RuntimeError, lambda: tensor.view(7, -1)) + self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1)) - dst.masked_fill_(mask, val) - for i in range(num_dest): - if mask[i]: - dst2[i] = val - self.assertEqual(dst, dst2, 0) + # test view when tensor is not contiguous in every dimension, but only + # contiguous dimensions are touched. + tensor = torch.rand(4, 2, 5, 1, 6, 2, 9, 3, device=device).transpose(-1, 2).transpose(-2, 3) + # size: [ 4, 2, 3, 9, 6, 2, 1, 5] + # stride: [3840, 1620, 1, 3, 54, 27, 324, 324] + # contiguous dim chunks: [__________, ____, ____, __________, ____, ____] + # merging 1 to chunk after: [__________, ____, ____, __________, __________] + contig_tensor = tensor.clone() + # [4, 2] => [8, 1] + # [3] => [3] + # [9] => [3, 3] + # [6, 2] => [4, 1, 3] + # [1, 5] => [5] + view_size = [8, 1, 3, 3, 3, 4, 1, 3, 5] + self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) + # [4, 2] => [2, 4] + # [3] => [3] + # [9] => [1, 9] + # [6, 2] => [2, 2, 3] + # [1, 5] => [5, 1] + view_size = [2, 4, 3, 1, 9, 2, 2, 3, 5, 1] + self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) + # adding size 1 dims + view_size = [1, 1, 2, 1, 4, 3, 1, 1, 9, 1, 2, 1, 2, 3, 1, 5, 1, 1] + self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) - # test non-contiguous case - dst = torch.randn(num_dest, num_dest, num_dest).permute((2, 0, 1)) - dst2 = dst.clone() - dst.masked_fill_((dst > 0).to(dtype), val) - dst2.masked_fill_((dst2 > 0).to(dtype), val) - self.assertEqual(dst, dst2, 0) - self.assertEqual(len(w), 28) + # invalid views + self.assertRaises(RuntimeError, lambda: tensor.view(-1)) + # crossing [4, 2], [3] + self.assertRaises(RuntimeError, lambda: tensor.view(24, 9, 6, 2, 1, 5)) + # crossing [6, 2], [1, 5] + self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 9, 6, 10)) + # crossing [9], [6, 2] + self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 54, 2, 1, 5)) - warn = 'masked_fill_ received a mask with dtype torch.uint8,' - for wi in w: - self.assertEqual(str(wi.message)[0:52], str(warn)) + # view with stride 0 dims + tensor = torch.empty(1, 1, device=device).expand(3, 4) # all dims are contiguous + contig_tensor = tensor.clone() + self.assertEqual(tensor.view(-1), contig_tensor.view(-1)) + self.assertEqual(tensor.view(1, -1, 1), contig_tensor.view(1, -1, 1)) + self.assertEqual(tensor.view(-1, 1), contig_tensor.view(-1, 1)) + self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1)) + self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1)) - def test_masked_fill_bool_tensor(self): - for device in torch.testing.get_all_device_types(): - dst = torch.tensor([True, False, True], device=device) - mask = torch.tensor([False, True, False], device=device) + def test_flip(self, device): + data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2) - dst.masked_fill_(mask, True) - self.assertEqual(dst, torch.tensor([True, True, True], device=device)) + self.assertEqual(torch.tensor([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2), data.flip(0)) + self.assertEqual(torch.tensor([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2), data.flip(1)) + self.assertEqual(torch.tensor([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2), data.flip(2)) + self.assertEqual(torch.tensor([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2), data.flip(0, 1)) + self.assertEqual(torch.tensor([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2), data.flip(0, 1, 2)) - dst = dst.masked_fill(mask, False) - self.assertEqual(dst, torch.tensor([True, False, True], device=device)) + # check for wrap dim + self.assertEqual(torch.tensor([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2), data.flip(-1)) + # check for permute + self.assertEqual(torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(0, 2)) + self.assertEqual(torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(2, 0)) - def test_abs(self): - def _test_abs(tensors_dict): - for _category, tensors in tensors_dict.items(): - for data in tensors: - _test_abs_single(data) + # not allow flip on the same dim more than once + self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 1)) + # not allow empty list as input + self.assertRaises(TypeError, lambda: data.flip()) - def _test_abs_single(data): - switch = torch.rand(data.size()).mul(2).floor().mul(2).add(-1).type(data.dtype) - res = torch.mul(data, switch) - self.assertTensorsSlowEqual(res.abs(), data, 1e-16) + # not allow size of flip dim > total dims + self.assertRaises(IndexError, lambda: data.flip(0, 1, 2, 3)) + # not allow dim > max dim + self.assertRaises(IndexError, lambda: data.flip(3)) - shapes = [(3, 4), (3, 5, 7), (2, 2, 5, 8, 2, 3), (1000,), (10, 10, 10)] + # test for non-contiguous case + expanded_data = torch.arange(1, 4, device=device).view(3, 1).expand(3, 2) + transposed_data = torch.arange(1, 9, device=device).view(2, 2, 2).transpose(0, 1) + self.assertEqual(torch.tensor([3, 3, 2, 2, 1, 1]).view(3, 2), expanded_data.flip(0)) + self.assertEqual(torch.tensor([8, 7, 4, 3, 6, 5, 2, 1]).view(2, 2, 2), transposed_data.flip(0, 1, 2)) - for shape in shapes: - # Test all except char/byte - _test_abs(self._make_tensors(shape, val_range=(0, 1000))) + # test for shape + data = torch.randn(2, 3, 4, device=device) + size = [2, 3, 4] + test_dims = [] + for i in range(1, 3): + test_dims += combinations(range(len(size)), i) - # Test char - _test_abs_single(torch.CharTensor(*shape).random_(0, 100)) + for ds in test_dims: + self.assertEqual(size, list(data.flip(ds).size())) - # Test byte - byte_tensor = torch.ByteTensor(*shape).random_(0, 100) - self.assertTensorsSlowEqual(byte_tensor, byte_tensor.abs(), 1e-16) + # test rectangular case + data = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3).to(device) + flip0_result = torch.tensor([[4, 5, 6], [1, 2, 3]]).to(device) + flip1_result = torch.tensor([[3, 2, 1], [6, 5, 4]]).to(device) - # Checking that the right abs function is called for LongTensor - bignumber = 2 ^ 31 + 1 - res = torch.LongTensor((-bignumber,)) - self.assertGreater(res.abs()[0], 0) + self.assertEqual(flip0_result, data.flip(0)) + self.assertEqual(flip1_result, data.flip(1)) - # One of - rec = torch.randn(2, 2, 3, 7, 6, 2).type(torch.float64).clamp(0, 1) - val1 = rec.select(-1, -1).data[0][0][0].sum() - val2 = rec.select(-1, -1).data.abs()[0][0][0].sum() - self.assertEqual(val1, val2, 1e-8, 'absolute value') + # test empty tensor, should just return an empty tensor of the same shape + data = torch.tensor([]) + self.assertEqual(data, data.flip(0)) - # Both abs(0.0) and abs(-0.0) should result in 0.0 - for dtype in (torch.float, torch.double): - abs_zeros = torch.tensor([0.0, -0.0], dtype=dtype).abs().tolist() - for num in abs_zeros: - self.assertGreater(math.copysign(1.0, num), 0.0) + def test_rot90(self, device): + data = torch.arange(1, 5, device=device).view(2, 2) + self.assertEqual(torch.tensor([1, 2, 3, 4]).view(2, 2), data.rot90(0, [0, 1])) + self.assertEqual(torch.tensor([2, 4, 1, 3]).view(2, 2), data.rot90(1, [0, 1])) + self.assertEqual(torch.tensor([4, 3, 2, 1]).view(2, 2), data.rot90(2, [0, 1])) + self.assertEqual(torch.tensor([3, 1, 4, 2]).view(2, 2), data.rot90(3, [0, 1])) - def test_hardshrink(self): - data_original = torch.tensor([1, 0.5, 0.3, 0.6]).view(2, 2) - float_types = [ - 'torch.DoubleTensor', - 'torch.FloatTensor' - ] - for t in float_types: - data = data_original.type(t) - self.assertEqual(torch.tensor([1, 0.5, 0, 0.6]).view(2, 2), data.hardshrink(0.3)) - self.assertEqual(torch.tensor([1, 0, 0, 0.6]).view(2, 2), data.hardshrink(0.5)) + # test for default args k=1, dims=[0, 1] + self.assertEqual(data.rot90(), data.rot90(1, [0, 1])) - # test default lambd=0.5 - self.assertEqual(data.hardshrink(), data.hardshrink(0.5)) + # test for reversed order of dims + self.assertEqual(data.rot90(3, [0, 1]), data.rot90(1, [1, 0])) - # test non-contiguous case - self.assertEqual(torch.tensor([1, 0, 0.5, 0.6]).view(2, 2), data.t().hardshrink(0.3)) + # test for modulo of k + self.assertEqual(data.rot90(5, [0, 1]), data.rot90(1, [0, 1])) + self.assertEqual(data.rot90(3, [0, 1]), data.rot90(-1, [0, 1])) + self.assertEqual(data.rot90(-5, [0, 1]), data.rot90(-1, [0, 1])) - def test_hardshrink_edge_cases(self): - def h(t, values, l_expected): - for l, expected in l_expected.items(): - values_tensor = torch.tensor([float(v) for v in values]).type(t) - expected_tensor = torch.tensor([float(v) for v in expected]).type(t) - self.assertEqual(expected_tensor == values_tensor.hardshrink(l), - torch.ones_like(values_tensor)) + # test for dims out-of-range error + self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, -3])) + self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 2])) - def test_helper(t, min, max): - h(t, [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], - {0.0: [0.0, min, -min, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], - min: [0.0, 0.0, 0.0, 0.1, -0.1, 1.0, -1.0, max, -max, inf, -inf], - 0.1: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, -1.0, max, -max, inf, -inf], - 1.0: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, max, -max, inf, -inf], - max: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, inf, -inf], - inf: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}) + # test tensor with more than 2D + data = torch.arange(1, 9, device=device).view(2, 2, 2) + self.assertEqual(torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2])) + self.assertEqual(data.rot90(1, [1, -1]), data.rot90(1, [1, 2])) - test_helper(torch.DoubleTensor, - torch.finfo(torch.double).tiny, torch.finfo(torch.double).max) - test_helper(torch.FloatTensor, - torch.finfo(torch.float).tiny, torch.finfo(torch.float).max) + # test for errors + self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 3])) + self.assertRaises(RuntimeError, lambda: data.rot90(1, [1, 1])) + self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 1, 2])) + self.assertRaises(RuntimeError, lambda: data.rot90(1, [0])) - def test_unbiased(self): - tensor = torch.randn(100) - self.assertEqual(tensor.var(0), tensor.var(0, unbiased=True)) - self.assertEqual(tensor.var(), tensor.var(unbiased=True)) - self.assertEqual(tensor.var(unbiased=False), tensor.var(0, unbiased=False)) + def test_signal_window_functions(self, device): + if not TEST_SCIPY: + raise unittest.SkipTest('Scipy not found') - tensor = torch.FloatTensor([1.0, 2.0]) - self.assertEqual(tensor.var(unbiased=True), 0.5) - self.assertEqual(tensor.var(unbiased=False), 0.25) + def test(name): + torch_method = getattr(torch, name + '_window') + for size in [1, 2, 5, 10, 50, 100, 1024, 2048]: + for periodic in [True, False]: + res = torch_method(size, periodic=periodic, device=device) + ref = torch.from_numpy(signal.get_window(name, size, fftbins=periodic)) + self.assertEqual(res, ref) + with self.assertRaisesRegex(RuntimeError, r'not implemented for sparse types'): + torch_method(3, layout=torch.sparse_coo) + with self.assertRaisesRegex(RuntimeError, r'floating point'): + torch_method(3, dtype=torch.long) + self.assertTrue(torch_method(3, requires_grad=True).requires_grad) + self.assertFalse(torch_method(3).requires_grad) - tensor = torch.FloatTensor([1.0, 2.0, 3.0]) - self.assertEqual(tensor.var(unbiased=True), 1.0) - self.assertEqual(tensor.var(unbiased=False), 2.0 / 3.0) + for window in ['hann', 'hamming', 'bartlett', 'blackman']: + test(window) - tensor = torch.randn(100) - self.assertEqual(tensor.std(0), tensor.std(0, unbiased=True)) - self.assertEqual(tensor.std(), tensor.std(unbiased=True)) - self.assertEqual(tensor.std(unbiased=False), tensor.std(0, unbiased=False)) + def test_broadcast(self, device): - @skipIfRocm - def test_structseq_repr(self): - a = torch.arange(250).reshape(5, 5, 10) - expected = """ - torch.return_types.max( - values=tensor([[ 40, 41, 42, 43, 44, 45, 46, 47, 48, 49], - [ 90, 91, 92, 93, 94, 95, 96, 97, 98, 99], - [140, 141, 142, 143, 144, 145, 146, 147, 148, 149], - [190, 191, 192, 193, 194, 195, 196, 197, 198, 199], - [240, 241, 242, 243, 244, 245, 246, 247, 248, 249]]), - indices=tensor([[4, 4, 4, 4, 4, 4, 4, 4, 4, 4], - [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], - [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], - [4, 4, 4, 4, 4, 4, 4, 4, 4, 4], - [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]]))""" - self.assertEqual(repr(a.max(1)), textwrap.dedent(expected).strip()) + # all functions + fns = { + "dist", "atan2", "pow", "lerp", "add", + "sub", "mul", "div", "fmod", "remainder", + "eq", "ge", "gt", "le", "lt", "max", "min", "ne", + "addcdiv", "addcmul", "masked_scatter", "masked_select", "masked_fill", + "map", "map2", "copy" + } + # functions with three tensor arguments + fns_3_args = {"addcdiv", "addcmul", "map2"} - def test_var_stability(self): - tensor = torch.FloatTensor([2281.5, 2281.25]) - self.assertEqual(tensor.var(dim=0), 0.03125) - self.assertEqual(tensor.var(), 0.03125) + for fn in fns: + (dims_small, dims_large, dims_full) = self._select_broadcastable_dims() + full1d = torch.randn(*dims_full, device=device).flatten().float() + small = torch.randn(*dims_small, device=device).float() + large = torch.randn(*dims_large, device=device).float() + small_expanded = small.expand(*dims_full) + large_expanded = large.expand(*dims_full) + small2 = None + small2_expanded = None + if fn in fns_3_args: + # create another smaller tensor + (dims_small2, _, _) = self._select_broadcastable_dims(dims_full) + small2 = torch.randn(*dims_small2, device=device).float() + small2_expanded = small2.expand(*dims_full) - @staticmethod - def _test_view(self, cast): - tensor = cast(torch.rand(15)) - template = cast(torch.rand(3, 5)) - empty = cast(torch.empty(0)) - target = template.size() - self.assertEqual(tensor.view_as(template).size(), target) - self.assertEqual(tensor.view(3, 5).size(), target) - self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target) - self.assertEqual(tensor.view(-1, 5).size(), target) - self.assertEqual(tensor.view(3, -1).size(), target) - tensor_view = tensor.view(5, 3) - tensor_view.fill_(random.uniform(0, 1)) - self.assertEqual(empty.view_as(empty), empty) - self.assertEqual(empty.view(0), empty) - self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1])) - self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty) + if small.is_cuda and fn in ['map', 'map2']: + # map and map2 are not implementd on CUDA tensors + continue - # test size inference with empty tensors - self.assertEqual(empty.view(-1).size(), torch.Size([0])) - self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0])) + if hasattr(large_expanded, fn): + # run through tensor versions of functions + # and verify fully expanded inputs give same results + expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded} - with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"): - empty.view(-1, 0) + def tensorfn(myfn, t1, t2): + if fn == "lerp": + return myfn(t1, 0.5) + elif fn == "masked_select": + return myfn(t1 < 0) + elif fn == "masked_scatter": + return myfn(t1 < 0.5, full1d) + elif fn == "masked_fill": + return myfn(t1 < 0.5, 1.0) + elif fn in fns_3_args: + return myfn(1, t1, t2) + else: + return myfn(t1) - with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"): - empty.view(3, 0, -1, 0) + # test various orders + for first, second, third in [(large, small, small2), (small, large, small2), + (small2, small, large), (small2, large, small)]: + if first is None: + break # ignore last iter when small2 is None + method_expanded = getattr(expanded[first], fn) + method = getattr(first, fn) + r1 = tensorfn(method_expanded, expanded[second], expanded[third]) + r2 = tensorfn(method, second, third) + self.assertEqual(r1, r2) - self.assertRaises(RuntimeError, lambda: tensor.view(15, 0)) - self.assertRaises(RuntimeError, lambda: tensor.view(7, -1)) - self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1)) + # now for torch. versions of functions + if hasattr(torch, fn): + fntorch = getattr(torch, fn) + expanded = {large: large_expanded, small: small_expanded, small2: small2_expanded} - # test view when tensor is not contiguous in every dimension, but only - # contiguous dimensions are touched. - tensor = cast(torch.rand(4, 2, 5, 1, 6, 2, 9, 3)).transpose(-1, 2).transpose(-2, 3) - # size: [ 4, 2, 3, 9, 6, 2, 1, 5] - # stride: [3840, 1620, 1, 3, 54, 27, 324, 324] - # contiguous dim chunks: [__________, ____, ____, __________, ____, ____] - # merging 1 to chunk after: [__________, ____, ____, __________, __________] - contig_tensor = tensor.clone() - # [4, 2] => [8, 1] - # [3] => [3] - # [9] => [3, 3] - # [6, 2] => [4, 1, 3] - # [1, 5] => [5] - view_size = [8, 1, 3, 3, 3, 4, 1, 3, 5] - self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) - # [4, 2] => [2, 4] - # [3] => [3] - # [9] => [1, 9] - # [6, 2] => [2, 2, 3] - # [1, 5] => [5, 1] - view_size = [2, 4, 3, 1, 9, 2, 2, 3, 5, 1] - self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) - # adding size 1 dims - view_size = [1, 1, 2, 1, 4, 3, 1, 1, 9, 1, 2, 1, 2, 3, 1, 5, 1, 1] - self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) + def torchfn(t1, t2, t3): + if fn == "lerp": + return fntorch(t1, t2, 0.5) + elif fn == "masked_select": + return fntorch(t1, t2 < 0) + elif fn == "masked_scatter": + return fntorch(t1, t2 < 0.5, full1d) + elif fn == "masked_fill": + return fntorch(t1, t2 < 0.5, 1.0) + elif fn in fns_3_args: + return fntorch(t1, 1.0, t2, t3) + else: + return fntorch(t1, t2) - # invalid views - self.assertRaises(RuntimeError, lambda: tensor.view(-1)) - # crossing [4, 2], [3] - self.assertRaises(RuntimeError, lambda: tensor.view(24, 9, 6, 2, 1, 5)) - # crossing [6, 2], [1, 5] - self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 9, 6, 10)) - # crossing [9], [6, 2] - self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 54, 2, 1, 5)) + # test various orders + for first, second, third in [(large, small, small2), (small, large, small2), + (small2, small, large), (small2, large, small)]: + if first is None: + break # ignore last iter when small2 is None + r1 = torchfn(expanded[first], expanded[second], expanded[third]) + r2 = torchfn(first, second, third) + self.assertEqual(r1, r2) - # view with stride 0 dims - tensor = cast(torch.empty(1, 1)).expand(3, 4) # all dims are contiguous - contig_tensor = tensor.clone() - self.assertEqual(tensor.view(-1), contig_tensor.view(-1)) - self.assertEqual(tensor.view(1, -1, 1), contig_tensor.view(1, -1, 1)) - self.assertEqual(tensor.view(-1, 1), contig_tensor.view(-1, 1)) - self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1)) - self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1)) + # now for in place functions + # in-place tensor is not broadcastable; test only guaranteed + # to work by broadcasting other argument(s) + if not hasattr(large_expanded, fn + "_"): + continue - def test_view(self): - _TestTorchMixin._test_view(self, lambda x: x) + # need to clone largeExpanded so we can reuse, since functions are in-place + large_expanded_clone = large_expanded.clone() - def test_view_empty(self): - x = torch.randn(0, 6) - self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape) + def tensorfn_inplace(t0, t1, t2=None): + t0_fn = getattr(t0, fn + "_") + if fn == "lerp": + return t0_fn(t1, 0.5) + elif fn == "masked_scatter": + return t0_fn(t1 < 0.5, full1d) + elif fn == "masked_fill": + return t0_fn(t1 < 0.5, 1.0) + elif fn == "map": + return t0_fn(t1, lambda x, y: x + y) + elif fn == "map2": + return t0_fn(t1, t2, lambda x, y, z: x + y + z) + elif fn in fns_3_args: + return t0_fn(1.0, t1, t2) + else: + return t0_fn(t1) + # in-place pointwise operations don't actually work if the in-place + # tensor is 0-strided (numpy has the same issue) + if (0 not in large_expanded.stride() and 0 not in large_expanded_clone.stride()): + r1 = tensorfn_inplace(large_expanded, small_expanded, small2_expanded) + r2 = tensorfn_inplace(large_expanded_clone, small, small2) + self.assertEqual(r1, r2) - def test_reshape(self): - x = torch.randn(3, 3) - self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr()) - self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr()) - self.assertEqual(torch.reshape(x, (9,)), x.reshape(9)) - self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1)) + def broadcastable(t0, t1, t2=None): + try: + t1.expand_as(t0) + if t2 is not None: + t2.expand_as(t0) + except RuntimeError: + return False + return True - y = torch.randn(4, 4, 4)[:, 0, :] - self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr()) - self.assertEqual(y.contiguous().view(-1), y.reshape(-1)) - self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr()) + def _test_in_place_broadcastable(t0, t1, t2=None): + if not broadcastable(t0, t1, t2): + same_size = t0.numel() == t1.numel() and (t0.numel() == t2.numel() if t2 is not None else True) + if not same_size: + self.assertRaises(RuntimeError, lambda: tensorfn_inplace(t0, t1, t2)) + else: + tensorfn_inplace(t0, t1, t2) - s = torch.randn(()) - self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr()) - self.assertEqual(s.reshape(-1).shape, (1,)) - self.assertRaises(RuntimeError, lambda: s.reshape(2)) + if fn not in fns_3_args: + _test_in_place_broadcastable(small, large_expanded) + _test_in_place_broadcastable(small, large) + else: + _test_in_place_broadcastable(small2, small_expanded, large_expanded) + _test_in_place_broadcastable(small2, small, large) - empty = torch.tensor([]) - self.assertEqual(empty, empty.reshape(-1)) - self.assertEqual(empty, empty.reshape([0])) - # TODO: fix these once we have multi-dimensional empty tensors - self.assertEqual(empty.reshape([0, 1]).shape, (0, 1)) - self.assertEqual(empty.reshape([1, -1]).shape, (1, 0)) - self.assertRaises(RuntimeError, lambda: empty.reshape(1)) + def test_broadcast_fused_matmul(self, device): + fns = ["baddbmm", "addbmm", "addmm", "addmv", "addr"] - x = torch.randn(3, 3) - self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr()) - self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr()) - self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10))) + for fn in fns: + batch_dim = random.randint(1, 8) + n_dim = random.randint(1, 8) + m_dim = random.randint(1, 8) + p_dim = random.randint(1, 8) - def test_empty_reshape(self): - x = torch.randn(0, 6) - self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape) - # should be viewable -- i.e. data_ptr is the same. - self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr()) + def dims_full_for_fn(): + if fn == "baddbmm": + return ([batch_dim, n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim]) + elif fn == "addbmm": + return ([n_dim, p_dim], [batch_dim, n_dim, m_dim], [batch_dim, m_dim, p_dim]) + elif fn == "addmm": + return ([n_dim, p_dim], [n_dim, m_dim], [m_dim, p_dim]) + elif fn == "addmv": + return ([n_dim], [n_dim, m_dim], [m_dim]) + elif fn == "addr": + return ([n_dim, m_dim], [n_dim], [m_dim]) + else: + raise AssertionError("unknown function") - # match NumPy semantics -- don't infer the size of dimension with a degree of freedom - self.assertRaises(RuntimeError, lambda: x.reshape(0, -1)) + (t0_dims_full, t1_dims, t2_dims) = dims_full_for_fn() + (t0_dims_small, _, _) = self._select_broadcastable_dims(t0_dims_full) - def test_tensor_shape_empty(self): - for device in torch.testing.get_all_device_types(): - x = torch.randn((0, 1, 3, 0), device=device) - # flatten - self.assertEqual((0,), torch.flatten(x, 0, 3).shape) - self.assertEqual((0, 0), torch.flatten(x, 0, 2).shape) - self.assertEqual((0, 3, 0), torch.flatten(x, 1, 2).shape) - - # squeeze, unsqueeze - self.assertEqual((0, 1, 1, 3, 0), torch.unsqueeze(x, 1).shape) - self.assertEqual((0, 3, 0), torch.squeeze(x, 1).shape) - self.assertEqual((0, 3, 0), torch.squeeze(x).shape) - - # transpose, t - self.assertEqual((0, 0, 3, 1), torch.transpose(x, 1, 3).shape) - y = torch.randn((5, 0), device=device) - self.assertEqual((0, 5), y.t().shape) - - # select - self.assertEqual((0, 1, 0), torch.select(x, 2, 2).shape) - - # repeat, permute - self.assertEqual((9, 0, 5, 6, 0), x.repeat(9, 7, 5, 2, 3).shape) - self.assertEqual((3, 0, 0, 1), x.permute(2, 3, 0, 1).shape) - - # diagonal, diagflat - self.assertEqual((0,), torch.diagonal(torch.randn((5, 0), device=device)).shape) - self.assertEqual((0,), torch.diagonal(torch.randn((0, 5), device=device)).shape) - # off the end offsets are valid - self.assertEqual((0,), torch.diagonal(torch.randn((5, 0), device=device), offset=1).shape) - self.assertEqual((0,), torch.diagonal(torch.randn((0, 5), device=device), offset=1).shape) - # check non-zero sized offsets off the end - self.assertEqual((5, 6, 0), torch.diagonal(torch.randn((3, 4, 5, 6), device=device), offset=45252).shape) - self.assertEqual((5, 6, 0), torch.diagonal(torch.randn((3, 4, 5, 6), device=device), offset=-45252).shape) - - self.assertEqual((0, 0), torch.diagflat(torch.tensor([], device=device)).shape) - self.assertEqual(torch.zeros(1, 1), torch.diagflat(torch.tensor([], device=device), offset=1)) - self.assertEqual((0, 0), torch.diagflat(torch.tensor([[]], device=device)).shape) - self.assertEqual(torch.zeros(1, 1), torch.diagflat(torch.tensor([[]], device=device), offset=1)) - - # stack, split, chunk - self.assertEqual((4, 0, 1, 3, 0), torch.stack((x, x, x, x)).shape) - self.assertEqual([(0, 1, 3, 0)], - [z.shape for z in torch.chunk(x, 1, dim=0)]) - - self.assertEqual([(0, 1, 3, 0), ] * 3, [z.shape for z in torch.chunk(x, 3, dim=0)]) - self.assertEqual([(0, 1, 1, 0), ] * 3, [z.shape for z in torch.chunk(x, 3, dim=2)]) - - # NOTE: split_with_sizes behaves differently than NumPy in that it - # takes sizes rather than offsets - self.assertEqual([(0, 1, 0, 0), (0, 1, 1, 0), (0, 1, 2, 0)], - [z.shape for z in torch.split(x, (0, 1, 2), dim=2)]) - - self.assertRaises(RuntimeError, lambda: torch.split(x, 0, dim=1)) - # This is strange because the split size is larger than the dim size, but consistent with - # how split handles that case generally (when no 0s are involved). - self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 1, dim=0)]) - self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 0, dim=0)]) + t0_small = torch.randn(*t0_dims_small, device=device).float() + t1 = torch.randn(*t1_dims, device=device).float() + t2 = torch.randn(*t2_dims, device=device).float() - # functions that operate over a dimension but don't reduce. - def test_dim_function_empty(self): - for device in torch.testing.get_all_device_types(): - shape = (0, 1, 2, 0) - x = torch.randn(shape, device=device) - - # size stride - self.assertEqual(0, x.size(3)) - self.assertEqual(2, x.size(2)) - self.assertEqual(2, x.stride(0)) - self.assertEqual(1, x.stride(2)) - - self.assertEqual(x, torch.nn.functional.glu(x, 0)) - self.assertEqual((0, 1, 1, 0), torch.nn.functional.glu(x, 2).shape) - - # softmax, logsoftmax - self.assertEqual(x, torch.nn.functional.softmax(x, 0)) - self.assertEqual(x, torch.nn.functional.softmax(x, 2)) - self.assertEqual(x, torch.nn.functional.softmax(x, 3)) - - self.assertEqual(x, torch.nn.functional.log_softmax(x, 0)) - self.assertEqual(x, torch.nn.functional.log_softmax(x, 2)) - self.assertEqual(x, torch.nn.functional.log_softmax(x, 3)) - - # cumsum, cumprod - self.assertEqual(shape, torch.cumsum(x, 0).shape) - self.assertEqual(shape, torch.cumsum(x, 2).shape) - self.assertEqual(shape, torch.cumprod(x, 0).shape) - self.assertEqual(shape, torch.cumprod(x, 2).shape) - - # flip - self.assertEqual(x, x.flip(0)) - self.assertEqual(x, x.flip(2)) - - # roll - self.assertEqual(x, x.roll(0, 1).roll(0, -1)) - self.assertEqual(x, x.roll(1, x.size(1))) - self.assertEqual(x, x.roll(1)) - self.assertEqual(x, x.roll((1, 1), (3, 1))) - - # unbind - self.assertEqual((), x.unbind(0)) - self.assertEqual((torch.empty((0, 1, 0), device=device), torch.empty((0, 1, 0), device=device)), - x.unbind(2)) - - # cross - y = torch.randn((0, 1, 3, 0), device=device) - self.assertEqual(y.shape, torch.cross(y, y).shape) - - # renorm - self.assertEqual(shape, torch.renorm(x, 1, 0, 5).shape) - self.assertEqual(shape, torch.renorm(x, 1, 2, 5).shape) - - # sort - self.assertEqual([shape, shape], [z.shape for z in torch.sort(x, dim=0)]) - self.assertEqual([shape, shape], [z.shape for z in torch.sort(x, dim=2)]) - - # topk - self.assertEqual([shape, shape], [z.shape for z in torch.topk(x, 0, dim=0)]) - self.assertEqual([(0, 1, 1, 0), (0, 1, 1, 0)], [z.shape for z in torch.topk(x, 1, dim=2)]) - - y = torch.randn((2, 3, 4), device=device) - self.assertEqual([(2, 3, 0), (2, 3, 0)], [z.shape for z in torch.topk(y, 0)]) - - # gather - self.assertEqual(shape, torch.gather(x, 0, torch.empty(shape, dtype=torch.int64, device=device)).shape) - self.assertEqual(shape, torch.gather(x, 2, torch.empty(shape, dtype=torch.int64, device=device)).shape) - larger_shape = torch.empty((0, 1, 3, 0), dtype=torch.int64, device=device) - self.assertEqual(larger_shape.shape, torch.gather(x, 2, larger_shape).shape) - smaller_shape = torch.empty((0, 1, 0, 0), dtype=torch.int64, device=device) - self.assertEqual(smaller_shape.shape, torch.gather(x, 2, smaller_shape).shape) - y = torch.randn((2, 3, 4), device=device) - self.assertEqual((0, 3, 4), - torch.gather(y, 0, torch.empty((0, 3, 4), dtype=torch.int64, device=device)).shape) - - # scatter, scatter_add - for dim in [0, 2]: - y = torch.randn(shape, device=device) - y_src = torch.randn(shape, device=device) - ind = torch.empty(shape, dtype=torch.int64, device=device) - self.assertEqual(shape, y.scatter_(dim, ind, y_src).shape) - self.assertEqual(shape, y.scatter_add_(dim, ind, y_src).shape) - - z = torch.randn((2, 3, 4), device=device) - z_src = torch.randn((2, 3, 4), device=device) - self.assertEqual(z, z.scatter_(2, torch.empty((2, 3, 0), dtype=torch.int64, device=device), z_src)) - self.assertEqual(z, z.scatter_add_(2, torch.empty((2, 3, 0), dtype=torch.int64, device=device), z_src)) - - # index_fill, index_copy, index_add - c = x.clone() - c_clone = c.clone() - ind_empty = torch.tensor([], dtype=torch.int64, device=device) - ind_01 = torch.tensor([0, 1], dtype=torch.int64, device=device) - self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1)) - self.assertEqual(c_clone, c.index_fill_(2, ind_empty, -1)) - self.assertEqual(c_clone, c.index_fill_(2, torch.tensor([0, 1], dtype=torch.int64, device=device), -1)) - self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2, 0), device=device))) - self.assertEqual(c_clone, c.index_copy_(2, ind_empty, torch.empty((0, 1, 0, 0), device=device))) - self.assertEqual(c_clone, c.index_copy_(2, ind_01, torch.empty((0, 1, 2, 0), device=device))) - self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2, 0), device=device))) - self.assertEqual(c_clone, c.index_add_(2, ind_empty, torch.empty((0, 1, 0, 0), device=device))) - self.assertEqual(c_clone, c.index_add_(2, ind_01, torch.empty((0, 1, 2, 0), device=device))) - - c = torch.randn((0, 1, 2), device=device) - c_clone = c.clone() - self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1)) - self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2), device=device))) - self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2), device=device))) - self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1)) - self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2), device=device))) - self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2), device=device))) - - # index fill/copy/add non-empty - z = torch.randn((2, 3, 4), device=device) - self.assertEqual(z, z.index_fill_(0, ind_empty, -1)) - z = torch.randn((2, 3, 4), device=device) - self.assertEqual(z, z.index_copy_(0, ind_empty, torch.empty((0, 3, 4), device=device))) - z = torch.randn((2, 3, 4), device=device) - self.assertEqual(z, z.index_add_(0, ind_empty, torch.empty((0, 3, 4), device=device))) - - # index_select - self.assertEqual(x, x.index_select(0, ind_empty)) - self.assertEqual((0, 1, 0, 0), x.index_select(2, ind_empty).shape) - self.assertEqual(x, x.index_select(2, ind_01)) - z = torch.randn((2, 3, 4), device=device) # non-empty - self.assertEqual((0, 3, 4), z.index_select(0, ind_empty).shape) - c = torch.randn((0, 1, 2), device=device) - self.assertEqual(c, c.index_select(0, ind_empty)) - c = torch.randn((0, 1, 2), device=device) - self.assertEqual(c, c.index_select(0, ind_empty)) - - @skipIfRocm - def test_blas_empty(self): - for device in torch.testing.get_all_device_types(): + t0_full = t0_small.expand(*t0_dims_full).to(device) - def fn(torchfn, *args): - return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape - for shape in args)) - - # mm, addmm - self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape) - self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape) - self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape) - self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape) - self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6))) - - self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape) - self.assertEqual((5, 6), fn(torch.addmm, (5, 6), (5, 0), (0, 6)).shape) - - # mv, addmv - self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape) - self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape) - self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,))) - - self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape) - self.assertEqual((3,), fn(torch.addmv, (3,), (3, 0), (0,)).shape) - - # ger, addr - self.assertEqual((0, 0), fn(torch.ger, (0,), (0,)).shape) - self.assertEqual((5, 0), fn(torch.ger, (5,), (0,)).shape) - self.assertEqual((0, 4), fn(torch.ger, (0,), (4,)).shape) - - self.assertEqual((0, 0), fn(torch.addr, (0, 0), (0,), (0,)).shape) - self.assertEqual((5, 0), fn(torch.addr, (5, 0), (5,), (0,)).shape) - self.assertEqual((0, 4), fn(torch.addr, (0, 4), (0,), (4,)).shape) - - # bmm, baddbmm - self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape) - self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape) - self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape) - self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6))) - - self.assertEqual((0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape) - self.assertEqual((3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape) - self.assertEqual((0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape) - self.assertEqual((3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape) - - # addbmm - self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape) - self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape) - self.assertEqual((5, 6), fn(torch.addbmm, (5, 6), (0, 5, 0), (0, 0, 6)).shape) - - # matmul - self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,))) - self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape) - self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape) - self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape) - self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4))) - - # dot - self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,))) - - if torch._C.has_lapack: - # lu - A_LU, pivots = fn(torch.lu, (0, 5, 5)) - self.assertEqual([(0, 5, 5), (0, 5)], [A_LU.shape, pivots.shape]) - A_LU, pivots = fn(torch.lu, (0, 0, 0)) - self.assertEqual([(0, 0, 0), (0, 0)], [A_LU.shape, pivots.shape]) - A_LU, pivots = fn(torch.lu, (2, 0, 0)) - self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape]) + fntorch = getattr(torch, fn) + r0 = fntorch(t0_small, t1, t2) + r1 = fntorch(t0_full, t1, t2) + self.assertEqual(r0, r1) - def check_single_matmul(self, x, y, shape): - a = np.array(x, copy=False) - b = np.array(y, copy=False) - expected = np.matmul(a, b) + def test_broadcast_batched_matmul(self, device): + n_dim = random.randint(1, 8) + m_dim = random.randint(1, 8) + p_dim = random.randint(1, 8) + full_batch_dims = [random.randint(1, 3) for i in range(random.randint(1, 3))] + (batch_dims_small, _, _) = self._select_broadcastable_dims(full_batch_dims) - ans = torch.matmul(x, y) - self.assertTrue(ans.is_contiguous()) - self.assertTrue(np.array_equal(ans, expected)) + def verify_batched_matmul(full_lhs, one_dimensional): + if not one_dimensional: + lhs_dims = [n_dim, m_dim] + rhs_dims = [m_dim, p_dim] + result_dims = [n_dim, p_dim] + else: + lhs_dims = [n_dim, m_dim] if full_lhs else [m_dim] + rhs_dims = [m_dim, p_dim] if not full_lhs else [m_dim] + result_dims = [n_dim] if full_lhs else [p_dim] - out = torch.zeros(*shape, dtype=torch.int64) - ans = torch.matmul(x, y, out=out) - self.assertIs(ans, out) - self.assertTrue(ans.is_contiguous()) - self.assertTrue(np.array_equal(ans, expected)) + lhs_mat_dims = lhs_dims if len(lhs_dims) != 1 else [1, m_dim] + rhs_mat_dims = rhs_dims if len(rhs_dims) != 1 else [m_dim, 1] + full_mat_dims = lhs_mat_dims if full_lhs else rhs_mat_dims + dim0_dims = rhs_dims if full_lhs else lhs_dims + small_dims = batch_dims_small + (rhs_mat_dims if full_lhs else lhs_mat_dims) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_matmul_small_brute_force_1d_Nd(self): - # Issue #20452: range(0, 10) does not work. - n = 1 - for m in range(1, 8): - for p in range(1, 8): - for o in range(1, 5): - # 1d, 3d, inner dimensions C - x = torch.arange(m) - y = torch.arange(o * m * p).reshape(o, m, p) - self.check_single_matmul(x, y, (o, n, p)) + small = torch.randn(*(small_dims), device=device).float() + dim0 = torch.randn(*(dim0_dims), device=device).float() + full = torch.randn(*(full_batch_dims + full_mat_dims), device=device).float() + if not one_dimensional: + (lhsTensors, rhsTensors) = ((full,), (small, dim0)) if full_lhs else ((small, dim0), (full,)) + else: + (lhsTensors, rhsTensors) = ((full,), (dim0,)) if full_lhs else ((dim0,), (full,)) - # 1d, 3d, inner dimensions Fortran - x = torch.arange(m) - y = torch.arange(o * p * m).reshape(o, p, m).transpose(-1, -2) - self.check_single_matmul(x, y, (o, n, p)) + def maybe_squeeze_result(l, r, result): + if len(lhs_dims) == 1 and l.dim() != 1: + return result.squeeze(-2) + elif len(rhs_dims) == 1 and r.dim() != 1: + return result.squeeze(-1) + else: + return result - # 1d, 3d, inner dimensions non-contiguous - x = torch.arange(2 * m)[::2] - y = torch.arange(o * m * 2 * p).reshape(o, m, 2 * p)[:, :, ::2] - self.check_single_matmul(x, y, (o, n, p)) + for lhs in lhsTensors: + lhs_expanded = lhs.expand(*(torch.Size(full_batch_dims) + torch.Size(lhs_mat_dims))) + lhs_expanded_matmul_fn = lhs_expanded.matmul + for rhs in rhsTensors: + rhs_expanded = ((rhs if len(rhs_dims) != 1 else rhs.unsqueeze(-1)). + expand(*(torch.Size(full_batch_dims) + torch.Size(rhs_mat_dims)))) + truth = maybe_squeeze_result(lhs_expanded, rhs_expanded, lhs_expanded_matmul_fn(rhs_expanded)) + for l in (lhs, lhs_expanded): + for r in (rhs, rhs_expanded): + l_matmul_fn = l.matmul + result = maybe_squeeze_result(l, r, l_matmul_fn(r)) + self.assertEqual(truth, result) + # test torch.matmul function as well + torch_result = maybe_squeeze_result(l, r, torch.matmul(l, r)) + self.assertEqual(truth, torch_result) + # test torch.matmul with out + out = torch.zeros_like(torch_result) + torch.matmul(l, r, out=out) + self.assertEqual(truth, maybe_squeeze_result(l, r, out)) - for r in range(1, 5): - # 1d, 4d, inner dimensions C - x = torch.arange(m) - y = torch.arange(r * o * m * p).reshape(r, o, m, p) - self.check_single_matmul(x, y, (r, o, n, p)) + # compare to bmm + bmm_result = (torch.bmm(lhs_expanded.contiguous().view(-1, *lhs_mat_dims), + rhs_expanded.contiguous().view(-1, *rhs_mat_dims))) + self.assertEqual(truth.view(-1, *result_dims), bmm_result.view(-1, *result_dims)) - # 1d, 4d, inner dimensions Fortran - x = torch.arange(m) - y = torch.arange(r * o * p * m).reshape(r, o, p, m).transpose(-1, -2) - self.check_single_matmul(x, y, (r, o, n, p)) - - # 1d, 4d, inner dimensions non-contiguous - x = torch.arange(2 * m)[::2] - y = torch.arange(r * o * m * 2 * p).reshape(r, o, m, 2 * p)[:, :, :, ::2] - self.check_single_matmul(x, y, (r, o, n, p)) + for indices in product((True, False), repeat=2): + verify_batched_matmul(*indices) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_matmul_small_brute_force_2d_Nd(self): - # Issue #20452: range(0, 10) does not work. - for n in range(1, 5): - for m in range(1, 5): - for p in range(1, 5): - for o in range(1, 3): - # 2d, 3d, inner dimensions C - x = torch.arange(n * m).reshape(n, m) - y = torch.arange(o * m * p).reshape(o, m, p) - self.check_single_matmul(x, y, (o, n, p)) + def test_contiguous(self, device): + x = torch.randn(1, 16, 5, 5, device=device) + self.assertTrue(x.is_contiguous()) + stride = list(x.stride()) + stride[0] = 20 + # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1 + x.set_(x.storage(), 0, x.size(), stride) + self.assertTrue(x.is_contiguous()) - # 2d, 3d, inner dimensions Fortran - x = torch.arange(m * n).reshape(m, n).transpose(-1, -2) - y = torch.arange(o * p * m).reshape(o, p, m).transpose(-1, -2) - self.check_single_matmul(x, y, (o, n, p)) + def test_index(self, device): - # 2d, 3d, inner dimensions non-contiguous - x = torch.arange(n * 2 * m).reshape(n, 2 * m)[:, ::2] - y = torch.arange(o * m * 2 * p).reshape(o, m, 2 * p)[:, :, ::2] - self.check_single_matmul(x, y, (o, n, p)) + def consec(size, start=1): + sequence = torch.ones(int(torch.Tensor(size).prod(0))).cumsum(0) + sequence.add_(start - 1) + return sequence.view(*size) - for r in range(1, 2): - # 2d, 4d, inner dimensions C - x = torch.arange(n * m).reshape(n, m) - y = torch.arange(r * o * m * p).reshape(r, o, m, p) - self.check_single_matmul(x, y, (r, o, n, p)) + reference = consec((3, 3, 3)).to(device) - # 2d, 4d, inner dimensions Fortran - x = torch.arange(m * n).reshape(m, n).transpose(-1, -2) - y = torch.arange(r * o * p * m).reshape(r, o, p, m).transpose(-1, -2) - self.check_single_matmul(x, y, (r, o, n, p)) + # empty tensor indexing + self.assertEqual(reference[torch.LongTensor().to(device)], reference.new(0, 3, 3)) - # 2d, 4d, inner dimensions non-contiguous - x = torch.arange(n * 2 * m).reshape(n, 2 * m)[:, ::2] - y = torch.arange(r * o * m * 2 * p).reshape(r, o, m, 2 * p)[:, :, :, ::2] - self.check_single_matmul(x, y, (r, o, n, p)) + self.assertEqual(reference[0], consec((3, 3)), 0) + self.assertEqual(reference[1], consec((3, 3), 10), 0) + self.assertEqual(reference[2], consec((3, 3), 19), 0) + self.assertEqual(reference[0, 1], consec((3,), 4), 0) + self.assertEqual(reference[0:2], consec((2, 3, 3)), 0) + self.assertEqual(reference[2, 2, 2], 27, 0) + self.assertEqual(reference[:], consec((3, 3, 3)), 0) - @skipIfRocm - def test_blas_alpha_beta_empty(self): - for device in torch.testing.get_all_device_types(): - # ensure beta is respected - value = 11 - input = torch.full((2,), value, device=device) - mat = torch.ones((2, 0), device=device) - vec = torch.ones((0,), device=device) - out = torch.randn((2,), device=device) - alpha = 6 - beta = 3 - self.assertEqual(torch.full((2,), beta * value, device=device), - torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta)) - self.assertEqual(torch.full((2,), beta * value, device=device), - torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta, out=out)) - - # torch.addmm - input = torch.full((2, 3), value, device=device) - mat2 = torch.ones((0, 3), device=device) - out = torch.randn((2, 3), device=device) - self.assertEqual(torch.full((2, 3), beta * value, device=device), - torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta)) - self.assertEqual(torch.full((2, 3), beta * value, device=device), - torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta, out=out)) + # indexing with Ellipsis + self.assertEqual(reference[..., 2], torch.Tensor([[3, 6, 9], + [12, 15, 18], + [21, 24, 27]]), 0) + self.assertEqual(reference[0, ..., 2], torch.Tensor([3, 6, 9]), 0) + self.assertEqual(reference[..., 2], reference[:, :, 2], 0) + self.assertEqual(reference[0, ..., 2], reference[0, :, 2], 0) + self.assertEqual(reference[0, 2, ...], reference[0, 2], 0) + self.assertEqual(reference[..., 2, 2, 2], 27, 0) + self.assertEqual(reference[2, ..., 2, 2], 27, 0) + self.assertEqual(reference[2, 2, ..., 2], 27, 0) + self.assertEqual(reference[2, 2, 2, ...], 27, 0) + self.assertEqual(reference[...], reference, 0) - @skipIfNoLapack - def test_lapack_empty(self): - # FIXME: these are just a selection of LAPACK functions -- we need a general strategy here. - # The LAPACK functions themselves generally do NOT work with zero sized dimensions, although - # numpy/sci often has a direct wrapper (e.g. lu_factor) and a wrapper that "does the right thing" - # (e.g. lu). We often name our functions identically to the lapack function, so it will take work - # to name / migrate-to better wrappers. - for device in torch.testing.get_all_device_types(): + reference_5d = consec((3, 3, 3, 3, 3)).to(device) + self.assertEqual(reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0], 0) + self.assertEqual(reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0], 0) + self.assertEqual(reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1], 0) + self.assertEqual(reference_5d[...], reference_5d, 0) - # need to init cuda to check has_magma - empty = torch.randn((0, 0), device=device) - if device == 'cuda' and not torch.cuda.has_magma: - continue + # LongTensor indexing + reference = consec((5, 5, 5)).to(device) + idx = torch.LongTensor([2, 4]).to(device) + self.assertEqual(reference[idx], torch.stack([reference[2], reference[4]])) + # TODO: enable one indexing is implemented like in numpy + # self.assertEqual(reference[2, idx], torch.stack([reference[2, 2], reference[2, 4]])) + # self.assertEqual(reference[3, idx, 1], torch.stack([reference[3, 2], reference[3, 4]])[:, 1]) - def fn(torchfn, *args): - return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape - for shape in args)) - - # inverse, pinverse - self.assertEqual((0, 0), fn(torch.inverse, (0, 0)).shape) - self.assertEqual((5, 0), fn(torch.pinverse, (0, 5)).shape) - self.assertEqual((0, 5), fn(torch.pinverse, (5, 0)).shape) - self.assertEqual((0, 0), fn(torch.pinverse, (0, 0)).shape) - - # det, logdet, slogdet - self.assertEqual(torch.tensor(1., device=device), fn(torch.det, (0, 0))) - self.assertEqual(torch.tensor(0., device=device), fn(torch.logdet, (0, 0))) - self.assertEqual((torch.tensor(1., device=device), torch.tensor(0., device=device)), - fn(torch.slogdet, (0, 0))) - - # eig, symeig - evalues, evectors = fn(torch.eig, (0, 0), True) - self.assertEqual([(0, 2), (0, 0)], [evalues.shape, evectors.shape]) - evalues, evectors = fn(torch.symeig, (0, 0), True) - self.assertEqual([(0,), (0, 0)], [evalues.shape, evectors.shape]) - - # qr - q, r = fn(torch.qr, (3, 0), True) - self.assertEqual([(3, 0), (0, 0)], [q.shape, r.shape]) - q, r = fn(torch.qr, (0, 3), True) - self.assertEqual([(0, 0), (0, 3)], [q.shape, r.shape]) - q, r = fn(torch.qr, (3, 0), False) - self.assertEqual([(3, 3), (3, 0)], [q.shape, r.shape]) - - # lstsq - self.assertRaises(RuntimeError, lambda: torch.lstsq(torch.randn(0, 0), torch.randn(0, 0))) - self.assertRaises(RuntimeError, lambda: torch.lstsq(torch.randn(0,), torch.randn(0, 0))) + # None indexing + self.assertEqual(reference[2, None], reference[2].unsqueeze(0)) + self.assertEqual(reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0)) + self.assertEqual(reference[2:4, None], reference[2:4].unsqueeze(1)) + self.assertEqual(reference[None, 2, None, None], reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0)) + self.assertEqual(reference[None, 2:5, None, None], reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2)) - def test_expand(self): - tensor = torch.rand(1, 8, 1) - tensor2 = torch.rand(5) - template = torch.rand(4, 8, 5) - target = template.size() - self.assertEqual(tensor.expand_as(template).size(), target) - self.assertEqual(tensor.expand(4, 8, 5).size(), target) - self.assertEqual(tensor.expand(target).size(), target) - self.assertEqual(tensor2.expand_as(template).size(), target) - self.assertEqual(tensor2.expand(4, 8, 5).size(), target) - self.assertEqual(tensor2.expand(target).size(), target) + # indexing 0-length slice + self.assertEqual(torch.empty(0, 5, 5), reference[slice(0)]) + self.assertEqual(torch.empty(0, 5), reference[slice(0), 2]) + self.assertEqual(torch.empty(0, 5), reference[2, slice(0)]) + self.assertEqual(torch.tensor([]), reference[2, 1:1, 2]) - # test double expand - self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1)) + # indexing with step + reference = consec((10, 10, 10)).to(device) + self.assertEqual(reference[1:5:2], torch.stack([reference[1], reference[3]], 0)) + self.assertEqual(reference[1:6:2], torch.stack([reference[1], reference[3], reference[5]], 0)) + self.assertEqual(reference[1:9:4], torch.stack([reference[1], reference[5]], 0)) + self.assertEqual(reference[2:4, 1:5:2], torch.stack([reference[2:4, 1], reference[2:4, 3]], 1)) + self.assertEqual(reference[3, 1:6:2], torch.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0)) + self.assertEqual(reference[None, 2, 1:9:4], torch.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0)) + self.assertEqual(reference[:, 2, 1:6:2], + torch.stack([reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1)) - # test non-contiguous - noncontig = torch.randn(5, 2, 1, 3)[:, 0] - self.assertFalse(noncontig.is_contiguous()) - self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1)) + lst = [list(range(i, i + 10)) for i in range(0, 100, 10)] + tensor = torch.DoubleTensor(lst).to(device) + for _i in range(100): + idx1_start = random.randrange(10) + idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1) + idx1_step = random.randrange(1, 8) + idx1 = slice(idx1_start, idx1_end, idx1_step) + if random.randrange(2) == 0: + idx2_start = random.randrange(10) + idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1) + idx2_step = random.randrange(1, 8) + idx2 = slice(idx2_start, idx2_end, idx2_step) + lst_indexed = list(map(lambda l: l[idx2], lst[idx1])) + tensor_indexed = tensor[idx1, idx2] + else: + lst_indexed = lst[idx1] + tensor_indexed = tensor[idx1] + self.assertEqual(torch.DoubleTensor(lst_indexed), tensor_indexed) - # make sure it's compatible with unsqueeze - expanded = tensor2.expand(1, 1, 5) - unsqueezed = tensor2.unsqueeze(0).unsqueeze(1) - self.assertEqual(expanded, unsqueezed) - self.assertEqual(expanded.stride(), unsqueezed.stride()) + self.assertRaises(ValueError, lambda: reference[1:9:0]) + self.assertRaises(ValueError, lambda: reference[1:9:-1]) - # test -1 as target size - self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5)) - self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1)) + self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1]) + self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1]) + self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3]) - # test expanding empty to empty - self.assertEqual(torch.zeros(0).expand((0,)), torch.zeros(0)) + self.assertRaises(IndexError, lambda: reference[0.0]) + self.assertRaises(TypeError, lambda: reference[0.0:2.0]) + self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0]) + self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0]) + self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0]) + self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0]) - def test_repeat(self): - initial_shape = (8, 4) - tensor = torch.rand(*initial_shape) + def delitem(): + del reference[0] - size = (3, 1, 1) - torchSize = torch.Size(size) - target = [3, 8, 4] - self.assertEqual(tensor.repeat(*size).size(), target, 'Error in repeat') - self.assertEqual(tensor.repeat(torchSize).size(), target, - 'Error in repeat using LongStorage') - result = tensor.repeat(*size) - self.assertEqual(result.size(), target, 'Error in repeat using result') - result = tensor.repeat(torchSize) - self.assertEqual(result.size(), target, 'Error in repeat using result and LongStorage') - self.assertEqual(result.mean(0).view(8, 4), tensor, 'Error in repeat (not equal)') + self.assertRaises(TypeError, delitem) - zeroDimTarget = torch.Size([24, 0]) - self.assertEqual(tensor.repeat((3, 0)).size(), zeroDimTarget, "Error when calling with 0 repeats") + @skipCUDANonDefaultStreamIf(True) + def test_advancedindex(self, device): + # Tests for Integer Array Indexing, Part I - Purely integer array + # indexing - def test_repeat_interleave(self): - x = torch.tensor([0, 1, 2, 3]) - expected = torch.tensor([1, 2, 2, 3, 3, 3]) - self.assertEqual(torch.repeat_interleave(x), expected) + def consec(size, start=1): + numel = reduce(lambda x, y: x * y, size, 1) + sequence = torch.ones(numel).cumsum(0) + sequence.add_(start - 1) + return sequence.view(*size) - with self.assertRaises(RuntimeError): - torch.repeat_interleave(torch.arange(4).reshape(2, 2)) + # pick a random valid indexer type + def ri(indices): + choice = random.randint(0, 2) + if choice == 0: + return torch.LongTensor(indices).to(device) + elif choice == 1: + return list(indices) + else: + return tuple(indices) - with self.assertRaises(RuntimeError): - torch.repeat_interleave(torch.arange(4.0)) + def validate_indexing(x): + self.assertEqual(x[[0]], consec((1,))) + self.assertEqual(x[ri([0]), ], consec((1,))) + self.assertEqual(x[ri([3]), ], consec((1,), 4)) + self.assertEqual(x[[2, 3, 4]], consec((3,), 3)) + self.assertEqual(x[ri([2, 3, 4]), ], consec((3,), 3)) + self.assertEqual(x[ri([0, 2, 4]), ], torch.Tensor([1, 3, 5])) - with self.assertRaises(RuntimeError): - torch.repeat_interleave(torch.tensor([1, 2, -1, 3, 4])) + def validate_setting(x): + dtype = x.type() + x[[0]] = -2 + self.assertEqual(x[[0]], torch.Tensor([-2]).type(dtype)) + x[[0]] = -1 + self.assertEqual(x[ri([0]), ], torch.Tensor([-1]).type(dtype)) + x[[2, 3, 4]] = 4 + self.assertEqual(x[[2, 3, 4]], torch.Tensor([4, 4, 4]).type(dtype)) + x[ri([2, 3, 4]), ] = 3 + self.assertEqual(x[ri([2, 3, 4]), ], torch.Tensor([3, 3, 3]).type(dtype)) + x[ri([0, 2, 4]), ] = torch.Tensor([5, 4, 3]).type(dtype).to(device) + self.assertEqual(x[ri([0, 2, 4]), ], torch.Tensor([5, 4, 3]).type(dtype)) - y = torch.tensor([[1, 2], [3, 4]]) + # First, we will test indexing to generate return values - y1_v1 = torch.repeat_interleave(y, 2) - y1_v2 = torch.repeat_interleave(y, torch.tensor(2)) - y1_v3 = torch.repeat_interleave(y, torch.tensor([2])) - y1_expect = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4]) - self.assertEqual(y1_v1, y1_expect) - self.assertEqual(y1_v2, y1_expect) - self.assertEqual(y1_v3, y1_expect) + # Case 1: Purely Integer Array Indexing + reference = consec((10,)).to(device) + validate_indexing(reference) + validate_indexing(reference.type(torch.half)) - y2 = torch.repeat_interleave(y, 3, dim=1) - y2_expect = torch.tensor([[1, 1, 1, 2, 2, 2], - [3, 3, 3, 4, 4, 4]]) - self.assertEqual(y2, y2_expect) + # setting values + validate_setting(reference) + validate_setting(reference.type(torch.half)) - y3 = torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) - y3_expect = torch.tensor([[1, 2], - [3, 4], - [3, 4]]) - self.assertEqual(y3, y3_expect) + # Tensor with stride != 1 - with self.assertRaises(RuntimeError): - torch.repeat_interleave(y, torch.tensor([1, 2, 3]), dim=0) + # strided is [1, 3, 5, 7] + reference = consec((10,)).to(device) + strided = torch.Tensor().to(device) + strided.set_(reference.storage(), storage_offset=0, + size=torch.Size([4]), stride=[2]) - with self.assertRaises(RuntimeError): - torch.repeat_interleave(y, torch.arange(9).reshape(3, 3), dim=0) + self.assertEqual(strided[[0]], torch.Tensor([1])) + self.assertEqual(strided[ri([0]), ], torch.Tensor([1])) + self.assertEqual(strided[ri([3]), ], torch.Tensor([7])) + self.assertEqual(strided[[1, 2]], torch.Tensor([3, 5])) + self.assertEqual(strided[ri([1, 2]), ], torch.Tensor([3, 5])) + self.assertEqual(strided[ri([[2, 1], [0, 3]]), ], + torch.Tensor([[5, 3], [1, 7]])) - # test zero sized dimension - x = torch.zeros((5, 0)) - y = torch.repeat_interleave(x, repeats=3, dim=1) - self.assertEqual(y, x.new_zeros(5, 0)) + # stride is [4, 8] + strided = torch.Tensor().to(device) + strided.set_(reference.storage(), storage_offset=4, + size=torch.Size([2]), stride=[4]) + self.assertEqual(strided[[0]], torch.Tensor([5])) + self.assertEqual(strided[ri([0]), ], torch.Tensor([5])) + self.assertEqual(strided[ri([1]), ], torch.Tensor([9])) + self.assertEqual(strided[[0, 1]], torch.Tensor([5, 9])) + self.assertEqual(strided[ri([0, 1]), ], torch.Tensor([5, 9])) + self.assertEqual(strided[ri([[0, 1], [1, 0]]), ], + torch.Tensor([[5, 9], [9, 5]])) - x = torch.tensor([], dtype=torch.int64) - y = torch.repeat_interleave(x, x) - self.assertEqual(y, x) + # reference is 1 2 + # 3 4 + # 5 6 + reference = consec((3, 2)).to(device) + self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([1, 3, 5])) + self.assertEqual(reference[ri([0, 1, 2]), ri([1])], torch.Tensor([2, 4, 6])) + self.assertEqual(reference[ri([0]), ri([0])], consec((1,))) + self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6)) + self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]], torch.Tensor([1, 2])) + self.assertEqual(reference[[ri([0, 1, 1, 0, 2]), ri([1])]], + torch.Tensor([2, 4, 4, 2, 6])) + self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], + torch.Tensor([1, 2, 3, 3])) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_repeat_tile(self): + rows = ri([[0, 0], + [1, 2]]) + columns = [0], + self.assertEqual(reference[rows, columns], torch.Tensor([[1, 1], + [3, 5]])) - initial_shape = (8, 4) + rows = ri([[0, 0], + [1, 2]]) + columns = ri([1, 0]) + self.assertEqual(reference[rows, columns], torch.Tensor([[2, 1], + [4, 5]])) + rows = ri([[0, 0], + [1, 2]]) + columns = ri([[0, 1], + [1, 0]]) + self.assertEqual(reference[rows, columns], torch.Tensor([[1, 2], + [4, 5]])) - repeats = ((3, 1, 1), - (3, 3, 3), - (1, 2, 1), - (2, 2, 2, 2)) + # setting values + reference[ri([0]), ri([1])] = -1 + self.assertEqual(reference[ri([0]), ri([1])], torch.Tensor([-1])) + reference[ri([0, 1, 2]), ri([0])] = torch.Tensor([-1, 2, -4]).to(device) + self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([-1, + 2, -4])) + reference[rows, columns] = torch.Tensor([[4, 6], [2, 3]]).to(device) + self.assertEqual(reference[rows, columns], + torch.Tensor([[4, 6], [2, 3]])) - def _generate_noncontiguous_input(): + # Verify still works with Transposed (i.e. non-contiguous) Tensors - out = np.broadcast_to(np.random.random((1, 4)), - initial_shape) + reference = torch.Tensor([[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11]]).to(device).t_() - assert not (out.flags.c_contiguous or out.flags.f_contiguous) + # Transposed: [[0, 4, 8], + # [1, 5, 9], + # [2, 6, 10], + # [3, 7, 11]] - return out + self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([0, 1, + 2])) + self.assertEqual(reference[ri([0, 1, 2]), ri([1])], torch.Tensor([4, 5, + 6])) + self.assertEqual(reference[ri([0]), ri([0])], torch.Tensor([0])) + self.assertEqual(reference[ri([2]), ri([1])], torch.Tensor([6])) + self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]], torch.Tensor([0, 4])) + self.assertEqual(reference[[ri([0, 1, 1, 0, 3]), ri([1])]], + torch.Tensor([4, 5, 5, 4, 7])) + self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], + torch.Tensor([0, 4, 1, 1])) - for repeat in repeats: - for tensor in (torch.from_numpy(np.random.random(initial_shape)), - torch.from_numpy(_generate_noncontiguous_input()),): + rows = ri([[0, 0], + [1, 2]]) + columns = [0], + self.assertEqual(reference[rows, columns], torch.Tensor([[0, 0], + [1, 2]])) - self.assertEqual(tensor.repeat(*repeat).numpy(), - np.tile(tensor.numpy(), repeat)) + rows = ri([[0, 0], + [1, 2]]) + columns = ri([1, 0]) + self.assertEqual(reference[rows, columns], torch.Tensor([[4, 0], + [5, 2]])) + rows = ri([[0, 0], + [1, 3]]) + columns = ri([[0, 1], + [1, 2]]) + self.assertEqual(reference[rows, columns], torch.Tensor([[0, 4], + [5, 11]])) - def test_is_same_size(self): - t1 = torch.Tensor(3, 4, 9, 10) - t2 = torch.Tensor(3, 4) - t3 = torch.Tensor(1, 9, 3, 3) - t4 = torch.Tensor(3, 4, 9, 10) + # setting values + reference[ri([0]), ri([1])] = -1 + self.assertEqual(reference[ri([0]), ri([1])], torch.Tensor([-1])) + reference[ri([0, 1, 2]), ri([0])] = torch.Tensor([-1, 2, -4]).to(device) + self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.Tensor([-1, + 2, -4])) + reference[rows, columns] = torch.Tensor([[4, 6], [2, 3]]).to(device) + self.assertEqual(reference[rows, columns], + torch.Tensor([[4, 6], [2, 3]])) - self.assertFalse(t1.is_same_size(t2)) - self.assertFalse(t1.is_same_size(t3)) - self.assertTrue(t1.is_same_size(t4)) + # stride != 1 - def test_is_set_to(self): - t1 = torch.Tensor(3, 4, 9, 10) - t2 = torch.Tensor(3, 4, 9, 10) - t3 = torch.Tensor().set_(t1) - t4 = t3.clone().resize_(12, 90) - self.assertFalse(t1.is_set_to(t2)) - self.assertTrue(t1.is_set_to(t3)) - self.assertTrue(t3.is_set_to(t1), "is_set_to should be symmetric") - self.assertFalse(t1.is_set_to(t4)) - self.assertFalse(torch.Tensor().is_set_to(torch.Tensor()), - "Tensors with no storages should not appear to be set " - "to each other") + # strided is [[1 3 5 7], + # [9 11 13 15]] - t1 = torch.tensor([True, True], dtype=torch.bool) - t2 = torch.tensor([0], dtype=torch.bool).set_(t1) - self.assertTrue(t1.is_set_to(t2)) + reference = torch.arange(0., 24).view(3, 8).to(device) + strided = torch.Tensor().to(device) + strided.set_(reference.storage(), 1, size=torch.Size([2, 4]), + stride=[8, 2]) - def test_tensor_set(self): - t1 = torch.Tensor() - t2 = torch.Tensor(3, 4, 9, 10).uniform_() - t1.set_(t2) - self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) - size = torch.Size([9, 3, 4, 10]) - t1.set_(t2.storage(), 0, size) - self.assertEqual(t1.size(), size) - t1.set_(t2.storage(), 0, tuple(size)) - self.assertEqual(t1.size(), size) - self.assertEqual(t1.stride(), (120, 40, 10, 1)) - stride = (10, 360, 90, 1) - t1.set_(t2.storage(), 0, size, stride) - self.assertEqual(t1.stride(), stride) - t1.set_(t2.storage(), 0, size=size, stride=stride) - self.assertEqual(t1.size(), size) - self.assertEqual(t1.stride(), stride) + self.assertEqual(strided[ri([0, 1]), ri([0])], torch.Tensor([1, 9])) + self.assertEqual(strided[ri([0, 1]), ri([1])], torch.Tensor([3, 11])) + self.assertEqual(strided[ri([0]), ri([0])], torch.Tensor([1])) + self.assertEqual(strided[ri([1]), ri([3])], torch.Tensor([15])) + self.assertEqual(strided[[ri([0, 0]), ri([0, 3])]], torch.Tensor([1, 7])) + self.assertEqual(strided[[ri([1]), ri([0, 1, 1, 0, 3])]], + torch.Tensor([9, 11, 11, 9, 15])) + self.assertEqual(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]], + torch.Tensor([1, 3, 9, 9])) - # test argument names - t1 = torch.Tensor() - # 1. case when source is tensor - t1.set_(source=t2) - self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) - # 2. case when source is storage - t1.set_(source=t2.storage()) - self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) - # 3. case when source is storage, and other args also specified - t1.set_(source=t2.storage(), storage_offset=0, size=size, stride=stride) - self.assertEqual(t1.size(), size) - self.assertEqual(t1.stride(), stride) + rows = ri([[0, 0], + [1, 1]]) + columns = [0], + self.assertEqual(strided[rows, columns], torch.Tensor([[1, 1], + [9, 9]])) - t1 = torch.tensor([True, True], dtype=torch.bool) - t2 = torch.tensor([False, False], dtype=torch.bool) - t1.set_(t2) - self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) + rows = ri([[0, 1], + [1, 0]]) + columns = ri([1, 2]) + self.assertEqual(strided[rows, columns], torch.Tensor([[3, 13], + [11, 5]])) + rows = ri([[0, 0], + [1, 1]]) + columns = ri([[0, 1], + [1, 2]]) + self.assertEqual(strided[rows, columns], torch.Tensor([[1, 3], + [11, 13]])) - def test_tensor_set_errors(self): - f_cpu = torch.randn((2, 3), dtype=torch.float32) - d_cpu = torch.randn((2, 3), dtype=torch.float64) + # setting values - # change dtype - self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu.storage())) - self.assertRaises(RuntimeError, - lambda: f_cpu.set_(d_cpu.storage(), 0, d_cpu.size(), d_cpu.stride())) - self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu)) + # strided is [[10, 11], + # [17, 18]] - # change device - if torch.cuda.is_available(): - f_cuda = torch.randn((2, 3), dtype=torch.float32, device='cuda') + reference = torch.arange(0., 24).view(3, 8).to(device) + strided = torch.Tensor().to(device) + strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), + stride=[7, 1]) + self.assertEqual(strided[ri([0]), ri([1])], torch.Tensor([11])) + strided[ri([0]), ri([1])] = -1 + self.assertEqual(strided[ri([0]), ri([1])], torch.Tensor([-1])) - # cpu -> cuda - self.assertRaises(RuntimeError, lambda: f_cpu.set_(f_cuda.storage())) - self.assertRaises(RuntimeError, - lambda: f_cpu.set_(f_cuda.storage(), 0, f_cuda.size(), f_cuda.stride())) - self.assertRaises(RuntimeError, lambda: f_cpu.set_(f_cuda)) + reference = torch.arange(0., 24).view(3, 8).to(device) + strided = torch.Tensor().to(device) + strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), + stride=[7, 1]) + self.assertEqual(strided[ri([0, 1]), ri([1, 0])], torch.Tensor([11, + 17])) + strided[ri([0, 1]), ri([1, 0])] = torch.Tensor([-1, 2]).to(device) + self.assertEqual(strided[ri([0, 1]), ri([1, 0])], torch.Tensor([-1, + 2])) - # cuda -> cpu - self.assertRaises(RuntimeError, lambda: f_cuda.set_(f_cpu.storage())) - self.assertRaises(RuntimeError, - lambda: f_cuda.set_(f_cpu.storage(), 0, f_cpu.size(), f_cpu.stride())) - self.assertRaises(RuntimeError, lambda: f_cuda.set_(f_cpu)) + reference = torch.arange(0., 24).view(3, 8).to(device) + strided = torch.Tensor().to(device) + strided.set_(reference.storage(), 10, size=torch.Size([2, 2]), + stride=[7, 1]) - @unittest.skipIf(torch.cuda.device_count() < 2, 'less than 2 GPUs detected') - def test_tensor_set_errors_multigpu(self): - f_cuda0 = torch.randn((2, 3), dtype=torch.float32, device='cuda:0') - f_cuda1 = torch.randn((2, 3), dtype=torch.float32, device='cuda:1') + rows = ri([[0], + [1]]) + columns = ri([[0, 1], + [0, 1]]) + self.assertEqual(strided[rows, columns], + torch.Tensor([[10, 11], [17, 18]])) + strided[rows, columns] = torch.Tensor([[4, 6], [2, 3]]).to(device) + self.assertEqual(strided[rows, columns], + torch.Tensor([[4, 6], [2, 3]])) - self.assertRaises(RuntimeError, lambda: f_cuda0.set_(f_cuda1.storage())) - self.assertRaises(RuntimeError, - lambda: f_cuda0.set_(f_cuda1.storage(), 0, f_cuda1.size(), f_cuda1.stride())) - self.assertRaises(RuntimeError, lambda: f_cuda0.set_(f_cuda1)) + # Tests using less than the number of dims, and ellipsis - def test_equal(self): - # Contiguous, 1D - t1 = torch.Tensor((3, 4, 9, 10)) - t2 = t1.contiguous() - t3 = torch.Tensor((1, 9, 3, 10)) - t4 = torch.Tensor((3, 4, 9)) - t5 = torch.Tensor() - self.assertTrue(t1.equal(t2)) - self.assertFalse(t1.equal(t3)) - self.assertFalse(t1.equal(t4)) - self.assertFalse(t1.equal(t5)) - self.assertTrue(torch.equal(t1, t2)) - self.assertFalse(torch.equal(t1, t3)) - self.assertFalse(torch.equal(t1, t4)) - self.assertFalse(torch.equal(t1, t5)) + # reference is 1 2 + # 3 4 + # 5 6 + reference = consec((3, 2)).to(device) + self.assertEqual(reference[ri([0, 2]), ], torch.Tensor([[1, 2], [5, 6]])) + self.assertEqual(reference[ri([1]), ...], torch.Tensor([[3, 4]])) + self.assertEqual(reference[..., ri([1])], torch.Tensor([[2], [4], [6]])) - # Non contiguous, 2D - s = torch.Tensor(((1, 2, 3, 4), (5, 6, 7, 8))) - s1 = s[:, 1:3] - s2 = s1.clone() - s3 = torch.Tensor(((2, 3), (6, 7))) - s4 = torch.Tensor(((0, 0), (0, 0))) + # verify too many indices fails + with self.assertRaises(IndexError): + reference[ri([1]), ri([0, 2]), ri([3])] - self.assertFalse(s1.is_contiguous()) - self.assertTrue(s1.equal(s2)) - self.assertTrue(s1.equal(s3)) - self.assertFalse(s1.equal(s4)) - self.assertTrue(torch.equal(s1, s2)) - self.assertTrue(torch.equal(s1, s3)) - self.assertFalse(torch.equal(s1, s4)) + # test invalid index fails + reference = torch.empty(10, device=device) + # can't test cuda because it is a device assert + if not reference.is_cuda: + for err_idx in (10, -11): + with self.assertRaisesRegex(IndexError, r'out of'): + reference[err_idx] + with self.assertRaisesRegex(IndexError, r'out of'): + reference[torch.LongTensor([err_idx]).to(device)] + with self.assertRaisesRegex(IndexError, r'out of'): + reference[[err_idx]] - def test_element_size(self): - byte = torch.ByteStorage().element_size() - char = torch.CharStorage().element_size() - short = torch.ShortStorage().element_size() - int = torch.IntStorage().element_size() - long = torch.LongStorage().element_size() - float = torch.FloatStorage().element_size() - double = torch.DoubleStorage().element_size() - bool = torch.BoolStorage().element_size() - bfloat16 = torch.BFloat16Storage().element_size() + if TEST_NUMPY: + # we use numpy to compare against, to verify that our advanced + # indexing semantics are the same, and also for ease of test + # writing - self.assertEqual(byte, torch.ByteTensor().element_size()) - self.assertEqual(char, torch.CharTensor().element_size()) - self.assertEqual(short, torch.ShortTensor().element_size()) - self.assertEqual(int, torch.IntTensor().element_size()) - self.assertEqual(long, torch.LongTensor().element_size()) - self.assertEqual(float, torch.FloatTensor().element_size()) - self.assertEqual(double, torch.DoubleTensor().element_size()) - self.assertEqual(bool, torch.BoolTensor().element_size()) + def tensor_indices_to_np(tensor, indices): + # convert the Torch Tensor to a numpy array + if (tensor.is_cuda): + tensor = tensor.cpu() + npt = tensor.numpy() - self.assertGreater(byte, 0) - self.assertGreater(char, 0) - self.assertGreater(short, 0) - self.assertGreater(int, 0) - self.assertGreater(long, 0) - self.assertGreater(float, 0) - self.assertGreater(double, 0) - self.assertGreater(bool, 0) - self.assertGreater(bfloat16, 0) + # convert indices + idxs = tuple(i.tolist() if isinstance(i, torch.LongTensor) else + i for i in indices) - # These tests are portable, not necessarily strict for your system. - self.assertEqual(byte, 1) - self.assertEqual(char, 1) - self.assertEqual(bool, 1) - self.assertGreaterEqual(short, 2) - self.assertGreaterEqual(int, 2) - self.assertGreaterEqual(int, short) - self.assertGreaterEqual(long, 4) - self.assertGreaterEqual(long, int) - self.assertGreaterEqual(double, float) + return npt, idxs - def test_split(self): - tensor = torch.rand(7, 4) - split_size = 3 - dim = 0 - target_sizes = ([3, 4], [3, 4], [1, 4]) - splits = tensor.split(split_size, dim) - start = 0 - for target_size, split in zip(target_sizes, splits): - self.assertEqual(split.size(), target_size) - self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0) - start = start + target_size[dim] + def get_numpy(tensor, indices): + npt, idxs = tensor_indices_to_np(tensor, indices) - # Variable sections split - tensor = torch.randn(20, 10) - dim = 0 - split_sizes = [5, 5, 10] - target_sizes = ([[5, 10], [5, 10], [10, 10]]) - splits = tensor.split(split_sizes, dim) - start = 0 - for target_size, split in zip(target_sizes, splits): - self.assertEqual(split.size(), target_size) - self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0) - start = start + target_size[dim] + # index and return as a Torch Tensor + return torch.Tensor(npt[idxs]) - split_sizes = [2, 2, 6] - target_sizes = ([20, 2], [20, 2], [20, 6]) - dim = 1 - splits = tensor.split(split_sizes, dim) - start = 0 - for target_size, split in zip(target_sizes, splits): - self.assertEqual(split.size(), target_size) - self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0) - start = start + target_size[dim] + def set_numpy(tensor, indices, value): + if not isinstance(value, int): + if value.is_cuda: + value = value.cpu() + value = value.numpy() - def test_chunk(self): - tensor = torch.rand(4, 7) - num_chunks = 3 - dim = 1 - target_sizes = ([4, 3], [4, 3], [4, 1]) - splits = tensor.chunk(num_chunks, dim) - start = 0 - for target_size, split in zip(target_sizes, splits): - self.assertEqual(split.size(), target_size) - self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0) - start = start + target_size[dim] + npt, idxs = tensor_indices_to_np(tensor, indices) + npt[idxs] = value + return npt - # Invalid chunk sizes - error_regex = 'chunk expects.*greater than 0' - with self.assertRaisesRegex(RuntimeError, error_regex): - tensor.chunk(0) - with self.assertRaisesRegex(RuntimeError, error_regex): - tensor.chunk(-2) + def assert_get_eq(tensor, indexer): + self.assertEqual(tensor[indexer], + get_numpy(tensor, indexer).to(device)) - def test_tolist(self): - list0D = [] - tensor0D = torch.Tensor(list0D) - self.assertEqual(tensor0D.tolist(), list0D) + def assert_set_eq(tensor, indexer, val): + pyt = tensor.clone() + numt = tensor.clone() + pyt[indexer] = val + numt = torch.Tensor(set_numpy(numt, indexer, val)).to(device) + self.assertEqual(pyt, numt) - table1D = [1, 2, 3] - tensor1D = torch.Tensor(table1D) - storage = torch.Storage(table1D) - self.assertEqual(tensor1D.tolist(), table1D) - self.assertEqual(storage.tolist(), table1D) - self.assertEqual(tensor1D.tolist(), table1D) - self.assertEqual(storage.tolist(), table1D) + def assert_backward_eq(tensor, indexer): + cpu = tensor.float().clone().detach().requires_grad_(True) + outcpu = cpu[indexer] + gOcpu = torch.rand_like(outcpu) + outcpu.backward(gOcpu) + gpu = cpu.cuda().detach().requires_grad_(True) + outgpu = gpu[indexer] + outgpu.backward(gOcpu.cuda()) + self.assertEqual(cpu.grad, gpu.grad) - table2D = [[1, 2], [3, 4]] - tensor2D = torch.Tensor(table2D) - self.assertEqual(tensor2D.tolist(), table2D) + def get_set_tensor(indexed, indexer): + set_size = indexed[indexer].size() + set_count = indexed[indexer].numel() + set_tensor = torch.randperm(set_count).view(set_size).double().to(device) + return set_tensor - tensor3D = torch.Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) - tensorNonContig = tensor3D.select(1, 1) - self.assertFalse(tensorNonContig.is_contiguous()) - self.assertEqual(tensorNonContig.tolist(), [[3, 4], [7, 8]]) + # Tensor is 0 1 2 3 4 + # 5 6 7 8 9 + # 10 11 12 13 14 + # 15 16 17 18 19 + reference = torch.arange(0., 20).view(4, 5).to(device) - def test_permute(self): - orig = [1, 2, 3, 4, 5, 6, 7] - perm = torch.randperm(7).tolist() - x = torch.Tensor(*orig).fill_(0) - new = list(map(lambda x: x - 1, x.permute(*perm).size())) - self.assertEqual(perm, new) - self.assertEqual(x.size(), orig) + indices_to_test = [ + # grab the second, fourth columns + [slice(None), [1, 3]], - @staticmethod - def _test_flip(self, use_cuda=False): - device = torch.device('cuda') if use_cuda else torch.device('cpu') - data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2) + # first, third rows, + [[0, 2], slice(None)], - self.assertEqual(torch.tensor([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2), data.flip(0)) - self.assertEqual(torch.tensor([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2), data.flip(1)) - self.assertEqual(torch.tensor([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2), data.flip(2)) - self.assertEqual(torch.tensor([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2), data.flip(0, 1)) - self.assertEqual(torch.tensor([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2), data.flip(0, 1, 2)) + # weird shape + [slice(None), [[0, 1], + [2, 3]]], + # negatives + [[-1], [0]], + [[0, 2], [-1]], + [slice(None), [-1]], + ] - # check for wrap dim - self.assertEqual(torch.tensor([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2), data.flip(-1)) - # check for permute - self.assertEqual(torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(0, 2)) - self.assertEqual(torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(2, 0)) + # only test dupes on gets + get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]] - # not allow flip on the same dim more than once - self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 1)) - # not allow empty list as input - self.assertRaises(TypeError, lambda: data.flip()) + for indexer in get_indices_to_test: + assert_get_eq(reference, indexer) + if torch.cuda.is_available(): + assert_backward_eq(reference, indexer) - # not allow size of flip dim > total dims - self.assertRaises(IndexError, lambda: data.flip(0, 1, 2, 3)) - # not allow dim > max dim - self.assertRaises(IndexError, lambda: data.flip(3)) + for indexer in indices_to_test: + assert_set_eq(reference, indexer, 44) + assert_set_eq(reference, + indexer, + get_set_tensor(reference, indexer)) - # test for non-contiguous case - expanded_data = torch.arange(1, 4, device=device).view(3, 1).expand(3, 2) - transposed_data = torch.arange(1, 9, device=device).view(2, 2, 2).transpose(0, 1) - self.assertEqual(torch.tensor([3, 3, 2, 2, 1, 1]).view(3, 2), expanded_data.flip(0)) - self.assertEqual(torch.tensor([8, 7, 4, 3, 6, 5, 2, 1]).view(2, 2, 2), transposed_data.flip(0, 1, 2)) + reference = torch.arange(0., 160).view(4, 8, 5).to(device) - # test for shape - data = torch.randn(2, 3, 4, device=device) - size = [2, 3, 4] - test_dims = [] - for i in range(1, 3): - test_dims += combinations(range(len(size)), i) - - for ds in test_dims: - self.assertEqual(size, list(data.flip(ds).size())) - - # test rectangular case - data = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3) - flip0_result = torch.tensor([[4, 5, 6], [1, 2, 3]]) - flip1_result = torch.tensor([[3, 2, 1], [6, 5, 4]]) - if use_cuda: - data = data.cuda() - flip0_result = flip0_result.cuda() - flip1_result = flip1_result.cuda() - self.assertEqual(flip0_result, data.flip(0)) - self.assertEqual(flip1_result, data.flip(1)) - - # test empty tensor, should just return an empty tensor of the same shape - data = torch.tensor([]) - self.assertEqual(data, data.flip(0)) - - def test_flip(self): - self._test_flip(self, use_cuda=False) - - def test_roll(self): - for device in torch.testing.get_all_device_types(): - numbers = torch.arange(1, 9, device=device) - - single_roll = numbers.roll(1, 0) - expected = torch.tensor([8, 1, 2, 3, 4, 5, 6, 7], device=device) - self.assertEqual(single_roll, expected, "{} did not equal expected result".format(single_roll)) - - roll_backwards = numbers.roll(-2, 0) - expected = torch.tensor([3, 4, 5, 6, 7, 8, 1, 2], device=device) - self.assertEqual(roll_backwards, expected, "{} did not equal expected result".format(roll_backwards)) - - data = numbers.view(2, 2, 2) - rolled = data.roll(1, 0) - expected = torch.tensor([5, 6, 7, 8, 1, 2, 3, 4], device=device).view(2, 2, 2) - self.assertEqual(expected, rolled, "{} did not equal expected result: {}".format(rolled, expected)) - - data = data.view(2, 4) - # roll a loop until back where started - loop_rolled = data.roll(2, 0).roll(4, 1) - self.assertEqual(data, loop_rolled, "{} did not equal the original: {}".format(loop_rolled, data)) - # multiple inverse loops - self.assertEqual(data, data.roll(-20, 0).roll(-40, 1)) - self.assertEqual(torch.tensor([8, 1, 2, 3, 4, 5, 6, 7], device=device), numbers.roll(1, 0)) - - # test non-contiguous - # strided equivalent to numbers.as_strided(size=(4, 2), stride=(1, 4)) - strided = numbers.view(2, 4).transpose(0, 1) - self.assertFalse(strided.is_contiguous(), "this test needs a non-contiguous tensor") - expected = torch.tensor([4, 8, 1, 5, 2, 6, 3, 7]).view(4, 2) - rolled = strided.roll(1, 0) - self.assertEqual(expected, rolled, - "non contiguous tensor rolled to {} instead of {} ".format(rolled, expected)) - - # test roll with no dimension specified - expected = numbers.roll(1, 0).view(2, 4) - self.assertEqual(expected, data.roll(1), "roll with no dims should flatten and roll.") - self.assertEqual(expected, data.roll(1, dims=None), "roll with no dims should flatten and roll.") - - # test roll over multiple dimensions - expected = torch.tensor([[7, 8, 5, 6], [3, 4, 1, 2]], device=device) - double_rolled = data.roll(shifts=(2, -1), dims=(1, 0)) - self.assertEqual(double_rolled, expected, - "should be able to roll over two dimensions, got {}".format(double_rolled)) - - self.assertRaisesRegex(RuntimeError, "required", lambda: data.roll(shifts=(), dims=())) - self.assertRaisesRegex(RuntimeError, "required", lambda: data.roll(shifts=(), dims=1)) - # shifts/dims should align - self.assertRaisesRegex(RuntimeError, "align", lambda: data.roll(shifts=(1, 2), dims=(1,))) - self.assertRaisesRegex(RuntimeError, "align", lambda: data.roll(shifts=(1,), dims=(1, 2))) - - def test_reversed(self): - val = torch.arange(0, 10) - self.assertEqual(reversed(val), torch.arange(9, -1, -1)) - - val = torch.arange(1, 10).view(3, 3) - self.assertEqual(reversed(val), torch.tensor([[7, 8, 9], [4, 5, 6], [1, 2, 3]])) - - val = torch.tensor(42) - self.assertEqual(reversed(val), torch.tensor(42)) - - def test_contains(self): - x = torch.arange(0, 10) - self.assertEqual(4 in x, True) - self.assertEqual(12 in x, False) - - x = torch.arange(1, 10).view(3, 3) - val = torch.arange(1, 4) - self.assertEqual(val in x, True) - val += 10 - self.assertEqual(val in x, False) - - @staticmethod - def _test_rot90(self, use_cuda=False): - device = torch.device("cuda" if use_cuda else "cpu") - data = torch.arange(1, 5, device=device).view(2, 2) - self.assertEqual(torch.tensor([1, 2, 3, 4]).view(2, 2), data.rot90(0, [0, 1])) - self.assertEqual(torch.tensor([2, 4, 1, 3]).view(2, 2), data.rot90(1, [0, 1])) - self.assertEqual(torch.tensor([4, 3, 2, 1]).view(2, 2), data.rot90(2, [0, 1])) - self.assertEqual(torch.tensor([3, 1, 4, 2]).view(2, 2), data.rot90(3, [0, 1])) - - # test for default args k=1, dims=[0, 1] - self.assertEqual(data.rot90(), data.rot90(1, [0, 1])) - - # test for reversed order of dims - self.assertEqual(data.rot90(3, [0, 1]), data.rot90(1, [1, 0])) - - # test for modulo of k - self.assertEqual(data.rot90(5, [0, 1]), data.rot90(1, [0, 1])) - self.assertEqual(data.rot90(3, [0, 1]), data.rot90(-1, [0, 1])) - self.assertEqual(data.rot90(-5, [0, 1]), data.rot90(-1, [0, 1])) - - # test for dims out-of-range error - self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, -3])) - self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 2])) - - # test tensor with more than 2D - data = torch.arange(1, 9, device=device).view(2, 2, 2) - self.assertEqual(torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2])) - self.assertEqual(data.rot90(1, [1, -1]), data.rot90(1, [1, 2])) - - # test for errors - self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 3])) - self.assertRaises(RuntimeError, lambda: data.rot90(1, [1, 1])) - self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 1, 2])) - self.assertRaises(RuntimeError, lambda: data.rot90(1, [0])) - - def test_rot90(self): - self._test_rot90(self, use_cuda=False) - - def test_storage(self): - v = torch.randn(3, 5) - self.assertEqual(v.storage()[0], v.data[0][0]) - self.assertEqual(v.storage()[14], v.data[2][4]) - - def test_nonzero(self): - devices = torch.testing.get_all_device_types() - num_srcs = [ - 12, 12, 12, 12, 12, 125, - ] - - types = [ - 'torch.ByteTensor', - 'torch.CharTensor', - 'torch.ShortTensor', - 'torch.IntTensor', - 'torch.FloatTensor', - 'torch.DoubleTensor', - 'torch.LongTensor', - ] - - shapes = [ - torch.Size((12,)), - torch.Size((12, 1)), - torch.Size((1, 12)), - torch.Size((6, 2)), - torch.Size((3, 2, 2)), - torch.Size((5, 5, 5)), - ] - - def is_lexicographically_sorted(inds): - """Check sorted ascending with - i -> j -> k changing slowest to fastest""" - assert inds.size(1) == 3 - if inds.size(0) > 1: - i0, j0, k0 = inds[:-1].t() - i1, j1, k1 = inds[+1:].t() - i_ok = (i1 >= i0) - j_ok = (j1 >= j0) | (i1 > i0) - k_ok = (k1 >= k0) | (j1 > j0) | (i1 > i0) - lex = torch.stack((i_ok, j_ok, k_ok), dim=1) - return lex - return torch.full_like(inds, 1) - - def gen_nontrivial_input(num_src, dtype, device): - while True: - tensor = torch.rand(num_src).mul(2).floor().type(dtype).to(device) - if tensor.sum() > 0: - return tensor - - for device in devices: - for dtype in types: - for shape, num_src in zip(shapes, num_srcs): - tensor = gen_nontrivial_input(num_src, dtype, device) - tensor = tensor.clone().resize_(shape) - dst1 = torch.nonzero(tensor) - dst2 = tensor.nonzero() - dst3 = torch.LongTensor().to(device) - torch.nonzero(tensor, out=dst3) - - self.assertRaisesRegex( - TypeError, - "received an invalid combination of arguments", - lambda: torch.nonzero(tensor, as_tuple=True, out=dst3)) - if len(shape) == 1: - dst = [] - for i in range(num_src): - if tensor[i] != 0: - dst += [i] - dst = torch.LongTensor(dst).to(device) - self.assertEqual(dst1.select(1, 0), dst, 0) - self.assertEqual(dst2.select(1, 0), dst, 0) - self.assertEqual(dst3.select(1, 0), dst, 0) - elif len(shape) == 2: - # This test will allow through some False positives. It only checks - # that the elements flagged positive are indeed non-zero. - for i in range(dst1.size(0)): - self.assertNotEqual(tensor[dst1[i, 0], dst1[i, 1]].item(), 0) - elif len(shape) == 3: - # This test will allow through some False positives. It only checks - # that the elements flagged positive are indeed non-zero. - for i in range(dst1.size(0)): - self.assertNotEqual(tensor[dst1[i, 0], dst1[i, 1], dst1[i, 2]].item(), 0) - lex = is_lexicographically_sorted(dst1) - self.assertEqual(torch.ones_like(lex), lex) - if TEST_NUMPY: - tup1 = torch.nonzero(tensor, as_tuple=True) - tup2 = tensor.nonzero(as_tuple=True) - tup3 = torch.where(tensor) - np1 = tensor.cpu().numpy().nonzero() - for t in (tup1, tup2, tup3): - self.assertEqual(len(t), len(np1)) - for i in range(len(t)): - self.assertEqual(t[i].cpu().numpy(), np1[i]) - - def test_nonzero_empty(self): - def assert_tuple_empty(tup, dim): - self.assertEqual(dim, len(tup)) - for t in tup: - self.assertEqual(torch.Size([0]), t.shape) - for device in torch.testing.get_all_device_types(): - x = torch.randn(0, 2, 0, 5, 0, device=device) - y = torch.nonzero(x) - z = torch.nonzero(x, as_tuple=True) - - self.assertEqual(0, y.numel()) - self.assertEqual(torch.Size([0, 5]), y.shape) - assert_tuple_empty(z, 5) - - x = torch.tensor(0.5, device=device) - y = torch.nonzero(x) - # nonzero with as_tuple returns a - # tuple of len 1 for a zero-dim tensor. - # This is done to match Numpy behavior. - z = torch.nonzero(x, as_tuple=True) - self.assertEqual(1, len(z)) - self.assertEqual(torch.zeros(1, dtype=torch.long), z[0]) - - x = torch.zeros((), device=device) - y = torch.nonzero(x) - z = torch.nonzero(x, as_tuple=True) - self.assertEqual(torch.Size([0, 0]), y.shape) - self.assertEqual(1, len(z)) - self.assertEqual(torch.empty(0, dtype=torch.long), z[0]) - - def test_deepcopy(self): - from copy import deepcopy - a = torch.randn(5, 5) - b = torch.randn(5, 5) - c = a.view(25) - q = [a, [a.storage(), b.storage()], b, c] - w = deepcopy(q) - self.assertEqual(w[0], q[0], 0) - self.assertEqual(w[1][0], q[1][0], 0) - self.assertEqual(w[1][1], q[1][1], 0) - self.assertEqual(w[1], q[1], 0) - self.assertEqual(w[2], q[2], 0) - - # Check that deepcopy preserves sharing - w[0].add_(1) - for i in range(a.numel()): - self.assertEqual(w[1][0][i], q[1][0][i] + 1) - self.assertEqual(w[3], c + 1) - w[2].sub_(1) - for i in range(a.numel()): - self.assertEqual(w[1][1][i], q[1][1][i] - 1) - - def test_deepcopy_scalar(self): - from copy import deepcopy - a = torch.tensor(5) - self.assertEqual(a.size(), deepcopy(a).size()) - self.assertEqual(a, deepcopy(a)) - - def test_deepcopy_parameter(self): - from copy import deepcopy - l = torch.nn.Linear(10, 1) - s = l.state_dict(keep_vars=True) - self.assertEqual(torch.nn.Parameter, type(s['weight'])) - self.assertEqual(torch.nn.Parameter, type(s['bias'])) - - s2 = deepcopy(s) - self.assertEqual(torch.nn.Parameter, type(s2['weight'])) - self.assertEqual(torch.nn.Parameter, type(s2['bias'])) - - def test_pickle(self): - if sys.version_info[0] == 2: - import cPickle as pickle - else: - import pickle - a = torch.randn(5, 5) - serialized = pickle.dumps(a) - b = pickle.loads(serialized) - self.assertEqual(a, b) - - def test_pickle_parameter(self): - if sys.version_info[0] == 2: - import cPickle as pickle - else: - import pickle - a = torch.nn.Parameter(torch.randn(5, 5)) - serialized = pickle.dumps(a) - b = pickle.loads(serialized) - self.assertTrue(isinstance(b, torch.nn.Parameter)) - self.assertEqual(a.requires_grad, b.requires_grad) - self.assertEqual(a, b) - - def test_pickle_parameter_no_requires_grad(self): - if sys.version_info[0] == 2: - import cPickle as pickle - else: - import pickle - a = torch.nn.Parameter(torch.randn(5, 5), requires_grad=False) - serialized = pickle.dumps(a) - b = pickle.loads(serialized) - self.assertTrue(isinstance(b, torch.nn.Parameter)) - self.assertEqual(a.requires_grad, b.requires_grad) - self.assertEqual(a, b) - - def test_pickle_dtype(self): - t = torch.float32 - serialized = pickle.dumps(t) - b = pickle.loads(serialized) - self.assertTrue(isinstance(b, torch.dtype)) - self.assertEqual(id(b), id(t)) - - def test_pickle_size(self): - a = torch.rand(10).size() - serialized = pickle.dumps(a) - b = pickle.loads(serialized) - self.assertTrue(isinstance(b, torch.Size)) - self.assertEqual(a, b) - - def test_norm_fastpaths(self): - x = torch.randn(3, 5) - - # slow path - result = torch.norm(x, 4.5, 1) - expected = torch.pow(x.abs().pow(4.5).sum(1), 1.0 / 4.5) - self.assertEqual(result, expected) - - # fast 0-norm - result = torch.norm(x, 0, 1) - expected = (x != 0).type_as(x).sum(1) - self.assertEqual(result, expected) - - # fast 1-norm - result = torch.norm(x, 1, 1) - expected = x.abs().sum(1) - self.assertEqual(result, expected) - - # fast 2-norm - result = torch.norm(x, 2, 1) - expected = torch.sqrt(x.pow(2).sum(1)) - self.assertEqual(result, expected) - - # fast 3-norm - result = torch.norm(x, 3, 1) - expected = torch.pow(x.pow(3).abs().sum(1), 1.0 / 3.0) - self.assertEqual(result, expected) - - @staticmethod - def _test_bernoulli(self, t_dtype, p_dtype, device): - for trivial_p in ([0, 1], [1, 0, 1, 1, 0, 1]): - x = torch.tensor(trivial_p, dtype=p_dtype, device=device) - self.assertEqual(x.bernoulli().tolist(), trivial_p) - - def isBinary(t): - return torch.ne(t, 0).mul_(torch.ne(t, 1)).sum().item() == 0 - - p = torch.rand(5, 5, dtype=p_dtype, device=device) - self.assertTrue(isBinary(p.bernoulli())) - - p = torch.rand(5, dtype=p_dtype, device=device).expand(5, 5) - self.assertTrue(isBinary(p.bernoulli())) - - p = torch.rand(5, 5, dtype=p_dtype, device=device) - torch.bernoulli(torch.rand_like(p), out=p) - self.assertTrue(isBinary(p)) - - p = torch.rand(5, dtype=p_dtype, device=device).expand(5, 5) - torch.bernoulli(torch.rand_like(p), out=p) - self.assertTrue(isBinary(p)) - - t = torch.empty(10, 10, dtype=t_dtype, device=device) - - t.fill_(2) - t.bernoulli_(0.5) - self.assertTrue(isBinary(t)) - - p = torch.rand(10, dtype=p_dtype, device=device).expand(10, 10) - t.fill_(2) - t.bernoulli_(p) - self.assertTrue(isBinary(t)) - - t.fill_(2) - torch.bernoulli(torch.rand_like(t, dtype=p_dtype), out=t) - self.assertTrue(isBinary(t)) - - t.fill_(2) - t.bernoulli_(torch.rand_like(t, dtype=p_dtype)) - self.assertTrue(isBinary(t)) - - def test_bernoulli(self): - self._test_bernoulli(self, torch.float32, torch.float64, 'cpu') - # test that it works with integral tensors - self._test_bernoulli(self, torch.uint8, torch.float64, 'cpu') - # test that it works with bool tensors - self._test_bernoulli(self, torch.bool, torch.float32, 'cpu') - - def test_normal(self): - for device in torch.testing.get_all_device_types(): - q = torch.empty(100, 100, device=device).normal_() - self.assertEqual(q.mean(), 0, 0.2) - self.assertEqual(q.std(), 1, 0.2) - - q.normal_(2, 3) - self.assertEqual(q.mean(), 2, 0.3) - self.assertEqual(q.std(), 3, 0.3) - - q = torch.empty(100, 100, device=device) - q_row1 = q[0:1].clone() - q[99:100].normal_() - self.assertEqual(q[99:100].mean(), 0, 0.2) - self.assertEqual(q[99:100].std(), 1, 0.2) - self.assertEqual(q[0:1].clone(), q_row1) - - mean = torch.empty(100, 100, device=device) - std = torch.empty(100, 100, device=device) - mean[:50] = 0 - mean[50:] = 1 - std[:, :50] = 4 - std[:, 50:] = 1 - - r = torch.normal(mean) - self.assertEqual(r[:50].mean(), 0, 0.2) - self.assertEqual(r[50:].mean(), 1, 0.2) - self.assertEqual(r.std(), 1, 0.2) - - r = torch.normal(mean, 3) - self.assertEqual(r[:50].mean(), 0, 0.2) - self.assertEqual(r[50:].mean(), 1, 0.2) - self.assertEqual(r.std(), 3, 0.2) - - r = torch.normal(2, std) - self.assertEqual(r.mean(), 2, 0.2) - self.assertEqual(r[:, :50].std(), 4, 0.3) - self.assertEqual(r[:, 50:].std(), 1, 0.2) - - r = torch.normal(mean, std) - self.assertEqual(r[:50].mean(), 0, 0.2) - self.assertEqual(r[50:].mean(), 1, 0.2) - self.assertEqual(r[:, :50].std(), 4, 0.3) - self.assertEqual(r[:, 50:].std(), 1, 0.2) - - r = torch.normal(2, 3, (100, 100)) - self.assertEqual(r.mean(), 2, 0.2) - self.assertEqual(r.std(), 3, 0.2) - - def test_generator_cpu(self): - # test default generators are equal - self.assertEqual(torch.default_generator, torch.default_generator) - - # tests Generator API - # manual_seed, seed, initial_seed, get_state, set_state - g1 = torch.Generator() - g2 = torch.Generator() - g1.manual_seed(12345) - g2.manual_seed(12345) - self.assertEqual(g1.initial_seed(), g2.initial_seed()) - - g1.seed() - g2.seed() - self.assertNotEqual(g1.initial_seed(), g2.initial_seed()) - - g1 = torch.Generator() - g2_state = g2.get_state() - g2_randn = torch.randn(1, generator=g2) - g1.set_state(g2_state) - g1_randn = torch.randn(1, generator=g1) - self.assertEqual(g1_randn, g2_randn) - - default_state = torch.default_generator.get_state() - q = torch.Tensor(100) - g1_normal = q.normal_() - g2 = torch.Generator() - g2.set_state(default_state) - g2_normal = q.normal_(generator=g2) - self.assertEqual(g1_normal, g2_normal) - - def test_sobolengine_unscrambled_lowdim(self): - engine_1d = torch.quasirandom.SobolEngine(1) - expected_1d = torch.tensor([0.5, 0.75, 0.25, 0.375, 0.875, 0.625, 0.125, 0.1875, 0.6875, 0.9375]) - actual_1d = engine_1d.draw(10) - self.assertEqual(actual_1d.view(-1), expected_1d) - self.assertEqual(actual_1d.size(), torch.Size([10, 1])) - - # Test out kwarg - engine_1d.reset() - actual_1d_out = torch.Tensor().float() - engine_1d.draw(10, out=actual_1d_out) - self.assertEqual(actual_1d.view(-1), expected_1d) - - engine_3d = torch.quasirandom.SobolEngine(3) - expected_3d = torch.tensor([0.5, 0.75, 0.25, 0.625, 0.125, 0.375, 0.875, 0.3125, 0.8125, 0.5625]) - actual_3d = engine_3d.draw(10) - self.assertEqual(actual_3d[:, 2], expected_3d) - self.assertEqual(actual_3d[:, 0], expected_1d) - self.assertEqual(actual_3d.size(), torch.Size([10, 3])) - - engine_3d = torch.quasirandom.SobolEngine(3) - draws = torch.cat([engine_3d.draw() for _ in range(0, 10)]) - self.assertEqual(draws, actual_3d) - - engine_3d = torch.quasirandom.SobolEngine(3).fast_forward(5) - draws = engine_3d.draw(5) - self.assertEqual(draws, actual_3d[5:]) - engine_3d.reset() - self.assertEqual(engine_3d.draw(3), actual_3d[:3]) - engine_3d.fast_forward(2) - self.assertEqual(engine_3d.draw(5), actual_3d[5:]) - - def test_sobolengine_unscrambled_highdim(self): - from collections import Counter - engine = torch.quasirandom.SobolEngine(1111) - count1 = dict(Counter(engine.draw().view(-1).tolist())) - count2 = dict(Counter(engine.draw().view(-1).tolist())) - count3 = dict(Counter(engine.draw().view(-1).tolist())) - self.assertTrue(count1 == {0.5: 1111}) - self.assertTrue(count2 == {0.25: 580, 0.75: 531}) - self.assertTrue(count3 == {0.25: 531, 0.75: 580}) - - engine = torch.quasirandom.SobolEngine(1111) - draws = engine.draw(1000) - self.assertTrue(torch.all(draws <= 1)) - self.assertTrue(torch.all(draws >= 0)) - - def test_sobolengine_scrambled_lowdim(self): - engine_1d = torch.quasirandom.SobolEngine(1, scramble=True, seed=1729) - expected_1d = [0.16478512, 0.43221009, 0.84261382, 0.99750268, 0.27460563, - 0.01084163, 0.73373985, 0.65039611, 0.12329865, 0.35587373] - actual_1d = engine_1d.draw(10) - self.assertEqual(actual_1d.flatten(), torch.tensor(expected_1d)) - self.assertEqual(actual_1d.size(), torch.Size([10, 1])) - # make sure random seed if chosen if none is provided - engine_1d_a = torch.quasirandom.SobolEngine(1, scramble=True) - engine_1d_b = torch.quasirandom.SobolEngine(1, scramble=True) - self.assertNotEqual(engine_1d_a.draw(2), engine_1d_b.draw(2)) - - engine_3d = torch.quasirandom.SobolEngine(3, scramble=True, seed=1729) - expected_3d = [0.32642800, 0.17881306, 0.68837059, 0.46492538, 0.91789097, - 0.58075899, 0.03642474, 0.68229187, 0.20051685, 0.30083340] - actual_3d = engine_3d.draw(10) - self.assertEqual(actual_3d[:, 2], torch.tensor(expected_3d)) - self.assertEqual(actual_3d.size(), torch.Size([10, 3])) - - engine_3d = torch.quasirandom.SobolEngine(3, scramble=True, seed=1729) - draws = torch.cat([engine_3d.draw() for _ in range(0, 10)]) - self.assertEqual(draws, actual_3d) - - engine_3d = torch.quasirandom.SobolEngine(3, scramble=True, seed=1729) - engine_3d.fast_forward(5) - draws = engine_3d.draw(5) - self.assertEqual(draws, actual_3d[5:]) - engine_3d.reset() - self.assertEqual(engine_3d.draw(3), actual_3d[:3]) - engine_3d.fast_forward(2) - self.assertEqual(engine_3d.draw(5), actual_3d[5:]) - - def test_sobolengine_scrambled_highdim(self): - engine = torch.quasirandom.SobolEngine(1111, scramble=True) - draws = engine.draw(1000) - self.assertTrue(torch.all(draws <= 1)) - self.assertTrue(torch.all(draws >= 0)) - - def test_parsing_int64(self): - # accepts integer arguments - x = torch.cumsum(torch.ones(5, 5), 0) - self.assertEqual(x, torch.cumsum(torch.ones(5, 5), torch.tensor(0))) - # doesn't accept floating point variables - self.assertRaises(TypeError, lambda: torch.cumsum(torch.ones(5, 5), torch.tensor(0.))) - - def test_parsing_double(self): - # accepts floating point and integer arguments - x = torch.randn(2, 3) - torch.isclose(x, x, 1, 1) - self.assertTrue(torch.isclose(x, x, 1, 1).all()) - self.assertTrue(torch.isclose(x, x, 1.5, 1.).all()) - # accepts floating point and integer tensors - self.assertTrue(torch.isclose(x, x, torch.tensor(1), torch.tensor(1)).all()) - self.assertTrue(torch.isclose(x, x, torch.tensor(1.5), torch.tensor(1.)).all()) - # doesn't accept variables with requires_grad - self.assertRaises(TypeError, - lambda: torch.isclose(x, x, torch.tensor(1.5), torch.tensor(1., requires_grad=True)).all()) - - def test_parsing_intlist(self): - # parse with integer variables - self.assertEqual(torch.Size([3, 4]), torch.ones((torch.tensor(3), torch.tensor(4))).shape) - self.assertEqual(torch.Size([3, 4]), torch.ones(torch.tensor(3), torch.tensor(4)).shape) - # parse with numpy integers - if TEST_NUMPY: - self.assertEqual(torch.Size([3, 4]), torch.ones((np.array(3), np.int64(4))).shape) - self.assertEqual(torch.Size([3, 4]), torch.ones(np.array(3), np.int64(4)).shape) - self.assertEqual(torch.Size([3, 4]), torch.ones((np.int64(3), np.array(4))).shape) - self.assertEqual(torch.Size([3, 4]), torch.ones(np.int64(3), np.array(4)).shape) - - # fail parse with float variables - self.assertRaises(TypeError, lambda: torch.ones((torch.tensor(3.), torch.tensor(4)))) - # fail parse with numpy floats - if TEST_NUMPY: - self.assertRaises(TypeError, lambda: torch.ones((np.float(3.), torch.tensor(4)))) - self.assertRaises(TypeError, lambda: torch.ones((np.array(3.), torch.tensor(4)))) + indices_to_test = [ + [slice(None), slice(None), [0, 3, 4]], + [slice(None), [2, 4, 5, 7], slice(None)], + [[2, 3], slice(None), slice(None)], + [slice(None), [0, 2, 3], [1, 3, 4]], + [slice(None), [0], [1, 2, 4]], + [slice(None), [0, 1, 3], [4]], + [slice(None), [[0, 1], [1, 0]], [[2, 3]]], + [slice(None), [[0, 1], [2, 3]], [[0]]], + [slice(None), [[5, 6]], [[0, 3], [4, 4]]], + [[0, 2, 3], [1, 3, 4], slice(None)], + [[0], [1, 2, 4], slice(None)], + [[0, 1, 3], [4], slice(None)], + [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], + [[[0, 1], [1, 0]], [[2, 3]], slice(None)], + [[[0, 1], [2, 3]], [[0]], slice(None)], + [[[2, 1]], [[0, 3], [4, 4]], slice(None)], + [[[2]], [[0, 3], [4, 1]], slice(None)], + # non-contiguous indexing subspace + [[0, 2, 3], slice(None), [1, 3, 4]], - # fail parse with > 1 element variables - self.assertRaises(TypeError, lambda: torch.ones(torch.tensor(3, 3))) - self.assertRaises(TypeError, lambda: torch.ones((torch.tensor(3, 3)))) - if TEST_NUMPY: - self.assertRaises(TypeError, lambda: torch.ones(np.array(3, 3))) - self.assertRaises(TypeError, lambda: torch.ones((np.array(3, 3)))) + # less dim, ellipsis + [[0, 2], ], + [[0, 2], slice(None)], + [[0, 2], Ellipsis], + [[0, 2], slice(None), Ellipsis], + [[0, 2], Ellipsis, slice(None)], + [[0, 2], [1, 3]], + [[0, 2], [1, 3], Ellipsis], + [Ellipsis, [1, 3], [2, 3]], + [Ellipsis, [2, 3, 4]], + [Ellipsis, slice(None), [2, 3, 4]], + [slice(None), Ellipsis, [2, 3, 4]], - # fail parse with additional positional args after intlist arg - self.assertRaisesRegex(TypeError, - "received an invalid combination of arguments", - lambda: torch.LongTensor((6, 0), 1, 1, 0)) - self.assertRaisesRegex(TypeError, - "missing 1 required positional arguments", - lambda: torch.tensor().new_zeros((5, 5), 0)) + # ellipsis counts for nothing + [Ellipsis, slice(None), slice(None), [0, 3, 4]], + [slice(None), Ellipsis, slice(None), [0, 3, 4]], + [slice(None), slice(None), Ellipsis, [0, 3, 4]], + [slice(None), slice(None), [0, 3, 4], Ellipsis], + [Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)], + [[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)], + [[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis], + ] - def _test_serialization_data(self): - a = [torch.randn(5, 5).float() for i in range(2)] - b = [a[i % 2] for i in range(4)] # 0-3 - b += [a[0].storage()] # 4 - b += [a[0].reshape(-1)[1:4].storage()] # 5 - b += [torch.arange(1, 11).int()] # 6 - t1 = torch.FloatTensor().set_(a[0].reshape(-1)[1:4].clone().storage(), 0, (3,), (1,)) - t2 = torch.FloatTensor().set_(a[0].reshape(-1)[1:4].clone().storage(), 0, (3,), (1,)) - b += [(t1.storage(), t1.storage(), t2.storage())] # 7 - b += [a[0].reshape(-1)[0:2].storage()] # 8 - return b + for indexer in indices_to_test: + assert_get_eq(reference, indexer) + assert_set_eq(reference, indexer, 212) + assert_set_eq(reference, + indexer, + get_set_tensor(reference, indexer)) + if torch.cuda.is_available(): + assert_backward_eq(reference, indexer) - def _test_serialization_assert(self, b, c): - self.assertEqual(b, c, 0) - self.assertTrue(isinstance(c[0], torch.FloatTensor)) - self.assertTrue(isinstance(c[1], torch.FloatTensor)) - self.assertTrue(isinstance(c[2], torch.FloatTensor)) - self.assertTrue(isinstance(c[3], torch.FloatTensor)) - self.assertTrue(isinstance(c[4], torch.FloatStorage)) - c[0].fill_(10) - self.assertEqual(c[0], c[2], 0) - self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0) - c[1].fill_(20) - self.assertEqual(c[1], c[3], 0) - # I have to do it in this roundabout fashion, because there's no - # way to slice storages - for i in range(4): - self.assertEqual(c[4][i + 1], c[5][i]) + reference = torch.arange(0., 1296).view(3, 9, 8, 6).to(device) - # check that serializing the same storage view object unpickles - # it as one object not two (and vice versa) - views = c[7] - self.assertEqual(views[0]._cdata, views[1]._cdata) - self.assertEqual(views[0], views[2]) - self.assertNotEqual(views[0]._cdata, views[2]._cdata) + indices_to_test = [ + [slice(None), slice(None), slice(None), [0, 3, 4]], + [slice(None), slice(None), [2, 4, 5, 7], slice(None)], + [slice(None), [2, 3], slice(None), slice(None)], + [[1, 2], slice(None), slice(None), slice(None)], + [slice(None), slice(None), [0, 2, 3], [1, 3, 4]], + [slice(None), slice(None), [0], [1, 2, 4]], + [slice(None), slice(None), [0, 1, 3], [4]], + [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]], + [slice(None), slice(None), [[0, 1], [2, 3]], [[0]]], + [slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]], + [slice(None), [0, 2, 3], [1, 3, 4], slice(None)], + [slice(None), [0], [1, 2, 4], slice(None)], + [slice(None), [0, 1, 3], [4], slice(None)], + [slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)], + [slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)], + [slice(None), [[0, 1], [3, 2]], [[0]], slice(None)], + [slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)], + [slice(None), [[2]], [[0, 3], [4, 2]], slice(None)], + [[0, 1, 2], [1, 3, 4], slice(None), slice(None)], + [[0], [1, 2, 4], slice(None), slice(None)], + [[0, 1, 2], [4], slice(None), slice(None)], + [[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)], + [[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)], + [[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)], + [[[2]], [[0, 3], [4, 5]], slice(None), slice(None)], + [slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]], + [slice(None), [2, 3, 4], [1, 3, 4], [4]], + [slice(None), [0, 1, 3], [4], [1, 3, 4]], + [slice(None), [6], [0, 2, 3], [1, 3, 4]], + [slice(None), [2, 3, 5], [3], [4]], + [slice(None), [0], [4], [1, 3, 4]], + [slice(None), [6], [0, 2, 3], [1]], + [slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]], + [[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)], + [[2, 0, 1], [1, 2, 3], [4], slice(None)], + [[0, 1, 2], [4], [1, 3, 4], slice(None)], + [[0], [0, 2, 3], [1, 3, 4], slice(None)], + [[0, 2, 1], [3], [4], slice(None)], + [[0], [4], [1, 3, 4], slice(None)], + [[1], [0, 2, 3], [1], slice(None)], + [[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)], - rootview = c[8] - self.assertEqual(rootview.data_ptr(), c[0].data_ptr()) + # less dim, ellipsis + [Ellipsis, [0, 3, 4]], + [Ellipsis, slice(None), [0, 3, 4]], + [Ellipsis, slice(None), slice(None), [0, 3, 4]], + [slice(None), Ellipsis, [0, 3, 4]], + [slice(None), slice(None), Ellipsis, [0, 3, 4]], + [slice(None), [0, 2, 3], [1, 3, 4]], + [slice(None), [0, 2, 3], [1, 3, 4], Ellipsis], + [Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)], + [[0], [1, 2, 4]], + [[0], [1, 2, 4], slice(None)], + [[0], [1, 2, 4], Ellipsis], + [[0], [1, 2, 4], Ellipsis, slice(None)], + [[1], ], + [[0, 2, 1], [3], [4]], + [[0, 2, 1], [3], [4], slice(None)], + [[0, 2, 1], [3], [4], Ellipsis], + [Ellipsis, [0, 2, 1], [3], [4]], + ] - def test_serialization(self): - # Test serialization with a real file - b = self._test_serialization_data() - for use_name in (False, True): - # Passing filename to torch.save(...) will cause the file to be opened twice, - # which is not supported on Windows - if sys.platform == "win32" and use_name: - continue - with tempfile.NamedTemporaryFile() as f: - handle = f if not use_name else f.name - torch.save(b, handle) - f.seek(0) - c = torch.load(handle) - self._test_serialization_assert(b, c) - # test non-ascii encoding of bytes arrays/strings - # The following bytes are produced by serializing - # [b'\xc5\xbc\xc4\x85\xc4\x85\xc3\xb3\xc5\xbc\xc4\x85\xc5\xbc', torch.zeros(1, dtype=torch.float), 2] - # in Python 2.7.12 and PyTorch 0.4.1, where the first element contains - # bytes of some utf-8 characters (i.e., `utf8_str.encode('utf-8')`). - serialized = ( - b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.' - b'\x80\x02}q\x01(U\x10protocol_versionq\x02M\xe9\x03U\n' - b'type_sizesq\x03}q\x04(U\x03intq\x05K\x04U\x05shortq\x06K\x02U' - b'\x04longq\x07K\x04uU\rlittle_endianq\x08\x88u.\x80\x02]q' - b'\x01(U\x0e\xc5\xbc\xc4\x85\xc4\x85\xc3\xb3\xc5\xbc\xc4\x85' - b'\xc5\xbcq\x02ctorch._utils\n_rebuild_tensor_v2\nq\x03((U' - b'\x07storageq\x04ctorch\nFloatStorage\nq\x05U\x0845640624q' - b'\x06U\x03cpuq\x07\x8a\x01\x01NtQK\x00K\x01\x85K\x01\x85' - b'\x89NtRq\x08K\x02e.\x80\x02]q\x01U\x0845640624q\x02a.\x01\x00' - b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' - ) - buf = io.BytesIO(serialized) - utf8_bytes = b'\xc5\xbc\xc4\x85\xc4\x85\xc3\xb3\xc5\xbc\xc4\x85\xc5\xbc' - utf8_str = utf8_bytes.decode('utf-8') - if PY3: - with self.assertRaisesRegex(UnicodeDecodeError, "'ascii' codec can't decode byte"): - loaded = torch.load(buf) - buf.seek(0) - loaded_utf8 = torch.load(buf, encoding='utf-8') - self.assertEqual(loaded_utf8, [utf8_str, torch.zeros(1, dtype=torch.float), 2]) - buf.seek(0) - loaded_bytes = torch.load(buf, encoding='bytes') - else: - loaded_bytes = torch.load(buf) - self.assertEqual(loaded_bytes, [utf8_bytes, torch.zeros(1, dtype=torch.float), 2]) + for indexer in indices_to_test: + assert_get_eq(reference, indexer) + assert_set_eq(reference, indexer, 1333) + assert_set_eq(reference, + indexer, + get_set_tensor(reference, indexer)) + indices_to_test += [ + [slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]], + [slice(None), slice(None), [[2]], [[0, 3], [4, 4]]], + ] + for indexer in indices_to_test: + assert_get_eq(reference, indexer) + assert_set_eq(reference, indexer, 1333) + if torch.cuda.is_available(): + assert_backward_eq(reference, indexer) - def test_serialization_filelike(self): - # Test serialization (load and save) with a filelike object - b = self._test_serialization_data() - with BytesIOContext() as f: - torch.save(b, f) - f.seek(0) - c = torch.load(f) - self._test_serialization_assert(b, c) + def test_advancedindex_big(self, device): + reference = torch.arange(0, 123344).int().to(device) - @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows") - def test_serialization_fake_zip(self): - data = [ - ord('P'), - ord('K'), - 5, - 6 - ] - for i in range(0, 100): - data.append(0) - t = torch.tensor(data, dtype=torch.uint8) + self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ], + torch.LongTensor([0, 123, 44488, 68807, 123343])) - with tempfile.NamedTemporaryFile() as f: - torch.save(t, f.name) + def test_kthvalue(self, device): + SIZE = 50 + x = torch.rand(SIZE, SIZE, SIZE, device=device) + x0 = x.clone() - # If this check is False for all Python versions (i.e. the fix - # has been backported), this test and torch.serialization._is_zipfile - # can be deleted - self.assertTrue(zipfile.is_zipfile(f)) - self.assertFalse(torch.serialization._is_zipfile(f)) - self.assertEqual(torch.load(f.name), t) + k = random.randint(1, SIZE) + res1val, res1ind = torch.kthvalue(x, k, keepdim=False) + res2val, res2ind = torch.sort(x) - def test_serialization_gzip(self): - # Test serialization with gzip file - b = self._test_serialization_data() - f1 = tempfile.NamedTemporaryFile(delete=False) - f2 = tempfile.NamedTemporaryFile(delete=False) - torch.save(b, f1) - with open(f1.name, 'rb') as f_in, gzip.open(f2.name, 'wb') as f_out: - shutil.copyfileobj(f_in, f_out) + self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0) + self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0) + # test use of result tensors + k = random.randint(1, SIZE) + res1val = torch.tensor([], device=device) + res1ind = torch.tensor([], dtype=torch.long, device=device) + torch.kthvalue(x, k, keepdim=False, out=(res1val, res1ind)) + res2val, res2ind = torch.sort(x) + self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0) + self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0) - with gzip.open(f2.name, 'rb') as f: - c = torch.load(f) - self._test_serialization_assert(b, c) + # test non-default dim + k = random.randint(1, SIZE) + res1val, res1ind = torch.kthvalue(x, k, 0, keepdim=False) + res2val, res2ind = torch.sort(x, 0) + self.assertEqual(res1val, res2val[k - 1], 0) + self.assertEqual(res1ind, res2ind[k - 1], 0) - def test_serialization_offset(self): - a = torch.randn(5, 5) - b = torch.randn(2, 2) - m = torch.nn.Conv2d(1, 1, (1, 3)) - i, j = 41, 43 - with tempfile.NamedTemporaryFile() as f: - pickle.dump(i, f) - torch.save(a, f) - pickle.dump(j, f) - torch.save(b, f) - torch.save(m, f) - f.seek(0) - i_loaded = pickle.load(f) - a_loaded = torch.load(f) - j_loaded = pickle.load(f) - b_loaded = torch.load(f) - m_loaded = torch.load(f) - self.assertTrue(torch.equal(a, a_loaded)) - self.assertTrue(torch.equal(b, b_loaded)) - self.assertTrue(m.kernel_size == m_loaded.kernel_size) - self.assertEqual(i, i_loaded) - self.assertEqual(j, j_loaded) + # non-contiguous + y = x.narrow(1, 0, 1) + y0 = y.contiguous() + k = random.randint(1, SIZE) + res1val, res1ind = torch.kthvalue(y, k) + res2val, res2ind = torch.kthvalue(y0, k) + self.assertEqual(res1val, res2val, 0) + self.assertEqual(res1ind, res2ind, 0) - def test_serialization_offset_filelike(self): - a = torch.randn(5, 5) - b = torch.randn(2, 3) - i, j = 41, 43 - with BytesIOContext() as f: - pickle.dump(i, f) - torch.save(a, f) - pickle.dump(j, f) - torch.save(b, f) - f.seek(0) - i_loaded = pickle.load(f) - a_loaded = torch.load(f) - j_loaded = pickle.load(f) - b_loaded = torch.load(f) - self.assertTrue(torch.equal(a, a_loaded)) - self.assertTrue(torch.equal(b, b_loaded)) - self.assertEqual(i, i_loaded) - self.assertEqual(j, j_loaded) + # check that the input wasn't modified + self.assertEqual(x, x0, 0) - def test_serialization_offset_gzip(self): - a = torch.randn(5, 5) - i = 41 - f1 = tempfile.NamedTemporaryFile(delete=False) - f2 = tempfile.NamedTemporaryFile(delete=False) - with open(f1.name, 'wb') as f: - pickle.dump(i, f) - torch.save(a, f) - with open(f1.name, 'rb') as f_in, gzip.open(f2.name, 'wb') as f_out: - shutil.copyfileobj(f_in, f_out) + # simple test case (with repetitions) + y = torch.tensor((3., 5, 4, 1, 1, 5), device=device) + self.assertEqual(torch.kthvalue(y, 3)[0], 3, 0) + self.assertEqual(torch.kthvalue(y, 2)[0], 1, 0) - with gzip.open(f2.name, 'rb') as f: - j = pickle.load(f) - b = torch.load(f) - self.assertTrue(torch.equal(a, b)) - self.assertEqual(i, j) + # simple test case (with NaN) + SIZE = 50 + x = torch.rand(SIZE, SIZE, SIZE, device=device) + x[torch.arange(SIZE), :, torch.randint(50, (50,))] = nan + ks = [random.randint(1, SIZE), 1, SIZE, SIZE - 1] + res2val, res2ind = torch.sort(x) + for k in ks: + res1val, res1ind = torch.kthvalue(x, k, keepdim=False) + self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0) + self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0) - def test_half_tensor(self): - x = torch.randn(5, 5).float() - y = torch.randn(5, 5).float() - xh, yh = x.half(), y.half() + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + def test_lu_solve_batched_non_contiguous(self, device): + from numpy.linalg import solve + from common_utils import random_fullrank_matrix_distinct_singular_value - self.assertEqual(x.half().float(), x, 1e-3) + A = random_fullrank_matrix_distinct_singular_value(2, 2) + b = torch.randn(2, 2, 2) + x_exp = torch.as_tensor(solve(A.permute(0, 2, 1).numpy(), b.permute(2, 1, 0).numpy())).to(device) + A = A.to(device).permute(0, 2, 1) + b = b.to(device).permute(2, 1, 0) + assert not A.is_contiguous() and not b.is_contiguous(), "contiguous inputs" + LU_data, LU_pivots = torch.lu(A) + x = torch.lu_solve(b, LU_data, LU_pivots) + self.assertEqual(x, x_exp) - z = torch.Tensor(5, 5) - self.assertEqual(z.copy_(xh), x, 1e-3) + @slowTest + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_lu_solve_batched_many_batches(self, device): + from common_utils import lu_solve_test_helper - with tempfile.NamedTemporaryFile() as f: - torch.save(xh, f) - f.seek(0) - xh2 = torch.load(f) - self.assertEqual(xh.float(), xh2.float()) + def cast(t): + return t.to(device) - def test_serialize_device(self): - device_str = ['cpu', 'cpu:0', 'cuda', 'cuda:0'] - device_obj = [torch.device(d) for d in device_str] - for device in device_obj: - device_copied = copy.deepcopy(device) - self.assertEqual(device, device_copied) + def run_test(A_dims, b_dims, cast): + b, A, LU_data, LU_pivots = lu_solve_test_helper(self, A_dims, b_dims, cast, True) + x = torch.lu_solve(b, LU_data, LU_pivots) + b_ = torch.matmul(A, x) + self.assertEqual(b_, b.expand_as(b_)) - @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') - def test_half_tensor_cuda(self): - x = torch.randn(5, 5).half() - self.assertEqual(x.cuda(), x) + run_test((5, 65536), (65536, 5, 10), cast) + run_test((5, 262144), (262144, 5, 10), cast) - xc = x.cuda() - with tempfile.NamedTemporaryFile() as f: - torch.save(xc, f) - f.seek(0) - xc2 = torch.load(f) - self.assertIsInstance(xc2, type(xc)) - self.assertEqual(xc.float(), xc2.float()) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + def test_lu_solve_batched_broadcasting(self, device): + from numpy.linalg import solve + from common_utils import random_fullrank_matrix_distinct_singular_value - def _test_serialization_cuda(self, filecontext_lambda): - device_count = torch.cuda.device_count() - t0 = torch.cuda.FloatTensor(5).fill_(1) - torch.cuda.set_device(device_count - 1) - tn = torch.cuda.FloatTensor(3).fill_(2) - torch.cuda.set_device(0) - b = (t0, tn) - with filecontext_lambda() as f: - torch.save(b, f) - f.seek(0) - c = torch.load(f) - self.assertEqual(b, c, 0) - u0, un = c - self.assertEqual(u0.get_device(), 0) - self.assertEqual(un.get_device(), device_count - 1) + def run_test(A_dims, b_dims, device, pivot=True): + A_matrix_size = A_dims[-1] + A_batch_dims = A_dims[:-2] + A = random_fullrank_matrix_distinct_singular_value(A_matrix_size, *A_batch_dims) + b = torch.randn(*b_dims) + x_exp = torch.as_tensor(solve(A.numpy(), b.numpy())).to(device) + A, b = A.to(device), b.to(device) + LU_data, LU_pivots = torch.lu(A, pivot=pivot) + x = torch.lu_solve(b, LU_data, LU_pivots) + self.assertEqual(x, x_exp) - @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') - def test_serialization_cuda(self): - self._test_serialization_cuda(tempfile.NamedTemporaryFile) + # test against numpy.linalg.solve + run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), device) # no broadcasting + run_test((2, 1, 3, 4, 4), (4, 6), device) # broadcasting b + run_test((4, 4), (2, 1, 3, 4, 2), device) # broadcasting A + run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), device) # broadcasting A & b - @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') - def test_serialization_cuda_filelike(self): - self._test_serialization_cuda(BytesIOContext) + def test_dim_reduction(self, device): + example = [[-1, 2, 1], [5, 3, 6]] - def test_serialization_backwards_compat(self): - a = [torch.arange(1 + i, 26 + i).view(5, 5).float() for i in range(2)] - b = [a[i % 2] for i in range(4)] - b += [a[0].storage()] - b += [a[0].reshape(-1)[1:4].clone().storage()] - path = download_file('https://download.pytorch.org/test_data/legacy_serialized.pt') - c = torch.load(path) - self.assertEqual(b, c, 0) - self.assertTrue(isinstance(c[0], torch.FloatTensor)) - self.assertTrue(isinstance(c[1], torch.FloatTensor)) - self.assertTrue(isinstance(c[2], torch.FloatTensor)) - self.assertTrue(isinstance(c[3], torch.FloatTensor)) - self.assertTrue(isinstance(c[4], torch.FloatStorage)) - c[0].fill_(10) - self.assertEqual(c[0], c[2], 0) - self.assertEqual(c[4], torch.FloatStorage(25).fill_(10), 0) - c[1].fill_(20) - self.assertEqual(c[1], c[3], 0) + types = [torch.double, + torch.float, + torch.int64, + torch.int32, + torch.int16] - # test some old tensor serialization mechanism - class OldTensorBase(object): - def __init__(self, new_tensor): - self.new_tensor = new_tensor + # This won't test for 256bit instructions, since we usually + # only work on 1 cacheline (1024bit) at a time and these + # examples aren't big enough to trigger that. + for dtype in types: + x = torch.tensor(example, device=device, dtype=dtype) + self.assertEqual(x.sum().item(), 16) + self.assertEqual(x.sum(0), torch.FloatTensor([4, 5, 7])) + self.assertEqual(x.sum(1), torch.FloatTensor([2, 14])) + y = torch.tensor(example, device=device, dtype=dtype) + torch.sum(x, 0, out=y) + self.assertEqual(x.sum(0), y) - def __getstate__(self): - return (self.new_tensor.storage(), - self.new_tensor.storage_offset(), - tuple(self.new_tensor.size()), - self.new_tensor.stride()) + # Mean not supported for Int types + for dtype in types[:2]: + x = torch.tensor(example, device=device, dtype=dtype) + self.assertEqual(x.mean().item(), 16.0 / 6) + self.assertEqual(x.mean(0), torch.FloatTensor([2.0, 2.5, 7.0 / 2])) + self.assertEqual(x.mean(1), torch.FloatTensor([2.0 / 3, 14.0 / 3])) + self.assertEqual(x.mean(), x.mean((0, 1))) - class OldTensorV1(OldTensorBase): - def __reduce__(self): - return (torch.Tensor, (), self.__getstate__()) + for dtype in types: + x = torch.tensor(example, device=device, dtype=dtype) + self.assertEqual(x.prod().item(), -180) + self.assertEqual(x.prod(0), torch.FloatTensor([-5, 6, 6])) + self.assertEqual(x.prod(1), torch.FloatTensor([-2, 90])) - class OldTensorV2(OldTensorBase): - def __reduce__(self): - return (_rebuild_tensor, self.__getstate__()) + for dtype in types: + x = torch.tensor(example, device=device, dtype=dtype) + self.assertEqual(x.max().item(), 6) + self.assertEqual(x.max(0), (torch.FloatTensor([5, 3, 6]), torch.FloatTensor([1, 1, 1]))) + self.assertEqual(x.max(1), (torch.FloatTensor([2, 6]), torch.FloatTensor([1, 2]))) - x = torch.randn(30).as_strided([2, 3], [9, 3], 2) - for old_cls in [OldTensorV1, OldTensorV2]: - with tempfile.NamedTemporaryFile() as f: - old_x = old_cls(x) - torch.save(old_x, f) - f.seek(0) - load_x = torch.load(f) - self.assertEqual(x.storage(), load_x.storage()) - self.assertEqual(x.storage_offset(), load_x.storage_offset()) - self.assertEqual(x.size(), load_x.size()) - self.assertEqual(x.stride(), load_x.stride()) + for dtype in types: + x = torch.tensor(example, device=device, dtype=dtype) + self.assertEqual(x.min().item(), -1) + self.assertEqual(x.min(0), (torch.FloatTensor([-1, 2, 1]), torch.FloatTensor([0, 0, 0]))) + self.assertEqual(x.min(1), (torch.FloatTensor([-1, 3]), torch.FloatTensor([0, 1]))) - # unique_key is necessary because on Python 2.7, if a warning passed to - # the warning module is the same, it is not raised again. - def _test_serialization_container(self, unique_key, filecontext_lambda): - tmpmodule_name = 'tmpmodule{}'.format(unique_key) + for dtype in types: + x = torch.tensor(example, device=device, dtype=dtype) + self.assertEqual(x.argmax().item(), 5) + self.assertEqual(x.argmax(dim=None).item(), 5) + self.assertEqual(x.argmax(dim=0), torch.FloatTensor([1, 1, 1])) + self.assertEqual(x.argmax(dim=1), torch.FloatTensor([1, 2])) + self.assertEqual(x.argmax(dim=0, keepdim=True), torch.FloatTensor([[1, 1, 1]])) + # test that non-contiguous tensors work + self.assertEqual(x[:, :2].argmax().item(), 2) - def import_module(name, filename): - if sys.version_info >= (3, 5): - import importlib.util - spec = importlib.util.spec_from_file_location(name, filename) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - else: - import imp - module = imp.load_source(name, filename) - sys.modules[module.__name__] = module - return module + for dtype in types: + x = torch.tensor(example, device=device, dtype=dtype) + self.assertEqual(x.argmin().item(), 0) + self.assertEqual(x.argmin(dim=None).item(), 0) + self.assertEqual(x.argmin(dim=0), torch.FloatTensor([0, 0, 0])) + self.assertEqual(x.argmin(dim=1), torch.FloatTensor([0, 1])) + self.assertEqual(x.argmin(dim=1, keepdim=True), torch.FloatTensor([[0], [1]])) + # test that non-contiguous tensors work + self.assertEqual(x[:, :2].argmin().item(), 0) - with filecontext_lambda() as checkpoint: - fname = get_file_path_2(os.path.dirname(__file__), 'data', 'network1.py') - module = import_module(tmpmodule_name, fname) - torch.save(module.Net(), checkpoint) + dim_red_fns = [ + "mean", "median", "mode", "norm", "prod", + "std", "sum", "var", "max", "min"] - # First check that the checkpoint can be loaded without warnings - checkpoint.seek(0) - with warnings.catch_warnings(record=True) as w: - loaded = torch.load(checkpoint) - self.assertTrue(isinstance(loaded, module.Net)) - if can_retrieve_source: - self.assertEquals(len(w), 0) + def normfn_attr(t, dim, keepdim=False, out=None): + attr = torch.norm + return attr(t, 2, dim, keepdim, out=out) - # Replace the module with different source - fname = get_file_path_2(os.path.dirname(__file__), 'data', 'network2.py') - module = import_module(tmpmodule_name, fname) - checkpoint.seek(0) - with warnings.catch_warnings(record=True) as w: - loaded = torch.load(checkpoint) - self.assertTrue(isinstance(loaded, module.Net)) - if can_retrieve_source: - self.assertEquals(len(w), 1) - self.assertTrue(w[0].category, 'SourceChangeWarning') + for fn_name in dim_red_fns: + fn_attr = getattr(torch, fn_name) if fn_name != "norm" else normfn_attr - def test_serialization_container(self): - self._test_serialization_container('file', tempfile.NamedTemporaryFile) + def fn(x, dim, keepdim=False, out=None): + ans = fn_attr(x, dim, keepdim=keepdim, out=out) + return ans if not istuple(ans) else ans[0] - def test_serialization_container_filelike(self): - self._test_serialization_container('filelike', BytesIOContext) + def fn_tuple(x, dim, keepdim=False, out=None): + return fn_attr(x, dim, keepdim=keepdim, out=out) - def test_serialization_map_location(self): - test_file_path = download_file('https://download.pytorch.org/test_data/gpu_tensors.pt') + def test_multidim(x, dim): + self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, keepdim=True)) + self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension()) + self.assertEqual(x.ndimension(), fn(x, dim, keepdim=True).ndimension()) - def map_location(storage, loc): - return storage + # general case + x = torch.randn(3, 4, 5, device=device) + dim = random.randint(0, 2) + test_multidim(x, dim) - def load_bytes(): - with open(test_file_path, 'rb') as f: - return io.BytesIO(f.read()) + # check 1-d behavior + x = torch.randn(1, device=device) + dim = 0 + self.assertEqual(fn(x, dim).shape, ()) + self.assertEqual(fn(x, dim, keepdim=True).shape, (1,)) - fileobject_lambdas = [lambda: test_file_path, load_bytes] - cpu_map_locations = [ - map_location, - {'cuda:0': 'cpu'}, - 'cpu', - torch.device('cpu'), - ] - gpu_0_map_locations = [ - {'cuda:0': 'cuda:0'}, - 'cuda', - 'cuda:0', - torch.device('cuda'), - torch.device('cuda', 0) - ] - gpu_last_map_locations = [ - 'cuda:{}'.format(torch.cuda.device_count() - 1), - ] + # check reducing of a singleton dimension + dims = [3, 4, 5] + singleton_dim = random.randint(0, 2) + dims[singleton_dim] = 1 + x = torch.randn(dims, device=device) + test_multidim(x, singleton_dim) - def check_map_locations(map_locations, tensor_class, intended_device): - for fileobject_lambda in fileobject_lambdas: - for map_location in map_locations: - tensor = torch.load(fileobject_lambda(), map_location=map_location) + # check reducing with output kwargs + if fn_name in ['median', 'mode', 'max', 'min']: + y = torch.randn(5, 3, device=device) + values = torch.randn(5, 3, device=device) + indices = torch.zeros(5, 3, device=device).long() - 1 + fn_tuple(y, 1, keepdim=False, out=(values[:, 1], indices[:, 1])) + values_expected, indices_expected = fn_tuple(y, 1, keepdim=False) + self.assertEqual(values[:, 1], values_expected, + '{} values with out= kwarg'.format(fn_name)) + self.assertEqual(indices[:, 1], indices_expected, + '{} indices with out= kwarg'.format(fn_name)) + continue - self.assertEqual(tensor.device, intended_device) - self.assertIsInstance(tensor, tensor_class) - self.assertEqual(tensor, tensor_class([[1.0, 2.0], [3.0, 4.0]])) + x = torch.randn(5, 3, device=device) + y = torch.randn(5, 3, device=device) + fn(y, 1, keepdim=False, out=x[:, 1]) + expected = fn(y, 1, keepdim=False) + self.assertEqual(x[:, 1], expected, '{} with out= kwarg'.format(fn_name)) - check_map_locations(cpu_map_locations, torch.FloatTensor, torch.device('cpu')) - if torch.cuda.is_available(): - check_map_locations(gpu_0_map_locations, torch.cuda.FloatTensor, torch.device('cuda', 0)) - check_map_locations( - gpu_last_map_locations, - torch.cuda.FloatTensor, - torch.device('cuda', torch.cuda.device_count() - 1) - ) + def test_remainder_overflow(self, device): + # Check Integer Overflows + x = torch.tensor(23500, dtype=torch.int64, device=device) + q = 392486996410368 + self.assertEqual(x % q, x) + self.assertEqual(-x % q, q - x) + self.assertEqual(x % -q, x - q) + self.assertEqual(-x % -q, -x) - @unittest.skipIf(torch.cuda.is_available(), "Testing torch.load on CPU-only machine") - @unittest.skipIf(not PY3, "Test tensors were serialized using python 3") - def test_load_nonexistent_device(self): - # Setup: create a serialized file object with a 'cuda:0' restore location - # The following was generated by saving a torch.randn(2, device='cuda') tensor. - serialized = (b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9' - b'\x03.\x80\x02}q\x00(X\x10\x00\x00\x00protocol_versionq' - b'\x01M\xe9\x03X\r\x00\x00\x00little_endianq\x02\x88X\n' - b'\x00\x00\x00type_sizesq\x03}q\x04(X\x05\x00\x00\x00shortq' - b'\x05K\x02X\x03\x00\x00\x00intq\x06K\x04X\x04\x00\x00\x00' - b'longq\x07K\x04uu.\x80\x02ctorch._utils\n_rebuild_tensor_v2' - b'\nq\x00((X\x07\x00\x00\x00storageq\x01ctorch\nFloatStorage' - b'\nq\x02X\x0e\x00\x00\x0094919395964320q\x03X\x06\x00\x00' - b'\x00cuda:0q\x04K\x02Ntq\x05QK\x00K\x02\x85q\x06K\x01\x85q' - b'\x07\x89Ntq\x08Rq\t.\x80\x02]q\x00X\x0e\x00\x00\x00' - b'94919395964320q\x01a.\x02\x00\x00\x00\x00\x00\x00\x00\xbb' - b'\x1f\x82\xbe\xea\x81\xd1>') + def test_rpow(self, device): + m = torch.randn(10, 10, device=device) + self.assertEqual(torch.pow(2, m), 2**m) - buf = io.BytesIO(serialized) + # test with scalar + m = torch.randn(1, device=device).squeeze() + assert m.dim() == 0, "m is intentionally a scalar" + self.assertEqual(torch.pow(2, m), 2**m) - error_msg = r'Attempting to deserialize object on a CUDA device' - with self.assertRaisesRegex(RuntimeError, error_msg): - _ = torch.load(buf) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_symeig(self, device): + from common_utils import random_symmetric_matrix - def test_serialization_filelike_api_requirements(self): - filemock = FilelikeMock(b'', has_readinto=False) - tensor = torch.randn(3, 5) - torch.save(tensor, filemock) - expected_superset = {'write', 'flush'} - self.assertTrue(expected_superset.issuperset(filemock.calls)) + def run_test(dims, eigenvectors, upper): + x = random_symmetric_matrix(*dims).to(device) + oute = torch.empty(dims[1:] + dims[:1], device=device) + outv = torch.empty(dims[1:] + dims[:1] * 2, device=device) + torch.symeig(x, eigenvectors=eigenvectors, upper=upper, out=(oute, outv)) - # Reset between save and load - filemock.seek(0) - filemock.calls.clear() + if eigenvectors: + x_recon = torch.matmul(torch.matmul(outv, torch.diag_embed(oute)), outv.transpose(-2, -1)) + self.assertEqual(x, x_recon, 1e-8, 'Incorrect reconstruction using V @ diag(e) @ V.T') + else: + eigvals, _ = torch.symeig(x, eigenvectors=True, upper=upper) + self.assertEqual(eigvals, oute, 'Eigenvalues mismatch') + self.assertEqual(torch.zeros_like(outv), outv, 'Eigenvector matrix not zero') - _ = torch.load(filemock) - expected_superset = {'read', 'readline', 'seek', 'tell'} - self.assertTrue(expected_superset.issuperset(filemock.calls)) + rese, resv = x.symeig(eigenvectors=eigenvectors, upper=upper) + self.assertEqual(rese, oute, "outputs of symeig and symeig with out don't match") + self.assertEqual(resv, outv, "outputs of symeig and symeig with out don't match") - def _test_serialization_filelike(self, tensor, mock, desc): - f = mock(b'') - torch.save(tensor, f) - f.seek(0) - data = mock(f.read()) + # test non-contiguous + x = random_symmetric_matrix(*dims).to(device) + n_dim = len(dims) + 1 + # Reverse the batch dimensions and the matrix dimensions and then concat them + x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) + assert not x.is_contiguous(), "x is intentionally non-contiguous" + rese, resv = torch.symeig(x, eigenvectors=eigenvectors, upper=upper) + if eigenvectors: + x_recon = torch.matmul(torch.matmul(resv, torch.diag_embed(rese)), resv.transpose(-2, -1)) + self.assertEqual(x, x_recon, 1e-8, 'Incorrect reconstruction using V @ diag(e) @ V.T') + else: + eigvals, _ = torch.symeig(x, eigenvectors=True, upper=upper) + self.assertEqual(eigvals, rese, 'Eigenvalues mismatch') + self.assertEqual(torch.zeros_like(resv), resv, 'Eigenvector matrix not zero') - msg = 'filelike serialization with {}' + batch_dims_set = [(), (3,), (3, 5), (5, 3, 5)] + for batch_dims, eigenvectors, upper in product(batch_dims_set, (True, False), (True, False)): + run_test((5,) + batch_dims, eigenvectors, upper) - b = torch.load(data) - self.assertTrue(torch.equal(tensor, b), msg.format(desc)) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_svd(self, device): + def run_test(dims, some, compute_uv): + x = torch.randn(*dims, device=device) + outu, outs, outv = torch.Tensor().to(device), torch.Tensor().to(device), torch.Tensor().to(device) + torch.svd(x, some=some, compute_uv=compute_uv, out=(outu, outs, outv)) + + if compute_uv: + if some: + x_recon = torch.matmul(outu, torch.matmul(outs.diag_embed(), outv.transpose(-2, -1))) + self.assertEqual(x, x_recon, 1e-8, 'Incorrect reconstruction using U @ diag(S) @ V.T') + else: + narrow_u = outu[..., :min(*dims[-2:])] + narrow_v = outv[..., :min(*dims[-2:])] + x_recon = torch.matmul(narrow_u, torch.matmul(outs.diag_embed(), narrow_v.transpose(-2, -1))) + self.assertEqual(x, x_recon, 1e-8, 'Incorrect reconstruction using U @ diag(S) @ V.T') + else: + _, singvals, _ = torch.svd(x, compute_uv=True) + self.assertEqual(singvals, outs, 'Singular values mismatch') + self.assertEqual(outu, torch.zeros_like(outu), 'U not zero') + self.assertEqual(outv, torch.zeros_like(outv), 'V not zero') - def test_serialization_filelike_missing_attrs(self): - # Test edge cases where filelike objects are missing attributes. - # The Python io docs suggests that these attributes should really exist - # and throw io.UnsupportedOperation, but that isn't always the case. - mocks = [ - ('no readinto', lambda x: FilelikeMock(x)), - ('has readinto', lambda x: FilelikeMock(x, has_readinto=True)), - ('no fileno', lambda x: FilelikeMock(x, has_fileno=False)), - ] + resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) + self.assertEqual(resu, outu, 'outputs of svd and svd with out differ') + self.assertEqual(ress, outs, 'outputs of svd and svd with out differ') + self.assertEqual(resv, outv, 'outputs of svd and svd with out differ') - to_serialize = torch.randn(3, 10) - for desc, mock in mocks: - self._test_serialization_filelike(to_serialize, mock, desc) + # test non-contiguous + x = torch.randn(*dims, device=device) + n_dim = len(dims) + # Reverse the batch dimensions and the matrix dimensions and then concat them + x = x.permute(tuple(range(n_dim - 3, -1, -1)) + (n_dim - 1, n_dim - 2)) + assert not x.is_contiguous(), "x is intentionally non-contiguous" + resu, ress, resv = torch.svd(x, some=some, compute_uv=compute_uv) + if compute_uv: + if some: + x_recon = torch.matmul(resu, torch.matmul(ress.diag_embed(), resv.transpose(-2, -1))) + self.assertEqual(x, x_recon, 1e-8, 'Incorrect reconstruction using U @ diag(S) @ V.T') + else: + narrow_u = resu[..., :min(*dims[-2:])] + narrow_v = resv[..., :min(*dims[-2:])] + x_recon = torch.matmul(narrow_u, torch.matmul(ress.diag_embed(), narrow_v.transpose(-2, -1))) + self.assertEqual(x, x_recon, 1e-8, 'Incorrect reconstruction using U @ diag(S) @ V.T') + else: + _, singvals, _ = torch.svd(x, compute_uv=True) + self.assertEqual(singvals, ress, 'Singular values mismatch') + self.assertEqual(resu, torch.zeros_like(resu), 'U not zero') + self.assertEqual(resv, torch.zeros_like(resv), 'V not zero') - def test_serialization_filelike_stress(self): - a = torch.randn(11 * (2 ** 9) + 1, 5 * (2 ** 9)) + shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices + (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices + (3, 7), (5, 3, 7), (7, 5, 3, 7)] # thin matrices + for dims, some, compute_uv in product(shapes, [True, False], [True, False]): + run_test(dims, some, compute_uv) - # This one should call python read multiple times - self._test_serialization_filelike(a, lambda x: FilelikeMock(x, has_readinto=False), - 'read() stress test') - self._test_serialization_filelike(a, lambda x: FilelikeMock(x, has_readinto=True), - 'readinto() stress test') + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_svd_no_singularvectors(self, device): + for size in [(5, 5), (5, 20), (20, 5)]: + a = torch.randn(*size, device=device) + u, s_expect, v = torch.svd(a) + u, s_actual, v = torch.svd(a, compute_uv=False) + self.assertEqual(s_expect, s_actual, "Singular values don't match") - def test_serialization_filelike_uses_readinto(self): - # For maximum effiency, when reading a file-like object, - # ensure the C API calls readinto instead of read. - a = torch.randn(5, 4) + def test_lerp(self, device): + start_end_shapes = [(), (5,), (5, 5), (5, 5, 5)] + for shapes in product(start_end_shapes, start_end_shapes): + start = torch.randn(shapes[0], device=device) + end = torch.randn(shapes[1], device=device) - f = io.BytesIO() - torch.save(a, f) - f.seek(0) - data = FilelikeMock(f.read(), has_readinto=True) + # Tensor weights + for weight in [torch.randn(shapes[0], device=device), random.random()]: + actual = torch.lerp(start, end, weight) + actual_method = start.lerp(end, weight) + self.assertEqual(actual, actual_method) + actual_out = torch.Tensor().to(device) + torch.lerp(start, end, weight, out=actual_out) + self.assertEqual(actual, actual_out) + expected = start + weight * (end - start) + self.assertEqual(expected, actual) - b = torch.load(data) - self.assertTrue(data.was_called('readinto')) + def test_diagflat(self, device): + dtype = torch.float32 + # Basic sanity test + x = torch.randn((100,), dtype=dtype, device=device) + result = torch.diagflat(x) + expected = torch.diag(x) + self.assertEqual(result, expected) - def test_serialization_storage_slice(self): - # Generated using: - # - # t = torch.zeros(2); - # s1 = t.storage()[:1] - # s2 = t.storage()[1:] - # torch.save((s1, s2), 'foo.ser') - # - # with PyTorch 0.3.1 - serialized = (b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03' - b'.\x80\x02}q\x00(X\n\x00\x00\x00type_sizesq\x01}q\x02(X\x03' - b'\x00\x00\x00intq\x03K\x04X\x05\x00\x00\x00shortq\x04K\x02X' - b'\x04\x00\x00\x00longq\x05K\x04uX\x10\x00\x00\x00protocol_versionq' - b'\x06M\xe9\x03X\r\x00\x00\x00little_endianq\x07\x88u.\x80\x02' - b'(X\x07\x00\x00\x00storageq\x00ctorch\nFloatStorage\nq\x01X\x0e' - b'\x00\x00\x0094279043900432q\x02X\x03\x00\x00\x00cpuq\x03K\x02' - b'X\x0e\x00\x00\x0094279029750368q\x04K\x00K\x01\x87q\x05tq\x06' - b'Q(h\x00h\x01X\x0e\x00\x00\x0094279043900432q\x07h\x03K\x02X' - b'\x0e\x00\x00\x0094279029750432q\x08K\x01K\x01\x87q\ttq\nQ' - b'\x86q\x0b.\x80\x02]q\x00X\x0e\x00\x00\x0094279043900432q' - b'\x01a.\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' - b'\x00\x00\x00\x00') + # Test offset + x = torch.randn((100,), dtype=dtype, device=device) + result = torch.diagflat(x, 17) + expected = torch.diag(x, 17) + self.assertEqual(result, expected) - buf = io.BytesIO(serialized) - (s1, s2) = torch.load(buf) - self.assertEqual(s1[0], 0) - self.assertEqual(s2[0], 0) - self.assertEqual(s1.data_ptr() + 4, s2.data_ptr()) + # Test where input has more than one dimension + x = torch.randn((2, 3, 4), dtype=dtype, device=device) + result = torch.diagflat(x) + expected = torch.diag(x.contiguous().view(-1)) + self.assertEqual(result, expected) - def test_load_error_msg(self): - expected_err_msg = (".*You can only torch.load from a file that is seekable. " + - "Please pre-load the data into a buffer like io.BytesIO and " + - "try to load from it instead.") + # Noncontig input + x = torch.randn((2, 3, 4), dtype=dtype, device=device).transpose(2, 0) + self.assertFalse(x.is_contiguous()) + result = torch.diagflat(x) + expected = torch.diag(x.contiguous().view(-1)) + self.assertEqual(result, expected) - resource = FilelikeMock(data=b"data") - delattr(resource, "tell") - delattr(resource, "seek") - self.assertRaisesRegex(AttributeError, expected_err_msg, lambda: torch.load(resource)) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_norm(self, device): + # full reduction + x = torch.randn(25, device=device) + xn = x.cpu().numpy() + for p in [0, 1, 2, 3, 4, inf, -inf]: + res = x.norm(p).item() + expected = np.linalg.norm(xn, p) + self.assertEqual(res, expected, "full reduction failed for {}-norm".format(p)) - def test_from_buffer(self): - a = bytearray([1, 2, 3, 4]) - self.assertEqual(torch.ByteStorage.from_buffer(a).tolist(), [1, 2, 3, 4]) - shorts = torch.ShortStorage.from_buffer(a, 'big') - self.assertEqual(shorts.size(), 2) - self.assertEqual(shorts.tolist(), [258, 772]) - ints = torch.IntStorage.from_buffer(a, 'little') - self.assertEqual(ints.size(), 1) - self.assertEqual(ints[0], 67305985) - f = bytearray([0x40, 0x10, 0x00, 0x00]) - floats = torch.FloatStorage.from_buffer(f, 'big') - self.assertEqual(floats.size(), 1) - self.assertEqual(floats[0], 2.25) + # one dimension + x = torch.randn(25, 25, device=device) + xn = x.cpu().numpy() + for p in [0, 1, 2, 3, 4, inf, -inf]: + res = x.norm(p, 1).cpu().numpy() + expected = np.linalg.norm(xn, p, 1) + self.assertEqual(res.shape, expected.shape) + self.assertTrue(np.allclose(res, expected), "dim reduction failed for {}-norm".format(p)) - f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40]) - bools = torch.BoolStorage.from_buffer(f, 'big') - self.assertEqual(bools.size(), 8) - self.assertEqual(bools.tolist(), [False, True, True, True, True, True, True, True]) - self.assertEqual(bools.type(), 'torch.BoolStorage') + # matrix norm + for p in ['fro', 'nuc']: + res = x.norm(p).cpu().numpy() + expected = np.linalg.norm(xn, p) + self.assertEqual(res.shape, expected.shape) + self.assertTrue(np.allclose(res, expected), "dim reduction failed for {}-norm".format(p)) - f = bytearray(b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9') - bools = torch.BoolStorage.from_buffer(f, 'big') - self.assertEqual(bools.size(), 19) + # larger tensor sanity check + self.assertEqual(2 * torch.norm(torch.ones(10000)), torch.norm(torch.ones(40000))) - f = bytearray(b'\0x4A') - bools = torch.BoolStorage.from_buffer(f, 'big') - self.assertEqual(bools.size(), 4) - self.assertEqual(bools.tolist(), [False, True, True, True]) + @skipCUDAIfNoMagma + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_nuclear_norm_axes_small_brute_force(self, device): + def check_single_nuclear_norm(x, axes): + if x.is_cuda and randrange(100) < 95: + return # too many cpu <==> gpu copies - def test_storage_casts(self): - storage = torch.IntStorage([-1, 0, 1, 2, 3, 4]) - self.assertEqual(storage.size(), 6) - self.assertEqual(storage.tolist(), [-1, 0, 1, 2, 3, 4]) - self.assertEqual(storage.type(), 'torch.IntStorage') - self.assertIs(storage.dtype, torch.int32) + a = np.array(x.cpu(), copy=False) + expected = np.linalg.norm(a, "nuc", axis=axes) - floatStorage = storage.float() - self.assertEqual(floatStorage.size(), 6) - self.assertEqual(floatStorage.tolist(), [-1, 0, 1, 2, 3, 4]) - self.assertEqual(floatStorage.type(), 'torch.FloatStorage') - self.assertEqual(floatStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) - self.assertIs(floatStorage.dtype, torch.float32) + ans = torch.norm(x, "nuc", dim=axes) + self.assertTrue(ans.is_contiguous()) + self.assertEqual(ans.shape, expected.shape) + self.assertTrue(np.allclose(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True)) - halfStorage = storage.half() - self.assertEqual(halfStorage.size(), 6) - self.assertEqual(halfStorage.tolist(), [-1, 0, 1, 2, 3, 4]) - self.assertEqual(halfStorage.type(), 'torch.HalfStorage') - self.assertEqual(halfStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) - self.assertIs(halfStorage.dtype, torch.float16) + out = torch.zeros(expected.shape, dtype=x.dtype, device=x.device) + ans = torch.norm(x, "nuc", dim=axes, out=out) + self.assertIs(ans, out) + self.assertTrue(ans.is_contiguous()) + self.assertEqual(ans.shape, expected.shape) + self.assertTrue(np.allclose(ans.cpu(), expected, rtol=1e-02, atol=1e-03, equal_nan=True)) - bfloat16Storage = storage.bfloat16() - self.assertEqual(bfloat16Storage.size(), 6) - self.assertEqual(bfloat16Storage.tolist(), [-1, 0, 1, 2, 3, 4]) - self.assertEqual(bfloat16Storage.type(), 'torch.BFloat16Storage') - self.assertEqual(bfloat16Storage.int().tolist(), [-1, 0, 1, 2, 3, 4]) - self.assertIs(bfloat16Storage.dtype, torch.bfloat16) + for n in range(1, 3): + for m in range(1, 3): + for axes in permutations([0, 1], 2): + # 2d, inner dimensions C + x = torch.randn(n, m, device=device) + check_single_nuclear_norm(x, axes) - longStorage = storage.long() - self.assertEqual(longStorage.size(), 6) - self.assertEqual(longStorage.tolist(), [-1, 0, 1, 2, 3, 4]) - self.assertEqual(longStorage.type(), 'torch.LongStorage') - self.assertEqual(longStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) - self.assertIs(longStorage.dtype, torch.int64) + # 2d, inner dimensions Fortran + x = torch.randn(m, n, device=device).transpose(-1, -2) + check_single_nuclear_norm(x, axes) - shortStorage = storage.short() - self.assertEqual(shortStorage.size(), 6) - self.assertEqual(shortStorage.tolist(), [-1, 0, 1, 2, 3, 4]) - self.assertEqual(shortStorage.type(), 'torch.ShortStorage') - self.assertEqual(shortStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) - self.assertIs(shortStorage.dtype, torch.int16) + # 2d, inner dimensions non-contiguous + x = torch.randn(n, 2 * m, device=device)[:, ::2] + check_single_nuclear_norm(x, axes) - doubleStorage = storage.double() - self.assertEqual(doubleStorage.size(), 6) - self.assertEqual(doubleStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0]) - self.assertEqual(doubleStorage.type(), 'torch.DoubleStorage') - self.assertEqual(doubleStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) - self.assertIs(doubleStorage.dtype, torch.float64) + # 2d, all dimensions non-contiguous + x = torch.randn(7 * n, 2 * m, device=device)[::7, ::2] + check_single_nuclear_norm(x, axes) - charStorage = storage.char() - self.assertEqual(charStorage.size(), 6) - self.assertEqual(charStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0]) - self.assertEqual(charStorage.type(), 'torch.CharStorage') - self.assertEqual(charStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) - self.assertIs(charStorage.dtype, torch.int8) + for o in range(1, 3): + for axes in permutations([0, 1, 2], 2): + # 3d, inner dimensions C + x = torch.randn(o, n, m, device=device) + check_single_nuclear_norm(x, axes) - byteStorage = storage.byte() - self.assertEqual(byteStorage.size(), 6) - self.assertEqual(byteStorage.tolist(), [255, 0, 1, 2, 3, 4]) - self.assertEqual(byteStorage.type(), 'torch.ByteStorage') - self.assertEqual(byteStorage.int().tolist(), [255, 0, 1, 2, 3, 4]) - self.assertIs(byteStorage.dtype, torch.uint8) + # 3d, inner dimensions Fortran + x = torch.randn(o, m, n, device=device).transpose(-1, -2) + check_single_nuclear_norm(x, axes) - boolStorage = storage.bool() - self.assertEqual(boolStorage.size(), 6) - self.assertEqual(boolStorage.tolist(), [True, False, True, True, True, True]) - self.assertEqual(boolStorage.type(), 'torch.BoolStorage') - self.assertEqual(boolStorage.int().tolist(), [1, 0, 1, 1, 1, 1]) - self.assertIs(boolStorage.dtype, torch.bool) + # 3d, inner dimensions non-contiguous + x = torch.randn(o, n, 2 * m, device=device)[:, :, ::2] + check_single_nuclear_norm(x, axes) - def test_storage_device(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: - x = torch.tensor([], device=device) - self.assertEqual(x.dtype, x.storage().dtype) + # 3d, all dimensions non-contiguous + x = torch.randn(7 * o, 5 * n, 2 * m, device=device)[::7, ::5, ::2] + check_single_nuclear_norm(x, axes) - @unittest.skipIf(torch.cuda.device_count() < 2, 'less than 2 GPUs detected') - def test_storage_multigpu(self): - devices = ['cuda:0', 'cuda:1'] - for device in devices: - x = torch.tensor([], device=device) - self.assertEqual(x.dtype, x.storage().dtype) + for r in range(1, 3): + for axes in permutations([0, 1, 2, 3], 2): + # 4d, inner dimensions C + x = torch.randn(r, o, n, m, device=device) + check_single_nuclear_norm(x, axes) - @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows") - def test_from_file(self): - size = 10000 - with tempfile.NamedTemporaryFile() as f: - s1 = torch.FloatStorage.from_file(f.name, True, size) - t1 = torch.FloatTensor(s1).copy_(torch.randn(size)) + # 4d, inner dimensions Fortran + x = torch.randn(r, o, n, m, device=device).transpose(-1, -2) + check_single_nuclear_norm(x, axes) - # check mapping - s2 = torch.FloatStorage.from_file(f.name, True, size) - t2 = torch.FloatTensor(s2) - self.assertEqual(t1, t2, 0) + # 4d, inner dimensions non-contiguous + x = torch.randn(r, o, n, 2 * m, device=device)[:, :, :, ::2] + check_single_nuclear_norm(x, axes) - # check changes to t1 from t2 - rnum = random.uniform(-1, 1) - t1.fill_(rnum) - self.assertEqual(t1, t2, 0) + # 4d, all dimensions non-contiguous + x = torch.randn(7 * r, 5 * o, 11 * n, 2 * m, device=device)[::7, ::5, ::11, ::2] + check_single_nuclear_norm(x, axes) - # check changes to t2 from t1 - rnum = random.uniform(-1, 1) - t2.fill_(rnum) - self.assertEqual(t1, t2, 0) + @skipCUDAIfNoMagma + def test_nuclear_norm_exceptions(self, device): + for lst in [], [1], [1, 2]: + for axes in (), (0,), (0, 1): + x = torch.tensor(lst, dtype=torch.double, device=device) + self.assertRaises(RuntimeError, torch.norm, x, "nuc", axes) - @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows") - def test_torch_from_file(self): - size = 10000 - with tempfile.NamedTemporaryFile() as f: - s1 = torch.from_file(f.name, True, size, dtype=torch.float) - t1 = torch.FloatTensor(s1).copy_(torch.randn(size)) + x = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.double, device=device) + self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0)) + self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 2)) - # check mapping - s2 = torch.from_file(f.name, True, size, dtype=torch.float) - t2 = torch.FloatTensor(s2) - self.assertEqual(t1, t2, 0) + def test_dist(self, device): + def run_test(x, y): + for p in [0, 1, 2, 3, 4, inf, -inf]: + dist_xy = torch.dist(x, y, p) + dist_xy_norm = torch.norm(x - y, p) + self.assertEqual(dist_xy, dist_xy_norm) - # check changes to t1 from t2 - rnum = random.uniform(-1, 1) - t1.fill_(rnum) - self.assertEqual(t1, t2, 0) + run_test(torch.randn(5, device=device), torch.randn(5, device=device)) - # check changes to t2 from t1 - rnum = random.uniform(-1, 1) - t2.fill_(rnum) - self.assertEqual(t1, t2, 0) + x = torch.zeros(3, device=device) + y = torch.zeros(3, device=device) + y[1] = 1. + run_test(x, y) - def test_print(self): - default_type = torch.Tensor().type() - for t in torch._tensor_classes: - if t == torch.HalfTensor: - continue # HalfTensor does not support fill - if t.is_sparse: - continue - if t.is_cuda and not torch.cuda.is_available(): - continue - if t == torch.cuda.BFloat16Tensor: - self.assertRaises(RuntimeError, lambda: t(100, 100).fill_(1)) - continue - obj = t(100, 100).fill_(1) - obj.__repr__() - str(obj) - # test half tensor - obj = torch.rand(100, 100, device='cpu').half() - obj.__repr__() - str(obj) - for t in torch._storage_classes: - if t == torch.BFloat16Storage: - continue # Fix once fill is enabled for bfloat16 - if t.is_cuda and not torch.cuda.is_available(): - continue - if t == torch.BoolStorage or t == torch.cuda.BoolStorage: - obj = t(100).fill_(True) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_geqrf(self, device): + a = torch.randn(5, 5, device=device) + b, c = torch.geqrf(a) + b_placeholder, c_placeholder = torch.empty_like(b), torch.empty_like(c) + torch.geqrf(a, out=(b_placeholder, c_placeholder)) + self.assertEqual(b, b_placeholder) + self.assertEqual(c, c_placeholder) + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_triangular_solve(self, device): + from common_utils import triangular_solve_test_helper + for (k, n), (upper, unitriangular, transpose) in product(zip([2, 3, 5], [3, 5, 7]), + product([True, False], repeat=3)): + b, A = triangular_solve_test_helper((n, n), (n, k), lambda t: t.to(device), upper, unitriangular) + x = torch.triangular_solve(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0] + if transpose: + self.assertLessEqual(b.dist(A.t().mm(x)), 4e-12) else: - obj = t(100).fill_(1) - obj.__repr__() - str(obj) + self.assertLessEqual(b.dist(A.mm(x)), 4e-12) - # test big integer - x = torch.tensor(2341234123412341) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor(2341234123412341)''') + @slowTest + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_triangular_solve_batched_many_batches(self, device): + from common_utils import triangular_solve_test_helper - # test scientific notation - x = torch.tensor([1e28, 1e-28]) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([1.0000e+28, 1.0000e-28])''') + def cast(t): + return t.to(device) - # test scientific notation using set_printoptions - x = torch.tensor([1e2, 1e-2]) - torch.set_printoptions(sci_mode=True) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([1.0000e+02, 1.0000e-02])''') - torch.set_printoptions(sci_mode=False) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([ 100.0000, 0.0100])''') - torch.set_printoptions(sci_mode=None) # reset to the default value + for upper, transpose, unitriangular in product([True, False], repeat=3): + b, A = triangular_solve_test_helper((256, 256, 5, 5), (5, 1), cast, upper, unitriangular) + x, _ = torch.triangular_solve(b, A, + upper=upper, transpose=transpose, unitriangular=unitriangular) + if transpose: + A = A.transpose(-2, -1) + self.assertEqual(torch.matmul(A, x), b.expand(A.shape[:-2] + (5, 1))) - # test no leading space if all elements positive - x = torch.tensor([1, 2]) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([1, 2])''') + b, A = triangular_solve_test_helper((3, 3), (512, 512, 3, 1), cast, upper, unitriangular) + x, _ = torch.triangular_solve(b, A, + upper=upper, transpose=transpose, unitriangular=unitriangular) + if transpose: + A = A.transpose(-2, -1) + self.assertEqual(torch.matmul(A, x), b) - # test for leading space if there are negative elements - x = torch.tensor([1, -2]) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([ 1, -2])''') + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") + def test_triangular_solve_batched_broadcasting(self, device): + from scipy.linalg import solve_triangular as tri_solve + from common_utils import triangular_solve_test_helper - # test inf and nan - x = torch.tensor([4, inf, 1.5, -inf, 0, nan, 1]) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([4.0000, inf, 1.5000, -inf, 0.0000, nan, 1.0000])''') + def scipy_tri_solve_batched(A, B, upper, trans, diag): + batch_dims_A, batch_dims_B = A.shape[:-2], B.shape[:-2] + single_dim_A, single_dim_B = A.shape[-2:], B.shape[-2:] + expand_dims = tuple(torch._C._infer_size(torch.Size(batch_dims_A), + torch.Size(batch_dims_B))) + expand_A = np.broadcast_to(A, expand_dims + single_dim_A) + expand_B = np.broadcast_to(B, expand_dims + single_dim_B) + flat_A = expand_A.reshape((-1,) + single_dim_A) + flat_B = expand_B.reshape((-1,) + single_dim_B) + flat_X = np.vstack([tri_solve(a, b, lower=(not upper), trans=int(trans), unit_diagonal=diag) + for a, b in zip(flat_A, flat_B)]) + return flat_X.reshape(expand_B.shape) - # test dtype - torch.set_default_dtype(torch.float) - x = torch.tensor([1e-324, 1e-323, 1e-322, 1e307, 1e308, 1e309], dtype=torch.float64) - self.assertEqual(x.__repr__(), str(x)) - expected_str = '''\ -tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308, - inf], dtype=torch.float64)''' - self.assertExpectedInline(str(x), expected_str) + def run_test(A_dims, b_dims, device, upper, transpose, unitriangular): + b, A = triangular_solve_test_helper(A_dims, b_dims, lambda t: t.to(device), upper, unitriangular) + x_exp = torch.as_tensor(scipy_tri_solve_batched(A.cpu().numpy(), b.cpu().numpy(), + upper, transpose, unitriangular)) + x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0] - # test changing default dtype - torch.set_default_dtype(torch.float64) - self.assertEqual(x.__repr__(), str(x)) - expected_str = '''\ -tensor([ 0.0000e+00, 9.8813e-324, 9.8813e-323, 1.0000e+307, 1.0000e+308, - inf])''' - self.assertExpectedInline(str(x), expected_str) + self.assertEqual(x, x_exp.to(device)) - # test summary - x = torch.zeros(10000) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([0., 0., 0., ..., 0., 0., 0.])''') + for upper, transpose, unitriangular in product([True, False], repeat=3): + # test against scipy.linalg.solve_triangular + run_test((2, 1, 3, 4, 4), (2, 1, 3, 4, 6), device, upper, transpose, unitriangular) # no broadcasting + run_test((2, 1, 3, 4, 4), (4, 6), device, upper, transpose, unitriangular) # broadcasting b + run_test((4, 4), (2, 1, 3, 4, 2), device, upper, transpose, unitriangular) # broadcasting A + run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), device, upper, transpose, unitriangular) # broadcasting A & b + + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_lstsq(self, device): + def cast_fn(tensor): + return tensor.to(device=device) - # test internal summary function - x = torch.rand(1, 20, 5, 30) - summary = torch._tensor_str.get_summarized_data(x) - self.assertEqual(summary.shape, (1, 6, 5, 6)) - first_and_last = [0, 1, 2, -3, -2, -1] - self.assertEqual(summary, x[:, first_and_last][..., first_and_last]) + def _test_underdetermined(a, b, expectedNorm): + # underdetermined systems are not supported on the GPU + if 'cuda' in device: + return - # test device - if torch.cuda.is_available(): - x = torch.tensor([123], device='cuda:0') - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([123], device='cuda:0')''') + m = a.size()[0] + n = a.size()[1] + assert(m <= n) + + a_copy = a.clone() + b_copy = b.clone() + res1 = torch.lstsq(b, a)[0] + self.assertEqual(a, a_copy, 0) + self.assertEqual(b, b_copy, 0) + self.assertEqual((torch.mm(a, res1) - b).norm(), expectedNorm, 1e-8) + + ta = cast_fn(torch.Tensor()) + tb = cast_fn(torch.Tensor()) + res2 = torch.lstsq(b, a, out=(tb, ta))[0] + self.assertEqual(a, a_copy, 0) + self.assertEqual(b, b_copy, 0) + self.assertEqual((torch.mm(a, res1) - b).norm(), expectedNorm, 1e-8) - # test changing default to cuda - torch.set_default_tensor_type(torch.cuda.FloatTensor) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([123])''') + res3 = torch.lstsq(b, a, out=(b, a))[0] + self.assertEqual((torch.mm(a_copy, b) - b_copy).norm(), expectedNorm, 1e-8) + self.assertEqual(res1, tb, 0) + self.assertEqual(res1, b, 0) + self.assertEqual(res1, res2, 0) + self.assertEqual(res1, res3, 0) - # test printing a tensor on a different gpu than current one. - if torch.cuda.device_count() >= 2: - with torch.cuda.device(1): - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([123], device='cuda:0')''') + def _test_overdetermined(a, b, expectedNorm): + m = a.size()[0] + n = a.size()[1] + assert(m > n) - # test printing cpu tensor when default device is cuda - y = torch.tensor([123], device='cpu') - self.assertEqual(y.__repr__(), str(y)) - self.assertExpectedInline(str(y), '''tensor([123], device='cpu')''') - torch.set_default_tensor_type(default_type) + def check_norm(a, b, expected_norm, gels_result): + # Checks |ax - b| and the residual info from the result - # test integral floats and requires_grad - x = torch.tensor([123.], requires_grad=True) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([123.], requires_grad=True)''') + # The first n rows is the least square solution. + # Rows n to m-1 contain residual information. + x = gels_result[:n] + resid_info = gels_result[n:] - # test non-contiguous print - # sliced tensor should have > PRINT_OPTS.threshold elements - x = torch.ones(100, 2, 2, 10) - y = x.as_strided(size=(100, 2, 10), stride=(2 * 2 * 10, 2 * 10, 1)) - self.assertEqual(str(y), y.__repr__()) - expected_str = '''\ -tensor([[[1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.]], + resid_norm = (torch.mm(a, x) - b).norm() + self.assertEqual(resid_norm, expectedNorm, 1e-8) + self.assertEqual(resid_info.norm(), resid_norm, 1e-8) - [[1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.]], + a_copy = a.clone() + b_copy = b.clone() + res1 = torch.lstsq(b, a)[0] + self.assertEqual(a, a_copy, 0) + self.assertEqual(b, b_copy, 0) + check_norm(a, b, expectedNorm, res1) - [[1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.]], + ta = cast_fn(torch.Tensor()) + tb = cast_fn(torch.Tensor()) + res2 = torch.lstsq(b, a, out=(tb, ta))[0] + self.assertEqual(a, a_copy, 0) + self.assertEqual(b, b_copy, 0) + check_norm(a, b, expectedNorm, res2) - ..., + res3 = torch.lstsq(b, a, out=(b, a))[0] + check_norm(a_copy, b_copy, expectedNorm, res3) - [[1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.]], + self.assertEqual(res1, tb, 0) + self.assertEqual(res1, b, 0) + self.assertEqual(res1, res2, 0) + self.assertEqual(res1, res3, 0) - [[1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.]], + # basic test + expectedNorm = 0 + a = cast_fn(torch.Tensor(((1.44, -9.96, -7.55, 8.34), + (-7.84, -0.28, 3.24, 8.09), + (-4.39, -3.24, 6.27, 5.28), + (4.53, 3.83, -6.64, 2.06)))).t() + b = cast_fn(torch.Tensor(((8.58, 8.26, 8.48, -5.28), + (9.35, -4.43, -0.70, -0.26)))).t() + _test_underdetermined(a, b, expectedNorm) - [[1., 1., 1., ..., 1., 1., 1.], - [1., 1., 1., ..., 1., 1., 1.]]])\ -''' + # test overdetermined + expectedNorm = 17.390200628863 + a = cast_fn(torch.Tensor(((1.44, -9.96, -7.55, 8.34, 7.08, -5.45), + (-7.84, -0.28, 3.24, 8.09, 2.52, -5.70), + (-4.39, -3.24, 6.27, 5.28, 0.74, -1.19), + (4.53, 3.83, -6.64, 2.06, -2.47, 4.70)))).t() + b = cast_fn(torch.Tensor(((8.58, 8.26, 8.48, -5.28, 5.72, 8.93), + (9.35, -4.43, -0.70, -0.26, -7.36, -2.52)))).t() + _test_overdetermined(a, b, expectedNorm) - self.assertExpectedInline(str(y), expected_str) + # test underdetermined + expectedNorm = 0 + a = cast_fn(torch.Tensor(((1.44, -9.96, -7.55), + (-7.84, -0.28, 3.24), + (-4.39, -3.24, 6.27), + (4.53, 3.83, -6.64)))).t() + b = cast_fn(torch.Tensor(((8.58, 8.26, 8.48), + (9.35, -4.43, -0.70)))).t() + _test_underdetermined(a, b, expectedNorm) - # test print 0-dim tensor: there's no 0-dim in Numpy, we match arrayprint style - x = torch.tensor(0.00002) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor(2.0000e-05)''') + # test reuse + expectedNorm = 0 + a = cast_fn(torch.Tensor(((1.44, -9.96, -7.55, 8.34), + (-7.84, -0.28, 3.24, 8.09), + (-4.39, -3.24, 6.27, 5.28), + (4.53, 3.83, -6.64, 2.06)))).t() + b = cast_fn(torch.Tensor(((8.58, 8.26, 8.48, -5.28), + (9.35, -4.43, -0.70, -0.26)))).t() + ta = cast_fn(torch.Tensor()) + tb = cast_fn(torch.Tensor()) + torch.lstsq(b, a, out=(tb, ta)) + self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8) + torch.lstsq(b, a, out=(tb, ta)) + self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8) + torch.lstsq(b, a, out=(tb, ta)) + self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, 1e-8) - # test print boolean tensor - x = torch.tensor([True]) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([True])''') + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_qr(self, device): + def run_test(tensor_dims, some): + A = torch.randn(*tensor_dims, device=device) + Q, R = torch.qr(A, some=some) - x = torch.tensor(True) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor(True)''') + # Check0: Q[-2:] = (m, n_columns), R[-2:] = (n_columns, n) + m, n = tensor_dims[-2:] + n_columns = m if (not some) and m > n else min(m, n) + self.assertEqual(Q.size(-2), m) + self.assertEqual(R.size(-1), n) + self.assertEqual(Q.size(-1), n_columns) - # [Numpy] test print float in sci_mode when min < 0.0001. - x = torch.tensor([0.00002]) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([2.0000e-05])''') + # Check1: A = QR + self.assertEqual(A, torch.matmul(Q, R)) - # [Numpy] test print float in sci_mode when max > 1e8. - # TODO: Pytorch uses fixed precision to print, while Numpy uses dragon4_scientific - # to do automatic trimming and padding. - x = torch.tensor([123456789.]) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([1.2346e+08])''') + # Check2: A = QR (with out) + Q_out, R_out = torch.Tensor().to(device), torch.Tensor().to(device) + torch.qr(A, some=some, out=(Q_out, R_out)) + self.assertEqual(A, torch.matmul(Q_out, R_out)) - # [Numpy] test print float in sci_mode when max / min > 1000. - x = torch.tensor([0.01, 11]) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([1.0000e-02, 1.1000e+01])''') + # Check3: Q == Q_out, R == R_out + self.assertEqual(Q, Q_out) + self.assertEqual(R, R_out) - # [Numpy] test print int max / min > 1000, no sci_mode - x = torch.tensor([1, 1010]) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([ 1, 1010])''') + # Check4: Q^{T}Q = I, triu(R) = R + self.assertEqual(torch.matmul(Q.transpose(-2, -1), Q), + torch.eye(n_columns, device=device).expand(Q.shape[:-2] + (n_columns, n_columns))) + self.assertEqual(R.triu(), R) - # [Numpy] test print int > 1e8, no sci_mode - x = torch.tensor([1000000000]) # 1e9 - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([1000000000])''') + tensor_dims_list = [(3, 5), (5, 5), (5, 3), # Single matrix + (7, 3, 5), (7, 5, 5), (7, 5, 3), # 3-dim Tensors + (7, 5, 3, 5), (7, 5, 5, 5), (7, 5, 5, 3)] # 4-dim Tensors + for tensor_dims, some in product(tensor_dims_list, [True, False]): + run_test(tensor_dims, some) - # [Numpy] test printing float in int_mode - x = torch.tensor([1., 1000.]) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([ 1., 1000.])''') + def test_randperm(self, device): + if device == 'cpu': + rng_device = None + else: + rng_device = [0] - # [Numpy] test printing float in int_mode in sci format when max / min > 1000. - x = torch.tensor([1., 1010.]) - self.assertEqual(x.__repr__(), str(x)) - self.assertExpectedInline(str(x), '''tensor([1.0000e+00, 1.0100e+03])''') + # Test core functionality. On CUDA, for small n, randperm is offloaded to CPU instead. For large n, randperm is + # executed on GPU. + for n in (100, 50000, 100000): + # Ensure both integer and floating-point numbers are tested. Half follows an execution path that is + # different from others on CUDA. + for dtype in (torch.long, torch.half, torch.float): + if n > 2049 and dtype == torch.half: # Large n for torch.half will raise an exception, do not test here. + continue + with torch.random.fork_rng(devices=rng_device): + res1 = torch.randperm(n, dtype=dtype, device=device) + res2 = torch.empty(0, dtype=dtype, device=device) + torch.randperm(n, out=res2, dtype=dtype, device=device) + self.assertEqual(res1, res2, 0) - def test_sizeof(self): - sizeof_empty = torch.randn(0).storage().__sizeof__() - sizeof_10 = torch.randn(10).storage().__sizeof__() - sizeof_100 = torch.randn(100).storage().__sizeof__() - self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10) - self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0) + # Default type is long + for n in (100, 10000): + self.assertEqual(torch.randperm(n, device=device).dtype, torch.long) - sizeof_empty = torch.randn(0).type(torch.ByteTensor).storage().__sizeof__() - sizeof_10 = torch.randn(10).type(torch.ByteTensor).storage().__sizeof__() - sizeof_100 = torch.randn(100).type(torch.ByteTensor).storage().__sizeof__() - self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10) - self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0) + # randperm of 0 elements is an empty tensor + res1 = torch.randperm(0) + res2 = torch.tensor(5, dtype=dtype, device=device) + torch.randperm(0, out=res2) + self.assertEqual(res1.numel(), 0) + self.assertEqual(res2.numel(), 0) - def test_unsqueeze(self): - x = torch.randn(2, 3, 4) - y = x.unsqueeze(1) - self.assertEqual(y, x.view(2, 1, 3, 4)) - y = x.clone().unsqueeze_(2) - self.assertEqual(y, x.view(2, 3, 1, 4)) + # Test exceptions when n is too large for a floating point type + for dtype, small_n, large_n in ((torch.half, 2**11 + 1, 2**11 + 2), + (torch.float, 2**24 + 1, 2**24 + 2), + (torch.double, 2**25, # 2**53 + 1 is too large to run + 2**53 + 2)): + res = torch.empty(0, dtype=dtype, device=device) + torch.randperm(small_n, out=res) # No exception expected + self.assertRaises(RuntimeError, lambda: torch.randperm(large_n, out=res, device=device)) - x = x[:, 1] - self.assertFalse(x.is_contiguous()) - y = x.unsqueeze(1) - self.assertEqual(y, x.contiguous().view(2, 1, 4)) - y = x.clone().unsqueeze_(2) - self.assertEqual(y, x.contiguous().view(2, 4, 1)) + # Test non-contiguous tensors + for n in (4, 5, 6, 10, 20): + non_contiguous_tensor = torch.zeros((2, 3), dtype=torch.long, device=device).t() + self.assertFalse(non_contiguous_tensor.is_contiguous()) + with torch.random.fork_rng(devices=rng_device): + res = torch.randperm(n, dtype=torch.long, device=device) + torch.randperm(n, out=non_contiguous_tensor) + self.assertEqual(non_contiguous_tensor, res) + + def test_random_neg_values(self, device): + signed_types = ['torch.DoubleTensor', 'torch.FloatTensor', 'torch.LongTensor', + 'torch.IntTensor', 'torch.ShortTensor'] + for tname in signed_types: + res = torch.rand(SIZE, SIZE).type(tname).to(device) + res.random_(-10, -1) + self.assertLessEqual(res.max().item(), 9) + self.assertGreaterEqual(res.min().item(), -10) - def test_iter(self): - x = torch.randn(5, 5) - for i, sub in enumerate(x): - self.assertEqual(sub, x[i]) + def test_triu_tril(self, device): + def gen_mask(shape, diagonal, device, upper): + mask = torch.zeros(*shape[-2:]).byte() + for i in range(shape[-2]): + for j in range(shape[-1]): + cond = j - i < diagonal if upper else j - i > diagonal + if cond: + mask[i, j] = 1 + return mask.expand(*shape).to(device) - x = torch.Tensor() - self.assertEqual(list(x), []) + torch_functions = {True: torch.triu, False: torch.tril} + if TEST_NUMPY: + numpy_functions = {True: np.triu, False: np.tril} - def test_accreal_type(self): - x = torch.ones(2, 3, 4) - self.assertIsInstance(x.double().sum().item(), float) - self.assertIsInstance(x.float().sum().item(), float) - self.assertIsInstance(x.long().sum().item(), int) - self.assertIsInstance(x.int().sum().item(), int) - self.assertIsInstance(x.short().sum().item(), int) - self.assertIsInstance(x.char().sum().item(), int) - self.assertIsInstance(x.byte().sum().item(), int) + # TODO: remove this when bool and half are supported for torch.where + def bool_half_compat_where(pred, true_tensor, false_tensor, dtype): + if dtype == torch.bool or dtype == torch.half: + return torch.where(pred.byte(), true_tensor.byte(), false_tensor.byte()).to(dtype=dtype) + else: + return torch.where(pred, true_tensor, false_tensor) - def test_assertEqual(self): - x = torch.FloatTensor([0]) - self.assertEqual(x, 0) - xv = torch.autograd.Variable(x) - self.assertEqual(xv, 0) - self.assertEqual(x, xv) - self.assertEqual(xv, x) + def run_test(shape, device, diagonal, dtype): + x = torch.empty(*shape, device=device, dtype=dtype).fill_(2) - def test_new(self): - x = torch.autograd.Variable(torch.Tensor()) - y = torch.autograd.Variable(torch.randn(4, 4)) - z = torch.autograd.Variable(torch.IntTensor([1, 2, 3])) - self.assertEqual(x.new().shape, [0]) - self.assertEqual(x.new(), x) - self.assertEqual(x.new(1, 2).shape, [1, 2]) - self.assertEqual(x.new(torch.Size([3, 4])).shape, [3, 4]) - self.assertEqual(x.new([3, 4]).shape, [2]) - self.assertEqual(x.new([3, 4]).tolist(), [3, 4]) - self.assertEqual(x.new((3, 4)).tolist(), [3, 4]) - if TEST_NUMPY: - self.assertEqual(x.new([np.int32(3), np.float64(4)]).tolist(), [3, 4]) - self.assertEqual(x.new(np.array((3, 4))).tolist(), [3, 4]) - self.assertEqual(x.new([z[2], z[0] + 3]).tolist(), [3, 4]) - self.assertEqual(x.new(size=(3, 4)).shape, [3, 4]) - self.assertEqual(x.new(()).shape, [0]) - self.assertEqual(x.new(y.storage()).data_ptr(), y.data_ptr()) - self.assertEqual(x.new(y).data_ptr(), y.data_ptr()) - self.assertIsNot(x.new(y), y) + for upper in [True, False]: + # normal test with mask + torch_tri_func = torch_functions[upper] + res1 = torch_tri_func(x, diagonal=diagonal) + res2 = torch.empty(0, device=device, dtype=dtype) + torch_tri_func(x, diagonal=diagonal, out=res2) + exp_mask = gen_mask(shape, diagonal, device, upper) + expected = bool_half_compat_where(exp_mask, torch.tensor(0).type_as(x), x, dtype) + self.assertEqual(res1, res2, 0) + self.assertEqual(expected, res1, 0) - self.assertRaises(TypeError, lambda: x.new(z)) - # TypeError would be better - self.assertRaises(RuntimeError, lambda: x.new(z.storage())) + # non-contiguous and expanded tensors test + if 0 not in shape: + for s in range(-len(shape), -1): + # non-contiguous tensors + x_nc = x.clone().transpose(s, s + 1) + exp_mask = gen_mask(x_nc.size(), diagonal, device, upper) + if 1 not in shape: + assert not x_nc.is_contiguous(), "x is intentionally non-contiguous" + exp_nc = bool_half_compat_where(exp_mask, torch.tensor(0).type_as(x), x_nc, dtype) + self.assertEqual(torch_tri_func(x_nc, diagonal), exp_nc, 0) + x_nc_is_contiguous = x_nc.is_contiguous() + if upper: + self.assertEqual(x_nc.triu_(diagonal), exp_nc, 0) + else: + self.assertEqual(x_nc.tril_(diagonal), exp_nc, 0) - def test_empty_like(self): - x = torch.autograd.Variable(torch.Tensor()) - y = torch.autograd.Variable(torch.randn(4, 4)) - z = torch.autograd.Variable(torch.IntTensor([1, 2, 3])) - for a in (x, y, z): - self.assertEqual(torch.empty_like(a).shape, a.shape) - self.assertEqual(torch.empty_like(a).type(), a.type()) + self.assertTrue(x_nc.is_contiguous() == x_nc_is_contiguous, + "contiguity of x_nc should not be changed") - def test_empty_strided(self): - for device in torch.testing.get_all_device_types(): - for shape in [(2, 3, 4), (0, 2, 0)]: - # some of these cases are pretty strange, just verifying that if as_strided - # allows them then empty_strided can as well. - for strides in [(12, 4, 1), (2, 4, 6), (0, 0, 0)]: - empty_strided = torch.empty_strided(shape, strides, device=device) - # as_strided checks the storage size is big enough to support such a strided tensor; - # instead of repeating this calculation, we just use empty_strided which does the same - # calculation when setting the storage size. - as_strided = torch.empty(empty_strided.storage().size(), - device=device).as_strided(shape, strides) - self.assertEqual(empty_strided.shape, as_strided.shape) - self.assertEqual(empty_strided.stride(), as_strided.stride()) + # expanded tensors + expanded_size = (x.size(0),) + x.size() + x_expanded = x.clone().expand(*expanded_size) + if x.size(0) != 1: + assert 0 in x_expanded.stride(), "x intentionally has 0 in its stride" + output = torch_tri_func(x_expanded, diagonal) + self.assertEqual(output, expected.expand(expanded_size), 0) + if x.size(0) != 1: + self.assertTrue(0 in x_expanded.stride(), + "geometry of x_expanded should be the same") + if upper: + self.assertEqual(output, x_expanded.triu_(diagonal), 0) + else: + self.assertEqual(output, x_expanded.tril_(diagonal), 0) - def test_pin_memory(self): - x = torch.randn(3, 5) - self.assertFalse(x.is_pinned()) - if not torch.cuda.is_available(): - self.assertRaises(RuntimeError, lambda: x.pin_memory()) - else: - pinned = x.pin_memory() - self.assertTrue(pinned.is_pinned()) - self.assertEqual(pinned, x) - self.assertNotEqual(pinned.data_ptr(), x.data_ptr()) - # test that pin_memory on already pinned tensor has no effect - self.assertIs(pinned, pinned.pin_memory()) - self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr()) + if not TEST_NUMPY: + continue - @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') - def test_pin_memory_from_constructor(self): + # numpy test + numpy_tri_func = numpy_functions[upper] + self.assertEqual(numpy_tri_func(x.to('cpu').numpy(), diagonal), res1.cpu().numpy()) - def _get_like(t, **kwargs): - return [ - torch.rand_like(t, **kwargs), - torch.randn_like(t, **kwargs), - torch.empty_like(t, **kwargs), - torch.full_like(t, 4, **kwargs), - torch.zeros_like(t, **kwargs), - torch.ones_like(t, **kwargs), - ] + diagonals = [-2, -1, 0, 1, 2] + shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices + (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices + (3, 7), (5, 3, 7), (7, 5, 3, 7), # thin matrices + (3, 0), (0, 3, 3), (3, 3, 0, 0), # no numel matrices + (3, 1), (5, 3, 1), (7, 5, 3, 1), # very fat matrices + (1, 3), (5, 1, 3), (7, 5, 1, 3), # very thin matrices + (1, 3, 3, 3), (3, 1, 3, 3, 3)] # unsqueezed batch dimensions + dtypes = [dtype for dtype in torch.testing.get_all_dtypes() if dtype != torch.bfloat16] + for s, d, dtype in product(shapes, diagonals, dtypes): + run_test(s, device, d, dtype) - def _get_tensors(**kwargs): - return [ - torch.tensor([10, 11], **kwargs), - torch.randn(3, 5, **kwargs), - torch.rand(3, **kwargs), - # torch.randint(3, 5, **kwargs), // unsupported - torch.zeros(3, **kwargs), - torch.randperm(3, **kwargs), - torch.empty(6, **kwargs), - torch.ones(6, **kwargs), - torch.eye(6, **kwargs), - torch.arange(3, 5, **kwargs)] + @skipCUDANonDefaultStreamIf(True) + def test_multinomial_alias(self, device): + # Get probs vector to use in setup + def get_probs(length, is_contiguous): + probs = torch.softmax(torch.randn(length), 0) + if not is_contiguous: + probs = torch.softmax(torch.randn(length, 2), 0)[:, 1] + assert not (is_contiguous ^ probs.is_contiguous()), "contiguity requirement not met" + return probs.to(device) - pinned_tensors = _get_tensors(pin_memory=True) + _get_like(torch.empty(5, dtype=torch.float64), pin_memory=True) - for x in pinned_tensors: - self.assertTrue(x.is_pinned()) + for is_contiguous in [True, False]: + probs = get_probs(4, is_contiguous) + alias_table, prob_table = torch._multinomial_alias_setup(probs) + for n_samples in [-1, 1, 10]: + if n_samples > 0: + samples = torch._multinomial_alias_draw(prob_table, alias_table, n_samples) + self.assertEqual(prob_table.size(), torch.Size([4]), "size mismatch: probability table") + self.assertEqual(alias_table.size(), torch.Size([4]), "size mismatch: alias table") + self.assertEqual(samples.size(), torch.Size([n_samples]), "wrong number of samples") + else: + with self.assertRaisesRegex(RuntimeError, "cannot sample <= 0 samples"): + torch._multinomial_alias_draw(prob_table, alias_table, n_samples) - tensors = _get_tensors() + _get_like(torch.empty(5, dtype=torch.float64, pin_memory=True)) - for x in tensors: - self.assertFalse(x.is_pinned()) + with self.assertRaisesRegex(RuntimeError, "expected 1-D"): + probs = probs.view(2, 2) + torch._multinomial_alias_setup(probs) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_numpy_unresizable(self): - x = np.zeros((2, 2)) - y = torch.from_numpy(x) - with self.assertRaises(ValueError): - x.resize((5, 5)) + with self.assertRaisesRegex(RuntimeError, "expected 1-D"): + a_t, p_t = torch._multinomial_alias_setup(probs) + torch._multinomial_alias_draw(p_t.view(2, 2), a_t.view(2, 2)) - z = torch.randn(5, 5) - w = z.numpy() - with self.assertRaises(RuntimeError): - z.resize_(10, 10) - with self.assertRaises(ValueError): - w.resize((10, 10)) + MAX_SAMPLES = 200000 + for probs in [get_probs(4, True), + torch.tensor([0.8, 0.2], device=device), + torch.tensor([0.7, 0.2, 0.1], device=device)]: + # Check how different the alias distribution and the original distribution are + alias_dist = torch.zeros_like(probs) + alias_table, prob_table = torch._multinomial_alias_setup(probs) + alias_samples = torch._multinomial_alias_draw(prob_table, alias_table, MAX_SAMPLES) + alias_dist = torch.unique(alias_samples, return_counts=True)[1].to(dtype=probs.dtype) / MAX_SAMPLES + self.assertTrue(torch.allclose(alias_dist, probs, rtol=0.02, atol=0.0), + "Actual: {}\nExpected: {}".format(alias_dist, probs)) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_to_numpy(self): - def get_castable_tensor(shape, tp): - dtype = tp.dtype - if dtype.is_floating_point: - dtype_info = torch.finfo(dtype) - # can't directly use min and max, because for double, max - min - # is greater than double range and sampling always gives inf. - low = max(dtype_info.min, -1e10) - high = min(dtype_info.max, 1e10) - t = torch.empty(shape, dtype=torch.float64).uniform_(low, high) - else: - # can't directly use min and max, because for int64_t, max - min - # is greater than int64_t range and triggers UB. - dtype_info = torch.iinfo(dtype) - low = max(dtype_info.min, int(-1e10)) - high = min(dtype_info.max, int(1e10)) - dtype_info = torch.iinfo(dtype) - t = torch.empty(shape, dtype=torch.int64).random_(low, high) - return t.to(dtype) + for probs in [torch.tensor([0.2501, 0.25, 0.2499, 0.25], device=device), + torch.tensor([0.8, 0.199, 0.001], device=device), + torch.tensor([0.25001, 0.25, 0.24999, 0.25], device=device), + torch.tensor([0.33, 0.34, 0.33], device=device), + torch.tensor([0.8, 0.1999, 0.0001], device=device)]: + # Check the difference between the original probabilities and the reconstructed + # probabilities from the alias and probability tables output by _multinomial_alias_setup + alias_table, prob_table = torch._multinomial_alias_setup(probs) + actual = torch.zeros_like(probs) + for i, vals in enumerate(zip(alias_table, prob_table)): + idx, p = vals + actual[i] += p + actual[idx] += 1. - p + actual = actual / len(probs) + self.assertEqual(actual, probs, 1e-6) - types = [ - torch.ByteTensor, - torch.CharTensor, - torch.ShortTensor, - torch.IntTensor, - torch.HalfTensor, - torch.FloatTensor, - torch.DoubleTensor, - torch.LongTensor, - ] - for tp in types: - # 1D - sz = 10 - x = get_castable_tensor(sz, tp) - y = x.numpy() - for i in range(sz): - self.assertEqual(x[i], y[i]) + # Some special cases + test_cases = [torch.tensor([1.0, 0.0, 0.0], device=device), torch.tensor([0.0, 1.0], device=device)] + for probs in test_cases: + alias_table, prob_table = torch._multinomial_alias_setup(probs) + alias_samples = torch._multinomial_alias_draw(prob_table, alias_table, MAX_SAMPLES) + self.assertEqual(alias_samples.unique(), probs.nonzero().squeeze(-1)) - # 1D > 0 storage offset - xm = get_castable_tensor(sz * 2, tp) - x = xm.narrow(0, sz - 1, sz) - self.assertTrue(x.storage_offset() > 0) - y = x.numpy() - for i in range(sz): - self.assertEqual(x[i], y[i]) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_lapack_empty(self, device): + # FIXME: these are just a selection of LAPACK functions -- we need a general strategy here. + # The LAPACK functions themselves generally do NOT work with zero sized dimensions, although + # numpy/sci often has a direct wrapper (e.g. lu_factor) and a wrapper that "does the right thing" + # (e.g. lu). We often name our functions identically to the lapack function, so it will take work + # to name / migrate-to better wrappers. + def fn(torchfn, *args): + return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape + for shape in args)) + + # inverse, pinverse + self.assertEqual((0, 0), fn(torch.inverse, (0, 0)).shape) + self.assertEqual((5, 0), fn(torch.pinverse, (0, 5)).shape) + self.assertEqual((0, 5), fn(torch.pinverse, (5, 0)).shape) + self.assertEqual((0, 0), fn(torch.pinverse, (0, 0)).shape) + + # det, logdet, slogdet + self.assertEqual(torch.tensor(1., device=device), fn(torch.det, (0, 0))) + self.assertEqual(torch.tensor(0., device=device), fn(torch.logdet, (0, 0))) + self.assertEqual((torch.tensor(1., device=device), torch.tensor(0., device=device)), + fn(torch.slogdet, (0, 0))) + + # eig, symeig + evalues, evectors = fn(torch.eig, (0, 0), True) + self.assertEqual([(0, 2), (0, 0)], [evalues.shape, evectors.shape]) + evalues, evectors = fn(torch.symeig, (0, 0), True) + self.assertEqual([(0,), (0, 0)], [evalues.shape, evectors.shape]) + + # qr + q, r = fn(torch.qr, (3, 0), True) + self.assertEqual([(3, 0), (0, 0)], [q.shape, r.shape]) + q, r = fn(torch.qr, (0, 3), True) + self.assertEqual([(0, 0), (0, 3)], [q.shape, r.shape]) + q, r = fn(torch.qr, (3, 0), False) + self.assertEqual([(3, 3), (3, 0)], [q.shape, r.shape]) + + # lstsq + self.assertRaises(RuntimeError, lambda: torch.lstsq(torch.randn(0, 0), torch.randn(0, 0))) + self.assertRaises(RuntimeError, lambda: torch.lstsq(torch.randn(0,), torch.randn(0, 0))) + + def test_roll(self, device): + numbers = torch.arange(1, 9, device=device) + + single_roll = numbers.roll(1, 0) + expected = torch.tensor([8, 1, 2, 3, 4, 5, 6, 7], device=device) + self.assertEqual(single_roll, expected, "{} did not equal expected result".format(single_roll)) + + roll_backwards = numbers.roll(-2, 0) + expected = torch.tensor([3, 4, 5, 6, 7, 8, 1, 2], device=device) + self.assertEqual(roll_backwards, expected, "{} did not equal expected result".format(roll_backwards)) + + data = numbers.view(2, 2, 2) + rolled = data.roll(1, 0) + expected = torch.tensor([5, 6, 7, 8, 1, 2, 3, 4], device=device).view(2, 2, 2) + self.assertEqual(expected, rolled, "{} did not equal expected result: {}".format(rolled, expected)) + + data = data.view(2, 4) + # roll a loop until back where started + loop_rolled = data.roll(2, 0).roll(4, 1) + self.assertEqual(data, loop_rolled, "{} did not equal the original: {}".format(loop_rolled, data)) + # multiple inverse loops + self.assertEqual(data, data.roll(-20, 0).roll(-40, 1)) + self.assertEqual(torch.tensor([8, 1, 2, 3, 4, 5, 6, 7], device=device), numbers.roll(1, 0)) - def check2d(x, y): - for i in range(sz1): - for j in range(sz2): - self.assertEqual(x[i][j], y[i][j]) + # test non-contiguous + # strided equivalent to numbers.as_strided(size=(4, 2), stride=(1, 4)) + strided = numbers.view(2, 4).transpose(0, 1) + self.assertFalse(strided.is_contiguous(), "this test needs a non-contiguous tensor") + expected = torch.tensor([4, 8, 1, 5, 2, 6, 3, 7]).view(4, 2) + rolled = strided.roll(1, 0) + self.assertEqual(expected, rolled, + "non contiguous tensor rolled to {} instead of {} ".format(rolled, expected)) + + # test roll with no dimension specified + expected = numbers.roll(1, 0).view(2, 4) + self.assertEqual(expected, data.roll(1), "roll with no dims should flatten and roll.") + self.assertEqual(expected, data.roll(1, dims=None), "roll with no dims should flatten and roll.") + + # test roll over multiple dimensions + expected = torch.tensor([[7, 8, 5, 6], [3, 4, 1, 2]], device=device) + double_rolled = data.roll(shifts=(2, -1), dims=(1, 0)) + self.assertEqual(double_rolled, expected, + "should be able to roll over two dimensions, got {}".format(double_rolled)) + + self.assertRaisesRegex(RuntimeError, "required", lambda: data.roll(shifts=(), dims=())) + self.assertRaisesRegex(RuntimeError, "required", lambda: data.roll(shifts=(), dims=1)) + # shifts/dims should align + self.assertRaisesRegex(RuntimeError, "align", lambda: data.roll(shifts=(1, 2), dims=(1,))) + self.assertRaisesRegex(RuntimeError, "align", lambda: data.roll(shifts=(1,), dims=(1, 2))) + + def test_nonzero_empty(self, device): + def assert_tuple_empty(tup, dim): + self.assertEqual(dim, len(tup)) + for t in tup: + self.assertEqual(torch.Size([0]), t.shape) - # empty - x = torch.Tensor().type(tp) - y = x.numpy() - self.assertEqual(y.size, 0) + x = torch.randn(0, 2, 0, 5, 0, device=device) + y = torch.nonzero(x) + z = torch.nonzero(x, as_tuple=True) + + self.assertEqual(0, y.numel()) + self.assertEqual(torch.Size([0, 5]), y.shape) + assert_tuple_empty(z, 5) + + x = torch.tensor(0.5, device=device) + y = torch.nonzero(x) + # nonzero with as_tuple returns a + # tuple of len 1 for a zero-dim tensor. + # This is done to match Numpy behavior. + z = torch.nonzero(x, as_tuple=True) + self.assertEqual(1, len(z)) + self.assertEqual(torch.zeros(1, dtype=torch.long), z[0]) + + x = torch.zeros((), device=device) + y = torch.nonzero(x) + z = torch.nonzero(x, as_tuple=True) + self.assertEqual(torch.Size([0, 0]), y.shape) + self.assertEqual(1, len(z)) + self.assertEqual(torch.empty(0, dtype=torch.long), z[0]) + + def test_normal(self, device): + q = torch.empty(100, 100, device=device).normal_() + self.assertEqual(q.mean(), 0, 0.2) + self.assertEqual(q.std(), 1, 0.2) + + q.normal_(2, 3) + self.assertEqual(q.mean(), 2, 0.3) + self.assertEqual(q.std(), 3, 0.3) + + q = torch.empty(100, 100, device=device) + q_row1 = q[0:1].clone() + q[99:100].normal_() + self.assertEqual(q[99:100].mean(), 0, 0.2) + self.assertEqual(q[99:100].std(), 1, 0.2) + self.assertEqual(q[0:1].clone(), q_row1) + + mean = torch.empty(100, 100, device=device) + std = torch.empty(100, 100, device=device) + mean[:50] = 0 + mean[50:] = 1 + std[:, :50] = 4 + std[:, 50:] = 1 + + r = torch.normal(mean) + self.assertEqual(r[:50].mean(), 0, 0.2) + self.assertEqual(r[50:].mean(), 1, 0.2) + self.assertEqual(r.std(), 1, 0.2) + + r = torch.normal(mean, 3) + self.assertEqual(r[:50].mean(), 0, 0.2) + self.assertEqual(r[50:].mean(), 1, 0.2) + self.assertEqual(r.std(), 3, 0.2) - # contiguous 2D - sz1 = 3 - sz2 = 5 - x = get_castable_tensor((sz1, sz2), tp) - y = x.numpy() - check2d(x, y) - self.assertTrue(y.flags['C_CONTIGUOUS']) + r = torch.normal(2, std) + self.assertEqual(r.mean(), 2, 0.2) + self.assertEqual(r[:, :50].std(), 4, 0.3) + self.assertEqual(r[:, 50:].std(), 1, 0.2) - # with storage offset - xm = get_castable_tensor((sz1 * 2, sz2), tp) - x = xm.narrow(0, sz1 - 1, sz1) - y = x.numpy() - self.assertTrue(x.storage_offset() > 0) - check2d(x, y) - self.assertTrue(y.flags['C_CONTIGUOUS']) + r = torch.normal(mean, std) + self.assertEqual(r[:50].mean(), 0, 0.2) + self.assertEqual(r[50:].mean(), 1, 0.2) + self.assertEqual(r[:, :50].std(), 4, 0.3) + self.assertEqual(r[:, 50:].std(), 1, 0.2) - # non-contiguous 2D - x = get_castable_tensor((sz2, sz1), tp).t() - y = x.numpy() - check2d(x, y) - self.assertFalse(y.flags['C_CONTIGUOUS']) + r = torch.normal(2, 3, (100, 100)) + self.assertEqual(r.mean(), 2, 0.2) + self.assertEqual(r.std(), 3, 0.2) - # with storage offset - xm = get_castable_tensor((sz2 * 2, sz1), tp) - x = xm.narrow(0, sz2 - 1, sz2).t() - y = x.numpy() - self.assertTrue(x.storage_offset() > 0) - check2d(x, y) + def test_empty_strided(self, device): + for shape in [(2, 3, 4), (0, 2, 0)]: + # some of these cases are pretty strange, just verifying that if as_strided + # allows them then empty_strided can as well. + for strides in [(12, 4, 1), (2, 4, 6), (0, 0, 0)]: + empty_strided = torch.empty_strided(shape, strides, device=device) + # as_strided checks the storage size is big enough to support such a strided tensor; + # instead of repeating this calculation, we just use empty_strided which does the same + # calculation when setting the storage size. + as_strided = torch.empty(empty_strided.storage().size(), + device=device).as_strided(shape, strides) + self.assertEqual(empty_strided.shape, as_strided.shape) + self.assertEqual(empty_strided.stride(), as_strided.stride()) + + def test_sign(self, device): + for dtype in torch.testing.get_all_math_dtypes(device): + + # Include NaN for floating point numbers + if dtype.is_floating_point: + dt_info = torch.finfo(dtype) - # non-contiguous 2D with holes - xm = get_castable_tensor((sz2 * 2, sz1 * 2), tp) - x = xm.narrow(0, sz2 - 1, sz2).narrow(1, sz1 - 1, sz1).t() - y = x.numpy() - self.assertTrue(x.storage_offset() > 0) - check2d(x, y) + # Create tensor (with NaN checking) + a = torch.tensor([float('nan'), -12, 0, 71, dt_info.min, dt_info.max], device=device, dtype=dtype) + a_target = torch.tensor([0, -1, 0, 1, -1, 1], device=device, dtype=dtype) - if tp != torch.HalfTensor: - # check writeable - x = get_castable_tensor((3, 4), tp) - y = x.numpy() - self.assertTrue(y.flags.writeable) - y[0][1] = 3 - self.assertTrue(x[0][1] == 3) - y = x.t().numpy() - self.assertTrue(y.flags.writeable) - y[0][1] = 3 - self.assertTrue(x[0][1] == 3) + else: + dt_info = torch.iinfo(dtype) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_to_numpy_bool(self): - x = torch.tensor([True, False], dtype=torch.bool) - self.assertEqual(x.dtype, torch.bool) + # If unsigned type, everything should be >= 0 + if dt_info.min == 0: + a = torch.tensor([12, 0, 71, dt_info.min, dt_info.max], device=device, dtype=dtype) + a_target = torch.tensor([1, 0, 1, 0, 1], device=device, dtype=dtype) + else: + a = torch.tensor([-12, 0, 71, dt_info.min, dt_info.max], device=device, dtype=dtype) + a_target = torch.tensor([-1, 0, 1, -1, 1], device=device, dtype=dtype) - y = x.numpy() - self.assertEqual(y.dtype, np.bool) - for i in range(len(x)): - self.assertEqual(x[i], y[i]) + self.assertEqual(a.sign(), a_target, 'sign device={} dtype={}'.format(device, dtype)) + self.assertEqual(torch.sign(a), a_target, 'sign device={} dtype={}'.format(device, dtype)) - x = torch.tensor([True], dtype=torch.bool) - self.assertEqual(x.dtype, torch.bool) + out = torch.empty_like(a) + torch.sign(a, out=out) + self.assertEqual(out, a_target, 'sign_out device={} dtype={}'.format(device, dtype)) - y = x.numpy() - self.assertEqual(y.dtype, np.bool) - self.assertEqual(x[0], y[0]) + a.sign_() + self.assertEqual(a, a_target, 'sign_ device={} dtype={}'.format(device, dtype)) - def test_dlpack_conversion(self): - x = torch.randn(1, 2, 3, 4).type('torch.FloatTensor') - z = from_dlpack(to_dlpack(x)) - self.assertEqual(z, x) + # Include test for bool dtype + a_bool = torch.tensor([True, True, False, float('nan')], device=device).bool() + a_bool_target = torch.tensor([True, True, False, True], device=device).bool() + self.assertEqual(a_bool.sign(), a_bool_target, 'sign device={} dtype=bool'.format(device)) + self.assertEqual(torch.sign(a_bool), a_bool_target, 'sign device={} dtype=bool'.format(device)) - @unittest.skipIf(not torch.cuda.is_available(), "No CUDA") - def test_dlpack_cuda(self): - x = torch.randn(1, 2, 3, 4).cuda() - z = from_dlpack(to_dlpack(x)) - self.assertEqual(z, x) + a_out = torch.empty_like(a_bool) + torch.sign(a_bool, out=a_out) + self.assertEqual(a_out, a_bool_target, 'sign_out device={} dtype=bool'.format(device)) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_from_numpy(self): - dtypes = [ - np.double, - np.float, - np.float16, - np.int64, - np.int32, - np.int16, - np.int8, - np.uint8, - np.longlong, - np.bool, - ] - for dtype in dtypes: - array = np.array([1, 2, 3, 4], dtype=dtype) - tensor_from_array = torch.from_numpy(array) - # TODO: change to tensor equality check once HalfTensor - # implements `==` - for i in range(len(array)): - self.assertEqual(tensor_from_array[i], array[i]) - # This is a special test case for Windows - # https://github.com/pytorch/pytorch/issues/22615 - array2 = array % 2 - tensor_from_array2 = torch.from_numpy(array2) - for i in range(len(array2)): - self.assertEqual(tensor_from_array2[i], array2[i]) + a_bool.sign_() + self.assertEqual(a_bool, a_bool_target, 'sign_ device={} dtype=bool'.format(device)) - # Test unsupported type - array = np.array([1, 2, 3, 4], dtype=np.complex) - with self.assertRaises(TypeError): - tensor_from_array = torch.from_numpy(array) + def test_logical_any(self, device): + x = torch.zeros([2, 3, 400], dtype=torch.uint8, device=device) - # check storage offset - x = np.linspace(1, 125, 125) - x.shape = (5, 5, 5) - x = x[1] - expected = torch.arange(1, 126).view(5, 5, 5)[1] - self.assertEqual(torch.from_numpy(x), expected) + self.assertEqual( + torch.tensor(0, dtype=torch.uint8, device=device), + x.any()) - # check noncontiguous - x = np.linspace(1, 25, 25) - x.shape = (5, 5) - expected = torch.arange(1, 26).view(5, 5).t() - self.assertEqual(torch.from_numpy(x.T), expected) + self.assertEqual( + torch.zeros([1, 3, 400], dtype=torch.uint8, device=device), + x.any(0, keepdim=True)) - # check noncontiguous with holes - x = np.linspace(1, 125, 125) - x.shape = (5, 5, 5) - x = x[:, 1] - expected = torch.arange(1, 126).view(5, 5, 5)[:, 1] - self.assertEqual(torch.from_numpy(x), expected) + self.assertEqual( + torch.zeros([2, 1, 400], dtype=torch.uint8, device=device), + x.any(1, keepdim=True)) - # check zero dimensional - x = np.zeros((0, 2)) - self.assertEqual(torch.from_numpy(x).shape, (0, 2)) - x = np.zeros((2, 0)) - self.assertEqual(torch.from_numpy(x).shape, (2, 0)) + self.assertEqual( + torch.zeros([2, 3, 1], dtype=torch.uint8, device=device), + x.any(2, keepdim=True)) - # check ill-sized strides raise exception - x = np.array([3., 5., 8.]) - x.strides = (3,) - self.assertRaises(ValueError, lambda: torch.from_numpy(x)) + # set the last element to 0 + x[-1][-1][-1] = 1 - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_ctor_with_numpy_array(self): - correct_dtypes = [ - np.double, - np.float, - np.float16, - np.int64, - np.int32, - np.int16, - np.int8, - np.uint8, - np.bool, - ] + self.assertEqual( + torch.tensor(1, dtype=torch.uint8, device=device), + x.any()) - incorrect_byteorder = '>' if sys.byteorder == 'little' else '<' - incorrect_dtypes = map(lambda t: incorrect_byteorder + t, ['d', 'f']) + y = torch.zeros([1, 3, 400], dtype=torch.uint8, device=device) + y[-1][-1][-1] = 1 + self.assertEqual(y, x.any(0, keepdim=True)) - for dtype in correct_dtypes: - array = np.array([1, 2, 3, 4], dtype=dtype) + y = torch.zeros([2, 1, 400], dtype=torch.uint8, device=device) + y[-1][-1][-1] = 1 + self.assertEqual(y, x.any(1, keepdim=True)) - # Upcast - tensor = torch.DoubleTensor(array) - for i in range(len(array)): - self.assertEqual(tensor[i], array[i]) + y = torch.zeros([2, 3, 1], dtype=torch.uint8, device=device) + y[-1][-1][-1] = 1 + self.assertEqual(y, x.any(2, keepdim=True)) - if torch.cuda.is_available(): - tensor = torch.cuda.DoubleTensor(array) - for i in range(len(array)): - self.assertEqual(tensor[i], array[i]) + def test_logical_all(self, device): + x = torch.ones([2, 3, 400], dtype=torch.uint8, device=device) - # Downcast (sometimes) - tensor = torch.FloatTensor(array) - for i in range(len(array)): - self.assertEqual(tensor[i], array[i]) + self.assertEqual( + torch.tensor(1, dtype=torch.uint8, device=device), + x.all()) - tensor = torch.HalfTensor(array) - for i in range(len(array)): - self.assertEqual(tensor[i], array[i]) + self.assertEqual( + torch.ones([1, 3, 400], dtype=torch.uint8, device=device), + x.all(0, keepdim=True)) - if torch.cuda.is_available(): - tensor = torch.cuda.FloatTensor(array) - for i in range(len(array)): - self.assertEqual(tensor[i], array[i]) + self.assertEqual( + torch.ones([2, 1, 400], dtype=torch.uint8, device=device), + x.all(1, keepdim=True)) - tensor = torch.cuda.HalfTensor(array) - for i in range(len(array)): - self.assertEqual(tensor[i], array[i]) + self.assertEqual( + torch.ones([2, 3, 1], dtype=torch.uint8, device=device), + x.all(2, keepdim=True)) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_ctor_with_numpy_scalar_ctor(self): - dtypes = [ - np.double, - np.float, - np.float16, - np.int64, - np.int32, - np.int16, - np.uint8, - np.bool, - ] - for dtype in dtypes: - self.assertEqual(dtype(42), torch.tensor(dtype(42)).item()) + # set the last element to 0 + x[-1][-1][-1] = 0 - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_numpy_index(self): - i = np.int32([0, 1, 2]) - x = torch.randn(5, 5) - for idx in i: - self.assertFalse(isinstance(idx, int)) - self.assertEqual(x[idx], x[int(idx)]) + self.assertEqual( + torch.tensor(0, dtype=torch.uint8, device=device), + x.all()) + + y = torch.ones([1, 3, 400], dtype=torch.uint8, device=device) + y[-1][-1][-1] = 0 + self.assertEqual(y, x.all(0, keepdim=True)) + + y = torch.ones([2, 1, 400], dtype=torch.uint8, device=device) + y[-1][-1][-1] = 0 + self.assertEqual(y, x.all(1, keepdim=True)) + + y = torch.ones([2, 3, 1], dtype=torch.uint8, device=device) + y[-1][-1][-1] = 0 + self.assertEqual(y, x.all(2, keepdim=True)) + + def test_log_normal(self, device): + a = torch.tensor([10], dtype=torch.float, device=device).log_normal_() + self.assertEqual(a.dtype, torch.float) + self.assertEqual(a.size(), torch.Size([1])) + + def test_geometric(self, device): + a = torch.tensor([10], dtype=torch.float, device=device).geometric_(0.5) + self.assertEqual(a.dtype, torch.float) + self.assertEqual(a.size(), torch.Size([1])) + + def test_pairwise_distance_empty(self, device): + shape = (2, 0) + x = torch.randn(shape, device=device) + y = torch.randn(shape, device=device) + + self.assertEqual(torch.zeros(2, device=device), torch.pairwise_distance(x, y)) + self.assertEqual(torch.zeros((2, 1), device=device), torch.pairwise_distance(x, y, keepdim=True)) + + shape = (0, 2) + x = torch.randn(shape, device=device) + y = torch.randn(shape, device=device) + self.assertEqual(torch.zeros(0, device=device), torch.pairwise_distance(x, y)) + self.assertEqual(torch.zeros((0, 1), device=device), torch.pairwise_distance(x, y, keepdim=True)) + + def test_pdist_empty(self, device): + shape = (0, 2) + x = torch.randn(shape, device=device) + self.assertEqual(torch.empty(0, device=device), torch.pdist(x)) + + shape = (1, 2) + x = torch.randn(shape, device=device) + self.assertEqual(torch.empty(0, device=device), torch.pdist(x)) + + shape = (3, 0) + x = torch.randn(shape, device=device) + self.assertEqual(torch.zeros(3, device=device), torch.pdist(x)) + + def test_cdist_empty(self, device): + x = torch.randn((0, 5), device=device) + y = torch.randn((4, 5), device=device) + self.assertEqual(torch.empty(0, 4, device=device), torch.cdist(x, y)) + + x = torch.randn((2, 5), device=device) + y = torch.randn((0, 5), device=device) + self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y)) + + x = torch.randn((2, 0), device=device) + y = torch.randn((3, 0), device=device) + self.assertEqual(torch.zeros(2, 3, device=device), torch.cdist(x, y)) + + x = torch.randn((2, 0), device=device) + y = torch.randn((0, 0), device=device) + self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y)) + + def test_cdist_norm(self, device): + for r1 in [3, 4, 5, 6]: + for m in [2, 3, 4, 10]: + for r2 in [4, 6, 7, 8]: + for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: + x = torch.randn(r1, m, device=device) + y = torch.randn(r2, m, device=device) + actual = torch.cdist(x, y, p=p) + expected = brute_cdist(x, y, p=p) + self.assertTrue(torch.allclose(expected, actual)) + + def test_cdist_norm_batch(self, device): + for r1 in [3, 4, 5, 6]: + for m in [2, 3, 4, 10]: + for r2 in [4, 6, 7, 8]: + for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: + x = torch.randn(2, 3, 6, r1, m, device=device) + y = torch.randn(2, 3, 6, r2, m, device=device) + actual = torch.cdist(x, y, p=p) + expected = brute_cdist(x, y, p=p) + self.assertTrue(torch.allclose(expected, actual)) + + def test_cdist_large(self, device): + x = torch.randn(1000, 10, device=device) + y = torch.randn(1000, 10, device=device) + actual = torch.cdist(x, y, p=2) + expected = brute_cdist(x, y, p=2) + self.assertTrue(torch.allclose(expected, actual)) + + def test_cdist_large_batch(self, device): + x = torch.randn(4, 3, 1000, 10, device=device) + y = torch.randn(4, 3, 1000, 10, device=device) + actual = torch.cdist(x, y, p=2) + expected = brute_cdist(x, y, p=2) + self.assertTrue(torch.allclose(expected, actual)) + + def test_cdist_non_contiguous(self, device): + x = torch.randn(5, 7, device=device).transpose(-1, -2) + y = torch.randn(5, 3, device=device).transpose(-1, -2) + actual = torch.cdist(x, y, p=2) + expected = brute_cdist(x, y, p=2) + self.assertFalse(x.is_contiguous()) + self.assertFalse(y.is_contiguous()) + self.assertTrue(torch.allclose(expected, actual)) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_numpy_array_interface(self): - types = [ - torch.DoubleTensor, - torch.FloatTensor, - torch.HalfTensor, - torch.LongTensor, - torch.IntTensor, - torch.ShortTensor, - torch.ByteTensor, - ] - dtypes = [ - np.float64, - np.float32, - np.float16, - np.int64, - np.int32, - np.int16, - np.uint8, - ] - for tp, dtype in zip(types, dtypes): - if np.dtype(dtype).kind == 'u': - x = torch.Tensor([1, 2, 3, 4]).type(tp) - array = np.array([1, 2, 3, 4], dtype=dtype) - else: - x = torch.Tensor([1, -2, 3, -4]).type(tp) - array = np.array([1, -2, 3, -4], dtype=dtype) + x = torch.randn(7, 5, device=device) + y = torch.randn(5, 3, device=device).t() + actual = torch.cdist(x, y, p=2) + expected = brute_cdist(x, y, p=2) + self.assertTrue(x.is_contiguous()) + self.assertFalse(y.is_contiguous()) + self.assertTrue(torch.allclose(expected, actual)) - # Test __array__ w/o dtype argument - asarray = np.asarray(x) - self.assertIsInstance(asarray, np.ndarray) - self.assertEqual(asarray.dtype, dtype) - for i in range(len(x)): - self.assertEqual(asarray[i], x[i]) + x = torch.randn(5, 7, device=device).t() + y = torch.randn(3, 5, device=device) + actual = torch.cdist(x, y, p=2) + expected = brute_cdist(x, y, p=2) + self.assertFalse(x.is_contiguous()) + self.assertTrue(y.is_contiguous()) + self.assertTrue(torch.allclose(expected, actual)) + + def test_cdist_non_contiguous_batch(self, device): + x = torch.randn(4, 3, 2, 5, 7, device=device).transpose(-1, -2) + y = torch.randn(4, 3, 2, 5, 3, device=device).transpose(-1, -2) + actual = torch.cdist(x, y, p=2) + expected = brute_cdist(x, y, p=2) + self.assertFalse(x.is_contiguous()) + self.assertFalse(y.is_contiguous()) + self.assertTrue(torch.allclose(expected, actual)) + + x = torch.randn(7, 2, 7, 5, device=device) + y = torch.randn(7, 2, 5, 3, device=device).transpose(-1, -2) + actual = torch.cdist(x, y, p=2) + expected = brute_cdist(x, y, p=2) + self.assertTrue(x.is_contiguous()) + self.assertFalse(y.is_contiguous()) + self.assertTrue(torch.allclose(expected, actual)) + + x = torch.randn(4, 5, 7, device=device).transpose(-1, -2) + y = torch.randn(4, 3, 5, device=device) + actual = torch.cdist(x, y, p=2) + expected = brute_cdist(x, y, p=2) + self.assertFalse(x.is_contiguous()) + self.assertTrue(y.is_contiguous()) + self.assertTrue(torch.allclose(expected, actual)) + + def test_multinomial_constraints(self, device): + x = torch.empty(1, 2, 3, dtype=torch.double, device=device) + self.assertRaisesRegex( + RuntimeError, "prob_dist must be 1 or 2 dim", + lambda: torch.multinomial(x, 2)) + x = torch.empty(1, 2, dtype=torch.long, device=device) + self.assertRaisesRegex( + RuntimeError, "multinomial only supports floating-point dtypes for input", + lambda: torch.multinomial(x, 2)) + x = torch.empty(1, 2, dtype=torch.double, device=device) + y = torch.empty(1, 2, dtype=torch.double, device=device) + self.assertRaisesRegex( + RuntimeError, "multinomial expects Long tensor out", + lambda: torch.multinomial(x, 2, out=y)) + x = torch.empty(2, dtype=torch.double, device=device) + self.assertRaisesRegex( + RuntimeError, "cannot sample n_sample <= 0 samples", + lambda: torch.multinomial(x, 0)) + x = torch.empty(2, dtype=torch.double, device=device) + self.assertRaisesRegex( + RuntimeError, "cannot sample n_sample <= 0 samples", + lambda: torch.multinomial(x, -1)) + x = torch.empty(2, dtype=torch.double, device=device) + self.assertRaisesRegex( + RuntimeError, "cannot sample n_sample > prob_dist", + lambda: torch.multinomial(x, 3, False)) + x = torch.empty(16777217, dtype=torch.double, device=device) + self.assertRaisesRegex( + RuntimeError, "number of categories cannot exceed", + lambda: torch.multinomial(x, 3)) - # Test __array_wrap__, same dtype - abs_x = np.abs(x) - abs_array = np.abs(array) - self.assertIsInstance(abs_x, tp) - for i in range(len(x)): - self.assertEqual(abs_x[i], abs_array[i]) + def test_add(self, device): + # [res] torch.add([res,] tensor1, tensor2) + m1 = torch.randn(100, 100, device=device) + v1 = torch.randn(100, device=device) - # Test __array__ with dtype argument - for dtype in dtypes: - x = torch.IntTensor([1, -2, 3, -4]) - asarray = np.asarray(x, dtype=dtype) - self.assertEqual(asarray.dtype, dtype) - if np.dtype(dtype).kind == 'u': - wrapped_x = np.array([1, -2, 3, -4], dtype=dtype) - for i in range(len(x)): - self.assertEqual(asarray[i], wrapped_x[i]) - else: - for i in range(len(x)): - self.assertEqual(asarray[i], x[i]) + # contiguous + res1 = torch.add(m1[4], v1) + res2 = res1.clone().zero_() + for i in range(m1.size(1)): + res2[i] = m1[4, i] + v1[i] + self.assertEqual(res1, res2) - # Test some math functions with float types - float_types = [torch.DoubleTensor, torch.FloatTensor] - float_dtypes = [np.float64, np.float32] - for tp, dtype in zip(float_types, float_dtypes): - x = torch.Tensor([1, 2, 3, 4]).type(tp) - array = np.array([1, 2, 3, 4], dtype=dtype) - for func in ['sin', 'sqrt', 'ceil']: - ufunc = getattr(np, func) - res_x = ufunc(x) - res_array = ufunc(array) - self.assertIsInstance(res_x, tp) - for i in range(len(x)): - self.assertEqual(res_x[i], res_array[i]) + m1 = torch.randn(100, 100, device=device) + v1 = torch.randn(100, device=device) - # Test functions with boolean return value - for tp, dtype in zip(types, dtypes): - x = torch.Tensor([1, 2, 3, 4]).type(tp) - array = np.array([1, 2, 3, 4], dtype=dtype) - geq2_x = np.greater_equal(x, 2) - geq2_array = np.greater_equal(array, 2).astype('uint8') - self.assertIsInstance(geq2_x, torch.ByteTensor) - for i in range(len(x)): - self.assertEqual(geq2_x[i], geq2_array[i]) + # non-contiguous + res1 = torch.add(m1[:, 4], v1) + res2 = res1.clone().zero_() + for i in range(m1.size(0)): + res2[i] = m1[i, 4] + v1[i] + self.assertEqual(res1, res2) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_multiplication_numpy_scalar(self): - for np_dtype in [np.float32, np.float64, np.int32, np.int64, np.int16, np.uint8]: - for t_dtype in [torch.float, torch.double]: - np_sc = np_dtype(2.0) - t = torch.ones(2, requires_grad=True, dtype=t_dtype) - r1 = t * np_sc - self.assertIsInstance(r1, torch.Tensor) - self.assertTrue(r1.dtype == t_dtype) - self.assertTrue(r1.requires_grad) - r2 = np_sc * t - self.assertIsInstance(r2, torch.Tensor) - self.assertTrue(r2.dtype == t_dtype) - self.assertTrue(r2.requires_grad) + # [res] torch.add([res,] tensor, value) + m1 = torch.randn(10, 10, device=device) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_trapz(self): - def test_dx(sizes, dim, dx, device): - t = torch.randn(sizes, device=device) - actual = torch.trapz(t, dx=dx, dim=dim) - expected = np.trapz(t.cpu().numpy(), dx=dx, axis=dim) - self.assertEqual(expected.shape, actual.shape) - self.assertTrue(np.allclose(expected, actual.cpu().numpy())) + # contiguous + res1 = m1.clone() + res1[3].add_(2) + res2 = m1.clone() + for i in range(m1.size(1)): + res2[3, i] = res2[3, i] + 2 + self.assertEqual(res1, res2) - def test_x(sizes, dim, x, device): - t = torch.randn(sizes, device=device) - actual = torch.trapz(t, x=torch.tensor(x, device=device), dim=dim) - expected = np.trapz(t.cpu().numpy(), x=x, axis=dim) - self.assertEqual(expected.shape, actual.shape) - self.assertTrue(np.allclose(expected, actual.cpu().numpy())) + # non-contiguous + m1 = torch.randn(10, 10, device=device) + res1 = m1.clone() + res1[:, 3].add_(2) + res2 = m1.clone() + for i in range(m1.size(0)): + res2[i, 3] = res2[i, 3] + 2 + self.assertEqual(res1, res2) - for device in torch.testing.get_all_device_types(): - test_dx((2, 3, 4), 1, 1, device) - test_dx((10, 2), 0, 0.1, device) - test_dx((1, 10), 0, 2.3, device) - test_dx((0, 2), 0, 1.0, device) - test_dx((0, 2), 1, 1.0, device) - test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device) - test_x((10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device) - test_x((1, 10), 0, [1.0], device) - test_x((0, 2), 0, [], device) - test_x((0, 2), 1, [1.0, 2.0], device) - with self.assertRaisesRegex( - IndexError, - 'Dimension out of range'): - test_x((2, 3), 2, [], device) - test_dx((2, 3), 2, 1.0, device) - with self.assertRaisesRegex( - RuntimeError, - 'There must be one `x` value for each sample point'): - test_x((2, 3), 1, [1.0, 2.0], device) - test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device) + # inter-type + m1 = torch.randn(10, 10, device=device) + self.assertEqual(m1 + 3, m1 + torch.tensor(3)) + self.assertEqual(3 + m1, torch.tensor(3) + m1) + one = torch.tensor(1, dtype=torch.uint8, device=device) + self.assertEqual(torch.add(one, 1), 2) + self.assertEqual(torch.add(one, 1).dtype, torch.uint8) + + # contiguous + non-contiguous + m1 = torch.randn(10, 10, device=device) + m2 = torch.randn(10, 10, device=device).t() + res = m1 + m2 + self.assertTrue(res.is_contiguous()) + self.assertEqual(res, m1 + m2.contiguous()) + + # 1d + empty + m1 = torch.tensor([1.0], dtype=torch.float, device=device) + m2 = torch.tensor([], dtype=torch.float, device=device) + self.assertEqual(m1 + m2, []) + + # bool + m1 = torch.tensor([True, False, False, True, False, False], dtype=torch.bool, device=device) + m2 = torch.tensor([True, True, False, False, False, True], dtype=torch.bool, device=device) + expected = torch.tensor([True, True, False, True, False, True], dtype=torch.bool, device=device) + self.assertEqual(m1 + m2, expected) + + # fused multiply add + a = torch.zeros(2, 3, dtype=torch.bool, device=device) + res = torch.add(a, a, alpha=0) + expected = torch.zeros(2, 3, device=device).bool() + self.assertEqual(res, expected) + + # bfloat16 + m1 = torch.tensor([1., 2.], dtype=torch.bfloat16) + m2 = torch.tensor([3., 4.], dtype=torch.bfloat16) + self.assertEqual(m1 + m2, torch.tensor([4., 6.], dtype=torch.bfloat16)) + + def test_bool_sub(self, device): + m1 = torch.tensor([True, False, False, True, False, False], dtype=torch.bool, device=device) + m2 = torch.tensor([True, True, False, False, False, True], dtype=torch.bool, device=device) + self.assertRaisesRegex(RuntimeError, + r"Subtraction, the `\-` operator, with two bool tensors is not supported. " + r"Use the `\^` or `logical_xor\(\)` operator instead.", + lambda: m1 - m2) + self.assertRaisesRegex(RuntimeError, + r"Subtraction, the `\-` operator, with a bool tensor is not supported. " + r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.", + lambda: 1 - m1) + self.assertRaisesRegex(RuntimeError, + r"Subtraction, the `\-` operator, with a bool tensor is not supported. " + r"If you are trying to invert a mask, use the `\~` or `logical_not\(\)` operator instead.", + lambda: m2 - 1) + + def test_mul(self, device): + m1 = torch.randn(10, 10, device=device) + res1 = m1.clone() + res1[:, 3].mul_(2) + res2 = m1.clone() + for i in range(res1.size(0)): + res2[i, 3] = res2[i, 3] * 2 + self.assertEqual(res1, res2) - def test_error_msg_type_translation(self): - with self.assertRaisesRegex( - RuntimeError, - # message includes both Double and Long - '(?=.*Double)(?=.*Long)'): + a1 = torch.tensor([True, False, False, True], dtype=torch.bool, device=device) + a2 = torch.tensor([True, False, True, False], dtype=torch.bool, device=device) + self.assertEqual(a1 * a2, torch.tensor([True, False, False, False], dtype=torch.bool, device=device)) - # Calls model with a DoubleTensor input but LongTensor weights - input = torch.autograd.Variable(torch.randn(1, 1, 1, 6).double()) - weight = torch.zeros(1, 1, 1, 3).long() - model = torch.nn.Conv2d(1, 1, (1, 3), stride=1, padding=0, bias=False) - model.weight.data = weight - out = model(input) + if device == 'cpu': + a1 = torch.tensor([0.1, 0.1], dtype=torch.bfloat16, device=device) + a2 = torch.tensor([1.1, 0.1], dtype=torch.bfloat16, device=device) + self.assertEqual(a1 * a2, torch.tensor([0.11, 0.01], dtype=torch.bfloat16, device=device), 0.01) + self.assertEqual(a1.mul(a2), a1 * a2) + + def test_cumsum(self, device): + x = torch.rand(100, 100, device=device) + res1 = torch.cumsum(x, 1) + res2 = torch.Tensor().to(device) + torch.cumsum(x, 1, out=res2) + self.assertEqual(res1, res2) - def test_tensor_from_sequence(self): - class MockSequence(object): - def __init__(self, lst): - self.lst = lst + a = torch.tensor([[True, False, True], + [False, False, False], + [True, True, True]], device=device) + b = a.byte() + aRes = torch.cumsum(a, 0) + bRes = torch.cumsum(b, 0) + self.assertEqual(aRes, bRes) + self.assertEqual(aRes, torch.tensor([[1, 0, 1], + [1, 0, 1], + [2, 1, 2]])) + + aRes = torch.cumsum(a, 1) + bRes = torch.cumsum(b, 1) + self.assertEqual(aRes, bRes) + self.assertEqual(aRes, torch.tensor([[1, 1, 2], + [0, 0, 0], + [1, 2, 3]])) + + def test_cumprod(self, device): + x = torch.rand(100, 100, device=device) + res1 = torch.cumprod(x, 1) + res2 = torch.Tensor().to(device) + torch.cumprod(x, 1, out=res2) + self.assertEqual(res1, res2) - def __len__(self): - return len(self.lst) + a = torch.tensor([[True, False, True], + [False, False, False], + [True, True, True]], dtype=torch.bool, device=device) + b = a.byte() + aRes = torch.cumprod(a, 0) + bRes = torch.cumprod(b, 0) + self.assertEqual(aRes, bRes) + self.assertEqual(aRes, torch.tensor([[1, 0, 1], + [0, 0, 0], + [0, 0, 0]])) + + aRes = torch.cumprod(a, 1) + bRes = torch.cumprod(b, 1) + self.assertEqual(aRes, bRes) + self.assertEqual(aRes, torch.tensor([[1, 0, 0], + [0, 0, 0], + [1, 1, 1]])) + + def test_std_mean(self, device): + x = torch.rand(100, 50, 20, device=device) + for dim in range(x.dim()): + for unbiased in [False, True]: + for keepdim in [False, True]: + std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) + std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim) + mean2 = x.mean(dim=dim, keepdim=keepdim) + self.assertEqual(std1, std2) + self.assertEqual(mean1, mean2) + + def test_std_mean_all_dims(self, device): + x = torch.rand(100, 50, 20, device=device) + for unbiased in [False, True]: + std1, mean1 = torch.std_mean(x, unbiased=unbiased) + std2 = x.std(unbiased=unbiased) + mean2 = x.mean() + self.assertEqual(std1, std2) + self.assertEqual(mean1, mean2) + + def test_var_mean(self, device): + x = torch.rand(100, 300, 50, device=device) + for dim in range(x.dim()): + for unbiased in [False, True]: + for keepdim in [False, True]: + var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) + var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim) + mean2 = x.mean(dim=dim, keepdim=keepdim) + self.assertEqual(var1, var2) + self.assertEqual(mean1, mean2) + + def test_var_mean_all_dims(self, device): + x = torch.rand(100, 50, 20, device=device) + for unbiased in [False, True]: + var1, mean1 = torch.var_mean(x, unbiased=unbiased) + var2 = x.var(unbiased=unbiased) + mean2 = x.mean() + self.assertEqual(var1, var2) + self.assertEqual(mean1, mean2) - def __getitem__(self, item): - raise TypeError + def test_std_mean_some_dims(self, device): + sizes = (4, 6, 7, 5, 3) + dims = len(sizes) + x = torch.rand(sizes, device=device) + for num_of_dims in range(2, dims): + dim_list = list(combinations(list(range(dims)), r=num_of_dims)) + for dim in dim_list: + for unbiased in [False, True]: + for keepdim in [False, True]: + std1, mean1 = torch.std_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) + std2 = x.std(dim=dim, unbiased=unbiased, keepdim=keepdim) + mean2 = x.mean(dim=dim, keepdim=keepdim) + self.assertEqual(std1, std2) + self.assertEqual(mean1, mean2) - class GoodMockSequence(MockSequence): - def __getitem__(self, item): - return self.lst[item] + def test_zeros_like(self, device): + expected = torch.zeros((100, 100,), device=device) - bad_mock_seq = MockSequence([1.0, 2.0, 3.0]) - good_mock_seq = GoodMockSequence([1.0, 2.0, 3.0]) - with self.assertRaisesRegex(ValueError, 'could not determine the shape'): - torch.Tensor(bad_mock_seq) - self.assertEqual(torch.Tensor([1.0, 2.0, 3.0]), torch.Tensor(good_mock_seq)) + res1 = torch.zeros_like(expected) + self.assertEqual(res1, expected) - def test_comparison_ops(self): - x = torch.randn(5, 5) - y = torch.randn(5, 5) + def test_histc(self, device): + # negative nbins throws + with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'): + torch.histc(torch.tensor([1], dtype=torch.float, device=device), bins=-1) + + # without nbins + actual = torch.histc( + torch.tensor([2, 5], dtype=torch.float, device=device)) + expected = torch.zeros(100, dtype=torch.float, device=device) + expected.data[0] = 1 + expected.data[99] = 1 + self.assertEqual(expected, actual) + # tensor with the same element + actual = torch.histc(torch.ones(5, dtype=torch.float, device=device), bins=5) + self.assertEqual( + torch.tensor([0, 0, 5, 0, 0], dtype=torch.float, device=device), + actual) + # no element falls between [min, max] + actual = torch.histc( + torch.ones(5, dtype=torch.float, device=device), bins=5, min=2, max=3) + self.assertEqual( + torch.tensor([0, 0, 0, 0, 0], dtype=torch.float, device=device), + actual) + # element falls below min + integral bin size and + actual = torch.histc( + torch.tensor([2, 4, 2, 2, 5, 4], dtype=torch.float, device=device), + bins=5, min=1, max=5) + self.assertEqual( + torch.tensor([0, 3, 0, 2, 1], dtype=torch.float, device=device), + actual) + # non-integral bin size + actual = torch.histc( + torch.tensor([1, 2, 1], dtype=torch.float, device=device), + bins=4, min=0, max=3) + self.assertEqual( + torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device), + actual) + # double input + actual = torch.histc( + torch.tensor([1, 2, 1], dtype=torch.double, device=device), bins=4, min=0, max=3) + self.assertEqual( + torch.tensor([0, 2, 1, 0], dtype=torch.double, device=device), + actual) + self.assertEqual(actual.dtype, torch.double) + # mixed input + actual = torch.histc( + torch.tensor([1., 2, 1], dtype=torch.float, device=device), + bins=4, min=0, max=3) + self.assertEqual( + torch.tensor([0, 2, 1, 0], dtype=torch.float, device=device), + actual) + self.assertEqual(actual.dtype, torch.float) + # scalar input and 1 bin -- should return a 1-dimensional tensor, not a scalar. + actual = torch.histc( + torch.tensor(0, dtype=torch.float, device=device), + bins=1, min=0, max=3) + self.assertEqual( + torch.tensor([1], dtype=torch.float, device=device), + actual) - eq = x == y - for idx in iter_indices(x): - self.assertEqual(x[idx] == y[idx], eq[idx] == 1) + # test against numpy.histogram() + def test_against_np(tensor, bins=100, min=0, max=0): + if min == 0 and max == 0: + min = tensor.min().item() + max = tensor.max().item() + nparr = tensor.cpu().numpy() + actual = torch.histc(tensor, bins=bins, min=min, max=max) + expected = torch.from_numpy(np.histogram(nparr, bins=bins, range=(min, max))[0]) + self.assertEqual(actual.cpu(), expected) - ne = x != y - for idx in iter_indices(x): - self.assertEqual(x[idx] != y[idx], ne[idx] == 1) + if TEST_NUMPY: + test_against_np(torch.tensor([1., 2, 1], device=device)) + test_against_np(torch.randn(5000, device=device)) - lt = x < y - for idx in iter_indices(x): - self.assertEqual(x[idx] < y[idx], lt[idx] == 1) + # Test bins arg + test_against_np(torch.randn(301, device=device), bins=10) - le = x <= y - for idx in iter_indices(x): - self.assertEqual(x[idx] <= y[idx], le[idx] == 1) + # Test truncated range + test_against_np(torch.randn(201, device=device), min=0.1, max=1) - gt = x > y - for idx in iter_indices(x): - self.assertEqual(x[idx] > y[idx], gt[idx] == 1) + noncontig = torch.randn(100, 3, device=device)[:, 2] + test_against_np(noncontig) - ge = x >= y - for idx in iter_indices(x): - self.assertEqual(x[idx] >= y[idx], ge[idx] == 1) + multidim = torch.randn(3, 5, 7, 2, device=device) + test_against_np(multidim) - def test_bitwise_ops(self): - x = torch.randn(5, 5).gt(0) - y = torch.randn(5, 5).gt(0) + expanded = torch.randn(1, 5, 1, 2, device=device).expand(3, 5, 7, 2) + test_against_np(expanded) - and_result = x & y - for idx in iter_indices(x): - if and_result[idx]: - self.assertTrue(x[idx] and y[idx]) - else: - self.assertFalse(x[idx] and y[idx]) + def test_bool_tensor_comparison_ops(self, device): + a = torch.tensor([True, False, True, False, True, False], dtype=torch.bool, device=device) + b = torch.tensor([True, False, True, True, True, True], dtype=torch.bool, device=device) + self.assertEqual(a == b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)) + self.assertEqual(a != b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)) + self.assertEqual(a < b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.bool, device=device)) + self.assertEqual(a > b, torch.tensor([0, 0, 0, 0, 0, 0], dtype=torch.bool, device=device)) + self.assertEqual(a >= b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.bool, device=device)) + self.assertEqual(a <= b, torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.bool, device=device)) + self.assertEqual(a > False, torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device)) + self.assertEqual(a == torch.tensor(True, dtype=torch.bool, device=device), + torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.bool, device=device)) + self.assertEqual(a == torch.tensor(0, dtype=torch.bool, device=device), + torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool, device=device)) + self.assertFalse(a.equal(b)) + + def test_bool_tensor_value_change(self, device): + x = torch.tensor([True, False], dtype=torch.bool, device=device) + x[0] = False + x[1] = True + self.assertEqual(x, torch.tensor([False, True], dtype=torch.bool, device=device)) - or_result = x | y - for idx in iter_indices(x): - if or_result[idx]: - self.assertTrue(x[idx] or y[idx]) - else: - self.assertFalse(x[idx] or y[idx]) + def test_unfold_all_devices_and_dtypes(self, device): + for dt in torch.testing.get_all_dtypes(): + if dt == torch.bfloat16: + self.assertRaises(RuntimeError, lambda: torch.randint(5, (0, 1, 3, 0), dtype=dt, device=device)) + continue - xor_result = x ^ y - for idx in iter_indices(x): - if xor_result[idx]: - self.assertTrue(x[idx] ^ y[idx]) + if dt == torch.half and device == 'cpu': + # fix once random is implemented for Half on CPU + self.assertRaises(RuntimeError, lambda: torch.randint(5, (0, 1, 3, 0), dtype=dt, device=device)) else: - self.assertFalse(x[idx] ^ y[idx]) + x = torch.randint(5, (0, 1, 3, 0), dtype=dt, device=device) + self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape) - x_clone = x.clone() - x_clone &= y - self.assertEqual(x_clone, and_result) + def test_copy_all_dtypes_and_devices(self, device): + from copy import copy + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor([1, 2, 3, 4], dtype=dt, device=device) + x_clone = x.clone() + if (device == 'cuda' and dt == torch.bfloat16): + self.assertRaises(RuntimeError, lambda: copy(x)) + continue + y = copy(x) + y.fill_(1) + # copy is a shallow copy, only copies the tensor view, + # not the data + self.assertEqual(x, y) + + def test_resize_all_dtypes_and_devices(self, device): + shape = (2, 2) + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) + x.resize_(shape) + self.assertEqual(shape, x.shape) + + def test_resize_as_all_dtypes_and_devices(self, device): + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) + y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device) + x.resize_as_(y) + self.assertEqual(y.shape, x.shape) + + def test_view_all_dtypes_and_devices(self, device): + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) + if (device == 'cuda' and dt == torch.bfloat16): + self.assertRaises(RuntimeError, lambda: x.view(6)) + continue + self.assertEqual(x.view(6).shape, [6]) + + def test_fill_all_dtypes_and_devices(self, device): + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor((1, 1), dtype=dt, device=device) + if (device == 'cuda' and dt == torch.bfloat16): + self.assertRaises(RuntimeError, lambda: x.fill_(1)) + continue + x.fill_(1) + + self.assertEqual(x, torch.tensor([1, 1], dtype=dt, device=device)) + self.assertEqual(dt, x.dtype) + + def test_clone_all_dtypes_and_devices(self, device): + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor((1, 1), dtype=dt, device=device) + y = x.clone() + if (device == 'cuda' and dt == torch.bfloat16): + # `x - y` is used inside of the assertEqual + self.assertRaises(RuntimeError, lambda: x - y) + continue + self.assertEqual(x, y) + + def test_cat_all_dtypes_and_devices(self, device): + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor([[1, 2], [3, 4]], dtype=dt, device=device) + if (device == 'cuda' and dt == torch.bfloat16): + self.assertRaises(RuntimeError, lambda: torch.cat((x, x), 0)) + continue + + expected1 = torch.tensor([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=dt, device=device) + self.assertEqual(torch.cat((x, x), 0), expected1) + + expected2 = torch.tensor([[1, 2, 1, 2], [3, 4, 3, 4]], dtype=dt, device=device) + self.assertEqual(torch.cat((x, x), 1), expected2) + + def test_tensor_factories_empty(self, device): + # ensure we can create empty tensors from each factory function + shapes = [(5, 0, 1), (0,), (0, 0, 1, 0, 2, 0, 0)] + + for shape in shapes: + for dt in torch.testing.get_all_dtypes(): - x_clone = x.clone() - x_clone |= y - self.assertEqual(x_clone, or_result) + if (device == 'cuda' and dt == torch.bfloat16): + self.assertRaises(RuntimeError, lambda: torch.zeros(shape, device=device, dtype=dt).shape) + self.assertRaises(RuntimeError, lambda: torch.zeros_like(torch.zeros(shape, device=device, dtype=dt)).shape) + self.assertRaises(RuntimeError, lambda: torch.full(shape, 3, device=device, dtype=dt).shape) + self.assertRaises(RuntimeError, lambda: torch.full_like(torch.zeros(shape, device=device, dtype=dt), 3)) + self.assertRaises(RuntimeError, lambda: torch.ones(shape, device=device, dtype=dt).shape) + self.assertRaises(RuntimeError, lambda: torch.ones_like(torch.zeros(shape, device=device, dtype=dt)).shape) + self.assertRaises(RuntimeError, lambda: torch.empty_like(torch.zeros(shape, device=device, dtype=dt)).shape) + else: + self.assertEqual(shape, torch.zeros(shape, device=device, dtype=dt).shape) + self.assertEqual(shape, torch.zeros_like(torch.zeros(shape, device=device, dtype=dt)).shape) + self.assertEqual(shape, torch.full(shape, 3, device=device, dtype=dt).shape) + self.assertEqual(shape, torch.full_like(torch.zeros(shape, device=device, dtype=dt), 3).shape) + self.assertEqual(shape, torch.ones(shape, device=device, dtype=dt).shape) + self.assertEqual(shape, torch.ones_like(torch.zeros(shape, device=device, dtype=dt)).shape) + self.assertEqual(shape, torch.empty(shape, device=device, dtype=dt).shape) + self.assertEqual(shape, torch.empty_like(torch.zeros(shape, device=device, dtype=dt)).shape) + self.assertEqual(shape, torch.empty_strided(shape, (0,) * len(shape), device=device, dtype=dt).shape) + + if dt == torch.half and device == "cpu": + # update once random is implemented for half on CPU + self.assertRaises(RuntimeError, lambda: torch.randint(6, shape, device=device, dtype=dt).shape) + else: + if dt == torch.bfloat16: + self.assertRaises(RuntimeError, lambda: torch.randint(6, shape, device=device, dtype=dt)) + continue # Remove once random is supported for bfloat16 on cuda + self.assertEqual(shape, torch.randint(6, shape, device=device, dtype=dt).shape) + self.assertEqual(shape, torch.randint_like(torch.zeros(shape, device=device, dtype=dt), 6).shape) + + if dt != torch.double and dt != torch.float and dt != torch.half: + self.assertRaises(RuntimeError, lambda: torch.rand(shape, device=device, dtype=dt).shape) + + if dt == torch.double or dt == torch.float: + self.assertEqual(shape, torch.randn(shape, device=device, dtype=dt).shape) + self.assertEqual(shape, torch.randn_like(torch.zeros(shape, device=device, dtype=dt)).shape) + + self.assertEqual((0,), torch.arange(0, device=device).shape) + self.assertEqual((0, 0), torch.eye(0, device=device).shape) + self.assertEqual((0, 0), torch.eye(0, 0, device=device).shape) + self.assertEqual((5, 0), torch.eye(5, 0, device=device).shape) + self.assertEqual((0, 5), torch.eye(0, 5, device=device).shape) + self.assertEqual((0,), torch.linspace(1, 1, 0, device=device).shape) + self.assertEqual((0,), torch.logspace(1, 1, 0, device=device).shape) + self.assertEqual((0,), torch.randperm(0, device=device).shape) + self.assertEqual((0,), torch.bartlett_window(0, device=device).shape) + self.assertEqual((0,), torch.bartlett_window(0, periodic=False, device=device).shape) + self.assertEqual((0,), torch.hamming_window(0, device=device).shape) + self.assertEqual((0,), torch.hann_window(0, device=device).shape) + self.assertEqual((1, 1, 0), torch.tensor([[[]]], device=device).shape) + self.assertEqual((1, 1, 0), torch.as_tensor([[[]]], device=device).shape) + + def test_eye(self, device): + for dtype in torch.testing.get_all_dtypes(): + if dtype == torch.bfloat16: + continue - x_clone = x.clone() - x_clone ^= y - self.assertEqual(x_clone, xor_result) + for n, m in product([3, 5, 7], repeat=2): + # Construct identity using diagonal and fill + res1 = torch.eye(n, m, device=device, dtype=dtype) + naive_eye = torch.zeros(n, m, dtype=dtype, device=device) + naive_eye.diagonal(dim1=-2, dim2=-1).fill_(1) + self.assertEqual(naive_eye, res1) - def test_op_invert(self): - res = 0xffff - torch.arange(127, dtype=torch.int8) - for dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): - a = torch.arange(127, dtype=dtype) - self.assertEqual(res.to(dtype), ~a) + # Check eye_out outputs + res2 = torch.empty(0, device=device, dtype=dtype) + torch.eye(n, m, out=res2) + self.assertEqual(res1, res2) - self.assertEqual(torch.tensor([True, False]), - ~torch.tensor([False, True])) + def test_addcmul(self, device): + def rand_tensor(size, dtype, device): + if dtype.is_floating_point: + return torch.rand(size=size, dtype=dtype, device=device) + if dtype == torch.uint8: + return torch.randint(1, 5, size=size, dtype=dtype, device=device) + else: + return torch.randint(-5, 5, size=size, dtype=dtype, device=device) - # test exceptions - for dtype in(torch.half, torch.float, torch.double): - a = torch.zeros(10, dtype=dtype) - with self.assertRaises(TypeError): - b = ~a + for dtype in torch.testing.get_all_math_dtypes(device): + a = rand_tensor((2, 2), dtype=dtype, device=device) + b = rand_tensor((2, 2), dtype=dtype, device=device) + c = rand_tensor((2, 2), dtype=dtype, device=device) + if dtype.is_floating_point: + alpha = 0.1 + else: + alpha = 3 + actual = torch.addcmul(a, alpha, b, c) + expected = a + alpha * b * c + self.assertTrue(torch.allclose(expected, actual)) - def test_apply(self): - x = torch.arange(1, 6) - res = x.clone().apply_(lambda k: k + k) - self.assertEqual(res, x * 2) - self.assertRaises(TypeError, lambda: x.apply_(lambda k: "str")) + def test_empty_tensor_props(self, device): + sizes = [(0,), (0, 3), (5, 0), (5, 0, 3, 0, 2), (0, 3, 0, 2), (0, 5, 0, 2, 0)] + for size in sizes: + x = torch.empty(tuple(size), device=device) + self.assertEqual(size, x.shape) + self.assertTrue(x.is_contiguous()) + size_ones_instead_of_zeros = (x if x != 0 else 1 for x in size) + y = torch.empty(tuple(size_ones_instead_of_zeros), device=device) + self.assertEqual(x.stride(), y.stride()) - def test_map(self): - x = torch.autograd.Variable(torch.randn(3, 3)) - y = torch.autograd.Variable(torch.randn(3)) - res = x.clone() - res.map_(y, lambda a, b: a + b) - self.assertEqual(res, x + y) - self.assertRaisesRegex(TypeError, "not callable", lambda: res.map_(y, "str")) + @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') + def test_tensordot(self, device): + a = torch.arange(60., device=device).reshape(3, 4, 5) + b = torch.arange(24., device=device).reshape(4, 3, 2) + c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu() + cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), + axes=([1, 0], [0, 1]))) + self.assertEqual(c, cn) + a = torch.randn(2, 3, 4, 5, device=device) + b = torch.randn(4, 5, 6, 7, device=device) + c = torch.tensordot(a, b, dims=2).cpu() + cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), + axes=2)) + self.assertEqual(c, cn) + c = torch.tensordot(a, b).cpu() + cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy())) + self.assertEqual(c, cn) + + def test_narrow_empty(self, device): + x = torch.randn(2, 3, 4, device=device) + for d in range(x.dim()): + y = x.narrow(d, x.size(d), 0) + sz = list(x.size()) + sz[d] = 0 + self.assertEqual(sz, y.size()) + + def test_linspace(self, device): + _from = random.random() + to = _from + random.random() + res1 = torch.linspace(_from, to, 137, device=device) + res2 = torch.tensor((), device=device) + torch.linspace(_from, to, 137, out=res2) + self.assertEqual(res1, res2, 0) + self.assertRaises(RuntimeError, lambda: torch.linspace(0, 1, -1, device=device)) + self.assertEqual(torch.linspace(0, 1, 1, device=device), torch.zeros(1, device=device), 0) - def test_map2(self): - x = torch.autograd.Variable(torch.randn(3, 3)) - y = torch.autograd.Variable(torch.randn(3)) - z = torch.autograd.Variable(torch.randn(1, 3)) - res = x.clone() - res.map2_(y, z, lambda a, b, c: a + b * c) - self.assertEqual(res, x + y * z) - z.requires_grad = True - self.assertRaisesRegex( - RuntimeError, "requires grad", - lambda: res.map2_(y, z, lambda a, b, c: a + b * c)) + # Check linspace for generating with start > end. + self.assertEqual(torch.linspace(2, 0, 3, device=device), torch.tensor((2, 1, 0), device=device), 0) - def test_Size(self): - x = torch.Size([1, 2, 3]) - self.assertIsInstance(x, tuple) - self.assertEqual(x[0], 1) - self.assertEqual(x[1], 2) - self.assertEqual(x[2], 3) - self.assertEqual(len(x), 3) - self.assertRaises(TypeError, lambda: torch.Size(torch.ones(3))) + # Check linspace for non-contiguous tensors. + x = torch.zeros(2, 3, device=device) + y = torch.linspace(0, 3, 4, out=x.narrow(1, 1, 2)) + self.assertEqual(x, torch.tensor(((0, 0, 1), (0, 2, 3)), device=device), 0) - self.assertIsInstance(x * 2, torch.Size) - self.assertIsInstance(x[:-1], torch.Size) - self.assertIsInstance(x + x, torch.Size) + def test_logical(self, device): + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor([1, 2, 3, 4], device=device, dtype=dt) + b = torch.tensor([2], device=device, dtype=dt) - def test_Size_scalar(self): - three = torch.tensor(3) - two = torch.tensor(2) - x = torch.Size([0, 1, two, three, 4]) - for i in range(1, 5): - self.assertEqual(x[i], i) + if dt == torch.half and device == 'cpu': + self.assertRaises(RuntimeError, lambda: x.lt(2)) + continue - def test_Size_iter(self): - for sizes in [iter([1, 2, 3, 4, 5]), range(1, 6)]: - x = torch.Size(sizes) - for i in range(0, 5): - self.assertEqual(x[i], i + 1) + if dt == torch.bool: + # torch.bool is a special case and is being tested later + # in this test + continue - def test_t_not_2d_error(self): - self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t()) - self.assertRaises(RuntimeError, lambda: torch.randn(2, 3, 4).t_()) + if device == 'cuda' and dt == torch.bfloat16: + self.assertRaises(RuntimeError, lambda: x > b) + self.assertRaises(RuntimeError, lambda: x < b) + self.assertRaises(RuntimeError, lambda: x == b) + self.assertRaises(RuntimeError, lambda: x != b) + self.assertRaises(RuntimeError, lambda: x >= b) + self.assertRaises(RuntimeError, lambda: x <= b) + continue - # unit test for special case transposed copy (see ATen/native/Copy.cpp for details) - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") - def test_big_transpose(self): - t = torch.rand(456, 789) - t1 = t.t().contiguous() - t2 = torch.from_numpy(t.numpy().transpose()) - self.assertEqual(t1, t2) + self.assertEqual(x.lt(2), torch.tensor([True, False, False, False])) + self.assertEqual(x.le(2), torch.tensor([True, True, False, False])) + self.assertEqual(x.ge(2), torch.tensor([False, True, True, True])) + self.assertEqual(x.gt(2), torch.tensor([False, False, True, True])) + self.assertEqual(x.eq(2), torch.tensor([False, True, False, False])) + self.assertEqual(x.ne(2), torch.tensor([True, False, True, True])) + + self.assertEqual(x.lt(b), torch.tensor([True, False, False, False])) + self.assertEqual(x.le(b), torch.tensor([True, True, False, False])) + self.assertEqual(x.ge(b), torch.tensor([False, True, True, True])) + self.assertEqual(x.gt(b), torch.tensor([False, False, True, True])) + self.assertEqual(x.eq(b), torch.tensor([False, True, False, False])) + self.assertEqual(x.ne(b), torch.tensor([True, False, True, True])) + + with warnings.catch_warnings(record=True) as warningsCount: + byteRes = torch.empty_like(x, device=device).byte() + boolRes = torch.empty_like(x, device=device).bool() + + torch.lt(x, b, out=byteRes) + torch.lt(x, b, out=boolRes) + self.assertEqual(byteRes.bool(), boolRes) + + torch.le(x, b, out=byteRes) + torch.le(x, b, out=boolRes) + self.assertEqual(byteRes.bool(), boolRes) + + torch.ge(x, b, out=byteRes) + torch.ge(x, b, out=boolRes) + self.assertEqual(byteRes.bool(), boolRes) + + torch.gt(x, b, out=byteRes) + torch.gt(x, b, out=boolRes) + self.assertEqual(byteRes.bool(), boolRes) + + torch.eq(x, b, out=byteRes) + torch.eq(x, b, out=boolRes) + self.assertEqual(byteRes.bool(), boolRes) + + torch.ne(x, b, out=byteRes) + torch.ne(x, b, out=boolRes) + self.assertEqual(byteRes.bool(), boolRes) + + self.assertEquals(len(warningsCount), 6) + + # Bool Tensor + x = torch.tensor([True, False, True, False], device=device) + self.assertEqual(x.lt(True), torch.tensor([False, True, False, True])) + self.assertEqual(x.le(True), torch.tensor([True, True, True, True])) + self.assertEqual(x.ge(True), torch.tensor([True, False, True, False])) + self.assertEqual(x.gt(True), torch.tensor([False, False, False, False])) + self.assertEqual(x.eq(True), torch.tensor([True, False, True, False])) + self.assertEqual(x.ne(True), torch.tensor([False, True, False, True])) + + def test_index_copy(self, device): + num_copy, num_dest = 3, 20 + dest = torch.randn(num_dest, 4, 5, device=device) + src = torch.randn(num_copy, 4, 5, device=device) + idx = torch.randperm(num_dest, device=device).narrow(0, 0, num_copy) + dest2 = dest.clone() + dest.index_copy_(0, idx, src) + for i in range(idx.size(0)): + dest2[idx[i]] = src[i] + self.assertEqual(dest, dest2, 0) - def test_inplace_division(self): - t = torch.rand(5, 5) - id_before = id(t) - t /= 2 - id_after = id(t) - self.assertEqual(id_before, id_after) + dest = torch.randn(num_dest, device=device) + src = torch.randn(num_copy, device=device) + idx = torch.randperm(num_dest, device=device).narrow(0, 0, num_copy) + dest2 = dest.clone() + dest.index_copy_(0, idx, src) + for i in range(idx.size(0)): + dest2[idx[i]] = src[i] + self.assertEqual(dest, dest2, 0) + + # Bool tensor + dest = torch.zeros(2, 2, dtype=torch.bool, device=device) + src = torch.tensor([[True, True], [True, True]], device=device) + index = torch.tensor([0, 1], device=device) + dest.index_copy_(0, index, src) + self.assertEqual(dest, torch.tensor([[True, True], [True, True]], device=device)) + + # Error cases + a = torch.randn(3, 5) + c = torch.zeros(3) + self.assertRaises(IndexError, lambda: a.index_copy_(dim=1, index=torch.tensor([3]), source=c)) + + def test_index_fill(self, device): + for dt in torch.testing.get_all_dtypes(): + if dt == torch.half or dt == torch.bfloat16: + continue - def test_simple_scalar_cast(self): - ok = [torch.Tensor([1.5]), torch.zeros(1, 1, 1, 1)] - ok_values = [1.5, 0] + x = torch.tensor([[1, 2], [4, 5]], dtype=dt, device=device) + index = torch.tensor([0], device=device) + x.index_fill_(1, index, 0) + self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dt, device=device)) + + def test_index_select(self, device): + src = torch.randn(3, 4, 5, device=device) + # Index can be duplicated. + idx = torch.tensor([2, 1, 0, 1, 2], dtype=torch.long, device=device) + dest = torch.index_select(src, 0, idx) + self.assertEqual(dest.shape, (5, 4, 5)) + for i in range(idx.size(0)): + self.assertEqual(dest[i], src[idx[i]]) - not_ok = map(torch.Tensor, [[], [1, 2], [[1, 2], [3, 4]]]) + # Check that 'out' is used correctly. + out = torch.randn(5 * 4 * 5, device=device) + dest = torch.index_select(src, 0, idx, out=out.view(5, 4, 5)) + self.assertEqual(dest.shape, (5, 4, 5)) + for i in range(idx.size(0)): + self.assertEqual(dest[i], src[idx[i]]) + out.fill_(0.123) + self.assertEqual(out, dest.view(-1)) # Must point to the same storage. + + # Bool tensor + src = torch.tensor([False, True, False, False], device=device, dtype=torch.bool) + idx = torch.tensor([1], dtype=torch.long, device=device) + dest = torch.index_select(src, 0, idx) + self.assertEqual(torch.tensor([True]), dest) + + def test_take_empty(self, device): + for input_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]: + for indices_shape in [(0,), (0, 1, 2, 0)]: + input = torch.empty(input_shape, device=device) + indices = torch.empty(indices_shape, dtype=torch.int64, device=device) + self.assertEqual(indices, torch.take(input, indices)) + + def test_put_empty(self, device): + for dst_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]: + for indices_shape in [(0,), (0, 1, 2, 0)]: + for accumulate in [False, True]: + dst = torch.randn(dst_shape, device=device) + indices = torch.empty(indices_shape, dtype=torch.int64, device=device) + src = torch.randn(indices_shape, device=device) + self.assertEqual(dst, dst.put_(indices, src, accumulate=accumulate)) + + def test_scatter_to_large_input(self, device): + input = torch.zeros(4, 4, device=device) + src = torch.ones(2, 2, device=device) + index = torch.tensor([[1], [2]], device=device, dtype=torch.long) + input.scatter_(0, index, src) + self.assertEqual(input, torch.tensor([[0, 0, 0, 0], + [1, 0, 0, 0], + [1, 0, 0, 0], + [0, 0, 0, 0]], device=device)) + + def test_scatter_add_to_large_input(self, device): + input = torch.zeros(4, 4, device=device) + src = torch.ones(2, 2, device=device) + index = torch.tensor([[1], [2]], device=device, dtype=torch.long) + input.scatter_add_(0, index, src) + self.assertEqual(input, torch.tensor([[0, 0, 0, 0], + [1, 0, 0, 0], + [1, 0, 0, 0], + [0, 0, 0, 0]], device=device)) + + def test_scatter_bool(self, device): + x = torch.tensor([[True, True, True], [True, True, True]], device=device) + res = torch.zeros(3, 3, dtype=torch.bool, device=device) + res = res.scatter_(0, torch.tensor([[0, 1, 2], [0, 1, 2]], device=device), x) + self.assertEqual(res, torch.tensor([[True, False, False], + [False, True, False], + [False, False, True]], device=device)) + + def test_scatter_add_bool(self, device): + x = torch.tensor([[True, True, True, True, True], [True, True, True, True, True]], device=device) + res = torch.zeros(3, 5, dtype=torch.bool, device=device) + res = res.scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]], device=device), x) + self.assertEqual(res, torch.tensor([[True, True, True, True, True], + [False, True, False, True, False], + [True, False, True, False, True]], device=device)) + + def test_masked_scatter_bool_tensor(self, device): + src = torch.tensor([True, True, True], device=device) + dst = torch.tensor([False, False, False], device=device) + mask = torch.tensor([False, True, False], device=device) + + dst.masked_scatter_(mask, src) + self.assertEqual(dst, torch.tensor([False, True, False], device=device)) + + mask = torch.tensor([True, False, True], device=device) + dst = dst.masked_scatter(mask, src) + self.assertEqual(dst, torch.tensor([True, True, True], device=device)) + + def test_masked_select(self, device): + for dt in torch.testing.get_all_dtypes(): + with warnings.catch_warnings(record=True) as w: + for maskType in [torch.uint8, torch.bool]: + num_src = 10 + src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dt, device=device) + mask = torch.rand(num_src, device=device).clamp(0, 1).mul(2).floor().to(maskType) + + if dt == torch.bfloat16 and device == 'cuda': + # remove once bfloat16 implemented on CUDA + self.assertRaises(RuntimeError, lambda: src.masked_select(mask)) + continue - for tensor, value in zip(ok, ok_values): - self.assertEqual(int(tensor), int(value)) - self.assertEqual(float(tensor), float(value)) - if sys.version_info[0] < 3: - self.assertEqual(long(tensor), long(value)) + if dt == torch.half and device == 'cpu': + self.assertRaises(RuntimeError, lambda: src.masked_select(mask)) + continue - for tensor in not_ok: - self.assertRaises(ValueError, lambda: int(tensor)) - self.assertRaises(ValueError, lambda: float(tensor)) - if sys.version_info[0] < 3: - self.assertRaises(ValueError, lambda: long(tensor)) + dst = src.masked_select(mask) + dst2 = [] + for i in range(num_src): + if mask[i]: + dst2 += [src[i]] + self.assertEqual(dst, torch.tensor(dst2), 0) + + dst3 = torch.empty_like(src, device=device) + torch.masked_select(src, mask, out=dst3) + self.assertEqual(dst3, torch.Tensor(dst2), 0) + self.assertEqual(len(w), 1) + + warn = 'masked_select received a mask with dtype torch.uint8,' + self.assertEqual(str(w[0].message)[0:53], str(warn)) + + def test_masked_fill_bool_tensor(self, device): + dst = torch.tensor([True, False, True], device=device) + mask = torch.tensor([False, True, False], device=device) + + dst.masked_fill_(mask, True) + self.assertEqual(dst, torch.tensor([True, True, True], device=device)) + + dst = dst.masked_fill(mask, False) + self.assertEqual(dst, torch.tensor([True, False, True], device=device)) + + def test_tensor_shape_empty(self, device): + x = torch.randn((0, 1, 3, 0), device=device) + # flatten + self.assertEqual((0,), torch.flatten(x, 0, 3).shape) + self.assertEqual((0, 0), torch.flatten(x, 0, 2).shape) + self.assertEqual((0, 3, 0), torch.flatten(x, 1, 2).shape) + + # squeeze, unsqueeze + self.assertEqual((0, 1, 1, 3, 0), torch.unsqueeze(x, 1).shape) + self.assertEqual((0, 3, 0), torch.squeeze(x, 1).shape) + self.assertEqual((0, 3, 0), torch.squeeze(x).shape) + + # transpose, t + self.assertEqual((0, 0, 3, 1), torch.transpose(x, 1, 3).shape) + y = torch.randn((5, 0), device=device) + self.assertEqual((0, 5), y.t().shape) + + # select + self.assertEqual((0, 1, 0), torch.select(x, 2, 2).shape) + + # repeat, permute + self.assertEqual((9, 0, 5, 6, 0), x.repeat(9, 7, 5, 2, 3).shape) + self.assertEqual((3, 0, 0, 1), x.permute(2, 3, 0, 1).shape) + + # diagonal, diagflat + self.assertEqual((0,), torch.diagonal(torch.randn((5, 0), device=device)).shape) + self.assertEqual((0,), torch.diagonal(torch.randn((0, 5), device=device)).shape) + # off the end offsets are valid + self.assertEqual((0,), torch.diagonal(torch.randn((5, 0), device=device), offset=1).shape) + self.assertEqual((0,), torch.diagonal(torch.randn((0, 5), device=device), offset=1).shape) + # check non-zero sized offsets off the end + self.assertEqual((5, 6, 0), torch.diagonal(torch.randn((3, 4, 5, 6), device=device), offset=45252).shape) + self.assertEqual((5, 6, 0), torch.diagonal(torch.randn((3, 4, 5, 6), device=device), offset=-45252).shape) + + self.assertEqual((0, 0), torch.diagflat(torch.tensor([], device=device)).shape) + self.assertEqual(torch.zeros(1, 1), torch.diagflat(torch.tensor([], device=device), offset=1)) + self.assertEqual((0, 0), torch.diagflat(torch.tensor([[]], device=device)).shape) + self.assertEqual(torch.zeros(1, 1), torch.diagflat(torch.tensor([[]], device=device), offset=1)) + + # stack, split, chunk + self.assertEqual((4, 0, 1, 3, 0), torch.stack((x, x, x, x)).shape) + self.assertEqual([(0, 1, 3, 0)], + [z.shape for z in torch.chunk(x, 1, dim=0)]) + + self.assertEqual([(0, 1, 3, 0), ] * 3, [z.shape for z in torch.chunk(x, 3, dim=0)]) + self.assertEqual([(0, 1, 1, 0), ] * 3, [z.shape for z in torch.chunk(x, 3, dim=2)]) + + # NOTE: split_with_sizes behaves differently than NumPy in that it + # takes sizes rather than offsets + self.assertEqual([(0, 1, 0, 0), (0, 1, 1, 0), (0, 1, 2, 0)], + [z.shape for z in torch.split(x, (0, 1, 2), dim=2)]) + + self.assertRaises(RuntimeError, lambda: torch.split(x, 0, dim=1)) + # This is strange because the split size is larger than the dim size, but consistent with + # how split handles that case generally (when no 0s are involved). + self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 1, dim=0)]) + self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 0, dim=0)]) - def test_offset_scalar_cast(self): - x = torch.Tensor([1, 2, 3]) - y = x[2:] - self.assertEqual(int(y), 3) + # functions that operate over a dimension but don't reduce. + def test_dim_function_empty(self, device): + shape = (0, 1, 2, 0) + x = torch.randn(shape, device=device) + + # size stride + self.assertEqual(0, x.size(3)) + self.assertEqual(2, x.size(2)) + self.assertEqual(2, x.stride(0)) + self.assertEqual(1, x.stride(2)) + + self.assertEqual(x, torch.nn.functional.glu(x, 0)) + self.assertEqual((0, 1, 1, 0), torch.nn.functional.glu(x, 2).shape) + + # softmax, logsoftmax + self.assertEqual(x, torch.nn.functional.softmax(x, 0)) + self.assertEqual(x, torch.nn.functional.softmax(x, 2)) + self.assertEqual(x, torch.nn.functional.softmax(x, 3)) + + self.assertEqual(x, torch.nn.functional.log_softmax(x, 0)) + self.assertEqual(x, torch.nn.functional.log_softmax(x, 2)) + self.assertEqual(x, torch.nn.functional.log_softmax(x, 3)) + + # cumsum, cumprod + self.assertEqual(shape, torch.cumsum(x, 0).shape) + self.assertEqual(shape, torch.cumsum(x, 2).shape) + self.assertEqual(shape, torch.cumprod(x, 0).shape) + self.assertEqual(shape, torch.cumprod(x, 2).shape) + + # flip + self.assertEqual(x, x.flip(0)) + self.assertEqual(x, x.flip(2)) + + # roll + self.assertEqual(x, x.roll(0, 1).roll(0, -1)) + self.assertEqual(x, x.roll(1, x.size(1))) + self.assertEqual(x, x.roll(1)) + self.assertEqual(x, x.roll((1, 1), (3, 1))) + + # unbind + self.assertEqual((), x.unbind(0)) + self.assertEqual((torch.empty((0, 1, 0), device=device), torch.empty((0, 1, 0), device=device)), + x.unbind(2)) + + # cross + y = torch.randn((0, 1, 3, 0), device=device) + self.assertEqual(y.shape, torch.cross(y, y).shape) + + # renorm + self.assertEqual(shape, torch.renorm(x, 1, 0, 5).shape) + self.assertEqual(shape, torch.renorm(x, 1, 2, 5).shape) + + # sort + self.assertEqual([shape, shape], [z.shape for z in torch.sort(x, dim=0)]) + self.assertEqual([shape, shape], [z.shape for z in torch.sort(x, dim=2)]) + + # topk + self.assertEqual([shape, shape], [z.shape for z in torch.topk(x, 0, dim=0)]) + self.assertEqual([(0, 1, 1, 0), (0, 1, 1, 0)], [z.shape for z in torch.topk(x, 1, dim=2)]) + + y = torch.randn((2, 3, 4), device=device) + self.assertEqual([(2, 3, 0), (2, 3, 0)], [z.shape for z in torch.topk(y, 0)]) + + # gather + self.assertEqual(shape, torch.gather(x, 0, torch.empty(shape, dtype=torch.int64, device=device)).shape) + self.assertEqual(shape, torch.gather(x, 2, torch.empty(shape, dtype=torch.int64, device=device)).shape) + larger_shape = torch.empty((0, 1, 3, 0), dtype=torch.int64, device=device) + self.assertEqual(larger_shape.shape, torch.gather(x, 2, larger_shape).shape) + smaller_shape = torch.empty((0, 1, 0, 0), dtype=torch.int64, device=device) + self.assertEqual(smaller_shape.shape, torch.gather(x, 2, smaller_shape).shape) + y = torch.randn((2, 3, 4), device=device) + self.assertEqual((0, 3, 4), + torch.gather(y, 0, torch.empty((0, 3, 4), dtype=torch.int64, device=device)).shape) + + # scatter, scatter_add + for dim in [0, 2]: + y = torch.randn(shape, device=device) + y_src = torch.randn(shape, device=device) + ind = torch.empty(shape, dtype=torch.int64, device=device) + self.assertEqual(shape, y.scatter_(dim, ind, y_src).shape) + self.assertEqual(shape, y.scatter_add_(dim, ind, y_src).shape) + + z = torch.randn((2, 3, 4), device=device) + z_src = torch.randn((2, 3, 4), device=device) + self.assertEqual(z, z.scatter_(2, torch.empty((2, 3, 0), dtype=torch.int64, device=device), z_src)) + self.assertEqual(z, z.scatter_add_(2, torch.empty((2, 3, 0), dtype=torch.int64, device=device), z_src)) + + # index_fill, index_copy, index_add + c = x.clone() + c_clone = c.clone() + ind_empty = torch.tensor([], dtype=torch.int64, device=device) + ind_01 = torch.tensor([0, 1], dtype=torch.int64, device=device) + self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1)) + self.assertEqual(c_clone, c.index_fill_(2, ind_empty, -1)) + self.assertEqual(c_clone, c.index_fill_(2, torch.tensor([0, 1], dtype=torch.int64, device=device), -1)) + self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2, 0), device=device))) + self.assertEqual(c_clone, c.index_copy_(2, ind_empty, torch.empty((0, 1, 0, 0), device=device))) + self.assertEqual(c_clone, c.index_copy_(2, ind_01, torch.empty((0, 1, 2, 0), device=device))) + self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2, 0), device=device))) + self.assertEqual(c_clone, c.index_add_(2, ind_empty, torch.empty((0, 1, 0, 0), device=device))) + self.assertEqual(c_clone, c.index_add_(2, ind_01, torch.empty((0, 1, 2, 0), device=device))) + + c = torch.randn((0, 1, 2), device=device) + c_clone = c.clone() + self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1)) + self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2), device=device))) + self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2), device=device))) + self.assertEqual(c_clone, c.index_fill_(0, ind_empty, -1)) + self.assertEqual(c_clone, c.index_copy_(0, ind_empty, torch.empty((0, 1, 2), device=device))) + self.assertEqual(c_clone, c.index_add_(0, ind_empty, torch.empty((0, 1, 2), device=device))) + + # index fill/copy/add non-empty + z = torch.randn((2, 3, 4), device=device) + self.assertEqual(z, z.index_fill_(0, ind_empty, -1)) + z = torch.randn((2, 3, 4), device=device) + self.assertEqual(z, z.index_copy_(0, ind_empty, torch.empty((0, 3, 4), device=device))) + z = torch.randn((2, 3, 4), device=device) + self.assertEqual(z, z.index_add_(0, ind_empty, torch.empty((0, 3, 4), device=device))) + + # index_select + self.assertEqual(x, x.index_select(0, ind_empty)) + self.assertEqual((0, 1, 0, 0), x.index_select(2, ind_empty).shape) + self.assertEqual(x, x.index_select(2, ind_01)) + z = torch.randn((2, 3, 4), device=device) # non-empty + self.assertEqual((0, 3, 4), z.index_select(0, ind_empty).shape) + c = torch.randn((0, 1, 2), device=device) + self.assertEqual(c, c.index_select(0, ind_empty)) + c = torch.randn((0, 1, 2), device=device) + self.assertEqual(c, c.index_select(0, ind_empty)) + + def test_nonzero(self, device): + num_srcs = [ + 12, 12, 12, 12, 12, 125, + ] - # skip this test for now as it affects all tests - @unittest.skipIf(True, "flush_denormal not supported") - def test_set_flush_denormal(self): - tiny_float = 1e-42 - tiny_double = 1e-320 - float_tensor = torch.FloatTensor([1.0, tiny_float]) - double_tensor = torch.DoubleTensor([1.0, tiny_float, tiny_double]) + types = [ + 'torch.ByteTensor', + 'torch.CharTensor', + 'torch.ShortTensor', + 'torch.IntTensor', + 'torch.FloatTensor', + 'torch.DoubleTensor', + 'torch.LongTensor', + ] - self.assertEqual(float_tensor[0], 1.0, prec=0.0) - self.assertEqual(float_tensor[1], tiny_float, prec=tiny_float / 16) - self.assertEqual(double_tensor[0], 1.0, prec=0.0) - self.assertEqual(double_tensor[1], tiny_float, prec=0.0) - self.assertEqual(double_tensor[2], tiny_double, prec=0.0) + shapes = [ + torch.Size((12,)), + torch.Size((12, 1)), + torch.Size((1, 12)), + torch.Size((6, 2)), + torch.Size((3, 2, 2)), + torch.Size((5, 5, 5)), + ] - torch.set_flush_denormal(True) - self.assertEqual(float_tensor[0], 1.0, prec=0.0) - self.assertEqual(float_tensor[1], 0.0, prec=0.0) # tiny_float to zero - self.assertEqual(double_tensor[0], 1.0, prec=0.0) - # tiny_float is not converted to zero in double type - self.assertEqual(double_tensor[1], tiny_float, prec=0.0) - self.assertEqual(double_tensor[2], 0.0, prec=0.0) # tiny_double to zero - torch.set_flush_denormal(False) + def is_lexicographically_sorted(inds): + """Check sorted ascending with + i -> j -> k changing slowest to fastest""" + assert inds.size(1) == 3 + if inds.size(0) > 1: + i0, j0, k0 = inds[:-1].t() + i1, j1, k1 = inds[+1:].t() + i_ok = (i1 >= i0) + j_ok = (j1 >= j0) | (i1 > i0) + k_ok = (k1 >= k0) | (j1 > j0) | (i1 > i0) + lex = torch.stack((i_ok, j_ok, k_ok), dim=1) + return lex + return torch.full_like(inds, 1) + + def gen_nontrivial_input(num_src, dtype, device): + while True: + tensor = torch.rand(num_src).mul(2).floor().type(dtype).to(device) + if tensor.sum() > 0: + return tensor - def test_unique(self): - def run_test(device): - x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], device=device) - expected_unique = torch.tensor([1, 2, 3, 5, 8], device=device) - expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device) - expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device) + for dtype in types: + for shape, num_src in zip(shapes, num_srcs): + tensor = gen_nontrivial_input(num_src, dtype, device) + tensor = tensor.clone().resize_(shape) + dst1 = torch.nonzero(tensor) + dst2 = tensor.nonzero() + dst3 = torch.LongTensor().to(device) + torch.nonzero(tensor, out=dst3) + + self.assertRaisesRegex( + TypeError, + "received an invalid combination of arguments", + lambda: torch.nonzero(tensor, as_tuple=True, out=dst3)) + if len(shape) == 1: + dst = [] + for i in range(num_src): + if tensor[i] != 0: + dst += [i] + dst = torch.LongTensor(dst).to(device) + self.assertEqual(dst1.select(1, 0), dst, 0) + self.assertEqual(dst2.select(1, 0), dst, 0) + self.assertEqual(dst3.select(1, 0), dst, 0) + elif len(shape) == 2: + # This test will allow through some False positives. It only checks + # that the elements flagged positive are indeed non-zero. + for i in range(dst1.size(0)): + self.assertNotEqual(tensor[dst1[i, 0], dst1[i, 1]].item(), 0) + elif len(shape) == 3: + # This test will allow through some False positives. It only checks + # that the elements flagged positive are indeed non-zero. + for i in range(dst1.size(0)): + self.assertNotEqual(tensor[dst1[i, 0], dst1[i, 1], dst1[i, 2]].item(), 0) + lex = is_lexicographically_sorted(dst1) + self.assertEqual(torch.ones_like(lex), lex) + if TEST_NUMPY: + tup1 = torch.nonzero(tensor, as_tuple=True) + tup2 = tensor.nonzero(as_tuple=True) + tup3 = torch.where(tensor) + np1 = tensor.cpu().numpy().nonzero() + for t in (tup1, tup2, tup3): + self.assertEqual(len(t), len(np1)) + for i in range(len(t)): + self.assertEqual(t[i].cpu().numpy(), np1[i]) + + def test_pdist_norm(self, device): + def test_pdist_single(shape, device, p, dtype, trans): + x = torch.randn(shape, dtype=dtype, device=device) + if trans: + x.transpose_(-2, -1) + actual = torch.pdist(x, p=p) + expected = brute_pdist(x, p=p) + self.assertEqual(expected.shape, actual.shape) + self.assertTrue(torch.allclose(expected, actual)) - x_unique = torch.unique(x) - self.assertEqual( - expected_unique.tolist(), sorted(x_unique.tolist())) + for shape in [(4, 5), (3, 2), (2, 1)]: + for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: + for trans in [False, True]: + for dtype in [torch.float32, torch.float64]: + test_pdist_single(shape, device, p, dtype, trans) - x_unique, x_inverse = x.unique(return_inverse=True) - self.assertEqual( - expected_unique.tolist(), sorted(x_unique.tolist())) - self.assertEqual(expected_inverse.numel(), x_inverse.numel()) + # do a simplified comparison with big inputs, see: + # https://github.com/pytorch/pytorch/issues/15511 + for dtype in [torch.float32, torch.float64]: + test_pdist_single((1000, 2), device, 2, dtype, False) - x_unique = x.unique(sorted=True) - self.assertEqual(expected_unique, x_unique) + def test_atan2(self, device): + def _test_atan2_with_size(size, device): + a = torch.rand(size=size, device=device, dtype=torch.double) + b = torch.rand(size=size, device=device, dtype=torch.double) + actual = a.atan2(b) + x = a.view(-1) + y = b.view(-1) + expected = torch.tensor([math.atan2(x[i].item(), y[i].item()) for i in range(x.numel())], + device=device, dtype=torch.double) + self.assertTrue(torch.allclose(expected, actual.view(-1), rtol=0, atol=0.02)) - x_unique, x_counts = torch.unique(x, sorted=True, return_counts=True) - self.assertEqual(expected_counts, x_counts) + _test_atan2_with_size((2, 2), device) + _test_atan2_with_size((3, 3), device) + _test_atan2_with_size((5, 5), device) - x_unique, x_inverse = torch.unique( - x, sorted=True, return_inverse=True) - self.assertEqual(expected_unique, x_unique) - self.assertEqual(expected_inverse, x_inverse) + def test_atan2_edgecases(self, device): + def _test_atan2(x, y, expected, device, dtype): + expected_tensor = torch.tensor([expected], dtype=dtype, device=device) + x_tensor = torch.tensor([x], dtype=dtype, device=device) + y_tensor = torch.tensor([y], dtype=dtype, device=device) + actual = torch.atan2(y_tensor, x_tensor) + self.assertTrue(torch.allclose(expected_tensor, actual, rtol=0, atol=0.02)) - x_unique, x_inverse, x_counts = torch.unique( - x, sorted=True, return_inverse=True, return_counts=True) - self.assertEqual(expected_unique, x_unique) - self.assertEqual(expected_inverse, x_inverse) - self.assertEqual(expected_counts, x_counts) - - # Tests per-element unique on a higher rank tensor. - y = x.view(2, 2, 2) - y_unique, y_inverse = y.unique(sorted=True, return_inverse=True) - self.assertEqual(expected_unique, y_unique) - self.assertEqual(expected_inverse.view(y.size()), y_inverse) - - y_unique, y_inverse, y_counts = torch.unique( - y, sorted=True, return_inverse=True, return_counts=True) - self.assertEqual(expected_unique, y_unique) - self.assertEqual(expected_inverse.view(y.size()), y_inverse) - self.assertEqual(expected_counts, y_counts) - - # Tests unique on other types. - int_unique, int_inverse, int_counts = torch.unique( - torch.tensor([2, 1, 2], dtype=torch.int, device=device), - sorted=True, - return_inverse=True, - return_counts=True - ) - self.assertEqual(torch.tensor([1, 2], dtype=torch.int, device=device), int_unique) - self.assertEqual(torch.tensor([1, 0, 1], dtype=torch.long, device=device), int_inverse) - self.assertEqual(torch.tensor([1, 2], dtype=torch.long, device=device), int_counts) + for dtype in [torch.float, torch.double]: + _test_atan2(0, 0, 0, device, dtype) + _test_atan2(0, 1, math.pi / 2, device, dtype) + _test_atan2(0, -1, math.pi / -2, device, dtype) + _test_atan2(-1, 0, math.pi, device, dtype) + _test_atan2(1, 0, 0, device, dtype) + _test_atan2(-1, -1, math.pi * -3 / 4 , device, dtype) + _test_atan2(1, 1, math.pi / 4 , device, dtype) + _test_atan2(1, -1, math.pi / -4 , device, dtype) + _test_atan2(-1, 1, math.pi * 3 / 4 , device, dtype) - double_unique, double_inverse, double_counts = torch.unique( - torch.tensor([2., 1.5, 2.1, 2.], dtype=torch.double, device=device), - sorted=True, - return_inverse=True, - return_counts=True - ) - self.assertEqual(torch.tensor([1.5, 2., 2.1], dtype=torch.double, device=device), double_unique) - self.assertEqual(torch.tensor([1, 0, 2, 1], dtype=torch.long, device=device), double_inverse) - self.assertEqual(torch.tensor([1, 2, 1], dtype=torch.long, device=device), double_counts) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_trapz(self, device): + def test_dx(sizes, dim, dx, device): + t = torch.randn(sizes, device=device) + actual = torch.trapz(t, dx=dx, dim=dim) + expected = np.trapz(t.cpu().numpy(), dx=dx, axis=dim) + self.assertEqual(expected.shape, actual.shape) + self.assertTrue(np.allclose(expected, actual.cpu().numpy())) - byte_unique, byte_inverse, byte_counts = torch.unique( - torch.tensor([133, 7, 7, 7, 42, 128], dtype=torch.uint8, device=device), - sorted=True, - return_inverse=True, - return_counts=True - ) - self.assertEqual(torch.tensor([7, 42, 128, 133], dtype=torch.uint8, device=device), byte_unique) - self.assertEqual(torch.tensor([3, 0, 0, 0, 1, 2], dtype=torch.long, device=device), byte_inverse) - self.assertEqual(torch.tensor([3, 1, 1, 1], dtype=torch.long, device=device), byte_counts) + def test_x(sizes, dim, x, device): + t = torch.randn(sizes, device=device) + actual = torch.trapz(t, x=torch.tensor(x, device=device), dim=dim) + expected = np.trapz(t.cpu().numpy(), x=x, axis=dim) + self.assertEqual(expected.shape, actual.shape) + self.assertTrue(np.allclose(expected, actual.cpu().numpy())) - # test consecutive version - z = torch.tensor([1, 2, 2, 2, 5, 5, 2, 2, 3], device=device) - expected_z_unique = torch.tensor([1, 2, 5, 2, 3], device=device) - expected_z_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4], device=device) - expected_z_counts = torch.tensor([1, 3, 2, 2, 1], device=device) + test_dx((2, 3, 4), 1, 1, device) + test_dx((10, 2), 0, 0.1, device) + test_dx((1, 10), 0, 2.3, device) + test_dx((0, 2), 0, 1.0, device) + test_dx((0, 2), 1, 1.0, device) + test_x((2, 3, 4), 1, [1.0, 2.0, 3.0], device) + test_x((10, 2), 0, [2.0, 3.0, 4.0, 7.0, 11.0, 14.0, 22.0, 26.0, 26.1, 30.3], device) + test_x((1, 10), 0, [1.0], device) + test_x((0, 2), 0, [], device) + test_x((0, 2), 1, [1.0, 2.0], device) + with self.assertRaisesRegex( + IndexError, + 'Dimension out of range'): + test_x((2, 3), 2, [], device) + test_dx((2, 3), 2, 1.0, device) + with self.assertRaisesRegex( + RuntimeError, + 'There must be one `x` value for each sample point'): + test_x((2, 3), 1, [1.0, 2.0], device) + test_x((2, 3), 1, [1.0, 2.0, 3.0, 4.0], device) - z_unique = torch.unique_consecutive(z) - self.assertEqual(z_unique, expected_z_unique) + def test_reduction_empty(self, device): + fns_to_test = [ + # name, function, identity + ('max', torch.max, None), + ('kthvalue', lambda *args, **kwargs: torch.kthvalue(*args, k=1, **kwargs), None), + ('argmax', torch.argmax, None), + ('min', torch.min, None), + ('argmin', torch.argmin, None), + ('mode', torch.mode, None), + ('median', torch.median, None), - z_unique, z_inverse = torch.unique_consecutive(z, return_inverse=True) - self.assertEqual(z_unique, expected_z_unique) - self.assertEqual(z_inverse, expected_z_inverse) + ('prod', torch.prod, 1), + ('sum', torch.sum, 0), + ('norm', torch.norm, 0), + ('mean', torch.mean, nan), + ('var', torch.var, nan), + ('std', torch.std, nan), + ('logsumexp', torch.logsumexp, -inf), + ] - z_unique, z_counts = torch.unique_consecutive(z, return_counts=True) - self.assertEqual(z_unique, expected_z_unique) - self.assertEqual(z_counts, expected_z_counts) + shape = (2, 0, 4) + x = torch.randn(shape, device=device) + + for fn in [torch.max, torch.min]: + ident_err = 'operation does not have an identity' + self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x)) + + for item in fns_to_test: + name, fn, identity = item + if identity is None: + ident_err = 'does not have an identity' + self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=2)) + self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=2, keepdim=True)) + self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=1)) + self.assertRaisesRegex(RuntimeError, ident_err, lambda: fn(x, dim=1, keepdim=True)) + else: + self.assertEqual(torch.empty((2, 0), device=device), fn(x, dim=2)) + self.assertEqual(torch.empty((2, 0, 1), device=device), fn(x, dim=2, keepdim=True)) + # assertEqual doesn't work with inf, -inf, nan and two tensors. + check = (torch.testing.assert_allclose if math.isnan(identity) or math.isinf(identity) else + self.assertEqual) + check(torch.full((2, 4), identity, device=device), fn(x, dim=1)) + check(torch.full((2, 1, 4), identity, device=device), fn(x, dim=1, keepdim=True)) + try: + check(torch.full((), identity, device=device), fn(x)) + except TypeError as err: + # ignore if there is no allreduce. + self.assertTrue('dim' in str(err)) + + # any + xb = x.to(torch.uint8) + yb = x.to(torch.uint8) + self.assertEqual((2, 0), xb.any(2).shape) + self.assertEqual((2, 0, 1), xb.any(2, keepdim=True).shape) + self.assertEqual(torch.zeros((2, 4), device=device), xb.any(1)) + self.assertEqual(torch.zeros((2, 1, 4), device=device), xb.any(1, keepdim=True)) + self.assertEqual(torch.zeros((), device=device), xb.any()) + + # all + self.assertEqual((2, 0), xb.all(2).shape) + self.assertEqual((2, 0, 1), xb.all(2, keepdim=True).shape) + self.assertEqual(torch.ones((2, 4), device=device), xb.all(1)) + self.assertEqual(torch.ones((2, 1, 4), device=device), xb.all(1, keepdim=True)) + self.assertEqual(torch.ones((), device=device), xb.all()) + + def test_addcdiv(self, device): + def _test_addcdiv(a, alpha, b, c): + actual = torch.addcdiv(a, alpha, b, c) + # implementation of addcdiv downcasts alpha. arithmetic ops don't. + if not actual.dtype.is_floating_point: + alpha = int(alpha) + expected = a + (alpha * b) / c + self.assertTrue(torch.allclose(expected, actual, equal_nan=True)) - z_unique, z_inverse, z_counts = torch.unique_consecutive(z, return_inverse=True, return_counts=True) - self.assertEqual(z_unique, expected_z_unique) - self.assertEqual(z_inverse, expected_z_inverse) - self.assertEqual(z_counts, expected_z_counts) + def non_zero_rand(size, dtype, device): + if dtype.is_floating_point: + a = torch.rand(size=size, dtype=dtype, device=device) + elif dtype == torch.uint8: + a = torch.randint(1, 5, size=size, dtype=dtype, device=device) + else: + a = torch.randint(-5, 5, size=size, dtype=dtype, device=device) + return a + (a == 0).type(dtype) - run_test(torch.device('cpu')) - if torch.cuda.is_available(): - run_test(torch.device('cuda')) + for dtype in torch.testing.get_all_math_dtypes(device): + _test_addcdiv( + non_zero_rand((2, 2), dtype=dtype, device=device), + 0.5, + non_zero_rand((2, 2), dtype=dtype, device=device), + non_zero_rand((2, 2), dtype=dtype, device=device)) - @skipIfRocm - def test_unique_dim(self): - self.assertFalse(hasattr(torch, 'unique_dim')) + def test_unary_out_op_mem_overlap(self, device): + sz = 3 + doubles = torch.randn(2 * sz, device=device) + positives = torch.randint(1, 100, (2 * sz,), device=device).double() + ints = torch.randint(-100, 100, (2 * sz,), device=device) + unary_mem_overlap_cases = [ + ("abs", doubles, True, True, 'cpu'), + ("abs", doubles, False, True, 'cuda'), + ("acos", doubles, True, True, 'cpu'), + ("acos", doubles, False, True, 'cuda'), + ("asin", doubles, True, True, 'cpu'), + ("asin", doubles, False, True, 'cuda'), + ("atan", doubles, True, True, 'cpu'), + ("atan", doubles, False, True, 'cuda'), + ("bitwise_not", ints, True, True, 'cpu'), + ("bitwise_not", ints, True, True, 'cuda'), + ("ceil", doubles, True, True, 'cpu'), + ("ceil", doubles, True, True, 'cuda'), + ("cos", doubles, True, True, 'cpu'), + ("cos", doubles, False, True, 'cuda'), + ("cosh", doubles, True, True, 'cpu'), + ("cosh", doubles, False, True, 'cuda'), + ("digamma", doubles, True, True, 'cpu'), + ("erf", doubles, True, True, 'cpu'), + ("erf", doubles, False, True, 'cuda'), + ("erfc", doubles, True, True, 'cpu'), + ("erfc", doubles, False, True, 'cuda'), + ("erfinv", doubles, True, True, 'cpu'), + ("erfinv", doubles, True, True, 'cuda'), + ("exp", doubles, True, True, 'cpu'), + ("exp", doubles, False, True, 'cuda'), + ("expm1", doubles, True, True, 'cpu'), + ("expm1", doubles, False, True, 'cuda'), + ("floor", doubles, True, True, 'cpu'), + ("floor", doubles, False, True, 'cuda'), + ("frac", doubles, True, True, 'cpu'), + ("frac", doubles, False, True, 'cuda'), + ("log", positives, True, True, 'cpu'), + ("log", positives, False, True, 'cuda'), + ("log10", positives, True, True, 'cpu'), + ("log10", positives, False, True, 'cuda'), + ("log1p", positives, True, True, 'cpu'), + ("log1p", positives, False, True, 'cuda'), + ("log2", positives, True, True, 'cpu'), + ("log2", positives, False, True, 'cuda'), + ("neg", doubles, True, True, 'cpu'), + ("neg", doubles, True, True, 'cuda'), + ("reciprocal", doubles, True, True, 'cpu'), + ("reciprocal", doubles, False, True, 'cuda'), + ("round", doubles, True, True, 'cpu'), + ("round", doubles, True, True, 'cuda'), + ("rsqrt", positives, True, True, 'cpu'), + ("rsqrt", positives, True, True, 'cuda'), + ("sin", doubles, True, True, 'cpu'), + ("sin", doubles, False, True, 'cuda'), + ("sinh", doubles, True, True, 'cpu'), + ("sinh", doubles, False, True, 'cuda'), + ("sigmoid", doubles, True, True, 'cpu'), + ("sigmoid", doubles, False, False, 'cuda'), + ("sqrt", doubles, True, True, 'cpu'), + ("sqrt", doubles, False, True, 'cuda'), + ("tan", doubles, True, True, 'cpu'), + ("tan", doubles, False, True, 'cuda'), + ("tanh", doubles, True, True, 'cpu'), + ("tanh", doubles, False, True, 'cuda'), + ("trunc", doubles, True, True, 'cpu'), + ("trunc", doubles, False, True, 'cuda') + ] - def run_test(dtype=torch.float, device=torch.device('cpu')): - x = torch.tensor([[[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]], - [[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]]], - dtype=dtype, - device=device) - x_empty = torch.empty(5, 0, dtype=dtype, device=device) - x_ill_formed_empty = torch.empty(5, 0, 0, dtype=dtype, device=device) - x_ill_formed_empty_another = torch.empty(5, 0, 5, dtype=dtype, device=device) - expected_unique_dim0 = torch.tensor([[[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]]], - dtype=dtype, - device=device) - expected_inverse_dim0 = torch.tensor([0, 0]) - expected_counts_dim0 = torch.tensor([2]) - expected_unique_dim1 = torch.tensor([[[0., 1.], - [1., 1.], - [2., 1.]], - [[0., 1.], - [1., 1.], - [2., 1.]]], - dtype=dtype, - device=device) - expected_inverse_dim1 = torch.tensor([1, 0, 2, 0]) - expected_counts_dim1 = torch.tensor([2, 1, 1]) - expected_unique_dim2 = torch.tensor([[[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]], - [[1., 1.], - [0., 1.], - [2., 1.], - [0., 1.]]], - dtype=dtype, - device=device) - expected_inverse_dim2 = torch.tensor([0, 1]) - expected_counts_dim2 = torch.tensor([1, 1]) - expected_unique_empty = torch.tensor([], dtype=dtype, device=device) - expected_inverse_empty = torch.tensor([], dtype=torch.long, device=device) - expected_counts_empty = torch.tensor([], dtype=torch.long, device=device) - # dim0 - x_unique = torch.unique(x, dim=0) - self.assertEqual(expected_unique_dim0, x_unique) + for (fn, inputs, has_input_output_mem_overlap_check, + has_internal_mem_overlap_check, dev) in unary_mem_overlap_cases: + if dev != device: + continue + out_fn = getattr(torch, fn) + in_fn = getattr(torch.Tensor, fn + '_') - x_unique, x_inverse = torch.unique( - x, - return_inverse=True, - dim=0) - self.assertEqual(expected_unique_dim0, x_unique) - self.assertEqual(expected_inverse_dim0, x_inverse) + self.unary_check_input_output_mem_overlap(inputs, sz, out_fn, + expected_failure=not has_input_output_mem_overlap_check) - x_unique, x_counts = torch.unique( - x, - return_inverse=False, - return_counts=True, - dim=0) - self.assertEqual(expected_unique_dim0, x_unique) - self.assertEqual(expected_counts_dim0, x_counts) + self.check_internal_mem_overlap(in_fn, num_inputs=1, device=dev, + expected_failure=not has_internal_mem_overlap_check) - x_unique, x_inverse, x_counts = torch.unique( - x, - return_inverse=True, - return_counts=True, - dim=0) - self.assertEqual(expected_unique_dim0, x_unique) - self.assertEqual(expected_inverse_dim0, x_inverse) - self.assertEqual(expected_counts_dim0, x_counts) + def test_binary_op_mem_overlap(self, device): + ops = [ + ("add", True, True, 'cpu'), + ("add", True, True, 'cuda'), + ("mul", True, True, 'cpu'), + ("mul", True, True, 'cuda'), + ("sub", True, True, 'cpu'), + ("sub", True, True, 'cuda'), + ("div", True, True, 'cpu'), + ("div", True, True, 'cuda'), + ("pow", True, True, 'cpu'), + ("pow", True, True, 'cuda') + ] - # dim1 - x_unique = torch.unique(x, dim=1) - self.assertEqual(expected_unique_dim1, x_unique) + for (fn, has_input_output_mem_overlap_check, + has_internal_mem_overlap_check, dev) in ops: + if dev != device: + continue + out_op = getattr(torch, fn) + inplace_op = getattr(torch.Tensor, fn + '_') + self.check_internal_mem_overlap( + inplace_op, num_inputs=2, device=device, + expected_failure=not has_internal_mem_overlap_check) - x_unique, x_inverse = torch.unique( - x, - return_inverse=True, - dim=1) - self.assertEqual(expected_unique_dim1, x_unique) - self.assertEqual(expected_inverse_dim1, x_inverse) + self.binary_check_input_output_mem_overlap(out_op, device, + expected_failure=not has_input_output_mem_overlap_check) - x_unique, x_counts = torch.unique( - x, - return_inverse=False, - return_counts=True, - dim=1) - self.assertEqual(expected_unique_dim1, x_unique) - self.assertEqual(expected_counts_dim1, x_counts) + def test_ternary_op_mem_overlap(self, device): + ops = [ + ("addcmul", True, True, 'cpu'), + ("addcmul", True, True, 'cuda'), + ("addcdiv", True, True, 'cpu'), + ("addcdiv", True, True, 'cuda'), + ("lerp", True, True, 'cpu'), + ("lerp", False, False, 'cuda') + ] - x_unique, x_inverse, x_counts = torch.unique( - x, - return_inverse=True, - return_counts=True, - dim=1) - self.assertEqual(expected_unique_dim1, x_unique) - self.assertEqual(expected_inverse_dim1, x_inverse) - self.assertEqual(expected_counts_dim1, x_counts) + for (fn, has_input_output_mem_overlap_check, + has_internal_mem_overlap_check, dev) in ops: + if dev != device: + continue + out_op = getattr(torch, fn) + inplace_op = getattr(torch.Tensor, fn + '_') + self.check_internal_mem_overlap( + inplace_op, num_inputs=3, device=device, + expected_failure=not has_internal_mem_overlap_check) + self.ternary_check_input_output_mem_overlap(out_op, dev, + expected_failure=not has_input_output_mem_overlap_check) - # dim2 - x_unique = torch.unique(x, dim=2) - self.assertEqual(expected_unique_dim2, x_unique) + def test_copy_mem_overlap(self, device): + self.check_internal_mem_overlap( + torch.Tensor.copy_, num_inputs=2, device=device) + sz = 3 + doubles = torch.randn(2 * sz, device=device) + self.unary_check_input_output_mem_overlap( + doubles, sz, lambda input, out: out.copy_(input)) - x_unique, x_inverse = torch.unique( - x, - return_inverse=True, - dim=2) - self.assertEqual(expected_unique_dim2, x_unique) - self.assertEqual(expected_inverse_dim2, x_inverse) + def test_pow_scalar_overloads_mem_overlap(self, device): + sz = 3 + doubles = torch.randn(2 * sz, device=device) + self.check_internal_mem_overlap( + lambda t: t.pow_(42), num_inputs=1, device=device) + self.unary_check_input_output_mem_overlap( + doubles, sz, lambda input, out: torch.pow(input, 42, out=out)) + self.unary_check_input_output_mem_overlap( + doubles, sz, lambda input, out: torch.pow(42, input, out=out)) - x_unique, x_counts = torch.unique( - x, - return_inverse=False, - return_counts=True, - dim=2) - self.assertEqual(expected_unique_dim2, x_unique) - self.assertEqual(expected_counts_dim2, x_counts) + @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') + def test_int_pow(self, device): - x_unique, x_inverse, x_counts = torch.unique( - x, - return_inverse=True, - return_counts=True, - dim=2) - self.assertEqual(expected_unique_dim2, x_unique) - self.assertEqual(expected_inverse_dim2, x_inverse) - self.assertEqual(expected_counts_dim2, x_counts) + def _test_integral_pow(dt, range, dev): + tensor = torch.tensor((3, 3), dtype=dt, device=dev).random_(*range) + exps = [0, 1, 2, 4, + torch.tensor((3, 3), dtype=dt, device=dev).random_(0, 5)] + for exp in exps: + self._test_pow(tensor, exp) - # test empty tensor - x_unique, x_inverse, x_counts = torch.unique( - x_empty, - return_inverse=True, - return_counts=True, - dim=1) - self.assertEqual(expected_unique_empty, x_unique) - self.assertEqual(expected_inverse_empty, x_inverse) - self.assertEqual(expected_counts_empty, x_counts) + _test_integral_pow(torch.int8, (-3, 4), device) + _test_integral_pow(torch.uint8, (0, 4), device) + _test_integral_pow(torch.int16, (-5, 5), device) + _test_integral_pow(torch.int64, (-10, 10), device) + _test_integral_pow(torch.int32, (-10, 10), device) - # test not a well formed tensor - # Checking for runtime error, as this is the expected behaviour - with self.assertRaises(RuntimeError): - torch.unique( - x_ill_formed_empty, - return_inverse=True, - return_counts=True, - dim=1) + @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') + def test_int_tensor_pow_neg_ints(self, device): + ints = [torch.iinfo(torch.int32).min, + -3, -2, -1, 0, 1, 2, 3, + torch.iinfo(torch.int32).max] + neg_ints = [torch.iinfo(torch.int32).min, -3, -2, -1] + tensor = torch.tensor(ints, dtype=torch.int32, device=device) + for pow in neg_ints: + self._test_pow(tensor, pow) - # test along dim2 - with self.assertRaises(RuntimeError): - torch.unique( - x_ill_formed_empty_another, - return_inverse=True, - return_counts=True, - dim=2) + @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') + def test_long_tensor_pow_floats(self, device): + ints = [0, 1, 23, 4567] + floats = [0.0, 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0] + tensor = torch.tensor(ints, dtype=torch.int64, device=device) + for pow in floats: + self._test_pow(tensor, pow) - # test consecutive version - y = torch.tensor( - [[0, 1], - [0, 1], - [0, 1], - [1, 2], - [1, 2], - [3, 4], - [0, 1], - [0, 1], - [3, 4], - [1, 2]], - dtype=dtype, - device=device - ) - expected_y_unique = torch.tensor( - [[0, 1], - [1, 2], - [3, 4], - [0, 1], - [3, 4], - [1, 2]], - dtype=dtype, - device=device - ) - expected_y_inverse = torch.tensor([0, 0, 0, 1, 1, 2, 3, 3, 4, 5], dtype=dtype, device=device) - expected_y_counts = torch.tensor([3, 2, 1, 2, 1, 1], dtype=dtype, device=device) - y_unique, y_inverse, y_counts = torch.unique_consecutive(y, return_inverse=True, return_counts=True, dim=0) - self.assertEqual(expected_y_inverse, y_inverse) - self.assertEqual(expected_y_counts, y_counts) + @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') + def test_float_scalar_pow_float_tensor(self, device): + floats = [2.0, -3 / 2, -1.0, -1 / 2, -1 / 3, 0.0, + 1 / 3, 1 / 2, 1.0, 3 / 2, 2.0] + tensor = torch.tensor(floats, dtype=torch.float32, device=device) + for base in floats: + self._test_pow(base, tensor) - run_test(torch.float) - run_test(torch.double) - run_test(torch.long) - run_test(torch.uint8) - if torch.cuda.is_available(): - run_test(torch.float, torch.device('cuda')) - run_test(torch.double, torch.device('cuda')) - run_test(torch.long, torch.device('cuda')) - run_test(torch.uint8, torch.device('cuda')) + @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') + def test_tensor_pow_tensor(self, dev): + def rotate(l, n): + return l[-n:] + l[:-n] - def test_show_config(self): - # We can't usefully test the output; just make sure this doesn't crash - torch.__config__.show() + def test_tensor_pow_tensor(values, torch_type, numpy_type): + vals_tensor = torch.tensor(values, dtype=torch_type, device=dev) + for i in range(len(values)): + pows = rotate(values, i) + pows_tensor = torch.tensor(pows, dtype=torch_type, device=dev) + self._test_pow(vals_tensor, pows_tensor) - def test_parallel_info(self): - torch.__config__.parallel_info() + ints = [0, 1, 2, 3] + test_tensor_pow_tensor(ints, torch.int32, np.int32) + test_tensor_pow_tensor(ints, torch.int64, np.int64) - @staticmethod - def _test_bincount(self, device): - # negative input throws - with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'): - torch.bincount(torch.tensor([1, -1], device=device)) - # n-d input, with n > 1 throws - with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'): - torch.bincount(torch.tensor([[1, 2], [3, 4]], device=device)) - # floating input type throws - with self.assertRaisesRegex(RuntimeError, 'not implemented'): - torch.bincount(torch.tensor([1., 0.3], device=device)) - # minlength < 0 throws - with self.assertRaisesRegex(RuntimeError, 'minlength should be >= 0'): - torch.bincount(torch.tensor([1, 3], device=device), - torch.tensor([.2, .2], device=device), - minlength=-1) - # input and weights dim mismatch - with self.assertRaisesRegex(RuntimeError, 'same length'): - torch.bincount(torch.tensor([1, 0], device=device), - torch.tensor([1., 0.3, 0.5], device=device)) - # 1-d input with no elements and default minlength - self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long)), - torch.zeros(0, dtype=torch.long, device=device)) - # 1-d input with no elements and specified minlength - self.assertEqual(torch.bincount(torch.tensor([], device=device, dtype=torch.long), minlength=10), - torch.zeros(10, dtype=torch.long, device=device)) + floats = [-3.0, -2.0, -1.0, -1 / 2, -1 / 3, + 0.0, + 1 / 3, 1 / 2, 1.0, 2.0, 3.0] + test_tensor_pow_tensor(floats, torch.float32, np.float32) + test_tensor_pow_tensor(floats, torch.float64, np.float64) - # test tensor method without weights - long_counts = torch.tensor( - [0, 3, 2, 1, 3], dtype=torch.uint8, device=device).bincount() - self.assertEqual( - torch.tensor([1, 1, 1, 2], dtype=torch.int64, device=device), - long_counts) - # test minlength functionality - int_counts = torch.bincount( - torch.tensor([1, 1, 1, 1], device=device), minlength=5) - self.assertEqual( - torch.tensor([0, 4, 0, 0, 0], dtype=torch.int64, device=device), - int_counts) - # test weights - byte_counts = torch.bincount( - torch.tensor([0, 1, 1, 1, 4], device=device), - torch.tensor([.1, .2, .3, .4, .5], device=device)) - self.assertEqual( - torch.tensor([0.1, 0.9, 0, 0, 0.5], device=device), byte_counts) - byte_counts = torch.bincount( - torch.tensor([0, 1, 1, 1, 4], device=device), - torch.tensor([1, 2, 3, 4, 5], dtype=torch.int8, device=device)) - self.assertEqual( - torch.tensor([1, 9, 0, 0, 5], device=device), byte_counts) - # test non-contiguous inputs and weights - inputs = torch.tensor([[0, 0], [3, 1], [2, 1], [1, 1], [3, 4]], device=device) - weights = torch.tensor([[.1, 1], [.2, 2], [.3, 3], [.4, 4], [.5, 5]], device=device) - for i in [0, 1]: - assert not inputs[:, i].is_contiguous(), "Inputs are supposed to be non-contiguous" - assert not weights[:, i].is_contiguous(), "Weights are supposed to be non-contiguous" - # inputs are non-contiguous but weights are contiguous - self.assertEqual(inputs[:, 0].bincount(), torch.tensor([1, 1, 1, 2])) - # inputs and weights are non-contiguous - self.assertEqual(inputs[:, 1].bincount(weights[:, 1]), torch.tensor([1, 9, 0, 0, 5])) - # weights are non-contiguous but inputs are contiguous - self.assertEqual(inputs[:, 1].contiguous().bincount(weights[:, 1]), - torch.tensor([1, 9, 0, 0, 5])) + def test_var_mean_some_dims(self, device): + sizes = (4, 6, 7, 5, 3) + dims = len(sizes) - # test bincount on non-contiguous slices - all0s = torch.zeros((32, 2), dtype=torch.int64, device=device) - self.assertEqual(all0s[:, 0].bincount(), torch.tensor([32])) + x = torch.rand(sizes, device=device) + for num_of_dims in range(2, dims): + dim_list = list(combinations(list(range(dims)), r=num_of_dims)) + for dim in dim_list: + for unbiased in [False, True]: + for keepdim in [False, True]: + var1, mean1 = torch.var_mean(x, dim=dim, unbiased=unbiased, keepdim=keepdim) + var2 = x.var(dim=dim, unbiased=unbiased, keepdim=keepdim) + mean2 = x.mean(dim=dim, keepdim=keepdim) + self.assertEqual(var1, var2) + self.assertEqual(mean1, mean2) - all1s = torch.ones((32, 2), dtype=torch.int64, device=device) - self.assertEqual(all1s[:, 0].bincount(), torch.tensor([0, 32])) + # passes on ROCm w/ python 2.7, fails w/ python 3.6 + @skipCUDAIfRocm + def test_stft(self, device): + if not TEST_LIBROSA: + raise unittest.SkipTest('librosa not found') - # test large number of bins - global memory use - big_exp = torch.zeros(10000000, device=device) - big_exp[-1] = 50.0 - big_w = torch.tensor([.5] * 100, device=device) - big_out = torch.tensor([9999999] * 100, device=device).bincount(big_w) - self.assertEqual(big_exp, big_out) - # test large input size - big_exp = torch.zeros(2, device=device) - big_exp[1] = 1000000 - big_out = torch.ones(1000000, dtype=torch.int8, device=device).bincount() - self.assertEqual(big_exp, big_out) + def librosa_stft(x, n_fft, hop_length, win_length, window, center): + if window is None: + window = np.ones(n_fft if win_length is None else win_length) + else: + window = window.cpu().numpy() + input_1d = x.dim() == 1 + if input_1d: + x = x.view(1, -1) + result = [] + for xi in x: + ri = librosa.stft(xi.cpu().numpy(), n_fft, hop_length, win_length, window, center=center) + result.append(torch.from_numpy(np.stack([ri.real, ri.imag], -1))) + result = torch.stack(result, 0) + if input_1d: + result = result[0] + return result - @slowTest - def test_slow_test(self): - # Just a smoketest to make sure our slowTest decorator works. - pass + def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None, + center=True, expected_error=None): + x = torch.randn(*sizes, device=device) + if win_sizes is not None: + window = torch.randn(*win_sizes, device=device) + else: + window = None + if expected_error is None: + result = x.stft(n_fft, hop_length, win_length, window, center=center) + ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center) + self.assertEqual(result, ref_result, 7e-6, 'stft comparison against librosa') + else: + self.assertRaises(expected_error, + lambda: x.stft(n_fft, hop_length, win_length, window, center=center)) - def test_bincount_cpu(self): - self._test_bincount(self, device='cpu') + for center in [True, False]: + _test((10,), 7, center=center) + _test((10, 4000), 1024, center=center) - def test_is_nonzero(self): - self.assertExpectedRaises(RuntimeError, lambda: torch.tensor([]).is_nonzero(), subname="empty") - self.assertExpectedRaises(RuntimeError, lambda: torch.tensor([0, 0]).is_nonzero(), subname="multiple") - self.assertFalse(torch.tensor(0).is_nonzero()) - self.assertTrue(torch.tensor(1).is_nonzero()) - self.assertFalse(torch.tensor([0]).is_nonzero()) - self.assertTrue(torch.tensor([1]).is_nonzero()) - self.assertFalse(torch.tensor([[0]]).is_nonzero()) - self.assertTrue(torch.tensor([[1]]).is_nonzero()) + _test((10,), 7, 2, center=center) + _test((10, 4000), 1024, 512, center=center) - def test_meshgrid(self): - a = torch.tensor(1) - b = torch.tensor([1, 2, 3]) - c = torch.tensor([1, 2]) - grid_a, grid_b, grid_c = torch.meshgrid([a, b, c]) - self.assertEqual(grid_a.shape, torch.Size([1, 3, 2])) - self.assertEqual(grid_b.shape, torch.Size([1, 3, 2])) - self.assertEqual(grid_c.shape, torch.Size([1, 3, 2])) - grid_a2, grid_b2, grid_c2 = torch.meshgrid(a, b, c) - self.assertEqual(grid_a2.shape, torch.Size([1, 3, 2])) - self.assertEqual(grid_b2.shape, torch.Size([1, 3, 2])) - self.assertEqual(grid_c2.shape, torch.Size([1, 3, 2])) - expected_grid_a = torch.ones(1, 3, 2, dtype=torch.int64) - expected_grid_b = torch.tensor([[[1, 1], - [2, 2], - [3, 3]]]) - expected_grid_c = torch.tensor([[[1, 2], - [1, 2], - [1, 2]]]) - self.assertTrue(grid_a.equal(expected_grid_a)) - self.assertTrue(grid_b.equal(expected_grid_b)) - self.assertTrue(grid_c.equal(expected_grid_c)) - self.assertTrue(grid_a2.equal(expected_grid_a)) - self.assertTrue(grid_b2.equal(expected_grid_b)) - self.assertTrue(grid_c2.equal(expected_grid_c)) + _test((10,), 7, 2, win_sizes=(7,), center=center) + _test((10, 4000), 1024, 512, win_sizes=(1024,), center=center) - # NB: we must not be built with CUDA; if we are built with CUDA but no CUDA - # is available, we get a different error. - @unittest.skipIf(torch.backends.cuda.is_built() or IS_SANDCASTLE, "CUDA is built, can't test CUDA not built error") - def test_cuda_not_built(self): - msg = "Torch not compiled with CUDA enabled" - self.assertRaisesRegex(AssertionError, msg, lambda: torch.cuda.current_device()) - self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1], device="cuda")) - self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1]).cuda()) - self.assertRaisesRegex(TypeError, msg, lambda: torch.cuda.FloatTensor()) - self.assertRaisesRegex(TypeError, msg, lambda: torch.set_default_tensor_type(torch.cuda.FloatTensor)) - self.assertRaisesRegex(AssertionError, msg, lambda: torch.tensor([1]).to(device="cuda")) + # spectral oversample + _test((10,), 7, 2, win_length=5, center=center) + _test((10, 4000), 1024, 512, win_length=100, center=center) - def test_cast_binary_op(self): - # Scalar - a = torch.tensor(2) - b = torch.tensor(3) - a_copy = a.clone() - b_copy = b.clone() + _test((10, 4, 2), 1, 1, expected_error=RuntimeError) + _test((10,), 11, 1, center=False, expected_error=RuntimeError) + _test((10,), -1, 1, expected_error=RuntimeError) + _test((10,), 3, win_length=5, expected_error=RuntimeError) + _test((10,), 5, 4, win_sizes=(11,), expected_error=RuntimeError) + _test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError) - self.assertEqual(torch.tensor(6), a.float() * b) + @skipCUDAIfRocm + def test_blas_empty(self, device): + + def fn(torchfn, *args): + return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape + for shape in args)) + + # mm, addmm + self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape) + self.assertEqual((0, 5), fn(torch.mm, (0, 0), (0, 5)).shape) + self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape) + self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape) + self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6))) + + self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape) + self.assertEqual((5, 6), fn(torch.addmm, (5, 6), (5, 0), (0, 6)).shape) + + # mv, addmv + self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape) + self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape) + self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,))) + + self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape) + self.assertEqual((3,), fn(torch.addmv, (3,), (3, 0), (0,)).shape) + + # ger, addr + self.assertEqual((0, 0), fn(torch.ger, (0,), (0,)).shape) + self.assertEqual((5, 0), fn(torch.ger, (5,), (0,)).shape) + self.assertEqual((0, 4), fn(torch.ger, (0,), (4,)).shape) + + self.assertEqual((0, 0), fn(torch.addr, (0, 0), (0,), (0,)).shape) + self.assertEqual((5, 0), fn(torch.addr, (5, 0), (5,), (0,)).shape) + self.assertEqual((0, 4), fn(torch.addr, (0, 4), (0,), (4,)).shape) + + # bmm, baddbmm + self.assertEqual((0, 0, 0), fn(torch.bmm, (0, 0, 0), (0, 0, 0)).shape) + self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape) + self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape) + self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6))) + + self.assertEqual((0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape) + self.assertEqual((3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape) + self.assertEqual((0, 5, 6), fn(torch.baddbmm, (0, 5, 6), (0, 5, 0), (0, 0, 6)).shape) + self.assertEqual((3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape) + + # addbmm + self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape) + self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape) + self.assertEqual((5, 6), fn(torch.addbmm, (5, 6), (0, 5, 0), (0, 0, 6)).shape) + + # matmul + self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,))) + self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape) + self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape) + self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape) + self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4))) + + # dot + self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,))) + + if torch._C.has_lapack: + # lu + A_LU, pivots = fn(torch.lu, (0, 5, 5)) + self.assertEqual([(0, 5, 5), (0, 5)], [A_LU.shape, pivots.shape]) + A_LU, pivots = fn(torch.lu, (0, 0, 0)) + self.assertEqual([(0, 0, 0), (0, 0)], [A_LU.shape, pivots.shape]) + A_LU, pivots = fn(torch.lu, (2, 0, 0)) + self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape]) + + @skipCUDAIfRocm + def test_blas_alpha_beta_empty(self, device): + # ensure beta is respected + value = 11 + input = torch.full((2,), value, device=device) + mat = torch.ones((2, 0), device=device) + vec = torch.ones((0,), device=device) + out = torch.randn((2,), device=device) + alpha = 6 + beta = 3 + self.assertEqual(torch.full((2,), beta * value, device=device), + torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta)) + self.assertEqual(torch.full((2,), beta * value, device=device), + torch.addmv(input=input, mat=mat, vec=vec, alpha=alpha, beta=beta, out=out)) + + # torch.addmm + input = torch.full((2, 3), value, device=device) + mat2 = torch.ones((0, 3), device=device) + out = torch.randn((2, 3), device=device) + self.assertEqual(torch.full((2, 3), beta * value, device=device), + torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta)) + self.assertEqual(torch.full((2, 3), beta * value, device=device), + torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta, out=out)) + + @skipCUDAIfRocm + def test_unique_dim(self, device): + self.assertFalse(hasattr(torch, 'unique_dim')) - self.assertEqual(a.type(), a_copy.type()) - self.assertEqual(a.data.type(), a_copy.data.type()) - self.assertEqual(b.type(), b_copy.type()) - self.assertEqual(b.data.type(), b_copy.type()) + def run_test(device, dtype): + x = torch.tensor([[[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]], + [[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]]], + dtype=dtype, + device=device) + x_empty = torch.empty(5, 0, dtype=dtype, device=device) + x_ill_formed_empty = torch.empty(5, 0, 0, dtype=dtype, device=device) + x_ill_formed_empty_another = torch.empty(5, 0, 5, dtype=dtype, device=device) + expected_unique_dim0 = torch.tensor([[[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]]], + dtype=dtype, + device=device) + expected_inverse_dim0 = torch.tensor([0, 0]) + expected_counts_dim0 = torch.tensor([2]) + expected_unique_dim1 = torch.tensor([[[0., 1.], + [1., 1.], + [2., 1.]], + [[0., 1.], + [1., 1.], + [2., 1.]]], + dtype=dtype, + device=device) + expected_inverse_dim1 = torch.tensor([1, 0, 2, 0]) + expected_counts_dim1 = torch.tensor([2, 1, 1]) + expected_unique_dim2 = torch.tensor([[[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]], + [[1., 1.], + [0., 1.], + [2., 1.], + [0., 1.]]], + dtype=dtype, + device=device) + expected_inverse_dim2 = torch.tensor([0, 1]) + expected_counts_dim2 = torch.tensor([1, 1]) + expected_unique_empty = torch.tensor([], dtype=dtype, device=device) + expected_inverse_empty = torch.tensor([], dtype=torch.long, device=device) + expected_counts_empty = torch.tensor([], dtype=torch.long, device=device) + # dim0 + x_unique = torch.unique(x, dim=0) + self.assertEqual(expected_unique_dim0, x_unique) - def test_cartesian_prod(self): - a = torch.tensor([1]) - b = torch.tensor([1, 2, 3]) - c = torch.tensor([1, 2]) - prod = torch.cartesian_prod(a, b, c) - expected = torch.tensor(list(product([a], b, c))) - self.assertEqual(expected, prod) + x_unique, x_inverse = torch.unique( + x, + return_inverse=True, + dim=0) + self.assertEqual(expected_unique_dim0, x_unique) + self.assertEqual(expected_inverse_dim0, x_inverse) - # test 0 size input - d = torch.empty(0, dtype=b.dtype) - prod = torch.cartesian_prod(a, b, c, d) - expected = torch.empty(0, 4, dtype=b.dtype) - self.assertEqual(expected, prod) + x_unique, x_counts = torch.unique( + x, + return_inverse=False, + return_counts=True, + dim=0) + self.assertEqual(expected_unique_dim0, x_unique) + self.assertEqual(expected_counts_dim0, x_counts) - # test single input - prod = torch.cartesian_prod(b) - self.assertEqual(b, prod) + x_unique, x_inverse, x_counts = torch.unique( + x, + return_inverse=True, + return_counts=True, + dim=0) + self.assertEqual(expected_unique_dim0, x_unique) + self.assertEqual(expected_inverse_dim0, x_inverse) + self.assertEqual(expected_counts_dim0, x_counts) - def test_combinations(self): - a = torch.tensor([1, 2, 3]) + # dim1 + x_unique = torch.unique(x, dim=1) + self.assertEqual(expected_unique_dim1, x_unique) - c = torch.combinations(a, r=1) - expected = torch.tensor(list(combinations(a, r=1))) - self.assertEqual(c, expected) + x_unique, x_inverse = torch.unique( + x, + return_inverse=True, + dim=1) + self.assertEqual(expected_unique_dim1, x_unique) + self.assertEqual(expected_inverse_dim1, x_inverse) - c = torch.combinations(a, r=1, with_replacement=True) - expected = torch.tensor(list(combinations_with_replacement(a, r=1))) - self.assertEqual(c, expected) + x_unique, x_counts = torch.unique( + x, + return_inverse=False, + return_counts=True, + dim=1) + self.assertEqual(expected_unique_dim1, x_unique) + self.assertEqual(expected_counts_dim1, x_counts) - c = torch.combinations(a) - expected = torch.tensor(list(combinations(a, r=2))) - self.assertEqual(c, expected) + x_unique, x_inverse, x_counts = torch.unique( + x, + return_inverse=True, + return_counts=True, + dim=1) + self.assertEqual(expected_unique_dim1, x_unique) + self.assertEqual(expected_inverse_dim1, x_inverse) + self.assertEqual(expected_counts_dim1, x_counts) - c = torch.combinations(a, with_replacement=True) - expected = torch.tensor(list(combinations_with_replacement(a, r=2))) - self.assertEqual(c, expected) + # dim2 + x_unique = torch.unique(x, dim=2) + self.assertEqual(expected_unique_dim2, x_unique) - c = torch.combinations(a, r=3) - expected = torch.tensor(list(combinations(a, r=3))) - self.assertEqual(c, expected) + x_unique, x_inverse = torch.unique( + x, + return_inverse=True, + dim=2) + self.assertEqual(expected_unique_dim2, x_unique) + self.assertEqual(expected_inverse_dim2, x_inverse) - c = torch.combinations(a, r=4) - expected = torch.empty(0, 4, dtype=a.dtype) - self.assertEqual(c, expected) + x_unique, x_counts = torch.unique( + x, + return_inverse=False, + return_counts=True, + dim=2) + self.assertEqual(expected_unique_dim2, x_unique) + self.assertEqual(expected_counts_dim2, x_counts) - c = torch.combinations(a, r=5) - expected = torch.empty(0, 5, dtype=a.dtype) - self.assertEqual(c, expected) + x_unique, x_inverse, x_counts = torch.unique( + x, + return_inverse=True, + return_counts=True, + dim=2) + self.assertEqual(expected_unique_dim2, x_unique) + self.assertEqual(expected_inverse_dim2, x_inverse) + self.assertEqual(expected_counts_dim2, x_counts) + + # test empty tensor + x_unique, x_inverse, x_counts = torch.unique( + x_empty, + return_inverse=True, + return_counts=True, + dim=1) + self.assertEqual(expected_unique_empty, x_unique) + self.assertEqual(expected_inverse_empty, x_inverse) + self.assertEqual(expected_counts_empty, x_counts) - # test empty imput - a = torch.empty(0) - c1 = torch.combinations(a) - c2 = torch.combinations(a, with_replacement=True) - expected = torch.empty(0, 2, dtype=a.dtype) - self.assertEqual(c1, expected) - self.assertEqual(c2, expected) + # test not a well formed tensor + # Checking for runtime error, as this is the expected behaviour + with self.assertRaises(RuntimeError): + torch.unique( + x_ill_formed_empty, + return_inverse=True, + return_counts=True, + dim=1) - def test_has_internal_overlap(self): - OVERLAP_NO = 0 - OVERLAP_YES = 1 - OVERLAP_TOO_HARD = 2 + # test along dim2 + with self.assertRaises(RuntimeError): + torch.unique( + x_ill_formed_empty_another, + return_inverse=True, + return_counts=True, + dim=2) - # Check for contiguous tensors - a = torch.randn(3, 3) - self.assertEqual(torch._debug_has_internal_overlap(a), OVERLAP_NO) + # test consecutive version + y = torch.tensor( + [[0, 1], + [0, 1], + [0, 1], + [1, 2], + [1, 2], + [3, 4], + [0, 1], + [0, 1], + [3, 4], + [1, 2]], + dtype=dtype, + device=device + ) + expected_y_unique = torch.tensor( + [[0, 1], + [1, 2], + [3, 4], + [0, 1], + [3, 4], + [1, 2]], + dtype=dtype, + device=device + ) + expected_y_inverse = torch.tensor([0, 0, 0, 1, 1, 2, 3, 3, 4, 5], dtype=dtype, device=device) + expected_y_counts = torch.tensor([3, 2, 1, 2, 1, 1], dtype=dtype, device=device) + y_unique, y_inverse, y_counts = torch.unique_consecutive(y, return_inverse=True, return_counts=True, dim=0) + self.assertEqual(expected_y_inverse, y_inverse) + self.assertEqual(expected_y_counts, y_counts) - # Checks for zero strides - b = torch.randn(1, 3) - b_expanded = b.expand(4, 3) - self.assertEqual(torch._debug_has_internal_overlap(b_expanded), OVERLAP_YES) + run_test(device, torch.float) + run_test(device, torch.double) + run_test(device, torch.long) + run_test(device, torch.uint8) @unittest.skipIf(torch.cuda.device_count() < 2, 'only one GPU detected') - def test_reverse_binary_ops_multiple_device(self): + @onlyCUDA + def test_reverse_binary_ops_multiple_device(self, device): self.assertEqual(2 + torch.tensor(3), 2 + torch.tensor(3).to("cuda:1")) # __radd__ self.assertEqual(2 - torch.tensor(3), 2 - torch.tensor(3).to("cuda:1")) # __rsub__ self.assertEqual(2 * torch.tensor(3), 2 * torch.tensor(3).to("cuda:1")) # __rmul__ @@ -12828,571 +12365,566 @@ def test_reverse_binary_ops_multiple_device(self): torch.tensor(2).to("cuda:1") // torch.tensor(3).to("cuda:0"), torch.tensor(2) // torch.tensor(3)) - def test_allow_tensor_metadata_change(self): - def do_test(t): - with self.assertRaisesRegex( - RuntimeError, - "set_sizes_contiguous is not allowed on a Tensor created from .data or .detach()"): - t.resize_((2, 1)) - with self.assertRaisesRegex( - RuntimeError, - "set_storage is not allowed on a Tensor created from .data or .detach()"): - t.set_() - with self.assertRaisesRegex( - RuntimeError, - "set_storage_offset is not allowed on a Tensor created from .data or .detach()"): - t.set_(t.storage(), 0, t.size(), list(t.stride())) - - do_test(torch.tensor([[1, 2]]).data) - do_test(torch.tensor([[1, 2]]).detach()) - - def test_c10_layer_norm(self): - # test that we can call c10 ops and they return a reasonable result - X = torch.rand(5, 5, dtype=torch.float) - weight = torch.rand(*X.size()[1:], dtype=torch.float) - bias = torch.rand(*X.size()[1:], dtype=torch.float) - epsilon = 1e-4 - - expected_norm = torch.nn.functional.layer_norm( - X, X.size()[1:], weight=weight, bias=bias, eps=epsilon) - actual_norm, actual_mean, actual_stdev = \ - torch.ops._caffe2.LayerNorm(torch.tensor(X), torch.tensor( - weight), torch.tensor(bias), 1, epsilon, True) - torch.testing.assert_allclose(expected_norm, actual_norm) - - def test_memory_format(self): - x = torch.randn(10, 3, 32, 32) - nhwc = x.contiguous(memory_format=torch.channels_last) - self.assertFalse(nhwc.is_contiguous()) - self.assertTrue(nhwc.is_contiguous(memory_format=torch.channels_last)) - self.assertEqual(nhwc, x) - - def test_memory_format_preserved_after_permute(self): - x = torch.randn(10, 3, 32, 32) - nhwc = x.contiguous(memory_format=torch.channels_last) - y = nhwc.permute(0, 1, 3, 2).permute(0, 1, 3, 2) - self.assertTrue(y.is_contiguous(memory_format=torch.channels_last)) - - def test_memory_format_contiguous_returns_same_tensor_if_already_satisfies(self): - x = torch.randn(10, 32, 32, 3).permute(0, 3, 1, 2) - alias = x.contiguous(memory_format=torch.channels_last) - alias.fill_(7) - self.assertEqual(x, alias) - - @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') - def test_memory_format_permute_cuda(self): - x = torch.randn(10, 3, 32, 32).cuda() - nhwc = x.contiguous(memory_format=torch.channels_last) - y = nhwc.permute(0, 1, 3, 2).permute(0, 1, 3, 2) - self.assertTrue(y.is_contiguous(memory_format=torch.channels_last)) - - @unittest.skipIf(not torch.cuda.is_available(), 'no CUDA') - def test_memory_format_empty_like_cuda(self): - x = torch.randn(10, 3, 32, 32).cuda() - self._test_memory_format_empty_like(x) - - def test_memory_format_empty_like_cpu(self): - x = torch.randn(10, 3, 32, 32) - self._test_memory_format_empty_like(x) - - def _test_memory_format_empty_like(self, x): - nhwc = x.contiguous(memory_format=torch.channels_last) - - like = torch.empty_like(nhwc, memory_format=torch.preserve_format) - self.assertFalse(like.is_contiguous()) - self.assertTrue(like.is_contiguous(memory_format=torch.channels_last)) - - like_x = torch.empty_like(x, memory_format=torch.preserve_format) - self.assertTrue(like_x.is_contiguous()) - self.assertFalse(like_x.is_contiguous(memory_format=torch.channels_last)) - - like = torch.empty_like(x, memory_format=torch.channels_last) - self.assertFalse(like.is_contiguous()) - self.assertTrue(like.is_contiguous(memory_format=torch.channels_last)) - - like = torch.empty_like(nhwc, memory_format=torch.contiguous_format) - self.assertTrue(like.is_contiguous()) - self.assertFalse(like.is_contiguous(memory_format=torch.channels_last)) - - like = torch.empty_like(nhwc) - self.assertTrue(like.is_contiguous()) - self.assertFalse(like.is_contiguous(memory_format=torch.channels_last)) - - sparse = x.to_sparse() - with self.assertRaises(RuntimeError): - z = torch.empty_like(sparse, memory_format=torch.preserve_format) - - def test_memory_format_empty(self): - with self.assertRaises(RuntimeError): - x = torch.empty((3, 3), memory_format=torch.channels_last) - x = torch.empty((3, 3, 3, 3), memory_format=torch.channels_last) - self.assertTrue(x.is_contiguous(memory_format=torch.channels_last)) - - def test_subclass_tensors(self): - # raise an error when trying to subclass FloatTensor - with self.assertRaisesRegex(TypeError, "type 'torch.FloatTensor' is not an acceptable base type"): - class Foo1(torch.FloatTensor): - pass + @onlyCUDA + def test_ceil_out_mismatch(self, device): + a = torch.randn(1) + b = torch.randn(1, device=device) + self.assertRaises(RuntimeError, lambda: torch.ceil(a, out=b)) - # but allow subclassing Tensor: - class Foo2(torch.Tensor): - def foo(self): - return 5 - f = Foo2() - self.assertEqual(f.foo(), 5) - def test_ndim(self): - a = torch.randn(1, 2, 3) - self.assertEqual(3, a.ndim) - b = torch.randn(()) - self.assertEqual(0, b.ndim) - c = torch.randn(1, 0) - self.assertEqual(2, c.ndim) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_has_storage_numpy(self, device): + for dtype in [np.float32, np.float64, np.int64, + np.int32, np.int16, np.uint8]: + arr = np.array([1], dtype=dtype) + self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.float32).storage()) + self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.double).storage()) + self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.int).storage()) + self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.long).storage()) + self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.uint8).storage()) + + def test_all_any_empty(self, device): + x = torch.ByteTensor().to(device) + self.assertTrue(x.all()) + self.assertFalse(x.any()) - def test_T(self): - a = torch.randn(2, 3, 4) - t1 = a.T - t2 = a.permute(2, 1, 0) - self.assertEqual(t2, t1) - b = torch.randn(10) - self.assertEqual(b, b.T) - scalar = torch.tensor(5) - self.assertEqual(scalar, scalar.T) + x = torch.BoolTensor().to(device) + self.assertTrue(x.all()) + self.assertFalse(x.any()) - def test_python_types(self): - a1 = torch.randn((1, 2), dtype=torch.float64) - a2 = torch.randn((1, 2), dtype=float) - self.assertEqual(a1.dtype, a2.dtype) + @onlyCUDA + def test_multinomial_device_constrain(self, device): + x = torch.empty(0, device="cpu") + y = torch.empty(0, device=device) + self.assertRaisesRegex( + RuntimeError, "multinomial arguments must have the same device", + lambda: torch.multinomial(x, 2, out=y)) + + @unittest.skipIf(torch.cuda.device_count() < 2, "only one GPU detected") + @onlyCUDA + def test_multinomial_gpu_device_constrain(self, device): + x = torch.empty(0, device="cuda:0") + y = torch.empty(0, device="cuda:1") + self.assertRaisesRegex( + RuntimeError, "multinomial arguments must have the same device", + lambda: torch.multinomial(x, 2, out=y)) - b1 = torch.arange(10, 20, dtype=torch.int64) - b2 = torch.arange(10, 20, dtype=int) - self.assertEqual(b1.dtype, b2.dtype) + @unittest.skipIf(torch.cuda.device_count() < 2, 'only one GPU detected') + @onlyCUDA + def test_zeros_like_multiple_device(self, device): + expected = torch.zeros(100, 100, device=device) + x = torch.randn(100, 100, device='cuda:1', dtype=torch.float32) + output = torch.zeros_like(x) + self.assertEqual(output, expected) - c1 = torch.tensor([True, False], dtype=torch.bool) - c2 = torch.tensor([True, False], dtype=bool) - self.assertEqual(c1.dtype, c2.dtype) + @onlyCUDA + def test_ones_like(self, device): + expected = torch.ones(100, 100, device=device) - def test_fill_diagonal(self): - a1 = torch.randn(7, 3) - a2 = a1.clone() - v = 1 - for i in range(3): - a2[i][i] = v - a1.fill_diagonal_(v) - self.assertEqual(a1, a2) + res1 = torch.ones_like(expected) + self.assertEqual(res1, expected) - b1 = torch.randn(7, 3) - b2 = b1.clone() - for i in range(3): - b2[i][i] = v - b2[i + 4][i] = v - b1.fill_diagonal_(v, wrap=True) - self.assertEqual(b1, b2) + @unittest.skipIf(torch.cuda.device_count() < 2, 'only one GPU detected') + @onlyCUDA + def test_ones_like_multiple_device(self, device): + expected = torch.ones(100, 100, device=device) + x = torch.randn(100, 100, device='cuda:1', dtype=torch.float32) + output = torch.ones_like(x) + self.assertEqual(output, expected) - c1 = torch.rand(3, 3, 3) - c2 = c1.clone() - for i in range(3): - c2[i][i][i] = v - c1.fill_diagonal_(v) - self.assertEqual(c1, c2) + @unittest.skipIf(torch.cuda.device_count() < 2, 'fewer than 2 GPUs detected') + @onlyCUDA + def test_device_guard(self, device): + # verify that all operators with `device_guard: False` behave properly with multiple devices. + # TODO: if we had operator introspection we could figure out this set of operators automatically... + current_device = torch.cuda.current_device() + device = torch.device('cuda:1') if current_device == 0 else torch.device('cuda:0') + x = torch.randn((1, 2, 3), device=device) + y = torch.zeros((1, 3, 2), device=device) + scalar = torch.tensor(5, device=device) - # non-contiguous tensor - d1 = torch.rand(3, 3, 3)[:, 1, ...] - d2 = d1.clone() - for i in range(3): - d2[i][i] = v - d1.fill_diagonal_(v) - self.assertEqual(d1, d2) + # property ops + torch.cudnn_is_acceptable(x) + x.is_distributed() + x.is_floating_point() + x.is_complex() + x.is_same_size(y) + x.is_signed() + x.size(0) + x.stride(0) + x.numel() + x.is_set_to(y) + x.data_ptr() + scalar.is_nonzero() - e1 = torch.rand(7, 3, 3)[:, 1, ...] - e2 = e1.clone() - for i in range(3): - e2[i][i] = v - e2[i + 4][i] = v - e1.fill_diagonal_(v, wrap=True) - self.assertEqual(e1, e2) + # sparse property ops + y[0][1] = 5 + y_sparse = y.to_sparse() + y_sparse.sparse_dim() + y_sparse._dimI() + y_sparse.dense_dim() + y_sparse._dimV() + y_sparse._nnz() + y_sparse.is_coalesced() + y_sparse._indices() + y_sparse._values() + y_sparse.indices() + y_sparse.values() - def test_sign(self): - for device in torch.testing.get_all_device_types(): - for dtype in torch.testing.get_all_math_dtypes(device): + # in-place ops + def inplace(): + return torch.randn((1, 2, 3), device=device) + inplace().as_strided_(y.size(), y.stride()) + inplace().resize_(y.size()) + inplace().squeeze_() + inplace().squeeze_(0) + inplace().unsqueeze_(2) + inplace().transpose_(1, 2) + inplace().squeeze_().t_() + inplace().set_(x.storage()) + inplace().set_(x.storage(), x.storage_offset(), x.size(), x.stride()) + inplace().set_(x) + inplace().set_() + y_sparse._coalesced_(True) - # Include NaN for floating point numbers - if dtype.is_floating_point: - dt_info = torch.finfo(dtype) + # shape modification + x.as_strided(y.size(), y.stride()) + x.expand((5, 2, 3)) + x.expand_as(x) + x.sum_to_size((1,)) + torch.broadcast_tensors(x , x) + x.reshape((1, 3, 2)) + x.reshape_as(y) + x.squeeze() + x.squeeze(0) + x.squeeze().t() + x.transpose(1, 2) + x.unsqueeze(2) + x.view((1, 3, 2)) + x.view_as(y) - # Create tensor (with NaN checking) - a = torch.tensor([float('nan'), -12, 0, 71, dt_info.min, dt_info.max], device=device, dtype=dtype) - a_target = torch.tensor([0, -1, 0, 1, -1, 1], device=device, dtype=dtype) + # chunk, split, etc. + x.chunk(2, dim=1) + x.split(1, dim=2) + x.split_with_sizes([1, 2], dim=2) + x.unfold(dimension=2, size=1, step=1) - else: - dt_info = torch.iinfo(dtype) + x.narrow(1, 1, 1) + x.select(1, 1) + torch.isnan(x) - # If unsigned type, everything should be >= 0 - if dt_info.min == 0: - a = torch.tensor([12, 0, 71, dt_info.min, dt_info.max], device=device, dtype=dtype) - a_target = torch.tensor([1, 0, 1, 0, 1], device=device, dtype=dtype) - else: - a = torch.tensor([-12, 0, 71, dt_info.min, dt_info.max], device=device, dtype=dtype) - a_target = torch.tensor([-1, 0, 1, -1, 1], device=device, dtype=dtype) + torch.empty((1, 3, 2), out=y) + torch.empty_like(x) + torch.empty_like(x, dtype=torch.int64) - self.assertEqual(a.sign(), a_target, 'sign device={} dtype={}'.format(device, dtype)) - self.assertEqual(torch.sign(a), a_target, 'sign device={} dtype={}'.format(device, dtype)) + # to + x.to(x) + x.to(y) + x.to(x, copy=True) - out = torch.empty_like(a) - torch.sign(a, out=out) - self.assertEqual(out, a_target, 'sign_out device={} dtype={}'.format(device, dtype)) + @onlyCUDA + def test_tensor_factory_gpu_type_inference(self, device): + saved_type = torch.Tensor().type() + torch.set_default_tensor_type(torch.cuda.DoubleTensor) + torch.set_default_dtype(torch.float32) + self.assertIs(torch.float32, torch.tensor(0.).dtype) + self.assertEqual(torch.device('cuda:0'), torch.tensor(0.).device) + torch.set_default_dtype(torch.float64) + self.assertIs(torch.float64, torch.tensor(0.).dtype) + self.assertEqual(torch.device('cuda:0'), torch.tensor(0.).device) + torch.set_default_tensor_type(saved_type) - a.sign_() - self.assertEqual(a, a_target, 'sign_ device={} dtype={}'.format(device, dtype)) + @onlyCUDA + def test_tensor_factory_gpu_type(self, device): + saved_type = torch.Tensor().type() + torch.set_default_tensor_type(torch.cuda.FloatTensor) + x = torch.zeros((5, 5)) + self.assertIs(torch.float32, x.dtype) + self.assertTrue(x.is_cuda) + torch.set_default_tensor_type(torch.cuda.DoubleTensor) + x = torch.zeros((5, 5)) + self.assertIs(torch.float64, x.dtype) + self.assertTrue(x.is_cuda) + torch.set_default_tensor_type(saved_type) - # Include test for bool dtype - a_bool = torch.tensor([True, True, False, float('nan')], device=device).bool() - a_bool_target = torch.tensor([True, True, False, True], device=device).bool() - self.assertEqual(a_bool.sign(), a_bool_target, 'sign device={} dtype=bool'.format(device)) - self.assertEqual(torch.sign(a_bool), a_bool_target, 'sign device={} dtype=bool'.format(device)) + @onlyCPU + def test_renorm_ps(self, device): + # full reduction + x = torch.randn(5, 5) + xn = x.numpy() + for p in [1, 2, 3, 4, inf]: + res = x.renorm(p, 1, 1) + expected = x / x.norm(p, 0, keepdim=True).clamp(min=1) + self.assertEqual(res.numpy(), expected.numpy(), "renorm failed for {}-norm".format(p)) - a_out = torch.empty_like(a_bool) - torch.sign(a_bool, out=a_out) - self.assertEqual(a_out, a_bool_target, 'sign_out device={} dtype=bool'.format(device)) + @onlyCUDA + def test_topk_noncontiguous_gpu(self, device): + t = torch.randn(20, device=device)[::2] + top1, idx1 = t.topk(5) + top2, idx2 = t.contiguous().topk(5) + self.assertEqual(top1, top2) + self.assertEqual(idx1, idx2) - a_bool.sign_() - self.assertEqual(a_bool, a_bool_target, 'sign_ device={} dtype=bool'.format(device)) + def test_is_signed(self, device): + self.assertEqual(torch.IntTensor(5).to(device).is_signed(), True) + self.assertEqual(torch.ByteTensor(5).to(device).is_signed(), False) + self.assertEqual(torch.CharTensor(5).to(device).is_signed(), True) + self.assertEqual(torch.FloatTensor(5).to(device).is_signed(), True) + self.assertEqual(torch.HalfTensor(10).to(device).is_signed(), True) - def test_function_unwrap_message(self): - self.assertRaisesRegex(RuntimeError, ' call to _th_lt', - lambda: torch.ones(1, dtype=torch.float) < torch.ones(1, dtype=torch.double)) + @onlyCUDA + def test_solve_methods_arg_device(self, device): + for b_device, A_device in product(['cpu', 'cuda'], repeat=2): + if b_device == A_device: + continue - def check_internal_mem_overlap(self, inplace_op, num_inputs, device, - expected_failure=False): - if isinstance(inplace_op, str): - inplace_op = getattr(torch.Tensor, inplace_op) - input = torch.randn(1, device=device).expand(3, 3) - inputs = [input] + [torch.randn_like(input) - for i in range(num_inputs - 1)] - if not expected_failure: - with self.assertRaisesRegex(RuntimeError, 'single memory location'): - inplace_op(*inputs) - else: - with self.assertRaises(AssertionError): - with self.assertRaisesRegex(RuntimeError, 'single memory location'): - inplace_op(*inputs) + b = torch.randn(3, 1, device=b_device) + A = torch.randn(3, 3, device=A_device) + err_str = "Expected b and A to be on the same device" + with self.assertRaisesRegex(RuntimeError, err_str): + torch.solve(b, A) - def unary_check_input_output_mem_overlap(self, data, sz, op, - expected_failure=False): + with self.assertRaisesRegex(RuntimeError, err_str): + torch.cholesky_solve(b, A) - def _test(op, output, input): - output_exp = torch.empty_like(output) - op(input, out=output_exp) - self.assertEqual(op(input, out=output), output_exp, op.__name__) + with self.assertRaisesRegex(RuntimeError, err_str): + torch.triangular_solve(b, A) - # output is identical to input: - _test(op, output=data[0:sz], input=data[0:sz]) - # output and input are independent: - _test(op, output=data[0:sz], input=data[sz:2 * sz]) - # output partially overlaps with input: - if not expected_failure: - with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): - _test(op, data[0:sz], data[1:sz + 1]) - else: - with self.assertRaises(AssertionError): - with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): - _test(op, data[0:sz], data[1:sz + 1]) + # b and A have to be modified to match accepted inputs sizes for lu_solve + b = b.unsqueeze(0) + A = A.unsqueeze(0) + with self.assertRaisesRegex(RuntimeError, err_str): + torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=A_device).int()) - @torchtest.for_all_device_types() - def test_unary_out_op_mem_overlap(self, device): - sz = 3 - doubles = torch.randn(2 * sz, device=device) - positives = torch.randint(1, 100, (2 * sz,), device=device).double() - ints = torch.randint(-100, 100, (2 * sz,), device=device) - unary_mem_overlap_cases = [ - ("abs", doubles, True, True, 'cpu'), - ("abs", doubles, False, True, 'cuda'), - ("acos", doubles, True, True, 'cpu'), - ("acos", doubles, False, True, 'cuda'), - ("asin", doubles, True, True, 'cpu'), - ("asin", doubles, False, True, 'cuda'), - ("atan", doubles, True, True, 'cpu'), - ("atan", doubles, False, True, 'cuda'), - ("bitwise_not", ints, True, True, 'cpu'), - ("bitwise_not", ints, True, True, 'cuda'), - ("ceil", doubles, True, True, 'cpu'), - ("ceil", doubles, True, True, 'cuda'), - ("cos", doubles, True, True, 'cpu'), - ("cos", doubles, False, True, 'cuda'), - ("cosh", doubles, True, True, 'cpu'), - ("cosh", doubles, False, True, 'cuda'), - ("digamma", doubles, True, True, 'cpu'), - ("erf", doubles, True, True, 'cpu'), - ("erf", doubles, False, True, 'cuda'), - ("erfc", doubles, True, True, 'cpu'), - ("erfc", doubles, False, True, 'cuda'), - ("erfinv", doubles, True, True, 'cpu'), - ("erfinv", doubles, True, True, 'cuda'), - ("exp", doubles, True, True, 'cpu'), - ("exp", doubles, False, True, 'cuda'), - ("expm1", doubles, True, True, 'cpu'), - ("expm1", doubles, False, True, 'cuda'), - ("floor", doubles, True, True, 'cpu'), - ("floor", doubles, False, True, 'cuda'), - ("frac", doubles, True, True, 'cpu'), - ("frac", doubles, False, True, 'cuda'), - ("log", positives, True, True, 'cpu'), - ("log", positives, False, True, 'cuda'), - ("log10", positives, True, True, 'cpu'), - ("log10", positives, False, True, 'cuda'), - ("log1p", positives, True, True, 'cpu'), - ("log1p", positives, False, True, 'cuda'), - ("log2", positives, True, True, 'cpu'), - ("log2", positives, False, True, 'cuda'), - ("neg", doubles, True, True, 'cpu'), - ("neg", doubles, True, True, 'cuda'), - ("reciprocal", doubles, True, True, 'cpu'), - ("reciprocal", doubles, False, True, 'cuda'), - ("round", doubles, True, True, 'cpu'), - ("round", doubles, False, True, 'cuda'), - ("rsqrt", positives, True, True, 'cpu'), - ("rsqrt", positives, False, True, 'cuda'), - ("sin", doubles, True, True, 'cpu'), - ("sin", doubles, False, True, 'cuda'), - ("sinh", doubles, True, True, 'cpu'), - ("sinh", doubles, False, True, 'cuda'), - ("sigmoid", doubles, True, True, 'cpu'), - ("sigmoid", doubles, False, False, 'cuda'), - ("sqrt", doubles, True, True, 'cpu'), - ("sqrt", doubles, False, True, 'cuda'), - ("tan", doubles, True, True, 'cpu'), - ("tan", doubles, False, True, 'cuda'), - ("tanh", doubles, True, True, 'cpu'), - ("tanh", doubles, False, True, 'cuda'), - ("trunc", doubles, True, True, 'cpu'), - ("trunc", doubles, False, True, 'cuda') - ] + # This checks if a suitable error message is thrown + # when LU output and pivots are on the same device + with self.assertRaisesRegex(RuntimeError, + "Expected LU_pivots and LU_data to be on the same device"): + torch.lu_solve(b, A, torch.rand(A.shape[:-1], device=b_device).int()) - for (fn, inputs, has_input_output_mem_overlap_check, - has_internal_mem_overlap_check, dev) in unary_mem_overlap_cases: - if dev != device: - continue - out_fn = getattr(torch, fn) - in_fn = getattr(torch.Tensor, fn + '_') + # Note - reports a leak of 512 bytes on CUDA device 1 + @unittest.skipIf(torch.cuda.device_count() < 2, 'less than 2 GPUs detected') + @skipCUDAMemoryLeakCheckIf(True) + @onlyCUDA + def test_tensor_set_errors_multigpu(self, device): + f_cuda0 = torch.randn((2, 3), dtype=torch.float32, device='cuda:0') + f_cuda1 = torch.randn((2, 3), dtype=torch.float32, device='cuda:1') - self.unary_check_input_output_mem_overlap(inputs, sz, out_fn, - expected_failure=not has_input_output_mem_overlap_check) + self.assertRaises(RuntimeError, lambda: f_cuda0.set_(f_cuda1.storage())) + self.assertRaises(RuntimeError, + lambda: f_cuda0.set_(f_cuda1.storage(), 0, f_cuda1.size(), f_cuda1.stride())) + self.assertRaises(RuntimeError, lambda: f_cuda0.set_(f_cuda1)) - self.check_internal_mem_overlap(in_fn, num_inputs=1, device=dev, - expected_failure=not has_internal_mem_overlap_check) + @onlyCUDA + def test_half_tensor(self, device): + x = torch.randn(5, 5).half() + self.assertEqual(x.to(device), x) - def binary_check_input_output_mem_overlap(self, op, device, - expected_failure=False): - sz = 3 - data = torch.randn(2 * sz, device=device) - other = torch.randn(sz, device=device) + xc = x.to(device) + with tempfile.NamedTemporaryFile() as f: + torch.save(xc, f) + f.seek(0) + xc2 = torch.load(f) + self.assertIsInstance(xc2, type(xc)) + self.assertEqual(xc.float(), xc2.float()) - self.unary_check_input_output_mem_overlap( - data, sz, lambda input, out: op(other, input, out=out), - expected_failure=expected_failure) + @onlyCUDA + def test_serialization(self, device): + def _test_serialization(filecontext_lambda): + device_count = torch.cuda.device_count() + t0 = torch.cuda.FloatTensor(5).fill_(1) + torch.cuda.set_device(device_count - 1) + tn = torch.cuda.FloatTensor(3).fill_(2) + torch.cuda.set_device(0) + b = (t0, tn) + with filecontext_lambda() as f: + torch.save(b, f) + f.seek(0) + c = torch.load(f) + self.assertEqual(b, c, 0) + u0, un = c + self.assertEqual(u0.get_device(), 0) + self.assertEqual(un.get_device(), device_count - 1) - self.unary_check_input_output_mem_overlap( - data, sz, lambda input, out: op(input, other, out=out), - expected_failure=expected_failure) + _test_serialization(tempfile.NamedTemporaryFile) + _test_serialization(BytesIOContext) - @torchtest.for_all_device_types() - def test_binary_op_mem_overlap(self, device): - ops = [ - ("add", True, True, 'cpu'), - ("add", True, True, 'cuda'), - ("mul", True, True, 'cpu'), - ("mul", True, True, 'cuda'), - ("sub", True, True, 'cpu'), - ("sub", True, True, 'cuda'), - ("div", True, True, 'cpu'), - ("div", True, True, 'cuda'), - ("pow", True, True, 'cpu'), - ("pow", False, False, 'cuda') - ] + def test_memory_format_preserved_after_permute(self, device): + x = torch.randn(10, 3, 32, 32, device=device) + nhwc = x.contiguous(memory_format=torch.channels_last) + y = nhwc.permute(0, 1, 3, 2).permute(0, 1, 3, 2) + self.assertTrue(y.is_contiguous(memory_format=torch.channels_last)) - for (fn, has_input_output_mem_overlap_check, - has_internal_mem_overlap_check, dev) in ops: - if dev != device: - continue - out_op = getattr(torch, fn) - inplace_op = getattr(torch.Tensor, fn + '_') - self.check_internal_mem_overlap( - inplace_op, num_inputs=2, device=device, - expected_failure=not has_internal_mem_overlap_check) + def test_memory_format_empty_like(self, device): + x = torch.randn(10, 3, 32, 32, device=device) + nhwc = x.contiguous(memory_format=torch.channels_last) - self.binary_check_input_output_mem_overlap(out_op, device, - expected_failure=not has_input_output_mem_overlap_check) + like = torch.empty_like(nhwc, memory_format=torch.preserve_format) + self.assertFalse(like.is_contiguous()) + self.assertTrue(like.is_contiguous(memory_format=torch.channels_last)) - def ternary_check_input_output_mem_overlap(self, op, device, - expected_failure=False): - sz = 3 - data = torch.randn(2 * sz, device=device) - other1 = torch.randn(sz, device=device) - other2 = torch.randn(sz, device=device) + like_x = torch.empty_like(x, memory_format=torch.preserve_format) + self.assertTrue(like_x.is_contiguous()) + self.assertFalse(like_x.is_contiguous(memory_format=torch.channels_last)) - self.unary_check_input_output_mem_overlap( - data, sz, lambda input, out: op(input, other1, other2, out=out), - expected_failure=expected_failure) + like = torch.empty_like(x, memory_format=torch.channels_last) + self.assertFalse(like.is_contiguous()) + self.assertTrue(like.is_contiguous(memory_format=torch.channels_last)) - self.unary_check_input_output_mem_overlap( - data, sz, lambda input, out: op(other1, input, other2, out=out), - expected_failure=expected_failure) + like = torch.empty_like(nhwc, memory_format=torch.contiguous_format) + self.assertTrue(like.is_contiguous()) + self.assertFalse(like.is_contiguous(memory_format=torch.channels_last)) - self.unary_check_input_output_mem_overlap( - data, sz, lambda input, out: op(other1, other2, input, out=out), - expected_failure=expected_failure) + like = torch.empty_like(nhwc) + self.assertTrue(like.is_contiguous()) + self.assertFalse(like.is_contiguous(memory_format=torch.channels_last)) - @torchtest.for_all_device_types() - def test_ternary_op_mem_overlap(self, device): - ops = [ - ("addcmul", True, True, 'cpu'), - ("addcmul", True, True, 'cuda'), - ("addcdiv", True, True, 'cpu'), - ("addcdiv", True, True, 'cuda'), - ("lerp", True, True, 'cpu'), - ("lerp", False, False, 'cuda') - ] + sparse = x.to_sparse() + with self.assertRaises(RuntimeError): + z = torch.empty_like(sparse, memory_format=torch.preserve_format) - for (fn, has_input_output_mem_overlap_check, - has_internal_mem_overlap_check, dev) in ops: - if dev != device: - continue - out_op = getattr(torch, fn) - inplace_op = getattr(torch.Tensor, fn + '_') - self.check_internal_mem_overlap( - inplace_op, num_inputs=3, device=device, - expected_failure=not has_internal_mem_overlap_check) - self.ternary_check_input_output_mem_overlap(out_op, dev, - expected_failure=not has_input_output_mem_overlap_check) + def test_unique(self, device): + x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], device=device) + expected_unique = torch.tensor([1, 2, 3, 5, 8], device=device) + expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device) + expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device) - @torchtest.for_all_device_types() - def test_copy_mem_overlap(self, device): - self.check_internal_mem_overlap( - torch.Tensor.copy_, num_inputs=2, device=device) - sz = 3 - doubles = torch.randn(2 * sz, device=device) - self.unary_check_input_output_mem_overlap( - doubles, sz, lambda input, out: out.copy_(input)) + x_unique = torch.unique(x) + self.assertEqual( + expected_unique.tolist(), sorted(x_unique.tolist())) - @torchtest.for_all_device_types() - def test_pow_scalar_overloads_mem_overlap(self, device): - sz = 3 - doubles = torch.randn(2 * sz, device=device) - self.check_internal_mem_overlap( - lambda t: t.pow_(42), num_inputs=1, device=device, - expected_failure=(device == 'cuda')) - self.unary_check_input_output_mem_overlap( - doubles, sz, lambda input, out: torch.pow(input, 42, out=out), - expected_failure=(device == 'cuda')) - self.unary_check_input_output_mem_overlap( - doubles, sz, lambda input, out: torch.pow(42, input, out=out), - expected_failure=(device == 'cuda')) + x_unique, x_inverse = x.unique(return_inverse=True) + self.assertEqual( + expected_unique.tolist(), sorted(x_unique.tolist())) + self.assertEqual(expected_inverse.numel(), x_inverse.numel()) + + x_unique = x.unique(sorted=True) + self.assertEqual(expected_unique, x_unique) + + x_unique, x_counts = torch.unique(x, sorted=True, return_counts=True) + self.assertEqual(expected_counts, x_counts) + + x_unique, x_inverse = torch.unique( + x, sorted=True, return_inverse=True) + self.assertEqual(expected_unique, x_unique) + self.assertEqual(expected_inverse, x_inverse) + + x_unique, x_inverse, x_counts = torch.unique( + x, sorted=True, return_inverse=True, return_counts=True) + self.assertEqual(expected_unique, x_unique) + self.assertEqual(expected_inverse, x_inverse) + self.assertEqual(expected_counts, x_counts) + + # Tests per-element unique on a higher rank tensor. + y = x.view(2, 2, 2) + y_unique, y_inverse = y.unique(sorted=True, return_inverse=True) + self.assertEqual(expected_unique, y_unique) + self.assertEqual(expected_inverse.view(y.size()), y_inverse) + + y_unique, y_inverse, y_counts = torch.unique( + y, sorted=True, return_inverse=True, return_counts=True) + self.assertEqual(expected_unique, y_unique) + self.assertEqual(expected_inverse.view(y.size()), y_inverse) + self.assertEqual(expected_counts, y_counts) + + # Tests unique on other types. + int_unique, int_inverse, int_counts = torch.unique( + torch.tensor([2, 1, 2], dtype=torch.int, device=device), + sorted=True, + return_inverse=True, + return_counts=True + ) + self.assertEqual(torch.tensor([1, 2], dtype=torch.int, device=device), int_unique) + self.assertEqual(torch.tensor([1, 0, 1], dtype=torch.long, device=device), int_inverse) + self.assertEqual(torch.tensor([1, 2], dtype=torch.long, device=device), int_counts) + + double_unique, double_inverse, double_counts = torch.unique( + torch.tensor([2., 1.5, 2.1, 2.], dtype=torch.double, device=device), + sorted=True, + return_inverse=True, + return_counts=True + ) + self.assertEqual(torch.tensor([1.5, 2., 2.1], dtype=torch.double, device=device), double_unique) + self.assertEqual(torch.tensor([1, 0, 2, 1], dtype=torch.long, device=device), double_inverse) + self.assertEqual(torch.tensor([1, 2, 1], dtype=torch.long, device=device), double_counts) + + byte_unique, byte_inverse, byte_counts = torch.unique( + torch.tensor([133, 7, 7, 7, 42, 128], dtype=torch.uint8, device=device), + sorted=True, + return_inverse=True, + return_counts=True + ) + self.assertEqual(torch.tensor([7, 42, 128, 133], dtype=torch.uint8, device=device), byte_unique) + self.assertEqual(torch.tensor([3, 0, 0, 0, 1, 2], dtype=torch.long, device=device), byte_inverse) + self.assertEqual(torch.tensor([3, 1, 1, 1], dtype=torch.long, device=device), byte_counts) + + # test consecutive version + z = torch.tensor([1, 2, 2, 2, 5, 5, 2, 2, 3], device=device) + expected_z_unique = torch.tensor([1, 2, 5, 2, 3], device=device) + expected_z_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4], device=device) + expected_z_counts = torch.tensor([1, 3, 2, 2, 1], device=device) + + z_unique = torch.unique_consecutive(z) + self.assertEqual(z_unique, expected_z_unique) + + z_unique, z_inverse = torch.unique_consecutive(z, return_inverse=True) + self.assertEqual(z_unique, expected_z_unique) + self.assertEqual(z_inverse, expected_z_inverse) + + z_unique, z_counts = torch.unique_consecutive(z, return_counts=True) + self.assertEqual(z_unique, expected_z_unique) + self.assertEqual(z_counts, expected_z_counts) + + z_unique, z_inverse, z_counts = torch.unique_consecutive(z, return_inverse=True, return_counts=True) + self.assertEqual(z_unique, expected_z_unique) + self.assertEqual(z_inverse, expected_z_inverse) + self.assertEqual(z_counts, expected_z_counts) + + @dtypesIfCUDA(torch.half, torch.float, torch.double) + @dtypes(torch.float, torch.double) + def test_erfinv(self, device, dtype): + # general testing. Narrow the range to avoid accuracy issues + input_values = torch.randn(4, 4, dtype=dtype, device=device).clamp(-0.3, 0.3) + self.assertEqual(input_values.erf().erfinv(), input_values) + # test inf + self.assertTrue(torch.equal(torch.tensor([-1, 1], dtype=dtype, device=device).erfinv(), + torch.tensor([-inf, inf], dtype=dtype, device=device))) + # test nan + self.assertEqual(torch.tensor([-2, 2], dtype=dtype, device=device).erfinv(), + torch.tensor([nan, nan], dtype=dtype, device=device)) + + if dtype == torch.double: + # double precision + a = torch.tensor([0.5, 0.8], dtype=torch.double, device=device).erfinv() + self.assertAlmostEqual(a[0].item(), 0.47693627620447, places=13) + self.assertAlmostEqual(a[1].item(), 0.90619380243682, places=13) + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_ctor_with_numpy_array(self, device): + correct_dtypes = [ + np.double, + np.float, + np.float16, + np.int64, + np.int32, + np.int16, + np.int8, + np.uint8, + np.bool, + ] -# Functions to test negative dimension wrapping -METHOD = 1 -INPLACE_METHOD = 2 -FUNCTIONAL = 4 -DIM_ARG = None + incorrect_byteorder = '>' if sys.byteorder == 'little' else '<' + incorrect_dtypes = map(lambda t: incorrect_byteorder + t, ['d', 'f']) -def make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim=0): - def neg_dim_test(self): - if isinstance(tensor_arg, list): - assert METHOD not in types and INPLACE_METHOD not in types - x = [torch.randn(arg) for arg in tensor_arg] - ndim = len(tensor_arg[-1]) - else: - x = torch.randn(*tensor_arg) - ndim = len(tensor_arg) - ndim += extra_dim + for dtype in correct_dtypes: + array = np.array([1, 2, 3, 4], dtype=dtype) - n_dim_to_test = sum(map(lambda e: e is DIM_ARG, arg_constr())) + # Upcast + tensor = torch.DoubleTensor(array).to(device) + for i in range(len(array)): + self.assertEqual(tensor[i], array[i]) - for dims_val in combinations(range(ndim), n_dim_to_test): - arg = arg_constr() - arg_neg = copy.deepcopy(arg) - idx = 0 - for i, v in enumerate(arg): - if v is DIM_ARG: - arg[i] = dims_val[idx] - arg_neg[i] = dims_val[idx] - ndim - idx += 1 + # Downcast (sometimes) + tensor = torch.FloatTensor(array).to(device) + for i in range(len(array)): + self.assertEqual(tensor[i], array[i]) - if METHOD in types: - a = getattr(x, name)(*arg) - b = getattr(x, name)(*arg_neg) - self.assertEqual(a, b) + tensor = torch.HalfTensor(array).to(device) + for i in range(len(array)): + self.assertEqual(tensor[i], array[i]) - if INPLACE_METHOD in types: - a = x.clone() - getattr(a, name + '_')(*arg) - b = x.clone() - getattr(b, name + '_')(*arg_neg) - self.assertEqual(a, b) + def test_dlpack_conversion(self, device): + x = torch.randn(1, 2, 3, 4, device=device, dtype=torch.float) + z = from_dlpack(to_dlpack(x)) + self.assertEqual(z, x) - if FUNCTIONAL in types: - a = getattr(torch, name)(x, *arg) - b = getattr(torch, name)(x, *arg_neg) - self.assertEqual(a, b) + @onlyCUDA + def test_pin_memory_from_constructor(self, device): + def _get_like(t, **kwargs): + return [ + torch.rand_like(t, **kwargs), + torch.randn_like(t, **kwargs), + torch.empty_like(t, **kwargs), + torch.full_like(t, 4, **kwargs), + torch.zeros_like(t, **kwargs), + torch.ones_like(t, **kwargs), + ] - return neg_dim_test + def _get_tensors(**kwargs): + return [ + torch.tensor([10, 11], **kwargs), + torch.randn(3, 5, **kwargs), + torch.rand(3, **kwargs), + # torch.randint(3, 5, **kwargs), // unsupported + torch.zeros(3, **kwargs), + torch.randperm(3, **kwargs), + torch.empty(6, **kwargs), + torch.ones(6, **kwargs), + torch.eye(6, **kwargs), + torch.arange(3, 5, **kwargs)] + pinned_tensors = _get_tensors(pin_memory=True) + _get_like(torch.empty(5, dtype=torch.float64), pin_memory=True) + for x in pinned_tensors: + self.assertTrue(x.is_pinned()) -def idx_tensor(size, max_val): - return torch.LongTensor(*size).random_(0, max_val - 1) + tensors = _get_tensors() + _get_like(torch.empty(5, dtype=torch.float64, pin_memory=True)) + for x in tensors: + self.assertFalse(x.is_pinned()) + def test_storage_device(self, device): + x = torch.tensor([], device=device) + self.assertEqual(x.dtype, x.storage().dtype) -def add_neg_dim_tests(): - neg_dim_tests = [ - ('narrow', (10, 20, 30), lambda: [DIM_ARG, 0, 5], [METHOD]), - ('transpose', (10, 20, 30), lambda: [DIM_ARG, DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL]), - ('size', (10, 20, 30), lambda: [DIM_ARG], [METHOD]), - ('cat', [(2, 3, 4), (2, 3, 4)], lambda: [DIM_ARG], [FUNCTIONAL]), - ('chunk', (10, 20, 30), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]), - ('gather', (10, 20), lambda: [DIM_ARG, idx_tensor((10, 20), 10)], [METHOD, FUNCTIONAL]), - ('index_select', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10)], [METHOD, FUNCTIONAL]), - ('split', (10, 20), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]), - ('squeeze', (10, 1, 20, 1), lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL]), - ('unbind', (2, 3, 4), lambda: [DIM_ARG], [FUNCTIONAL]), - ('unsqueeze', (10, 20), lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL], 1), - ('cumprod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), - ('cumsum', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), - ('mean', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), - ('median', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), - ('mode', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), - ('norm', (10, 20), lambda: [2, DIM_ARG], [METHOD, FUNCTIONAL]), - ('prod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), - ('std', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), - ('sum', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), - ('var', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), - ('kthvalue', (10, 20), lambda: [3, DIM_ARG], [METHOD, FUNCTIONAL]), - ('max', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), - ('min', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), - ('sort', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), - ('topk', (10, 20), lambda: [5, DIM_ARG], [METHOD, FUNCTIONAL]), - ('renorm', (10, 20), lambda: [2, DIM_ARG, 1], [METHOD, INPLACE_METHOD, FUNCTIONAL]), - ('index_add', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), torch.randn(10, 10)], [INPLACE_METHOD]), - ('index_copy', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), torch.randn(10, 10)], [INPLACE_METHOD]), - ('index_fill', (10, 10), lambda: [DIM_ARG, idx_tensor((10,), 10), 12], [INPLACE_METHOD]), - ('scatter', (10, 10), lambda: [DIM_ARG, idx_tensor((10, 10), 10), torch.randn(10, 10)], [INPLACE_METHOD]), - ('select', (10, 20), lambda: [DIM_ARG, 3], [METHOD]), - ('unfold', (10, 20), lambda: [DIM_ARG, 5, 2], [METHOD]), - ] + @unittest.skipIf(torch.cuda.device_count() < 2, 'less than 2 GPUs detected') + @onlyCUDA + def test_storage_multigpu(self, device): + devices = ['cuda:0', 'cuda:1'] + for device in devices: + x = torch.tensor([], device=device) + self.assertEqual(x.dtype, x.storage().dtype) - for decl in neg_dim_tests: - if len(decl) == 4: - name, tensor_arg, arg_constr, types = decl - extra_dim = 0 - elif len(decl) == 5: - name, tensor_arg, arg_constr, types, extra_dim = decl + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + def test_lu(self, device): + from common_utils import random_fullrank_matrix_distinct_singular_value as fullrank - test_name = 'test_' + name + '_neg_dim' + def run_test(device, pivot): + def run_subtest(matrix_size, batches, device, pivot): + a = fullrank(matrix_size, *batches).to(device) + a_LU_info, pivots_info, info_ = a.lu(pivot=pivot, get_infos=True) + self.assertEqual(a_LU_info.size(), torch.Size(batches + (matrix_size, matrix_size))) + self.assertEqual(pivots_info.size(), torch.Size(batches + (matrix_size,))) + self.assertEqual(info_.size(), torch.Size(batches)) + self.assertEqual(info_.abs().sum(), 0) + a_LU, pivots = a.lu(pivot=pivot) + self.assertEqual(a_LU, a_LU_info) + self.assertEqual(pivots_info, pivots) + if a.is_cuda: + a_LU_info_nopiv, nopiv, info_nopiv = a.lu(pivot=False, get_infos=True) + self.assertEqual(nopiv, torch.arange(1, 1 + a.size(-1), device=device, dtype=torch.int32).expand(a.shape[:-1])) + self.assertEqual(info_, info_nopiv) + P, L, U = torch.lu_unpack(a_LU, pivots) + self.assertEqual(P.matmul(L.matmul(U)), a) + + for ms, batch in product([3, 5, 7], [(), (2,), (3,), (3, 5)]): + run_subtest(ms, batch, device, pivot) + + # Info should be positive for rank deficient matrices + a = torch.ones(5, 3, 3, device=device) + self.assertGreater(a.lu(pivot=pivot, get_infos=True)[2][0], 0) + + run_test(device, True) - assert not hasattr(_TestTorchMixin, test_name), "Duplicated test name: " + test_name - setattr(_TestTorchMixin, test_name, make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim)) + if device == 'cpu': + # Error checking, no pivoting variant on CPU + with self.assertRaisesRegex(RuntimeError, 'lu without pivoting is not implemented on the CPU'): + torch.lu(torch.empty(1, 2, 2), pivot=False) + else: + run_test(device, False) -add_neg_dim_tests() +add_neg_dim_tests() +instantiate_device_type_tests(TestTorchDeviceType, globals()) class TestTorch(TestCase, _TestTorchMixin): pass diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py new file mode 100644 index 0000000000000..45e83a60c0e5e --- /dev/null +++ b/test/test_type_promotion.py @@ -0,0 +1,235 @@ +import torch +import unittest + +from common_utils import TestCase, run_tests, load_tests + +# load_tests from common_utils is used to automatically filter tests for +# sharding on sandcastle. This line silences flake warnings +load_tests = load_tests + +class TestTypePromotion(TestCase): + + def setUp(self): + super(TestTypePromotion, self).setUp() + torch.set_default_dtype(torch.float32) + self.device = 'cpu' + + # In-place operations don't promote. + # `int+float -> float` but `int.add_(float)` is rejected as an error. + # Promoting inplace would require re-allocating and copying the memory of the + # tensor data, since element size could change. + def test_inplace(self): + int_tensor = torch.ones([4, 4, 4], dtype=torch.int32, device=self.device) + + self.assertRaisesRegex(RuntimeError, "can't be cast to", lambda: int_tensor.add_(1.5)) + + expected = torch.ones([4, 4, 4], dtype=torch.int32, device=self.device) + + long_tensor = torch.ones([4, 4, 4], dtype=torch.int64, device=self.device) + int_tensor.add_(long_tensor) + int_tensor.add_(1) + three = expected + 2 + self.assertEqual(int_tensor, three) + self.assertEqual(int_tensor.dtype, torch.int32) + + bool_tensor = torch.tensor([1, 1, 1], dtype=torch.bool) + uint8_tensor = torch.tensor([1, 1, 1], dtype=torch.uint8) + # We treat bool as a separate category, which means uint8 cannot cast to bool. + self.assertRaisesRegex(RuntimeError, "can't be cast to", lambda: bool_tensor.add_(uint8_tensor)) + + # We allow demotion from signed to unsigned, unlike numpy, because: + # * We don't want the performance penalty of inspecting scalar values. + # * We don't want 'signed' to be considered a distinct 'category' + # in promotion rules. + # We don't want signed to be a separate category because if it was, + # uint16_tensor + 5 would result in a long_tensor, which is not what we want. + int16_tensor = torch.tensor([1, 1, 1], dtype=torch.int16) + uint8_tensor *= int16_tensor + + def test_unsinged(self): + dont_promote = torch.ones(3, dtype=torch.uint8) + 5 + self.assertEqual(dont_promote.dtype, torch.uint8) + + # some basic examples + + def test_int_promotion(self): + a = torch.ones([4, 4, 4], dtype=torch.int32, device=self.device) + b = torch.ones([4, 4, 4], dtype=torch.int64, device=self.device) + c = a + b + self.assertEqual(c, b + b) + self.assertEqual(c.dtype, torch.int64) + + def test_float_promotion(self): + a = torch.ones([4, 4, 4], dtype=torch.float, device=self.device) + b = torch.ones([4, 4, 4], dtype=torch.double, device=self.device) + c = a + b + self.assertEqual(c, b + b) + self.assertEqual(c.dtype, torch.double) + c = b + a + self.assertEqual(c, b + b) + self.assertEqual(c.dtype, torch.double) + + def test_add_wrapped(self): + a = torch.ones([4, 4, 4], dtype=torch.int, device=self.device) + b = 1 + c = a + b + self.assertEqual(c, a + a) + self.assertEqual(c.dtype, torch.int) + + def test_int_to_float(self): + a = torch.ones([4, 4, 4], dtype=torch.int32, device=self.device) + b = torch.ones([4, 4, 4], dtype=torch.float, device=self.device) + c = a + b + self.assertEqual(c.dtype, torch.float32) + + # some examples from: + # https://github.com/pytorch/pytorch/issues/9515 + + def test_from_issue(self): + a = torch.rand(3, dtype=torch.float32, device=self.device) + u = torch.tensor([0, 0, 1], dtype=torch.uint8, device=self.device) + self.assertEqual((a * 5).dtype, torch.float32) + self.assertEqual((u + 1).dtype, torch.uint8) + self.assertEqual((u + 1000).dtype, torch.uint8) # integer overflow + + # not a "wrapped number" + other = torch.tensor(5.5, dtype=torch.double, device=self.device) + + self.assertEqual((u + 5.5).dtype, torch.get_default_dtype()) + self.assertEqual((u + other).dtype, torch.double) + # adding a 0-dim tensor to a float doesn't promote to double unless first + # type was integral. + self.assertEqual((a + other).dtype, torch.float32) + + def test_half(self): + half = torch.tensor(5.5, dtype=torch.float16, device=self.device) + if(self.device == 'cpu'): + self.assertRaisesRegex(RuntimeError, "not implemented for 'Half'", + lambda: half + 2.2) + else: + self.assertEqual((half + 2.2).dtype, torch.float16) + self.assertEqual((half + 100000).dtype, torch.float16) # inf + default_tensor = torch.tensor(100000.0, device=self.device) + self.assertEqual((half + default_tensor).dtype, torch.get_default_dtype()) + + def test_alternate_result(self): + f = torch.tensor([1, 1, 1, 1], dtype=torch.float, device=self.device) + o = torch.tensor([0, 0, 0, 0], dtype=torch.long, device=self.device) + self.assertRaisesRegex(RuntimeError, + "can't be cast to", + lambda: torch.add(f, f, out=o)) + d = torch.tensor([1, 1, 1, 1], dtype=torch.double, device=self.device) + torch.add(f, f, out=d) + self.assertEqual(d.dtype, torch.double) + self.assertEqual(f + f, d) + + def test_mixed_type_backward(self): + f = torch.ones([3, 3], dtype=torch.float, requires_grad=True, device=self.device) + ten = torch.tensor([10.], dtype=torch.double, device=self.device) + tens = f * ten + s = (tens + 2).sum() + s.backward() + self.assertEqual(f.grad, tens) + + # If we don't convert the returned grad_input to the actual input type + # we get an error like: + # RuntimeError: Function SubBackward0 returned an invalid gradient at index 0 - expected type \ + # torch.FloatTensor but got torch.DoubleTensor + f_dtypes = [torch.float, torch.double] + f_dtypes = f_dtypes if self.device == 'cpu' else f_dtypes + [torch.half] + i_dtypes = [torch.int, torch.long] + for f in ['add', 'sub', 'rsub', 'mul', 'div']: + for dtype1 in f_dtypes: + for dtype2 in (f_dtypes + i_dtypes): + x = torch.ones(10, requires_grad=True, dtype=dtype1, device=self.device) + y = torch.ones(10, dtype=dtype2, device=self.device) + + func = getattr(torch, f) + func(x, y).sum().backward() + + # verifies that a.add(b) is the same as a.to(b.dtype).add(b) in cases + # where that should hold. + def test_many_promotions(self): + from_to = { + torch.float16: torch.float32, + torch.half: torch.float16, + torch.int: torch.long, + torch.uint8: torch.long, + torch.uint8: torch.float, + torch.int: torch.float, + torch.int: torch.double, + torch.int16: torch.long, + torch.float16: torch.double, + torch.bool: torch.long, + torch.bool: torch.float + } + + for k, v in from_to.items(): + a = torch.rand([3, 3], device=self.device).to(k) # no _th_uniform for half on cpu. + b = torch.rand([3, 3], device=self.device).to(v) + c = a.add(b) + d = a.to(v).add(b) + self.assertEqual(c.dtype, d.dtype, message='from {} to {}'.format(k, v)) + self.assertEqual(c, d, message='from {} to {}'.format(k, v)) + + def test_non_promoting_ops(self): + x = torch.ones(4, dtype=torch.double) + err = 'expected dtype .ouble .*but got dtype .loat' + self.assertRaisesRegex(RuntimeError, err, + lambda: torch.neg(torch.ones(4, dtype=torch.float), out=x)) + self.assertRaisesRegex(RuntimeError, err, + lambda: torch.lerp(x, torch.ones(4, dtype=torch.float), 1)) + + def test_alpha_mismatch(self): + x = torch.ones(4, dtype=torch.int) + err = 'alpha must not be' + self.assertRaisesRegex(RuntimeError, err, + lambda: torch.add(x, x, alpha=1.1)) + x = x.to(torch.bool) + self.assertRaisesRegex(RuntimeError, err, + lambda: torch.add(x, x, alpha=1.1)) + self.assertEqual(x + x, torch.add(x, x, alpha=True)) + + def test_booleans(self): + onedim = torch.tensor([True]) + + self.assertEqual(onedim + onedim, onedim) + self.assertEqual(onedim + True, onedim) + self.assertEqual(torch.add(True, True), True) + self.assertEqual(torch.add(False, False), False) + self.assertEqual(torch.add(False, True), True) + + self.assertRaisesRegex(RuntimeError, "Boolean alpha only supported", + lambda: torch.add(1, 1, alpha=True)) + self.assertEquals(torch.add(torch.tensor(True), torch.tensor(True), True), torch.tensor(True)) + + def test_create_bool_tensors(self): + expected = torch.tensor([0], dtype=torch.int64, device=self.device) + self.assertEqual(torch.arange(False, True, device=self.device), expected) + self.assertEqual(torch.arange(True, device=self.device), expected) + expected = torch.tensor([0, 0.5], dtype=torch.get_default_dtype(), device=self.device) + self.assertEqual(torch.arange(False, True, 0.5, device=self.device), expected) + expected = torch.ones(0, dtype=torch.int64, device=self.device) + self.assertEqual(torch.arange(False, False, device=self.device), expected) + + self.assertEqual(torch.linspace(False, True, device=self.device), torch.linspace(0, 1, device=self.device)) + self.assertEqual(torch.logspace(False, True, device=self.device), torch.logspace(0, 1, device=self.device)) + + # this seems like odd behavior but ints also create float tensors, numpy doesn't have this function. + self.assertEqual(torch.scalar_tensor(False, device=self.device), torch.tensor(0., device=self.device)) + +@unittest.skipIf(not torch.cuda.is_available(), "no cuda") +class TestTypePromotionCuda(TestTypePromotion): + def setUp(self): + super(TestTypePromotionCuda, self).setUp() + self.device = 'cuda' + +# ensure type promotion logic properly handles an alternate default dtype. +class TestTypePromotionDefaultDouble(TestTypePromotion): + def setUp(self): + super(TestTypePromotionDefaultDouble, self).setUp() + torch.set_default_dtype(torch.double) + + +if __name__ == '__main__': + run_tests() diff --git a/test/test_utils.py b/test/test_utils.py index 0970f6656055d..2a1a124baf6d3 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -14,7 +14,7 @@ import torch.hub as hub from torch.autograd._functions.utils import prepare_onnx_paddings from torch.autograd._functions.utils import check_onnx_broadcast -from common_utils import skipIfRocm, load_tests +from common_utils import skipIfRocm, load_tests, IS_MACOS # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings @@ -519,6 +519,7 @@ def sum_of_model_parameters(model): SUM_OF_PRETRAINED_RESNET18_PARAMS = -12703.992365 +@unittest.skipIf(IS_MACOS, 'Broken on macOS; see #26032') class TestHub(TestCase): @classmethod def setUpClass(cls): diff --git a/third_party/fbgemm b/third_party/fbgemm index f9078fdd81046..53f0c0d175ae4 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit f9078fdd8104603897948f563245f4528b77da5b +Subproject commit 53f0c0d175ae4283609a5b251052f9c6598b8aee diff --git a/third_party/gloo b/third_party/gloo index a9fa7c8d6e95a..2101e02ceabd9 160000 --- a/third_party/gloo +++ b/third_party/gloo @@ -1 +1 @@ -Subproject commit a9fa7c8d6e95a6be22b734e22b595bc80f03aea0 +Subproject commit 2101e02ceabd9f1b0bb354f6ea705cefe83558b2 diff --git a/third_party/ideep b/third_party/ideep index cb3d393a0164d..78eafa5d23192 160000 --- a/third_party/ideep +++ b/third_party/ideep @@ -1 +1 @@ -Subproject commit cb3d393a0164dbda9f755b0b082d996f5b094fe1 +Subproject commit 78eafa5d231924e3d525d4dc46de880015257618 diff --git a/third_party/onnx b/third_party/onnx index 4eb737c8eacf4..1316afc9f972f 160000 --- a/third_party/onnx +++ b/third_party/onnx @@ -1 +1 @@ -Subproject commit 4eb737c8eacf48f8771d515d0a373a4b964cfab7 +Subproject commit 1316afc9f972f81340faa05763e2898f38bcc3b0 diff --git a/tools/amd_build/patches/a_torch_cuda___init__.py.patch b/tools/amd_build/patches/a_torch_cuda___init__.py.patch index 94c2d5926e89d..92e772d58666c 100644 --- a/tools/amd_build/patches/a_torch_cuda___init__.py.patch +++ b/tools/amd_build/patches/a_torch_cuda___init__.py.patch @@ -1,26 +1,23 @@ diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py -index f52ab04f1..4e3f63c4b 100644 +index 8450f27812..1de27a5b0d 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py -@@ -123,7 +123,7 @@ def _lazy_call(callable): +@@ -144,8 +144,6 @@ def _lazy_call(callable): # Don't store the actual traceback to avoid memory cycle _queued_calls.append((callable, traceback.format_stack())) -_lazy_call(_check_capability) -+#_lazy_call(_check_capability) - +- class DeferredCudaCallError(Exception): -@@ -159,9 +159,9 @@ def _lazy_init(): - "Cannot re-initialize CUDA in forked subprocess. " + msg) - _check_driver() - torch._C._cuda_init() -- _cudart = _load_cudart() -- _cudart.cudaGetErrorName.restype = ctypes.c_char_p -- _cudart.cudaGetErrorString.restype = ctypes.c_char_p -+ # _cudart = _load_cudart() -+ #_cudart.cudaGetErrorName.restype = ctypes.c_char_p -+ #_cudart.cudaGetErrorString.restype = ctypes.c_char_p - _original_pid = os.getpid() - _initialized = True - # Important to do this after _initialized, since some queued calls + pass +@@ -191,9 +189,6 @@ def _lazy_init(): + "Cannot re-initialize CUDA in forked subprocess. " + msg) + _check_driver() + torch._C._cuda_init() +- _cudart = _load_cudart() +- _cudart.cudaGetErrorName.restype = ctypes.c_char_p +- _cudart.cudaGetErrorString.restype = ctypes.c_char_p + _original_pid = os.getpid() + # Some of the queued calls may reentrantly call _lazy_init(); + # we need to just return without initializing in that case. diff --git a/tools/amd_build/pyHIPIFY/constants.py b/tools/amd_build/pyHIPIFY/constants.py index 6198031a9c8f6..1384e6c891008 100644 --- a/tools/amd_build/pyHIPIFY/constants.py +++ b/tools/amd_build/pyHIPIFY/constants.py @@ -51,6 +51,7 @@ API_RAND = 41 API_LAST = 42 API_FFT = 43 +API_RTC = 44 HIP_UNSUPPORTED = 43 API_PYTORCH = 1337 diff --git a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py index 7f0ae6d16aace..f34d5a0f38d3a 100644 --- a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py +++ b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py @@ -108,6 +108,7 @@ ("CUstreamBatchMemOpType", ("hipStreamBatchMemOpType", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), ("CUdevice_P2PAttribute", ("hipDeviceP2PAttribute", CONV_TYPE, API_DRIVER, HIP_UNSUPPORTED)), ("CUevent", ("hipEvent_t", CONV_TYPE, API_DRIVER)), + ("CUevent_st", ("ihipEvent_t", CONV_TYPE, API_DRIVER)), ("CUevent_flags", ("hipEventFlags", CONV_EVENT, API_DRIVER, HIP_UNSUPPORTED)), ("CUfilter_mode", ("hipTextureFilterMode", CONV_TEX, API_DRIVER)), ("CUGLDeviceList", ("hipGLDeviceList", CONV_GL, API_DRIVER, HIP_UNSUPPORTED)), @@ -277,6 +278,7 @@ ("cusparse.h", ("hipsparse.h", CONV_INCLUDE, API_RAND)), ("cufft.h", ("hipfft.h", CONV_INCLUDE, API_BLAS)), ("cufftXt.h", ("hipfft.h", CONV_INCLUDE, API_BLAS)), + ("nvrtc.h", ("hip/hiprtc.h", CONV_INCLUDE, API_RTC)), ("thrust/system/cuda/", ("thrust/system/hip/", CONV_INCLUDE, API_BLAS)), ("cub/util_allocator.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), ("cub/block/block_reduce.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), @@ -2181,6 +2183,29 @@ ("cufftDestroy", ("hipfftDestroy", CONV_MATH_FUNC, API_FFT)), ("cufftGetVersion", ("hipfftGetVersion", CONV_MATH_FUNC, API_FFT)), ("cufftGetProperty", ("hipfftGetProperty", CONV_MATH_FUNC, API_FFT, HIP_UNSUPPORTED)), + ("nvrtcResult", ("hiprtcResult", CONV_TYPE, API_RTC)), + ("NVRTC_SUCCESS", ("HIPRTC_SUCCESS", CONV_TYPE, API_RTC)), + ("NVRTC_ERROR_OUT_OF_MEMORY", ("HIPRTC_ERROR_OUT_OF_MEMORY", CONV_TYPE, API_RTC)), + ("NVRTC_ERROR_PROGRAM_CREATION_FAILURE", ("HIPRTC_ERROR_PROGRAM_CREATION_FAILURE", CONV_TYPE, API_RTC)), + ("NVRTC_ERROR_INVALID_INPUT", ("HIPRTC_ERROR_INVALID_INPUT", CONV_TYPE, API_RTC)), + ("NVRTC_ERROR_INVALID_PROGRAM", ("HIPRTC_ERROR_INVALID_PROGRAM", CONV_TYPE, API_RTC)), + ("NVRTC_ERROR_COMPILATION", ("HIPRTC_ERROR_COMPILATION", CONV_TYPE, API_RTC)), + ("NVRTC_ERROR_BUILTIN_OPERATION_FAILURE", ("HIPRTC_ERROR_BUILTIN_OPERATION_FAILURE", CONV_TYPE, API_RTC)), + ("NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION", ("HIPRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION", CONV_TYPE, API_RTC)), + ("NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID", ("HIPRTC_ERROR_NAME_EXPRESSION_NOT_VALID", CONV_TYPE, API_RTC)), + ("NVRTC_ERROR_INTERNAL_ERROR", ("HIPRTC_ERROR_INTERNAL_ERROR", CONV_TYPE, API_RTC)), + ("nvrtcGetErrorString", ("hiprtcGetErrorString", CONV_JIT, API_RTC)), + ("nvrtcVersion", ("hiprtcVersion", CONV_JIT, API_RTC)), + ("nvrtcProgram", ("hiprtcProgram", CONV_TYPE, API_RTC)), + ("nvrtcAddNameExpression", ("hiprtcAddNameExpression", CONV_JIT, API_RTC)), + ("nvrtcCompileProgram", ("hiprtcCompileProgram", CONV_JIT, API_RTC)), + ("nvrtcCreateProgram", ("hiprtcCreateProgram", CONV_JIT, API_RTC)), + ("nvrtcDestroyProgram", ("hiprtcDestroyProgram", CONV_JIT, API_RTC)), + ("nvrtcGetLoweredName", ("hiprtcGetLoweredName", CONV_JIT, API_RTC)), + ("nvrtcGetProgramLog", ("hiprtcGetProgramLog", CONV_JIT, API_RTC)), + ("nvrtcGetProgramLogSize", ("hiprtcGetProgramLogSize", CONV_JIT, API_RTC)), + ("nvrtcGetPTX", ("hiprtcGetCode", CONV_JIT, API_RTC)), + ("nvrtcGetPTXSize", ("hiprtcGetCodeSize", CONV_JIT, API_RTC)), ("thrust::cuda::", ("thrust::hip::", CONV_MATH_FUNC, API_BLAS)), ("cub::", ("hipcub::", CONV_MATH_FUNC, API_BLAS)), ]) diff --git a/tools/amd_build/unwrap_clang.sh b/tools/amd_build/unwrap_clang.sh new file mode 100644 index 0000000000000..1b414829a8352 --- /dev/null +++ b/tools/amd_build/unwrap_clang.sh @@ -0,0 +1,9 @@ +#!/bin/bash +shopt -s extglob + +ORIG_COMP=/opt/rocm/hcc/bin/clang-*_original +# note that the wrapping always names the compiler "clang-7.0_original" +if [ -e $ORIG_COMP ]; then + WRAPPED=/opt/rocm/hcc/bin/clang-?([0-9])?([0-9])[0-9] + sudo mv $ORIG_COMP $WRAPPED +fi diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 7631176850254..cb2c3a97339a7 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -205,8 +205,7 @@ self: cholesky_backward(grad, upper, result) - name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor - self: not_implemented("cholesky_solve") - input2: not_implemented("cholesky_solve") + self, input2: cholesky_solve_backward(grad, self, input2, result, upper) - name: cholesky_inverse(Tensor self, bool upper=False) -> Tensor self: not_implemented("cholesky_inverse") @@ -386,7 +385,7 @@ - name: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor self: not_implemented("histc") -- name: index(Tensor self, Tensor?[] indices) -> Tensor +- name: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor self: index_backward(zeros_like(self), indices, grad) indices: TensorList() @@ -608,11 +607,10 @@ self: norm_backward(grad, self, p, result, dim, keepdim) - name: norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor - self: norm_backward(grad, self.to(grad.scalar_type()), p, result).to(self.scalar_type()) - + self: norm_backward(grad, self.to(grad.scalar_type()), p, result) - name: norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor - self: norm_backward(grad, self.to(grad.scalar_type()), p, result, dim, keepdim).to(self.scalar_type()) + self: norm_backward(grad, self.to(grad.scalar_type()), p, result, dim, keepdim) - name: _pdist_forward(Tensor self, float p=2) -> Tensor self: _pdist_backward(grad, self, p, result) @@ -623,7 +621,7 @@ pdist: not_implemented("_pdist_backward") - name: cdist(Tensor x1, Tensor x2, float p=2) -> Tensor - x1: _cdist_backward(grad, x1, x2, p, result) + x1: _cdist_backward(grad.contiguous(), x1, x2, p, result) x2: _cdist_backward(grad.transpose(-1, -2).contiguous(), x2, x1, p, result.transpose(-1, -2).contiguous()) - name: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor @@ -671,10 +669,10 @@ exponent: pow_backward_exponent(grad, self, exponent) - name: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor - self: prod_backward(grad, self.to(grad.scalar_type()), result).to(self.scalar_type()) + self: prod_backward(grad, self.to(grad.scalar_type()), result) - name: prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor - self: prod_backward(grad, self.to(grad.scalar_type()), result, dim, keepdim).to(self.scalar_type()) + self: prod_backward(grad, self.to(grad.scalar_type()), result, dim, keepdim) - name: put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) self: grad.clone().put_(index, zeros_like(source), accumulate) @@ -807,10 +805,10 @@ self: -grad * alpha - name: sum(Tensor self, *, ScalarType? dtype=None) -> Tensor - self: grad.expand(self.sizes()).to(self.scalar_type()) + self: grad.expand(self.sizes()) - name: sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor - self: sum_backward(grad, self.sizes(), dim, keepdim).to(self.scalar_type()) + self: sum_backward(grad, self.sizes(), dim, keepdim) - name: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) self: svd_backward(grads, self, some, compute_uv, U, S, V) @@ -850,7 +848,7 @@ - name: trace(Tensor self) -> Tensor self: trace_backward(grad, self.sizes()) -- name: transpose(Tensor(a) self, int dim0, int dim1) -> Tensor(a) +- name: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) self: grad.transpose(dim0, dim1) - name: transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) @@ -1505,7 +1503,7 @@ - name: _fft_with_size(Tensor self, int signal_ndim, bool complex_input, bool complex_output, bool inverse, int[] checked_signal_sizes, bool normalized, bool onesided, int[] output_sizes) -> Tensor self: fft_backward(self, grad, signal_ndim, complex_input, complex_output, inverse, checked_signal_sizes, normalized, onesided, output_sizes) -- name: unbind(Tensor(a) self, int dim=0) -> Tensor(a)[] +- name: unbind.int(Tensor(a) self, int dim=0) -> Tensor(a)[] self: unbind_backward(grads, dim) - name: stack(Tensor[] tensors, int dim=0) -> Tensor diff --git a/tools/autograd/env.py b/tools/autograd/env.py deleted file mode 100644 index d54f915db6dd6..0000000000000 --- a/tools/autograd/env.py +++ /dev/null @@ -1,12 +0,0 @@ -import os - -# This file copied from tools/setup_helpers/env.py. -# PLEASE DO NOT ADD ANYTHING TO THIS FILE, the BUILD_NAMEDTENSOR flag is temporary. -def check_env_flag(name, default=''): - return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] - - -def check_negative_env_flag(name, default=''): - return os.getenv(name, default).upper() in ['OFF', '0', 'NO', 'FALSE', 'N'] - -BUILD_NAMEDTENSOR = check_env_flag('BUILD_NAMEDTENSOR') diff --git a/tools/autograd/gen_autograd.py b/tools/autograd/gen_autograd.py index 3817fbfc42620..6d4d9c8436b22 100644 --- a/tools/autograd/gen_autograd.py +++ b/tools/autograd/gen_autograd.py @@ -4,7 +4,8 @@ python -m tools.autograd.gen_autograd \ build/aten/src/ATen/Declarations.yaml \ - $OUTPUT_DIR + $OUTPUT_DIR \ + tools/autograd Where $OUTPUT_DIR is where you would like the files to be generated. In the full build system, OUTPUT_DIR is @@ -180,7 +181,7 @@ def get_signature(name, params, call_args): return declarations -def gen_autograd(aten_path, out, autograd_dir): +def gen_autograd(aten_path, out, autograd_dir, disable_autograd=False): aten_decls = load_aten_declarations(aten_path) # Parse and load derivatives.yaml @@ -191,21 +192,19 @@ def gen_autograd(aten_path, out, autograd_dir): template_path = os.path.join(autograd_dir, 'templates') # Generate VariableType.h/cpp - from .gen_variable_type import gen_variable_type - gen_variable_type(out, aten_decls, template_path) + if not disable_autograd: + from .gen_variable_type import gen_variable_type + gen_variable_type(out, aten_decls, template_path) # Generate Functions.h/cpp from .gen_autograd_functions import gen_autograd_functions_lib gen_autograd_functions_lib( out, autograd_functions, template_path) - # Load deprecated signatures - deprecated = load_deprecated_signatures( - aten_decls, os.path.join(autograd_dir, 'deprecated.yaml')) - # Generate variable_factories.h from .gen_variable_factories import gen_variable_factories - gen_variable_factories(out, aten_decls, template_path) + gen_variable_factories( + out, aten_decls, template_path, disable_autograd=disable_autograd) def gen_autograd_python(aten_path, out, autograd_dir): diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 8377567ce3464..753af1519c3f2 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -33,9 +33,10 @@ 'copy_sparse_to_sparse_', 'copy_', 'numpy_T', # this needs to be an attribute in Python, not a function 'nonzero(_(out|numpy))?', - 'set_quantizer_', + 'set_quantizer_', # return types not supported yet 'set_data', '.*_overrideable', # overrideable functions for backend extension + 'data', 'is_leaf', 'output_nr' ] # These function signatures are not exposed to Python. Note that this signature @@ -117,7 +118,7 @@ """) PY_VARIABLE_METHOD_DEF = CodeTemplate("""\ -{"${name}", (PyCFunction)${pycname}, ${flags}, NULL},""") +{"${name}", (PyCFunction)${pycfunc_voidcast}${pycname}, ${flags}, NULL},""") PY_RETURN_NAMEDTUPLE_DEF = CodeTemplate("""\ static PyStructSequence_Field fields${namedtuple_type_index}[] = { @@ -155,6 +156,7 @@ 'std::vector', 'Scalar', 'bool', 'int64_t', 'void*', 'void', 'QScheme', 'double', + 'IntArrayRef', } TENSOR_OPTIONS = CodeTemplate("""\ @@ -627,7 +629,6 @@ def get_python_binding_arguments(declaration): py_default_device = 'self.device()' if is_like_or_new_function_with_options else None device_arg = { 'default': 'None', - 'default_init': 'None', 'dynamic_type': 'Device', 'kwarg_only': True, 'name': 'device', @@ -692,6 +693,7 @@ def process_function(name, declarations): 'name': name, 'dispatch_name': 'dispatch_{}'.format(name), 'pycname': 'THPVariable_{}'.format(name), + 'pycfunc_voidcast': '', 'signatures': [], 'max_args': max(len(o['arguments']) + len(o['python_binding_arguments']) for o in declarations), 'unpack_self': [], @@ -735,6 +737,7 @@ def process_function(name, declarations): else: tmpl = PY_VARIABLE_METHOD_VARARGS env['flags'] = 'METH_VARARGS | METH_KEYWORDS' + env['pycfunc_voidcast'] = '(void(*)(void))' if not is_module and not has_self: env['flags'] += ' | METH_STATIC' diff --git a/tools/autograd/gen_variable_factories.py b/tools/autograd/gen_variable_factories.py index 2db0f8777b1df..dd3984454ff40 100644 --- a/tools/autograd/gen_variable_factories.py +++ b/tools/autograd/gen_variable_factories.py @@ -34,20 +34,21 @@ def fully_qualified_type(argument_type): return "{}at::{}".format(argument_type[:index], argument_type[index:]) -def gen_variable_factories(out, declarations, template_path): +def gen_variable_factories(out, declarations, template_path, disable_autograd=False): function_definitions = [] for decl in declarations: has_tensor_options = any(a["simple_type"] == "TensorOptions" for a in decl["arguments"]) is_namespace_fn = 'namespace' in decl['method_of'] if (has_tensor_options or decl["name"].endswith("_like")) and is_namespace_fn: - function_definitions.append(process_function(decl, has_tensor_options)) + function_definitions.append( + process_function(decl, has_tensor_options, disable_autograd=disable_autograd)) write(out, "variable_factories.h", CodeTemplate.from_file(template_path + "/variable_factories.h"), {"function_definitions": function_definitions}) -def process_function(decl, has_tensor_options): +def process_function(decl, has_tensor_options, disable_autograd): formals = [] actuals = [] for argument in decl["arguments"]: @@ -65,7 +66,10 @@ def process_function(decl, has_tensor_options): # it's a tensor actuals.append('{}.options().is_variable(false)'.format(actuals[0])) - pre_record_trace, post_record_trace = format_trace(decl) + if not disable_autograd: + pre_record_trace, post_record_trace = format_trace(decl) + else: + pre_record_trace, post_record_trace = '', '' return FUNCTION_TEMPLATE.substitute( name=decl["name"], formals=formals, actuals=actuals, requires_grad=requires_grad, diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index e2e1d7d599788..5ec91de7330aa 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -26,11 +26,11 @@ from .utils import CodeTemplate, nested_dict, write, uninplace_api_name from .gen_autograd import VIEW_FUNCTIONS from .gen_autograd_functions import uses_single_grad -from .env import BUILD_NAMEDTENSOR # These functions are written manually in templates/VariableType.cpp MANUAL_IMPLEMENTATIONS = { - 'resize_', 'resize_as_', 'detach', 'detach_', 'copy_', 'backward', 'set_data' + 'resize_', 'resize_as_', 'detach', 'detach_', 'copy_', 'backward', + 'set_data', 'data', 'is_leaf', 'output_nr' } # These functions we don't want to record for tracing, because we always want @@ -154,7 +154,10 @@ """) WRAPPER_REGISTRATION = CodeTemplate("""\ -.registerVariableOp<${return_type} (${formal_types})>("${schema_string}", &VariableType::${api_name}) +.op(torch::RegisterOperators::options() + .schema("${schema_string}") + .impl_unboxedOnlyKernel<${return_type} (${formal_types}), &VariableType::${api_name}>(TensorTypeId::VariableTensorId) + .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) """) UNPACK_TENSOR = CodeTemplate("""\ @@ -282,11 +285,6 @@ def find_factory_functions(declarations): def should_trace(declaration): - if BUILD_NAMEDTENSOR: - # Short-term plan: Don't support tracing Dimname. - # Long-term plan: Add Dimname as a first-class type to the JIT. - if any('Dimname' in arg['simple_type'] for arg in declaration['arguments']): - return False # Operations involving Storage or Type are not traceable at the moment if any(arg['simple_type'] in {'Storage', 'Type', 'ConstQuantizerPtr'} for arg in declaration['arguments']): return False diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index 6e2f20f6b0eb1..871e36954068a 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -65,7 +65,7 @@ Tensor maybe_multiply(const Tensor & t, const Scalar & s) { bool is_one = false; if (s.isFloatingPoint()) { is_one = s.toDouble() == 1; - } else if(s.isIntegral()) { + } else if(s.isIntegral(true)) { is_one = s.toLong() == 1; } @@ -528,7 +528,7 @@ Tensor mm_mat2_backward(const Tensor & grad, const Tensor & mat1, IntArrayRef si int64_t out_cols = grad.size(1); Tensor t = at::zeros({}, grad.options()).expand({out_rows, out_cols}, true); Tensor r = at::empty({out_cols, out_rows}, grad.options()).t(); - at::s_native_addmm_out(r, t, mat1.t(), grad, alpha, 1); + at::addmm_out(r, t, mat1.t(), grad, alpha, 1); return r; } return maybe_multiply(grad.t().mm(mat1).t(), alpha); @@ -1982,6 +1982,27 @@ std::tuple triangular_solve_backward( return std::tuple{grad_b, grad_a}; } +std::tuple cholesky_solve_backward( + const Tensor& grad_x, const Tensor& self, + const Tensor& input2, const Tensor& result, const bool upper) { + Tensor grad_self, grad_input2; + if (grad_x.defined()) { + grad_self = grad_x.cholesky_solve(input2, /*upper=*/upper); + } else { + grad_self = at::zeros({1}, self.options()).expand_as(self); + } + + Tensor common_term = at::matmul(grad_self, result.transpose(-2, -1)); + common_term = common_term + common_term.transpose(-2, -1); + + if (upper) { + grad_input2 = -at::matmul(input2, common_term); + } else { + grad_input2 = -at::matmul(common_term, input2); + } + return std::tuple{grad_self, grad_input2}; +} + // Generally speaking, fft's backward is ifft. Tensor fft_backward(const Tensor& self, const Tensor& grad, int64_t signal_ndim, bool complex_input, bool complex_output, diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index 6b4e33b80e3c4..d1c338bc075e1 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -1,6 +1,7 @@ #include "torch/csrc/autograd/VariableTypeUtils.h" #include +#include // ${generated_comment} @@ -31,6 +32,11 @@ namespace torch { namespace autograd { ${type_derived_method_definitions} -static auto& registerer = globalATenDispatch() +namespace { + +static auto registerer = torch::RegisterOperators() ${wrapper_registrations}; + +} + }} // namespace torch::autograd diff --git a/tools/autograd/templates/VariableType.h b/tools/autograd/templates/VariableType.h index 3d0674caadfed..f12b66fc75d6d 100644 --- a/tools/autograd/templates/VariableType.h +++ b/tools/autograd/templates/VariableType.h @@ -7,6 +7,7 @@ #include #include +#include #include // for size_t #include // for function diff --git a/tools/autograd/templates/python_nn_functions.cpp b/tools/autograd/templates/python_nn_functions.cpp index 628fd740117d6..7ee6b91b1b4eb 100644 --- a/tools/autograd/templates/python_nn_functions.cpp +++ b/tools/autograd/templates/python_nn_functions.cpp @@ -50,7 +50,7 @@ static PyObject * THPVariable__parse_to(PyObject* module, PyObject* args, PyObje ${py_methods} static PyMethodDef nn_functions[] = { - {"_parse_to", (PyCFunction)THPVariable__parse_to, METH_VARARGS | METH_KEYWORDS, nullptr}, + {"_parse_to", (PyCFunction)(void(*)(void))THPVariable__parse_to, METH_VARARGS | METH_KEYWORDS, nullptr}, ${py_method_defs} {NULL} }; diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp index a694d006bfde7..b8b7686eb2a9e 100644 --- a/tools/autograd/templates/python_torch_functions.cpp +++ b/tools/autograd/templates/python_torch_functions.cpp @@ -99,7 +99,7 @@ inline Tensor dispatch_arange(Scalar start, Scalar end, Scalar step, const Tenso static inline bool allIntegral(std::initializer_list> l) { for (Scalar& s : l) { - if (!s.isIntegral()) { + if (!s.isIntegral(true)) { return false; } } @@ -112,7 +112,7 @@ static PyObject * THPVariable_arange(PyObject* self, PyObject* args, PyObject* k static PythonArgParser parser({ "arange(Scalar end, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", "arange(Scalar start, Scalar end, Scalar step=1, *, Tensor out=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False)", - }); + }, /*traceable=*/true); ParsedArgs<9> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); @@ -451,20 +451,20 @@ static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args, PyObje ${py_methods} static PyMethodDef torch_functions[] = { - {"arange", (PyCFunction)THPVariable_arange, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"as_tensor", (PyCFunction)THPVariable_as_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"dsmm", (PyCFunction)THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"arange", (PyCFunction)(void(*)(void))THPVariable_arange, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"as_tensor", (PyCFunction)(void(*)(void))THPVariable_as_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"dsmm", (PyCFunction)(void(*)(void))THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, {"from_numpy", (PyCFunction)THPVariable_from_numpy, METH_STATIC | METH_O, NULL}, - {"hsmm", (PyCFunction)THPVariable_hspmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"_promote_types", (PyCFunction)THPVariable__promote_types, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"nonzero", (PyCFunction)THPVariable_nonzero, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"randint", (PyCFunction)THPVariable_randint, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"range", (PyCFunction)THPVariable_range, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"saddmm", (PyCFunction)THPVariable_sspaddmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"sparse_coo_tensor", (PyCFunction)THPVariable_sparse_coo_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"spmm", (PyCFunction)THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"tensor", (PyCFunction)THPVariable_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, - {"get_device", (PyCFunction)THPVariable_get_device, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"hsmm", (PyCFunction)(void(*)(void))THPVariable_hspmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"_promote_types", (PyCFunction)(void(*)(void))THPVariable__promote_types, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"nonzero", (PyCFunction)(void(*)(void))THPVariable_nonzero, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"randint", (PyCFunction)(void(*)(void))THPVariable_randint, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"range", (PyCFunction)(void(*)(void))THPVariable_range, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"saddmm", (PyCFunction)(void(*)(void))THPVariable_sspaddmm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"sparse_coo_tensor", (PyCFunction)(void(*)(void))THPVariable_sparse_coo_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"spmm", (PyCFunction)(void(*)(void))THPVariable_mm, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"tensor", (PyCFunction)(void(*)(void))THPVariable_tensor, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, + {"get_device", (PyCFunction)(void(*)(void))THPVariable_get_device, METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL}, ${py_method_defs} {NULL} }; diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index 72bd1cc290898..f34966c2d495d 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -27,6 +27,7 @@ #include "torch/csrc/utils/tensor_numpy.h" #include "torch/csrc/utils/tensor_types.h" #include "torch/csrc/utils/structseq.h" +#include #include #include "c10/util/Optional.h" @@ -581,7 +582,7 @@ static PyObject * THPVariable_new(PyObject* self, PyObject* args, PyObject* kwar HANDLE_TH_ERRORS auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); - return THPVariable_Wrap(torch::utils::legacy_tensor_new(self_.type_id(), self_.scalar_type(), args, kwargs)); + return THPVariable_Wrap(torch::utils::legacy_tensor_new(legacyExtractTypeId(self_), self_.scalar_type(), args, kwargs)); END_HANDLE_TH_ERRORS } @@ -590,7 +591,7 @@ static PyObject * THPVariable_new_ones(PyObject* self, PyObject* args, PyObject* HANDLE_TH_ERRORS auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); - return THPVariable_Wrap(torch::utils::new_ones(self_.type_id(), self_.scalar_type(), args, kwargs)); + return THPVariable_Wrap(torch::utils::new_ones(legacyExtractTypeId(self_), self_.scalar_type(), args, kwargs)); END_HANDLE_TH_ERRORS } @@ -599,7 +600,7 @@ static PyObject * THPVariable_new_tensor(PyObject* self, PyObject* args, PyObjec HANDLE_TH_ERRORS auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); - return THPVariable_Wrap(torch::utils::new_tensor(self_.type_id(), self_.scalar_type(), args, kwargs)); + return THPVariable_Wrap(torch::utils::new_tensor(legacyExtractTypeId(self_), self_.scalar_type(), args, kwargs)); END_HANDLE_TH_ERRORS } @@ -608,7 +609,7 @@ static PyObject * THPVariable_new_zeros(PyObject* self, PyObject* args, PyObject HANDLE_TH_ERRORS auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); - return THPVariable_Wrap(torch::utils::new_zeros(self_.type_id(), self_.scalar_type(), args, kwargs)); + return THPVariable_Wrap(torch::utils::new_zeros(legacyExtractTypeId(self_), self_.scalar_type(), args, kwargs)); END_HANDLE_TH_ERRORS } @@ -724,18 +725,18 @@ static PyObject * THPVariable_bool_scalar(PyObject* self, PyObject* args) { } PyMethodDef variable_methods[] = { - {"__add__", (PyCFunction)THPVariable_add, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__radd__", (PyCFunction)THPVariable_add, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__iadd__", (PyCFunction)THPVariable_add_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__rmul__", (PyCFunction)THPVariable_mul, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__mul__", (PyCFunction)THPVariable_mul, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__imul__", (PyCFunction)THPVariable_mul_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__sub__", (PyCFunction)THPVariable_sub, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__isub__", (PyCFunction)THPVariable_sub_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__div__", (PyCFunction)THPVariable_div, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__truediv__", (PyCFunction)THPVariable_div, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__idiv__", (PyCFunction)THPVariable_div_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"__mod__", (PyCFunction)THPVariable_remainder, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__add__", (PyCFunction)(void(*)(void))THPVariable_add, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__radd__", (PyCFunction)(void(*)(void))THPVariable_add, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__iadd__", (PyCFunction)(void(*)(void))THPVariable_add_, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__rmul__", (PyCFunction)(void(*)(void))THPVariable_mul, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__mul__", (PyCFunction)(void(*)(void))THPVariable_mul, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__imul__", (PyCFunction)(void(*)(void))THPVariable_mul_, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__sub__", (PyCFunction)(void(*)(void))THPVariable_sub, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__isub__", (PyCFunction)(void(*)(void))THPVariable_sub_, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__div__", (PyCFunction)(void(*)(void))THPVariable_div, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__truediv__", (PyCFunction)(void(*)(void))THPVariable_div, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__idiv__", (PyCFunction)(void(*)(void))THPVariable_div_, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__mod__", (PyCFunction)(void(*)(void))THPVariable_remainder, METH_VARARGS | METH_KEYWORDS, NULL}, {"__bool__", (PyCFunction)THPVariable_bool_scalar, METH_NOARGS, NULL}, {"__float__", (PyCFunction)THPVariable_float_scalar, METH_NOARGS, NULL}, {"__int__", (PyCFunction)THPVariable_integral_scalar, METH_NOARGS, NULL}, @@ -743,16 +744,16 @@ PyMethodDef variable_methods[] = { {"__index__", (PyCFunction)THPVariable_index_scalar, METH_NOARGS, NULL}, {"__nonzero__", (PyCFunction)THPVariable_bool_scalar, METH_NOARGS, NULL}, {"__invert__", (PyCFunction)THPVariable_invert, METH_NOARGS, NULL}, - {"__matmul__", (PyCFunction)THPVariable_matmul, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__matmul__", (PyCFunction)(void(*)(void))THPVariable_matmul, METH_VARARGS | METH_KEYWORDS, NULL}, {"_is_view", (PyCFunction)THPVariable__is_view, METH_NOARGS, NULL}, {"apply_", (PyCFunction)THPVariable_apply_, METH_O, NULL}, {"bfloat16", (PyCFunction)THPVariable_bfloat16, METH_NOARGS, NULL}, {"byte", (PyCFunction)THPVariable_byte, METH_NOARGS, NULL}, {"char", (PyCFunction)THPVariable_char, METH_NOARGS, NULL}, - {"contiguous", (PyCFunction)THPVariable_contiguous, METH_VARARGS | METH_KEYWORDS, NULL}, - {"copy_", (PyCFunction)THPVariable_copy_, METH_VARARGS | METH_KEYWORDS, NULL}, + {"contiguous", (PyCFunction)(void(*)(void))THPVariable_contiguous, METH_VARARGS | METH_KEYWORDS, NULL}, + {"copy_", (PyCFunction)(void(*)(void))THPVariable_copy_, METH_VARARGS | METH_KEYWORDS, NULL}, {"cpu", (PyCFunction)THPVariable_cpu, METH_NOARGS, NULL}, - {"cuda", (PyCFunction)THPVariable_cuda, METH_VARARGS | METH_KEYWORDS, NULL}, + {"cuda", (PyCFunction)(void(*)(void))THPVariable_cuda, METH_VARARGS | METH_KEYWORDS, NULL}, {"data_ptr", (PyCFunction)THPVariable_data_ptr, METH_NOARGS, NULL}, {"dim", (PyCFunction)THPVariable_dim, METH_NOARGS, NULL}, #ifdef BUILD_NAMEDTENSOR @@ -765,30 +766,30 @@ PyMethodDef variable_methods[] = { {"bool", (PyCFunction)THPVariable_bool, METH_NOARGS, NULL}, {"half", (PyCFunction)THPVariable_half, METH_NOARGS, NULL}, {"int", (PyCFunction)THPVariable_int, METH_NOARGS, NULL}, - {"is_contiguous", (PyCFunction)THPVariable_is_contiguous, METH_VARARGS | METH_KEYWORDS, NULL}, + {"is_contiguous", (PyCFunction)(void(*)(void))THPVariable_is_contiguous, METH_VARARGS | METH_KEYWORDS, NULL}, {"item", (PyCFunction)THPVariable_item, METH_NOARGS, NULL}, {"long", (PyCFunction)THPVariable_long, METH_NOARGS, NULL}, - {"map_", (PyCFunction)THPVariable_map_, METH_VARARGS | METH_KEYWORDS, NULL}, - {"map2_", (PyCFunction)THPVariable_map2_, METH_VARARGS | METH_KEYWORDS, NULL}, + {"map_", (PyCFunction)(void(*)(void))THPVariable_map_, METH_VARARGS | METH_KEYWORDS, NULL}, + {"map2_", (PyCFunction)(void(*)(void))THPVariable_map2_, METH_VARARGS | METH_KEYWORDS, NULL}, {"ndimension", (PyCFunction)THPVariable_dim, METH_NOARGS, NULL}, {"nelement", (PyCFunction)THPVariable_numel, METH_NOARGS, NULL}, - {"new", (PyCFunction)THPVariable_new, METH_VARARGS | METH_KEYWORDS, NULL}, - {"new_ones", (PyCFunction)THPVariable_new_ones, METH_VARARGS | METH_KEYWORDS, NULL}, - {"new_tensor", (PyCFunction)THPVariable_new_tensor, METH_VARARGS | METH_KEYWORDS, NULL}, - {"new_zeros", (PyCFunction)THPVariable_new_zeros, METH_VARARGS | METH_KEYWORDS, NULL}, - {"nonzero", (PyCFunction)THPVariable_nonzero, METH_VARARGS | METH_KEYWORDS, NULL}, + {"new", (PyCFunction)(void(*)(void))THPVariable_new, METH_VARARGS | METH_KEYWORDS, NULL}, + {"new_ones", (PyCFunction)(void(*)(void))THPVariable_new_ones, METH_VARARGS | METH_KEYWORDS, NULL}, + {"new_tensor", (PyCFunction)(void(*)(void))THPVariable_new_tensor, METH_VARARGS | METH_KEYWORDS, NULL}, + {"new_zeros", (PyCFunction)(void(*)(void))THPVariable_new_zeros, METH_VARARGS | METH_KEYWORDS, NULL}, + {"nonzero", (PyCFunction)(void(*)(void))THPVariable_nonzero, METH_VARARGS | METH_KEYWORDS, NULL}, {"numpy", (PyCFunction)THPVariable_numpy, METH_NOARGS, NULL}, {"record_stream", (PyCFunction)THPVariable_record_stream, METH_O, NULL}, - {"requires_grad_", (PyCFunction)THPVariable_requires_grad_, METH_VARARGS | METH_KEYWORDS, NULL}, + {"requires_grad_", (PyCFunction)(void(*)(void))THPVariable_requires_grad_, METH_VARARGS | METH_KEYWORDS, NULL}, {"short", (PyCFunction)THPVariable_short, METH_NOARGS, NULL}, - {"size", (PyCFunction)THPVariable_size, METH_VARARGS | METH_KEYWORDS, NULL}, + {"size", (PyCFunction)(void(*)(void))THPVariable_size, METH_VARARGS | METH_KEYWORDS, NULL}, {"storage", (PyCFunction)THPVariable_storage, METH_NOARGS, NULL}, {"storage_offset", (PyCFunction)THPVariable_storage_offset, METH_NOARGS, NULL}, {"storage_type", (PyCFunction)THPVariable_storage_type, METH_NOARGS, NULL}, - {"stride", (PyCFunction)THPVariable_stride, METH_VARARGS | METH_KEYWORDS, NULL}, - {"to", (PyCFunction)THPVariable_to, METH_VARARGS | METH_KEYWORDS, NULL}, + {"stride", (PyCFunction)(void(*)(void))THPVariable_stride, METH_VARARGS | METH_KEYWORDS, NULL}, + {"to", (PyCFunction)(void(*)(void))THPVariable_to, METH_VARARGS | METH_KEYWORDS, NULL}, {"tolist", (PyCFunction)THPVariable_tolist, METH_NOARGS, NULL}, - {"type", (PyCFunction)THPVariable_type, METH_VARARGS | METH_KEYWORDS, NULL}, + {"type", (PyCFunction)(void(*)(void))THPVariable_type, METH_VARARGS | METH_KEYWORDS, NULL}, ${py_method_defs} {NULL} }; diff --git a/tools/autograd/templates/variable_factories.h b/tools/autograd/templates/variable_factories.h index 12fa02adf0d69..0e9b31d1b06b8 100644 --- a/tools/autograd/templates/variable_factories.h +++ b/tools/autograd/templates/variable_factories.h @@ -3,8 +3,10 @@ // ${generated_comment} #include +#include #include #include +#include #include #include #include @@ -19,6 +21,134 @@ using at::DimnameList; namespace torch { +namespace detail { + enum class ListInitTensorType { Scalar, InitList }; + + // We use `ListInitTensor` to support converting an arbitrarily nested braced-init-list + // (e.g. {{1, 2}, {3, 4}}) into the equivalent Tensor, taking advantage of the fact that + // the constructor will automatically be called recursively until it reaches all innermost + // scalar values. + // + // At any time, a `ListInitTensor` object represents either of the following: + // 1. A scalar with value `scalar()` and type `scalar_type()`. + // 2. A Tensor represented in `std::initializer_list` form, with value + // `init_list()`, Tensor scalar type `scalar_type()`, and Tensor sizes `sizes()`. + struct ListInitTensor { +#define TENSOR(T, S) \ + ListInitTensor(T scalar) : \ + scalar_(scalar), init_list_(), \ + sizes_(), \ + scalar_type_(at::k##S), \ + type_(ListInitTensorType::Scalar) {} +AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) +#undef TENSOR + ListInitTensor(std::initializer_list init_list) : + scalar_(), + init_list_(init_list), + sizes_(), + scalar_type_(), + type_(ListInitTensorType::InitList) { + TORCH_CHECK(init_list.size() > 0, "Empty init-list is not supported"); + scalar_type_ = init_list.begin()->scalar_type_; + const ListInitTensor& first_elem = *(init_list.begin()); + for (const auto& elem : init_list) { + TORCH_CHECK(elem.scalar_type_ == first_elem.scalar_type_, + "Expected all elements of the tensor to have the same scalar type: ", + first_elem.scalar_type_, + ", but got element of scalar type: ", + elem.scalar_type_); + TORCH_CHECK(elem.sizes_ == first_elem.sizes_, + "Expected all sub-lists to have sizes: ", + first_elem.sizes_, + " (e.g. ", first_elem, "), ", + "but got sub-list ", + elem, + " with sizes: ", + elem.sizes_); + } + sizes_.reserve(first_elem.sizes_.size() + 1); + sizes_.push_back(init_list.size()); + sizes_.insert(sizes_.end(), first_elem.sizes_.begin(), first_elem.sizes_.end()); + } + + const c10::Scalar& scalar() const { + return scalar_; + } + + const std::initializer_list& init_list() const { + return init_list_; + } + + const std::vector& sizes() const { + return sizes_; + } + + const c10::ScalarType& scalar_type() const { + return scalar_type_; + } + + const ListInitTensorType& type() const { + return type_; + } + + at::Tensor to_tensor(const at::TensorOptions& options) const { + // NOTE: Here we explicitly choose to initialize the tensor on CPU first, + // fill each element of the tensor, and then move the tensor to the desired + // device. For CUDA device, this approach only involves 1 CUDA kernel launch, + // and is much faster than initializing the tensor on CUDA first and then + // filling each element of it (which involves `N` CUDA kernel launches where + // `N` is the number of the elements in the tensor). + at::Tensor tensor = ([&]() { + at::AutoNonVariableTypeMode non_var_type_mode(true); + return at::empty(sizes_, at::TensorOptions(options).device(at::kCPU).is_variable(false)); + })(); + fill_tensor(tensor); + return tensor.to(options.device()); + } + + void pretty_print_recursive(std::ostream& stream) const { + if (type_ == ListInitTensorType::Scalar) { + AT_DISPATCH_ALL_TYPES_AND3(at::kBool, at::kHalf, at::kBFloat16, scalar_type_, "ListInitTensor_pretty_print_scalar", [&] { + stream << scalar_.to(); + }); + } else if (type_ == ListInitTensorType::InitList) { + stream << "{"; + for (const ListInitTensor* it = init_list_.begin(); it != init_list_.end(); it++) { + it->pretty_print_recursive(stream); + if (std::next(it) != init_list_.end()) stream << ", "; + } + stream << "}"; + } + } + + private: + void fill_tensor(at::Tensor tensor) const { + size_t index = 0; + for (const auto& elem : init_list_) { + if (elem.type_ == ListInitTensorType::Scalar) { + at::NoGradGuard guard; + tensor[index].fill_(elem.scalar()); + } else if (elem.type_ == ListInitTensorType::InitList) { + elem.fill_tensor(tensor[index]); + } else { + TORCH_INTERNAL_ASSERT(false, "Invalid ListInitTensor"); + } + index++; + } + } + c10::Scalar scalar_; + std::initializer_list init_list_; + std::vector sizes_; + c10::ScalarType scalar_type_; + ListInitTensorType type_; + }; + + inline std::ostream& operator<<(std::ostream& stream, const ListInitTensor& list_init_tensor) { + list_init_tensor.pretty_print_recursive(stream); + return stream; + } +} // namespace detail + #define TENSOR(T, S) \ inline at::Tensor tensor( \ at::ArrayRef values, const at::TensorOptions& options) { \ @@ -47,6 +177,14 @@ namespace torch { AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) #undef TENSOR +inline at::Tensor tensor(detail::ListInitTensor list_init_tensor, const at::TensorOptions& options) { + return autograd::make_variable(list_init_tensor.to_tensor(options), options.requires_grad()); +} + +inline at::Tensor tensor(detail::ListInitTensor list_init_tensor) { + return torch::tensor(list_init_tensor, at::dtype(list_init_tensor.scalar_type())); +} + /// A generic deleter function. using Deleter = std::function; using at::MemoryFormat; diff --git a/tools/build_variables.py b/tools/build_variables.py index 657d6a96d31cd..ac2180c769aee 100644 --- a/tools/build_variables.py +++ b/tools/build_variables.py @@ -56,6 +56,8 @@ "torch/csrc/distributed/rpc/future_message.cpp", "torch/csrc/distributed/rpc/message.cpp", "torch/csrc/distributed/rpc/script_call.cpp", + "torch/csrc/distributed/rpc/script_remote_call.cpp", + "torch/csrc/distributed/rpc/script_rref_proto.cpp", "torch/csrc/distributed/rpc/script_ret.cpp", "torch/csrc/Exceptions.cpp", "torch/csrc/jit/autodiff.cpp", @@ -110,6 +112,7 @@ "torch/csrc/jit/passes/peephole.cpp", "torch/csrc/jit/passes/python_print.cpp", "torch/csrc/jit/passes/quantization.cpp", + "torch/csrc/jit/passes/fuse_linear.cpp", "torch/csrc/jit/passes/remove_expands.cpp", "torch/csrc/jit/passes/requires_grad_analysis.cpp", "torch/csrc/jit/passes/shape_analysis.cpp", @@ -181,16 +184,27 @@ def add_torch_libs(): "torch/csrc/api/src/data/samplers/sequential.cpp", "torch/csrc/api/src/data/samplers/stream.cpp", "torch/csrc/api/src/jit.cpp", + "torch/csrc/api/src/serialize.cpp", "torch/csrc/api/src/nn/init.cpp", "torch/csrc/api/src/nn/module.cpp", "torch/csrc/api/src/nn/modules/batchnorm.cpp", "torch/csrc/api/src/nn/modules/conv.cpp", "torch/csrc/api/src/nn/modules/dropout.cpp", + "torch/csrc/api/src/nn/modules/distance.cpp", "torch/csrc/api/src/nn/modules/embedding.cpp", - "torch/csrc/api/src/nn/modules/functional.cpp", + "torch/csrc/api/src/nn/modules/fold.cpp", "torch/csrc/api/src/nn/modules/linear.cpp", - "torch/csrc/api/src/nn/modules/named_any.cpp", + "torch/csrc/api/src/nn/modules/loss.cpp", + "torch/csrc/api/src/nn/modules/pooling.cpp", "torch/csrc/api/src/nn/modules/rnn.cpp", + "torch/csrc/api/src/nn/modules/container/functional.cpp", + "torch/csrc/api/src/nn/modules/container/named_any.cpp", + "torch/csrc/api/src/nn/options/batchnorm.cpp", + "torch/csrc/api/src/nn/options/conv.cpp", + "torch/csrc/api/src/nn/options/dropout.cpp", + "torch/csrc/api/src/nn/options/linear.cpp", + "torch/csrc/api/src/nn/options/pooling.cpp", + "torch/csrc/api/src/nn/options/rnn.cpp", "torch/csrc/api/src/optim/adagrad.cpp", "torch/csrc/api/src/optim/adam.cpp", "torch/csrc/api/src/optim/lbfgs.cpp", @@ -215,6 +229,7 @@ def add_torch_libs(): "torch/csrc/Generator.cpp", "torch/csrc/Layout.cpp", "torch/csrc/MemoryFormat.cpp", + "torch/csrc/QEngine.cpp", "torch/csrc/QScheme.cpp", "torch/csrc/Module.cpp", "torch/csrc/PtrWrapper.cpp", @@ -245,6 +260,9 @@ def add_torch_libs(): "torch/csrc/distributed/rpc/python_functions.cpp", "torch/csrc/distributed/rpc/python_rpc_handler.cpp", "torch/csrc/distributed/rpc/rpc_agent.cpp", + "torch/csrc/distributed/rpc/rref.cpp", + "torch/csrc/distributed/rpc/rref_context.cpp", + "torch/csrc/distributed/rpc/types.cpp", "torch/csrc/jit/init.cpp", "torch/csrc/jit/passes/inline_fork_wait.cpp", "torch/csrc/jit/passes/onnx.cpp", @@ -253,6 +271,7 @@ def add_torch_libs(): "torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp", "torch/csrc/jit/passes/onnx/peephole.cpp", "torch/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp", + "torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp", "torch/csrc/jit/passes/remove_inplace_ops.cpp", "torch/csrc/jit/passes/utils/check_alias_annotation.cpp", "torch/csrc/jit/python_arg_flatten.cpp", @@ -273,6 +292,7 @@ def add_torch_libs(): "torch/csrc/utils/invalid_arguments.cpp", "torch/csrc/utils/object_ptr.cpp", "torch/csrc/utils/python_arg_parser.cpp", + "torch/csrc/utils/qengines.cpp", "torch/csrc/utils/structseq.cpp", "torch/csrc/utils/tensor_apply.cpp", "torch/csrc/utils/tensor_dtypes.cpp", diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py index c7224879d3fef..1ff0ceb3672d7 100644 --- a/tools/setup_helpers/cmake.py +++ b/tools/setup_helpers/cmake.py @@ -14,7 +14,6 @@ from . import which from .env import (BUILD_DIR, IS_64BIT, IS_DARWIN, IS_WINDOWS, check_negative_env_flag) from .cuda import USE_CUDA -from .dist_check import USE_DISTRIBUTED, USE_GLOO_IBVERBS from .numpy_ import USE_NUMPY, NUMPY_INCLUDE_DIR @@ -216,7 +215,8 @@ def generate(self, version, cmake_python_library, build_python, build_test, my_e # adding a new build option to this block: Consider making these two names identical and adding this option # in the block below. '_GLIBCXX_USE_CXX11_ABI': 'GLIBCXX_USE_CXX11_ABI', - 'USE_CUDA_STATIC_LINK': 'CAFFE2_STATIC_LINK_CUDA' + 'USE_CUDA_STATIC_LINK': 'CAFFE2_STATIC_LINK_CUDA', + 'USE_GLOO_IBVERBS': 'USE_IBVERBS' # Backward compatibility. Will be removed in the future. } additional_options.update({ # Build options that have the same environment variable name and CMake variable name and that do not start @@ -237,6 +237,9 @@ def generate(self, version, cmake_python_library, build_python, build_test, my_e 'WERROR') }) + if 'USE_GLOO_IBVERBS' in my_env: + print("WARNING: USE_GLOO_IBVERBS is deprecated. Use USE_IBVERBS instead.") + for var, val in my_env.items(): # We currently pass over all environment variables that start with "BUILD_", "USE_", and "CMAKE_". This is # because we currently have no reliable way to get the list of all build options we have specified in @@ -260,7 +263,6 @@ def generate(self, version, cmake_python_library, build_python, build_test, my_e 'BUILD_PYTHON': build_python, 'BUILD_TEST': build_test, 'USE_CUDA': USE_CUDA, - 'USE_DISTRIBUTED': USE_DISTRIBUTED, 'USE_NUMPY': USE_NUMPY, }) @@ -286,14 +288,13 @@ def generate(self, version, cmake_python_library, build_python, build_test, my_e NUMPY_INCLUDE_DIR=NUMPY_INCLUDE_DIR, **build_options) - if USE_GLOO_IBVERBS: - CMake.defines(args, USE_IBVERBS="1", USE_GLOO_IBVERBS="1") - expected_wrapper = '/usr/local/opt/ccache/libexec' if IS_DARWIN and os.path.exists(expected_wrapper): - CMake.defines(args, - CMAKE_C_COMPILER="{}/gcc".format(expected_wrapper), - CMAKE_CXX_COMPILER="{}/g++".format(expected_wrapper)) + if 'CMAKE_C_COMPILER' not in build_options and 'CC' not in os.environ: + CMake.defines(args, CMAKE_C_COMPILER="{}/gcc".format(expected_wrapper)) + if 'CMAKE_CXX_COMPILER' not in build_options and 'CXX' not in os.environ: + CMake.defines(args, CMAKE_CXX_COMPILER="{}/g++".format(expected_wrapper)) + for env_var_name in my_env: if env_var_name.startswith('gh'): # github env vars use utf-8, on windows, non-ascii code may diff --git a/tools/setup_helpers/dist_check.py b/tools/setup_helpers/dist_check.py deleted file mode 100644 index 3b3a7368d836e..0000000000000 --- a/tools/setup_helpers/dist_check.py +++ /dev/null @@ -1,110 +0,0 @@ -import os -import subprocess -import glob - -from .env import IS_CONDA, IS_WINDOWS, CONDA_DIR, check_env_flag, check_negative_env_flag, gather_paths - -# On ROCm, RCCL development isn't complete. https://github.com/ROCmSoftwarePlatform/rccl -USE_DISTRIBUTED = not check_negative_env_flag("USE_DISTRIBUTED") and not IS_WINDOWS -USE_GLOO_IBVERBS = False - -IB_DEVINFO_CMD = "ibv_devinfo" - - -def get_command_path(command): - """ - Helper function that checks if the command exists in the path and gets the - full path of a given linux command if it exists. - """ - def executable(command_path): - return os.path.isfile(command_path) and os.access(command_path, os.X_OK) - - for path in os.environ["PATH"].split(os.pathsep): - command_path = os.path.join(path, command) - if executable(command_path): - return command_path - - return None - - -def should_build_ib(): - """ - Helper function that detects the system's IB support and returns if we - should build with IB support. - """ - ib_util_found = False - ib_lib_found = False - ib_header_found = False - - try: - # If the command doesn't exist, we can directly return instead of - # making a subprocess call - full_cmd_path = get_command_path(IB_DEVINFO_CMD) - if not full_cmd_path: - ib_util_found = False - subprocess.check_output([full_cmd_path, "--list"]) - # Here we just would like to simply run the command to test if IB - # related tools / lib are installed without parsing the output. We - # will enable IB build as long as the command runs successfully. - # - # The output should look like either: - # - # > ibv_devinfo --list - # 0 HCAs founds: - # - # or - # - # > ibv_devinfo --list - # 4 HCAs found: - # mlx5_3 - # mlx5_2 - # mlx5_1 - # mlx5_0 - ib_util_found = True - except Exception: - # We just take all the exceptions here without affecting the build - ib_util_found = False - - lib_paths = list(filter(bool, [ - "/usr/lib/", - "/usr/lib/x86_64-linux-gnu/", - "/usr/lib/powerpc64le-linux-gnu/", - "/usr/lib/aarch64-linux-gnu/", - ] + gather_paths([ - "LIBRARY_PATH", - ]) + gather_paths([ - "LD_LIBRARY_PATH", - ]))) - - include_paths = [ - "/usr/include/", - ] - - if IS_CONDA: - lib_paths.append(os.path.join(CONDA_DIR, "lib")) - include_paths.append(os.path.join(CONDA_DIR, "include")) - - for path in lib_paths: - if path is None or not os.path.exists(path): - continue - ib_libraries = sorted(glob.glob(os.path.join(path, "libibverbs*"))) - if ib_libraries: - ib_lib_found = True - break - - for path in include_paths: - if path is None or not os.path.exists(path): - continue - if os.path.exists(os.path.join(path, "infiniband/verbs.h")): - ib_header_found = True - break - - return ib_util_found and ib_lib_found and ib_lib_found - -if USE_DISTRIBUTED: - # If the env variable is specified, use the value, - # otherwise only build with IB when IB support is detected on the system - if "USE_GLOO_IBVERBS" in os.environ: - USE_GLOO_IBVERBS = check_env_flag("USE_GLOO_IBVERBS") - else: - USE_GLOO_IBVERBS = should_build_ib() diff --git a/tools/setup_helpers/generate_code.py b/tools/setup_helpers/generate_code.py index bb8d53ebe5f5c..697eee2bf509e 100644 --- a/tools/setup_helpers/generate_code.py +++ b/tools/setup_helpers/generate_code.py @@ -23,7 +23,8 @@ def generate_code(ninja_global=None, declarations_path=None, nn_path=None, install_dir=None, - subset=None): + subset=None, + disable_autograd=False): # cwrap depends on pyyaml, so we can't import it earlier root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, root) @@ -41,7 +42,12 @@ def generate_code(ninja_global=None, gen_autograd_python(declarations_path or DECLARATIONS_PATH, autograd_gen_dir, 'tools/autograd') if subset == "libtorch" or not subset: - gen_autograd(declarations_path or DECLARATIONS_PATH, autograd_gen_dir, 'tools/autograd') + gen_autograd( + declarations_path or DECLARATIONS_PATH, + autograd_gen_dir, + 'tools/autograd', + disable_autograd=disable_autograd, + ) gen_jit_dispatch(declarations_path or DECLARATIONS_PATH, jit_gen_dir, 'tools/jit/templates') @@ -55,6 +61,12 @@ def main(): '--subset', help='Subset of source files to generate. Can be "libtorch" or "pybindings". Generates both when omitted.' ) + parser.add_argument( + '--disable-autograd', + default=False, + action='store_true', + help='It can skip generating autograd related code when the flag is set', + ) options = parser.parse_args() generate_code( options.ninja_global, @@ -62,6 +74,7 @@ def main(): options.nn_path, options.install_dir, options.subset, + options.disable_autograd, ) diff --git a/tools/shared/__init__.py b/tools/shared/__init__.py index c04e88be4bd99..3058f50c334f7 100644 --- a/tools/shared/__init__.py +++ b/tools/shared/__init__.py @@ -1,2 +1,2 @@ from .module_loader import import_module # noqa: F401 -from .cwrap_common import set_declaration_defaults, sort_by_number_of_options # noqa: F401 +from .cwrap_common import set_declaration_defaults, sort_by_number_of_args # noqa: F401 diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index d2ffb8c2f3976..e358761342f86 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -52,6 +52,7 @@ set(TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/Layout.cpp ${TORCH_SRC_DIR}/csrc/MemoryFormat.cpp ${TORCH_SRC_DIR}/csrc/python_dimname.cpp + ${TORCH_SRC_DIR}/csrc/QEngine.cpp ${TORCH_SRC_DIR}/csrc/QScheme.cpp ${TORCH_SRC_DIR}/csrc/Module.cpp ${TORCH_SRC_DIR}/csrc/PtrWrapper.cpp @@ -76,6 +77,7 @@ set(TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/jit/passes/onnx/peephole.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/onnx/constant_fold.cpp + ${TORCH_SRC_DIR}/csrc/jit/passes/onnx/scalar_type_analysis.cpp ${TORCH_SRC_DIR}/csrc/jit/python_arg_flatten.cpp ${TORCH_SRC_DIR}/csrc/jit/python_interpreter.cpp ${TORCH_SRC_DIR}/csrc/jit/python_ir.cpp @@ -94,6 +96,7 @@ set(TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/utils/invalid_arguments.cpp ${TORCH_SRC_DIR}/csrc/utils/object_ptr.cpp ${TORCH_SRC_DIR}/csrc/utils/python_arg_parser.cpp + ${TORCH_SRC_DIR}/csrc/utils/qengines.cpp ${TORCH_SRC_DIR}/csrc/utils/structseq.cpp ${TORCH_SRC_DIR}/csrc/utils/tensor_apply.cpp ${TORCH_SRC_DIR}/csrc/utils/tensor_dtypes.cpp @@ -221,7 +224,7 @@ endif() if (USE_DISTRIBUTED) list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_DISTRIBUTED) - if (NOT MSVC AND NOT APPLE) + if (NOT MSVC) list(APPEND TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/distributed/autograd/init.cpp ${TORCH_SRC_DIR}/csrc/distributed/autograd/context/dist_autograd_container.cpp @@ -235,6 +238,9 @@ if (USE_DISTRIBUTED) ${TORCH_SRC_DIR}/csrc/distributed/rpc/python_functions.cpp ${TORCH_SRC_DIR}/csrc/distributed/rpc/python_rpc_handler.cpp ${TORCH_SRC_DIR}/csrc/distributed/rpc/rpc_agent.cpp + ${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_context.cpp + ${TORCH_SRC_DIR}/csrc/distributed/rpc/rref.cpp + ${TORCH_SRC_DIR}/csrc/distributed/rpc/types.cpp ) list(APPEND TORCH_PYTHON_LINK_LIBRARIES c10d) list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D) diff --git a/torch/__init__.py b/torch/__init__.py index 9417796087d6f..d09e650f22a84 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -133,7 +133,7 @@ def is_storage(obj): def set_default_tensor_type(t): r"""Sets the default ``torch.Tensor`` type to floating point tensor type - :attr:`t`. This type will also be used as default floating point type for + ``t``. This type will also be used as default floating point type for type inference in :func:`torch.tensor`. The default floating point tensor type is initially ``torch.FloatTensor``. @@ -318,6 +318,7 @@ def manager_path(): import torch.backends.cuda import torch.backends.mkl import torch.backends.openmp +import torch.backends.quantized import torch.utils.data import torch.__config__ import torch.__future__ diff --git a/torch/__init__.pyi.in b/torch/__init__.pyi.in index 1073c25b457aa..02f604b93e721 100644 --- a/torch/__init__.pyi.in +++ b/torch/__init__.pyi.in @@ -88,6 +88,7 @@ class Tensor: device: _device = ... requires_grad: _bool = ... grad: Optional[Tensor] = ... + data: Tensor = ... ${tensor_method_hints} diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 7327689059c94..5dcd5bdd6cd97 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -6,8 +6,10 @@ import inspect import weakref +import warnings import torch._C from torch._six import builtins +from torch._utils_internal import get_source_lines_and_file # Wrapper functions that can call either of 2 functions depending on a boolean # argument @@ -168,7 +170,7 @@ class FunctionModifiers(object): Used to denote the behavior of a function in TorchScript. See export() and ignore() for details. """ - IGNORE_AND_DROP = "ignore (leave as a call to Python, replace with a 'raise' on torch.jit.save)" + UNUSED = "unused (ignored and replaced with raising of an exception)" IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)" EXPORT = "export (compile this function even if nothing calls it)" DEFAULT = "default (compile if called from a exported function / forward)" @@ -219,23 +221,52 @@ def unused_method(self, x): return fn -def ignore(drop_on_export=False): +def unused(fn): """ This decorator indicates to the compiler that a function or method should - be ignored and left as a Python function. - - Arguments: + be ignored and replaced with the raising of an exception. This allows you + to leave code in your model that is not yet TorchScript compatible and still + export your model. + + Example (using ``@torch.jit.unused`` on a method):: + + import torch + import torch.nn as nn + + class MyModule(nn.Module): + def __init__(self, use_memory_efficent): + super(MyModule, self).__init__() + self.use_memory_efficent = use_memory_efficent + + @torch.jit.unused + def memory_efficient(self, x): + import pdb + pdb.set_trace() + return x + 10 + + def forward(self, x): + # Use not-yet-scriptable memory efficient mode + if self.use_memory_efficient: + return self.memory_efficient(x) + else: + return x + 10 + + m = torch.jit.script(MyModule(use_memory_efficent=False)) + m.save("m.pt") + + m = torch.jit.script(MyModule(use_memory_efficient=True)) + # exception raised + m(torch.rand(100)) + """ + fn._torchscript_modifier = FunctionModifiers.UNUSED + return fn - drop_on_export (bool): When ``False``, calls to this function will - that will be run with ``example_inputs``. - arguments and returns to ``func`` must be tensors - or (possibly nested) tuples that - contain tensors. When ``True``, any calls to - this function from other TorchScript code will be replaced - with a `raise` when the model is saved. - This allows you to leave code in your TorchScript model that is only ever - run when the Python interpreter is present, but not run after you save - and load your model. +def ignore(drop=False, **kwargs): + """ + This decorator indicates to the compiler that a function or method should + be ignored and left as a Python function. This allows you to leave code in + your model that is not yet TorchScript compatible. Models with ignored + functions cannot be exported; use torch.jit.unused instead. Example (using ``@torch.jit.ignore`` on a method):: @@ -261,7 +292,7 @@ def forward(self, x): # Error! The call `debugger` cannot be saved since it calls into Python m.save("m.pt") - Example (using ``@torch.jit.ignore(drop_on_export=True)`` on a method): + Example (using ``@torch.jit.ignore(drop=True)`` on a method): .. testcode:: @@ -269,7 +300,7 @@ def forward(self, x): import torch.nn as nn class MyModule(nn.Module): - @torch.jit.ignore(drop_on_export=True) + @torch.jit.ignore(drop=True) def training_method(self, x): import pdb pdb.set_trace() @@ -290,24 +321,37 @@ def forward(self, x): import os os.remove('m.pt') """ - if callable(drop_on_export): - # used without any args, so drop_on_export is actually a function + + if callable(drop): + # used without any args, so drop is actually a function # @torch.jit.ignore # def fn(...): - fn = drop_on_export + fn = drop fn._torchscript_modifier = FunctionModifiers.IGNORE return fn - if isinstance(drop_on_export, bool): - def decorator(fn): - if drop_on_export: - fn._torchscript_modifier = FunctionModifiers.IGNORE_AND_DROP - else: - fn._torchscript_modifier = FunctionModifiers.IGNORE - return fn - return decorator - raise RuntimeError("Argument to @torch.jit.ignore must be a bool or " - "a function but got {}".format(drop_on_export)) + if not isinstance(drop, bool): + raise RuntimeError("Argument to @torch.jit.ignore must be a bool or " + "a function but got {}".format(drop)) + + # for backwards compat + drop_on_export = kwargs.pop("drop_on_export", None) + if drop_on_export: + warnings.warn("ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function " + "call on compilation. Use torch.jit.unused now. {}", category=DeprecationWarning) + + drop = drop_on_export + elif drop: + warnings.warn("ignore(True) has been deprecated. TorchScript will now drop the function " + "call on compilation. Use torch.jit.unused now. {}", category=DeprecationWarning) + + def decorator(fn): + if drop: + fn._torchscript_modifier = FunctionModifiers.UNUSED + else: + fn._torchscript_modifier = FunctionModifiers.IGNORE + return fn + return decorator def module_has_exports(mod): @@ -318,16 +362,16 @@ def module_has_exports(mod): return True return False -def should_drop_on_export(fn): +def should_drop(fn): attr = get_torchscript_modifier(fn) if attr is None: return False - return attr is FunctionModifiers.IGNORE_AND_DROP + return attr is FunctionModifiers.UNUSED def is_ignored_fn(fn): mod = get_torchscript_modifier(fn) - return mod is FunctionModifiers.IGNORE_AND_DROP or mod is FunctionModifiers.IGNORE + return mod is FunctionModifiers.UNUSED or mod is FunctionModifiers.IGNORE def get_torchscript_modifier(fn): @@ -434,11 +478,11 @@ def _get_overloaded_methods(method, mod_class): if overloads is None: return None - method_line_no = inspect.getsourcelines(method)[1] - mod_class_fileno = inspect.getsourcelines(mod_class)[1] - mod_end_fileno = mod_class_fileno + len(inspect.getsourcelines(mod_class)[0]) + method_line_no = get_source_lines_and_file(method)[1] + mod_class_fileno = get_source_lines_and_file(mod_class)[1] + mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0]) if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno): - raise Exception("Overloads are not useable when a module is redaclared within the same file: " + str(method)) + raise Exception("Overloads are not useable when a module is redeclared within the same file: " + str(method)) return overloads try: diff --git a/torch/namedtensor.py b/torch/_namedtensor_internals.py similarity index 50% rename from torch/namedtensor.py rename to torch/_namedtensor_internals.py index 1ab41ed14c501..f1d9297ec37b1 100644 --- a/torch/namedtensor.py +++ b/torch/_namedtensor_internals.py @@ -1,4 +1,5 @@ import torch +from torch._six import PY2 from collections import OrderedDict """ @@ -8,58 +9,80 @@ """ -def _assert_namedtensor_build(api_name): +def assert_namedtensor_build(api_name): if not torch._C._BUILD_NAMEDTENSOR: raise RuntimeError('NYI: {} is experimental and a part ' 'of our named tensors project.'.format(api_name)) -def _check_serializing_named_tensor(tensor): +def check_serializing_named_tensor(tensor): if torch._C._BUILD_NAMEDTENSOR and tensor.has_names(): raise RuntimeError( "NYI: Named tensors don't support serialization. Please drop " "names before serialization and/or serialize them seperately.") -def _build_dim_map(tensor): +def build_dim_map(tensor): """Returns a map of { dim: dim_name } where dim is a name if the dim is named and the dim index otherwise.""" return OrderedDict([(idx if name is None else name, name) for idx, name in enumerate(tensor.names)]) -def _namer_api_name(inplace): +def unzip_namedshape(namedshape): + if isinstance(namedshape, OrderedDict): + namedshape = namedshape.items() + if not hasattr(namedshape, '__iter__') and not isinstance(namedshape, tuple): + raise RuntimeError( + 'Expected namedshape to be OrderedDict or iterable of tuples, got: {}' + .format(type(namedshape))) + if len(namedshape) == 0: + raise RuntimeError('Expected namedshape to non-empty.') + return zip(*namedshape) + + +def namer_api_name(inplace): if inplace: return 'names_' else: - return 'view_names' + return 'renamed' -def _expand_single_glob(numel_pre_glob, numel_post_glob, names): +def is_ellipsis(item): + if PY2: + return item == '...' + else: + return item == Ellipsis or item == '...' + + +def expand_single_ellipsis(numel_pre_glob, numel_post_glob, names): return names[numel_pre_glob:len(names) - numel_post_glob] -def _update_names_with_list(tensor, names, inplace): - # Special case for tensor.view_names(None) +def resolve_ellipsis(names, tensor_names, fn_name): + ellipsis_indices = [i for i, name in enumerate(names) if is_ellipsis(name)] + if len(ellipsis_indices) >= 2: + raise RuntimeError('{}: More than one Ellipsis (\'...\') found in names (' + '{}). This function supports up to one Ellipsis.' + .format(fn_name, names)) + if len(ellipsis_indices) == 0: + return names + ellipsis_idx = ellipsis_indices[0] + globbed_names = expand_single_ellipsis(ellipsis_idx, len(names) - ellipsis_idx - 1, tensor_names) + return names[:ellipsis_idx] + globbed_names + names[ellipsis_idx + 1:] + + +def update_names_with_list(tensor, names, inplace): + # Special case for tensor.renamed(None) if len(names) == 1 and names[0] is None: return tensor._update_names(None, inplace) - glob_indices = [i for i, x in enumerate(names) if x == '*'] - if len(glob_indices) >= 2: - raise RuntimeError('{}: More than one \'*\' found in names (' - '{}). This function supports up to one \'*\'.' - .format(_namer_api_name(inplace), names)) - elif len(glob_indices) == 1: - glob_idx = glob_indices[0] - globbed_names = _expand_single_glob(glob_idx, len(names) - glob_idx - 1, tensor.names) - return tensor._update_names( - names[:glob_idx] + globbed_names + names[glob_idx + 1:], inplace) - else: - return tensor._update_names(names, inplace) + return tensor._update_names( + resolve_ellipsis(names, tensor.names, namer_api_name(inplace)), inplace) -def _update_names_with_mapping(tensor, rename_map, inplace): - dim_map = _build_dim_map(tensor) +def update_names_with_mapping(tensor, rename_map, inplace): + dim_map = build_dim_map(tensor) for old_dim in rename_map.keys(): new_dim = rename_map[old_dim] if old_dim in dim_map.keys(): @@ -68,14 +91,14 @@ def _update_names_with_mapping(tensor, rename_map, inplace): raise RuntimeError(('{api_name}: Tried to rename dim \'{old_dim}\' to dim ' '{new_dim} in Tensor[{dims}] but dim \'{old_dim}\' does not exist') .format(old_dim=old_dim, new_dim=new_dim, dims=tensor.names, - api_name=_namer_api_name(inplace))) + api_name=namer_api_name(inplace))) return tensor._update_names(tuple(dim_map.values()), inplace) -def _update_names(tensor, names, rename_map, inplace): +def update_names(tensor, names, rename_map, inplace): """There are two usages: - tensor.view_names(*names) returns a view on tensor with named dims `names`. + tensor.renamed(*names) returns a view on tensor with named dims `names`. `names` must be of length `tensor.dim()`; otherwise, if '*' is in `names`, then it is expanded greedily to be equal to the corresponding names from `tensor.names`. @@ -83,26 +106,26 @@ def _update_names(tensor, names, rename_map, inplace): For example, ``` >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W')) - >>> x.view_names('*', 'height', 'width').names + >>> x.renamed('*', 'height', 'width').names ('N', 'C', 'height', 'width') - >>> x.view_names('batch', '*', 'width').names + >>> x.renamed('batch', '*', 'width').names ('batch', 'C', 'H', 'width') ``` - tensor.view_names(**rename_map) returns a view on tensor that has renamed dims + tensor.renamed(**rename_map) returns a view on tensor that has renamed dims as specified in the mapping `rename_map`. For example, ``` >>> x = torch.empty(2, 3, 5, 7, names=('N', 'C', 'H', 'W')) - >>> x.view_names(W='width', H='height').names + >>> x.renamed(W='width', H='height').names ('N', 'C', 'height', 'width') ``` - Finally, tensor.view_names has an in-place version called tensor.names_. + Finally, tensor.renamed has an in-place version called tensor.names_. """ - _assert_namedtensor_build(_namer_api_name(inplace)) + assert_namedtensor_build(namer_api_name(inplace)) has_names = len(names) > 0 has_rename_pairs = bool(rename_map) @@ -110,12 +133,12 @@ def _update_names(tensor, names, rename_map, inplace): raise RuntimeError('{api_name}: This function takes either positional ' 'args or keyword args, but not both. Use tensor.{api_name}(*names) ' 'to name dims and tensor.{api_name}(**rename_map) to rename ' - 'dims.'.format(api_name=_namer_api_name(inplace))) + 'dims.'.format(api_name=namer_api_name(inplace))) - # Special case for tensor.view_names(*[]), which is valid for a 0 dim tensor. + # Special case for tensor.renamed(*[]), which is valid for a 0 dim tensor. if not has_names and not has_rename_pairs: - return _update_names_with_list(tensor, names, inplace) + return update_names_with_list(tensor, names, inplace) if has_names: - return _update_names_with_list(tensor, names, inplace) - return _update_names_with_mapping(tensor, rename_map, inplace) + return update_names_with_list(tensor, names, inplace) + return update_names_with_mapping(tensor, rename_map, inplace) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 280af531f79f1..73228ce514f90 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1903,7 +1903,9 @@ def callable(a, b) -> number q_per_channel_scales() -> Tensor Given a Tensor quantized by linear (affine) per-channel quantization, -returns a Tensor of scales of the underlying quantizer(). +returns a Tensor of scales of the underlying quantizer. It has the number of +elements that matches the corresponding dimensions (from q_per_channel_axis) of +the tensor. """) add_docstr_all('q_per_channel_zero_points', @@ -1911,7 +1913,18 @@ def callable(a, b) -> number q_per_channel_zero_points() -> Tensor Given a Tensor quantized by linear (affine) per-channel quantization, -returns a tensor of zero_points of the underlying quantizer(). +returns a tensor of zero_points of the underlying quantizer. It has the number of +elements that matches the corresponding dimensions (from q_per_channel_axis) of +the tensor. +""") + +add_docstr_all('q_per_channel_axis', + r""" +q_per_channel_axis() -> tuple of ints + +Given a Tensor quantized by linear (affine) per-channel quantization, +returns a indices of dimensions on which per-channel quantization is applied. +In the most commmon case there is only one such dimension. """) add_docstr_all('random_', diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index bd0267565f7f9..d1e87cbaca486 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -201,7 +201,7 @@ def _tensor_str(self, indent): # - tensor data needs to be summarized # Some of the codepaths don't fully support named tensors, so we send in # an unnamed tensor to the formatting code as a workaround. - self = self.view_names(None) + self = self.renamed(None) summarize = self.numel() > PRINT_OPTS.threshold if self.dtype is torch.float16 or self.dtype is torch.bfloat16: @@ -288,6 +288,7 @@ def _str(self): elif self.qscheme() == torch.per_channel_affine or self.qscheme() == torch.per_channel_symmetric: suffixes.append('scale=' + str(self.q_per_channel_scales())) suffixes.append('zero_point=' + str(self.q_per_channel_zero_points())) + suffixes.append('axis=' + ','.join(map(str, self.q_per_channel_axis()))) tensor_str = _tensor_str(self.dequantize(), indent) else: if self.numel() == 0 and not self.is_sparse: diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 227735db050af..d8deb1adae201 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -6047,20 +6047,20 @@ def merge_dicts(*dicts): add_docstr(torch.where, r""" -.. function:: where(condition, input, other) -> Tensor +.. function:: where(condition, x, y) -> Tensor -Return a tensor of elements selected from either :attr:`input` or :attr:`other`, depending on :attr:`condition`. +Return a tensor of elements selected from either :attr:`x` or :attr:`y`, depending on :attr:`condition`. The operation is defined as: .. math:: \text{out}_i = \begin{cases} - \text{input}_i & \text{if } \text{condition}_i \\ - \text{other}_i & \text{otherwise} \\ + \text{x}_i & \text{if } \text{condition}_i \\ + \text{y}_i & \text{otherwise} \\ \end{cases} .. note:: - The tensors :attr:`condition`, :attr:`input`, :attr:`other` must be :ref:`broadcastable `. + The tensors :attr:`condition`, :attr:`x`, :attr:`y` must be :ref:`broadcastable `. Arguments: condition (BoolTensor): When True (nonzero), yield x, otherwise yield y @@ -6068,7 +6068,7 @@ def merge_dicts(*dicts): y (Tensor): values selected at indices where :attr:`condition` is ``False`` Returns: - Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`input`, :attr:`other` + Tensor: A tensor of shape equal to the broadcasted shape of :attr:`condition`, :attr:`x`, :attr:`y` Example:: diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 81603469dc6c0..43aa5802b48ad 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function, unicode_literals import os +import inspect # this arbitrary-looking assortment of functionality is provided here # to have a central place for overrideable behavior. The motivating @@ -33,5 +34,23 @@ def resolve_library_path(path): return os.path.realpath(path) +def get_source_lines_and_file(obj): + """ + Wrapper around inspect.getsourcelines and inspect.getsourcefile. + + Returns: (sourcelines, file_lino, filename) + """ + filename = None # in case getsourcefile throws + try: + filename = inspect.getsourcefile(obj) + sourcelines, file_lineno = inspect.getsourcelines(obj) + except OSError as e: + raise OSError(( + "Can't get source for {}. TorchScript requires source access in order to carry out compilation. " + + "Make sure original .py files are available. Original error: {}").format(filename, e)) + + return sourcelines, file_lineno, filename + + TEST_MASTER_ADDR = '127.0.0.1' TEST_MASTER_PORT = 29500 diff --git a/torch/backends/__init__.py b/torch/backends/__init__.py index e69de29bb2d1d..9d74b8f9f0f05 100644 --- a/torch/backends/__init__.py +++ b/torch/backends/__init__.py @@ -0,0 +1,47 @@ +from contextlib import contextmanager +import types +# The idea for this parameter is that we forbid bare assignment +# to torch.backends..enabled and friends when running our +# test suite, where it's very easy to forget to undo the change +# later. +__allow_nonbracketed_mutation_flag = True + +def disable_global_flags(): + global __allow_nonbracketed_mutation_flag + __allow_nonbracketed_mutation_flag = False + +def flags_frozen(): + return not __allow_nonbracketed_mutation_flag + +@contextmanager +def __allow_nonbracketed_mutation(): + global __allow_nonbracketed_mutation_flag + old = __allow_nonbracketed_mutation_flag + __allow_nonbracketed_mutation_flag = True + try: + yield + finally: + __allow_nonbracketed_mutation_flag = old + +class ContextProp(object): + def __init__(self, getter, setter): + self.getter = getter + self.setter = setter + + def __get__(self, obj, objtype): + return self.getter() + + def __set__(self, obj, val): + if not flags_frozen(): + self.setter(val) + else: + raise RuntimeError("not allowed to set %s flags " + "after disable_global_flags; please use flags() context manager instead" % obj.__name__) + +class PropModule(types.ModuleType): + def __init__(self, m, name): + super(PropModule, self).__init__(name) + self.m = m + + def __getattr__(self, attr): + return self.m.__getattribute__(attr) diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py index 4f196b0259fb3..c24e9c48c9ee6 100644 --- a/torch/backends/cudnn/__init__.py +++ b/torch/backends/cudnn/__init__.py @@ -2,11 +2,11 @@ import ctypes import sys import torch -import types import warnings from torch.version import cuda from contextlib import contextmanager from subprocess import Popen, PIPE +from torch.backends import ContextProp, PropModule, __allow_nonbracketed_mutation # Write: # @@ -18,14 +18,6 @@ __cudnn_version = None # TODO: dynamic version checks via cudnnGetVersion - -# The idea for this parameter is that we forbid bare assignment -# to torch.backends.cudnn.enabled and friends when running our -# test suite, where it's very easy to forget to undo the change -# later. -__allow_nonbracketed_mutation_flag = True - - def find_cudnn_windows_lib(): # Override the default search process # Fixes https://github.com/pytorch/pytorch/issues/20202 @@ -164,27 +156,6 @@ def set_flags(_enabled, _benchmark, _deterministic, _verbose): torch._C._set_cudnn_deterministic(_deterministic) return orig_flags - -def disable_global_flags(): - global __allow_nonbracketed_mutation_flag - __allow_nonbracketed_mutation_flag = False - - -def flags_frozen(): - return not __allow_nonbracketed_mutation_flag - - -@contextmanager -def __allow_nonbracketed_mutation(): - global __allow_nonbracketed_mutation_flag - old = __allow_nonbracketed_mutation_flag - __allow_nonbracketed_mutation_flag = True - try: - yield - finally: - __allow_nonbracketed_mutation_flag = old - - @contextmanager def flags(enabled=False, benchmark=False, deterministic=False, verbose=False): with __allow_nonbracketed_mutation(): @@ -451,31 +422,11 @@ def add_tensor(*args): # The magic here is to allow us to intercept code like this: # -# torch.backends.cudnn.enabled = True - -class ContextProp(object): - def __init__(self, getter, setter): - self.getter = getter - self.setter = setter +# torch.backends..enabled = True - def __get__(self, obj, objtype): - return self.getter() - - def __set__(self, obj, val): - if not flags_frozen(): - self.setter(val) - else: - raise RuntimeError("not allowed to set torch.backends.cudnn flags " - "after disable_global_flags; please use flags() context manager instead") - - -class CudnnModule(types.ModuleType): +class CudnnModule(PropModule): def __init__(self, m, name): - super(CudnnModule, self).__init__(name) - self.m = m - - def __getattr__(self, attr): - return self.m.__getattribute__(attr) + super(CudnnModule, self).__init__(m, name) enabled = ContextProp(torch._C._get_cudnn_enabled, torch._C._set_cudnn_enabled) deterministic = ContextProp(torch._C._get_cudnn_deterministic, torch._C._set_cudnn_deterministic) diff --git a/torch/backends/mkldnn/__init__.py b/torch/backends/mkldnn/__init__.py index 1b852c97a775c..c62af5df3b048 100644 --- a/torch/backends/mkldnn/__init__.py +++ b/torch/backends/mkldnn/__init__.py @@ -1,6 +1,33 @@ +import sys import torch - +from contextlib import contextmanager +from torch.backends import ContextProp, PropModule, __allow_nonbracketed_mutation def is_available(): r"""Returns whether PyTorch is built with MKL-DNN support.""" return torch._C.has_mkldnn + +def set_flags(_enabled): + orig_flags = (torch._C._get_mkldnn_enabled(),) + torch._C._set_mkldnn_enabled(_enabled) + return orig_flags + +@contextmanager +def flags(enabled=False): + with __allow_nonbracketed_mutation(): + orig_flags = set_flags(enabled) + try: + yield + finally: + with __allow_nonbracketed_mutation(): + set_flags(orig_flags[0]) + +class MkldnnModule(PropModule): + def __init__(self, m, name): + super(MkldnnModule, self).__init__(m, name) + + enabled = ContextProp(torch._C._get_mkldnn_enabled, torch._C._set_mkldnn_enabled) + +# Cool stuff from torch/backends/cudnn/__init__.py and +# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 +sys.modules[__name__] = MkldnnModule(sys.modules[__name__], __name__) diff --git a/torch/backends/quantized/__init__.py b/torch/backends/quantized/__init__.py new file mode 100644 index 0000000000000..3ac3d30292fd8 --- /dev/null +++ b/torch/backends/quantized/__init__.py @@ -0,0 +1,29 @@ +from __future__ import absolute_import, division, print_function, unicode_literals +import sys +import torch +import types + +class ContextProp(object): + def __init__(self, getter, setter): + self.getter = getter + self.setter = setter + + def __get__(self, obj, objtype): + return self.getter() + + def __set__(self, obj, val): + self.setter(val) + +class QuantizedEngine(types.ModuleType): + def __init__(self, m, name): + super(QuantizedEngine, self).__init__(name) + self.m = m + + def __getattr__(self, attr): + return self.m.__getattribute__(attr) + # TODO: replace with strings(https://github.com/pytorch/pytorch/pull/26330/files#r324951460) + engine = ContextProp(torch._C._get_qengine, torch._C._set_qengine) + +# This is the sys.modules replacement trick, see +# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273 +sys.modules[__name__] = QuantizedEngine(sys.modules[__name__], __name__) diff --git a/torch/csrc/DataLoader.cpp b/torch/csrc/DataLoader.cpp index beb7bf91fef4f..778be341ee39b 100644 --- a/torch/csrc/DataLoader.cpp +++ b/torch/csrc/DataLoader.cpp @@ -98,7 +98,7 @@ static PyObject *THPModule_setWorkerSignalHandlers(PyObject *module, PyObject *a static std::map> worker_pids = {}; -static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) { +static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module, PyObject *noargs) { HANDLE_TH_ERRORS int error; std::set *pid_set; @@ -131,6 +131,10 @@ static PyObject *THPModule_errorIfAnyWorkerFails(PyObject *module) { std::ostringstream oss; oss << "DataLoader worker (pid " << worker_pid << ") is killed " << "by signal: " << strsignal(infop.si_status) << ". "; + if (infop.si_status == SIGBUS) { + oss << "It is possible that dataloader's workers are out of shared memory. " + << "Please try to raise your shared memory limit."; + } // This is necessary. Otherwise, the runtime error will kill the other // workers, and trigger this again. pid_set->clear(); diff --git a/torch/csrc/Device.cpp b/torch/csrc/Device.cpp index d17bbabc12469..6bdfcb47570d8 100644 --- a/torch/csrc/Device.cpp +++ b/torch/csrc/Device.cpp @@ -76,7 +76,7 @@ PyObject *THPDevice_pynew(PyTypeObject *type, PyObject *args, PyObject *kwargs) END_HANDLE_TH_ERRORS } -PyObject *THPDevice_type(THPDevice *self) +PyObject *THPDevice_type(THPDevice *self, PyObject *noargs) { HANDLE_TH_ERRORS std::ostringstream oss; @@ -86,7 +86,7 @@ PyObject *THPDevice_type(THPDevice *self) END_HANDLE_TH_ERRORS } -PyObject *THPDevice_index(THPDevice *self) +PyObject *THPDevice_index(THPDevice *self, PyObject *noargs) { HANDLE_TH_ERRORS if (self->device.has_index()) { @@ -138,7 +138,7 @@ PyObject *THPDevice_rc(PyObject *a, PyObject *b, int op) { END_HANDLE_TH_ERRORS } -PyObject *THPDevice_reduce(THPDevice *self) +PyObject *THPDevice_reduce(THPDevice *self, PyObject *noargs) { HANDLE_TH_ERRORS auto ret = THPObjectPtr{PyTuple_New(2)}; diff --git a/torch/csrc/Dtype.cpp b/torch/csrc/Dtype.cpp index ab6e6e93fb802..dc3c7c44ac9d6 100644 --- a/torch/csrc/Dtype.cpp +++ b/torch/csrc/Dtype.cpp @@ -20,7 +20,7 @@ PyObject * THPDtype_New(at::ScalarType scalar_type, const std::string& name) return self.release(); } -PyObject *THPDtype_is_floating_point(THPDtype *self) +PyObject *THPDtype_is_floating_point(THPDtype *self, PyObject *noargs) { if (at::isFloatingType(self->scalar_type) || at::isComplexType(self->scalar_type)) { Py_RETURN_TRUE; @@ -29,7 +29,7 @@ PyObject *THPDtype_is_floating_point(THPDtype *self) } } -PyObject *THPDtype_reduce(THPDtype *self) +PyObject *THPDtype_reduce(THPDtype *self, PyObject *noargs) { /* * For singletons, a string is returned. The string should be interpreted diff --git a/torch/csrc/Generator.cpp b/torch/csrc/Generator.cpp index 5154b76772cf0..2d7458dd99016 100644 --- a/torch/csrc/Generator.cpp +++ b/torch/csrc/Generator.cpp @@ -74,7 +74,7 @@ static PyObject * THPGenerator_pynew(PyTypeObject *type, PyObject *args, PyObjec END_HANDLE_TH_ERRORS } -static PyObject * THPGenerator_getState(THPGenerator *self) +static PyObject * THPGenerator_getState(THPGenerator *self, PyObject *noargs) { using namespace torch::autograd; HANDLE_TH_ERRORS @@ -134,7 +134,7 @@ static PyObject * THPGenerator_manualSeed(THPGenerator *self, PyObject *seed) END_HANDLE_TH_ERRORS } -static PyObject * THPGenerator_seed(THPGenerator *self) +static PyObject * THPGenerator_seed(THPGenerator *self, PyObject *noargs) { HANDLE_TH_ERRORS // See Note [Acquire lock when using random generators] @@ -144,14 +144,14 @@ static PyObject * THPGenerator_seed(THPGenerator *self) END_HANDLE_TH_ERRORS } -static PyObject * THPGenerator_initialSeed(THPGenerator *self) +static PyObject * THPGenerator_initialSeed(THPGenerator *self, PyObject *noargs) { HANDLE_TH_ERRORS return THPUtils_packUInt64(self->cdata->current_seed()); END_HANDLE_TH_ERRORS } -static PyObject * THPGenerator_get_device(THPGenerator *self) { +static PyObject * THPGenerator_get_device(THPGenerator *self, void *unused) { HANDLE_TH_ERRORS return THPDevice_New(self->cdata->device()); END_HANDLE_TH_ERRORS diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 30198548d49b5..cce111e53079f 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -36,6 +37,7 @@ #include #include #include +#include #include #include #include @@ -46,6 +48,7 @@ #include #include #include +#include #ifdef USE_CUDNN #include @@ -107,6 +110,7 @@ static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manag torch::utils::initializeLayouts(); torch::utils::initializeMemoryFormats(); torch::utils::initializeQSchemes(); + torch::utils::initializeQEngines(); torch::utils::initializeDtypes(); torch::tensors::initialize_python_bindings(); std::string path = THPUtils_unpackString(shm_manager_path); @@ -160,7 +164,7 @@ static PyObject * THPModule_crashIfATenASAN(PyObject *module, PyObject *arg) { return PyLong_FromLong(at::_crash_if_asan(static_cast(THPUtils_unpackLong(arg)))); } -static PyObject * THPModule_getNumThreads(PyObject *module) +static PyObject * THPModule_getNumThreads(PyObject *module, PyObject *noargs) { return PyLong_FromLong(at::get_num_threads()); } @@ -175,7 +179,7 @@ static PyObject * THPModule_setNumThreads(PyObject *module, PyObject *arg) Py_RETURN_NONE; } -static PyObject * THPModule_getNumInteropThreads(PyObject *module) +static PyObject * THPModule_getNumInteropThreads(PyObject *module, PyObject *noargs) { return PyLong_FromLong(at::get_num_interop_threads()); } @@ -309,7 +313,7 @@ static PyObject *THPModule_setBackcompatBroadcastWarn(PyObject *module, PyObject Py_RETURN_NONE; } -static PyObject *THPModule_getBackcompatBroadcastWarn(PyObject *module) +static PyObject *THPModule_getBackcompatBroadcastWarn(PyObject *module, PyObject *noargs) { if (getBackCompatBroadcastWarn()) Py_RETURN_TRUE; else Py_RETURN_FALSE; @@ -322,13 +326,13 @@ static PyObject *THPModule_setBackcompatKeepdimWarn(PyObject *module, PyObject * Py_RETURN_NONE; } -static PyObject *THPModule_getBackcompatKeepdimWarn(PyObject *module) +static PyObject *THPModule_getBackcompatKeepdimWarn(PyObject *module, PyObject *noargs) { if (getBackCompatKeepdimWarn()) Py_RETURN_TRUE; else Py_RETURN_FALSE; } -PyObject *THPModule_hasDistributed(PyObject *_unused) +PyObject *THPModule_hasDistributed(PyObject *_unused, PyObject *noargs) { #ifdef USE_DISTRIBUTED Py_RETURN_TRUE; @@ -337,14 +341,14 @@ PyObject *THPModule_hasDistributed(PyObject *_unused) #endif } -static PyObject *THPModule_showConfig(PyObject *module) +static PyObject *THPModule_showConfig(PyObject *module, PyObject *noargs) { HANDLE_TH_ERRORS return THPUtils_packString(at::show_config()); END_HANDLE_TH_ERRORS } -static PyObject *THPModule_parallelInfo(PyObject *module) +static PyObject *THPModule_parallelInfo(PyObject *module, PyObject *noargs) { HANDLE_TH_ERRORS return THPUtils_packString(at::get_parallel_info()); @@ -411,12 +415,26 @@ PyObject *THPModule_setUserEnabledCuDNN(PyObject *_unused, PyObject *arg) Py_RETURN_NONE; } -PyObject *THPModule_userEnabledCuDNN(PyObject *_unused) +PyObject *THPModule_userEnabledCuDNN(PyObject *_unused, PyObject *noargs) { if (at::globalContext().userEnabledCuDNN()) Py_RETURN_TRUE; else Py_RETURN_FALSE; } +PyObject *THPModule_setUserEnabledMkldnn(PyObject *_unused, PyObject *arg) +{ + THPUtils_assert(PyBool_Check(arg), "set_enabled_mkldnn expects a bool, " + "but got %s", THPUtils_typename(arg)); + at::globalContext().setUserEnabledMkldnn(arg == Py_True); + Py_RETURN_NONE; +} + +PyObject *THPModule_userEnabledMkldnn(PyObject *_unused, PyObject *noargs) +{ + if (at::globalContext().userEnabledMkldnn()) Py_RETURN_TRUE; + else Py_RETURN_FALSE; +} + PyObject *THPModule_setDeterministicCuDNN(PyObject *_unused, PyObject *arg) { THPUtils_assert(PyBool_Check(arg), "set_deterministic_cudnn expects a bool, " @@ -425,7 +443,7 @@ PyObject *THPModule_setDeterministicCuDNN(PyObject *_unused, PyObject *arg) Py_RETURN_NONE; } -PyObject *THPModule_deterministicCuDNN(PyObject *_unused) +PyObject *THPModule_deterministicCuDNN(PyObject *_unused, PyObject *noargs) { if (at::globalContext().deterministicCuDNN()) Py_RETURN_TRUE; else Py_RETURN_FALSE; @@ -439,7 +457,7 @@ PyObject *THPModule_setBenchmarkCuDNN(PyObject *_unused, PyObject *arg) Py_RETURN_NONE; } -PyObject *THPModule_benchmarkCuDNN(PyObject *_unused) +PyObject *THPModule_benchmarkCuDNN(PyObject *_unused, PyObject *noargs) { if (at::globalContext().benchmarkCuDNN()) Py_RETURN_TRUE; else Py_RETURN_FALSE; @@ -471,13 +489,26 @@ PyObject *THPModule_getDefaultDevice(PyObject *_unused, PyObject *arg) { END_HANDLE_TH_ERRORS } +PyObject *THPModule_setQEngine(PyObject */* unused */, PyObject *arg) +{ + TORCH_CHECK(THPQEngine_Check(arg), "qengine arg must be an instance of the torch.qengine"); + const auto qengine = reinterpret_cast(arg); + at::globalContext().setQEngine(qengine->qengine); + Py_RETURN_NONE; +} + +PyObject *THPModule_qEngine(PyObject */* unused */) +{ + return THPQEngine_New(at::globalContext().qEngine(), toString(at::globalContext().qEngine())); +} + static PyMethodDef TorchMethods[] = { {"_initExtension", (PyCFunction)THPModule_initExtension, METH_O, nullptr}, {"_autograd_init", (PyCFunction)THPAutograd_initExtension, METH_NOARGS, nullptr}, {"_add_docstr", (PyCFunction)THPModule_addDocStr, METH_VARARGS, nullptr}, {"_init_names", (PyCFunction)THPModule_initNames, METH_O, nullptr}, {"_has_distributed",(PyCFunction)THPModule_hasDistributed, METH_NOARGS, nullptr}, - {"_safe_call", (PyCFunction)THPModule_safeCall, METH_VARARGS | METH_KEYWORDS, nullptr}, + {"_safe_call", (PyCFunction)(void(*)(void))THPModule_safeCall, METH_VARARGS | METH_KEYWORDS, nullptr}, {"_set_default_tensor_type", (PyCFunction)THPModule_setDefaultTensorType, METH_O, nullptr}, {"_set_default_dtype", (PyCFunction)THPModule_setDefaultDtype, METH_O, nullptr}, {"_infer_size", (PyCFunction)THPModule_inferSize, METH_VARARGS, nullptr}, @@ -496,6 +527,8 @@ static PyMethodDef TorchMethods[] = { {"set_num_interop_threads", (PyCFunction)THPModule_setNumInteropThreads, METH_O, nullptr}, {"_get_cudnn_enabled", (PyCFunction)THPModule_userEnabledCuDNN, METH_NOARGS, nullptr}, {"_set_cudnn_enabled", (PyCFunction)THPModule_setUserEnabledCuDNN, METH_O, nullptr}, + {"_get_mkldnn_enabled", (PyCFunction)THPModule_userEnabledMkldnn, METH_NOARGS, nullptr}, + {"_set_mkldnn_enabled", (PyCFunction)THPModule_setUserEnabledMkldnn, METH_O, nullptr}, {"_get_cudnn_benchmark", (PyCFunction)THPModule_benchmarkCuDNN, METH_NOARGS, nullptr}, {"_set_cudnn_benchmark", (PyCFunction)THPModule_setBenchmarkCuDNN, METH_O, nullptr}, {"_get_cudnn_deterministic", (PyCFunction)THPModule_deterministicCuDNN, METH_NOARGS, nullptr}, @@ -504,7 +537,9 @@ static PyMethodDef TorchMethods[] = { {"_from_dlpack", (PyCFunction)THPModule_fromDLPack, METH_O, nullptr}, {"set_flush_denormal", (PyCFunction)THPModule_setFlushDenormal, METH_O, nullptr}, {"get_default_dtype", (PyCFunction)THPModule_getDefaultDtype, METH_NOARGS, nullptr}, - {"_get_default_device", (PyCFunction)THPModule_getDefaultDevice, METH_NOARGS, nullptr}, + {"_get_default_device", (PyCFunction)THPModule_getDefaultDevice, METH_NOARGS, nullptr}, + {"_get_qengine", (PyCFunction)THPModule_qEngine, METH_NOARGS, nullptr}, + {"_set_qengine", (PyCFunction)THPModule_setQEngine, METH_O, nullptr}, {nullptr, nullptr, 0, nullptr} }; @@ -645,6 +680,7 @@ PyObject* initModule() { THPDTypeInfo_init(module); THPLayout_init(module); THPMemoryFormat_init(module); + THPQEngine_init(module); THPQScheme_init(module); THPDevice_init(module); ASSERT_TRUE(THPVariable_initModule(module)); diff --git a/torch/csrc/QEngine.cpp b/torch/csrc/QEngine.cpp new file mode 100644 index 0000000000000..4ffcb88ad5342 --- /dev/null +++ b/torch/csrc/QEngine.cpp @@ -0,0 +1,79 @@ +#include + +#include +#include +#include + +#include + +#include +#include +#include + +PyObject* THPQEngine_New(at::QEngine qengine, const std::string& name) { + auto type = (PyTypeObject*)&THPQEngineType; + auto self = THPObjectPtr{type->tp_alloc(type, 0)}; + if (!self) + throw python_error(); + auto self_ = reinterpret_cast(self.get()); + self_->qengine = qengine; + std::strncpy(self_->name, name.c_str(), QENGINE_NAME_LEN); + self_->name[QENGINE_NAME_LEN] = '\0'; + return self.release(); +} + +PyObject* THPQEngine_repr(THPQEngine* self) { + std::string name = self->name; + return THPUtils_packString("torch." + name); +} + +PyTypeObject THPQEngineType = { + PyVarObject_HEAD_INIT(nullptr, 0) "torch.qengine", /* tp_name */ + sizeof(THPQEngine), /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + nullptr, /* tp_print */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + (reprfunc)THPQEngine_repr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + nullptr, /* tp_methods */ + nullptr, /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + nullptr, /* tp_alloc */ + nullptr, /* tp_new */ +}; + +void THPQEngine_init(PyObject* module) { + if (PyType_Ready(&THPQEngineType) < 0) { + throw python_error(); + } + Py_INCREF(&THPQEngineType); + if (PyModule_AddObject(module, "qengine", (PyObject*)&THPQEngineType) != + 0) { + throw python_error(); + } +} diff --git a/torch/csrc/QEngine.h b/torch/csrc/QEngine.h new file mode 100644 index 0000000000000..3af9f40ec95f9 --- /dev/null +++ b/torch/csrc/QEngine.h @@ -0,0 +1,24 @@ +#pragma once + +#include + +#include + +#include + +constexpr int QENGINE_NAME_LEN = 64; + +struct THPQEngine { + PyObject_HEAD at::QEngine qengine; + char name[QENGINE_NAME_LEN + 1]; +}; + +extern PyTypeObject THPQEngineType; + +inline bool THPQEngine_Check(PyObject* obj) { + return Py_TYPE(obj) == &THPQEngineType; +} + +PyObject* THPQEngine_New(at::QEngine qengine, const std::string& name); + +void THPQEngine_init(PyObject* module); diff --git a/torch/csrc/QScheme.cpp b/torch/csrc/QScheme.cpp index e83f8d643b171..eb73652a6a408 100644 --- a/torch/csrc/QScheme.cpp +++ b/torch/csrc/QScheme.cpp @@ -22,7 +22,7 @@ PyObject *THPQScheme_New(at::QScheme qscheme, const std::string& name) return self.release(); } -PyObject *THPQScheme_reduce(THPQScheme *self) { +PyObject *THPQScheme_reduce(THPQScheme *self, PyObject *noargs) { return THPUtils_packString(self->name); } diff --git a/torch/csrc/Size.cpp b/torch/csrc/Size.cpp index de99fcbb550f4..15902d9b09667 100644 --- a/torch/csrc/Size.cpp +++ b/torch/csrc/Size.cpp @@ -137,7 +137,7 @@ static PyMappingMethods THPSize_as_mapping = { nullptr }; -static PyObject *THPSize_numel(THPSize *self) +static PyObject *THPSize_numel(THPSize *self, PyObject *noargs) { HANDLE_TH_ERRORS int64_t numel = 1; @@ -148,7 +148,7 @@ static PyObject *THPSize_numel(THPSize *self) END_HANDLE_TH_ERRORS } -static PyObject *THPSize_reduce(THPSize *self) +static PyObject *THPSize_reduce(THPSize *self, PyObject *noargs) { HANDLE_TH_ERRORS auto ret = THPObjectPtr{PyTuple_New(2)}; diff --git a/torch/csrc/api/include/torch/arg.h b/torch/csrc/api/include/torch/arg.h index d4640ebe3f6d7..4b041d88f51e9 100644 --- a/torch/csrc/api/include/torch/arg.h +++ b/torch/csrc/api/include/torch/arg.h @@ -3,15 +3,17 @@ #include #define TORCH_ARG(T, name) \ - auto name(const T& new_##name)->decltype(*this) { /* NOLINT */ \ + public: \ + inline auto name(const T& new_##name)->decltype(*this) { /* NOLINT */ \ this->name##_ = new_##name; \ return *this; \ } \ - auto name(T&& new_##name)->decltype(*this) { /* NOLINT */ \ + inline auto name(T&& new_##name)->decltype(*this) { /* NOLINT */ \ this->name##_ = std::move(new_##name); \ return *this; \ } \ - const T& name() const noexcept { /* NOLINT */ \ + inline const T& name() const noexcept { /* NOLINT */ \ return this->name##_; \ } \ + private: \ T name##_ /* NOLINT */ diff --git a/torch/csrc/api/include/torch/data/dataloader_options.h b/torch/csrc/api/include/torch/data/dataloader_options.h index 4e0442f1303d1..600f704cffdc4 100644 --- a/torch/csrc/api/include/torch/data/dataloader_options.h +++ b/torch/csrc/api/include/torch/data/dataloader_options.h @@ -47,12 +47,12 @@ struct DataLoaderOptions { /// instance, which will do any necessary coalescing. struct FullDataLoaderOptions { explicit FullDataLoaderOptions(DataLoaderOptions options) - : batch_size(options.batch_size_), - workers(options.workers_), - max_jobs(options.max_jobs_.value_or(2 * workers)), - timeout(options.timeout_), - enforce_ordering(options.enforce_ordering_), - drop_last(options.drop_last_) {} + : batch_size(options.batch_size()), + workers(options.workers()), + max_jobs(options.max_jobs().value_or(2 * workers)), + timeout(options.timeout()), + enforce_ordering(options.enforce_ordering()), + drop_last(options.drop_last()) {} size_t batch_size; size_t workers; diff --git a/torch/csrc/api/include/torch/data/datasets/chunk.h b/torch/csrc/api/include/torch/data/datasets/chunk.h index 8063ff1be35c5..53784620405c8 100644 --- a/torch/csrc/api/include/torch/data/datasets/chunk.h +++ b/torch/csrc/api/include/torch/data/datasets/chunk.h @@ -350,16 +350,16 @@ class ChunkDataset final "Dataset needs to call reset() before calling get_batch()."); TORCH_CHECK( - batch_size == options_.batch_size_, + batch_size == options_.batch_size(), "The requested batch size does not match with the initialized batch size.\n" " The requested batch size is ", batch_size, - ", while the dataset is created with batch size equal to ", options_.batch_size_); + ", while the dataset is created with batch size equal to ", options_.batch_size()); return batch_buffer_->get_batch(); } /// Helper method around get_batch as `batch_size` is not strictly necessary BatchType get_batch() { - return get_batch(options_.batch_size_); + return get_batch(options_.batch_size()); } /// This will clear any internal state and starts the internal prefetching @@ -383,16 +383,16 @@ class ChunkDataset final // chunk buffer. batch_buffer_ = torch::make_unique< detail::BatchDataBuffer>( - options_.batch_size_, + options_.batch_size(), example_sampler_, - options_.cache_size_); + options_.cache_size()); // create new workers for this new epoch. quit_worker_ = false; AT_ASSERT(running_preloaders_ == 0); - running_preloaders_ = options_.preloader_count_; - for (size_t i = 0; i < options_.preloader_count_; ++i) { + running_preloaders_ = options_.preloader_count(); + for (size_t i = 0; i < options_.preloader_count(); ++i) { preload_threads_.emplace_back([this, i]() { this->preloader(i); }); } } @@ -427,7 +427,7 @@ class ChunkDataset final std::vector chunk_idx; { std::lock_guard lock(chunk_index_guard_); - if (auto chunk_sampler_result = chunk_sampler_.next(this->options_.cross_chunk_shuffle_count_)) { + if (auto chunk_sampler_result = chunk_sampler_.next(this->options_.cross_chunk_shuffle_count())) { chunk_idx = chunk_sampler_result.value(); } else { break; diff --git a/torch/csrc/api/include/torch/nn.h b/torch/csrc/api/include/torch/nn.h index 4d64f990e7c84..10474512a2580 100644 --- a/torch/csrc/api/include/torch/nn.h +++ b/torch/csrc/api/include/torch/nn.h @@ -1,7 +1,9 @@ #pragma once #include +#include #include #include #include +#include #include diff --git a/torch/csrc/api/include/torch/nn/functional.h b/torch/csrc/api/include/torch/nn/functional.h new file mode 100644 index 0000000000000..f670f3bd8f0cd --- /dev/null +++ b/torch/csrc/api/include/torch/nn/functional.h @@ -0,0 +1,3 @@ +#pragma once + +#include diff --git a/torch/csrc/api/include/torch/nn/functional/distance.h b/torch/csrc/api/include/torch/nn/functional/distance.h new file mode 100644 index 0000000000000..6b139c8c0c604 --- /dev/null +++ b/torch/csrc/api/include/torch/nn/functional/distance.h @@ -0,0 +1,36 @@ +#pragma once + +#include + +namespace torch { +namespace nn { +namespace functional { + +inline Tensor cosine_similarity( + const Tensor& x1, + const Tensor& x2, + const CosineSimilarityOptions& options) { + return torch::cosine_similarity( + x1, + x2, + options.dim(), + options.eps()); +} + +// ============================================================================ + +inline Tensor pairwise_distance( + const Tensor& x1, + const Tensor& x2, + const PairwiseDistanceOptions& options) { + return torch::pairwise_distance( + x1, + x2, + options.p(), + options.eps(), + options.keepdim()); +} + +} // namespace functional +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/include/torch/nn/functional/pooling.h b/torch/csrc/api/include/torch/nn/functional/pooling.h new file mode 100644 index 0000000000000..b21304026b2ca --- /dev/null +++ b/torch/csrc/api/include/torch/nn/functional/pooling.h @@ -0,0 +1,75 @@ +#pragma once + +#include + +namespace torch { +namespace nn{ +namespace functional { + +inline Tensor avg_pool1d(const Tensor& input, const AvgPool1dOptions& options) { + return torch::avg_pool1d( + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.ceil_mode(), + options.count_include_pad()); +} + +inline Tensor avg_pool2d(const Tensor& input, const AvgPool2dOptions& options) { + return torch::avg_pool2d( + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.ceil_mode(), + options.count_include_pad(), + options.divisor_override()); +} + +inline Tensor avg_pool3d(const Tensor& input, const AvgPool3dOptions& options) { + return torch::avg_pool3d( + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.ceil_mode(), + options.count_include_pad(), + options.divisor_override()); +} + +// ============================================================================ + +inline Tensor max_pool1d(const Tensor& input, const MaxPool1dOptions& options) { + return torch::max_pool1d( + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.dilation(), + options.ceil_mode()); +} + +inline Tensor max_pool2d(const Tensor& input, const MaxPool2dOptions& options) { + return torch::max_pool2d( + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.dilation(), + options.ceil_mode()); +} + +inline Tensor max_pool3d(const Tensor& input, const MaxPool3dOptions& options) { + return torch::max_pool3d( + input, + options.kernel_size(), + options.stride(), + options.padding(), + options.dilation(), + options.ceil_mode()); +} + +} // namespace functional +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/include/torch/nn/module.h b/torch/csrc/api/include/torch/nn/module.h index 85d6d5c5d3f30..e6d21014c5eaa 100644 --- a/torch/csrc/api/include/torch/nn/module.h +++ b/torch/csrc/api/include/torch/nn/module.h @@ -504,6 +504,10 @@ class TORCH_API Module : public std::enable_shared_from_this { const std::string& name, ModuleHolder module_holder); + /// Unregisters a submodule from this `Module`. If there is no such module + /// with `name` an exception is thrown. + void unregister_module(const std::string& name); + private: // Friend classes. diff --git a/torch/csrc/api/include/torch/nn/modules.h b/torch/csrc/api/include/torch/nn/modules.h index 16e74ab715435..5d91f2924cc2a 100644 --- a/torch/csrc/api/include/torch/nn/modules.h +++ b/torch/csrc/api/include/torch/nn/modules.h @@ -1,13 +1,20 @@ #pragma once -#include +// Containers +#include +#include +#include +#include +#include + +// Layers #include #include #include +#include #include -#include +#include #include -#include -#include +#include +#include #include -#include diff --git a/torch/csrc/api/include/torch/nn/modules/batchnorm.h b/torch/csrc/api/include/torch/nn/modules/batchnorm.h index 782a5d950d99b..210effd35a17c 100644 --- a/torch/csrc/api/include/torch/nn/modules/batchnorm.h +++ b/torch/csrc/api/include/torch/nn/modules/batchnorm.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -9,29 +10,6 @@ namespace torch { namespace nn { -/// Options for the `BatchNorm` module. -struct TORCH_API BatchNormOptions { - /* implicit */ BatchNormOptions(int64_t features); - /// The number of features of the input tensor. - /// Changing this parameter after construction __has no effect__. - TORCH_ARG(int64_t, features); - /// Whether to learn a scale and bias that are applied in an affine - /// transformation on the input. - /// Changing this parameter after construction __has no effect__. - TORCH_ARG(bool, affine) = true; - /// Whether to store and update batch statistics (mean and variance) in the - /// module. If `false`, you should call `pure_forward` and supply those batch - /// statistics yourself. - /// Changing this parameter after construction __has no effect__. - TORCH_ARG(bool, stateful) = true; - /// The epsilon value added for numerical stability. - /// Changing this parameter after construction __is effective__. - TORCH_ARG(double, eps) = 1e-5; - /// A momentum multiplier for the mean and variance. - /// Changing this parameter after construction __is effective__. - TORCH_ARG(double, momentum) = 0.1; -}; - /// Applies [Batch Normalization](https://arxiv.org/abs/1502.03167) to an input. /// /// Refer to the documentation for @@ -49,7 +27,7 @@ class TORCH_API BatchNormImpl : public torch::nn::Cloneable { public: explicit BatchNormImpl(int64_t features) : BatchNormImpl(BatchNormOptions(features)) {} - explicit BatchNormImpl(BatchNormOptions options); + explicit BatchNormImpl(const BatchNormOptions& options_); void reset() override; diff --git a/torch/csrc/api/include/torch/nn/modules/any.h b/torch/csrc/api/include/torch/nn/modules/container/any.h similarity index 100% rename from torch/csrc/api/include/torch/nn/modules/any.h rename to torch/csrc/api/include/torch/nn/modules/container/any.h diff --git a/torch/csrc/api/include/torch/nn/modules/functional.h b/torch/csrc/api/include/torch/nn/modules/container/functional.h similarity index 100% rename from torch/csrc/api/include/torch/nn/modules/functional.h rename to torch/csrc/api/include/torch/nn/modules/container/functional.h diff --git a/torch/csrc/api/include/torch/nn/modules/modulelist.h b/torch/csrc/api/include/torch/nn/modules/container/modulelist.h similarity index 100% rename from torch/csrc/api/include/torch/nn/modules/modulelist.h rename to torch/csrc/api/include/torch/nn/modules/container/modulelist.h diff --git a/torch/csrc/api/include/torch/nn/modules/named_any.h b/torch/csrc/api/include/torch/nn/modules/container/named_any.h similarity index 99% rename from torch/csrc/api/include/torch/nn/modules/named_any.h rename to torch/csrc/api/include/torch/nn/modules/container/named_any.h index d4fa54b02d823..ba8eef78ea0eb 100644 --- a/torch/csrc/api/include/torch/nn/modules/named_any.h +++ b/torch/csrc/api/include/torch/nn/modules/container/named_any.h @@ -2,7 +2,7 @@ #include #include -#include +#include #include #include diff --git a/torch/csrc/api/include/torch/nn/modules/sequential.h b/torch/csrc/api/include/torch/nn/modules/container/sequential.h similarity index 99% rename from torch/csrc/api/include/torch/nn/modules/sequential.h rename to torch/csrc/api/include/torch/nn/modules/container/sequential.h index 1f1c17e731f75..bd816ed90fa2c 100644 --- a/torch/csrc/api/include/torch/nn/modules/sequential.h +++ b/torch/csrc/api/include/torch/nn/modules/container/sequential.h @@ -3,7 +3,7 @@ #include #include #include -#include +#include #include #include diff --git a/torch/csrc/api/include/torch/nn/modules/conv.h b/torch/csrc/api/include/torch/nn/modules/conv.h index c492ef6f0d341..5c601f9a222f4 100644 --- a/torch/csrc/api/include/torch/nn/modules/conv.h +++ b/torch/csrc/api/include/torch/nn/modules/conv.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -13,69 +14,6 @@ namespace torch { namespace nn { -/// Options for a `D`-dimensional convolution module. -template -struct ConvOptions { - ConvOptions( - int64_t input_channels, - int64_t output_channels, - ExpandingArray kernel_size) : - input_channels_(input_channels), - output_channels_(output_channels), - kernel_size_(std::move(kernel_size)) {} - - /// The number of channels the input volumes will have. - /// Changing this parameter after construction __has no effect__. - TORCH_ARG(int64_t, input_channels); - - /// The number of output channels the convolution should produce. - /// Changing this parameter after construction __has no effect__. - TORCH_ARG(int64_t, output_channels); - - /// The kernel size to use. - /// For a `D`-dim convolution, must be a single number or a list of `D` - /// numbers. - /// This parameter __can__ be changed after construction. - TORCH_ARG(ExpandingArray, kernel_size); - - /// The stride of the convolution. - /// For a `D`-dim convolution, must be a single number or a list of `D` - /// numbers. - /// This parameter __can__ be changed after construction. - TORCH_ARG(ExpandingArray, stride) = 1; - - /// The padding to add to the input volumes. - /// For a `D`-dim convolution, must be a single number or a list of `D` - /// numbers. - /// This parameter __can__ be changed after construction. - TORCH_ARG(ExpandingArray, padding) = 0; - - /// The kernel dilation. - /// For a `D`-dim convolution, must be a single number or a list of `D` - /// numbers. - /// This parameter __can__ be changed after construction. - TORCH_ARG(ExpandingArray, dilation) = 1; - - /// For transpose convolutions, the padding to add to output volumes. - /// For a `D`-dim convolution, must be a single number or a list of `D` - /// numbers. - /// This parameter __can__ be changed after construction. - TORCH_ARG(ExpandingArray, output_padding) = 0; - - /// If true, convolutions will be transpose convolutions (a.k.a. - /// deconvolutions). - /// Changing this parameter after construction __has no effect__. - TORCH_ARG(bool, transposed) = false; - - /// Whether to add a bias after individual applications of the kernel. - /// Changing this parameter after construction __has no effect__. - TORCH_ARG(bool, with_bias) = true; - - /// The number of convolution groups. - /// This parameter __can__ be changed after construction. - TORCH_ARG(int64_t, groups) = 1; -}; - /// Base class for all (dimension-specialized) convolution modules. template class TORCH_API ConvImpl : public torch::nn::Cloneable { @@ -86,7 +24,7 @@ class TORCH_API ConvImpl : public torch::nn::Cloneable { ExpandingArray kernel_size) : ConvImpl(ConvOptions(input_channels, output_channels, kernel_size)) { } - explicit ConvImpl(ConvOptions options); + explicit ConvImpl(const ConvOptions& options_); void reset() override; @@ -114,9 +52,6 @@ class TORCH_API Conv1dImpl : public ConvImpl<1, Conv1dImpl> { Tensor forward(const Tensor& input); }; -/// `ConvOptions` specialized for 1-D convolution. -using Conv1dOptions = ConvOptions<1>; - /// A `ModuleHolder` subclass for `Conv1dImpl`. /// See the documentation for `Conv1dImpl` class to learn what methods it /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's @@ -134,9 +69,6 @@ class TORCH_API Conv2dImpl : public ConvImpl<2, Conv2dImpl> { Tensor forward(const Tensor& input); }; -/// `ConvOptions` specialized for 2-D convolution. -using Conv2dOptions = ConvOptions<2>; - /// A `ModuleHolder` subclass for `Conv2dImpl`. /// See the documentation for `Conv2dImpl` class to learn what methods it /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's @@ -154,9 +86,6 @@ class TORCH_API Conv3dImpl : public ConvImpl<3, Conv3dImpl> { Tensor forward(const Tensor& input); }; -/// `ConvOptions` specialized for 3-D convolution. -using Conv3dOptions = ConvOptions<3>; - /// A `ModuleHolder` subclass for `Conv3dImpl`. /// See the documentation for `Conv3dImpl` class to learn what methods it /// provides, or the documentation for `ModuleHolder` to learn about PyTorch's diff --git a/torch/csrc/api/include/torch/nn/modules/distance.h b/torch/csrc/api/include/torch/nn/modules/distance.h new file mode 100644 index 0000000000000..1a441cd6b8517 --- /dev/null +++ b/torch/csrc/api/include/torch/nn/modules/distance.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace nn { + +/// Returns the cosine similarity between :math:`x_1` and :math:`x_2`, computed +/// along `dim`. +class TORCH_API CosineSimilarityImpl : public Cloneable { + public: + explicit CosineSimilarityImpl(const CosineSimilarityOptions& options_ = {}); + + void reset() override; + + /// Pretty prints the `CosineSimilarity` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input1, const Tensor& input2); + + /// The options with which this `Module` was constructed. + CosineSimilarityOptions options; +}; + +/// A `ModuleHolder` subclass for `CosineSimilarityImpl`. +/// See the documentation for `CosineSimilarityImpl` class to learn what methods +/// it provides, or the documentation for `ModuleHolder` to learn about +/// Pytorch's module storage semantics. +TORCH_MODULE(CosineSimilarity); + +// ============================================================================ + +/// Returns the batchwise pairwise distance between vectors :math:`v_1`, +/// :math:`v_2` using the p-norm. +class TORCH_API PairwiseDistanceImpl : public Cloneable { + public: + explicit PairwiseDistanceImpl(const PairwiseDistanceOptions& options_ = {}); + + void reset() override; + + /// Pretty prints the `PairwiseDistance` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input1, const Tensor& input2); + + /// The options with which this `Module` was constructed. + PairwiseDistanceOptions options; +}; + +/// A `ModuleHolder` subclass for `PairwiseDistanceImpl`. +/// See the documentation for `PairwiseDistanceImpl` class to learn what methods +/// it provides, or the documentation for `ModuleHolder` to learn about +/// Pytorch's module storage semantics. +TORCH_MODULE(PairwiseDistance); + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/include/torch/nn/modules/dropout.h b/torch/csrc/api/include/torch/nn/modules/dropout.h index ad70eaff86d33..e29f603bc6505 100644 --- a/torch/csrc/api/include/torch/nn/modules/dropout.h +++ b/torch/csrc/api/include/torch/nn/modules/dropout.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -12,20 +13,11 @@ namespace torch { namespace nn { -/// Options for `Dropout` and `FeatureDropout`. -struct TORCH_API DropoutOptions { - /* implicit */ DropoutOptions(double rate = 0.5); - /// The probability with which a particular component of the input is set to - /// zero. - /// Changes to this parameter at runtime are effective. - TORCH_ARG(double, rate); -}; - namespace detail { template class DropoutImplBase : public torch::nn::Cloneable { public: - explicit DropoutImplBase(DropoutOptions options_ = DropoutOptions()); + explicit DropoutImplBase(const DropoutOptions& options_ = DropoutOptions()); void reset() override; @@ -40,7 +32,7 @@ class DropoutImplBase : public torch::nn::Cloneable { /// about the exact semantics of this module. class TORCH_API DropoutImpl : public detail::DropoutImplBase { public: - explicit DropoutImpl(DropoutOptions options_ = DropoutOptions()); + explicit DropoutImpl(const DropoutOptions& options_ = DropoutOptions()); /// During training, applies a noise mask to the input tensor. /// During evaluation, applies an identity function. @@ -62,7 +54,7 @@ class TORCH_API DropoutImpl : public detail::DropoutImplBase { class TORCH_API FeatureDropoutImpl : public detail::DropoutImplBase { public: - explicit FeatureDropoutImpl(DropoutOptions options_ = DropoutOptions()); + explicit FeatureDropoutImpl(const DropoutOptions& options_ = DropoutOptions()); /// During training, applies a noise mask to the input tensor. /// During evaluation, applies an identity function. diff --git a/torch/csrc/api/include/torch/nn/modules/fold.h b/torch/csrc/api/include/torch/nn/modules/fold.h new file mode 100644 index 0000000000000..a9e837f1a8a70 --- /dev/null +++ b/torch/csrc/api/include/torch/nn/modules/fold.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace torch { +namespace nn { + +/// Applies fold over a 3-D input. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.Fold to learn about +/// the exact behavior of this module. +class TORCH_API FoldImpl : public torch::nn::Cloneable { + public: + FoldImpl(ExpandingArray<2> output_size, ExpandingArray<2> kernel_size) + : FoldImpl(FoldOptions(output_size, kernel_size)) {} + explicit FoldImpl(const FoldOptions& options_); + + void reset() override {} + + /// Pretty prints the `Fold` module into the given `stream`. + void pretty_print(std::ostream& stream) const override { + stream << "torch::nn::Fold"; + } + + Tensor forward(const Tensor& input); + + /// The options with which this `Module` was constructed. + FoldOptions options; +}; + +/// A `ModuleHolder` subclass for `FoldImpl`. +/// See the documentation for `FoldImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(Fold); + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/include/torch/nn/modules/linear.h b/torch/csrc/api/include/torch/nn/modules/linear.h index d11cd93e87d34..4063b347e0cc6 100644 --- a/torch/csrc/api/include/torch/nn/modules/linear.h +++ b/torch/csrc/api/include/torch/nn/modules/linear.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -10,22 +11,12 @@ namespace torch { namespace nn { -/// Options for the `Linear` module. -struct TORCH_API LinearOptions { - LinearOptions(int64_t in, int64_t out); - /// The number of input features (columns of the input matrix). - TORCH_ARG(int64_t, in); - /// The number of output features to produce (columns of the output matrix). - TORCH_ARG(int64_t, out); - /// Whether to learn and add a bias after the linear transformation. - TORCH_ARG(bool, with_bias) = true; -}; /// Applies a linear transformation with optional bias. class TORCH_API LinearImpl : public Cloneable { public: LinearImpl(int64_t in, int64_t out) : LinearImpl(LinearOptions(in, out)) {} - explicit LinearImpl(LinearOptions options); + explicit LinearImpl(const LinearOptions& options_); void reset() override; diff --git a/torch/csrc/api/include/torch/nn/modules/loss.h b/torch/csrc/api/include/torch/nn/modules/loss.h new file mode 100644 index 0000000000000..1dbf2a334726a --- /dev/null +++ b/torch/csrc/api/include/torch/nn/modules/loss.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace torch { +namespace nn { + +/// Creates a criterion that measures the mean absolute error (MAE) between each +/// element in the input : math :`x` and target : `y`. +struct TORCH_API L1LossImpl : Module { + explicit L1LossImpl(const L1LossOptions& options_ = {}); + + /// Pretty prints the `L1Loss` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + Tensor forward(const Tensor& input, const Tensor& target); + + /// The options with which this `Module` was constructed. + L1LossOptions options; +}; + +/// A `ModuleHolder` subclass for `L1LossImpl`. +/// See the documentation for `L1LossImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(L1Loss); + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/include/torch/nn/modules/pooling.h b/torch/csrc/api/include/torch/nn/modules/pooling.h new file mode 100644 index 0000000000000..315e7f62cbfb0 --- /dev/null +++ b/torch/csrc/api/include/torch/nn/modules/pooling.h @@ -0,0 +1,152 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace torch { +namespace nn { + +/// Base class for all (dimension-specialized) avgpool modules. +template +class TORCH_API AvgPoolImpl : public torch::nn::Cloneable { + public: + AvgPoolImpl(ExpandingArray kernel_size) + : AvgPoolImpl(AvgPoolOptions(kernel_size)) {} + explicit AvgPoolImpl(const AvgPoolOptions& options_); + + void reset() override; + + /// Pretty prints the `AvgPool{1,2,3}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + AvgPoolOptions options; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AvgPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies avgpool over a 1-D input. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.AvgPool1d to learn +/// about the exact behavior of this module. +class TORCH_API AvgPool1dImpl : public AvgPoolImpl<1, AvgPool1dImpl> { + public: + using AvgPoolImpl<1, AvgPool1dImpl>::AvgPoolImpl; + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AvgPool1dImpl`. +/// See the documentation for `AvgPool1dImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(AvgPool1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AvgPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies avgpool over a 2-D input. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.AvgPool2d to learn +/// about the exact behavior of this module. +class TORCH_API AvgPool2dImpl : public AvgPoolImpl<2, AvgPool2dImpl> { + public: + using AvgPoolImpl<2, AvgPool2dImpl>::AvgPoolImpl; + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AvgPool2dImpl`. +/// See the documentation for `AvgPool2dImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(AvgPool2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AvgPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies avgpool over a 3-D input. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.AvgPool3d to learn +/// about the exact behavior of this module. +class TORCH_API AvgPool3dImpl : public AvgPoolImpl<3, AvgPool3dImpl> { + public: + using AvgPoolImpl<3, AvgPool3dImpl>::AvgPoolImpl; + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `AvgPool2dImpl`. +/// See the documentation for `AvgPool2dImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(AvgPool3d); + +// ============================================================================ + +/// Base class for all (dimension-specialized) maxpool modules. +template +class TORCH_API MaxPoolImpl : public torch::nn::Cloneable { + public: + MaxPoolImpl(ExpandingArray kernel_size) + : MaxPoolImpl(MaxPoolOptions(kernel_size)) {} + explicit MaxPoolImpl(const MaxPoolOptions& options_); + + void reset() override; + + /// Pretty prints the `MaxPool{1,2,3}d` module into the given `stream`. + void pretty_print(std::ostream& stream) const override; + + /// The options with which this `Module` was constructed. + MaxPoolOptions options; +}; + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxPool1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies maxpool over a 1-D input. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.MaxPool1d to learn +/// about the exact behavior of this module. +class TORCH_API MaxPool1dImpl : public MaxPoolImpl<1, MaxPool1dImpl> { + public: + using MaxPoolImpl<1, MaxPool1dImpl>::MaxPoolImpl; + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `MaxPool1dImpl`. +/// See the documentation for `MaxPool1dImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(MaxPool1d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxPool2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies maxpool over a 2-D input. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.MaxPool2d to learn +/// about the exact behavior of this module. +class TORCH_API MaxPool2dImpl : public MaxPoolImpl<2, MaxPool2dImpl> { + public: + using MaxPoolImpl<2, MaxPool2dImpl>::MaxPoolImpl; + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `MaxPool2dImpl`. +/// See the documentation for `MaxPool2dImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(MaxPool2d); + +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MaxPool3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +/// Applies maxpool over a 3-D input. +/// See https://pytorch.org/docs/master/nn.html#torch.nn.MaxPool3d to learn +/// about the exact behavior of this module. +class TORCH_API MaxPool3dImpl : public MaxPoolImpl<3, MaxPool3dImpl> { + public: + using MaxPoolImpl<3, MaxPool3dImpl>::MaxPoolImpl; + Tensor forward(const Tensor& input); +}; + +/// A `ModuleHolder` subclass for `MaxPool3dImpl`. +/// See the documentation for `MaxPool3dImpl` class to learn what methods it +/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(MaxPool3d); + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/include/torch/nn/modules/rnn.h b/torch/csrc/api/include/torch/nn/modules/rnn.h index e6d161e9f56b7..d6fa591b30812 100644 --- a/torch/csrc/api/include/torch/nn/modules/rnn.h +++ b/torch/csrc/api/include/torch/nn/modules/rnn.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -27,30 +28,6 @@ struct TORCH_API RNNOutput { }; namespace detail { - -/// Common options for LSTM and GRU modules. -struct TORCH_API RNNOptionsBase { - RNNOptionsBase(int64_t input_size, int64_t hidden_size); - virtual ~RNNOptionsBase() = default; - /// The number of features of a single sample in the input sequence `x`. - TORCH_ARG(int64_t, input_size); - /// The number of features in the hidden state `h`. - TORCH_ARG(int64_t, hidden_size); - /// The number of recurrent layers (cells) to use. - TORCH_ARG(int64_t, layers) = 1; - /// Whether a bias term should be added to all linear operations. - TORCH_ARG(bool, with_bias) = true; - /// If non-zero, adds dropout with the given probability to the output of each - /// RNN layer, except the final layer. - TORCH_ARG(double, dropout) = 0.0; - /// Whether to make the RNN bidirectional. - TORCH_ARG(bool, bidirectional) = false; - /// If true, the input sequence should be provided as `(batch, sequence, - /// features)`. If false (default), the expected layout is `(sequence, batch, - /// features)`. - TORCH_ARG(bool, batch_first) = false; -}; - /// Base class for all RNN implementations (intended for code sharing). template class TORCH_API RNNImplBase : public torch::nn::Cloneable { @@ -139,38 +116,6 @@ class TORCH_API RNNImplBase : public torch::nn::Cloneable { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -enum class RNNActivation : uint32_t {ReLU, Tanh}; - -/// Options for RNN modules. -struct TORCH_API RNNOptions { - RNNOptions(int64_t input_size, int64_t hidden_size); - - /// Sets the activation after linear operations to `tanh`. - RNNOptions& tanh(); - /// Sets the activation after linear operations to `relu`. - RNNOptions& relu(); - - /// The number of features of a single sample in the input sequence `x`. - TORCH_ARG(int64_t, input_size); - /// The number of features in the hidden state `h`. - TORCH_ARG(int64_t, hidden_size); - /// The number of recurrent layers (cells) to use. - TORCH_ARG(int64_t, layers) = 1; - /// Whether a bias term should be added to all linear operations. - TORCH_ARG(bool, with_bias) = true; - /// If non-zero, adds dropout with the given probability to the output of each - /// RNN layer, except the final layer. - TORCH_ARG(double, dropout) = 0.0; - /// Whether to make the RNN bidirectional. - TORCH_ARG(bool, bidirectional) = false; - /// If true, the input sequence should be provided as `(batch, sequence, - /// features)`. If false (default), the expected layout is `(sequence, batch, - /// features)`. - TORCH_ARG(bool, batch_first) = false; - /// The activation to use after linear operations. - TORCH_ARG(RNNActivation, activation) = RNNActivation::ReLU; -}; - /// A multi-layer Elman RNN module with Tanh or ReLU activation. /// See https://pytorch.org/docs/master/nn.html#torch.nn.RNN to learn about the /// exact behavior of this module. @@ -178,7 +123,7 @@ class TORCH_API RNNImpl : public detail::RNNImplBase { public: RNNImpl(int64_t input_size, int64_t hidden_size) : RNNImpl(RNNOptions(input_size, hidden_size)) {} - explicit RNNImpl(const RNNOptions& options); + explicit RNNImpl(const RNNOptions& options_); /// Pretty prints the `RNN` module into the given `stream`. void pretty_print(std::ostream& stream) const override; @@ -200,8 +145,6 @@ TORCH_MODULE(RNN); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -using LSTMOptions = detail::RNNOptionsBase; - /// A multi-layer long-short-term-memory (LSTM) module. /// See https://pytorch.org/docs/master/nn.html#torch.nn.LSTM to learn about the /// exact behavior of this module. @@ -209,7 +152,7 @@ class TORCH_API LSTMImpl : public detail::RNNImplBase { public: LSTMImpl(int64_t input_size, int64_t hidden_size) : LSTMImpl(LSTMOptions(input_size, hidden_size)) {} - explicit LSTMImpl(const LSTMOptions& options); + explicit LSTMImpl(const LSTMOptions& options_); /// Applies the `LSTM` module to an input sequence and input state. /// The `input` should follow a `(sequence, batch, features)` layout unless @@ -226,8 +169,6 @@ TORCH_MODULE(LSTM); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -using GRUOptions = detail::RNNOptionsBase; - /// A multi-layer gated recurrent unit (GRU) module. /// See https://pytorch.org/docs/master/nn.html#torch.nn.GRU to learn about the /// exact behavior of this module. @@ -235,7 +176,7 @@ class TORCH_API GRUImpl : public detail::RNNImplBase { public: GRUImpl(int64_t input_size, int64_t hidden_size) : GRUImpl(GRUOptions(input_size, hidden_size)) {} - explicit GRUImpl(const GRUOptions& options); + explicit GRUImpl(const GRUOptions& options_); /// Applies the `GRU` module to an input sequence and input state. /// The `input` should follow a `(sequence, batch, features)` layout unless diff --git a/torch/csrc/api/include/torch/nn/options.h b/torch/csrc/api/include/torch/nn/options.h new file mode 100644 index 0000000000000..caa49c4088d1c --- /dev/null +++ b/torch/csrc/api/include/torch/nn/options.h @@ -0,0 +1,10 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/torch/csrc/api/include/torch/nn/options/batchnorm.h b/torch/csrc/api/include/torch/nn/options/batchnorm.h new file mode 100644 index 0000000000000..ca6a952603d97 --- /dev/null +++ b/torch/csrc/api/include/torch/nn/options/batchnorm.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace nn { + +/// Options for the `BatchNorm` module. +struct TORCH_API BatchNormOptions { + /* implicit */ BatchNormOptions(int64_t features); + /// The number of features of the input tensor. + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(int64_t, features); + /// Whether to learn a scale and bias that are applied in an affine + /// transformation on the input. + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(bool, affine) = true; + /// Whether to store and update batch statistics (mean and variance) in the + /// module. If `false`, you should call `pure_forward` and supply those batch + /// statistics yourself. + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(bool, stateful) = true; + /// The epsilon value added for numerical stability. + /// Changing this parameter after construction __is effective__. + TORCH_ARG(double, eps) = 1e-5; + /// A momentum multiplier for the mean and variance. + /// Changing this parameter after construction __is effective__. + TORCH_ARG(double, momentum) = 0.1; +}; + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/include/torch/nn/options/conv.h b/torch/csrc/api/include/torch/nn/options/conv.h new file mode 100644 index 0000000000000..6559f827d2fad --- /dev/null +++ b/torch/csrc/api/include/torch/nn/options/conv.h @@ -0,0 +1,84 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch { +namespace nn { + +/// Options for a `D`-dimensional convolution module. +template +struct ConvOptions { + ConvOptions( + int64_t input_channels, + int64_t output_channels, + ExpandingArray kernel_size) : + input_channels_(input_channels), + output_channels_(output_channels), + kernel_size_(std::move(kernel_size)) {} + + /// The number of channels the input volumes will have. + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(int64_t, input_channels); + + /// The number of output channels the convolution should produce. + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(int64_t, output_channels); + + /// The kernel size to use. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(ExpandingArray, kernel_size); + + /// The stride of the convolution. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(ExpandingArray, stride) = 1; + + /// The padding to add to the input volumes. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(ExpandingArray, padding) = 0; + + /// The kernel dilation. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(ExpandingArray, dilation) = 1; + + /// For transpose convolutions, the padding to add to output volumes. + /// For a `D`-dim convolution, must be a single number or a list of `D` + /// numbers. + /// This parameter __can__ be changed after construction. + TORCH_ARG(ExpandingArray, output_padding) = 0; + + /// If true, convolutions will be transpose convolutions (a.k.a. + /// deconvolutions). + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(bool, transposed) = false; + + /// Whether to add a bias after individual applications of the kernel. + /// Changing this parameter after construction __has no effect__. + TORCH_ARG(bool, with_bias) = true; + + /// The number of convolution groups. + /// This parameter __can__ be changed after construction. + TORCH_ARG(int64_t, groups) = 1; +}; + +/// `ConvOptions` specialized for 1-D convolution. +using Conv1dOptions = ConvOptions<1>; + +/// `ConvOptions` specialized for 2-D convolution. +using Conv2dOptions = ConvOptions<2>; + +/// `ConvOptions` specialized for 3-D convolution. +using Conv3dOptions = ConvOptions<3>; + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/include/torch/nn/options/distance.h b/torch/csrc/api/include/torch/nn/options/distance.h new file mode 100644 index 0000000000000..6e7b6367ad7f9 --- /dev/null +++ b/torch/csrc/api/include/torch/nn/options/distance.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace nn { + +/// Options for the `CosineSimilarity` module. +struct TORCH_API CosineSimilarityOptions { + /// Dimension where cosine similarity is computed. Default: 1 + TORCH_ARG(int64_t, dim) = 1; + /// Small value to avoid division by zero. Default: 1e-8 + TORCH_ARG(double, eps) = 1e-8; +}; + +// ============================================================================ + +/// Options for the `PairwiseDistance` module. +struct TORCH_API PairwiseDistanceOptions { + PairwiseDistanceOptions(double p = 2.0) + : p_(p) {} + /// The norm degree. Default: 2 + TORCH_ARG(double, p); + /// Small value to avoid division by zero. Default: 1e-6 + TORCH_ARG(double, eps) = 1e-6; + /// Determines whether or not to keep the vector dimension. Default: false + TORCH_ARG(bool, keepdim) = false; +}; + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/include/torch/nn/options/dropout.h b/torch/csrc/api/include/torch/nn/options/dropout.h new file mode 100644 index 0000000000000..b82cb45036bf7 --- /dev/null +++ b/torch/csrc/api/include/torch/nn/options/dropout.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace nn { + +/// Options for `Dropout` and `FeatureDropout`. +struct TORCH_API DropoutOptions { + /* implicit */ DropoutOptions(double rate = 0.5); + /// The probability with which a particular component of the input is set to + /// zero. + /// Changes to this parameter at runtime are effective. + TORCH_ARG(double, rate); +}; + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/include/torch/nn/options/fold.h b/torch/csrc/api/include/torch/nn/options/fold.h new file mode 100644 index 0000000000000..ccf6bba2ab33b --- /dev/null +++ b/torch/csrc/api/include/torch/nn/options/fold.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch { +namespace nn { + +/// Options for a fold module. +struct TORCH_API FoldOptions { + FoldOptions(ExpandingArray<2> output_size, ExpandingArray<2> kernel_size) + : output_size_(std::move(output_size)), + kernel_size_(std::move(kernel_size)) {} + + /// describes the spatial shape of the large containing tensor of the sliding + /// local blocks. It is useful to resolve the ambiguity when multiple input + /// shapes map to same number of sliding blocks, e.g., with stride > 0. + TORCH_ARG(ExpandingArray<2>, output_size); + + /// the size of the sliding blocks + TORCH_ARG(ExpandingArray<2>, kernel_size); + + /// controls the spacing between the kernel points; also known as the à trous + /// algorithm. + TORCH_ARG(ExpandingArray<2>, dilation) = 1; + + /// controls the amount of implicit zero-paddings on both sides for padding + /// number of points for each dimension before reshaping. + TORCH_ARG(ExpandingArray<2>, padding) = 0; + + /// controls the stride for the sliding blocks. + TORCH_ARG(ExpandingArray<2>, stride) = 1; +}; + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/include/torch/nn/options/linear.h b/torch/csrc/api/include/torch/nn/options/linear.h new file mode 100644 index 0000000000000..d3df9cecc8881 --- /dev/null +++ b/torch/csrc/api/include/torch/nn/options/linear.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace nn { + +/// Options for the `Linear` module. +struct TORCH_API LinearOptions { + LinearOptions(int64_t in, int64_t out); + /// The number of input features (columns of the input matrix). + TORCH_ARG(int64_t, in); + /// The number of output features to produce (columns of the output matrix). + TORCH_ARG(int64_t, out); + /// Whether to learn and add a bias after the linear transformation. + TORCH_ARG(bool, with_bias) = true; +}; + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/include/torch/nn/options/loss.h b/torch/csrc/api/include/torch/nn/options/loss.h new file mode 100644 index 0000000000000..0fda84fbbf4df --- /dev/null +++ b/torch/csrc/api/include/torch/nn/options/loss.h @@ -0,0 +1,20 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace nn { + +/// Options for a L1 loss module. +struct TORCH_API L1LossOptions { + L1LossOptions(Reduction::Reduction reduction = Reduction::Mean) + : reduction_(reduction) {} + + /// Specifies the reduction to apply to the output. + TORCH_ARG(Reduction::Reduction, reduction); +}; + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/include/torch/nn/options/pooling.h b/torch/csrc/api/include/torch/nn/options/pooling.h new file mode 100644 index 0000000000000..ca04a1e010b9b --- /dev/null +++ b/torch/csrc/api/include/torch/nn/options/pooling.h @@ -0,0 +1,83 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch { +namespace nn { + +/// Options for a `D`-dimensional avgpool functional and module. +template +struct AvgPoolOptions { + AvgPoolOptions(ExpandingArray kernel_size) + : kernel_size_(kernel_size), stride_(kernel_size) {} + + /// the size of the window to take an average over + TORCH_ARG(ExpandingArray, kernel_size); + + /// the stride of the window. Default value is `kernel_size` + TORCH_ARG(ExpandingArray, stride); + + /// implicit zero padding to be added on both sides + TORCH_ARG(ExpandingArray, padding) = 0; + + /// when True, will use `ceil` instead of `floor` to compute the output shape + TORCH_ARG(bool, ceil_mode) = false; + + /// when True, will include the zero-padding in the averaging calculation + TORCH_ARG(bool, count_include_pad) = true; + + /// if specified, it will be used as divisor, otherwise `kernel_size` will be used + TORCH_ARG(c10::optional, divisor_override) = c10::nullopt; +}; + +/// `AvgPoolOptions` specialized for 1-D avgpool. +using AvgPool1dOptions = AvgPoolOptions<1>; + +/// `AvgPoolOptions` specialized for 2-D avgpool. +using AvgPool2dOptions = AvgPoolOptions<2>; + +/// `AvgPoolOptions` specialized for 3-D avgpool. +using AvgPool3dOptions = AvgPoolOptions<3>; + +// ============================================================================ + +/// Options for a `D`-dimensional maxpool functional and module. +template +struct MaxPoolOptions { + MaxPoolOptions(ExpandingArray kernel_size) + : kernel_size_(kernel_size), stride_(kernel_size) {} + + /// the size of the window to take a max over + TORCH_ARG(ExpandingArray, kernel_size); + + /// the stride of the window. Default value is `kernel_size + TORCH_ARG(ExpandingArray, stride); + + /// implicit zero padding to be added on both sides + TORCH_ARG(ExpandingArray, padding) = 0; + + /// a parameter that controls the stride of elements in the window + TORCH_ARG(ExpandingArray, dilation) = 1; + + /// if true, will return the max indices along with the outputs. Useful + /// for `MaxUnpool1d` later + TORCH_ARG(bool, return_indices) = false; + + /// when True, will use `ceil` instead of `floor` to compute the output shape + TORCH_ARG(bool, ceil_mode) = false; +}; + +/// `MaxPoolOptions` specialized for 1-D maxpool. +using MaxPool1dOptions = MaxPoolOptions<1>; + +/// `MaxPoolOptions` specialized for 2-D maxpool. +using MaxPool2dOptions = MaxPoolOptions<2>; + +/// `MaxPoolOptions` specialized for 3-D maxpool. +using MaxPool3dOptions = MaxPoolOptions<3>; + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/include/torch/nn/options/rnn.h b/torch/csrc/api/include/torch/nn/options/rnn.h new file mode 100644 index 0000000000000..d30d6d8949ead --- /dev/null +++ b/torch/csrc/api/include/torch/nn/options/rnn.h @@ -0,0 +1,73 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace nn { + +namespace detail { + +/// Common options for LSTM and GRU modules. +struct TORCH_API RNNOptionsBase { + RNNOptionsBase(int64_t input_size, int64_t hidden_size); + virtual ~RNNOptionsBase() = default; + /// The number of features of a single sample in the input sequence `x`. + TORCH_ARG(int64_t, input_size); + /// The number of features in the hidden state `h`. + TORCH_ARG(int64_t, hidden_size); + /// The number of recurrent layers (cells) to use. + TORCH_ARG(int64_t, layers) = 1; + /// Whether a bias term should be added to all linear operations. + TORCH_ARG(bool, with_bias) = true; + /// If non-zero, adds dropout with the given probability to the output of each + /// RNN layer, except the final layer. + TORCH_ARG(double, dropout) = 0.0; + /// Whether to make the RNN bidirectional. + TORCH_ARG(bool, bidirectional) = false; + /// If true, the input sequence should be provided as `(batch, sequence, + /// features)`. If false (default), the expected layout is `(sequence, batch, + /// features)`. + TORCH_ARG(bool, batch_first) = false; +}; + +} // namespace detail + +enum class RNNActivation : uint32_t {ReLU, Tanh}; + +/// Options for RNN modules. +struct TORCH_API RNNOptions { + RNNOptions(int64_t input_size, int64_t hidden_size); + + /// Sets the activation after linear operations to `tanh`. + RNNOptions& tanh(); + /// Sets the activation after linear operations to `relu`. + RNNOptions& relu(); + + /// The number of features of a single sample in the input sequence `x`. + TORCH_ARG(int64_t, input_size); + /// The number of features in the hidden state `h`. + TORCH_ARG(int64_t, hidden_size); + /// The number of recurrent layers (cells) to use. + TORCH_ARG(int64_t, layers) = 1; + /// Whether a bias term should be added to all linear operations. + TORCH_ARG(bool, with_bias) = true; + /// If non-zero, adds dropout with the given probability to the output of each + /// RNN layer, except the final layer. + TORCH_ARG(double, dropout) = 0.0; + /// Whether to make the RNN bidirectional. + TORCH_ARG(bool, bidirectional) = false; + /// If true, the input sequence should be provided as `(batch, sequence, + /// features)`. If false (default), the expected layout is `(sequence, batch, + /// features)`. + TORCH_ARG(bool, batch_first) = false; + /// The activation to use after linear operations. + TORCH_ARG(RNNActivation, activation) = RNNActivation::ReLU; +}; + +using LSTMOptions = detail::RNNOptionsBase; +using GRUOptions = detail::RNNOptionsBase; + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/include/torch/optim/adagrad.h b/torch/csrc/api/include/torch/optim/adagrad.h index 1174decd6e712..ba6c7788511a0 100644 --- a/torch/csrc/api/include/torch/optim/adagrad.h +++ b/torch/csrc/api/include/torch/optim/adagrad.h @@ -30,9 +30,9 @@ class TORCH_API Adagrad : public Optimizer { template explicit Adagrad( ParameterContainer&& parameters, - const AdagradOptions& options) + const AdagradOptions& options_) : Optimizer(std::forward(parameters)), - options(options) {} + options(options_) {} void step() override; diff --git a/torch/csrc/api/include/torch/optim/adam.h b/torch/csrc/api/include/torch/optim/adam.h index f21f87dc8cdbc..11d41451211ad 100644 --- a/torch/csrc/api/include/torch/optim/adam.h +++ b/torch/csrc/api/include/torch/optim/adam.h @@ -31,9 +31,9 @@ struct TORCH_API AdamOptions { class TORCH_API Adam : public Optimizer { public: template - explicit Adam(ParameterContainer&& parameters, const AdamOptions& options) + explicit Adam(ParameterContainer&& parameters, const AdamOptions& options_) : Optimizer(std::forward(parameters)), - options(options) {} + options(options_) {} void step() override; diff --git a/torch/csrc/api/include/torch/optim/lbfgs.h b/torch/csrc/api/include/torch/optim/lbfgs.h index 33b877bccfec2..2155b9a02cd24 100644 --- a/torch/csrc/api/include/torch/optim/lbfgs.h +++ b/torch/csrc/api/include/torch/optim/lbfgs.h @@ -27,11 +27,11 @@ struct TORCH_API LBFGSOptions { class TORCH_API LBFGS : public LossClosureOptimizer { public: template - explicit LBFGS(ParameterContainer&& parameters, const LBFGSOptions& options) + explicit LBFGS(ParameterContainer&& parameters, const LBFGSOptions& options_) : LossClosureOptimizer(std::forward(parameters)), - options(options), - ro(options.history_size_), - al(options.history_size_) {} + options(options_), + ro(options_.history_size()), + al(options_.history_size()) {} torch::Tensor step(LossClosure closure) override; diff --git a/torch/csrc/api/include/torch/optim/rmsprop.h b/torch/csrc/api/include/torch/optim/rmsprop.h index 7f80c710ae1ec..36129e23ae160 100644 --- a/torch/csrc/api/include/torch/optim/rmsprop.h +++ b/torch/csrc/api/include/torch/optim/rmsprop.h @@ -36,9 +36,9 @@ class TORCH_API RMSprop : public Optimizer { template explicit RMSprop( ParameterContainer&& parameters, - const RMSpropOptions& options) + const RMSpropOptions& options_) : Optimizer(std::forward(parameters)), - options(options) {} + options(options_) {} void step() override; diff --git a/torch/csrc/api/include/torch/optim/sgd.h b/torch/csrc/api/include/torch/optim/sgd.h index d01d45d912d90..0a9cc822b27c6 100644 --- a/torch/csrc/api/include/torch/optim/sgd.h +++ b/torch/csrc/api/include/torch/optim/sgd.h @@ -31,9 +31,9 @@ struct TORCH_API SGDOptions { class TORCH_API SGD : public Optimizer { public: template - explicit SGD(ParameterContainer&& parameters, const SGDOptions& options) + explicit SGD(ParameterContainer&& parameters, const SGDOptions& options_) : Optimizer(std::forward(parameters)), - options(options) {} + options(options_) {} void step() override; diff --git a/torch/csrc/api/include/torch/ordered_dict.h b/torch/csrc/api/include/torch/ordered_dict.h index a26de43215488..73e007fed6c69 100644 --- a/torch/csrc/api/include/torch/ordered_dict.h +++ b/torch/csrc/api/include/torch/ordered_dict.h @@ -146,6 +146,10 @@ class OrderedDict { /// `other` is already present in this `OrderedDict`, an exception is thrown. void update(const OrderedDict& other); + /// Removes the item that has `key` from this `OrderedDict` if exists and if + /// it doesn't an exception is thrown. + void erase(const Key& key); + /// Removes all items from this `OrderedDict`. void clear(); @@ -221,14 +225,14 @@ class OrderedDict::Item { } /// Returns a `(key, value)` pair. - const std::pair& pair() const noexcept { + const std::pair& pair() const noexcept { return pair_; } private: /// This is stored as an std::pair because it will make Python binding a lot, /// lot easier. - ::std::pair pair_; + ::std::pair pair_; }; // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ OrderedDict ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -401,6 +405,20 @@ const Value* OrderedDict::find(const Key& key) const noexcept { return &items_[iterator->second].value(); } +template +void OrderedDict::erase(const Key& key) { + auto it = index_.find(key); + TORCH_CHECK(it != index_.end(), "Key '", key, "' doesn't exist"); + + auto index = it->second; + index_.erase(it); + items_.erase(items_.begin() + index); + + for (auto& pair : index_) + if (pair.second > index) + --pair.second; +} + template bool OrderedDict::contains(const Key& key) const noexcept { return find(key) != nullptr; diff --git a/torch/csrc/api/include/torch/serialize.h b/torch/csrc/api/include/torch/serialize.h index 5fd19f52e4a77..5833a8cdd7499 100644 --- a/torch/csrc/api/include/torch/serialize.h +++ b/torch/csrc/api/include/torch/serialize.h @@ -2,6 +2,7 @@ #include #include +#include #include @@ -72,6 +73,8 @@ void save(const std::vector& tensor_vec, SaveToArgs&&... args) { archive.save_to(std::forward(args)...); } +TORCH_API std::vector pickle_save(const torch::IValue& ivalue); + /// Deserializes the given `value`. /// There must be an overload of `operator>>` between `serialize::InputArchive` /// and `Value` for this method to be well-formed. Currently, such an overload diff --git a/torch/csrc/api/src/nn/module.cpp b/torch/csrc/api/src/nn/module.cpp index 42785b0736da0..b887568f72802 100644 --- a/torch/csrc/api/src/nn/module.cpp +++ b/torch/csrc/api/src/nn/module.cpp @@ -330,6 +330,15 @@ Tensor& Module::register_buffer(std::string name, Tensor tensor) { return buffers_.insert(std::move(name), std::move(tensor)); } +void Module::unregister_module(const std::string& name) { + TORCH_CHECK( + children_.contains(name), + "No Module with name `", + name, + "` is registered"); + children_.erase(name); +} + void Module::pretty_print(std::ostream& stream) const { stream << name(); } diff --git a/torch/csrc/api/src/nn/modules/batchnorm.cpp b/torch/csrc/api/src/nn/modules/batchnorm.cpp index 7fab1f5f645a6..806d77bdb2c92 100644 --- a/torch/csrc/api/src/nn/modules/batchnorm.cpp +++ b/torch/csrc/api/src/nn/modules/batchnorm.cpp @@ -12,38 +12,37 @@ namespace torch { namespace nn { -BatchNormOptions::BatchNormOptions(int64_t features) : features_(features) {} -BatchNormImpl::BatchNormImpl(BatchNormOptions options) : options(options) { +BatchNormImpl::BatchNormImpl(const BatchNormOptions& options_) : options(options_) { reset(); } void BatchNormImpl::reset() { - if (options.affine_) { + if (options.affine()) { weight = register_parameter( - "weight", torch::empty({options.features_}).uniform_()); - bias = register_parameter("bias", torch::zeros({options.features_})); + "weight", torch::empty({options.features()}).uniform_()); + bias = register_parameter("bias", torch::zeros({options.features()})); } - if (options.stateful_) { + if (options.stateful()) { running_mean = - register_buffer("running_mean", torch::zeros({options.features_})); + register_buffer("running_mean", torch::zeros({options.features()})); running_var = - register_buffer("running_var", torch::ones({options.features_})); + register_buffer("running_var", torch::ones({options.features()})); } } void BatchNormImpl::pretty_print(std::ostream& stream) const { stream << std::boolalpha - << "torch::nn::BatchNorm(features=" << options.features_ - << ", eps=" << options.eps_ << ", momentum=" << options.momentum_ - << ", affine=" << options.affine_ << ", stateful=" << options.stateful_ + << "torch::nn::BatchNorm(features=" << options.features() + << ", eps=" << options.eps() << ", momentum=" << options.momentum() + << ", affine=" << options.affine() << ", stateful=" << options.stateful() << ")"; } Tensor BatchNormImpl::forward(const Tensor& input) { TORCH_CHECK( - options.stateful_, + options.stateful(), "Calling BatchNorm::forward is only permitted when " "the 'stateful' option is true (was false). " "Use BatchNorm::pure_forward instead."); @@ -68,8 +67,8 @@ Tensor BatchNormImpl::pure_forward( mean, variance, is_training(), - options.momentum_, - options.eps_, + options.momentum(), + options.eps(), torch::cuda::cudnn_is_available()); } diff --git a/torch/csrc/api/src/nn/modules/functional.cpp b/torch/csrc/api/src/nn/modules/container/functional.cpp similarity index 92% rename from torch/csrc/api/src/nn/modules/functional.cpp rename to torch/csrc/api/src/nn/modules/container/functional.cpp index c6ae368e3ff9d..215ba8739b943 100644 --- a/torch/csrc/api/src/nn/modules/functional.cpp +++ b/torch/csrc/api/src/nn/modules/container/functional.cpp @@ -1,4 +1,4 @@ -#include +#include #include diff --git a/torch/csrc/api/src/nn/modules/named_any.cpp b/torch/csrc/api/src/nn/modules/container/named_any.cpp similarity index 88% rename from torch/csrc/api/src/nn/modules/named_any.cpp rename to torch/csrc/api/src/nn/modules/container/named_any.cpp index 85c7656df9dfb..3237a1e42c39c 100644 --- a/torch/csrc/api/src/nn/modules/named_any.cpp +++ b/torch/csrc/api/src/nn/modules/container/named_any.cpp @@ -1,4 +1,4 @@ -#include +#include namespace torch { namespace nn { diff --git a/torch/csrc/api/src/nn/modules/conv.cpp b/torch/csrc/api/src/nn/modules/conv.cpp index 739b7ccd7d21a..661503b6a40d8 100644 --- a/torch/csrc/api/src/nn/modules/conv.cpp +++ b/torch/csrc/api/src/nn/modules/conv.cpp @@ -13,44 +13,44 @@ namespace torch { namespace nn { template -ConvImpl::ConvImpl(ConvOptions options) - : options(std::move(options)) { +ConvImpl::ConvImpl(const ConvOptions& options_) + : options(options_) { reset(); } template void ConvImpl::reset() { - if (!options.transposed_) { - for (auto pad : *options.output_padding_) { + if (!options.transposed()) { + for (auto pad : *options.output_padding()) { TORCH_CHECK( pad == 0, "Only transposed convolutions support output padding!"); } } std::vector weights_size; - if (options.transposed_) { - weights_size.push_back(options.input_channels_); - weights_size.push_back(options.output_channels_ / options.groups_); + if (options.transposed()) { + weights_size.push_back(options.input_channels()); + weights_size.push_back(options.output_channels() / options.groups()); } else { - weights_size.push_back(options.output_channels_); - weights_size.push_back(options.input_channels_ / options.groups_); + weights_size.push_back(options.output_channels()); + weights_size.push_back(options.input_channels() / options.groups()); } weights_size.insert( weights_size.end(), - options.kernel_size_->begin(), - options.kernel_size_->end()); - AT_ASSERT(weights_size.size() == 2 + options.kernel_size_->size()); + options.kernel_size()->begin(), + options.kernel_size()->end()); + AT_ASSERT(weights_size.size() == 2 + options.kernel_size()->size()); weight = this->register_parameter("weight", torch::empty(weights_size)); - if (options.with_bias_) { + if (options.with_bias()) { bias = this->register_parameter( - "bias", torch::empty(options.output_channels_)); + "bias", torch::empty(options.output_channels())); } const auto number_of_features = std::accumulate( - options.kernel_size_->begin(), - options.kernel_size_->end(), - options.input_channels_, + options.kernel_size()->begin(), + options.kernel_size()->end(), + options.input_channels(), std::multiplies{}); const auto stdv = 1.0 / std::sqrt(number_of_features); NoGradGuard no_grad; @@ -62,86 +62,81 @@ void ConvImpl::reset() { template void ConvImpl::pretty_print(std::ostream& stream) const { stream << "torch::nn::Conv" << D << "d" - << "(input_channels=" << options.input_channels_ - << ", output_channels=" << options.output_channels_ - << ", kernel_size=" << options.kernel_size_ - << ", stride=" << options.stride_ << ")"; + << "(input_channels=" << options.input_channels() + << ", output_channels=" << options.output_channels() + << ", kernel_size=" << options.kernel_size() + << ", stride=" << options.stride() << ")"; } Tensor Conv1dImpl::forward(const Tensor& input) { - if (options.transposed_) { + if (options.transposed()) { return torch::conv_transpose1d( input, weight, bias, - options.stride_, - options.padding_, - options.output_padding_, - options.groups_, - options.dilation_); + options.stride(), + options.padding(), + options.output_padding(), + options.groups(), + options.dilation()); } return torch::conv1d( input, weight, bias, - options.stride_, - options.padding_, - options.dilation_, - options.groups_); + options.stride(), + options.padding(), + options.dilation(), + options.groups()); } Tensor Conv2dImpl::forward(const Tensor& input) { - if (options.transposed_) { + if (options.transposed()) { return torch::conv_transpose2d( input, weight, bias, - options.stride_, - options.padding_, - options.output_padding_, - options.groups_, - options.dilation_); + options.stride(), + options.padding(), + options.output_padding(), + options.groups(), + options.dilation()); } return torch::conv2d( input, weight, bias, - options.stride_, - options.padding_, - options.dilation_, - options.groups_); + options.stride(), + options.padding(), + options.dilation(), + options.groups()); } Tensor Conv3dImpl::forward(const Tensor& input) { - if (options.transposed_) { + if (options.transposed()) { return torch::conv_transpose3d( input, weight, bias, - options.stride_, - options.padding_, - options.output_padding_, - options.groups_, - options.dilation_); + options.stride(), + options.padding(), + options.output_padding(), + options.groups(), + options.dilation()); } else { return torch::conv3d( input, weight, bias, - options.stride_, - options.padding_, - options.dilation_, - options.groups_); + options.stride(), + options.padding(), + options.dilation(), + options.groups()); } } -template struct ConvOptions<1>; template class ConvImpl<1, Conv1dImpl>; - -template struct ConvOptions<2>; template class ConvImpl<2, Conv2dImpl>; - -template struct ConvOptions<3>; template class ConvImpl<3, Conv3dImpl>; } // namespace nn diff --git a/torch/csrc/api/src/nn/modules/distance.cpp b/torch/csrc/api/src/nn/modules/distance.cpp new file mode 100644 index 0000000000000..7e098da19b4ee --- /dev/null +++ b/torch/csrc/api/src/nn/modules/distance.cpp @@ -0,0 +1,44 @@ +#include + +namespace F = torch::nn::functional; + +namespace torch { +namespace nn { + +CosineSimilarityImpl::CosineSimilarityImpl(const CosineSimilarityOptions& options_) + : options(options_) {} + +void CosineSimilarityImpl::reset() {} + +void CosineSimilarityImpl::pretty_print(std::ostream& stream) const { + stream << std::boolalpha + << "torch::nn::CosineSimilarity" + << "(dim=" << options.dim() + << ", eps=" << options.eps() << ")"; +} + +Tensor CosineSimilarityImpl::forward(const Tensor& x1, const Tensor& x2) { + return F::cosine_similarity(x1, x2, options); +} + +// ============================================================================ + +PairwiseDistanceImpl::PairwiseDistanceImpl(const PairwiseDistanceOptions& options_) + : options(options_) {} + +void PairwiseDistanceImpl::reset() {} + +void PairwiseDistanceImpl::pretty_print(std::ostream& stream) const { + stream << std::boolalpha + << "torch::nn::PairwiseDistance" + << "(p=" << options.p() + << ", eps=" << options.eps() + << ", keepdim=" << options.keepdim() << ")"; +} + +Tensor PairwiseDistanceImpl::forward(const Tensor& x1, const Tensor& x2) { + return F::pairwise_distance(x1, x2, options); +} + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/src/nn/modules/dropout.cpp b/torch/csrc/api/src/nn/modules/dropout.cpp index 84a0d916b7e13..e4268f0ef08aa 100644 --- a/torch/csrc/api/src/nn/modules/dropout.cpp +++ b/torch/csrc/api/src/nn/modules/dropout.cpp @@ -12,10 +12,10 @@ namespace torch { namespace nn { namespace detail { template -DropoutImplBase::DropoutImplBase(DropoutOptions options_) +DropoutImplBase::DropoutImplBase(const DropoutOptions& options_) : options(options_) { - TORCH_CHECK(options.rate_ >= 0, "Dropout rate must not be less than zero"); - TORCH_CHECK(options.rate_ <= 1, "Dropout rate must not be greater than one"); + TORCH_CHECK(options.rate() >= 0, "Dropout rate must not be less than zero"); + TORCH_CHECK(options.rate() <= 1, "Dropout rate must not be greater than one"); } template @@ -25,27 +25,25 @@ template class DropoutImplBase; template class DropoutImplBase; } // namespace detail -DropoutOptions::DropoutOptions(double rate) : rate_(rate) {} - -DropoutImpl::DropoutImpl(DropoutOptions options_) : DropoutImplBase(options_) {} +DropoutImpl::DropoutImpl(const DropoutOptions& options_) : DropoutImplBase(options_) {} Tensor DropoutImpl::forward(const Tensor& input) { - return torch::dropout(input, options.rate_, this->is_training()); + return torch::dropout(input, options.rate(), this->is_training()); } void DropoutImpl::pretty_print(std::ostream& stream) const { - stream << "torch::nn::Dropout(rate=" << options.rate_ << ")"; + stream << "torch::nn::Dropout(rate=" << options.rate() << ")"; } -FeatureDropoutImpl::FeatureDropoutImpl(DropoutOptions options_) +FeatureDropoutImpl::FeatureDropoutImpl(const DropoutOptions& options_) : DropoutImplBase(options_) {} Tensor FeatureDropoutImpl::forward(const Tensor& input) { - return torch::feature_dropout(input, options.rate_, this->is_training()); + return torch::feature_dropout(input, options.rate(), this->is_training()); } void FeatureDropoutImpl::pretty_print(std::ostream& stream) const { - stream << "torch::nn::FeatureDropout(rate=" << options.rate_ << ")"; + stream << "torch::nn::FeatureDropout(rate=" << options.rate() << ")"; } } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/nn/modules/embedding.cpp b/torch/csrc/api/src/nn/modules/embedding.cpp index 786b4272b79d7..16d4c95a854cb 100644 --- a/torch/csrc/api/src/nn/modules/embedding.cpp +++ b/torch/csrc/api/src/nn/modules/embedding.cpp @@ -20,14 +20,14 @@ EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options) : options(options) { void EmbeddingImpl::reset() { weight = register_parameter( - "weight", torch::empty({options.count_, options.dimension_})); + "weight", torch::empty({options.count(), options.dimension()})); NoGradGuard guard; weight.normal_(0, 1); } void EmbeddingImpl::pretty_print(std::ostream& stream) const { - stream << "torch::nn::Embedding(count=" << options.count_ - << ", dimension=" << options.dimension_ << ")"; + stream << "torch::nn::Embedding(count=" << options.count() + << ", dimension=" << options.dimension() << ")"; } Tensor EmbeddingImpl::forward(const Tensor& input) { diff --git a/torch/csrc/api/src/nn/modules/fold.cpp b/torch/csrc/api/src/nn/modules/fold.cpp new file mode 100644 index 0000000000000..6b0de8bc8cf87 --- /dev/null +++ b/torch/csrc/api/src/nn/modules/fold.cpp @@ -0,0 +1,31 @@ +#include + +#include +#include +#include + +namespace torch { +namespace nn { + +FoldImpl::FoldImpl(const FoldOptions& options_) : options(options_) { + reset(); +} + +Tensor FoldImpl::forward(const Tensor& input) { + TORCH_CHECK( + input.dim() == 3, + "Input Error: Only 3D input Tensors are supported (got ", + input.dim(), + "D)"); + + return torch::col2im( + input, + options.output_size(), + options.kernel_size(), + options.dilation(), + options.padding(), + options.stride()); +} + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/src/nn/modules/linear.cpp b/torch/csrc/api/src/nn/modules/linear.cpp index 8cd6842befc48..453bd127660a9 100644 --- a/torch/csrc/api/src/nn/modules/linear.cpp +++ b/torch/csrc/api/src/nn/modules/linear.cpp @@ -8,17 +8,16 @@ namespace torch { namespace nn { -LinearOptions::LinearOptions(int64_t in, int64_t out) : in_(in), out_(out) {} -LinearImpl::LinearImpl(LinearOptions options) : options(options) { +LinearImpl::LinearImpl(const LinearOptions& options_) : options(options_) { reset(); } void LinearImpl::reset() { weight = - register_parameter("weight", torch::empty({options.out_, options.in_})); - if (options.with_bias_) { - bias = register_parameter("bias", torch::empty(options.out_)); + register_parameter("weight", torch::empty({options.out(), options.in()})); + if (options.with_bias()) { + bias = register_parameter("bias", torch::empty(options.out())); } const auto stdv = 1.0 / std::sqrt(weight.size(1)); @@ -29,13 +28,13 @@ void LinearImpl::reset() { } void LinearImpl::pretty_print(std::ostream& stream) const { - stream << std::boolalpha << "torch::nn::Linear(in=" << options.in_ - << ", out=" << options.out_ << ", with_bias=" << options.with_bias_ + stream << std::boolalpha << "torch::nn::Linear(in=" << options.in() + << ", out=" << options.out() << ", with_bias=" << options.with_bias() << ")"; } Tensor LinearImpl::forward(const Tensor& input) { - AT_ASSERT(!options.with_bias_ || bias.defined()); + AT_ASSERT(!options.with_bias() || bias.defined()); return torch::linear(input, weight, bias); } } // namespace nn diff --git a/torch/csrc/api/src/nn/modules/loss.cpp b/torch/csrc/api/src/nn/modules/loss.cpp new file mode 100644 index 0000000000000..63faf7a951f70 --- /dev/null +++ b/torch/csrc/api/src/nn/modules/loss.cpp @@ -0,0 +1,18 @@ +#include + +namespace torch { +namespace nn { + +L1LossImpl::L1LossImpl(const torch::nn::L1LossOptions& options_) + : options(options_) {} + +void L1LossImpl::pretty_print(std::ostream& stream) const { + stream << "torch::nn::L1Loss"; +} + +Tensor L1LossImpl::forward(const Tensor& input, const Tensor& target) { + return torch::l1_loss(input, target, options.reduction()); +} + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/src/nn/modules/pooling.cpp b/torch/csrc/api/src/nn/modules/pooling.cpp new file mode 100644 index 0000000000000..8847dbe711b45 --- /dev/null +++ b/torch/csrc/api/src/nn/modules/pooling.cpp @@ -0,0 +1,73 @@ +#include + +#include + +namespace F = torch::nn::functional; + +namespace torch { +namespace nn { + +template +AvgPoolImpl::AvgPoolImpl(const AvgPoolOptions& options_) + : options(options_) {} + +template +void AvgPoolImpl::reset() {} + +template +void AvgPoolImpl::pretty_print(std::ostream& stream) const { + stream << "torch::nn::AvgPool" << D << "d" + << "(kernel_size=" << options.kernel_size() + << ", stride=" << options.stride() << ")"; +} + +Tensor AvgPool1dImpl::forward(const Tensor& input) { + return F::avg_pool1d(input, options); +} + +Tensor AvgPool2dImpl::forward(const Tensor& input) { + return F::avg_pool2d(input, options); +} + +Tensor AvgPool3dImpl::forward(const Tensor& input) { + return F::avg_pool3d(input, options); +} + +template class AvgPoolImpl<1, AvgPool1dImpl>; +template class AvgPoolImpl<2, AvgPool2dImpl>; +template class AvgPoolImpl<3, AvgPool3dImpl>; + +// ============================================================================ + +template +MaxPoolImpl::MaxPoolImpl(const MaxPoolOptions& options_) + : options(options_) {} + +template +void MaxPoolImpl::reset() {} + +template +void MaxPoolImpl::pretty_print(std::ostream& stream) const { + stream << "torch::nn::MaxPool" << D << "d" + << "(kernel_size=" << options.kernel_size() + << ", stride=" << options.stride() << ")"; +} + +Tensor MaxPool1dImpl::forward(const Tensor& input) { + return F::max_pool1d(input, options); +} + +Tensor MaxPool2dImpl::forward(const Tensor& input) { + return F::max_pool2d(input, options); +} + +Tensor MaxPool3dImpl::forward(const Tensor& input) { + return F::max_pool3d(input, options); +} + +template class MaxPoolImpl<1, MaxPool1dImpl>; +template class MaxPoolImpl<2, MaxPool2dImpl>; +template class MaxPoolImpl<3, MaxPool3dImpl>; + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/src/nn/modules/rnn.cpp b/torch/csrc/api/src/nn/modules/rnn.cpp index 8d6c1d8fa89c8..ec24498df7e6a 100644 --- a/torch/csrc/api/src/nn/modules/rnn.cpp +++ b/torch/csrc/api/src/nn/modules/rnn.cpp @@ -19,14 +19,8 @@ namespace torch { namespace nn { -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNOptionsBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -namespace detail { -RNNOptionsBase::RNNOptionsBase(int64_t input_size, int64_t hidden_size) - : input_size_(input_size), hidden_size_(hidden_size) {} - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNImplBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - +namespace detail { template RNNImplBase::RNNImplBase( const RNNOptionsBase& options_, @@ -40,19 +34,19 @@ RNNImplBase::RNNImplBase( template void RNNImplBase::reset() { - const auto num_directions = options.bidirectional_ ? 2 : 1; + const auto num_directions = options.bidirectional() ? 2 : 1; - w_ih.resize(options.layers_ * num_directions); - w_hh.resize(options.layers_ * num_directions); - b_ih.resize(options.layers_ * num_directions); - b_hh.resize(options.layers_ * num_directions); + w_ih.resize(options.layers() * num_directions); + w_hh.resize(options.layers() * num_directions); + b_ih.resize(options.layers() * num_directions); + b_hh.resize(options.layers() * num_directions); - const int64_t gate_size = options.hidden_size_ * number_of_gates_; + const int64_t gate_size = options.hidden_size() * number_of_gates_; - for (int64_t layer = 0; layer < options.layers_; ++layer) { + for (int64_t layer = 0; layer < options.layers(); ++layer) { for (auto direction = 0; direction < num_directions; direction++) { - const auto layer_input_size = layer == 0 ? options.input_size_ : - options.hidden_size_ * num_directions; + const auto layer_input_size = layer == 0 ? options.input_size() : + options.hidden_size() * num_directions; const auto suffix = direction == 1 ? "_reverse" : ""; const auto layer_idx = (layer * num_directions) + direction; w_ih[layer_idx] = this->register_parameter( @@ -60,9 +54,9 @@ void RNNImplBase::reset() { torch::empty({gate_size, layer_input_size})); w_hh[layer_idx] = this->register_parameter( "weight_hh_l" + std::to_string(layer) + suffix, - torch::empty({gate_size, options.hidden_size_})); + torch::empty({gate_size, options.hidden_size()})); - if (options.with_bias_) { + if (options.with_bias()) { b_ih[layer_idx] = this->register_parameter( "bias_ih_l" + std::to_string(layer) + suffix, torch::empty({gate_size})); @@ -75,7 +69,7 @@ void RNNImplBase::reset() { { NoGradGuard no_grad; - const auto stdv = 1.0 / std::sqrt(options.hidden_size_); + const auto stdv = 1.0 / std::sqrt(options.hidden_size()); for (auto& p : this->parameters()) { p.uniform_(-stdv, stdv); } @@ -102,13 +96,13 @@ void RNNImplBase::to(torch::Dtype dtype, bool non_blocking) { template void RNNImplBase::to(torch::Device device, bool non_blocking) { nn::Module::to(device, non_blocking); - const auto num_directions = options.bidirectional_ ? 2 : 1; - for (int64_t layer = 0; layer < options.layers_; layer++) { + const auto num_directions = options.bidirectional() ? 2 : 1; + for (int64_t layer = 0; layer < options.layers(); layer++) { for (auto direction = 0; direction < num_directions; direction++) { const auto layer_idx = (layer * num_directions) + direction; w_ih[layer_idx] = w_ih[layer_idx].to(device, non_blocking); w_hh[layer_idx] = w_hh[layer_idx].to(device, non_blocking); - if (options.with_bias_) { + if (options.with_bias()) { b_ih[layer_idx] = b_ih[layer_idx].to(device, non_blocking); b_hh[layer_idx] = b_hh[layer_idx].to(device, non_blocking); } @@ -121,9 +115,9 @@ template void RNNImplBase::pretty_print(std::ostream& stream) const { const std::string name = this->name(); const std::string name_without_impl = name.substr(0, name.size() - 4); - stream << name_without_impl << "(input_size=" << options.input_size_ - << ", hidden_size=" << options.hidden_size_ - << ", layers=" << options.layers_ << ", dropout=" << options.dropout_ + stream << name_without_impl << "(input_size=" << options.input_size() + << ", hidden_size=" << options.hidden_size() + << ", layers=" << options.layers() << ", dropout=" << options.dropout() << ")"; } @@ -139,13 +133,13 @@ void RNNImplBase::flatten_parameters() { NoGradGuard no_grad; torch::_cudnn_rnn_flatten_weight( flat_weights_, - /*weight_stride0=*/options.with_bias_ ? 4 : 2, - options.input_size_, + /*weight_stride0=*/options.with_bias() ? 4 : 2, + options.input_size(), static_cast(*cudnn_mode_), - options.hidden_size_, - options.layers_, - /*batch_first=*/options.batch_first_, - /*bidirectional=*/options.bidirectional_); + options.hidden_size(), + options.layers(), + /*batch_first=*/options.batch_first(), + /*bidirectional=*/options.bidirectional()); } template @@ -155,10 +149,10 @@ RNNOutput RNNImplBase::generic_forward( Tensor state) { if (!state.defined()) { // #layers, batch size, state size - const auto batch_size = input.size(options.batch_first_ ? 0 : 1); - const auto num_directions = options.bidirectional_ ? 2 : 1; + const auto batch_size = input.size(options.batch_first() ? 0 : 1); + const auto num_directions = options.bidirectional() ? 2 : 1; state = torch::zeros( - {options.layers_ * num_directions, batch_size, options.hidden_size_}, + {options.layers() * num_directions, batch_size, options.hidden_size()}, input.options()); } Tensor output, new_state; @@ -166,12 +160,12 @@ RNNOutput RNNImplBase::generic_forward( input, std::move(state), flat_weights_, - options.with_bias_, - options.layers_, - options.dropout_, + options.with_bias(), + options.layers(), + options.dropout(), this->is_training(), - options.bidirectional_, - options.batch_first_); + options.bidirectional(), + options.batch_first()); return {output, new_state}; } @@ -180,13 +174,13 @@ std::vector RNNImplBase::flat_weights() const { // Organize all weights in a flat vector in the order // (w_ih, w_hh, b_ih, b_hh), repeated for each layer (next to each other). std::vector flat; - const auto num_directions = options.bidirectional_ ? 2 : 1; - for (int64_t layer = 0; layer < options.layers_; layer++) { + const auto num_directions = options.bidirectional() ? 2 : 1; + for (int64_t layer = 0; layer < options.layers(); layer++) { for (auto direction = 0; direction < num_directions; direction++) { const auto layer_idx = (layer * num_directions) + direction; flat.push_back(w_ih[layer_idx]); flat.push_back(w_hh[layer_idx]); - if (options.with_bias_) { + if (options.with_bias()) { flat.push_back(b_ih[layer_idx]); flat.push_back(b_hh[layer_idx]); } @@ -217,39 +211,28 @@ template class RNNImplBase; // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -RNNOptions::RNNOptions(int64_t input_size, int64_t hidden_size) - : input_size_(input_size), hidden_size_(hidden_size) {} - -RNNOptions& RNNOptions::tanh() { - return activation(RNNActivation::Tanh); -} - -RNNOptions& RNNOptions::relu() { - return activation(RNNActivation::ReLU); -} - -RNNImpl::RNNImpl(const RNNOptions& options) +RNNImpl::RNNImpl(const RNNOptions& options_) : detail::RNNImplBase( - detail::RNNOptionsBase(options.input_size_, options.hidden_size_) - .layers(options.layers_) - .with_bias(options.with_bias_) - .dropout(options.dropout_) - .bidirectional(options.bidirectional_) - .batch_first(options.batch_first_), - static_cast(options.activation_)), - options(options) {} + detail::RNNOptionsBase(options_.input_size(), options_.hidden_size()) + .layers(options_.layers()) + .with_bias(options_.with_bias()) + .dropout(options_.dropout()) + .bidirectional(options_.bidirectional()) + .batch_first(options_.batch_first()), + static_cast(options_.activation())), + options(options_) {} void RNNImpl::pretty_print(std::ostream& stream) const { - stream << "torch::nn::RNN(input_size=" << options.input_size_ - << ", hidden_size=" << options.hidden_size_ - << ", layers=" << options.layers_ << ", dropout=" << options.dropout_ + stream << "torch::nn::RNN(input_size=" << options.input_size() + << ", hidden_size=" << options.hidden_size() + << ", layers=" << options.layers() << ", dropout=" << options.dropout() << ", activation=" - << (options.activation_ == RNNActivation::Tanh ? "tanh" : "relu") + << (options.activation() == RNNActivation::Tanh ? "tanh" : "relu") << ")"; } RNNOutput RNNImpl::forward(const Tensor& input, Tensor state) { - switch (options.activation_) { + switch (options.activation()) { case RNNActivation::ReLU: return generic_forward( static_cast(&torch::rnn_relu), @@ -267,9 +250,9 @@ RNNOutput RNNImpl::forward(const Tensor& input, Tensor state) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -LSTMImpl::LSTMImpl(const LSTMOptions& options) +LSTMImpl::LSTMImpl(const LSTMOptions& options_) : detail::RNNImplBase( - options, + options_, CuDNNMode::LSTM, /*number_of_gates=*/4) {} @@ -282,10 +265,10 @@ RNNOutput LSTMImpl::forward(const Tensor& input, Tensor state) { // different. So we just re-implement it specifically for the LSTM here. if (!state.defined()) { // 2 for hidden state and cell state, then #layers, batch size, state size - const auto batch_size = input.size(options.batch_first_ ? 0 : 1); - const auto num_directions = options.bidirectional_ ? 2 : 1; + const auto batch_size = input.size(options.batch_first() ? 0 : 1); + const auto num_directions = options.bidirectional() ? 2 : 1; state = torch::zeros( - {2, options.layers_ * num_directions, batch_size, options.hidden_size_}, + {2, options.layers() * num_directions, batch_size, options.hidden_size()}, input.options()); } Tensor output, hidden_state, cell_state; @@ -293,20 +276,20 @@ RNNOutput LSTMImpl::forward(const Tensor& input, Tensor state) { input, {state[0], state[1]}, flat_weights_, - options.with_bias_, - options.layers_, - options.dropout_, + options.with_bias(), + options.layers(), + options.dropout(), this->is_training(), - options.bidirectional_, - options.batch_first_); + options.bidirectional(), + options.batch_first()); return {output, torch::stack({hidden_state, cell_state})}; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -GRUImpl::GRUImpl(const GRUOptions& options) +GRUImpl::GRUImpl(const GRUOptions& options_) : detail::RNNImplBase( - options, + options_, CuDNNMode::GRU, /*number_of_gates=*/3) {} diff --git a/torch/csrc/api/src/nn/options/batchnorm.cpp b/torch/csrc/api/src/nn/options/batchnorm.cpp new file mode 100644 index 0000000000000..60c363b8d9e5c --- /dev/null +++ b/torch/csrc/api/src/nn/options/batchnorm.cpp @@ -0,0 +1,9 @@ +#include + +namespace torch { +namespace nn { + +BatchNormOptions::BatchNormOptions(int64_t features) : features_(features) {} + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/src/nn/options/conv.cpp b/torch/csrc/api/src/nn/options/conv.cpp new file mode 100644 index 0000000000000..d42d3fc8bc0d6 --- /dev/null +++ b/torch/csrc/api/src/nn/options/conv.cpp @@ -0,0 +1,11 @@ +#include + +namespace torch { +namespace nn { + +template struct ConvOptions<1>; +template struct ConvOptions<2>; +template struct ConvOptions<3>; + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/src/nn/options/dropout.cpp b/torch/csrc/api/src/nn/options/dropout.cpp new file mode 100644 index 0000000000000..9479fdda172b2 --- /dev/null +++ b/torch/csrc/api/src/nn/options/dropout.cpp @@ -0,0 +1,9 @@ +#include + +namespace torch { +namespace nn { + +DropoutOptions::DropoutOptions(double rate) : rate_(rate) {} + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/src/nn/options/linear.cpp b/torch/csrc/api/src/nn/options/linear.cpp new file mode 100644 index 0000000000000..00a5be6504234 --- /dev/null +++ b/torch/csrc/api/src/nn/options/linear.cpp @@ -0,0 +1,9 @@ +#include + +namespace torch { +namespace nn { + +LinearOptions::LinearOptions(int64_t in, int64_t out) : in_(in), out_(out) {} + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/src/nn/options/pooling.cpp b/torch/csrc/api/src/nn/options/pooling.cpp new file mode 100644 index 0000000000000..1d247b308c12d --- /dev/null +++ b/torch/csrc/api/src/nn/options/pooling.cpp @@ -0,0 +1,15 @@ +#include + +namespace torch { +namespace nn { + +template struct AvgPoolOptions<1>; +template struct AvgPoolOptions<2>; +template struct AvgPoolOptions<3>; + +template struct MaxPoolOptions<1>; +template struct MaxPoolOptions<2>; +template struct MaxPoolOptions<3>; + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/src/nn/options/rnn.cpp b/torch/csrc/api/src/nn/options/rnn.cpp new file mode 100644 index 0000000000000..5bdaa8f0dca23 --- /dev/null +++ b/torch/csrc/api/src/nn/options/rnn.cpp @@ -0,0 +1,25 @@ +#include + +namespace torch { +namespace nn { + +namespace detail { + +RNNOptionsBase::RNNOptionsBase(int64_t input_size, int64_t hidden_size) + : input_size_(input_size), hidden_size_(hidden_size) {} + +} // namespace detail + +RNNOptions::RNNOptions(int64_t input_size, int64_t hidden_size) + : input_size_(input_size), hidden_size_(hidden_size) {} + +RNNOptions& RNNOptions::tanh() { + return activation(RNNActivation::Tanh); +} + +RNNOptions& RNNOptions::relu() { + return activation(RNNActivation::ReLU); +} + +} // namespace nn +} // namespace torch diff --git a/torch/csrc/api/src/optim/adagrad.cpp b/torch/csrc/api/src/optim/adagrad.cpp index e5da19b8d3d06..4f931149f4f2e 100644 --- a/torch/csrc/api/src/optim/adagrad.cpp +++ b/torch/csrc/api/src/optim/adagrad.cpp @@ -23,14 +23,14 @@ void Adagrad::step() { continue; } - if (options.weight_decay_ > 0) { + if (options.weight_decay() > 0) { NoGradGuard guard; - p.grad() = p.grad() + options.weight_decay_ * p; + p.grad() = p.grad() + options.weight_decay() * p; } buffer_at(step_buffers, i) += 1.0; - const auto clr = options.learning_rate_ / - (1.0 + (buffer_at(step_buffers, i) - 1.0) * options.lr_decay_); + const auto clr = options.learning_rate() / + (1.0 + (buffer_at(step_buffers, i) - 1.0) * options.lr_decay()); auto& sum = buffer_at(sum_buffers, i); sum.addcmul_(p.grad(), p.grad(), 1.0); diff --git a/torch/csrc/api/src/optim/adam.cpp b/torch/csrc/api/src/optim/adam.cpp index 145a74e3753f1..c91f6dc49138f 100644 --- a/torch/csrc/api/src/optim/adam.cpp +++ b/torch/csrc/api/src/optim/adam.cpp @@ -22,9 +22,9 @@ void Adam::step() { continue; } - if (options.weight_decay_ > 0) { + if (options.weight_decay() > 0) { NoGradGuard guard; - p.grad() = p.grad() + options.weight_decay_ * p; + p.grad() = p.grad() + options.weight_decay() * p; } auto& exp_average = buffer_at(exp_average_buffers, i); @@ -32,16 +32,16 @@ void Adam::step() { buffer_at(step_buffers, i) += 1; const auto bias_correction1 = - 1 - std::pow(options.beta1_, buffer_at(step_buffers, i)); + 1 - std::pow(options.beta1(), buffer_at(step_buffers, i)); const auto bias_correction2 = - 1 - std::pow(options.beta2_, buffer_at(step_buffers, i)); + 1 - std::pow(options.beta2(), buffer_at(step_buffers, i)); - exp_average.mul_(options.beta1_).add_(p.grad(), 1 - options.beta1_); - exp_average_sq.mul_(options.beta2_) - .addcmul_(p.grad(), p.grad(), 1 - options.beta2_); + exp_average.mul_(options.beta1()).add_(p.grad(), 1 - options.beta1()); + exp_average_sq.mul_(options.beta2()) + .addcmul_(p.grad(), p.grad(), 1 - options.beta2()); Tensor denom; - if (options.amsgrad_) { + if (options.amsgrad()) { auto& max_exp_average_sq = buffer_at(max_exp_average_sq_buffers, i); max_exp_average_sq = torch::max(max_exp_average_sq, exp_average_sq); denom = max_exp_average_sq / bias_correction2; @@ -50,10 +50,10 @@ void Adam::step() { } const auto step_size = - options.learning_rate_ / bias_correction1; + options.learning_rate() / bias_correction1; NoGradGuard guard; - p.addcdiv_(exp_average, denom.sqrt() + options.eps_, -step_size); + p.addcdiv_(exp_average, denom.sqrt() + options.eps(), -step_size); } } diff --git a/torch/csrc/api/src/optim/lbfgs.cpp b/torch/csrc/api/src/optim/lbfgs.cpp index ab4c2e112a971..4346308223e0d 100644 --- a/torch/csrc/api/src/optim/lbfgs.cpp +++ b/torch/csrc/api/src/optim/lbfgs.cpp @@ -46,14 +46,14 @@ torch::Tensor LBFGS::step(LossClosure closure) { Tensor flat_grad = gather_flat_grad(); Tensor abs_grad_sum = flat_grad.abs().sum(); - if (abs_grad_sum.item() <= options.tolerance_grad_) { + if (abs_grad_sum.item() <= options.tolerance_grad()) { return loss; } Tensor ONE = torch::tensor(1, flat_grad.options()); int64_t n_iter = 0; - while (n_iter < options.max_iter_) { + while (n_iter < options.max_iter()) { n_iter++; state_n_iter++; @@ -69,7 +69,7 @@ torch::Tensor LBFGS::step(LossClosure closure) { if (ys.item() > 1e-10) { // updating memory - if (old_dirs.size() == options.history_size_) { + if (old_dirs.size() == options.history_size()) { // shift history by one (limited memory) old_dirs.pop_front(); old_stps.pop_front(); @@ -108,20 +108,20 @@ torch::Tensor LBFGS::step(LossClosure closure) { } /** - * comute step length + * compute step length */ // reset initial guess for step size if (n_iter == 1) { - t = torch::min(ONE, ONE / abs_grad_sum) * options.learning_rate_; + t = torch::min(ONE, ONE / abs_grad_sum) * options.learning_rate(); } else { - t = torch::tensor(options.learning_rate_, torch::kFloat32); + t = torch::tensor(options.learning_rate(), flat_grad.options()); } Tensor gtd = flat_grad.dot(d); add_grad(t, d); int64_t ls_func_evals = 0; - if (n_iter != options.max_iter_) { + if (n_iter != options.max_iter()) { // re-evaluate function only if not in last iteration // the reason we do this: in a stochastic setting, // no use to re-evaluate that function here @@ -137,20 +137,20 @@ torch::Tensor LBFGS::step(LossClosure closure) { * Check conditions */ - if (n_iter == options.max_iter_) { + if (n_iter == options.max_iter()) { break; - } else if (current_evals >= options.max_eval_) { + } else if (current_evals >= options.max_eval()) { break; - } else if (abs_grad_sum.item() <= options.tolerance_grad_) { + } else if (abs_grad_sum.item() <= options.tolerance_grad()) { break; - } else if (gtd.item() > -options.tolerance_grad_) { + } else if (gtd.item() > -options.tolerance_grad()) { break; } else if ( - d.mul(t).abs_().sum().item() <= options.tolerance_change_) { + d.mul(t).abs_().sum().item() <= options.tolerance_change()) { break; } else if ( std::abs(loss.item() - prev_loss.item()) < - options.tolerance_change_) { + options.tolerance_change()) { break; } } diff --git a/torch/csrc/api/src/optim/rmsprop.cpp b/torch/csrc/api/src/optim/rmsprop.cpp index 7018a16ce65bb..6e1d0fba6918f 100644 --- a/torch/csrc/api/src/optim/rmsprop.cpp +++ b/torch/csrc/api/src/optim/rmsprop.cpp @@ -23,33 +23,33 @@ void RMSprop::step() { continue; } - if (options.weight_decay_ > 0) { + if (options.weight_decay() > 0) { NoGradGuard guard; - p.grad() = p.grad() + options.weight_decay_ * p; + p.grad() = p.grad() + options.weight_decay() * p; } auto square_average = buffer_at(square_average_buffers, i); - square_average.mul_(options.alpha_) - .addcmul_(p.grad(), p.grad(), 1.0 - options.alpha_); + square_average.mul_(options.alpha()) + .addcmul_(p.grad(), p.grad(), 1.0 - options.alpha()); Tensor average; - if (options.centered_ > 0) { + if (options.centered() > 0) { auto& grad_average = buffer_at(grad_average_buffers, i); - grad_average.mul_(options.alpha_).add_(p.grad(), 1.0 - options.alpha_); + grad_average.mul_(options.alpha()).add_(p.grad(), 1.0 - options.alpha()); average = square_average.addcmul(grad_average, grad_average, -1.0) .sqrt() - .add_(options.eps_); + .add_(options.eps()); } else { - average = square_average.sqrt().add_(options.eps_); + average = square_average.sqrt().add_(options.eps()); } NoGradGuard guard; - if (options.momentum_ > 0) { + if (options.momentum() > 0) { auto& momentum = buffer_at(momentum_buffers, i); - momentum.mul_(options.momentum_).addcdiv_(p.grad(), average); - p.add_(momentum, -options.learning_rate_); + momentum.mul_(options.momentum()).addcdiv_(p.grad(), average); + p.add_(momentum, -options.learning_rate()); } else { - p.addcdiv_(p.grad(), average, -options.learning_rate_); + p.addcdiv_(p.grad(), average, -options.learning_rate()); } } } diff --git a/torch/csrc/api/src/optim/sgd.cpp b/torch/csrc/api/src/optim/sgd.cpp index 78a6fd847b1e9..86764aa618815 100644 --- a/torch/csrc/api/src/optim/sgd.cpp +++ b/torch/csrc/api/src/optim/sgd.cpp @@ -25,26 +25,26 @@ void SGD::step() { auto update = p.grad(); - if (options.weight_decay_ > 0) { + if (options.weight_decay() > 0) { NoGradGuard guard; - update += options.weight_decay_ * p; + update += options.weight_decay() * p; } - if (options.momentum_ != 0) { - const auto dampening = iteration_ == 0 ? 1 : 1 - options.dampening_; + if (options.momentum() != 0) { + const auto dampening = iteration_ == 0 ? 1 : 1 - options.dampening(); auto& momentum = buffer_at(momentum_buffers, i); - momentum = (options.momentum_ * momentum) + (dampening * update); - if (options.nesterov_) { + momentum = (options.momentum() * momentum) + (dampening * update); + if (options.nesterov()) { // See github.com/lisa-lab/pylearn2/pull/136#issuecomment-10381617 // for notes on this implementation of nesterov momentum. - update += options.momentum_ * momentum; + update += options.momentum() * momentum; } else { update = momentum; } } NoGradGuard guard; - p.add_(-options.learning_rate_ * update); + p.add_(-options.learning_rate() * update); } iteration_ += 1; } diff --git a/torch/csrc/api/src/serialize.cpp b/torch/csrc/api/src/serialize.cpp new file mode 100644 index 0000000000000..2c1fcb28d9cad --- /dev/null +++ b/torch/csrc/api/src/serialize.cpp @@ -0,0 +1,13 @@ +#include +#include +#include + +#include + +namespace torch { + +std::vector pickle_save(const at::IValue& ivalue) { + return jit::pickle_save(ivalue); +} + +} // namespace torch diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index 84047295978ec..4015c9fb72028 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -93,6 +93,18 @@ void VariableType::set_data(const Tensor & self, const Tensor & new_data) { as_variable_ref(self).set_data(new_data); } +Tensor VariableType::data(const Tensor & self) { + return as_variable_ref(self).variable_data(); +} + +bool VariableType::is_leaf(const Tensor & self) { + return as_variable_ref(self).is_leaf(); +} + +int64_t VariableType::output_nr(const Tensor & self) { + return as_variable_ref(self).output_nr(); +} + // We don't have an outplace copy, so this can't be generated automatically Tensor & VariableType::copy_(Tensor & self, const Tensor & src, bool non_blocking) { jit::Value* output = nullptr; diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 125a4e5c17be6..2a43c9197e029 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -12,6 +12,11 @@ #include #include #include +#include +#include +#include +#include +#include #include #include @@ -137,6 +142,35 @@ struct ReadyQueue { // When the GraphTask is finished, the parent worker thread that is waiting on // the task is notified and the current thread returns to the pool. +// Note [Streaming backwards] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~ +// On CUDA devices the autograd engine's device operations are run on the +// same stream that ran them in forward. This requires automatically +// syncing the streams so that function A finishes producing its +// output before function B consumes it. +// +// This synchronization occurs when outputs are placed into input buffers. +// The functions corresponding to input buffer positions have metadata +// recording their streams from forward, and during backward this +// data is used to sync the producer's stream with the consumer's. +// +// When a CUDA function is run either all its inputs were accumulated on the +// stream used to run the function OR the inputs are on different devices +// and the function is responsible for properly acquiring them. +// +// Historically, the autograd engine ran all CUDA operations on their +// device's DEFAULT stream. This meant that syncing (implicitly or +// explicitly) with the default streams was required before and after +// calling backward(). It also meant, however, that syncing with +// the default streams after backward() was sufficient to ensure +// that backward() had finished running. To preserve this historic +// behavior the engine records "leaf streams," the streams of the +// leaf variables, and syncs them with their device's default stream +// at the end of backward. All other streams are already synchronized +// to happen before at least one leaf stream (per the above), so syncing +// the leaf streams with the default streams is sufficient to implement +// the historic behavior. + // GraphTask holds metadata needed for a single execution of backward() struct GraphTask { std::exception_ptr exception_; @@ -181,6 +215,7 @@ struct GraphTask { std::vector captured_vars_; std::shared_ptr debug_info_ = at::getThreadLocalDebugInfo(); + std::unordered_set leaf_streams; void init_to_execute(Node& graph_root, const edge_list& outputs); @@ -454,11 +489,15 @@ static void validate_outputs(const edge_list& edges, variable_list& grads, const } grads[i] = at::sum_to(std::move(grads[i]), metadata.shape()); } + TORCH_CHECK(isFloatingType(grads[i].type().scalarType())); + if (metadata.type().scalarType() != grads[i].type().scalarType()) { + grads[i] = grads[i].to(metadata.type().scalarType()); + } if (!is_compatible_type(metadata.type(), grads[i].type())) { - std::stringstream ss; - ss << "invalid gradient at index " << i << " - expected type "; - ss << metadata.type() << " but got " << grads[i].type(); - AT_ERROR(format_error(ss.str())); + std::stringstream ss; + ss << "invalid gradient at index " << i << " - expected type "; + ss << metadata.type() << " but got " << grads[i].type(); + AT_ERROR(format_error(ss.str())); } auto output_device = output.device(); if (output_device != metadata.device()) { @@ -537,6 +576,10 @@ auto Engine::evaluate_function(NodeTask& task) -> void { if (!fn_info.needed_) return; } + // Switches to a function's CUDA stream (if applicable) before calling it + const auto opt_parent_stream = (*task.fn_).stream(c10::DeviceType::CUDA); + c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream}; + auto outputs = call_function(task); auto& fn = *task.fn_; @@ -545,7 +588,14 @@ auto Engine::evaluate_function(NodeTask& task) -> void { } int num_outputs = outputs.size(); - if (num_outputs == 0) return; // Don't even acquire the mutex + if (num_outputs == 0) { // Note: doesn't acquire the mutex + // Records leaf stream (if applicable) + // See note "Streaming backwards" + if (opt_parent_stream) { + task.base_->leaf_streams.emplace(*opt_parent_stream); + } + return; + } if (AnomalyMode::is_enabled()) { AutoGradMode grad_mode(false); @@ -592,7 +642,14 @@ auto Engine::evaluate_function(NodeTask& task) -> void { } // No buffers have been allocated for the function InputBuffer input_buffer(next.function->num_inputs()); - input_buffer.add(next.input_nr, std::move(output)); + + // Accumulates into buffer + const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA); + input_buffer.add(next.input_nr, + std::move(output), + opt_parent_stream, + opt_next_stream); + if (is_ready) { auto& queue = ready_queue(input_buffer.device()); queue.push(NodeTask(task.base_, next.function, std::move(input_buffer))); @@ -602,7 +659,13 @@ auto Engine::evaluate_function(NodeTask& task) -> void { } else { // The function already has a buffer auto &input_buffer = not_ready_it->second; - input_buffer.add(next.input_nr, std::move(output)); + + // Accumulates into buffer + const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA); + input_buffer.add(next.input_nr, + std::move(output), + opt_parent_stream, + opt_next_stream); if (is_ready) { auto& queue = ready_queue(input_buffer.device()); queue.push(NodeTask(task.base_, next.function, std::move(input_buffer))); @@ -727,6 +790,18 @@ auto Engine::execute(const edge_list& roots, cb_lock.lock(); } + // Syncs leaf streams with default streams (if necessary) + // See note "Streaming backwards" + for (const auto& leaf_stream : graph_task.leaf_streams) { + const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA}; + const auto default_stream = guard.getDefaultStream(leaf_stream.device()); + if (leaf_stream != default_stream) { + auto event = c10::Event{c10::DeviceType::CUDA}; + event.record(leaf_stream); + default_stream.wait(event); + } + } + return graph_task.captured_vars_; } diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index 4b4d7d6eb8271..29fd02b184410 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include diff --git a/torch/csrc/autograd/functions/tensor.cpp b/torch/csrc/autograd/functions/tensor.cpp index 5499dd116154b..33932b5862844 100644 --- a/torch/csrc/autograd/functions/tensor.cpp +++ b/torch/csrc/autograd/functions/tensor.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index dfa94fbc067da..c98de3fd714b6 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -7,7 +7,7 @@ #include #include -PyObject* THPAutograd_initExtension(PyObject* _unused) { +PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) { using namespace torch::autograd::profiler; auto tensor_module = THPObjectPtr(PyImport_ImportModule("torch.tensor")); if (!tensor_module) diff --git a/torch/csrc/autograd/input_buffer.cpp b/torch/csrc/autograd/input_buffer.cpp index 8b37a7be2d1e7..96d9a6f41855a 100644 --- a/torch/csrc/autograd/input_buffer.cpp +++ b/torch/csrc/autograd/input_buffer.cpp @@ -1,6 +1,9 @@ #include -#include +#include +#include +#include +#include #include #include @@ -8,21 +11,15 @@ namespace torch { namespace autograd { - -void InputBuffer::add(size_t pos, Variable var) { - AT_ASSERT(pos < buffer.size()); - if (!var.defined()) { - return; - } - auto& old_var = buffer[pos]; - if (!old_var.defined()) { - buffer[pos] = std::move(var); - } else { - at::OptionalDeviceGuard device_guard(device_of(var)); + static void accumulate(std::vector& buffer, + const size_t pos, + Variable&& var) { + TORCH_INTERNAL_ASSERT(pos < buffer.size()); + auto& old_var = buffer[pos]; // ATen doesn't route sparse additions correctly... // do dense + sparse in-place if possible if (old_var.is_sparse()) { -//storage use_count is a big hammer, but for anything lighter there's an adversarial example with unexpected inplace modification + //storage use_count is a big hammer, but for anything lighter there's an adversarial example with unexpected inplace modification if (!var.is_sparse() && var.is_contiguous() && var.storage().use_count() == 1) { buffer[pos] = var.add_(old_var); } else { @@ -36,6 +33,69 @@ void InputBuffer::add(size_t pos, Variable var) { } } } + + void InputBuffer::add(size_t pos, + Variable&& var, + const c10::optional& opt_producer_stream, + const c10::optional& opt_consumer_stream) { + TORCH_INTERNAL_ASSERT(pos < buffer.size()); + if (!var.defined()) { + return; + } + + // Switches to accumulate device + // The device (and stream) chosen for accumulation is: + // (1) If the variable is not a CUDA variable, accumulation happens on the + // device of the variable. + // (2) If the variable is a CUDA variable, and the producer and consumer + // share its device, then: + // (2a) if the producer and consumer do not share a stream, + // the consumer is synced with the producer. + // (2b) accumulation happens on the consumer's stream + // (3) If the variable is a CUDA variable but it, the producer, and the + // consumer are on multiple devices, then accumulation happens on + // the default stream of the variable's device. + + TORCH_INTERNAL_ASSERT(device_of(var)); + c10::optional opt_accumulate_stream = c10::nullopt; + if (device_of(var)->is_cuda()) { + const auto on_producer = opt_producer_stream + && device_of(var) == opt_producer_stream->device(); + const auto on_consumer = opt_consumer_stream + && device_of(var) == opt_consumer_stream->device(); + if (on_producer && on_consumer) { + // (2) CUDA variable with producer and consumer sharing a device + // Accumulation happens on consumer's stream + opt_accumulate_stream = opt_consumer_stream; + if (opt_producer_stream != opt_consumer_stream) { + // (2a) Syncs consumer with producer + auto event = c10::Event{c10::DeviceType::CUDA}; + event.record(*opt_producer_stream); + opt_consumer_stream->wait(event); + } + } else { + // (3) CUDA variable with multiple devices + // Accumulation happens on variable's device's default stream + const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA}; + const auto default_stream = guard.getDefaultStream(*device_of(var)); + opt_accumulate_stream = default_stream; + } + } + + auto& old_var = buffer[pos]; + if (!old_var.defined()) { + buffer[pos] = std::move(var); + } else { + if (opt_accumulate_stream) { + c10::OptionalStreamGuard stream_guard{opt_accumulate_stream}; + accumulate(buffer, pos, std::move(var)); + } else { + // (1) non-CUDA variable + // Accumulation happens on variable's device + c10::OptionalDeviceGuard device_guard{device_of(var)}; + accumulate(buffer, pos, std::move(var)); + } + } } auto InputBuffer::device() const -> at::Device { diff --git a/torch/csrc/autograd/input_buffer.h b/torch/csrc/autograd/input_buffer.h index b0d506ef756a7..02bcde5d9f968 100644 --- a/torch/csrc/autograd/input_buffer.h +++ b/torch/csrc/autograd/input_buffer.h @@ -11,6 +11,8 @@ #include #include +#include +#include namespace torch { namespace autograd { @@ -22,7 +24,12 @@ struct InputBuffer { InputBuffer& operator=(InputBuffer&& other) = default; // Accumulates the variable at a specified index. - void add(size_t pos, Variable var); + // The optional CUDA streams determine which stream the accumulation + // is run on and how the addition is synchronized. + void add(size_t pos, + Variable&& var, + const c10::optional& opt_producer_stream, + const c10::optional& opt_consumer_stream); at::Device device() const; diff --git a/torch/csrc/autograd/profiler_cuda.cpp b/torch/csrc/autograd/profiler_cuda.cpp index b03d574431877..ae3b1bbc86a56 100644 --- a/torch/csrc/autograd/profiler_cuda.cpp +++ b/torch/csrc/autograd/profiler_cuda.cpp @@ -1,6 +1,8 @@ #include #include +#ifndef __HIP_PLATFORM_HCC__ #include +#endif #include @@ -33,13 +35,19 @@ struct CUDAMethods : public CUDAStubs { return ms*1000.0; } void nvtxMarkA(const char* name) override { +#ifndef __HIP_PLATFORM_HCC__ ::nvtxMark(name); +#endif } void nvtxRangePushA(const char* name) override { +#ifndef __HIP_PLATFORM_HCC__ ::nvtxRangePushA(name); +#endif } void nvtxRangePop() override { +#ifndef __HIP_PLATFORM_HCC__ ::nvtxRangePop(); +#endif } void onEachDevice(std::function op) override { at::cuda::OptionalCUDAGuard device_guard; diff --git a/torch/csrc/autograd/python_autograd.h b/torch/csrc/autograd/python_autograd.h index 0f9ab2e9846c3..04771dca2af58 100644 --- a/torch/csrc/autograd/python_autograd.h +++ b/torch/csrc/autograd/python_autograd.h @@ -1,7 +1,7 @@ #ifndef THP_AUTOGRAD_H #define THP_AUTOGRAD_H -PyObject * THPAutograd_initExtension(PyObject *_unused); +PyObject * THPAutograd_initExtension(PyObject *_unused, PyObject *unused); void THPAutograd_initFunctions(); namespace torch { namespace autograd { diff --git a/torch/csrc/autograd/python_cpp_function.cpp b/torch/csrc/autograd/python_cpp_function.cpp index 58341e95e7ea5..6aab79b816140 100644 --- a/torch/csrc/autograd/python_cpp_function.cpp +++ b/torch/csrc/autograd/python_cpp_function.cpp @@ -130,7 +130,7 @@ PyObject* THPCppFunction_metadata(THPCppFunction *self, void *_unused) return metadata; } -PyObject* THPCppFunction_requires_grad(THPCppFunction* self) { +PyObject* THPCppFunction_requires_grad(THPCppFunction* self, void *unused) { Py_RETURN_TRUE; } @@ -153,7 +153,7 @@ PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook) return registerFunctionHook(fn, hook); } -PyObject* THPCppFunction_name(PyObject* self) { +PyObject* THPCppFunction_name(PyObject* self, PyObject *noargs) { auto& fn = *((THPCppFunction*)self)->cdata; return THPUtils_packString(fn.name()); } diff --git a/torch/csrc/autograd/python_cpp_function.h b/torch/csrc/autograd/python_cpp_function.h index 6a4228ecd6122..1b51b69adcab7 100644 --- a/torch/csrc/autograd/python_cpp_function.h +++ b/torch/csrc/autograd/python_cpp_function.h @@ -42,10 +42,10 @@ PyObject* CppFunction_pynew(PyTypeObject *type, PyObject *args, PyObject *kwds) PyObject* THPCppFunction_next_functions(THPCppFunction* self, PyObject* hook); PyObject* THPCppFunction_metadata(THPCppFunction *self, void *_unused); -PyObject* THPCppFunction_requires_grad(THPCppFunction* self); +PyObject* THPCppFunction_requires_grad(THPCppFunction* self, void *_unused); PyObject* THPCppFunction_register_hook_dict(PyObject* self, PyObject* _var); PyObject* THPCppFunction_register_hook(PyObject* self, PyObject* hook); -PyObject* THPCppFunction_name(PyObject* self); +PyObject* THPCppFunction_name(PyObject* self, PyObject *noargs); PyTypeObject* _initFunctionPyTypeObject(PyTypeObject& type, const char* name, PyGetSetDef* function_properties, PyMethodDef* function_methods); diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp index 3de513b18305d..0cdfae6b6c545 100644 --- a/torch/csrc/autograd/python_engine.cpp +++ b/torch/csrc/autograd/python_engine.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #ifndef _WIN32 #include @@ -212,7 +213,7 @@ PyObject* THPEngine_queue_callback(PyObject *self, PyObject *_callback) { END_HANDLE_TH_ERRORS } -PyObject* THPEngine_is_checkpoint_valid(PyObject *self) { +PyObject* THPEngine_is_checkpoint_valid(PyObject *self, PyObject *noargs) { HANDLE_TH_ERRORS if(engine.is_checkpoint_valid()) { Py_RETURN_TRUE; @@ -228,7 +229,7 @@ PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) } static struct PyMethodDef THPEngine_methods[] = { - {(char*)"run_backward", (PyCFunction)THPEngine_run_backward, METH_VARARGS | METH_KEYWORDS, nullptr}, + {(char*)"run_backward", (PyCFunction)(void(*)(void))THPEngine_run_backward, METH_VARARGS | METH_KEYWORDS, nullptr}, {(char*)"queue_callback", (PyCFunction)THPEngine_queue_callback, METH_O, nullptr}, {(char*)"is_checkpoint_valid", (PyCFunction)THPEngine_is_checkpoint_valid, METH_NOARGS, nullptr}, {nullptr} diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index dcb888d9405e2..a66c0b49d262e 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #ifdef BUILD_NAMEDTENSOR #include #endif @@ -176,7 +177,7 @@ static PyObject* THPVariable_make_subclass(PyObject* _ignored, PyObject* args, P typedef PyObject *(*getter)(PyObject *, void *); typedef int (*setter)(PyObject *, PyObject *, void *); -PyObject *THPVariable_get_T(THPVariable *self) +PyObject *THPVariable_get_T(THPVariable *self, void *unused) { HANDLE_TH_ERRORS auto& var = self->cdata; @@ -184,7 +185,7 @@ PyObject *THPVariable_get_T(THPVariable *self) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_cdata(THPVariable *self) +PyObject *THPVariable_get_cdata(THPVariable *self, void *unused) { HANDLE_TH_ERRORS auto& var = self->cdata; @@ -192,7 +193,7 @@ PyObject *THPVariable_get_cdata(THPVariable *self) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_version(THPVariable *self) +PyObject *THPVariable_get_version(THPVariable *self, void *unused) { HANDLE_TH_ERRORS auto& var = self->cdata; @@ -200,7 +201,7 @@ PyObject *THPVariable_get_version(THPVariable *self) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_grad_fn(THPVariable *self) +PyObject *THPVariable_get_grad_fn(THPVariable *self, void *unused) { HANDLE_TH_ERRORS auto& var = self->cdata; @@ -211,7 +212,7 @@ PyObject *THPVariable_get_grad_fn(THPVariable *self) END_HANDLE_TH_ERRORS } -static int THPVariable_set_grad_fn(THPVariable *self, PyObject *obj) +static int THPVariable_set_grad_fn(THPVariable *self, PyObject *obj, void *unused) { HANDLE_TH_ERRORS THPUtils_assertRet(-1, obj, "Deletion of _grad_fn not allowed. Detach tensor instead!"); @@ -221,14 +222,14 @@ static int THPVariable_set_grad_fn(THPVariable *self, PyObject *obj) END_HANDLE_TH_ERRORS_RET(-1) } -static PyObject *THPVariable_is_leaf(THPVariable *self) +static PyObject *THPVariable_is_leaf(THPVariable *self, void *unused) { HANDLE_TH_ERRORS return PyBool_FromLong(!self->cdata.grad_fn()); END_HANDLE_TH_ERRORS } -static PyObject * THPVariable_get_data(THPVariable *self) +static PyObject * THPVariable_get_data(THPVariable *self, void *unused) { HANDLE_TH_ERRORS auto var = self->cdata.variable_data(); @@ -236,7 +237,7 @@ static PyObject * THPVariable_get_data(THPVariable *self) END_HANDLE_TH_ERRORS } -int THPVariable_set_data(THPVariable *self, PyObject *data) +int THPVariable_set_data(THPVariable *self, PyObject *data, void *unused) { HANDLE_TH_ERRORS THPUtils_assertRet(-1, data, "Deleting tensor data is not allowed. Delete tensor instead!"); @@ -249,14 +250,14 @@ int THPVariable_set_data(THPVariable *self, PyObject *data) END_HANDLE_TH_ERRORS_RET(-1) } -PyObject *THPVariable_get_grad(THPVariable *self) +PyObject *THPVariable_get_grad(THPVariable *self, void *unused) { HANDLE_TH_ERRORS return THPVariable_Wrap(self->cdata.grad()); END_HANDLE_TH_ERRORS } -int THPVariable_set_grad(THPVariable *self, PyObject *py_grad) +int THPVariable_set_grad(THPVariable *self, PyObject *py_grad, void *unused) { HANDLE_TH_ERRORS auto& var = self->cdata; @@ -273,7 +274,7 @@ int THPVariable_set_grad(THPVariable *self, PyObject *py_grad) auto& grad = ((THPVariable*)py_grad)->cdata; bool gradIsSparse = (var.dtype() == grad.dtype() && var.device().type() == grad.device().type() && - layout_from_backend(tensorTypeIdToBackend(grad.type_id())) == kSparse); + grad.layout() == kSparse); THPUtils_assertRet(-1, grad.type() == var.type() || gradIsSparse, "assigned grad has data of a different type"); if (var.is_cuda()) { @@ -288,19 +289,19 @@ int THPVariable_set_grad(THPVariable *self, PyObject *py_grad) END_HANDLE_TH_ERRORS_RET(-1) } -PyObject *THPVariable_get_volatile(THPVariable *self) +PyObject *THPVariable_get_volatile(THPVariable *self, void *unused) { const char* msg = "volatile was removed (Variable.volatile is always False)"; PyErr_WarnEx(PyExc_UserWarning, msg, 1); Py_RETURN_FALSE; } -int THPVariable_set_volatile(THPVariable *self, PyObject *obj) +int THPVariable_set_volatile(THPVariable *self, PyObject *obj, void *unused) { return PyErr_WarnEx(PyExc_UserWarning, VOLATILE_WARNING, 1); } -PyObject *THPVariable_get_output_nr(THPVariable *self) +PyObject *THPVariable_get_output_nr(THPVariable *self, void *unused) { HANDLE_TH_ERRORS const auto output_nr = static_cast(self->cdata.output_nr()); @@ -308,14 +309,14 @@ PyObject *THPVariable_get_output_nr(THPVariable *self) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_requires_grad(THPVariable *self) +PyObject *THPVariable_get_requires_grad(THPVariable *self, void *unused) { HANDLE_TH_ERRORS return PyBool_FromLong(self->cdata.requires_grad()); END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_ndim(THPVariable *self) +PyObject *THPVariable_get_ndim(THPVariable *self, void *unused) { HANDLE_TH_ERRORS return PyInt_FromLong(self->cdata.dim()); @@ -323,7 +324,7 @@ PyObject *THPVariable_get_ndim(THPVariable *self) } #ifdef BUILD_NAMEDTENSOR -PyObject *THPVariable_get_names(THPVariable *self) +PyObject *THPVariable_get_names(THPVariable *self, void *unused) { HANDLE_TH_ERRORS // The long-term plan is to return a list of (python) torch.Dimname. @@ -336,7 +337,7 @@ PyObject *THPVariable_get_names(THPVariable *self) for (size_t i = 0; i < size; ++i) { PyObject* str = Py_None; if (dimnames[i].type() != at::NameType::WILDCARD) { - str = THPUtils_packString(dimnames[i].full_name().toUnqualString()); + str = THPUtils_packString(dimnames[i].symbol().toUnqualString()); if (!str) throw python_error(); } PyTuple_SET_ITEM(tuple.get(), i, str); @@ -361,7 +362,7 @@ int THPVariable_set_names(THPVariable *self, PyObject *names) { } #endif -int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj) +int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj, void *unused) { HANDLE_TH_ERRORS THPUtils_assertRet(-1, obj && PyBool_Check(obj), "requires_grad must be a bool"); @@ -380,14 +381,14 @@ int THPVariable_set_requires_grad(THPVariable *self, PyObject *obj) END_HANDLE_TH_ERRORS_RET(-1) } -PyObject *THPVariable_get_name(THPVariable* self) +PyObject *THPVariable_get_name(THPVariable* self, void *unused) { if (self->cdata.name() == "") Py_RETURN_NONE; return THPUtils_packString(self->cdata.name().c_str()); } -PyObject *THPVariable_get_backwards_hooks(THPVariable *self) +PyObject *THPVariable_get_backwards_hooks(THPVariable *self, void *unused) { HANDLE_TH_ERRORS if (self->backward_hooks) { @@ -398,7 +399,7 @@ PyObject *THPVariable_get_backwards_hooks(THPVariable *self) END_HANDLE_TH_ERRORS } -int THPVariable_set_backwards_hooks(THPVariable *self, PyObject *obj) +int THPVariable_set_backwards_hooks(THPVariable *self, PyObject *obj, void *unused) { HANDLE_TH_ERRORS THPUtils_assertRet(-1, obj, "Deletion of _backwards_hooks not allowed!"); @@ -416,7 +417,7 @@ int THPVariable_set_backwards_hooks(THPVariable *self, PyObject *obj) END_HANDLE_TH_ERRORS_RET(-1) } -PyObject *THPVariable_get_base(THPVariable *self) +PyObject *THPVariable_get_base(THPVariable *self, void *unused) { HANDLE_TH_ERRORS if (self->cdata.is_view()) { @@ -426,14 +427,14 @@ PyObject *THPVariable_get_base(THPVariable *self) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_get_shape(THPVariable *self) +PyObject *THPVariable_get_shape(THPVariable *self, void *unused) { HANDLE_TH_ERRORS return THPSize_New(self->cdata); END_HANDLE_TH_ERRORS } -PyObject *THPVariable_is_cuda(THPVariable *self) +PyObject *THPVariable_is_cuda(THPVariable *self, void *unused) { HANDLE_TH_ERRORS auto& self_ = self->cdata; @@ -441,7 +442,7 @@ PyObject *THPVariable_is_cuda(THPVariable *self) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_is_sparse(THPVariable *self) +PyObject *THPVariable_is_sparse(THPVariable *self, void *unused) { HANDLE_TH_ERRORS auto& self_ = self->cdata; @@ -449,7 +450,7 @@ PyObject *THPVariable_is_sparse(THPVariable *self) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_is_mkldnn(THPVariable *self) +PyObject *THPVariable_is_mkldnn(THPVariable *self, void *unused) { HANDLE_TH_ERRORS auto& self_ = self->cdata; @@ -457,7 +458,7 @@ PyObject *THPVariable_is_mkldnn(THPVariable *self) END_HANDLE_TH_ERRORS } -PyObject *THPVariable_is_quantized(THPVariable *self) +PyObject *THPVariable_is_quantized(THPVariable *self, void *unused) { HANDLE_TH_ERRORS auto& self_ = self->cdata; @@ -465,7 +466,7 @@ PyObject *THPVariable_is_quantized(THPVariable *self) END_HANDLE_TH_ERRORS } -static PyObject *THPVariable_dtype(THPVariable *self) +static PyObject *THPVariable_dtype(THPVariable *self, void *unused) { HANDLE_TH_ERRORS auto& self_ = self->cdata; @@ -473,14 +474,14 @@ static PyObject *THPVariable_dtype(THPVariable *self) END_HANDLE_TH_ERRORS } -static PyObject * THPVariable_layout(THPVariable* self) { +static PyObject * THPVariable_layout(THPVariable* self, void *unused) { HANDLE_TH_ERRORS auto& self_ = self->cdata; return torch::autograd::utils::wrap(torch::getLayout(self_.type().backend())); END_HANDLE_TH_ERRORS } -static PyObject * THPVariable_device(THPVariable* self) { +static PyObject * THPVariable_device(THPVariable* self, void *unused) { HANDLE_TH_ERRORS return THPDevice_New(self->cdata.device()); END_HANDLE_TH_ERRORS @@ -524,7 +525,7 @@ static PyMappingMethods THPVariable_as_mapping = { }; static PyMethodDef extra_methods[] = { - {"_make_subclass", (PyCFunction)THPVariable_make_subclass, METH_STATIC | METH_VARARGS | METH_KEYWORDS, nullptr}, + {"_make_subclass", (PyCFunction)(void(*)(void))THPVariable_make_subclass, METH_STATIC | METH_VARARGS | METH_KEYWORDS, nullptr}, {nullptr} }; diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index 2e1fda6ef472c..df2920daaa3cc 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -111,14 +111,11 @@ static Variable sequenceToVariable(c10::TensorTypeId type_id, PyObject* seq) { return torch::utils::indexing_tensor_from_data(type_id, kLong, c10::nullopt, seq); } -static Variable valueToTensor(c10::TensorTypeId type_id, ScalarType scalar_type, PyObject* value) { +static Variable valueToTensor(c10::TensorOptions options, PyObject* value) { if (THPVariable_Check(value)) { return reinterpret_cast(value)->cdata; } - auto options = TensorOptions(scalar_type) - .device(computeDeviceType(type_id)) - .layout(layout_from_backend(tensorTypeIdToBackend(type_id))) - .is_variable(true); + options = options.is_variable(true); if (THPUtils_checkLong(value) || PyBool_Check(value)) { return at::scalar_tensor(Scalar(THPUtils_unpackLong(value)), options); } @@ -128,7 +125,7 @@ static Variable valueToTensor(c10::TensorTypeId type_id, ScalarType scalar_type, throw TypeError( "can't assign a %s to a %s", Py_TYPE(value)->tp_name, - torch::utils::type_to_string(getNonVariableDeprecatedTypeProperties(tensorTypeIdToBackend(type_id), scalar_type)).c_str()); + torch::utils::type_to_string(getNonVariableDeprecatedTypeProperties(options.backend(), typeMetaToScalarType(options.dtype()))).c_str()); } static Variable boolToIndexingTensor(const Variable& self, bool value) { @@ -190,7 +187,9 @@ static Variable applySlicing(const Variable& self, PyObject* index, variable_lis handle_var(var); } } else if (PySequence_Check(obj)) { - handle_var(sequenceToVariable(self.type_id(), obj)); + // TODO: Naughty naughty get out of jail free + // (Fixing this means I have to fix the call chain though :/) + handle_var(sequenceToVariable(legacyExtractTypeId(self), obj)); } else { auto index = THPObjectPtr(PyNumber_Index(obj)); if (!index) { @@ -347,10 +346,11 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); Variable value; + // TODO: This qint special case looks very suspicious... if (isQIntType(self_.scalar_type())) { - value = valueToTensor(TensorTypeId::CPUTensorId, kFloat, py_value); + value = valueToTensor(device(kCPU).dtype(kFloat), py_value); } else { - value = valueToTensor(self_.type_id(), self_.scalar_type(), py_value); + value = valueToTensor(self_.options(), py_value); } // handle simple types: integers, slices, ellipsis, bool diff --git a/torch/csrc/autograd/utils/wrap_outputs.h b/torch/csrc/autograd/utils/wrap_outputs.h index 4b87063a563ef..f327470ea6a36 100644 --- a/torch/csrc/autograd/utils/wrap_outputs.h +++ b/torch/csrc/autograd/utils/wrap_outputs.h @@ -150,4 +150,12 @@ inline PyObject* wrap(at::TensorList tl) { return r.release(); } +inline PyObject* wrap(at::IntArrayRef list) { + auto r = THPObjectPtr{PyTuple_New(list.size())}; + if (!r) throw python_error(); + for (size_t i = 0; i < list.size(); ++i) { + PyTuple_SET_ITEM(r.get(), i, wrap(list[i])); + } + return r.release(); +} }}} // namespace torch::autograd::utils diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index 915b3c2a9d9bf..0e5e42ae75b76 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -1,12 +1,12 @@ #include +#include #include #include #include #include #include #include -#include #include #include @@ -71,17 +71,7 @@ void Variable::backward( const Tensor& gradient, bool keep_graph, bool create_graph) const { - auto autograd_meta = get_autograd_meta(); - std::vector edges; - edges.emplace_back(autograd_meta->grad_fn_, autograd_meta->output_nr_); - - std::vector inputs; - Tensor gradient_ = gradient; - if (!gradient.defined()) { - gradient_ = at::ones_like(*this); - } - inputs.push_back(std::move(as_variable_ref(gradient_))); - Engine::get_default_engine().execute(edges, inputs, keep_graph, create_graph); + torch::autograd::backward({*this}, {gradient}, keep_graph, create_graph); } void Variable::set_data(const at::Tensor &new_data) const { diff --git a/torch/csrc/cuda/Event.cpp b/torch/csrc/cuda/Event.cpp index c0e3bbb68edd7..6615dbd4c417b 100644 --- a/torch/csrc/cuda/Event.cpp +++ b/torch/csrc/cuda/Event.cpp @@ -81,13 +81,13 @@ static void THCPEvent_dealloc(THCPEvent *self) { Py_TYPE(self)->tp_free((PyObject*)self); } -static PyObject * THCPEvent_get_cuda_event(THCPEvent *self) { +static PyObject * THCPEvent_get_cuda_event(THCPEvent *self, void *unused) { HANDLE_TH_ERRORS return PyLong_FromVoidPtr(self->cuda_event.event()); END_HANDLE_TH_ERRORS } -static PyObject * THCPEvent_get_device(THCPEvent *self) { +static PyObject * THCPEvent_get_device(THCPEvent *self, void *unused) { HANDLE_TH_ERRORS at::optional device = self->cuda_event.device(); if (!device) { @@ -111,7 +111,7 @@ static PyObject * THCPEvent_wait(THCPEvent *self, THCPStream *stream) { END_HANDLE_TH_ERRORS } -static PyObject * THCPEvent_query(THCPEvent *self) { +static PyObject * THCPEvent_query(THCPEvent *self, PyObject *noargs) { HANDLE_TH_ERRORS return PyBool_FromLong(self->cuda_event.query()); END_HANDLE_TH_ERRORS @@ -123,14 +123,14 @@ static PyObject * THCPEvent_elapsed_time(THCPEvent *self, THCPEvent *other) { END_HANDLE_TH_ERRORS } -static PyObject * THCPEvent_synchronize(THCPEvent *self) { +static PyObject * THCPEvent_synchronize(THCPEvent *self, PyObject *noargs) { HANDLE_TH_ERRORS with_no_gil([&] { self->cuda_event.synchronize(); }); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } -static PyObject * THCPEvent_ipc_handle(THCPEvent *self) { +static PyObject * THCPEvent_ipc_handle(THCPEvent *self, PyObject *noargs) { HANDLE_TH_ERRORS cudaIpcEventHandle_t handle; self->cuda_event.ipc_handle(&handle); @@ -145,7 +145,7 @@ static struct PyGetSetDef THCPEvent_properties[] = { }; static PyMethodDef THCPEvent_methods[] = { - {(char*)"from_ipc_handle", (PyCFunction)THCPEvent_from_ipc_handle, + {(char*)"from_ipc_handle", (PyCFunction)(void(*)(void))THCPEvent_from_ipc_handle, METH_CLASS | METH_VARARGS | METH_KEYWORDS, nullptr}, {(char*)"record", (PyCFunction)THCPEvent_record, METH_O, nullptr}, {(char*)"wait", (PyCFunction)THCPEvent_wait, METH_O, nullptr}, diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 6998601c7d991..27072108a6d74 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -50,7 +50,7 @@ PyObject * THCPModule_setDevice_wrap(PyObject *self, PyObject *arg) END_HANDLE_TH_ERRORS } -PyObject * THCPModule_getDevice_wrap(PyObject *self) +PyObject * THCPModule_getDevice_wrap(PyObject *self, PyObject *noargs) { HANDLE_TH_ERRORS int device; @@ -60,7 +60,7 @@ PyObject * THCPModule_getDevice_wrap(PyObject *self) END_HANDLE_TH_ERRORS } -PyObject * THCPModule_getDeviceCount_wrap(PyObject *self) +PyObject * THCPModule_getDeviceCount_wrap(PyObject *self, PyObject *noargs) { HANDLE_TH_ERRORS //torch::utils::cuda_lazy_init(); @@ -68,7 +68,7 @@ PyObject * THCPModule_getDeviceCount_wrap(PyObject *self) END_HANDLE_TH_ERRORS } -PyObject * THCPModule_set_run_yet_variable_to_false_wrap(PyObject *self) +PyObject * THCPModule_set_run_yet_variable_to_false_wrap(PyObject *self, PyObject *noargs) { HANDLE_TH_ERRORS torch::utils::set_run_yet_variable_to_false(); @@ -117,7 +117,7 @@ PyObject * THCPModule_setStream_wrap(PyObject *self, PyObject *obj) END_HANDLE_TH_ERRORS } -PyObject * THCPModule_isDriverSufficient(PyObject *self) +PyObject * THCPModule_isDriverSufficient(PyObject *self, PyObject *noargs) { int count; cudaError_t err = cudaGetDeviceCount(&count); @@ -127,7 +127,7 @@ PyObject * THCPModule_isDriverSufficient(PyObject *self) return PyBool_FromLong(1); } -PyObject * THCPModule_getDriverVersion(PyObject *self) +PyObject * THCPModule_getDriverVersion(PyObject *self, PyObject *noargs) { int driverVersion = -1; cudaError_t err = cudaDriverGetVersion(&driverVersion); @@ -140,12 +140,12 @@ PyObject * THCPModule_getDriverVersion(PyObject *self) return PyLong_FromLong((int64_t) driverVersion); } -PyObject * THCPModule_getCompiledVersion(PyObject *self) +PyObject * THCPModule_getCompiledVersion(PyObject *self, PyObject *noargs) { return PyLong_FromLong((long) CUDA_VERSION); } -PyObject * THCPModule_cudaHostAllocator(PyObject *_unused) +PyObject * THCPModule_cudaHostAllocator(PyObject *_unused, PyObject *noargs) { HANDLE_TH_ERRORS c10::Allocator* allocator = THCState_getCudaHostAllocator(state); @@ -153,7 +153,7 @@ PyObject * THCPModule_cudaHostAllocator(PyObject *_unused) END_HANDLE_TH_ERRORS } -PyObject * THCPModule_cudaSynchronize(PyObject *_unused) +PyObject * THCPModule_cudaSynchronize(PyObject *_unused, PyObject *noargs) { HANDLE_TH_ERRORS THCudaCheck(cudaDeviceSynchronize()); @@ -161,7 +161,7 @@ PyObject * THCPModule_cudaSynchronize(PyObject *_unused) END_HANDLE_TH_ERRORS } -PyObject * THCPModule_cudaIPCCollect(PyObject *_unused /* unused */) +PyObject * THCPModule_cudaIPCCollect(PyObject *_unused, PyObject *noargs) { HANDLE_TH_ERRORS torch::CudaIPCCollect(); @@ -185,7 +185,7 @@ PyObject * THCPModule_cudaSleep(PyObject *_unused, PyObject *cycles) // by the thread that owns the mutex (obviously there can be only one such thread). static PyGILState_STATE cudaMutexGILState; -PyObject * THCPModule_cudaLockMutex(PyObject *module) +PyObject * THCPModule_cudaLockMutex(PyObject *module, PyObject *noargs) { auto mutex = c10::cuda::CUDACachingAllocator::getFreeMutex(); // This has to be a busy loop because we **absolutely need to** hold the GIL @@ -206,7 +206,7 @@ PyObject * THCPModule_cudaLockMutex(PyObject *module) Py_RETURN_NONE; } -PyObject * THCPModule_cudaUnlockMutex(PyObject *module) +PyObject * THCPModule_cudaUnlockMutex(PyObject *module, PyObject *noargs) { auto mutex = c10::cuda::CUDACachingAllocator::getFreeMutex(); PyGILState_Release(cudaMutexGILState); @@ -227,7 +227,7 @@ PyObject * THCPModule_hasPrimaryContext(PyObject *_unused, PyObject *arg) END_HANDLE_TH_ERRORS } -PyObject * THCPModule_emptyCache(PyObject *_unused) +PyObject * THCPModule_emptyCache(PyObject *_unused, PyObject *noargs) { HANDLE_TH_ERRORS c10::cuda::CUDACachingAllocator::emptyCache(); @@ -323,7 +323,7 @@ static void bindCudaDeviceProperties(PyObject* module) { } // Callback for python part. Used for additional initialization of python classes -static PyObject * THCPModule_initExtension(PyObject *self) +static PyObject * THCPModule_initExtension(PyObject *self, PyObject *noargs) { HANDLE_TH_ERRORS state = at::globalContext().lazyInitCUDA(); @@ -386,7 +386,7 @@ void THCPModule_useNccl() } #endif -PyObject * THCPModule_getCurrentBlasHandle_wrap(PyObject *self) +PyObject * THCPModule_getCurrentBlasHandle_wrap(PyObject *self, PyObject *noargs) { HANDLE_TH_ERRORS cublasHandle_t handle = THCState_getCurrentBlasHandle(state); diff --git a/torch/csrc/cuda/Stream.cpp b/torch/csrc/cuda/Stream.cpp index 283aa3f9782fe..bbe4b0eeadaa3 100644 --- a/torch/csrc/cuda/Stream.cpp +++ b/torch/csrc/cuda/Stream.cpp @@ -50,19 +50,19 @@ static void THCPStream_dealloc(THCPStream *self) { Py_TYPE(self)->tp_free((PyObject*)self); } -static PyObject * THCPStream_get_device(THCPStream *self) { +static PyObject * THCPStream_get_device(THCPStream *self, void *unused) { HANDLE_TH_ERRORS return THPDevice_New(self->cuda_stream.device()); END_HANDLE_TH_ERRORS } -static PyObject * THCPStream_get_cuda_stream(THCPStream *self) { +static PyObject * THCPStream_get_cuda_stream(THCPStream *self, void *unused) { HANDLE_TH_ERRORS return PyLong_FromVoidPtr(self->cuda_stream.stream()); END_HANDLE_TH_ERRORS } -static PyObject * THCPStream_get_priority(THCPStream *self) { +static PyObject * THCPStream_get_priority(THCPStream *self, void *unused) { HANDLE_TH_ERRORS return PyLong_FromLong(self->cuda_stream.priority()); END_HANDLE_TH_ERRORS @@ -77,13 +77,13 @@ static PyObject * THCPStream_priority_range() { END_HANDLE_TH_ERRORS } -static PyObject * THCPStream_query(THCPStream *self) { +static PyObject * THCPStream_query(THCPStream *self, PyObject *noargs) { HANDLE_TH_ERRORS return PyBool_FromLong(self->cuda_stream.query()); END_HANDLE_TH_ERRORS } -static PyObject * THCPStream_synchronize(THCPStream *self) { +static PyObject * THCPStream_synchronize(THCPStream *self, PyObject *noargs) { HANDLE_TH_ERRORS with_no_gil([&] { self->cuda_stream.synchronize(); }); Py_RETURN_NONE; @@ -115,7 +115,7 @@ static PyMethodDef THCPStream_methods[] = { {(char*)"synchronize", (PyCFunction)THCPStream_synchronize, METH_NOARGS, nullptr}, {(char*)"priority_range", - (PyCFunction)THCPStream_priority_range, METH_STATIC | METH_NOARGS, nullptr}, + (PyCFunction)(void(*)(void))THCPStream_priority_range, METH_STATIC | METH_NOARGS, nullptr}, {(char*)"__eq__", (PyCFunction)THCPStream_eq, METH_O, nullptr}, {nullptr} }; diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index efeec801904e3..ff14d5ee932b7 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -17,7 +17,6 @@ #include #include -#include #include #include @@ -35,28 +34,6 @@ namespace { #ifdef USE_C10D_GLOO constexpr char* GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME"; - -std::shared_ptr<::gloo::transport::Device> createDeviceForDefaultHostname() { - ::gloo::transport::tcp::attr attr; - - // Use the hostname to resolve the network address to - // use. Note: if the hostname does not resolve to an address (e.g. - // because of misconfigured /etc/hosts file), this will not work. - std::array hostname{}; - auto rv = gethostname(hostname.data(), hostname.size()); - if (rv != 0) { - throw std::system_error(errno, std::system_category()); - } - attr.hostname = hostname.data(); - return ::gloo::transport::tcp::CreateDevice(attr); -} - -std::shared_ptr<::gloo::transport::Device> createDeviceForInterface( - std::string iface) { - ::gloo::transport::tcp::attr attr; - attr.iface = std::move(iface); - return ::gloo::transport::tcp::CreateDevice(attr); -} #endif std::vector split(char separator, const std::string& string) { @@ -446,20 +423,17 @@ They are used in specifying strategies for reduction collectives, e.g., .def_readwrite("threads", &::c10d::ProcessGroupGloo::Options::threads); processGroupGloo.def_static( - "create_tcp_device", + "create_device", [](const std::string& hostname, const std::string& interface) -> std::shared_ptr<::gloo::transport::Device> { - ::gloo::transport::tcp::attr attr; if (!hostname.empty()) { - attr.hostname = hostname; - } else if (!interface.empty()) { - attr.iface = interface; - } else { - // Neither argument is specified; Gloo itself will use the - // hostname - // Nothing specified, default to something useful + return ::c10d::ProcessGroupGloo::createDeviceForHostname(hostname); + } + if (!interface.empty()) { + return ::c10d::ProcessGroupGloo::createDeviceForInterface(interface); } - return ::gloo::transport::tcp::CreateDevice(attr); + throw std::invalid_argument( + "Specify either `hostname` or `interface` argument."); }, py::arg("hostname") = "", py::arg("interface") = ""); @@ -481,10 +455,15 @@ They are used in specifying strategies for reduction collectives, e.g., char* ifnameEnv = getenv(GLOO_SOCKET_IFNAME_ENV); if (ifnameEnv) { for (const auto& iface : split(',', ifnameEnv)) { - options.devices.push_back(createDeviceForInterface(iface)); + options.devices.push_back( + ::c10d::ProcessGroupGloo::createDeviceForInterface(iface)); } } else { - options.devices.push_back(createDeviceForDefaultHostname()); + // If no hostname is specified, this function looks up + // the machine's hostname and returns a device instance + // associated with the address that the hostname resolves to. + options.devices.push_back( + ::c10d::ProcessGroupGloo::createDefaultDevice()); } options.timeout = timeout; @@ -506,12 +485,10 @@ They are used in specifying strategies for reduction collectives, e.g., const std::shared_ptr<::c10d::Store>&, int, int, - const std::string&, const std::chrono::milliseconds&>(), py::arg("store"), py::arg("rank"), py::arg("size"), - py::arg("groupName") = "", py::arg("timeout") = std::chrono::milliseconds( ::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis)); #endif diff --git a/torch/csrc/distributed/rpc/functions.cpp b/torch/csrc/distributed/rpc/functions.cpp index 423db88b49b1a..c613d0892be71 100644 --- a/torch/csrc/distributed/rpc/functions.cpp +++ b/torch/csrc/distributed/rpc/functions.cpp @@ -1,5 +1,14 @@ #include +#include +#include +#include +#include +#include +#include +#include +#include + namespace torch { namespace distributed { namespace rpc { @@ -18,18 +27,20 @@ Message processRequestBlocking(Message&& request) { switch (request.type()) { case MessageType::SCRIPT_CALL: { try { - ScriptCall op = ScriptCall::fromMessage(request); + ScriptCall sc = ScriptCall::fromMessage(request); + + // sc is only alive within this block, use reference to avoid copy + auto& stack = sc.stackRef(); + sc.op()->getOperation()(stack); - auto stack = op.stack(); - op.op()->getOperation()(stack); AT_ASSERT( stack.size() == 1, "Return value of a builtin operator or a " "TorchScript function should be a single IValue, got a vector of " "size ", stack.size()); - auto response = ScriptRet(std::move(stack.front())).toMessage(); + response.setId(request.id()); return response; } catch (std::exception& e) { @@ -39,7 +50,8 @@ Message processRequestBlocking(Message&& request) { } case MessageType::PYTHON_CALL: { try { - auto payload = PythonRpcHandler::generatePythonUDFResult(request); + auto payload = + PythonRpcHandler::getInstance().generatePythonUDFResult(request); return Message( std::move(payload), std::vector(), @@ -50,6 +62,50 @@ Message processRequestBlocking(Message&& request) { } break; } + case MessageType::REMOTE_CALL: { + ScriptRemoteCall src = ScriptRemoteCall::fromMessage(request); + + auto rrefId = RRefId::fromIValue(src.retRRefId()); + auto forkId = ForkId::fromIValue(src.retForkId()); + TORCH_CHECK(rrefId != forkId, "Does not support remote call to self."); + + auto& ctx = RRefContext::getInstance(); + auto ownerRRef = ctx->getOrCreateOwnerRRef(rrefId); + + // TODO: make this asynchronous + // src is only alive within this block, use reference to avoid copy + auto& stack = src.stackRef(); + src.op()->getOperation()(stack); + AT_ASSERT( + stack.size() == 1, + "Return value of a builtin operator or a " + "TorchScript function should be a single IValue, got a vector of " + "size ", + stack.size()); + + ownerRRef->setValue(std::move(stack.front())); + return Message(); + } + case MessageType::RREF_FETCH_CALL: { + ScriptRRefFetchCall srf = ScriptRRefFetchCall::fromMessage(request); + // TODO: make this asynchronous + std::shared_ptr> rref = + RRefContext::getInstance()->getOrCreateOwnerRRef( + RRefId::fromIValue(srf.value())); + auto response = ScriptRRefFetchRet(rref->getValue()).toMessage(); + response.setId(request.id()); + return response; + } + case MessageType::RREF_USER_CREATE: { + ScriptRRefCreate sra = ScriptRRefCreate::fromMessage(request); + RRefContext::getInstance()->addFork(sra.valueRef()); + return Message(); + } + case MessageType::RREF_USER_DELETE: { + ScriptRRefDelete srd = ScriptRRefDelete::fromMessage(request); + RRefContext::getInstance()->delFork(srd.valueRef()); + return Message(); + } default: { AT_ERROR("Request type ", request.type(), " not supported."); } diff --git a/torch/csrc/distributed/rpc/functions.h b/torch/csrc/distributed/rpc/functions.h index d579db505eb2e..d5f82885712f7 100644 --- a/torch/csrc/distributed/rpc/functions.h +++ b/torch/csrc/distributed/rpc/functions.h @@ -1,11 +1,7 @@ #pragma once -#include #include -#include #include -#include -#include namespace torch { namespace distributed { diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp index 83ba0d8dd3c50..e05239ee048c7 100644 --- a/torch/csrc/distributed/rpc/init.cpp +++ b/torch/csrc/distributed/rpc/init.cpp @@ -5,6 +5,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -40,11 +43,19 @@ PyObject* rpc_init(PyObject* /* unused */) { &RpcAgent::sync, py::call_guard()); + auto rref = + shared_ptr_class_(module, "RRef") + .def("owner", &RRef::owner, py::call_guard()) + .def( + "to_here", + [&](RRef& rref) { return torch::jit::toPyObject(rref.toHere()); }, + py::call_guard()); + auto futureMessage = shared_ptr_class_(module, "FutureMessage") .def( "wait", - [&](FutureMessage& fut) { return to_py_obj(fut.wait()); }, + [&](FutureMessage& fut) { return toPyObj(fut.wait()); }, py::call_guard()); shared_ptr_class_(module, "ProcessGroupAgent", rpcAgent) @@ -72,6 +83,10 @@ PyObject* rpc_init(PyObject* /* unused */) { &ProcessGroupAgent::sync, py::call_guard()); + module.def("init_rref_context", [](std::shared_ptr agent) { + RRefContext::initInstance(std::move(agent)); + }); + module.def( "invoke_rpc_builtin", [](RpcAgent& agent, @@ -79,7 +94,7 @@ PyObject* rpc_init(PyObject* /* unused */) { const std::string& opName, const py::args& args, const py::kwargs& kwargs) { - return py_rpc_builtin(agent, dst, opName, args, kwargs); + return pyRpcBuiltin(agent, dst, opName, args, kwargs); }); module.def( @@ -87,7 +102,17 @@ PyObject* rpc_init(PyObject* /* unused */) { [](RpcAgent& agent, const WorkerId& dst, const std::string& pickledPythonUDF) { - return py_rpc_python_udf(agent, dst, pickledPythonUDF); + return pyRpcPythonUdf(agent, dst, pickledPythonUDF); + }); + + module.def( + "invoke_remote_builtin", + [](RpcAgent& agent, + const WorkerId& dst, + const std::string& opName, + const py::args& args, + const py::kwargs& kwargs) { + return pyRemoteBuiltin(agent, dst, opName, args, kwargs); }); Py_RETURN_TRUE; diff --git a/torch/csrc/distributed/rpc/message.cpp b/torch/csrc/distributed/rpc/message.cpp index 1d1f5379a7b66..cbfa567c868b6 100644 --- a/torch/csrc/distributed/rpc/message.cpp +++ b/torch/csrc/distributed/rpc/message.cpp @@ -57,11 +57,22 @@ const MessageType& Message::type() const { } bool Message::isRequest() const { - return MessageType::SCRIPT_CALL == type_ || MessageType::PYTHON_CALL == type_; + return MessageType::SCRIPT_CALL == type_ || + MessageType::PYTHON_CALL == type_ || MessageType::REMOTE_CALL == type_ || + MessageType::RREF_FETCH_CALL == type_ || + MessageType::RREF_USER_CREATE == type_ || + MessageType::RREF_USER_DELETE == type_; +} + +bool Message::requiresResponse() const { + return MessageType::SCRIPT_CALL == type_ || + MessageType::PYTHON_CALL == type_ || + MessageType::RREF_FETCH_CALL == type_; } bool Message::isResponse() const { - return MessageType::SCRIPT_RET == type_ || MessageType::PYTHON_RET == type_; + return MessageType::SCRIPT_RET == type_ || MessageType::PYTHON_RET == type_ || + MessageType::RREF_FETCH_RET == type_; } bool Message::isShutdown() const { diff --git a/torch/csrc/distributed/rpc/message.h b/torch/csrc/distributed/rpc/message.h index c7932efa251cd..eea216f7354a7 100644 --- a/torch/csrc/distributed/rpc/message.h +++ b/torch/csrc/distributed/rpc/message.h @@ -12,6 +12,11 @@ enum MessageType { SCRIPT_RET, PYTHON_CALL, PYTHON_RET, + REMOTE_CALL, + RREF_FETCH_CALL, + RREF_FETCH_RET, + RREF_USER_CREATE, + RREF_USER_DELETE, SHUTDOWN, EXCEPTION, UNKNOWN @@ -62,6 +67,7 @@ class TORCH_API Message final { const MessageType& type() const; bool isRequest() const; + bool requiresResponse() const; bool isResponse() const; bool isShutdown() const; diff --git a/torch/csrc/distributed/rpc/process_group_agent.cpp b/torch/csrc/distributed/rpc/process_group_agent.cpp index 27e1b98c56c50..4cff1ef6f5dae 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.cpp +++ b/torch/csrc/distributed/rpc/process_group_agent.cpp @@ -126,7 +126,8 @@ ProcessGroupAgent::ProcessGroupAgent( workerIds_.emplace_back(std::move(tmpWorkerIds[rank]), rank); } - PythonRpcHandler::init(); + // construct PythonRpcHandler singleton here + PythonRpcHandler::getInstance(); listenerThread_ = std::thread(&ProcessGroupAgent::listenLoop, this); } @@ -139,6 +140,10 @@ const WorkerId& ProcessGroupAgent::getWorkerId( return workerIds_[idIter->second]; } +const WorkerId& ProcessGroupAgent::getWorkerId(worker_id_t id) const { + return workerIds_[id]; +} + void ProcessGroupAgent::join() { // Every process i sends a SHUTDOWN message to process i + 1. This is // necessary for now because: @@ -154,10 +159,6 @@ void ProcessGroupAgent::join() { listenerThread_.join(); } -int16_t ProcessGroupAgent::getWorkerId() { - return pg_->getRank(); -} - void ProcessGroupAgent::sync() { // Block until all processes wants to sync. This is necessary before acquiring // the lock below, because other processes might not enter sync() until it @@ -186,7 +187,7 @@ std::shared_ptr ProcessGroupAgent::sendImpl( auto requestId = nextId(); auto future = std::make_shared(); - if (message.isRequest()) { + if (message.requiresResponse()) { { std::lock_guard lock{futureMutex_}; futures_[requestId] = future; @@ -260,9 +261,10 @@ void ProcessGroupAgent::enqueueRecv(RecvWork work) { Message message = deserialize(work.type_, ss); - if (message.isRequest()) { - auto response = cb_(std::move(message)); - send(work.from_, std::move(response)); + if (message.requiresResponse()) { + send(work.from_, cb_(std::move(message))); + } else if (message.isRequest()) { + cb_(std::move(message)); } else if (message.isResponse()) { auto id = message.id(); { diff --git a/torch/csrc/distributed/rpc/process_group_agent.h b/torch/csrc/distributed/rpc/process_group_agent.h index 7b0cb8042cfd4..d9815b3bf2ddb 100644 --- a/torch/csrc/distributed/rpc/process_group_agent.h +++ b/torch/csrc/distributed/rpc/process_group_agent.h @@ -43,12 +43,12 @@ class ProcessGroupAgent : public RpcAgent { const WorkerId& getWorkerId(const std::string& workerName) const override; + const WorkerId& getWorkerId(worker_id_t id) const override; + void join() override; void sync() override; - int16_t getWorkerId() override; - protected: // This method wraps the destination information and the message into a // SendWork object, and put the SendWork into a queue. Another thread will diff --git a/torch/csrc/distributed/rpc/python_functions.cpp b/torch/csrc/distributed/rpc/python_functions.cpp index 33b3a5aae1b95..f97493279ae93 100644 --- a/torch/csrc/distributed/rpc/python_functions.cpp +++ b/torch/csrc/distributed/rpc/python_functions.cpp @@ -4,37 +4,15 @@ namespace torch { namespace distributed { namespace rpc { -py::object to_py_obj(const Message& message) { - switch (message.type()) { - case MessageType::SCRIPT_RET: { - ScriptRet ret = ScriptRet::fromMessage(message); - Stack stack; - stack.push_back(ret.value()); - return torch::jit::createPyObjectForStack(std::move(stack)); - } - case MessageType::PYTHON_RET: { - return PythonRpcHandler::loadPythonUDFResult(message); - } - case MessageType::EXCEPTION: { - std::string err(message.payload().begin(), message.payload().end()); - throw std::runtime_error(err); - } - default: { - AT_ERROR("Unrecognized response message type ", message.type()); - } - } -} +namespace { -std::shared_ptr py_rpc_builtin( - RpcAgent& agent, - const WorkerId& dst, +std::shared_ptr matchBuiltinOp( const std::string& opName, const py::args& args, - const py::kwargs& kwargs) { - // builtin operators. + const py::kwargs& kwargs, + Stack& stack) { Symbol symbol = Symbol::fromQualString(opName); if (symbol.is_aten()) { - Stack stack; for (const auto& op : torch::jit::getAllOperatorsFor(symbol)) { try { // FIXME: This is temporary solution. We should at least refactor @@ -49,8 +27,8 @@ std::shared_ptr py_rpc_builtin( continue; } - // Found the right op! Send it along... - return agent.send(dst, ScriptCall(op, std::move(stack)).toMessage()); + // Found the right op! + return op; } } @@ -63,9 +41,67 @@ std::shared_ptr py_rpc_builtin( ", kwargs: ", kwargs, ") to a builtin operator"); + + // builtin operators. +} + +} // namespace + +py::object toPyObj(const Message& message) { + switch (message.type()) { + case MessageType::SCRIPT_RET: { + ScriptRet ret = ScriptRet::fromMessage(message); + Stack stack; + stack.push_back(ret.value()); + return torch::jit::createPyObjectForStack(std::move(stack)); + } + case MessageType::PYTHON_RET: { + return PythonRpcHandler::getInstance().loadPythonUDFResult(message); + } + case MessageType::EXCEPTION: { + std::string err(message.payload().begin(), message.payload().end()); + throw std::runtime_error(err); + } + default: { + AT_ERROR("Unrecognized response message type ", message.type()); + } + } +} + +std::shared_ptr pyRpcBuiltin( + RpcAgent& agent, + const WorkerId& dst, + const std::string& opName, + const py::args& args, + const py::kwargs& kwargs) { + Stack stack; + auto op = matchBuiltinOp(opName, args, kwargs, stack); + return agent.send(dst, ScriptCall(op, std::move(stack)).toMessage()); +} + +std::shared_ptr pyRemoteBuiltin( + RpcAgent& agent, + const WorkerId& dst, + const std::string& opName, + const py::args& args, + const py::kwargs& kwargs) { + Stack stack; + auto op = matchBuiltinOp(opName, args, kwargs, stack); + + auto& ctx = RRefContext::getInstance(); + auto userRRef = ctx->createUserRRef(dst.id_); + agent.send( + dst, + ScriptRemoteCall( + op, + std::move(stack), + userRRef->id().toIValue(), + userRRef->forkId().toIValue()) + .toMessage()); + return userRRef; } -std::shared_ptr py_rpc_python_udf( +std::shared_ptr pyRpcPythonUdf( RpcAgent& agent, const WorkerId& dst, const std::string& pickledPythonUDF) { diff --git a/torch/csrc/distributed/rpc/python_functions.h b/torch/csrc/distributed/rpc/python_functions.h index 57535d7c86e75..67a5e6429476b 100644 --- a/torch/csrc/distributed/rpc/python_functions.h +++ b/torch/csrc/distributed/rpc/python_functions.h @@ -4,7 +4,10 @@ #include #include #include +#include +#include #include +#include #include #include #include @@ -13,20 +16,27 @@ namespace torch { namespace distributed { namespace rpc { -py::object to_py_obj(const Message& message); +py::object toPyObj(const Message& message); -std::shared_ptr py_rpc_builtin( +std::shared_ptr pyRpcBuiltin( RpcAgent& agent, const WorkerId& dst, const std::string& opName, const py::args& args, const py::kwargs& kwargs); -std::shared_ptr py_rpc_python_udf( +std::shared_ptr pyRpcPythonUdf( RpcAgent& agent, const WorkerId& dst, const std::string& pickledPythonUDF); +std::shared_ptr pyRemoteBuiltin( + RpcAgent& agent, + const WorkerId& dst, + const std::string& opName, + const py::args& args, + const py::kwargs& kwargs); + } // namespace rpc } // namespace distributed } // namespace torch diff --git a/torch/csrc/distributed/rpc/python_rpc_handler.cpp b/torch/csrc/distributed/rpc/python_rpc_handler.cpp index 9b70f80907f5d..09c78f518a459 100644 --- a/torch/csrc/distributed/rpc/python_rpc_handler.cpp +++ b/torch/csrc/distributed/rpc/python_rpc_handler.cpp @@ -3,41 +3,37 @@ namespace torch { namespace distributed { namespace rpc { -namespace { -py::object module_; -py::object runUDFFunction_; -py::object loadResultFunction_; -} // anonymous namespace -namespace PythonRpcHandler { -void init() { +PythonRpcHandler::PythonRpcHandler() { AutoGIL ag; - if (module_ == nullptr) { - module_ = py::module::import("torch.distributed.internal_rpc_utils"); - } - if (runUDFFunction_ == nullptr) { - runUDFFunction_ = module_.attr("run_python_udf_internal"); - } - if (loadResultFunction_ == nullptr) { - loadResultFunction_ = module_.attr("load_python_udf_result_internal"); - } + py::object module = + py::module::import("torch.distributed.internal_rpc_utils"); + runUDFFunction_ = module.attr("run_python_udf_internal"); + loadResultFunction_ = module.attr("load_python_udf_result_internal"); } -std::vector generatePythonUDFResult(const Message& request) { +PythonRpcHandler& PythonRpcHandler::getInstance() { + static PythonRpcHandler handler; + return handler; +} + +std::vector PythonRpcHandler::generatePythonUDFResult( + const Message& request) { AutoGIL ag; auto pargs = py::bytes(request.payload().data(), request.payload().size()); + TORCH_CHECK(runUDFFunction_ != nullptr, "runUDFFunction_ is nullptr"); py::bytes pres = runUDFFunction_(pargs); const auto& presStr = static_cast(pres); std::vector payload(presStr.begin(), presStr.end()); return payload; } -py::object loadPythonUDFResult(const Message& message) { +py::object PythonRpcHandler::loadPythonUDFResult(const Message& message) { AutoGIL ag; auto pargs = py::bytes(message.payload().data(), message.payload().size()); + TORCH_CHECK(loadResultFunction_ != nullptr, "loadResultFunction_ is nullptr"); return loadResultFunction_(pargs); } -} // namespace PythonRpcHandler } // namespace rpc } // namespace distributed diff --git a/torch/csrc/distributed/rpc/python_rpc_handler.h b/torch/csrc/distributed/rpc/python_rpc_handler.h index ae393383b58f3..c20de9768998a 100644 --- a/torch/csrc/distributed/rpc/python_rpc_handler.h +++ b/torch/csrc/distributed/rpc/python_rpc_handler.h @@ -7,16 +7,33 @@ namespace torch { namespace distributed { namespace rpc { -namespace PythonRpcHandler { -// initialize python module object and function objects in which python user -// defined function (UDF) will run there -void init(); -// execute python UDF, result is pickled to binary string -std::vector generatePythonUDFResult(const Message& request); -// returned python UDF result is pickled binary string, so run python -// function to unpickle the python UDF result and return pyObject to user -py::object loadPythonUDFResult(const Message& message); -} // namespace PythonRpcHandler +// Singleton class provides interface to execute python UDF remote call +// and deserialize the returned results by running python function +// in internal_rpc_utilities. +// The singleton object is constructed at first when RPC agent is +// constructed, where the python function in +// torch/distributed/internal_rpc_utils.py are imported only once. +class PYBIND11_EXPORT PythonRpcHandler { + public: + static PythonRpcHandler& getInstance(); + // Execute python UDF, result is pickled to binary string + std::vector generatePythonUDFResult(const Message& request); + // Returned python UDF result is pickled binary string, so run python + // function to unpickle the python UDF result and return py::object to user + py::object loadPythonUDFResult(const Message& message); + + private: + PythonRpcHandler(); + ~PythonRpcHandler() = default; + + PythonRpcHandler(const PythonRpcHandler&) = delete; + PythonRpcHandler& operator=(const PythonRpcHandler&) = delete; + PythonRpcHandler(PythonRpcHandler&&) = delete; + PythonRpcHandler& operator=(PythonRpcHandler&&) = delete; + + py::object runUDFFunction_; + py::object loadResultFunction_; +}; } // namespace rpc } // namespace distributed diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h index 894e3872dec13..f5ebda81efa3f 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.h +++ b/torch/csrc/distributed/rpc/rpc_agent.h @@ -2,6 +2,7 @@ #include #include +#include #include @@ -9,8 +10,6 @@ namespace torch { namespace distributed { namespace rpc { -using worker_id_t = int16_t; - // A globally unique ID to identify an RpcAgent struct WorkerId { WorkerId(std::string name, int id) @@ -24,12 +23,13 @@ struct WorkerId { WorkerId(std::string name, worker_id_t id) : name_(std::move(name)), id_(id) { bool validSize = name_.length() < MAX_NAME_LEN && name_.length() > 0; - bool validChar = std::find_if(name_.begin(), name_.end(), [](char c) { - return !(std::isalnum(c) || c == '-' || c == '_'); - }) == name_.end(); + bool validChar = + std::find_if(name_.begin(), name_.end(), [](char c) { + return !(std::isalnum(c) || c == '-' || c == '_' || c == ':'); + }) == name_.end(); TORCH_CHECK( validSize && validChar, - "Worker name must match ^[A-Za-z0-9-_]*$, " + "Worker name must match ^[A-Za-z0-9-_:]*$, " "and must be non-empty and shorter than ", MAX_NAME_LEN, " chars, " @@ -93,8 +93,7 @@ class RpcAgent { // Return a reference to the ``WorkerId`` of the given ``workerName``. virtual const WorkerId& getWorkerId(const std::string& workerName) const = 0; - // Retrieves the worker_id for this node. - virtual int16_t getWorkerId() = 0; + virtual const WorkerId& getWorkerId(worker_id_t id) const = 0; // Call sync and join all internal threads. This method should be called // before every RPC process exits. diff --git a/torch/csrc/distributed/rpc/rref.cpp b/torch/csrc/distributed/rpc/rref.cpp new file mode 100644 index 0000000000000..e81281b3a7fa2 --- /dev/null +++ b/torch/csrc/distributed/rpc/rref.cpp @@ -0,0 +1,119 @@ +#include +#include +#include + +namespace torch { +namespace distributed { +namespace rpc { + +std::atomic RRefContext::nextLocalId_{0}; + +////////////////////////// RRefForkData ///////////////////////////////// + +RRefForkData::RRefForkData( + worker_id_t ownerId, + const RRefId& rrefId, + const ForkId& forkId) + : ownerId_(ownerId), rrefId_(rrefId), forkId_(forkId) {} + +at::IValue RRefForkData::toIValue() const { + std::vector ivalues = { + (int64_t)ownerId_, rrefId_.toIValue(), forkId_.toIValue()}; + + return c10::ivalue::Tuple::create(std::move(ivalues)); +} + +RRefForkData RRefForkData::fromIValue(const at::IValue& ivalue) { + auto ivalues = ivalue.toTuple()->elements(); + + TORCH_CHECK( + ivalues.size() == 3, + "Constructing RRefForkData from ivalue " + "expects a GenericList of 3 elements, but got ", + ivalues.size()); + + int64_t ownerId = ivalues[0].toInt(); + TORCH_CHECK( + ownerId < std::numeric_limits::max(), + "RRefId createdOn out of range, got ", + ownerId); + + RRefId rrefId = RRefId::fromIValue(ivalues[1]); + ForkId forkId = ForkId::fromIValue(ivalues[2]); + + return RRefForkData(ownerId, rrefId, forkId); +} + +////////////////////////////// RRef ///////////////////////////////////// + +RRef::RRef(worker_id_t ownerId, const RRefId& rrefId) + : ownerId_(ownerId), rrefId_(rrefId) {} + +worker_id_t RRef::owner() const { + return ownerId_; +} + +const RRefId& RRef::id() const { + return rrefId_; +} + +at::IValue RRef::fork() const { + return RRefForkData( + ownerId_, rrefId_, RRefContext::getInstance()->genRRefId()) + .toIValue(); + // NB: does not support sharing RRefs between users + // TODO: notify the owner +} + +////////////////////////// UserRRef ///////////////////////////////////// + +UserRRef::UserRRef( + worker_id_t ownerId, + const RRefId& rrefId, + const ForkId& forkId) + : RRef(ownerId, rrefId), forkId_(forkId) { + AT_ASSERT( + !(forkId_ == rrefId_), + "User RRef's fork ID should not be the same as its rref Id"); + if (RRefContext::getInstance()->getWorkerId() == rrefId_.createdOn_) { + // creator user, notify owner. + auto& agent = RRefContext::getInstance()->agent(); + agent->send( + agent->getWorkerId(ownerId_), + ScriptRRefCreate(RRefForkData(ownerId_, rrefId_, forkId_).toIValue()) + .toMessage()); + } else { + AT_ERROR("Does not support sharing RRefs between users yet"); + } +} + +UserRRef::~UserRRef() { + auto& ctx = RRefContext::getInstance(); + if (ctx->getWorkerId() != ownerId_) { + ctx->agent()->send( + ctx->agent()->getWorkerId(ownerId_), + ScriptRRefDelete(RRefForkData(ownerId_, rrefId_, forkId_).toIValue()) + .toMessage()); + } +} + +const ForkId& UserRRef::forkId() const { + return forkId_; +} + +bool UserRRef::isOwner() const { + return false; +} + +IValue UserRRef::toHere() { + auto& agent = RRefContext::getInstance()->agent(); + std::shared_ptr fm = agent->send( + agent->getWorkerId(ownerId_), + ScriptRRefFetchCall(id().toIValue()).toMessage()); + auto srv = ScriptRRefFetchRet::fromMessage(fm->wait()); + return srv.value(); +} + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/rref.h b/torch/csrc/distributed/rpc/rref.h new file mode 100644 index 0000000000000..03fa11a25c629 --- /dev/null +++ b/torch/csrc/distributed/rpc/rref.h @@ -0,0 +1,140 @@ +#pragma once + +#include +#include +#include +#include + +#include + +namespace torch { +namespace distributed { +namespace rpc { + +class RRef; +class RRefContext; +class UserRRef; + +// Represents fork of an RRef to be sent over the wire. +// +// In order to preserve correctness of reference counting, each RRefForkData +// **MUST** be deserialized into a RRef. This means that if RRefForkData is to +// be transferred across the network, we need the guarantee that the message +// will *eventually* get to the peer, and that the peer will create a RRef out +// of it. Therefore, no constructor of RRefForkData is exposed, and +// applications should never directly use RRefForkData. All construction are +// done within ``RRef`` and ``RRefContext``. +struct RRefForkData { + at::IValue toIValue() const; + + private: + friend class RRef; + friend class RRefContext; + friend class UserRRef; + + RRefForkData( + worker_id_t ownerId, + const RRefId& rrefId_, + const ForkId& forkId_); + + static RRefForkData fromIValue(const at::IValue&); + + const worker_id_t ownerId_; + const RRefId rrefId_; + const ForkId forkId_; +}; + +static_assert( + C10_IS_TRIVIALLY_COPYABLE(RRefForkData), + "RRefForkData must be trivially copyable"); + +// TODO: make RRef an IValue, and edit createStackForSchema accordingly +class RRef { + public: + // RRef is made NOT copyable NOT movable to prevent messing up reference + // counting + RRef(const RRef& other) = delete; + RRef(RRef&& other) = delete; + + virtual ~RRef() = default; + + worker_id_t owner() const; + const RRefId& id() const; + IValue fork() const; + + virtual bool isOwner() const = 0; + virtual IValue toHere() = 0; + + protected: + RRef(worker_id_t ownerId, const RRefId& rrefId); + + const worker_id_t ownerId_; + const RRefId rrefId_; +}; + +class UserRRef final : public RRef { + public: + const ForkId& forkId() const; + bool isOwner() const override; + IValue toHere() override; + + ~UserRRef() override; + + private: + friend class RRefContext; + + UserRRef(worker_id_t ownerId, const RRefId& rrefId, const ForkId& forkId); + + const ForkId forkId_; +}; + +// Keep the template only on the derived class because ``RRefContext`` needs to +// erase the type on ``RRef`` and keep them in one map. +template +class OwnerRRef final : public RRef { + public: + bool isOwner() const override { + return true; + } + + T getValue() const { + // TODO: use callback to make this non-blocking + std::unique_lock lock(mutex_); + valueCV_.wait(lock, [this] { return value_.has_value(); }); + return value_.value(); + } + + void setValue(T&& value) { + { + std::lock_guard lock(mutex_); + value_ = std::move(value); + } + valueCV_.notify_all(); + } + + IValue toHere() override { + AT_ERROR("OwnerRRef does not support toHere(), use getValue() instead."); + } + + private: + friend class RRefContext; + + OwnerRRef(worker_id_t ownerId, const RRefId& rrefId) + : OwnerRRef(ownerId, rrefId, {}) {} + + OwnerRRef(OwnerRRef&& other) noexcept + : OwnerRRef(other.owner(), other.id(), std::move(other.value_)) {} + + OwnerRRef(worker_id_t ownerId, const RRefId& rrefId, c10::optional value) + : RRef(ownerId, rrefId) { + value_ = std::move(value); + } + + c10::optional value_; + mutable std::mutex mutex_; + mutable std::condition_variable valueCV_; +}; + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/rref_context.cpp b/torch/csrc/distributed/rpc/rref_context.cpp new file mode 100644 index 0000000000000..dc7302e8adce9 --- /dev/null +++ b/torch/csrc/distributed/rpc/rref_context.cpp @@ -0,0 +1,72 @@ +#include + +namespace torch { +namespace distributed { +namespace rpc { + +std::unique_ptr RRefContext::context_; + +void RRefContext::initInstance(std::shared_ptr agent) { + TORCH_CHECK(!RRefContext::context_, "Can only initialize RRefContext once."); + TORCH_CHECK(agent, "RRefContext requires a non-null RpcAgent shared_ptr."); + + RRefContext::context_ = + std::unique_ptr(new RRefContext(std::move(agent))); +} + +std::unique_ptr& RRefContext::getInstance() { + TORCH_CHECK( + RRefContext::context_, "Have to initialize RRefContext before use."); + return RRefContext::context_; +} + +RRefContext::RRefContext(std::shared_ptr agent) + : agent_(std::move(agent)) {} + +worker_id_t RRefContext::getWorkerId() const { + return agent_->getWorkerId().id_; +} + +RRefId RRefContext::genRRefId() { + return RRefId(getWorkerId(), nextLocalId_++); +} + +const std::shared_ptr& RRefContext::agent() const { + return agent_; +} + +void RRefContext::addFork(const at::IValue& value) { + auto rfd = RRefForkData::fromIValue(value); + AT_ASSERT( + rfd.ownerId_ == getWorkerId(), + "RRef user should never receive fork notification."); + std::lock_guard lock(mutex_); + auto& rrefForks = forks_[rfd.rrefId_]; + AT_ASSERT( + rrefForks.find(rfd.forkId_) == rrefForks.end(), + "Got fork notification twice on the same RRef ", + rfd.rrefId_); + rrefForks.insert(rfd.forkId_); +} + +void RRefContext::delFork(const at::IValue& value) { + auto rfd = RRefForkData::fromIValue(value); + AT_ASSERT( + rfd.ownerId_ == getWorkerId(), + "RRef user should never receive delete notification."); + std::lock_guard lock(mutex_); + auto& rrefForks = forks_[rfd.rrefId_]; + AT_ASSERT( + rrefForks.find(rfd.forkId_) != rrefForks.end(), + "Attempt to delete a non-exist fork ", + rfd.forkId_); + rrefForks.erase(rfd.forkId_); + if (rrefForks.empty()) { + owners_.erase(rfd.rrefId_); + forks_.erase(rfd.rrefId_); + } +} + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/rref_context.h b/torch/csrc/distributed/rpc/rref_context.h new file mode 100644 index 0000000000000..e18967416eb24 --- /dev/null +++ b/torch/csrc/distributed/rpc/rref_context.h @@ -0,0 +1,116 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace distributed { +namespace rpc { + +// Manages RRef lifetime and keeps track of RRef forks. +class RRefContext { + public: + static void initInstance(std::shared_ptr); + static std::unique_ptr& getInstance(); + + RRefContext(const RRefContext&) = delete; + void operator=(const RRefContext&) = delete; + + worker_id_t getWorkerId() const; + RRefId genRRefId(); + const std::shared_ptr& agent() const; + + // create a new RRef + template + std::shared_ptr> createOwnerRRef(worker_id_t ownerId) { + TORCH_CHECK(ownerId == getWorkerId(), "Cannot create OwnerRRef on user."); + return getOrCreateOwnerRRef(genRRefId()); + } + + std::shared_ptr createUserRRef(worker_id_t ownerId) { + TORCH_CHECK(ownerId != getWorkerId(), "Cannot create UserRRef on owner."); + return createUserRRef(ownerId, genRRefId(), genRRefId()); + } + + std::shared_ptr createUserRRef( + worker_id_t ownerId, + const RRefId& rrefId, + const ForkId& forkId) { + TORCH_CHECK( + ownerId != getWorkerId(), "RRef owner cannot create user RRef."); + // RRefContext does not track user RRefs, it will be destructed when there + // is no shared_ptrs pointing to it. NB: cannot use make_shared here as the + // constructor of UserRRef is private + return std::shared_ptr(new UserRRef(ownerId, rrefId, forkId)); + } + + // get an existing RRef or create a new one from a serialized + // ``RRefForkData``. + template + std::shared_ptr getOrCreateRRef(at::IValue&& value) { + auto rfd = RRefForkData::fromIValue(std::move(value)); + return getOrCreateRRef(rfd.ownerId_, rfd.rrefId_, rfd.forkId_); + } + + template + std::shared_ptr getOrCreateRRef( + worker_id_t ownerId, + const RRefId& rrefId, + const ForkId& forkId) { + if (ownerId == getWorkerId()) { + return getOrCreateOwnerRRef(rrefId); + } else { + return createUserRRef(ownerId, rrefId, forkId); + } + } + + template + std::shared_ptr> getOrCreateOwnerRRef(const RRefId& rrefId) { + std::lock_guard lock(mutex_); + const auto iter = owners_.find(rrefId); + if (iter == owners_.end()) { + // Scenario (1) the first time this owner knows about this RRef + // Scenario (2) This owner is also the creator. + // + // NB: cannot use make_shared here as the constructor of OwnerRRef is + // private. + auto rref = std::shared_ptr>( + new OwnerRRef(getWorkerId(), rrefId)); + owners_[rref->id()] = rref; + return rref; + + } else { + // Scenario (3) retrieving an existing RRef + return std::dynamic_pointer_cast>(iter->second); + } + } + + void addFork(const at::IValue& value); + void delFork(const at::IValue& value); + + private: + RRefContext(std::shared_ptr); + + static std::unique_ptr context_; + static std::atomic nextLocalId_; + + const std::shared_ptr agent_; + std::mutex mutex_; + // Keep OwnerRRefs alive until there is no living UserRRefs. + std::unordered_map, RRefId::Hash> owners_; + // Tracks known living UserRRefs of an OwnerRRef + std::unordered_map< + RRefId, + std::unordered_set, + RRefId::Hash> + forks_; +}; + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/script_call.cpp b/torch/csrc/distributed/rpc/script_call.cpp index 742d3b2b96e36..e8ee647e0a4bb 100644 --- a/torch/csrc/distributed/rpc/script_call.cpp +++ b/torch/csrc/distributed/rpc/script_call.cpp @@ -21,14 +21,16 @@ const std::vector& ScriptCall::stack() const { return stack_; } -Message ScriptCall::toMessage() { - std::vector ivalues; +std::vector& ScriptCall::stackRef() { + return stack_; +} + +void ScriptCall::toIValues(std::vector& ivalues) const { for (auto& value : stack_) { ivalues.push_back(value); } - if (op_) { - // builtin ops + if (op_) { // TODO: replace this with a real overload_name when FunctionSchema supports // that. ivalues.emplace_back(toString((*op_)->schema())); @@ -43,6 +45,28 @@ Message ScriptCall::toMessage() { opName.replace(0, ATEN_PREFIX_.length(), BUILTIN_OP_NAMESPACE_); ivalues.emplace_back(std::move(opName)); } +} + +std::shared_ptr ScriptCall::fromIValues( + std::vector& ivalues) { + const std::string& qualifiedName = ivalues.back().toStringRef(); + + if (qualifiedName.rfind(BUILTIN_OP_NAMESPACE_) == 0) { + ivalues.pop_back(); + const std::string& str_schema = ivalues.back().toStringRef(); + auto op = matchOperator(str_schema); + + ivalues.pop_back(); + // remove str_schema from ivalues + return op; + } else { + AT_ERROR("Unrecognized qualified name ", qualifiedName); + } +} + +Message ScriptCall::toMessage() { + std::vector ivalues; + toIValues(ivalues); std::vector tensor_table; auto payload = @@ -55,40 +79,29 @@ Message ScriptCall::toMessage() { ScriptCall ScriptCall::fromMessage(const Message& message) { auto payload = static_cast(message.payload().data()); auto payload_size = message.payload().size(); - auto value = jit::unpickle(payload, payload_size, nullptr, &message.tensors()); auto values = value.toTuple()->elements(); - - const std::string& qualifiedName = values.back().toStringRef(); - if (qualifiedName.rfind(BUILTIN_OP_NAMESPACE_) == 0) { - values.pop_back(); - - const std::string& str_schema = values.back().toStringRef(); - // extract symbol from the schema - auto schema = torch::jit::parseSchema(str_schema); - auto symbol = at::Symbol::fromQualString(schema.name()); - auto op = matchOperator(symbol, str_schema); - // remove str_schema from values - values.pop_back(); - - return ScriptCall(op, std::move(values)); - } else { - AT_ERROR("Unrecognized qualified name ", qualifiedName); - } + auto op = fromIValues(values); + return ScriptCall(op, std::move(values)); } std::shared_ptr ScriptCall::matchOperator( - at::Symbol& symbol, const std::string& str_schema) { // TODO: This is a temporary solution. We should pass enough information to // allow deterministically matched to one operator. + + // extract symbol from the schema + auto schema = torch::jit::parseSchema(str_schema); + auto symbol = at::Symbol::fromQualString(schema.name()); + for (auto op : torch::jit::getAllOperatorsFor(symbol)) { if (toString(op->schema()).compare(str_schema) == 0) { return op; } } + AT_ERROR("Cannot find matching operator for schema ", str_schema); } diff --git a/torch/csrc/distributed/rpc/script_call.h b/torch/csrc/distributed/rpc/script_call.h index 6a918fb8bdd54..4a38eed754f8e 100644 --- a/torch/csrc/distributed/rpc/script_call.h +++ b/torch/csrc/distributed/rpc/script_call.h @@ -15,22 +15,28 @@ using torch::jit::Operator; // A ScriptCall instance represents an invocation of a builtin operator for a // TorchScript function (not implemented yet). If it is a builtin operator, it // contains a shared ptr to the `Operator` and a list of arguments. -class TORCH_API ScriptCall final { +class TORCH_API ScriptCall { public: ScriptCall(std::shared_ptr op, std::vector&& args); std::shared_ptr op() const; // return the argument stack of this builtin operator const std::vector& stack() const; + std::vector& stackRef(); Message toMessage(); static ScriptCall fromMessage(const Message& message); + virtual ~ScriptCall() = default; + + protected: + virtual void toIValues(std::vector& ivalues) const; + static std::shared_ptr fromIValues( + std::vector& ivalues); + private: // Given an operator symbol and a string schema, return the matched operator. - static std::shared_ptr matchOperator( - at::Symbol& symbol, - const std::string& str_schema); + static std::shared_ptr matchOperator(const std::string& str_schema); static const std::string BUILTIN_OP_NAMESPACE_; static const std::string ATEN_PREFIX_; @@ -38,7 +44,7 @@ class TORCH_API ScriptCall final { // This field has value if this ScriptCall represents invocation of a builtin // operator. c10::optional> op_; - const std::vector stack_; + std::vector stack_; }; } // namespace rpc diff --git a/torch/csrc/distributed/rpc/script_remote_call.cpp b/torch/csrc/distributed/rpc/script_remote_call.cpp new file mode 100644 index 0000000000000..40a8638f19eca --- /dev/null +++ b/torch/csrc/distributed/rpc/script_remote_call.cpp @@ -0,0 +1,60 @@ +#include +#include + +namespace torch { +namespace distributed { +namespace rpc { + +ScriptRemoteCall::ScriptRemoteCall( + std::shared_ptr op, + std::vector&& args, + at::IValue retRRefId, + at::IValue retForkId) + : ScriptCall(std::move(op), std::move(args)), + retRRefId_(std::move(retRRefId)), + retForkId_(std::move(retForkId)) {} + +const at::IValue& ScriptRemoteCall::retRRefId() { + return retRRefId_; +} + +const at::IValue& ScriptRemoteCall::retForkId() { + return retForkId_; +} + +Message ScriptRemoteCall::toMessage() const { + std::vector ivalues; + ScriptCall::toIValues(ivalues); + ivalues.push_back(retRRefId_); + ivalues.push_back(retForkId_); + + std::vector tensor_table; + auto payload = + jit::pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table); + + return Message( + std::move(payload), std::move(tensor_table), MessageType::REMOTE_CALL); +} + +ScriptRemoteCall ScriptRemoteCall::fromMessage(const Message& message) { + auto payload = static_cast(message.payload().data()); + auto payload_size = message.payload().size(); + + auto value = + jit::unpickle(payload, payload_size, nullptr, &message.tensors()); + auto values = value.toTuple()->elements(); + + // remove the last element from values and convert it back to an RRef + auto retForkId = std::move(values.back()); + values.pop_back(); + auto retRRefId = std::move(values.back()); + values.pop_back(); + + auto op = ScriptCall::fromIValues(values); + return ScriptRemoteCall( + op, std::move(values), std::move(retRRefId), std::move(retForkId)); +} + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/script_remote_call.h b/torch/csrc/distributed/rpc/script_remote_call.h new file mode 100644 index 0000000000000..0602884f6e40d --- /dev/null +++ b/torch/csrc/distributed/rpc/script_remote_call.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch { +namespace distributed { +namespace rpc { + +using torch::jit::Operator; + +// A ScriptCall instance represents an invocation of a builtin operator for a +// TorchScript function (not implemented yet). If it is a builtin operator, it +// contains a shared ptr to the `Operator` and a list of arguments. +class TORCH_API ScriptRemoteCall final : public ScriptCall { + public: + ScriptRemoteCall( + std::shared_ptr op, + std::vector&& args, + at::IValue retRRefId, + at::IValue retForkId); + + const at::IValue& retRRefId(); + const at::IValue& retForkId(); + + Message toMessage() const; + static ScriptRemoteCall fromMessage(const Message& message); + + private: + const at::IValue retRRefId_; + const at::IValue retForkId_; +}; + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/script_rref_proto.cpp b/torch/csrc/distributed/rpc/script_rref_proto.cpp new file mode 100644 index 0000000000000..4b8be7d7d5518 --- /dev/null +++ b/torch/csrc/distributed/rpc/script_rref_proto.cpp @@ -0,0 +1,56 @@ +#include +#include + +namespace torch { +namespace distributed { +namespace rpc { + +const at::IValue& RRefMessageBase::value() { + return value_; +} + +at::IValue& RRefMessageBase::valueRef() { + return value_; +} + +Message RRefMessageBase::toMessage() const { + std::vector ivalues; + ivalues.push_back(value_); + std::vector tensor_table; + auto payload = + jit::pickle(c10::ivalue::Tuple::create(ivalues), &tensor_table); + + return Message(std::move(payload), std::move(tensor_table), type_); +} + +at::IValue RRefMessageBase::fromMessage(const Message& message) { + auto payload = static_cast(message.payload().data()); + auto payload_size = message.payload().size(); + + auto value = + jit::unpickle(payload, payload_size, nullptr, &message.tensors()); + auto values = value.toTuple()->elements(); + + AT_ASSERT(values.size() == 1, "Expect a single IValue from message."); + return std::move(values.front()); +} + +ScriptRRefFetchCall ScriptRRefFetchCall::fromMessage(const Message& message) { + return ScriptRRefFetchCall(RRefMessageBase::fromMessage(message)); +} + +ScriptRRefFetchRet ScriptRRefFetchRet::fromMessage(const Message& message) { + return ScriptRRefFetchRet(RRefMessageBase::fromMessage(message)); +} + +ScriptRRefCreate ScriptRRefCreate::fromMessage(const Message& message) { + return ScriptRRefCreate(RRefMessageBase::fromMessage(message)); +} + +ScriptRRefDelete ScriptRRefDelete::fromMessage(const Message& message) { + return ScriptRRefDelete(RRefMessageBase::fromMessage(message)); +} + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/script_rref_proto.h b/torch/csrc/distributed/rpc/script_rref_proto.h new file mode 100644 index 0000000000000..de35b7e72a251 --- /dev/null +++ b/torch/csrc/distributed/rpc/script_rref_proto.h @@ -0,0 +1,70 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch { +namespace distributed { +namespace rpc { + +// Temporary solution of RRef operations. +// TODO: Remove all these messages and use rpc + registered functions instead. +class TORCH_API RRefMessageBase { + public: + RRefMessageBase(at::IValue value, MessageType type) + : value_(std::move(value)), type_(type) {} + + const at::IValue& value(); + at::IValue& valueRef(); + + Message toMessage() const; + static at::IValue fromMessage(const Message& message); + + private: + at::IValue value_; + const MessageType type_; +}; + +// UserRRef uses this message to fetch the remote RRef value from the owner. +class TORCH_API ScriptRRefFetchCall final : public RRefMessageBase { + public: + ScriptRRefFetchCall(at::IValue rrefForkData) + : RRefMessageBase(std::move(rrefForkData), MessageType::RREF_FETCH_CALL) { + } + + static ScriptRRefFetchCall fromMessage(const Message& message); +}; + +// OwnerRRef uses this message to send the RRef value to a remote UserRRef +class TORCH_API ScriptRRefFetchRet final : public RRefMessageBase { + public: + ScriptRRefFetchRet(at::IValue value) + : RRefMessageBase(std::move(value), MessageType::RREF_FETCH_RET) {} + + static ScriptRRefFetchRet fromMessage(const Message& message); +}; + +// Creator UserRRef uses this message to notify OwnerRRef on create. +class TORCH_API ScriptRRefCreate final : public RRefMessageBase { + public: + ScriptRRefCreate(at::IValue value) + : RRefMessageBase(std::move(value), MessageType::RREF_USER_CREATE) {} + + static ScriptRRefCreate fromMessage(const Message& message); +}; + +// UserRRef (regardless of it's the creator or not) uses this message to notify +// OwnerRRef on delete. +class TORCH_API ScriptRRefDelete final : public RRefMessageBase { + public: + ScriptRRefDelete(at::IValue value) + : RRefMessageBase(std::move(value), MessageType::RREF_USER_DELETE) {} + + static ScriptRRefDelete fromMessage(const Message& message); +}; + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/types.cpp b/torch/csrc/distributed/rpc/types.cpp new file mode 100644 index 0000000000000..2071a6a59ae3c --- /dev/null +++ b/torch/csrc/distributed/rpc/types.cpp @@ -0,0 +1,54 @@ +#include + +namespace torch { +namespace distributed { +namespace rpc { + +GloballyUniqueId::GloballyUniqueId(worker_id_t createdOn, local_id_t localId) + : createdOn_(createdOn), localId_(localId) {} + +bool GloballyUniqueId::operator==(const GloballyUniqueId& other) const { + return createdOn_ == other.createdOn_ && localId_ == other.localId_; +} + +bool GloballyUniqueId::operator!=(const GloballyUniqueId& other) const { + return createdOn_ != other.createdOn_ || localId_ != other.localId_; +} + +at::IValue GloballyUniqueId::toIValue() const { + std::vector ivalues = {(int64_t)createdOn_, (int64_t)localId_}; + return c10::ivalue::Tuple::create(std::move(ivalues)); +} + +GloballyUniqueId GloballyUniqueId::fromIValue(const at::IValue& ivalue) { + auto ivalues = ivalue.toTuple()->elements(); + TORCH_CHECK( + ivalues.size() == 2, + "Constructing GloballyUniqueId from ivalue " + "expects a GenericList of two elements, but got ", + ivalues.size()); + + worker_id_t createdOn = ivalues[0].toInt(); + local_id_t localId = ivalues[1].toInt(); + + TORCH_CHECK( + createdOn < std::numeric_limits::max(), + "GloballyUniqueId createdOn out of range, got ", + createdOn); + + TORCH_CHECK( + localId < std::numeric_limits::max(), + "GloballyUniqueId localId out of range, got ", + localId); + + return GloballyUniqueId(createdOn, localId); +} + +std::ostream& operator<<(std::ostream& os, GloballyUniqueId const& globalId) { + return os << "GloballyUniqueId(" << globalId.createdOn_ << ", " + << globalId.localId_ << ")"; +} + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/distributed/rpc/types.h b/torch/csrc/distributed/rpc/types.h new file mode 100644 index 0000000000000..47de3d09f205a --- /dev/null +++ b/torch/csrc/distributed/rpc/types.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +namespace torch { +namespace distributed { +namespace rpc { + +using worker_id_t = int16_t; +using local_id_t = uint64_t; + +struct GloballyUniqueId final { + GloballyUniqueId(worker_id_t createdOn, local_id_t localId); + GloballyUniqueId(const GloballyUniqueId& other) = default; + + bool operator==(const GloballyUniqueId& other) const; + bool operator!=(const GloballyUniqueId& other) const; + + at::IValue toIValue() const; + static GloballyUniqueId fromIValue(const at::IValue&); + + struct Hash { + size_t operator()(const GloballyUniqueId& key) const { + return (uint64_t(key.createdOn_) << kLocalIdBits) | key.localId_; + } + }; + + static constexpr int kLocalIdBits = 48; + + const worker_id_t createdOn_; + const local_id_t localId_; +}; + +std::ostream& operator<<(std::ostream& os, const GloballyUniqueId& globalId); + +using RRefId = GloballyUniqueId; +using ForkId = GloballyUniqueId; + +} // namespace rpc +} // namespace distributed +} // namespace torch diff --git a/torch/csrc/generic/Storage.cpp b/torch/csrc/generic/Storage.cpp index 699d8090eae32..003724efd2dd6 100644 --- a/torch/csrc/generic/Storage.cpp +++ b/torch/csrc/generic/Storage.cpp @@ -279,13 +279,13 @@ static struct PyMemberDef THPStorage_(members)[] = { {nullptr} }; -static PyObject * THPStorage_(device)(THPStorage* self) { +static PyObject * THPStorage_(device)(THPStorage* self, void *unused) { HANDLE_TH_ERRORS return THPDevice_New(self->cdata->device()); END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(dtype)(THPStorage *self) +static PyObject * THPStorage_(dtype)(THPStorage *self, void *unused) { HANDLE_TH_ERRORS return torch::autograd::utils::wrap( diff --git a/torch/csrc/generic/StorageMethods.cpp b/torch/csrc/generic/StorageMethods.cpp index d268646615085..1a3411fe0ee48 100644 --- a/torch/csrc/generic/StorageMethods.cpp +++ b/torch/csrc/generic/StorageMethods.cpp @@ -4,14 +4,14 @@ #include #endif -static PyObject * THPStorage_(size)(THPStorage *self) +static PyObject * THPStorage_(size)(THPStorage *self, PyObject *noargs) { HANDLE_TH_ERRORS return PyLong_FromLong(THWStorage_(size)(LIBRARY_STATE self->cdata)); END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(dataPtr)(THPStorage *self) +static PyObject * THPStorage_(dataPtr)(THPStorage *self, PyObject *noargs) { HANDLE_TH_ERRORS return PyLong_FromVoidPtr(THWStorage_(data)(LIBRARY_STATE self->cdata)); @@ -25,7 +25,7 @@ static PyObject * THPStorage_(copy_)(PyObject *self, PyObject *args, PyObject *k END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(isPinned)(THPStorage *self) +static PyObject * THPStorage_(isPinned)(THPStorage *self, PyObject *noargs) { HANDLE_TH_ERRORS #if defined(USE_CUDA) @@ -36,14 +36,14 @@ static PyObject * THPStorage_(isPinned)(THPStorage *self) END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(elementSize)(THPStorage *self) +static PyObject * THPStorage_(elementSize)(THPStorage *self, PyObject *noargs) { HANDLE_TH_ERRORS return PyLong_FromLong(THWStorage_(elementSize)(LIBRARY_STATE_NOARGS)); END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(new)(THPStorage *self) +static PyObject * THPStorage_(new)(THPStorage *self, PyObject *noargs) { HANDLE_TH_ERRORS THWStoragePtr new_storage(THWStorage_(new)(LIBRARY_STATE_NOARGS)); @@ -278,7 +278,7 @@ static PyObject *THPStorage_(setFromFile)(THPStorage *self, PyObject *args) } #ifdef THC_GENERIC_FILE -PyObject * THPStorage_(getDevice)(THPStorage *self) +PyObject * THPStorage_(getDevice)(THPStorage *self, PyObject *noargs) { HANDLE_TH_ERRORS return PyLong_FromLong(THCStorage_(getDevice)(LIBRARY_STATE self->cdata)); @@ -302,7 +302,7 @@ PyObject * THPStorage_(_setCdata)(THPStorage *self, PyObject *new_cdata) } static PyMethodDef THPStorage_(methods)[] = { - {"copy_", (PyCFunction)THPStorage_(copy_), METH_VARARGS | METH_KEYWORDS, nullptr}, + {"copy_", (PyCFunction)(void(*)(void))THPStorage_(copy_), METH_VARARGS | METH_KEYWORDS, nullptr}, {"element_size", (PyCFunction)THPStorage_(elementSize), METH_NOARGS, nullptr}, {"fill_", (PyCFunction)THPStorage_(fill_), METH_O, nullptr}, {"new", (PyCFunction)THPStorage_(new), METH_NOARGS, nullptr}, @@ -311,12 +311,12 @@ static PyMethodDef THPStorage_(methods)[] = { {"data_ptr", (PyCFunction)THPStorage_(dataPtr), METH_NOARGS, nullptr}, {"is_pinned", (PyCFunction)THPStorage_(isPinned), METH_NOARGS, nullptr}, {"_write_file", (PyCFunction)THPStorage_(writeFile), METH_VARARGS, nullptr}, - {"_new_with_file", (PyCFunction)THPStorage_(newWithFile), METH_O | METH_STATIC, nullptr}, + {"_new_with_file", (PyCFunction)(void(*)(void))THPStorage_(newWithFile), METH_O | METH_STATIC, nullptr}, {"_set_from_file", (PyCFunction)THPStorage_(setFromFile), METH_VARARGS, nullptr}, #if !defined(THC_GENERIC_FILE) - {"from_buffer", (PyCFunction)THPStorage_(fromBuffer), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"from_buffer", (PyCFunction)(void(*)(void))THPStorage_(fromBuffer), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, #endif - {"from_file", (PyCFunction)THPStorage_(fromFile), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"from_file", (PyCFunction)(void(*)(void))THPStorage_(fromFile), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, #ifdef THC_GENERIC_FILE {"get_device", (PyCFunction)THPStorage_(getDevice), METH_NOARGS, nullptr}, #endif diff --git a/torch/csrc/generic/StorageSharing.cpp b/torch/csrc/generic/StorageSharing.cpp index 919b9de95d50f..edc8a1b5d6528 100644 --- a/torch/csrc/generic/StorageSharing.cpp +++ b/torch/csrc/generic/StorageSharing.cpp @@ -6,7 +6,7 @@ #include -static PyObject * THPStorage_(sharedDecref)(THPStorage *self) +static PyObject * THPStorage_(sharedDecref)(THPStorage *self, PyObject *noargs) { HANDLE_TH_ERRORS #ifndef THC_GENERIC_FILE @@ -21,7 +21,7 @@ static PyObject * THPStorage_(sharedDecref)(THPStorage *self) END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(sharedIncref)(THPStorage *self) +static PyObject * THPStorage_(sharedIncref)(THPStorage *self, PyObject *noargs) { HANDLE_TH_ERRORS #ifndef THC_GENERIC_FILE @@ -69,7 +69,7 @@ static PyObject * THPStorage_(pyNewFilenameStorage)(PyObject *_unused, PyObject END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(shareFilename)(THPStorage *self) +static PyObject * THPStorage_(shareFilename)(THPStorage *self, PyObject *noargs) { HANDLE_TH_ERRORS THWStorage *storage = self->cdata; @@ -150,7 +150,7 @@ static PyObject * THPStorage_(pyNewFdStorage)(PyObject *_unused, PyObject *args) END_HANDLE_TH_ERRORS } -static PyObject * THPStorage_(shareFd)(THPStorage *self) +static PyObject * THPStorage_(shareFd)(THPStorage *self, PyObject *noargs) { HANDLE_TH_ERRORS THWStorage *storage = self->cdata; @@ -212,7 +212,7 @@ static PyObject * THPStorage_(newSharedFd)(PyObject *_unused, PyObject *args) #else // THC_GENERIC_FILE -static PyObject * THPStorage_(shareCuda)(THPStorage *self) +static PyObject * THPStorage_(shareCuda)(THPStorage *self, PyObject *noargs) { HANDLE_TH_ERRORS THWStorage *storage = self->cdata; @@ -496,7 +496,7 @@ PyObject * THPStorage_(expired)(PyObject *_unused, PyObject *arg) END_HANDLE_TH_ERRORS } -PyObject * THPStorage_(sharedFd)(THPStorage *self) +PyObject * THPStorage_(sharedFd)(THPStorage *self, PyObject *noargs) { HANDLE_TH_ERRORS THMapAllocator *ctx = nullptr; @@ -510,7 +510,7 @@ PyObject * THPStorage_(sharedFd)(THPStorage *self) END_HANDLE_TH_ERRORS } -PyObject * THPStorage_(isShared)(THPStorage *self) +PyObject * THPStorage_(isShared)(THPStorage *self, PyObject *noargs) { #ifdef THC_GENERIC_FILE Py_RETURN_TRUE; @@ -525,22 +525,22 @@ PyObject * THPStorage_(isShared)(THPStorage *self) } static PyMethodDef THPStorage_(sharingMethods)[] = { - {"_new_with_weak_ptr", (PyCFunction)THPStorage_(newWithWeakPtr), METH_O | METH_CLASS, nullptr}, + {"_new_with_weak_ptr", (PyCFunction)(void(*)(void))THPStorage_(newWithWeakPtr), METH_O | METH_CLASS, nullptr}, #ifdef THC_GENERIC_FILE {"_share_cuda_", (PyCFunction)THPStorage_(shareCuda), METH_NOARGS, nullptr}, - {"_new_shared_cuda", (PyCFunction)THPStorage_(newSharedCuda), METH_VARARGS | METH_STATIC, nullptr}, - {"_release_ipc_counter", (PyCFunction)THPStorage_(releaseIPCCounter), METH_VARARGS | METH_STATIC, nullptr}, + {"_new_shared_cuda", (PyCFunction)(void(*)(void))THPStorage_(newSharedCuda), METH_VARARGS | METH_STATIC, nullptr}, + {"_release_ipc_counter", (PyCFunction)(void(*)(void))THPStorage_(releaseIPCCounter), METH_VARARGS | METH_STATIC, nullptr}, #else {"_share_fd_", (PyCFunction)THPStorage_(shareFd), METH_NOARGS, nullptr}, - {"_new_shared_fd", (PyCFunction)THPStorage_(newSharedFd), METH_VARARGS | METH_STATIC, nullptr}, - {"_new_using_fd", (PyCFunction)THPStorage_(pyNewFdStorage), METH_VARARGS | METH_STATIC, nullptr}, + {"_new_shared_fd", (PyCFunction)(void(*)(void))THPStorage_(newSharedFd), METH_VARARGS | METH_STATIC, nullptr}, + {"_new_using_fd", (PyCFunction)(void(*)(void))THPStorage_(pyNewFdStorage), METH_VARARGS | METH_STATIC, nullptr}, {"_share_filename_", (PyCFunction)THPStorage_(shareFilename), METH_NOARGS, nullptr}, - {"_new_shared_filename", (PyCFunction)THPStorage_(newSharedFilename), METH_VARARGS | METH_STATIC, nullptr}, - {"_new_using_filename", (PyCFunction)THPStorage_(pyNewFilenameStorage), METH_VARARGS | METH_STATIC, nullptr}, + {"_new_shared_filename", (PyCFunction)(void(*)(void))THPStorage_(newSharedFilename), METH_VARARGS | METH_STATIC, nullptr}, + {"_new_using_filename", (PyCFunction)(void(*)(void))THPStorage_(pyNewFilenameStorage), METH_VARARGS | METH_STATIC, nullptr}, #endif {"_weak_ref", (PyCFunction)THPStorage_(weakRef), METH_NOARGS, nullptr}, - {"_free_weak_ref", (PyCFunction)THPStorage_(freeWeakRef), METH_O | METH_STATIC, nullptr}, - {"_expired", (PyCFunction)THPStorage_(expired), METH_O | METH_STATIC, nullptr}, + {"_free_weak_ref", (PyCFunction)(void(*)(void))THPStorage_(freeWeakRef), METH_O | METH_STATIC, nullptr}, + {"_expired", (PyCFunction)(void(*)(void))THPStorage_(expired), METH_O | METH_STATIC, nullptr}, {"_shared_decref", (PyCFunction)THPStorage_(sharedDecref), METH_NOARGS, nullptr}, {"_shared_incref", (PyCFunction)THPStorage_(sharedIncref), METH_NOARGS, nullptr}, {"_get_shared_fd", (PyCFunction)THPStorage_(sharedFd), METH_NOARGS, nullptr}, diff --git a/torch/csrc/generic/serialization.cpp b/torch/csrc/generic/serialization.cpp index f4e47a436c8c9..7e2c8685082e6 100644 --- a/torch/csrc/generic/serialization.cpp +++ b/torch/csrc/generic/serialization.cpp @@ -22,7 +22,13 @@ void THPStorage_(writeFileRaw)(THWStorage *self, io fd) data = (scalar_t*)cpu_data.get(); THCudaCheck(cudaMemcpy(data, THWStorage_(data)(LIBRARY_STATE self), size * sizeof(scalar_t), cudaMemcpyDeviceToHost)); #endif - doWrite(fd, &size, sizeof(int64_t)); + if (THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) + doWrite(fd, &size, sizeof(int64_t)); + else { + int64_t nsize; // convert big endian cpu to little endian storage + THP_encodeInt64Buffer((uint8_t*)&nsize, (const int64_t *)&size, THPByteOrder::THP_LITTLE_ENDIAN, 1); + doWrite(fd, &nsize, sizeof(int64_t)); + } // fast track for bytes and little endian if (sizeof(scalar_t) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) { doWrite(fd, data, sizeof(scalar_t) * size); @@ -68,6 +74,11 @@ THWStorage * THPStorage_(readFileRaw)(io file, THWStorage *_storage) scalar_t *data; int64_t size; doRead(file, &size, sizeof(int64_t)); + if (THP_nativeByteOrder() == THPByteOrder::THP_BIG_ENDIAN) { + int64_t nsize; // convert little endian storage to big endian cpu + nsize = size; + THP_decodeInt64Buffer(&size, (const uint8_t*)&nsize, THP_nativeByteOrder(), 1); + } THWStoragePtr storage; if (_storage == nullptr) { storage = THWStorage_(newWithSize)(LIBRARY_STATE size); @@ -100,17 +111,17 @@ THWStorage * THPStorage_(readFileRaw)(io file, THWStorage *_storage) if (sizeof(scalar_t) == 2) { THP_decodeInt16Buffer((int16_t*)data + i, le_buffer.get(), - THPByteOrder::THP_LITTLE_ENDIAN, + THP_nativeByteOrder(), to_convert); } else if (sizeof(scalar_t) == 4) { THP_decodeInt32Buffer((int32_t*)data + i, le_buffer.get(), - THPByteOrder::THP_LITTLE_ENDIAN, + THP_nativeByteOrder(), to_convert); } else if (sizeof(scalar_t) == 8) { THP_decodeInt64Buffer((int64_t*)data + i, le_buffer.get(), - THPByteOrder::THP_LITTLE_ENDIAN, + THP_nativeByteOrder(), to_convert); } } diff --git a/torch/csrc/jit/docs/OVERVIEW.md b/torch/csrc/jit/docs/OVERVIEW.md index 4f54f21f5104f..daa5999b3f6ec 100644 --- a/torch/csrc/jit/docs/OVERVIEW.md +++ b/torch/csrc/jit/docs/OVERVIEW.md @@ -1026,6 +1026,32 @@ with prim::DifferentiableGraph_0 = graph(%13 : Float(*, *), return (%hy, %cy) ``` +## JIT Logging ## + +Logging is a very useful debugging technique, especially in the context of compilers. Compilers perform a series of passes and analyses and logging can help to trace issues such as wrong results or segmentation faults +all the way back to the original erroneous transformation. + +`TorchScript` offers a simple logging facility that can enabled by setting an environment variable `PYTORCH_JIT_LOG_LEVEL`. + +Logging is enabled on a per file basis. To enable logging in `dead_code_elimination.cpp`, `PYTORCH_JIT_LOG_LEVEL` should be +set to `dead_code_elimination.cpp` or, simply, to `dead_code_elimination` (i.e. `PYTORCH_JIT_LOG_LEVEL=dead_code_elimination`). + +Multiple files can be logged by separating each file name with a colon `:` as in the following example, `PYTORCH_JIT_LOG_LEVEL=dead_code_elimination:guard_elimination` + +There are 3 logging levels available for your use ordered by the detail level from lowest to highest. + +* `GRAPH_DUMP` should be used for printing entire graphs after optimization passes +* `GRAPH_UPDATE` should be used for reporting graph transformations (i.e. node deletion, constant folding, etc) +* `GRAPH_DEBUG` should be used for providing information useful for debugging + the internals of a particular optimization pass or analysis + +The current logging level is `GRAPH_UPDATE` meaning that both `GRAPH_DUMP` and `GRAPH_UPDATE` will be enabled when +one specifies a file(s) in `PYTORCH_JIT_LOG_LEVEL`. + +`GRAPH_DEBUG` can be enabled by prefixing a file name with an `>` as in `>alias_analysis`. +`>>` and `>>>` are also valid and **currently** are equivalent to `GRAPH_DEBUG` as there is no logging level that is +higher than `GRAPH_DEBUG`. + ## DifferentiableGraphOp ## [graph_executor.cpp](../graph_executor.cpp) diff --git a/torch/csrc/jit/function.h b/torch/csrc/jit/function.h index 73e3d06dfdd3e..0fc79657d11b8 100644 --- a/torch/csrc/jit/function.h +++ b/torch/csrc/jit/function.h @@ -21,9 +21,7 @@ struct TORCH_API Function { graph_(std::move(graph)), function_creator_(std::move(function_creator)) {} - void run(Stack& stack) { - get_executor().run(stack); - } + void run(Stack &stack) { get_executor().run(stack); } void run(Stack&& stack) { run(stack); diff --git a/torch/csrc/jit/fuser/cpu/fused_kernel.cpp b/torch/csrc/jit/fuser/cpu/fused_kernel.cpp index ae662ad6a2e12..6784f4ceb3bae 100644 --- a/torch/csrc/jit/fuser/cpu/fused_kernel.cpp +++ b/torch/csrc/jit/fuser/cpu/fused_kernel.cpp @@ -1,10 +1,15 @@ #include #include +#include #include #include #include #include +#ifdef _MSC_VER +#include +#endif + #include #include #include @@ -16,9 +21,32 @@ namespace jit { namespace fuser { namespace cpu { +#ifdef _MSC_VER +static const std::string getTempPath() { + char lpTempPathBuffer[MAX_PATH]; + + DWORD dwRetVal = GetTempPath( + MAX_PATH, // length of the buffer + lpTempPathBuffer); // buffer for path + + TORCH_CHECK(dwRetVal < MAX_PATH && dwRetVal != 0, "GetTempPath failed."); + + return std::string(lpTempPathBuffer); +} +static const std::string temp_dir = getTempPath(); +static const std::string so_template = temp_dir + "pytorch_fuserXXXXXX.dll"; +static const std::string cpp_template = temp_dir + "pytorch_fuserXXXXXX.cpp"; +static const std::string check_exists_string = "where \"${program}\" > nul 2> nul"; +static std::vector env_list; +constexpr int so_suffix_len = 4; +constexpr int cpp_suffix_len = 4; +#else static const std::string so_template = "/tmp/pytorch_fuserXXXXXX.so"; static const std::string cpp_template = "/tmp/pytorch_fuserXXXXXX.cpp"; static const std::string check_exists_string = "which '${program}' > /dev/null"; +constexpr int so_suffix_len = 3; +constexpr int cpp_suffix_len = 4; +#endif static bool programExists(const std::string& program) { TemplateEnv env; @@ -27,6 +55,117 @@ static bool programExists(const std::string& program) { return (system(cmd.c_str()) == 0); } +#ifdef _MSC_VER +c10::optional exec(const std::string& cmd) { + std::array buffer; + std::string result; + std::unique_ptr pipe( + _popen(cmd.c_str(), "r"), _pclose); + if (!pipe) { + return c10::nullopt; + } + while (fgets(buffer.data(), static_cast(buffer.size()), pipe.get()) != nullptr) { + result += buffer.data(); + } + return result; +} + +inline std::string& rtrim(std::string& s, const char* t = " \t\n\r\f\v") { + s.erase(s.find_last_not_of(t) + 1); + return s; +} + +void activate() { + char* root = nullptr; + std::string cmd; + c10::optional exec_out; + std::string path; + std::string vcruntime_plat; + std::string envvars; + + // Checking whether the environment is already activated + if (getenv("VSCMD_ARG_TGT_ARCH")) { + return; + } + + // Getting `ProgramFiles` through environment variable queries + root = getenv("ProgramFiles(x86)"); + if (!root) { + root = getenv("ProgramFiles"); + } + if (!root) { + return; + } + + // Getting VS 2017 installation path through `vswhere` + cmd = "\"" + std::string(root) + + "\\Microsoft Visual Studio\\Installer\\vswhere.exe\"" + " -latest -prerelease -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath"; + exec_out = exec(cmd); + if (!exec_out) { + return; + } + path = *exec_out; + rtrim(path); + + // Checking whether the activation script `vcvarsall.bat` exists + path += "\\VC\\Auxiliary\\Build"; + struct stat st; + if (stat(path.c_str(), &st) == -1 || !(st.st_mode & _S_IFDIR)) { + return; + } + path += "\\vcvarsall.bat"; + if (_access(path.c_str(), 0) == -1) { + return; + } + + // Determining current platform + if (sizeof(void*) == 8) { + vcruntime_plat = "x64"; + } else { + vcruntime_plat = "x86"; + } + + // Getting environment variables after activating VS development shell + cmd = "\"" + path + "\" " + vcruntime_plat + ">NUL && set"; + exec_out = exec(cmd); + if (!exec_out) { + return; + } + envvars = *exec_out; + + // Setting environment variables to the current environment + std::istringstream f(envvars); + std::string envvar; + while (getline(f, envvar, '\n')) { + env_list.push_back(envvar); + } +} + +intptr_t run(const std::string& cmd) { + // Getting the path of `cmd.exe` + char* comspec = getenv("COMSPEC"); + if (!comspec) { + comspec = "C:\\Windows\\System32\\cmd.exe"; + } + // Constructing the command line + const char* a[] = {"/c", cmd.c_str()}; + // Constructing the env array + // If `env_list` is not empty, then add char pointers ending with nullptr. + // Otherwise, it will be nullptr, which implies the default env. + std::vector e; + if (!env_list.empty()) { + for (auto& s : env_list) { + e.push_back(s.c_str()); + } + e.push_back(nullptr); + } + // Running the command + intptr_t r = _spawnve(_P_WAIT, comspec, a, e.data()); + return r; +} +#endif + // A single compiler config is accessed through getConfig() (below) // Controls compilation options and may be updated based on the result // of compilation attempts. @@ -37,6 +176,10 @@ struct CompilerConfig { cxx = cxx_env; } +#ifdef _MSC_VER + activate(); +#endif + if (!programExists(cxx)) { cxx = ""; } @@ -44,7 +187,13 @@ struct CompilerConfig { ~CompilerConfig() = default; - std::string cxx = "g++"; // compiler location + #ifdef _MSC_VER + std::string cxx = "cl"; + const std::string openmp_flags = "/openmp"; + #else + std::string cxx = "g++"; + const std::string openmp_flags = "-fopenmp"; + #endif bool openmp = true; }; @@ -63,24 +212,46 @@ static CompilerConfig& getConfig() { // understand for AVX512. When we need better CPU performance this // optimization can be re-enabled by tracking down the platforms where // this error occurs and only selectively disabling it. +#ifdef _MSC_VER +static std::string getArchFlags() { + if (InstructionSet::AVX512F()) { + return "/arch:AVX512"; + } else if (InstructionSet::AVX2()) { + return "/arch:AVX2"; + } else if (InstructionSet::AVX()) { + return "/arch:AVX"; + } else { + return ""; + } +} +static const std::string arch_flags = getArchFlags(); +static const std::string compile_string = + "cd /D \"" + temp_dir + "\" && " + "${cxx} /nologo /MD /Ox " + arch_flags + " /LD /EHsc " + "${fopenmp} \"${cpp_file}\" /link /out:\"${so_file}\""; +#else static const std::string compile_string = "\"${cxx}\" -O3 -g " #ifndef __PPC64__ // "-march=native " #endif "-std=c++11 -fPIC ${fopenmp} -shared \"${cpp_file}\" -o \"${so_file}\" -lm"; - +#endif static void runCompiler( const std::string& cpp_file, const std::string& so_file) { auto& config = getConfig(); TemplateEnv env; env.s("cxx", config.cxx); - env.s("fopenmp", config.openmp ? "-fopenmp" : ""); + env.s("fopenmp", config.openmp ? config.openmp_flags : ""); env.s("cpp_file", cpp_file); env.s("so_file", so_file); std::string result = format(compile_string, env); +#ifdef _MSC_VER + intptr_t r = run(result); +#else int r = system(result.c_str()); +#endif if (config.openmp && r != 0) { std::cerr << "warning: pytorch jit fuser failed to compile with openmp, trying without it...\n"; @@ -90,7 +261,11 @@ static void runCompiler( TORCH_CHECK(r == 0, "Failed to compile a fused CPU kernel"); } +#ifdef _MSC_VER +static const std::string disas_string = "dumpbin /DISASM:NOBYTES \"${so_file}\""; +#else static const std::string disas_string = "objdump -M intel -d \"${so_file}\""; +#endif static void disas(const std::string& so_file) { TemplateEnv env; env.s("so_file", so_file); @@ -115,10 +290,14 @@ FusedKernelCPU::FusedKernelCPU( std::move(chunk_desc), std::move(concat_desc), has_random) { - TempFile so_file(so_template, 3); - TempFile cpp_file(cpp_template, 4); + TempFile so_file(so_template, so_suffix_len); + TempFile cpp_file(cpp_template, cpp_suffix_len); cpp_file.write(code_); cpp_file.sync(); +#ifdef _MSC_VER + so_file.close(); + cpp_file.close(); +#endif runCompiler(cpp_file.name(), so_file.name()); if (debugFuser() >= 2) disas(so_file.name()); diff --git a/torch/csrc/jit/fuser/cpu/msvc_arch.h b/torch/csrc/jit/fuser/cpu/msvc_arch.h new file mode 100644 index 0000000000000..a45102f5f90b4 --- /dev/null +++ b/torch/csrc/jit/fuser/cpu/msvc_arch.h @@ -0,0 +1,108 @@ +// Example code extracted from MSDN page of __cpuidex + +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace fuser { +namespace cpu { + +class InstructionSet { + // forward declarations + class InstructionSet_Internal; + + public: + // getters + + static bool AVX(void) { + return CPU_Rep.f_1_ECX_[28]; + } + static bool AVX2(void) { + return CPU_Rep.f_7_EBX_[5]; + } + static bool AVX512F(void) { + return CPU_Rep.f_7_EBX_[16]; + } + + private: + static const InstructionSet_Internal CPU_Rep; + + class InstructionSet_Internal { + public: + InstructionSet_Internal() + : nIds_{0}, + nExIds_{0}, + f_1_ECX_{0}, + f_1_EDX_{0}, + f_7_EBX_{0}, + f_7_ECX_{0}, + f_81_ECX_{0}, + f_81_EDX_{0}, + data_{}, + extdata_{} { + // int cpuInfo[4] = {-1}; + std::array cpui; + + // Calling __cpuid with 0x0 as the function_id argument + // gets the number of the highest valid function ID. + __cpuid(cpui.data(), 0); + nIds_ = cpui[0]; + + for (int i = 0; i <= nIds_; ++i) { + __cpuidex(cpui.data(), i, 0); + data_.push_back(cpui); + } + + // load bitset with flags for function 0x00000001 + if (nIds_ >= 1) { + f_1_ECX_ = data_[1][2]; + f_1_EDX_ = data_[1][3]; + } + + // load bitset with flags for function 0x00000007 + if (nIds_ >= 7) { + f_7_EBX_ = data_[7][1]; + f_7_ECX_ = data_[7][2]; + } + + // Calling __cpuid with 0x80000000 as the function_id argument + // gets the number of the highest valid extended ID. + __cpuid(cpui.data(), 0x80000000); + nExIds_ = cpui[0]; + + for (int i = 0x80000000; i <= nExIds_; ++i) { + __cpuidex(cpui.data(), i, 0); + extdata_.push_back(cpui); + } + + // load bitset with flags for function 0x80000001 + if (nExIds_ >= 0x80000001) { + f_81_ECX_ = extdata_[1][2]; + f_81_EDX_ = extdata_[1][3]; + } + }; + + int nIds_; + int nExIds_; + std::bitset<32> f_1_ECX_; + std::bitset<32> f_1_EDX_; + std::bitset<32> f_7_EBX_; + std::bitset<32> f_7_ECX_; + std::bitset<32> f_81_ECX_; + std::bitset<32> f_81_EDX_; + std::vector> data_; + std::vector> extdata_; + }; +}; + +// Initialize static member data +const InstructionSet::InstructionSet_Internal InstructionSet::CPU_Rep; + +} // namespace cpu +} // namespace fuser +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/fuser/cpu/resource_strings.h b/torch/csrc/jit/fuser/cpu/resource_strings.h index dfc701b9dd7b0..54ec659a3b522 100644 --- a/torch/csrc/jit/fuser/cpu/resource_strings.h +++ b/torch/csrc/jit/fuser/cpu/resource_strings.h @@ -53,11 +53,34 @@ float fracf(float x) { ${type_declarations} +#ifdef _MSC_VER +template struct int_of_size; + +#define DEFINE_INT_OF_SIZE(int_t) \ +template<> struct int_of_size { using type = int_t; } + +DEFINE_INT_OF_SIZE(int64_t); +DEFINE_INT_OF_SIZE(int32_t); +DEFINE_INT_OF_SIZE(int16_t); +DEFINE_INT_OF_SIZE(int8_t); + +#undef DEFINE_INT_OF_SIZE + +template +using int_same_size_t = typename int_of_size::type; + +#define IndexTypeLoop int_same_size_t +#define ToIndexTypeLoop(x) static_cast(x) +#else +#define IndexTypeLoop IndexType +#define ToIndexTypeLoop(x) x +#endif + #define OMP_THRESHOLD 100000 static void ${kernelName}_kernel(IndexType totalElements, ${formals}) { #pragma omp parallel for if(totalElements > OMP_THRESHOLD) - for (IndexType linearIndex = 0; - linearIndex < totalElements; + for (IndexTypeLoop linearIndex = 0; + linearIndex < ToIndexTypeLoop(totalElements); linearIndex += 1) { // Convert `linearIndex` into an offset of tensor: ${tensorOffsets} @@ -66,8 +89,14 @@ static void ${kernelName}_kernel(IndexType totalElements, ${formals}) { } } +#ifdef _WIN32 +#define JIT_API __declspec(dllexport) +#else +#define JIT_API +#endif + extern "C" -void ${kernelName}(IndexType totalElements, void ** args) { +JIT_API void ${kernelName}(IndexType totalElements, void ** args) { ${kernelName}_kernel(totalElements ${,argument_loads}); } )"); diff --git a/torch/csrc/jit/fuser/cpu/temp_file.h b/torch/csrc/jit/fuser/cpu/temp_file.h index 009f66e2ba308..954718f0f8608 100644 --- a/torch/csrc/jit/fuser/cpu/temp_file.h +++ b/torch/csrc/jit/fuser/cpu/temp_file.h @@ -5,7 +5,18 @@ #include #include +#ifdef _WIN32 +#include +#include +#include +#include +#include +#include +#include +#include +#else #include +#endif #include #include @@ -15,6 +26,39 @@ namespace jit { namespace fuser { namespace cpu { +#ifdef _MSC_VER +int mkstemps(char* tmpl, int suffix_len) { + int len; + char* name; + int fd = -1; + int save_errno = errno; + + len = strlen(tmpl); + if (len < 6 + suffix_len || + strncmp(&tmpl[len - 6 - suffix_len], "XXXXXX", 6)) { + return -1; + } + + name = &tmpl[len - 6 - suffix_len]; + + std::random_device rd; + do { + for (unsigned i = 0; i < 6; ++i) { + name[i] = "abcdefghijklmnopqrstuvwxyz0123456789"[rd() % 36]; + } + + fd = _open(tmpl, _O_RDWR | _O_CREAT | _O_EXCL, _S_IWRITE | _S_IREAD); + } while (errno == EEXIST); + + if (fd >= 0) { + errno = save_errno; + return fd; + } else { + return -1; + } +} +#endif + struct TempFile { TH_DISALLOW_COPY_AND_ASSIGN(TempFile); @@ -24,7 +68,11 @@ struct TempFile { std::vector tt(t.c_str(), t.c_str() + t.size() + 1); int fd = mkstemps(tt.data(), suffix); AT_ASSERT(fd != -1); + #ifdef _MSC_VER + file_ = _fdopen(fd, "r+"); + #else file_ = fdopen(fd, "r+"); + #endif // - 1 becuase tt.size() includes the null terminator, // but std::string does not expect one @@ -44,17 +92,35 @@ struct TempFile { AT_ASSERT(str.size() == result); } +#ifdef _MSC_VER + void close() { + if (file_ != nullptr) { + fclose(file_); + } + file_ = nullptr; + } +#endif + FILE* file() { return file_; } ~TempFile() { +#ifdef _MSC_VER + if (file_ != nullptr) { + fclose(file_); + } + if (!name_.empty() && _access(name_.c_str(), 0) != -1) { + _unlink(name_.c_str()); + } +#else if (file_ != nullptr) { // unlink first to ensure another mkstemps doesn't // race between close and unlink unlink(name_.c_str()); fclose(file_); } +#endif } private: diff --git a/torch/csrc/jit/fuser/cuda/fused_kernel.cpp b/torch/csrc/jit/fuser/cuda/fused_kernel.cpp index 1d7ca3f2e4a55..4f2f7ebb9b045 100644 --- a/torch/csrc/jit/fuser/cuda/fused_kernel.cpp +++ b/torch/csrc/jit/fuser/cuda/fused_kernel.cpp @@ -108,13 +108,17 @@ FusedKernelCUDA::FusedKernelCUDA( AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram( &program, code_.c_str(), nullptr, 0, nullptr, nullptr)); +#ifdef __HIP_PLATFORM_HCC__ + std::vector args = {}; +#else const std::string compute = "--gpu-architecture=compute_" + std::to_string(major) + std::to_string(minor); const std::vector args = { "--std=c++11", compute.c_str(), "-default-device"}; +#endif const auto result = nvrtc().nvrtcCompileProgram(program, args.size(), args.data()); - if (result == NVRTC_ERROR_COMPILATION) { + if (result != NVRTC_SUCCESS) { size_t logsize; nvrtc().nvrtcGetProgramLogSize(program, &logsize); std::vector log(logsize); @@ -136,9 +140,14 @@ FusedKernelCUDA::FusedKernelCUDA( nvrtc().cuModuleGetFunction(&function_, module_, name_.c_str())); // Computes max blocks +#ifdef __HIP_PLATFORM_HCC__ + // XXX this is a temporary hack until the occupancy API is supported in ROCm + maxBlocks_ = 16 * prop_->multiProcessorCount; +#else AT_CUDA_DRIVER_CHECK(nvrtc().cuOccupancyMaxActiveBlocksPerMultiprocessor( &maxBlocks_, function_, 128, 0)); maxBlocks_ *= prop_->multiProcessorCount; +#endif // Resets device (end of hacked at::DeviceGuard) at::cuda::set_device(prior_device); diff --git a/torch/csrc/jit/fuser/cuda/resource_strings.h b/torch/csrc/jit/fuser/cuda/resource_strings.h index b45d9ca06caf3..c623348f7f072 100644 --- a/torch/csrc/jit/fuser/cuda/resource_strings.h +++ b/torch/csrc/jit/fuser/cuda/resource_strings.h @@ -13,6 +13,28 @@ tensor as input. Correct code for this case is generated, however, nvrtc does not know how to handle int*_t integer types, so typedefs help it handle those cases*/ +#ifdef __HIP_PLATFORM_HCC__ +static auto type_declarations_template = CodeTemplate(R"( +#include +${HalfHeader} +${RandHeader} + +#define POS_INFINITY INFINITY +#define NEG_INFINITY -INFINITY + +typedef ${IndexType} IndexType; +template +struct TensorInfo { + T* data; + IndexType sizes[N]; + IndexType strides[N]; +}; +template +struct TensorInfo { + T * data; +}; +)"); +#else static auto type_declarations_template = CodeTemplate(R"( typedef unsigned char uint8_t; typedef signed char int8_t; @@ -37,6 +59,7 @@ struct TensorInfo { T * data; }; )"); +#endif // We rewrite the code for philox RNG from curand as nvrtc couldn't resolve the // curand header correctly. diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index 9b42b92ccb80a..a7f2bf536cf5e 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -545,7 +546,7 @@ struct GraphExecutorImpl : public GraphExecutorImplBase { ExecutionPlan compileSpec(const ArgumentSpec& spec) { auto opt_graph = graph->copy(); - + SOURCE_DUMP("Optimizing the following function:", opt_graph); arg_spec_creator_.specializeTypes(*opt_graph, spec); // Phase 0. Inline functions, then clean up any artifacts that the inliner diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index cc08e7697b1b2..372ff3d5acc8a 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -30,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -123,6 +125,7 @@ void initJITBindings(PyObject* module) { return paramsDict; }, pybind11::return_value_policy::move) + .def("_jit_pass_onnx_scalar_type_analysis", ScalarTypeAnalysisForONNX) .def("_jit_pass_fuse", FuseGraph) .def( "_jit_pass_dce", @@ -166,6 +169,11 @@ void initJITBindings(PyObject* module) { "_jit_pass_quant_fusion", [](std::shared_ptr& g) { return QuantFusion(g); }) .def("_jit_pass_fold_convbn", &FoldConvBatchNorm2d) + .def("_jit_pass_fuse_linear", &FuseLinear) + .def("_jit_pass_fold_quantize", + [](script::Module& module, const std::string& method_name) { + FoldQuantizeCallIntoBuffer(module, method_name); + }) .def( "_jit_pass_quantlint", [](std::shared_ptr& g) { return QuantLinting(g); }) @@ -417,6 +425,7 @@ void initJITBindings(PyObject* module) { script::parseIR(input, &*graph); return graph; }); + m.def("parse_schema", parseSchema); py::class_(m, "FunctionSchema") .def_property_readonly( @@ -428,6 +437,14 @@ void initJITBindings(PyObject* module) { "arguments", [](FunctionSchema& self) { return self.arguments(); }) .def_property_readonly( "returns", [](FunctionSchema& self) { return self.returns(); }) + .def("is_backward_compatible_with", + [](const FunctionSchema& self, const FunctionSchema& old_schema) { + return self.isBackwardCompatibleWith(old_schema); + }) + .def("__eq__", [](const FunctionSchema& self, + const FunctionSchema& other) { + return self == other; + }) .def("__str__", [](FunctionSchema& self) { std::stringstream ss; ss << self; @@ -442,11 +459,18 @@ void initJITBindings(PyObject* module) { return (self.N()) ? py::cast(*self.N()) : py::none(); }) .def_property_readonly("default_value", [](Argument& self) -> py::object { - if (!self.default_value()) - return py::none(); - IValue v = *self.default_value(); - return toPyObject(std::move(v)); - }); + if (!self.default_value()) + return py::none(); + IValue v = *self.default_value(); + return toPyObject(std::move(v)); + }); + m.def( + "_jit_get_all_schemas", []() { + const std::vector>& operations = getAllOperators(); + return fmap(operations, [](const std::shared_ptr& op) { + return op->schema(); + }); + }); m.def("_jit_get_schemas_for_operator", [](const std::string& qualified_name) { auto symbol = Symbol::fromQualString(qualified_name); auto operations = getAllOperatorsFor(symbol); @@ -481,7 +505,6 @@ void initJITBindings(PyObject* module) { Value* node_output; py::object py_func_output; - auto retval = c10::make_intrusive(); // Insert new trace ops into the fork op's sub-block WithInsertPoint guard(body_block); IValue output_ivalue; @@ -504,6 +527,9 @@ void initJITBindings(PyObject* module) { torch::jit::script::lambdaLiftFork(fork_node); } + auto retval = + c10::make_intrusive(output_ivalue.type()); + // Record the ivalue in the tracer jit::tracer::setValueTrace(retval, node_output); @@ -512,8 +538,9 @@ void initJITBindings(PyObject* module) { return PythonFutureWrapper(retval); } else { - auto retval = c10::make_intrusive(); - retval->markCompleted(toIValue(f(*args_tup))); + auto result = toIValue(f(*args_tup)); + auto retval = c10::make_intrusive(result.type()); + retval->markCompleted(std::move(result)); return PythonFutureWrapper(retval); } }); diff --git a/torch/csrc/jit/interpreter.cpp b/torch/csrc/jit/interpreter.cpp index d73d1b88b6281..eaa21b921d7de 100644 --- a/torch/csrc/jit/interpreter.cpp +++ b/torch/csrc/jit/interpreter.cpp @@ -336,6 +336,7 @@ struct CodeImpl { int register_size_ = 0; size_t n_outputs; size_t n_inputs; + TypePtr return_type_; // We MUST hold onto graph here because some Operators stored in the // instruction lists have dependencies on meta-data stored in the graph @@ -366,6 +367,12 @@ struct CodeImpl { : preprocess_(*graph), current_node_(preprocess_.graph->return_node()) { graph_ = preprocess_.graph; n_outputs = graph_->outputs().size(); + if (n_outputs == 1) { + return_type_ = graph->outputs().at(0)->type(); + } else { + return_type_ = TupleType::create( + fmap(graph->outputs(), [](const Value* v) { return v->type(); })); + } n_inputs = graph_->inputs().size(); // std::cout << *graph_ << "\n"; emitCodeForBlock(graph_->block()); @@ -1041,7 +1048,8 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { public: c10::intrusive_ptr getOrCreateFuture() { if (!future_) { - future_ = c10::make_intrusive(); + future_ = + c10::make_intrusive(frames.front().function->return_type_); } return future_; } diff --git a/torch/csrc/jit/interpreter.h b/torch/csrc/jit/interpreter.h index 45ebb055b8f8d..2dbaa0b0b2fd6 100644 --- a/torch/csrc/jit/interpreter.h +++ b/torch/csrc/jit/interpreter.h @@ -11,7 +11,8 @@ class Tensor; } namespace c10 { struct IValue; -} +struct Type; +} // namespace c10 namespace torch { namespace jit { diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 83f91e46be0a9..6a962ee05e24c 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -1293,10 +1293,6 @@ struct ProfileOp : public Node { struct TORCH_API PythonOp : public Node { using Node::Node; - // should this Python function be skipped over when exported (i.e. for - // debugging functions that only run in Python) - bool ignore_on_export = false; - virtual std::string name() const = 0; virtual void writeScalars(std::ostream& out) const = 0; void cloneFrom(Node* other_) override = 0; diff --git a/torch/csrc/jit/jit_log.cpp b/torch/csrc/jit/jit_log.cpp index 9da42d9ccb738..db84d00240f94 100644 --- a/torch/csrc/jit/jit_log.cpp +++ b/torch/csrc/jit/jit_log.cpp @@ -1,20 +1,79 @@ -#include -#include -#include -#include + #include #include #include +#include +#include +#include +#include +#include +#include +#include + namespace torch { namespace jit { -JitLoggingLevels jit_log_level() { +static std::unordered_map +parseJITLogOption(const char *option) { + + std::stringstream in_ss; + in_ss << "function:"; + if (option) { + in_ss << option; + } + + std::unordered_map files_to_levels; + std::string line; + while (std::getline(in_ss, line, ':')) { + if (line.size() == 0) { + continue; + } + + auto index_at = line.find_last_of('>'); + auto begin_index = index_at == std::string::npos ? 0 : index_at + 1; + size_t logging_level = index_at == std::string::npos ? 1 : index_at + 2; + auto end_index = line.find_last_of('.') == std::string::npos + ? line.size() + : line.find_last_of('.'); + auto filename = line.substr(begin_index, end_index - begin_index); + files_to_levels.insert({filename, logging_level}); + } + + return files_to_levels; +} + +bool is_enabled(const char *cfname, JitLoggingLevels level) { + static const char* c_log_level = std::getenv("PYTORCH_JIT_LOG_LEVEL"); - static const JitLoggingLevels log_level = c_log_level - ? static_cast(std::atoi(c_log_level)) - : JitLoggingLevels::OFF; - return log_level; + static const std::unordered_map files_to_levels = + parseJITLogOption(c_log_level); + std::string fname{cfname}; + fname = c10::detail::StripBasename(fname); + auto end_index = fname.find_last_of('.') == std::string::npos + ? fname.size() + : fname.find_last_of('.'); + auto fname_no_ext = fname.substr(0, end_index); + + auto it = files_to_levels.find(fname_no_ext); + if (it == files_to_levels.end()) { + return false; + } + + return level <= static_cast(it->second); +} + +// Unfortunately, in `GraphExecutor` where `log_function` is invoked +// we won't have access to an original function, so we have to construct +// a dummy function to give to PythonPrint +std::string log_function(const std::shared_ptr &graph) { + torch::jit::Function func("source_dump", graph, nullptr); + std::stringstream ss; + std::vector tensors; + std::vector deps; + SourceRangeRecords source_ranges; + PythonPrint(ss, source_ranges, func, false, tensors, deps, false); + return ss.str(); } std::string debugValueOrDefault(const Node* n) { @@ -25,9 +84,9 @@ std::string jit_log_prefix( const std::string& prefix, const std::string& in_str) { std::stringstream in_ss(in_str); - std::stringstream out_ss(in_str); + std::stringstream out_ss; std::string line; - while (std::getline(in_ss, line, '\n')) { + while (std::getline(in_ss, line)) { out_ss << prefix << line << std::endl; } @@ -51,9 +110,6 @@ std::string jit_log_prefix( std::ostream& operator<<(std::ostream& out, JitLoggingLevels level) { switch (level) { - case JitLoggingLevels::OFF: - TORCH_INTERNAL_ASSERT("UNREACHABLE"); - break; case JitLoggingLevels::GRAPH_DUMP: out << "DUMP"; break; diff --git a/torch/csrc/jit/jit_log.h b/torch/csrc/jit/jit_log.h index 73fb4574c613c..e612a73b8acef 100644 --- a/torch/csrc/jit/jit_log.h +++ b/torch/csrc/jit/jit_log.h @@ -1,30 +1,56 @@ #pragma once +#include #include #include -// To enable logging please set(export) PYTORCH_JIT_LOG_LEVEL to -// the ordinal value of one of the following logging levels: 1 for GRAPH_DUMP, -// 2 for GRAPH_UPDATE, 3 for GRAPH_DEBUG. -// * Use GRAPH_DUMP for dumping graphs after optimization passes -// * Use GRAPH_UPDATE for reporting graph transformations (i.e. node deletion, -// constant folding, CSE) -// * Use GRAPH_DEBUG to provide information useful for debugging +// `TorchScript` offers a simple logging facility that can enabled by setting an +// environment variable `PYTORCH_JIT_LOG_LEVEL`. + +// Logging is enabled on a per file basis. To enable logging in +// `dead_code_elimination.cpp`, `PYTORCH_JIT_LOG_LEVEL` should be +// set to `dead_code_elimination.cpp` or, simply, to `dead_code_elimination` +// (i.e. `PYTORCH_JIT_LOG_LEVEL=dead_code_elimination`). + +// Multiple files can be logged by separating each file name with a colon `:` as +// in the following example, +// `PYTORCH_JIT_LOG_LEVEL=dead_code_elimination:guard_elimination` + +// There are 3 logging levels available for your use ordered by the detail level +// from lowest to highest. + +// * `GRAPH_DUMP` should be used for printing entire graphs after optimization +// passes +// * `GRAPH_UPDATE` should be used for reporting graph transformations (i.e. +// node deletion, constant folding, etc) +// * `GRAPH_DEBUG` should be used for providing information useful for debugging // the internals of a particular optimization pass or analysis +// The current logging level is `GRAPH_UPDATE` meaning that both `GRAPH_DUMP` +// and `GRAPH_UPDATE` will be enabled when +// one specifies a file(s) in `PYTORCH_JIT_LOG_LEVEL`. + +// `GRAPH_DEBUG` can be enabled by prefixing a file name with an `>` as in +// `>alias_analysis`. +// `>>` and `>>>` are also valid and **currently** are equivalent to +// `GRAPH_DEBUG` as there is no logging level that is +// higher than `GRAPH_DEBUG`. + namespace torch { namespace jit { struct Node; +struct Graph; enum class JitLoggingLevels { - OFF, - GRAPH_DUMP, + GRAPH_DUMP = 0, GRAPH_UPDATE, GRAPH_DEBUG, }; std::string debugValueOrDefault(const Node* n); +std::string TORCH_API log_function(const std::shared_ptr &graph); + TORCH_API JitLoggingLevels jit_log_level(); // Prefix every line in a multiline string \p IN_STR with \p PREFIX. @@ -38,14 +64,19 @@ TORCH_API std::string jit_log_prefix( int l, const std::string& in_str); +TORCH_API bool is_enabled(const char *cfname, JitLoggingLevels level); + TORCH_API std::ostream& operator<<(std::ostream& out, JitLoggingLevels level); -#define JIT_LOG(level, ...) \ - if (jit_log_level() != JitLoggingLevels::OFF && jit_log_level() >= level) { \ - std::cerr << jit_log_prefix( \ - level, __FILE__, __LINE__, ::c10::str(__VA_ARGS__)); \ +#define JIT_LOG(level, ...) \ + if (is_enabled(__FILE__, level)) { \ + std::cerr << jit_log_prefix(level, __FILE__, __LINE__, \ + ::c10::str(__VA_ARGS__)); \ } +// tries to reconstruct original python source +#define SOURCE_DUMP(MSG, G) \ + JIT_LOG(JitLoggingLevels::GRAPH_DUMP, MSG, "\n", log_function(G)); // use GRAPH_DUMP for dumping graphs after optimization passes #define GRAPH_DUMP(MSG, G) \ JIT_LOG(JitLoggingLevels::GRAPH_DUMP, MSG, "\n", (G)->toString()); diff --git a/torch/csrc/jit/operator.cpp b/torch/csrc/jit/operator.cpp index a3ac20b64cb89..02b6a19ae77fd 100644 --- a/torch/csrc/jit/operator.cpp +++ b/torch/csrc/jit/operator.cpp @@ -114,6 +114,17 @@ struct OperatorRegistry { } return ret; } + + const std::vector> getAllOperators() { + std::lock_guard guard(lock); + registerPendingOperators(); + std::vector> values; + values.clear(); + for (auto & kv : operators) { + values.insert(values.end(), kv.second.begin(), kv.second.end()); + } + return values; + } }; OperatorRegistry& getRegistry() { @@ -151,6 +162,10 @@ void registerOperator(Operator&& op) { getRegistry().registerOperator(std::move(op)); } +const std::vector> getAllOperators() { + return getRegistry().getAllOperators(); +} + const std::vector>& getAllOperatorsFor(Symbol name) { return getRegistry().getOperators(name); } @@ -211,8 +226,8 @@ bool Operator::matches(const Node* node) const { TypeEnv type_env; for (size_t i = 0; i < formals.size(); ++i) { auto formal = formals[i].type(); - const MatchTypeReturn matched_type = - matchTypeVariables(formal, actuals[i]->type(), type_env); + const MatchTypeReturn matched_type = matchTypeVariables( + formal, actuals[i]->type(), type_env); if (!matched_type.success()) { return false; } diff --git a/torch/csrc/jit/operator.h b/torch/csrc/jit/operator.h index 8deba097d5ee2..0352200d8ef36 100644 --- a/torch/csrc/jit/operator.h +++ b/torch/csrc/jit/operator.h @@ -167,6 +167,7 @@ struct TORCH_API Operator { TORCH_API std::string canonicalSchemaString(const FunctionSchema& schema); +TORCH_API const std::vector> getAllOperators(); TORCH_API const std::vector>& getAllOperatorsFor( Symbol name); diff --git a/torch/csrc/jit/passes/fuse_linear.cpp b/torch/csrc/jit/passes/fuse_linear.cpp new file mode 100644 index 0000000000000..0a764dc31a692 --- /dev/null +++ b/torch/csrc/jit/passes/fuse_linear.cpp @@ -0,0 +1,52 @@ +#include +#include + +namespace torch { +namespace jit { + +void FuseLinear(std::shared_ptr& graph) { + std::string addmm_pattern = R"IR( + graph(%input, %weight, %bias, %4): + %weight_t = aten::t(%weight) + %res = aten::addmm(%bias, %input, %weight_t, %4, %4) + return (%res))IR"; + std::string matmul_add_pattern = R"IR( + graph(%input, %weight, %bias, %4): + %weight_t = aten::t(%weight) + %output = aten::matmul(%input, %weight_t) + %res = aten::add_(%output, %bias, %4) + return (%res))IR"; + std::string fused_linear = R"IR( + graph(%input, %weight, %bias, %4): + %res = aten::linear(%input, %weight, %bias) + return (%res))IR"; + + std::string matmul_pattern = R"IR( + graph(%input, %weight): + %weight_t = aten::t(%weight) + %output = aten::matmul(%input, %weight_t) + return (%output))IR"; + std::string fused_linear_bias_none = R"IR( + graph(%input, %weight): + %bias: Tensor? = prim::Constant() + %res = aten::linear(%input, %weight, %bias) + return (%res))IR"; + + // replace addmm pattern to linear + SubgraphRewriter addmm_to_linear; + addmm_to_linear.RegisterRewritePattern(addmm_pattern, fused_linear); + addmm_to_linear.runOnGraph(graph); + + // replace matmul + add pattern to linear + SubgraphRewriter matmuladd_to_linear; + matmuladd_to_linear.RegisterRewritePattern(matmul_add_pattern, fused_linear); + matmuladd_to_linear.runOnGraph(graph); + + // replace matmul with bias=None pattern to linear + SubgraphRewriter matmul_to_linear; + matmul_to_linear.RegisterRewritePattern( + matmul_pattern, fused_linear_bias_none); + matmul_to_linear.runOnGraph(graph); +} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/fuse_linear.h b/torch/csrc/jit/passes/fuse_linear.h new file mode 100644 index 0000000000000..34b08c4a7cad7 --- /dev/null +++ b/torch/csrc/jit/passes/fuse_linear.h @@ -0,0 +1,17 @@ +/** \brief Fusing linear patterns as single at::linear for easier pattern + * matching in later passes + */ +#pragma once + +#include + +namespace torch { +namespace jit { + +/** \brief Match the at::linear pattern and fuse it into a single at::linear + * This pass fuse the addmm or matmul + add generated by JIT back to linear + * This pass can be deleted once the JIT can emit the aten::linear in the future + */ +TORCH_API void FuseLinear(std::shared_ptr& graph); +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp new file mode 100644 index 0000000000000..201d32196f88e --- /dev/null +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp @@ -0,0 +1,231 @@ +#include +#include + +namespace torch { +namespace jit { + +namespace onnx { +using namespace ::c10::onnx; +} + +namespace { +class ScalarTypeHashFunction { + public: + size_t operator()(const c10::ScalarType& type) const { + return static_cast(type); + } +}; + +static const std::unordered_map scalarTypeToONNXTypeMap = { + {c10::kFloat, 1}, + {c10::kByte, 2}, + {c10::kChar, 3}, + {c10::kShort, 5}, + {c10::kInt, 6}, + {c10::kLong, 7}, + {c10::kBool, 9}, + {c10::kHalf, 10}, + {c10::kDouble, 11}, +}; + +static int64_t ScalarTypeToONNXType(const c10::ScalarType& st) { + int64_t onnx_type = -1; + const auto it = scalarTypeToONNXTypeMap.find(st); + if (it != scalarTypeToONNXTypeMap.end()) { + onnx_type = it->second; + } + return onnx_type; +} + +// For these operators, all inputs and outputs share the same scalar type. +// There is no operator-wise special case handling needed. +static const std::unordered_set standardOps = { + onnx::Add, + onnx::Sub, + onnx::Mul, + onnx::Div, + onnx::Gemm, + onnx::Pow, + onnx::Mod, +}; + +static bool IsStandardOp(const NodeKind& nkind) { + return standardOps.find(nkind) != standardOps.end(); +} + +// For these operators, all inputs share the same scalar type. +// The output scalar type is always Bool. +static const std::unordered_set comparisonOps = { + onnx::Greater, + onnx::Less, + onnx::Equal, +}; + +static bool IsComparisonOp(const NodeKind& nkind) { + return comparisonOps.find(nkind) != comparisonOps.end(); +} + +static TensorTypePtr CreateProfiledTensorTypeWithScalarType( + const TensorTypePtr& typePtr, + const c10::ScalarType& scalar_type) { + return TensorType::create( + scalar_type, + typePtr->device(), + typePtr->sizes(), + typePtr->strides(), + typePtr->requiresGrad()); +} + +static bool IsImplicitCastSupported(const NodeKind& nodeKind) { + return (standardOps.find(nodeKind) != standardOps.end() || + comparisonOps.find(nodeKind) != comparisonOps.end()); +} + +static c10::optional PromoteScalarTypes(const std::vector& types) { + if (types.empty()) { + return c10::nullopt; + } + auto st = types[0]; + for (size_t i=1; i InferExpectedScalarType(const Node* n) { + std::vector typesFromTensors; + std::vector typesFromScalars; + std::for_each(n->inputs().begin(), n->inputs().end(), [&](const Value* input){ + auto nkind = input->node()->kind(); + if (nkind == onnx::Gather && input->node()->input(0)->node()->kind() == onnx::Shape) { + // This is a special pattern generated by code like `dim_size = x.size(0)`. + // It gets converted to the below ONNX IR graph + // %1 : Long() = onnx::Constant[value={0}]() + // %2 : Tensor = onnx::Shape(%x) + // %dim_size : Long() = onnx::Gather(%2, %1) + // `dim_size` is treated in PyTorch as Scalar. + // However, in the ONNX IR graph, it is an output of onnx::Gather, + // which is by default considered as a tensor. + typesFromScalars.emplace_back(c10::kLong); + } else if (nkind == onnx::Constant) { + typesFromScalars.emplace_back(input->node()->t(attr::value).scalar_type()); + } else if (auto scalar_type = input->type()->cast()->scalarType()) { + typesFromTensors.emplace_back(*scalar_type); + } + }); + + c10::optional st = c10::nullopt; + const c10::optional output_st = n->output()->type()->cast()->scalarType(); + + if (typesFromScalars.size() == n->inputs().size()) { + // If all inputs are scalars, infer scalar_type by calling c10::promoteTypes. + st = PromoteScalarTypes(typesFromScalars); + } else if (output_st && !IsComparisonOp(n->kind())) { + // If output scalar type is available, use that. + st = output_st; + } else if (!typesFromTensors.empty()) { + // When inputs consist of tensors and scalars. In PyTorch, scalars are implicitly casted to have the + // same scalar type as input tensors. + st = typesFromTensors[0]; + if (std::any_of(typesFromTensors.begin(), typesFromTensors.end(), [&st](const c10::ScalarType& type) { + return type != st; + })) { + std::cerr << "Warning: ONNX Scalar Type Analysis - Scalar types mismatch for tensor inputs of operator " + << n->kind().toDisplayString() + << ". Please report a bug to PyTorch. " + << "The scalar type " + << c10::toString(*st) + << " of the first tensor is chosen." << std::endl; + } + } else { + // When inputs consist of only scalars. + st = PromoteScalarTypes(typesFromScalars); + } + + return st; +} + +static void UpdateScalarTypeForInputs(Node* n, const c10::ScalarType& scalar_type) { + const int64_t onnx_type = ScalarTypeToONNXType(scalar_type); + if (onnx_type < 0) { + std::cerr << "Warning: ONNX Scalar Type Analysis - Scalar type: " + << c10::toString(scalar_type) + << " of input tensor in operator: " << n->kind().toDisplayString() + << " not supported in ONNX. " << std::endl; + return; + } + + for (auto input : n->inputs()) { + auto input_tensor_type = input->type()->cast(); + auto input_scalar_type = input_tensor_type->scalarType(); + + if ((input->node()->kind() == onnx::Constant) || + (input_scalar_type && (*input_scalar_type != scalar_type))) { + if (input->node()->kind() == onnx::Constant) { + // Fix up the scalar directly instead of inserting a cast operator. + // NOTE: Keep only the else branch once constant_folding is enabled by + // default. + at::Tensor val = input->node()->t(attr::value); + at::Tensor new_val = val.to(scalar_type); + Node* const_node = n->owningGraph()->create(onnx::Constant); + const_node->t_(attr::value, new_val); + const_node->insertBefore(n); + const_node->output()->setType(TensorType::create(new_val)); + n->replaceInputWith(input, const_node->output()); + } else { + Node* cast_node = n->owningGraph()->create(onnx::Cast); + cast_node->addInput(input); + cast_node->i_(attr::to, onnx_type); + cast_node->insertBefore(n); + cast_node->output()->setType(CreateProfiledTensorTypeWithScalarType( + input_tensor_type, scalar_type)); + n->replaceInputWith(input, cast_node->output()); + } + } + } +} + +static void UpdateScalarTypeForOutput(Node* n, const c10::ScalarType& scalar_type) { + auto output_tensor_type = n->output()->type()->cast(); + n->output()->setType( + CreateProfiledTensorTypeWithScalarType(output_tensor_type, scalar_type)); +} + +static void ImplicitCastForONNX(Block* block) { + for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { + for (auto sub : it->blocks()) { + ImplicitCastForONNX(sub); + } + auto* subgraph = it->owningGraph(); + + if (IsImplicitCastSupported(it->kind())) { + auto expected_scalar_type = InferExpectedScalarType(*it); + if (expected_scalar_type) { + UpdateScalarTypeForInputs(*it, *expected_scalar_type); + if (!IsComparisonOp(it->kind())) { + UpdateScalarTypeForOutput(*it, *expected_scalar_type); + } + } + } + } + EliminateDeadCode(block, true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); +} + +// This pass tries to resolve scalar type mismatch issues between input tensors +// introduced by the implicit type conversions on scalars. +// TODO: Note that currently this pass handles traced graph only. +// More specifically, graphs that have scalar type information recorded. +// For scripted graphs we need something like scalar type propagation, +// otherwise we do not have enough information to perform the check, let alone fixes. +void ImplicitCastForONNX(const std::shared_ptr& graph) { + ImplicitCastForONNX(graph->block()); +} +} // anonymous namespace + + +void ScalarTypeAnalysisForONNX(const std::shared_ptr& graph) { + ImplicitCastForONNX(graph->block()); +} + +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.h b/torch/csrc/jit/passes/onnx/scalar_type_analysis.h new file mode 100644 index 0000000000000..860adbd3cdd2a --- /dev/null +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + + +TORCH_API void ScalarTypeAnalysisForONNX(const std::shared_ptr& graph); + +} // namespace jit +} // namespace torch \ No newline at end of file diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index 6fc71a0f3284f..99f2e1be48d89 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -950,22 +950,17 @@ struct PythonPrintPass { switch (node->kind()) { case prim::PythonOp: { auto value = static_cast(node); - if (enforce_importable_ && !value->ignore_on_export) { + if (enforce_importable_) { throw script::ErrorReport(node->sourceRange()) << "Could not export Python function call '" << value->name() << "'. Remove calls to Python functions before export. " << "Did you forget add @script or @script_method annotation? " << "If this is a nn.ModuleList, add it to __constants__"; } - - if (value->ignore_on_export) { - stmt << "ops.prim.IgnoredPythonOp"; - } else { - std::stringstream scalars_stream; - stmt << "^" << value->name(); - value->writeScalars(scalars_stream); - stmt << scalars_stream.str(); - } + std::stringstream scalars_stream; + stmt << "^" << value->name(); + value->writeScalars(scalars_stream); + stmt << scalars_stream.str(); printValueList(stmt, node->inputs(), "(", ")"); } break; case prim::Uninitialized: { diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp index 633ef4d477f33..cf6c82408b013 100644 --- a/torch/csrc/jit/passes/quantization.cpp +++ b/torch/csrc/jit/passes/quantization.cpp @@ -1,4 +1,6 @@ #include +#include +#include #include #include @@ -8,6 +10,7 @@ #include #include +#include #include namespace torch { @@ -17,24 +20,25 @@ namespace { void findValuesInPattern( Graph& graph, const std::string& pattern, - std::unordered_set& values) { + std::unordered_set& values_to_skip) { Graph pattern_graph; std::unordered_map vmap; script::parseIR(pattern, &pattern_graph, vmap); auto matches = findPatternMatches(pattern_graph, graph); - for (auto match : matches) { + for (const auto& match : matches) { auto output_value = vmap.at("output"); TORCH_INTERNAL_ASSERT( match.values_map.find(output_value) != match.values_map.end(), "Didn't find Value output in match result."); - values.emplace(match.values_map[output_value]); + values_to_skip.emplace(match.values_map.at(output_value)); } } -std::unordered_set valuesToSkipObserver( +void addIntermediateValuesToSkipObserver( const script::Module& module, - const std::string& method_name) { + const std::string& method_name, + std::unordered_set& values_to_skip) { script::Method method = module.get_method(method_name); auto graph = method.graph(); @@ -54,18 +58,74 @@ graph(%self, %input): %relu = match::module[name="ReLU"](%self) %r = prim::CallMethod[name="forward"](%relu, %output) return (%r))"; - std::vector patterns = {conv_functional_relu, conv_relu_module}; + std::string matmul_add = R"( +graph(%input, %weight, %bias, %4): + %weight_t = aten::t(%weight) + %output = aten::matmul(%input, %weight_t) + %res = aten::add_(%output, %bias, %4) + return (%res))"; + std::vector patterns = { + conv_functional_relu, conv_relu_module, matmul_add}; - std::unordered_set values; for (const auto& pattern : patterns) { - findValuesInPattern(*graph, pattern, values); + findValuesInPattern(*graph, pattern, values_to_skip); + } +} + +std::string getFuncName(const c10::QualifiedName& qname) { + const auto& name = qname.qualifiedName(); + auto rdot_idx = name.rfind('.'); + if (rdot_idx != std::string::npos) { + return name.substr(rdot_idx + 1, name.length()); + } else { + return name; } +} - return values; +bool nodeQuantizable(Node* n) { + static std::vector call_funcs = { + "conv2d", + "linear", + "relu", + }; + std::vector aten_funcs = { + Symbol::aten("addmm"), Symbol::aten("matmul"), Symbol::aten("add_")}; + std::transform( + call_funcs.begin(), + call_funcs.end(), + std::back_inserter(aten_funcs), + [](const std::string& s) { return Symbol::aten(s); }); + bool is_quantizable = + std::find(aten_funcs.begin(), aten_funcs.end(), n->kind()) != + aten_funcs.end(); + if (n->kind() == prim::CallFunction) { + auto func_node = n->inputs()[0]->node(); + auto func = func_node->output()->type()->expect()->function(); + auto func_name = getFuncName(func->qualname()); + if (func_node->kind() == prim::Constant) { + is_quantizable |= + std::find(call_funcs.begin(), call_funcs.end(), func_name) != + call_funcs.end(); + } + } + return is_quantizable; } -static bool outputsNeedToBeObserved(Node* n) { - return n->kind() != prim::Constant; +bool valueNeedsToBeQuantized(Value* v) { + if (!v->type()->isSubtypeOf(TensorType::get())) { + return false; + } + // Check whether producer is quantizable + if (nodeQuantizable(v->node())) { + return true; + } + // Check whether user is quantizable + for (const auto& use : v->uses()) { + if (nodeQuantizable(use.user)) { + return true; + } + } + return false; } Node* traverseToQuantNode(Node* dq) { @@ -127,6 +187,13 @@ Node* insertObserver( observer_module = std::get<0>(qconfig); } std::string observer_name = "observer_for_" + v->debugName(); + // Temporary workaround to skip inserting duplicate modules, + // full support will come in next PR + for (script::Slot s : module.get_module_slots()) { + if (s.name() == observer_name) { + return nullptr; + } + } script::Module observer = observer_module.clone(); module.register_module(observer_name, observer); // Get handle of observer module @@ -150,24 +217,20 @@ Node* insertObserver( return call; } -c10::optional getQConfig( - const std::string& key, - const c10::optional& parent_qconfig, - const QConfigDict& qconfig_dict) { - if (qconfig_dict.find(key) != qconfig_dict.end()) { - return qconfig_dict.at(key); - } - return parent_qconfig; -} - -void getQConfigMapHelper( +void fillQConfigMap( const script::Module& module, const QConfigDict& qconfig_dict, - const std::string& key, - const c10::optional& parent_qconfig, - ModuleQConfigMap& map) { - auto qconfig = getQConfig(key, parent_qconfig, qconfig_dict); + ModuleQConfigMap& map, + const std::string& key = "", + const c10::optional& parent_qconfig = c10::nullopt) { + c10::optional qconfig; + if (qconfig_dict.find(key) != qconfig_dict.end()) { + qconfig = qconfig_dict.at(key); + } else { + qconfig = parent_qconfig; + } map[module.module_object()] = qconfig; + for (script::Slot s : module.get_module_slots()) { std::string child_key; if (key == "") { @@ -175,24 +238,24 @@ void getQConfigMapHelper( } else { child_key = key + "." + s.name(); } - getQConfigMapHelper(s.to_module(), qconfig_dict, child_key, qconfig, map); + fillQConfigMap(s.to_module(), qconfig_dict, map, child_key, qconfig); } } -ModuleQConfigMap getQConfigMap( - const script::Module& module, - const QConfigDict& qconfig_dict) { - ModuleQConfigMap map; - getQConfigMapHelper(module, qconfig_dict, "", c10::nullopt, map); - return map; -} - void InsertObserversImpl( script::Module& module, const std::string& method_name, - const ModuleQConfigMap& module_qconfig_map) { + const ModuleQConfigMap& module_qconfig_map, + std::unordered_set& values_to_skip) { + if (module_qconfig_map.count(module.module_object()) == 0) { + // the module is added by us, e.g.: observer module + return; + } + script::Method method = module.get_method(method_name); auto graph = method.graph(); + ConstantPropagation(graph); + addIntermediateValuesToSkipObserver(module, method_name, values_to_skip); // For storing all values that need to be instrumented with an observer call. std::vector values_to_observe; @@ -210,9 +273,10 @@ void InsertObserversImpl( // prim::Param nodes do not belong to the graph. Hence the Insert // point is the beginning of graph node. This also safe guards against // observing a potentially mutated value due to some in-place operation + Value* self = graph->inputs()[0]; for (size_t idx = 1; idx < method.num_inputs(); ++idx) { auto& v = graph->inputs()[idx]; - if (v->type()->isSubtypeOf(TensorType::get())) { + if (!values_to_skip.count(v) && valueNeedsToBeQuantized(v)) { auto qconfig = module_qconfig_map.at(module.module_object()); if (qconfig) { auto observer_node = @@ -224,28 +288,52 @@ void InsertObserversImpl( } } - auto values_to_skip = valuesToSkipObserver(module, method_name); - blocks_to_visit.push(graph->block()); while (!blocks_to_visit.empty()) { Block* b = blocks_to_visit.top(); blocks_to_visit.pop(); for (Node* n : b->nodes()) { - // Skip nodes that we don't need to observe, e.g. 'prim::Constant' or - // observer nodes - if (!outputsNeedToBeObserved(n) || observer_for_input.count(n) != 0) { + // Skip observer nodes + if (observer_for_input.count(n) != 0) { continue; } // Record all outputs in the values_to_observe - we'll later add observers // for all values from it. for (Value* v : n->outputs()) { - if (values_to_skip.count(v) == 0) { + if (!values_to_skip.count(v) && valueNeedsToBeQuantized(v)) { values_to_observe.push_back(v); } + if (v->node()->kind() == prim::CallMethod) { + // If we find a call to a method of a child module, + // we'll recursively insert observers for the forward function to + // the child module. + auto module_instance = v->node()->inputs()[0]; + auto module_method_name = v->node()->s(attr::name); + if (module_instance->node()->kind() == prim::GetAttr) { + auto child_module_name = module_instance->node()->s(attr::name); + auto child_module = module.find_module(child_module_name); + TORCH_INTERNAL_ASSERT( + child_module, + "Child module " + child_module_name + " does not exist"); + // Recursively insert observer for the forward function of child + // module + InsertObserversImpl( + child_module.value(), + module_method_name, + module_qconfig_map, + values_to_skip); + } else { + TORCH_INTERNAL_ASSERT( + module_instance == graph->inputs()[0], + "We only support call method either on %self" + "or child instance in insert_observers_pass right now"); + InsertObserversImpl( + module, module_method_name, module_qconfig_map, values_to_skip); + } + } } - // Schedule subblocks (if any) for visiting. for (Block* subblock : n->blocks()) { blocks_to_visit.push(subblock); } @@ -254,38 +342,11 @@ void InsertObserversImpl( // Actually add observer nodes. for (Value* v : values_to_observe) { - if (!v->type()->isSubtypeOf(TensorType::get())) { - continue; - } // Skip inserting observer for bias if (v->node()->kind() == prim::GetAttr && v->node()->s(c10::attr::name) == "bias") { continue; } - if (v->node()->kind() == prim::CallMethod && - v->node()->s(attr::name) == "forward") { - // If we find a call to forward function of a child module, - // we'll recursively insert observers for the forward function to - // the child module. - // One important detail is that currently we insert observer twice for - // input and output of the forward function call of the chlid module, - // this is required if child module has different qconfig from the - // parent module, but it should be removed if they have the same - // qconfig, we'll do this in a separate PR. - // Another note is that right now we only insert observer for "forward" - // function, but we may need to extend to all functions. - auto child_instance = v->node()->inputs()[0]; - TORCH_INTERNAL_ASSERT( - child_instance->node()->kind() == prim::GetAttr, - "Child instance should come from GetAttr."); - auto child_module_name = child_instance->node()->s(attr::name); - auto child_module = module.find_module(child_module_name); - TORCH_INTERNAL_ASSERT( - child_module, - "Child module " + child_module_name + " does not exist"); - // Recursively insert observer for the forward function of child module - InsertObserversImpl(child_module.value(), "forward", module_qconfig_map); - } auto qconfig = module_qconfig_map.at(module.module_object()); // Skip inserting observer if no qconfig is specified if (qconfig) { @@ -367,8 +428,8 @@ class QuantizeHelper { public: QuantizeHelper(const script::Module& m) : module_(m) {} IValue getQParams(Value* v); - c10::optional findChildModuleToQuantize(Value* v); - void quantizeBias(Value* v); + c10::optional findChildModuleToQuantize( + Value* child_instance); void quantizeTensor(Value* v, bool insert_after = true); void removeObserver(Value* v, const std::string& observer_name); void destroyNodes() { @@ -427,38 +488,6 @@ double getScale(const IValue& qparam) { return qparam.toTuple()->elements()[0].toTensor().item().toDouble(); } -void QuantizeHelper::quantizeBias(Value* v) { - // Traverse to the place where this is used - std::vector ops_with_bias = {Symbol::aten("conv2d"), - Symbol::aten("_convolution")}; - for (const auto& use : v->uses()) { - if (std::find( - ops_with_bias.begin(), ops_with_bias.end(), use.user->kind()) != - ops_with_bias.end()) { - // Make sure there is no observer module for bias - auto observer_name = findObserverName(v); - TORCH_INTERNAL_ASSERT(!observer_name, "bias should not be observed!"); - Value* activation = use.user->inputs()[0]; - Value* weight = use.user->inputs()[1]; - // Get qparam from activation - IValue act_qparam = getQParams(activation); - // Get qparam from weight - IValue weight_qparam = getQParams(weight); - IValue bias_scale = at::scalar_tensor( - c10::Scalar(getScale(act_qparam) * getScale(weight_qparam)), - at::kDouble); - IValue bias_qparam = c10::ivalue::Tuple::create( - std::vector({bias_scale, at::scalar_tensor(c10::Scalar(0))}), - act_qparam.toTuple()->type); - Node* dequant = insertQuantDeQuantCall(v, bias_qparam, at::kQInt32); - v->replaceAllUsesWith(dequant->output()); - Node* q = traverseToQuantNode(dequant); - TORCH_INTERNAL_ASSERT(q != nullptr); - q->replaceInputWith(dequant->output(), v); - } - } -} - void QuantizeHelper::quantizeTensor(Value* v, bool insert_after) { auto observer_name = findObserverName(v); if (!observer_name) { @@ -480,22 +509,18 @@ void QuantizeHelper::quantizeTensor(Value* v, bool insert_after) { } c10::optional QuantizeHelper::findChildModuleToQuantize( - Value* v) { - if (v->node()->kind() == prim::CallMethod && - v->node()->s(attr::name) == "forward") { - auto child_instance = v->node()->inputs()[0]; + Value* child_instance) { + TORCH_INTERNAL_ASSERT( + child_instance->node()->kind() == prim::GetAttr, + "Child instance should come from GetAttr."); + auto child_module_name = child_instance->node()->s(attr::name); + if (child_module_name.find("observer_for_") == std::string::npos) { + auto child_module = module_.find_module(child_module_name); TORCH_INTERNAL_ASSERT( - child_instance->node()->kind() == prim::GetAttr, - "Child instance should come from GetAttr."); - auto child_module_name = child_instance->node()->s(attr::name); - if (child_module_name.find("observer_for_") == std::string::npos) { - auto child_module = module_.find_module(child_module_name); - TORCH_INTERNAL_ASSERT( - child_module, - "InsertQuantDeQuant - Child module " + child_module_name + - " does not exist"); - return child_module; - } + child_module, + "InsertQuantDeQuant - Child module " + child_module_name + + " does not exist"); + return child_module; } return c10::nullopt; } @@ -517,62 +542,61 @@ void InsertQuantDeQuantImpl( } } - std::vector values_to_quantize; - std::unordered_map - child_modules_to_quantize; QuantizeHelper qh(module); std::stack blocks_to_visit; blocks_to_visit.push(graph->block()); while (!blocks_to_visit.empty()) { Block* b = blocks_to_visit.top(); blocks_to_visit.pop(); - for (Node* n : b->nodes()) { + for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end;) { + Node* n = *it++; for (Value* v : n->outputs()) { - if (v->type()->isSubtypeOf(TensorType::get())) { - auto child_module = qh.findChildModuleToQuantize(v); - if (child_module) { - child_modules_to_quantize[child_module.value().module_object()] = - child_module.value(); + if (!v->type()->isSubtypeOf(TensorType::get())) { + continue; + } + if (v->node()->kind() == prim::CallMethod) { + auto module_instance = v->node()->inputs()[0]; + auto module_method_name = v->node()->s(attr::name); + c10::optional m; + // calling method on self + if (module_instance == graph->inputs()[0]) { + m = module; + } else { + m = qh.findChildModuleToQuantize(module_instance); + } + if (m) { + InsertQuantDeQuantImpl(m.value(), module_method_name); } - values_to_quantize.push_back(v); } + if (v->node()->kind() == prim::GetAttr && + v->node()->s(c10::attr::name) == "bias") { + continue; + } + qh.quantizeTensor(v); } - // Schedule subblocks (if any) for visiting. for (Block* subblock : n->blocks()) { blocks_to_visit.push(subblock); } } } - for (Value* v : values_to_quantize) { - if (v->node()->kind() == prim::GetAttr && - v->node()->s(c10::attr::name) == "bias") { - qh.quantizeBias(v); - } else { - qh.quantizeTensor(v); - } - } - for (Value* v : input_values) { qh.quantizeTensor(v, false); } - for (auto& item : child_modules_to_quantize) { - InsertQuantDeQuantImpl(item.second, "forward"); - } - qh.destroyNodes(); } - } // namespace TORCH_API void InsertObservers( script::Module& module, const std::string& method_name, const QConfigDict& qconfig_dict) { - auto module_qconfig_map = getQConfigMap(module, qconfig_dict); - InsertObserversImpl(module, method_name, module_qconfig_map); + ModuleQConfigMap module_qconfig_map; + fillQConfigMap(module, qconfig_dict, module_qconfig_map); + std::unordered_set values_to_skip; + InsertObserversImpl(module, method_name, module_qconfig_map, values_to_skip); } script::Module InsertQuantDeQuant( @@ -601,35 +625,80 @@ void FoldQuantNodesIntoInputsOutputs(std::shared_ptr& graph) { } void QuantFusion(std::shared_ptr& graph) { - std::string pattern = R"( -graph(%a_quant, %w_quant, %b_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype, %c, %d, %e, %f): + const std::string quantized_linear_with_bias = + R"( +graph(%a_quant, %w_quant, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype, %4): + %w_quant_t = aten::t(%w_quant) + %packed_params = quantized::linear_prepack(%w_quant_t, %b) + %r = quantized::linear(%a_quant, %packed_params, %r_scale, %r_zero_point) + return (%r))"; + const std::unordered_map pattern_and_replacements = + {// quantized::conv2d + {R"( +graph(%a_quant, %w_quant, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): %a_intrepr = aten::int_repr(%a_quant) %a_dequant = aten::_dequantize_linear(%a_intrepr, %a_scale, %a_zero_point, %a_dtype) %w_intrepr = aten::int_repr(%w_quant) %w_dequant = aten::_dequantize_linear(%w_intrepr, %w_scale, %w_zero_point, %w_dtype) - %b_intrepr = aten::int_repr(%b_quant) - %b_dequant = aten::_dequantize_linear(%b_intrepr, %b_scale, %b_zero_point, %b_dtype) - %r = aten::conv2d(%a_dequant, %w_dequant, %b_dequant, %c, %d, %e, %f) + %r = aten::conv2d(%a_dequant, %w_dequant, %b, %stride, %padding, %dilation, %groups) %r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype) - return (%r_quant))"; - - std::string replacement = R"( -graph(%a_quant, %w_quant, %b_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): + return (%r_quant))", + R"( +graph(%a_quant, %w_quant, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups): + %packed_params = quantized::conv_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups) + %r = quantized::conv2d(%a_quant, %packed_params, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point) %0 : int = prim::Constant[value=0]() %1 : int = prim::Constant[value=1]() %2 : int = prim::Constant[value=2]() %3 : int = prim::Constant[value=3]() - %in_param : int[] = prim::ListConstruct(%0, %2, %3, %1) - %a_perm : Tensor = aten::permute(%a_quant, %in_param) - %w_perm : Tensor = aten::permute(%w_quant, %in_param) - %w_packed = quantized::fbgemm_conv_prepack(%w_perm, %stride, %padding, %dilation, %groups) - %r = quantized::fbgemm_conv2d(%a_perm, %w_packed, %b_quant, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point) %out_param : int[] = prim::ListConstruct(%0, %3, %1, %2) %r_perm = aten::permute(%r, %out_param) - return (%r_perm))"; - SubgraphRewriter rewriter; - rewriter.RegisterRewritePattern(pattern, replacement); - rewriter.runOnGraph(graph); + return (%r_perm))"}, + // addmm -> quantized::linear + {R"( +graph(%a_quant, %w_quant, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype, %4): + %a_intrepr = aten::int_repr(%a_quant) + %a_dequant = aten::_dequantize_linear(%a_intrepr, %a_scale, %a_zero_point, %a_dtype) + %w_intrepr = aten::int_repr(%w_quant) + %w_dequant = aten::_dequantize_linear(%w_intrepr, %w_scale, %w_zero_point, %w_dtype) + %r = aten::addmm(%b, %a_dequant, %w_dequant, %4, %4) + %r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype) + return (%r_quant))", + quantized_linear_with_bias}, + // matmul(with bias) -> quantized::linear + {R"( +graph(%a_quant, %w_quant, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype, %4): + %a_intrepr = aten::int_repr(%a_quant) + %a_dequant = aten::_dequantize_linear(%a_intrepr, %a_scale, %a_zero_point, %a_dtype) + %w_intrepr = aten::int_repr(%w_quant) + %w_dequant = aten::_dequantize_linear(%w_intrepr, %w_scale, %w_zero_point, %w_dtype) + %output = aten::matmul(%a_dequant, %w_dequant) + %r = aten::add_(%output, %b, %4) + %r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype) + return (%r_quant))", + quantized_linear_with_bias}, + // matmul(without bias) -> quantized::linear + {R"( +graph(%a_quant, %w_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype): + %a_intrepr = aten::int_repr(%a_quant) + %a_dequant = aten::_dequantize_linear(%a_intrepr, %a_scale, %a_zero_point, %a_dtype) + %w_intrepr = aten::int_repr(%w_quant) + %w_dequant = aten::_dequantize_linear(%w_intrepr, %w_scale, %w_zero_point, %w_dtype) + %r = aten::matmul(%a_dequant, %w_dequant) + %r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype) + return (%r_quant))", + R"( +graph(%a_quant, %w_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %r_scale, %r_zero_point, %r_dtype): + %w_quant_t = aten::t(%w_quant) + %bias: Tensor? = prim::Constant() + %packed_params = quantized::linear_prepack(%w_quant_t, %bias) + %r = quantized::linear(%a_quant, %packed_params, %r_scale, %r_zero_point) + return (%r))"}}; + for (const auto& item : pattern_and_replacements) { + SubgraphRewriter rewriter; + rewriter.RegisterRewritePattern(item.first, item.second); + rewriter.runOnGraph(graph); + } } struct ConvBNParameters { @@ -792,5 +861,54 @@ graph(%self, %x): } } +void FoldQuantizeCallIntoBuffer( + script::Module& module, + const std::string& method_name) { + const std::string pattern = R"( +graph(%self, %scale, %zero_point, %dtype): + %weight = prim::GetAttr[name="weight"](%self) + %weight_quant = aten::quantize_linear(%weight, %scale, %zero_point, %dtype) + return (%weight_quant))"; + Graph pattern_graph; + std::unordered_map vmap; + script::parseIR(pattern, &pattern_graph, vmap); + auto method = module.get_method(method_name); + auto graph = method.graph(); + auto matches = findPatternMatches(pattern_graph, *graph); + // Extra filter on scale/zero_point/dtype to make sure they are Constant + auto filter = [](const Match& match, + const std::unordered_map& vmap) { + const auto& match_vmap = match.values_map; + auto scale_node = match_vmap.at(vmap.at("scale"))->node(); + auto zero_point_node = match_vmap.at(vmap.at("zero_point"))->node(); + auto dtype_node = match_vmap.at(vmap.at("dtype"))->node(); + return scale_node->kind() == prim::Constant && + zero_point_node->kind() == prim::Constant && + dtype_node->kind() == prim::Constant; + }; + for (const auto& match : matches) { + if (!filter(match, vmap)) { + continue; + } + auto match_vmap = match.values_map; + auto float_weight = module.get_parameter("weight").variable_data(); + auto scale = toIValue(match_vmap.at(vmap.at("scale"))).value().toDouble(); + auto zero_point = + toIValue(match_vmap.at(vmap.at("zero_point"))).value().toInt(); + auto dtype = + toIValue(match_vmap.at(vmap.at("dtype"))).value().toScalarType(); + module.register_buffer( + "_quantized_weight", + at::quantize_linear(float_weight, scale, zero_point, dtype)); + } + + std::string replacement = R"( +graph(%self, %scale, %zero_point, %dtype): + %weight_quant = prim::GetAttr[name="_quantized_weight"](%self) + return (%weight_quant))"; + SubgraphRewriter rewriter; + rewriter.RegisterRewritePattern(pattern, replacement); + rewriter.runOnGraph(graph, filter); +} } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/quantization.h b/torch/csrc/jit/passes/quantization.h index bf2c4d6b225c3..e4beb5a362d6a 100644 --- a/torch/csrc/jit/passes/quantization.h +++ b/torch/csrc/jit/passes/quantization.h @@ -84,6 +84,15 @@ TORCH_API script::Module InsertQuantDeQuant( * Right now this is a fusion for fbgemm backend and only works for quantized * conv op, we'll extend to more ops and more backends in the future. * + * Currently supported fusion: + * q(conv2d(dq(a), dq(w), dq(b))) --> to_nchw(fbgemm_conv2d(prepack(to_nhwc(a)), + * prepack(to_nhwc(w)), + * prepack(to_nhwc(b)))) + * + * q(linear(dq(a), dq(w), dq(b))) --> to_nchw(fbgemm_linear(prepack(to_nhwc(a)), + * prepack(to_nhwc(w)), + * prepack(to_nhwc(b)))) + * * \param graph the graph we want to apply fusion */ TORCH_API void QuantFusion(std::shared_ptr& graph); @@ -96,5 +105,16 @@ TORCH_API void QuantFusion(std::shared_ptr& graph); */ TORCH_API void FoldConvBatchNorm2d(const script::Module& module); +/** \brief Fold quantize function call into module + * + * For the graph in the specified method of module, if we find a quantize_linear + * call on an attribute("weight") of the module, we'll quantize the attribute directly + * and register a new buffer "_quantized_weight" on the module and remove the + * quantize_linear call and replace the use of the quantized weight with + * "_quantized_weight". + */ +TORCH_API void FoldQuantizeCallIntoBuffer(script::Module& module, const std::string& method_name); + + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 19be42eea93ef..0eb66b0b1faa3 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -213,10 +213,62 @@ class ShapePropagator { continue; } } - return tensor_types; } + c10::ScalarType unionScalarTypes(c10::ScalarType original, c10::ScalarType next) { + if (original == c10::ScalarType::Undefined) { + return next; + } else { + return c10::promoteTypes(original, next); + } + } + + // Promotes result types for arithmetic operations on Tensor operands using + // new type promotion logic. See tensor_attributes.rst for details. + // This doesn't handle the case of arithmetic ops with Scalar arguments (when + // `Tensor.getUnsafeTensorImpl()->is_wrapped_nubmer()` would return true) + c10::optional getPromotedTypeForArithmeticOp(Node *node) { + c10::ScalarType dimmed = c10::ScalarType::Undefined; + c10::ScalarType zerodim = c10::ScalarType::Undefined; + // binary arithmetic ops, more than 2 args is alpha. + for (size_t i = 0 ; i < 2 ; i++ ) { + auto dtt = node->inputs()[i]->type()->expect(); + auto inputDtype = dtt->scalarType(); + if (!dtt || !inputDtype) { + return c10::nullopt; + } + if (dtt->dim() && *dtt->dim() > 0) { + dimmed = unionScalarTypes(dimmed, *inputDtype); + } else if (!isFloatingType(dimmed)) { + // if no dimensions + zerodim = unionScalarTypes(zerodim, *inputDtype); + } + } + // if a tensor with dimensions is already of the highest category, don't + // need to check zero-dim tensors. + if (isFloatingType(dimmed)) { + return dimmed; + } + // int_tensor * zero_dim_floating -> floating_tensor + if (isIntegralType(dimmed, false) && isFloatingType(zerodim) ) { + return zerodim; + } + // bool_tensor * non_bool_scalar -> non_bool_tensor + if (c10::ScalarType::Bool == dimmed && c10::ScalarType::Undefined != zerodim) { + return zerodim; + } + // types of dimensioned tensors generally take precedence over zero-dim + // tensors if not promoting due to category. e.g.: + // int_tensor * long -> int_tensor + if (c10::ScalarType::Undefined != dimmed) { + return dimmed; + } + + // no dimmed tensors. e.g. zero_dim_tensor + zero_dim_tensor. + return zerodim; + } + bool mergeTypes( ArrayRef lhs, ArrayRef rhs, @@ -682,13 +734,14 @@ class ShapePropagator { // primitive/tensor outputs. bool PropagateTensorShapeOnNode(Node* node, bool insert_expands) { - static const auto broadcast = [](std::vector& tensor_types, - size_t arg_for_type) -> TensorTypePtr { + static const auto broadcast = + [](std::vector& tensor_types, + c10::optional t) -> TensorTypePtr { if (tensor_types.size() == 1) { - return tensor_types[0]->dimensionedOnly(); + return tensor_types[0]->dimensionedOnly()->withScalarType(t); } AT_ASSERT(!tensor_types.empty()); - auto any_type = tensor_types[arg_for_type]; + auto any_type = tensor_types[0]; auto max_dims = any_type->dim(); for (auto& type : tensor_types) { if (!max_dims || !type->dim()) { @@ -698,7 +751,7 @@ class ShapePropagator { } } return TensorType::create( - any_type->scalarType(), + t, any_type->device(), max_dims, /*requires_grad=*/c10::nullopt); @@ -820,17 +873,35 @@ class ShapePropagator { // Requirements: // dims : broadcast all tensor args - // scalar type : always matching and preserved + // scalar type : promoted from input dtypes // device : always matching and preserved // tensor inputs : * // tensor outputs : 1 - static const register_formula_for broadcasting_ops{ + static const register_formula_for broadcasting_ops_arithmetic{ { // Tensor-Tensor operators "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", "aten::mul(Tensor self, Tensor other) -> Tensor", "aten::div(Tensor self, Tensor other) -> Tensor", + }, + [this](Node* node) -> type_vec_t { + if (auto maybe_tensor_types = gatherTensorTypes(node)) { + AT_ASSERT(maybe_tensor_types->size() >= 2); + auto dtype = getPromotedTypeForArithmeticOp(node); + return {broadcast(*maybe_tensor_types, dtype)}; + } + return {}; + }}; + + // Requirements: + // dims : broadcast all tensor args + // scalar type : always matching and preserved + // device : always matching and preserved + // tensor inputs : * + // tensor outputs : 1 + static const register_formula_for broadcasting_ops{ + { "aten::pow(Tensor self, Tensor exponent) -> Tensor", "aten::fmod(Tensor self, Tensor other) -> Tensor", "aten::remainder(Tensor self, Tensor other) -> Tensor", @@ -865,7 +936,8 @@ class ShapePropagator { first_scalar_type) { arg_for_type = 1; } - return {broadcast(*maybe_tensor_types, arg_for_type)}; + auto t = (*maybe_tensor_types)[arg_for_type]->scalarType(); + return {broadcast(*maybe_tensor_types, *t)}; } return {}; }}; @@ -878,20 +950,50 @@ class ShapePropagator { }, [this](Node* node) -> type_vec_t { if (auto maybe_tensor_types = gatherTensorTypes(node)) { - return {broadcast(*maybe_tensor_types, 0)}; + auto dtype = (*maybe_tensor_types)[0]->scalarType(); + if (!dtype) { + return {}; + } + return {broadcast(*maybe_tensor_types, *dtype)}; } return {}; }}; - // NB: we always take the scalar type of the Tensor - static const register_formula_for broadcasting_tensor_scalar_ops{ + static const register_formula_for broadcasting_tensor_scalar_ops_arithmetic{ { - // Tensor-Scalar operators "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor", "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor", "aten::mul(Tensor self, Scalar other) -> Tensor", "aten::div(Tensor self, Scalar other) -> Tensor", + }, + [this](Node* node) -> type_vec_t { + if (auto maybe_tensor_types = gatherTensorTypes(node)) { + auto first_scalar_type = (*maybe_tensor_types)[0]->scalarType(); + auto second_scalar_type = tryScalarTypeFromJitType(node->inputs()[1]->type()); + if (!first_scalar_type || !second_scalar_type) { + return {}; + } + if (isIntegralType(*first_scalar_type, false) && isFloatingType(*second_scalar_type) ) + { + auto default_dtype = at::typeMetaToScalarType(caffe2::get_default_dtype()); + return {broadcast(*maybe_tensor_types, default_dtype)}; + } + if (c10::ScalarType::Bool == *first_scalar_type && + c10::ScalarType::Bool != *second_scalar_type) + { + auto result_type = c10::promoteTypes(*first_scalar_type, *second_scalar_type); + return {broadcast(*maybe_tensor_types, result_type)}; + } + return {broadcast(*maybe_tensor_types, first_scalar_type)}; + } + return {}; + }}; + + // NB: we always take the scalar type of the Tensor + static const register_formula_for broadcasting_tensor_scalar_ops{ + { + "aten::pow(Tensor self, Scalar exponent) -> Tensor", "aten::fmod(Tensor self, Scalar other) -> Tensor", "aten::remainder(Tensor self, Scalar other) -> Tensor", @@ -909,7 +1011,7 @@ class ShapePropagator { }, [this](Node* node) -> type_vec_t { if (auto maybe_tensor_types = gatherTensorTypes(node)) { - return {broadcast(*maybe_tensor_types, 0)}; + return {broadcast(*maybe_tensor_types, (*maybe_tensor_types)[0]->scalarType())}; } return {}; }}; @@ -922,7 +1024,7 @@ class ShapePropagator { }, [this](Node* node) -> type_vec_t { if (auto maybe_tensor_types = gatherTensorTypes(node)) { - return {broadcast(*maybe_tensor_types, 1)}; + return {broadcast(*maybe_tensor_types, (*maybe_tensor_types)[1]->scalarType())}; } return {}; }}; @@ -980,8 +1082,7 @@ class ShapePropagator { }, [this](Node* node) -> type_vec_t { if (auto maybe_tensor_types = gatherTensorTypes(node)) { - return { - broadcast(*maybe_tensor_types, 0)->withScalarType(at::kBool)}; + return {broadcast(*maybe_tensor_types, at::kBool)}; } return {}; }}; @@ -1658,7 +1759,7 @@ class ShapePropagator { } else { // Batched matrix multiply (possibly with squeeze + unsqueeze if one // argument is 1D) - auto type = broadcast(tensor_types, 0); + auto type = broadcast(tensor_types, tensor_types[0]->scalarType()); if (dim1 == 1 || dim2 == 1) { type = type->withDim(type->dim().value() - 1); @@ -1729,14 +1830,37 @@ class ShapePropagator { // handled by the fallthrough because it's not always safe to run it due // to integer divide-by-zero. return PropagateShapeOnNodeByRunningIt(node); + } else if (node->matches("aten::pow(Tensor self, Scalar exponent) -> Tensor")) { + node->output()->setType(tensor_types.at(0)); + return true; } else if ( node->matches( "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor") || node->matches( "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor") || - node->matches("aten::mul(Tensor self, Scalar other) -> Tensor") || - node->matches("aten::pow(Tensor self, Scalar exponent) -> Tensor")) { - node->output()->setType(tensor_types.at(0)); + node->matches("aten::mul(Tensor self, Scalar other) -> Tensor")) { + auto first_scalar_type = (tensor_types)[0]->scalarType(); + auto second_scalar_type = tryScalarTypeFromJitType(node->inputs()[1]->type()); + if (!first_scalar_type || !second_scalar_type) { + return false; + } + if (isIntegralType(*first_scalar_type, false) && isFloatingType(*second_scalar_type) ) + { + auto default_dtype = at::typeMetaToScalarType(caffe2::get_default_dtype()); + auto type = tensor_types[0]->withScalarType(default_dtype); + node->output()->setType(type); + return true; + } + if (c10::ScalarType::Bool == *first_scalar_type && + c10::ScalarType::Bool != *second_scalar_type) + { + auto result_type = c10::promoteTypes(*first_scalar_type, *second_scalar_type); + auto type = tensor_types[0]->withScalarType(result_type); + node->output()->setType(type); + return true; + } + auto type = tensor_types[0]->withScalarType(first_scalar_type); + node->output()->setType(type); return true; } else if ( insert_expands && @@ -1866,7 +1990,9 @@ class ShapePropagator { if (inferred) { SHAPE_ASSERT(size_product != 0); size_t numel = 1; - for (int64_t s : tensor_types.at(0)->sizes().concrete_sizes().value()) + auto concrete_sizes = + tensor_types.at(0)->sizes().concrete_sizes().value(); + for (int64_t s : concrete_sizes) numel *= s; int64_t inferred_size = numel / size_product; sizes[inferred_idx] = inferred_size; diff --git a/torch/csrc/jit/passes/subgraph_rewrite.cpp b/torch/csrc/jit/passes/subgraph_rewrite.cpp index 197d3f321306d..6df94fafb164d 100644 --- a/torch/csrc/jit/passes/subgraph_rewrite.cpp +++ b/torch/csrc/jit/passes/subgraph_rewrite.cpp @@ -35,15 +35,22 @@ script::Module SubgraphRewriter::runOnModule(const script::Module& module) { return module; } -void SubgraphRewriter::runOnGraph(std::shared_ptr& graph) { +void SubgraphRewriter::runOnGraph( + std::shared_ptr& graph, + const std::function< + bool(const Match&, const std::unordered_map&)>& + filter) { for (const RewritePatternDescr& pattern : patterns_) { - rewriteSinglePatternOnGraph(graph, pattern); + rewriteSinglePatternOnGraph(graph, pattern, filter); } } void SubgraphRewriter::rewriteSinglePatternOnGraph( std::shared_ptr& graph, - RewritePatternDescr pattern) { + const RewritePatternDescr& pattern, + const std::function< + bool(const Match&, const std::unordered_map&)>& + filter) { std::unordered_map rewrite_map; std::vector values_to_rewrite; @@ -56,6 +63,9 @@ void SubgraphRewriter::rewriteSinglePatternOnGraph( const auto& matches = findPatternMatches(pattern_graph, *graph); for (const Match& match : matches) { + if (!filter(match, vmap)) { + continue; + } // Matches might overlap with each other, in that case some of the nodes in // the current match might have already been used in another folded pattern. // We need to skip such matches. diff --git a/torch/csrc/jit/passes/subgraph_rewrite.h b/torch/csrc/jit/passes/subgraph_rewrite.h index f338b13fecc1e..a1600125dfb10 100644 --- a/torch/csrc/jit/passes/subgraph_rewrite.h +++ b/torch/csrc/jit/passes/subgraph_rewrite.h @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -47,7 +48,19 @@ class TORCH_API SubgraphRewriter { script::Module runOnModule(const script::Module& module); // Run pattern-based subgraph rewrite pass on the graph (used in testing). - void runOnGraph(std::shared_ptr& graph); + // filter is a function that does extra filtering on the match, if it returns + // false for a given Match, we'll skip the match + // filter function takes a `Match` and a value map from parsing the pattern graph + // since we need to do extra filtering on the matched result but we need to refer + // to the values in the matched result through the values in pattern graph. + void runOnGraph( + std::shared_ptr& graph, + const std::function< + bool(const Match&, const std::unordered_map&)>& + filter = + [](const Match&, const std::unordered_map&) { + return true; + }); // Register standard rewrite patterns. void RegisterDefaultPatterns(); @@ -70,7 +83,13 @@ class TORCH_API SubgraphRewriter { void rewriteSinglePatternOnGraph( std::shared_ptr& graph, - RewritePatternDescr pattern); + const RewritePatternDescr& pattern, + const std::function< + bool(const Match&, const std::unordered_map&)>& + filter = + [](const Match&, const std::unordered_map&) { + return true; + }); bool overlapsWithPreviousMatches(const Match* match); }; diff --git a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp index b0a9afaa96853..3e36130ea7c24 100644 --- a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp +++ b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp @@ -26,9 +26,7 @@ IValue deepCopy(const IValue& self) { // Lists of ivalues should recursively deep copy their contents if (self.isGenericList()) { auto source = std::move(self).toGenericList(); - auto newList = source._elementType().has_value() - ? c10::impl::GenericList(*source._elementType()) - : c10::impl::GenericList(c10::impl::deprecatedUntypedList()); + auto newList = c10::impl::GenericList(source.elementType()); newList.reserve(source.size()); for (const IValue& value : source) { newList.push_back(deepCopy(value)); diff --git a/torch/csrc/jit/pickle.cpp b/torch/csrc/jit/pickle.cpp index 557ac322231f3..a063a5d06d3df 100644 --- a/torch/csrc/jit/pickle.cpp +++ b/torch/csrc/jit/pickle.cpp @@ -1,33 +1,24 @@ -#include #include +#include #include #include - namespace torch { namespace jit { +// These are both defined in `torch/serialization.py` +const char* torch_save_magic_number = + "\x6c\xfc\x9c\x46\xf9\x20\x6a\xa8\x50\x19"; +uint16_t protocol_version = 1001; + void pickle( - std::function writer, + std::function writer, const IValue& ivalue, std::vector* tensor_table) { Pickler pickler(std::move(writer), tensor_table); - - if (tensor_table == nullptr) { - // No tensor table provided, so tensors will be stored directly in the blob. - // Add torch.save metadata so these tensors can be de-serialized later - pickler.torchSaveStart(); - } - pickler.protocol(); pickler.pushIValue(ivalue); pickler.stop(); - - if (tensor_table == nullptr) { - // No tensor table provided, so tensors will be stored directly in the blob. - // Add torch.save metadata so these tensors can be de-serialized later - pickler.torchSaveStop(); - } } std::vector pickle( @@ -45,6 +36,61 @@ std::vector pickle( return data; } +// This has to live here instead of the C++ API to mirror torch.save since the +// mobile build excludes the C++ API +std::vector pickle_save(const at::IValue& ivalue) { + std::vector data; + + auto writer = [&](const char* bytes, size_t len) { + data.insert(data.end(), bytes, bytes + len); + }; + + jit::Pickler pickler(writer, /*tensor_table=*/nullptr); + // Output data to match torch.save, see torch/serialization.py for details + // Magic number (0x1950a86a20f9469cfc6c) + pickler.protocol(); + pickler.pushLong(torch_save_magic_number); + pickler.stop(); + + // Protocol Version + pickler.protocol(); + pickler.pushInt(protocol_version); + pickler.stop(); + + // sys_info, this isn't actually used in de-serialization so we can leave this + // one empty + pickler.protocol(); + pickler.pushEmptyDict(); + pickler.stop(); + + jit::Pickler data_pickler(writer, /*tensor_table=*/nullptr); + data_pickler.protocol(); + data_pickler.pushIValue(ivalue); + data_pickler.stop(); + + auto writeable_tensors = data_pickler.tensorData(); + + std::vector keys; + keys.reserve(writeable_tensors.size()); + std::vector types(writeable_tensors.size(), at::StringType::get()); + + for (size_t i = 0; i < writeable_tensors.size(); i++) { + keys.emplace_back(std::to_string(i)); + } + + auto keys_tuple = at::ivalue::Tuple::create(keys); + jit::pickle(writer, keys_tuple); + + for (const auto& tensor_data : writeable_tensors) { + const char* addr = tensor_data.data(); + size_t numel = tensor_data.numel(); + writer(reinterpret_cast(&numel), sizeof(numel)); + writer(addr, tensor_data.sizeInBytes()); + } + + return data; +} + IValue unpickle( std::function reader, ClassResolver class_resolver, diff --git a/torch/csrc/jit/pickle.h b/torch/csrc/jit/pickle.h index 9e9d86a5e3b3e..1eea5c2a10475 100644 --- a/torch/csrc/jit/pickle.h +++ b/torch/csrc/jit/pickle.h @@ -1,3 +1,5 @@ +#pragma once + #include #include #include @@ -6,6 +8,17 @@ namespace torch { namespace jit { +/// Pickle an IValue by calling a function to handle writing the data. +/// +/// `writer` is a function that takes in a pointer to a chunk of memory and its +/// size and consumes it. +/// +/// See `jit::pickle` for more details. +TORCH_API void pickle( + std::function writer, + const IValue& ivalue, + std::vector* tensor_table = nullptr); + /// Save a `torch::IValue` in a format compatible with Python's `pickle` module /// /// If present, `tensor_table` is a pointer to a table in which tensors that @@ -38,16 +51,9 @@ TORCH_API std::vector pickle( const IValue& ivalue, std::vector* tensor_table = nullptr); -/// Pickle an IValue by calling a function to handle writing the data. -/// -/// `writer` is a function that takes in a pointer to a chunk of memory and its -/// size and consumes it. -/// -/// See `jit::pickle` for more details. -TORCH_API void pickle( - std::function writer, - const IValue& ivalue, - std::vector* tensor_table = nullptr); + +TORCH_API std::vector pickle_save(const IValue& ivalue); + /// `reader` is a function that takes in a size to read from some pickled /// binary. `reader` should remember where it last read, and return diff --git a/torch/csrc/jit/pickler.cpp b/torch/csrc/jit/pickler.cpp index 435006bc2f35f..1a7a52ccb2aee 100644 --- a/torch/csrc/jit/pickler.cpp +++ b/torch/csrc/jit/pickler.cpp @@ -94,56 +94,6 @@ void Pickler::stop() { push(PickleOpCode::STOP); } -void Pickler::torchSaveStop() { - // Add the binary data for all the tensors to be included in the same binary - // TODO: The pickler should be refactored to stream out to a stream directly - // instead of staging in the stack_ array - // As another pickle program in the same binary archive, add a list of - // keys for each tensor (see torch/serialization.py) - protocol(); - push(PickleOpCode::MARK); - for (size_t i = 0; i < tensor_data_.size(); ++i) { - std::string key = std::to_string(i); - push(PickleOpCode::BINUNICODE); - push(key.size()); - pushBytes(key); - } - - push(PickleOpCode::TUPLE); - stop(); - - // Now dump the tensor binary data - for (const auto& data : tensor_data_) { - // first dump size - push(data.numel()); - writer_(data.data(), data.sizeInBytes()); - } -} - -void Pickler::torchSaveStart() { - // Output data to match torch.save, see torch/serialization.py for details - // Magic number (0x1950a86a20f9469cfc6c) - protocol(); - push(PickleOpCode::LONG1); - // LONG1 size - pushBytes("\x0a"); - // LONG1 data - pushBytes("\x6c\xfc\x9c\x46\xf9\x20\x6a\xa8\x50\x19"); - stop(); - - // Protocol Version (1001) - protocol(); - push(PickleOpCode::BININT2); - pushBytes("\xe9\x03"); - stop(); - - // sys_info, this isn't actually used in de-serialization so we can leave this - // one empty - protocol(); - push(PickleOpCode::EMPTY_DICT); - stop(); -} - // unmemoized version called by pushIValue void Pickler::pushIValueImpl(const IValue& ivalue) { if (ivalue.isTensor()) { @@ -332,6 +282,7 @@ void Pickler::pushStorageOfTensor(const at::Tensor& tensor) { push(PickleOpCode::TUPLE); push(PickleOpCode::BINPERSID); + // TODO: Skip this if not writing tensors memoized_storage_map_[addr] = pushNextBinPut(); tensor_data_.push_back(getWriteableTensorData(tensor)); } @@ -426,19 +377,6 @@ void Pickler::pushClass(PicklerClass cls) { pushGlobal("torch.jit._pickle", getClassName(cls)); } -void Pickler::pushTensorReference(const IValue& ivalue) { - pushClass(PicklerClass::TENSOR); - tensor_table_->push_back(ivalue.toTensor()); - int64_t tensor_id = tensor_table_->size() - 1; - // Reduce arguments are spread (e.g. `*args`) before calling the global, - // so wrap in a tuple - push(PickleOpCode::MARK); - pushIValue(tensor_id); - push(PickleOpCode::TUPLE); - - push(PickleOpCode::REDUCE); -} - void Pickler::pushSpecializedList( const IValue& ivalue, PicklerClass cls, @@ -476,13 +414,48 @@ void Pickler::pushDouble(double value) { } } -void Pickler::pushDict(const IValue& ivalue) { +void Pickler::pushLong(const std::string& data) { + uint64_t size = data.size(); + + if (size <= std::numeric_limits::max()) { + push(PickleOpCode::LONG1); + push(size); + } else { + TORCH_INTERNAL_ASSERT( + data.size() > std::numeric_limits::max(), + "Cannot pickle a long with a size larger than 4 bytes") + push(PickleOpCode::LONG4); + push(size); + } + pushBytes(data); +} + +void Pickler::pushTensorReference(const IValue& ivalue) { + pushClass(PicklerClass::TENSOR); + tensor_table_->push_back(ivalue.toTensor()); + int64_t tensor_id = tensor_table_->size() - 1; + // Reduce arguments are spread (e.g. `*args`) before calling the global, + // so wrap in a tuple + push(PickleOpCode::MARK); + pushIValue(tensor_id); + push(PickleOpCode::TUPLE); + + push(PickleOpCode::REDUCE); +} + +void Pickler::pushEmptyDict() { push(PickleOpCode::EMPTY_DICT); +} +void Pickler::pushDict(const IValue& ivalue) { + pushEmptyDict(); + auto dict_items = iterationOrder(ivalue.toGenericDict()); + if (dict_items.size() == 0) { + return; + } push(PickleOpCode::MARK); // Sort the dict for deterministic keys - auto dict_items = iterationOrder(ivalue.toGenericDict()); for (const auto& pair : dict_items) { pushIValue(pair.first); pushIValue(pair.second); @@ -551,12 +524,136 @@ void Pickler::pushTuple(const IValue& ivalue) { } } +// Pickled objects are stored in a form compatible with Python pickling. +// In torchscript List[T]/Dict[K, V] are statically typed and contain +// dynamic type tags allow T, K, and V to be recovered. But this info +// is not stored in the Python pickling information. However, we +// can recover this information from the static type of the top-level +// object being unpickled, because we have a record of the type of the +// objects it contains as attributes. +// `IfPossible` - we can only do this recovery when we have an object as +// the top-level unpickled thing (which is guarenteed for Modules, but +// not for torch.load/torch,save). Otherwise we do not know the types +// of the contained objects and cannot restore the tags. +static void restoreAccurateTypeTagsIfPossible(const IValue& root) { + if (!root.isObject()) { + return; + } + struct Work { + TypePtr static_type; + IValue value; + }; + std::vector to_process = {{root.type(), root}}; + std::unordered_set scanned; + while (!to_process.empty()) { + Work w = std::move(to_process.back()); + to_process.pop_back(); + // ensure we only scan each pointer value once, otherwise this + // can become exponential (and if we allow recursive data in the future, + // it would not terminiate). + if (w.value.isPtrType()) { + const void* key = w.value.internalToPointer(); + auto it = scanned.find(key); + if (it != scanned.end()) { + continue; + } + scanned.emplace_hint(it, key); + } + switch (w.static_type->kind()) { + case TensorType::Kind: + case NumberType::Kind: + case FloatType::Kind: + case IntType::Kind: + case NoneType::Kind: + case GeneratorType::Kind: + case BoolType::Kind: + case VarType::Kind: + case CapsuleType::Kind: + case StringType::Kind: + case FunctionType::Kind: + case DeviceObjType::Kind: + // no op, there is nothing to tag + break; + case AnyType::Kind: + // if Any type does show up, we no longer have a way to precisely + // recover the type information since the w.value may be an untagged + // List/Dict. We should prevent objects being serialized from having the + // Any type and if we do allow it in functions limit it to non-heap + // locations. + TORCH_INTERNAL_ASSERT( + false, "AnyType should not show up in the static type of objects"); + case TupleType::Kind: { + auto t = w.value.toTuple(); + auto ttype = w.static_type->expect(); + for (size_t i = 0; i < ttype->containedTypes().size(); ++i) { + Work elem = {ttype->containedTypes().at(i), t->elements().at(i)}; + to_process.emplace_back(std::move(elem)); + } + } break; + case FutureType::Kind: { + auto f = w.value.toFuture(); + auto t = w.static_type->expect(); + if (f->completed()) { + Work elem = {t->getElementType(), f->value()}; + to_process.emplace_back(std::move(elem)); + } + } break; + case OptionalType::Kind: { + if (!w.value.isNone()) { + auto t = w.static_type->expect(); + Work elem = {t->getElementType(), w.value}; + to_process.emplace_back(std::move(elem)); + } + } break; + case ListType::Kind: { + // specialized lists do not need their type refined, so we can exit + // early here + if (!w.value.isGenericList()) { + break; + } + auto elem_type = w.static_type->cast()->getElementType(); + auto lst = w.value.toGenericList(); + lst.unsafeSetElementType(elem_type); + for (const IValue& item : lst) { + Work elem = {elem_type, item}; + to_process.emplace_back(std::move(elem)); + } + } break; + case DictType::Kind: { + auto dt = w.static_type->cast(); + auto d = w.value.toGenericDict(); + d.unsafeSetKeyType(dt->getKeyType()); + d.unsafeSetValueType(dt->getValueType()); + for (const auto& item : d) { + Work kelem = {dt->getKeyType(), item.key()}; + Work velem = {dt->getValueType(), item.value()}; + to_process.emplace_back(std::move(kelem)); + to_process.emplace_back(std::move(velem)); + } + } break; + // in both cases the dynamic type is a class, and we are going to tag with + // the dynamic type + case InterfaceType::Kind: + case ClassType::Kind: { + auto obj = w.value.toObject(); + auto typ = obj->type(); // note: intentionally using the dynamic type, + // the static type is potentially less accurate + for (size_t i = 0; i < typ->numAttributes(); ++i) { + Work elem = {typ->getAttribute(i), obj->getSlot(i)}; + to_process.emplace_back(std::move(elem)); + } + }; + } + } +} + IValue Unpickler::parse_ivalue() { run(); TORCH_CHECK( stack_.size() == 1, "Unpickler expected 1 element on the stack, but found ", stack_.size()); + restoreAccurateTypeTagsIfPossible(stack_[0]); return stack_[0]; } @@ -639,8 +736,7 @@ PickleOpCode Unpickler::readInstruction() { auto opcode = readOpCode(); switch (opcode) { case PickleOpCode::EMPTY_LIST: { - stack_.emplace_back( - c10::impl::GenericList(c10::impl::deprecatedUntypedList())); + stack_.emplace_back(c10::impl::GenericList(AnyType::get())); } break; case PickleOpCode::EMPTY_TUPLE: { if (empty_tuple_.isNone()) { @@ -725,10 +821,28 @@ PickleOpCode Unpickler::readInstruction() { stack_.emplace_back(tuple); } break; case PickleOpCode::EMPTY_DICT: - stack_.emplace_back(c10::impl::GenericDict(c10::impl::deprecatedUntypedDict())); + stack_.emplace_back( + c10::impl::GenericDict(AnyType::get(), AnyType::get())); break; case PickleOpCode::APPENDS: { - readList(); + size_t start = marks_.back(); + auto list_ivalue = stack_.at(start - 1); + readList(list_ivalue); + } break; + case PickleOpCode::LIST: { + IValue list_ivalue = c10::impl::GenericList(AnyType::get()); + readList(list_ivalue); + stack_.push_back(std::move(list_ivalue)); + } break; + case PickleOpCode::DICT: { + size_t start = marks_.back(); + marks_.pop_back(); + auto dict = c10::impl::GenericDict(AnyType::get(), AnyType::get()); + for (size_t i = start; i < stack_.size(); i += 2) { + dict.insert_or_assign(stack_[i], stack_[i + 1]); + } + stack_.erase(stack_.begin() + start, stack_.end()); + stack_.push_back(std::move(dict)); } break; case PickleOpCode::SETITEMS: { size_t start = marks_.back(); @@ -760,6 +874,9 @@ PickleOpCode Unpickler::readInstruction() { stack_.pop_back(); switch (pickler_class) { case PicklerClass::TENSOR: + TORCH_INTERNAL_ASSERT( + tensor_table_, + "Pickler tried to write a tensor but had no tensor table to write to"); stack_.emplace_back(tensor_table_->at(setitem_data.toInt())); break; case PicklerClass::INTLIST: @@ -779,7 +896,7 @@ PickleOpCode Unpickler::readInstruction() { case PicklerClass::TENSOR: TORCH_CHECK( tensor_table_, - "Found a tensor table reference but Pickler" + "Found a tensor table reference but Unpickler" " has no tensor table\n"); stack_.emplace_back(tensor_table_->at(data.toInt())); break; @@ -967,10 +1084,9 @@ std::string Unpickler::readBytes(size_t length) { // Pop all the list items off of the stack and append them to the list at // the corresponding MARK -void Unpickler::readList() { +void Unpickler::readList(IValue list_ivalue) { size_t start = marks_.back(); marks_.pop_back(); - auto list_ivalue = stack_.at(start - 1); auto num_elements = stack_.size() - start; auto elements = at::ArrayRef(stack_).slice(start); if (list_ivalue.isIntList()) { diff --git a/torch/csrc/jit/pickler.h b/torch/csrc/jit/pickler.h index 6e648f341e43a..43b4b7206c5d5 100644 --- a/torch/csrc/jit/pickler.h +++ b/torch/csrc/jit/pickler.h @@ -127,7 +127,7 @@ class Pickler { public: Pickler( std::function writer, - std::vector* tensor_table = nullptr) + std::vector* tensor_table) : writer_(writer), tensor_table_(tensor_table) {} // Push protocol onto the stack @@ -138,30 +138,26 @@ class Pickler { void pushIValue(const IValue& ivalue); - // See torch/serialization.py for details, pushes a magic number, torch - // serialization version, and system info to the pickle archive all as - // individual pickle programs - void torchSaveStart(); - void torchSaveStop(); - void startTuple(); void endTuple(); const std::vector& tensorData() { return tensor_data_; } + void pushEmptyDict(); + void pushDict(const IValue& ivalue); + void pushInt(int64_t value); + void pushLong(const std::string& data); private: void pushIValueImpl(const IValue& ivalue); - void pushDict(const IValue& ivalue); void pushDouble(double value); void pushGenericList(const IValue& ivalue); - void pushInt(int64_t value); void pushIntList(const IValue& ivalue); void pushList(const IValue& ivalue); - void pushLiteralTensor(const IValue& ivalue); void pushTensor(const IValue& ivalue); void pushTensorReference(const IValue& ivalue); + void pushLiteralTensor(const IValue& ivalue); void pushTuple(const IValue& ivalue); void pushString(const std::string& string); // unmemoized version @@ -279,7 +275,7 @@ class Unpickler { PickleOpCode readInstruction(); PickleOpCode readOpCode(); std::string readString(); - void readList(); + void readList(IValue list_ivalue); void setInput(size_t memo_id); void run(); diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index f1dbff5f14eb2..a5140dc90a674 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -394,7 +395,9 @@ inline IValue toIValue( for (size_t i = 0; i < tuple_size; ++i) { values.push_back(toIValue(tuple[i], elem_types[i])); } - return c10::ivalue::Tuple::create(std::move(values), tuple_type); + return tuple_type->name() + ? c10::ivalue::Tuple::createNamed(std::move(values), tuple_type) + : c10::ivalue::Tuple::create(std::move(values)); } case TypeKind::StringType: return ConstantString::create(py::cast(obj)); @@ -430,6 +433,8 @@ inline IValue toIValue( } return repeated; } + case TypeKind::BoolType: + return c10::impl::toList(py::cast>(obj)); case TypeKind::TensorType: return c10::impl::toList(py::cast>(obj)); default: @@ -492,6 +497,8 @@ inline IValue toIValue( AT_ERROR("Function Values aren't yet supported"); case TypeKind::CapsuleType: AT_ERROR("Capsule Values aren't supported"); + case TypeKind::AnyType: + AT_ERROR("AnyType Values aren't supported"); } AT_ERROR( "Missing cases in toIValue for type: ", @@ -609,11 +616,11 @@ inline py::object toPyObject(IValue&& ivalue) { for (size_t i = 0; i < elements.size(); ++i) { t[i] = toPyObject(IValue{elements.at(i)}); } - if (tuple->type && tuple->type->schema() && - tuple->type->schema()->name() != "") { - auto unqualName = tuple->type->name()->name(); + if (tuple->type() && tuple->type()->schema() && + tuple->type()->schema()->name() != "") { + auto unqualName = tuple->type()->name()->name(); auto fieldNames = fmap( - tuple->type->schema()->arguments(), + tuple->type()->schema()->arguments(), [](const Argument& arg) { return arg.name(); }); return py::module::import("torch.jit") .attr("_create_named_tuple")( diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index 3f20c9394609c..201ec0f2bda28 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -139,7 +139,6 @@ void ConcretePythonOp::cloneFrom(Node* other_) { this->cconv = other->cconv; Py_INCREF(other->pyobj.get()); this->pyobj = THPObjectPtr(other->pyobj.get()); - this->ignore_on_export = other->ignore_on_export; for (auto& sa : other->scalar_args) { Py_INCREF(sa.get()); this->scalar_args.emplace_back(sa.get()); diff --git a/torch/csrc/jit/register_c10_ops.cpp b/torch/csrc/jit/register_c10_ops.cpp index b8ddf23697b0b..a8d14fa45df27 100644 --- a/torch/csrc/jit/register_c10_ops.cpp +++ b/torch/csrc/jit/register_c10_ops.cpp @@ -3,9 +3,12 @@ #include #include #include +#include +#include namespace torch { namespace jit { + namespace { at::Tensor wrap_tensor(at::Tensor&& tensor) { @@ -46,15 +49,16 @@ IValue wrap(IValue&& ivalue) { Operator createOperatorFromC10(const c10::OperatorHandle& op) { return Operator(op, [op](Stack& stack) { RECORD_FUNCTION(op.schema().name(), stack); - const auto input_size = op.schema().arguments().size(); const auto output_size = op.schema().returns().size(); Node* node = nullptr; + std::shared_ptr tracer_state; // trace the input before unwrapping, otherwise we may lose // the input information if (jit::tracer::isTracing()) { + tracer_state = jit::tracer::getTracingState(); auto symbol = Symbol::fromQualString(op.schema().name()); const auto& graph = tracer::getTracingState()->graph; node = graph->create(symbol, 0); @@ -130,6 +134,8 @@ Operator createOperatorFromC10(const c10::OperatorHandle& op) { } } graph->insertNode(node); + + jit::tracer::setTracingState(nullptr); } c10::Dispatcher::singleton().lookup(op, &stack).call(&stack); @@ -139,7 +145,8 @@ Operator createOperatorFromC10(const c10::OperatorHandle& op) { *iter = wrap(std::move(*iter)); } - if (jit::tracer::isTracing()) { + if (tracer_state) { + jit::tracer::setTracingState(std::move(tracer_state)); int i = 0; for (auto iter = stack.end() - output_size; iter != stack.end(); ++iter, ++i) { @@ -170,6 +177,12 @@ Operator createOperatorFromC10(const c10::OperatorHandle& op) { class RegistrationListener final : public c10::OpRegistrationListener { public: void onOperatorRegistered(const c10::OperatorHandle& op) override { + if(at::aten_op_is_already_moved_to_c10(op.schema().operator_name())) { + // Ignore ATen ops for now because they have their own code + // to expose them to JIT in register_aten_ops.cpp + // TODO Remove register_aten_ops.cpp and also use this registration here + return; + } torch::jit::registerOperator(createOperatorFromC10(op)); } diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index a637b24294eb7..a109775c24772 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -51,12 +51,12 @@ namespace jit { namespace { template -c10::List make_result_list() { +c10::List make_result_list(const TypePtr& elemType) { return c10::List(); } -template<> -c10::impl::GenericList make_result_list() { - return c10::impl::GenericList(c10::impl::deprecatedUntypedList()); +template <> +c10::impl::GenericList make_result_list(const TypePtr& elemType) { + return c10::impl::GenericList(elemType); } Operation noop(const Node* n) { @@ -584,8 +584,13 @@ RegisterOperators reg( }, aliasAnalysisFromSchema()), Operator( - "prim::data(Tensor(b) a) -> Tensor(b)", - noop, + "prim::data(Tensor(a) a) -> Tensor(a)", + [](Stack& stack) { + at::Tensor a; + pop(stack, a); + push(stack, autograd::Variable(a).variable_data()); + return 0; + }, aliasAnalysisFromSchema()), Operator( "prim::is_cuda(Tensor a) -> bool", @@ -740,7 +745,7 @@ RegisterOperators reg( auto ivalue = pop(stack); // Pickle the tensor - auto data = pickle({ivalue}); + auto data = jit::pickle_save(ivalue); // Write file std::fstream output(filename, std::ios::out | std::ios::binary); @@ -1039,12 +1044,16 @@ RegisterOperators reg( [](const Node* node) { size_t num_inputs = node->inputs().size(); auto type = node->output()->type()->expect(); + bool named = type->name().has_value(); return [=](Stack& stack) { std::vector elems{ std::make_move_iterator(stack.end() - num_inputs), std::make_move_iterator(stack.end())}; drop(stack, num_inputs); - push(stack, c10::ivalue::Tuple::create(std::move(elems), type)); + push( + stack, + named ? c10::ivalue::Tuple::createNamed(std::move(elems), type) + : c10::ivalue::Tuple::create(std::move(elems))); return 0; }; }, @@ -1197,13 +1206,16 @@ RegisterOperators reg( throw std::runtime_error( "DictConstruct must have an even number of inputs"); } - TORCH_INTERNAL_ASSERT(node->outputs().size() == 1, "DictConstruct must have exactly one output"); + TORCH_INTERNAL_ASSERT( + node->outputs().size() == 1, + "DictConstruct must have exactly one output"); TypePtr output_type = node->outputs()[0]->type(); - TORCH_INTERNAL_ASSERT(output_type->kind() == TypeKind::DictType, "DictConstruct output must be of Dict type."); - TypePtr key_type = static_cast(output_type.get())->getKeyType(); - TypePtr value_type = static_cast(output_type.get())->getValueType(); + auto dt = output_type->expect(); + TypePtr key_type = dt->getKeyType(); + TypePtr value_type = dt->getValueType(); return [=](Stack& stack) { auto vals = c10::impl::GenericDict(key_type, value_type); + vals.reserve(num_inputs / 2); for (size_t i = 0; i < num_inputs; i += 2) { auto val = pop(stack); auto key = pop(stack); @@ -1214,6 +1226,16 @@ RegisterOperators reg( }; }, aliasAnalysisSpecialCase()), + Operator( + "aten::dict() -> Dict(str, Tensor)", + [](const Node* node) -> Operation { + return [](Stack& stack) { + auto dict = + c10::impl::GenericDict(StringType::get(), TensorType::get()); + push(stack, dict); + return 0; + }; + }), Operator( "aten::_unwrap_optional(t(a)? optional) -> t(a)", [](Stack& stack) { @@ -1485,6 +1507,42 @@ int listReverse(Stack& stack) { return 0; } +template int minList(Stack &stack) { + c10::List a = pop(stack).to>(); + c10::List b = pop(stack).to>(); + + size_t min_size = std::min(a.size(), b.size()); + for (size_t i = 0; i < min_size; i++) { + if (a[i] == b[i]) { + continue; + } + + push(stack, a[i] < b[i] ? a : b); + return 0; + } + + push(stack, b.size() < a.size() ? b : a); + return 0; +} + +template int maxList(Stack &stack) { + c10::List a = pop(stack).to>(); + c10::List b = pop(stack).to>(); + + size_t min_size = std::min(a.size(), b.size()); + for (size_t i = 0; i < min_size; i++) { + if (a[i] == b[i]) { + continue; + } + + push(stack, a[i] > b[i] ? a : b); + return 0; + } + + push(stack, b.size() > a.size() ? b : a); + return 0; +} + template int listPop(Stack& stack) { int64_t idx = pop(stack).to(); @@ -1549,6 +1607,42 @@ int listRemove(Stack& stack) { return 0; } +template +int listMin(Stack& stack) { + c10::List list = pop(stack).to>(); + size_t list_size = list.size(); + if (list_size == 0) { + throw std::runtime_error("min() arg is an empty sequence"); + } + + T min_elem = list[0]; + for (size_t i = 1; i < list_size; ++i) { + T elem = list[i]; + min_elem = elem < min_elem ? elem : min_elem; + } + + stack.push_back(min_elem); + return 0; +} + +template +int listMax(Stack& stack) { + c10::List list = pop(stack).to>(); + size_t list_size = list.size(); + if (list_size == 0) { + throw std::runtime_error("max() arg is an empty sequence"); + } + + T max_elem = list[0]; + for (size_t i = 1; i < list_size; ++i) { + T elem = list[i]; + max_elem = elem > max_elem ? elem : max_elem; + } + + stack.push_back(max_elem); + return 0; +} + template <> int listRemove(Stack& stack) { at::Tensor elem = pop(stack).to(); @@ -1734,12 +1828,26 @@ int listList(Stack& stack) { return 0; } +template +int listContains(Stack& stack) { + auto key = pop(stack).to(); + auto list = pop(stack).to>(); + for (const T& item : list) { + if (item == key) { + push(stack, true); + return 0; + } + } + push(stack, false); + return 0; +} + template int listAdd(Stack& stack) { c10::List b = pop(stack).to>(); c10::List a = pop(stack).to>(); - c10::List ret = make_result_list(); + c10::List ret = make_result_list(a.elementType()); if (a.use_count() == 1) { ret = std::move(a); @@ -1767,7 +1875,7 @@ int listMulIntLeft(Stack& stack) { int64_t n = pop(stack).to(); c10::List list = pop(stack).to>(); - c10::List ret = make_result_list(); + c10::List ret = make_result_list(list.elementType()); const auto size = list.size() * n; ret.reserve(size); @@ -1786,7 +1894,7 @@ int listMulIntRight(Stack& stack) { c10::List list = pop(stack).to>(); int64_t n = pop(stack).to(); - c10::List ret = make_result_list(); + c10::List ret = make_result_list(list.elementType()); const auto size = list.size() * n; ret.reserve(size); @@ -1815,7 +1923,7 @@ int listSlice(Stack& stack) { const auto normalized_end = std::min(list_size, normalizeIndex(end, list_size)); - c10::List sliced_list = make_result_list(); + c10::List sliced_list = make_result_list(list.elementType()); if (normalized_end <= normalized_start) { // early exit if the slice is trivially empty push(stack, std::move(sliced_list)); @@ -1943,12 +2051,10 @@ c10::List makeListForDictKeysOrValues( template c10::impl::GenericList makeGenericListForDictKeysOrValues( - const std::pair, c10::optional>& types, + const std::pair& types, const std::vector>& order) { auto type = std::get(types); - auto values = type.has_value() - ? c10::impl::GenericList(*type) - : c10::impl::GenericList(c10::impl::deprecatedUntypedList()); + auto values = c10::impl::GenericList(type); values.reserve(order.size()); for (const auto& item : order) { values.push_back(std::get(item)); @@ -1962,7 +2068,7 @@ Operation dictKeysOrValues(const Node* n) { return [=](Stack& stack) -> int { auto dict = pop(stack).toGenericDict(); const auto& order = iterationOrder(dict); - const auto types = std::make_pair(dict._keyType(), dict._valueType()); + const auto types = std::make_pair(dict.keyType(), dict.valueType()); if (outputType->getElementType()->isSubtypeOf(TensorType::get())) { push(stack, makeListForDictKeysOrValues(types, order)); } else if (outputType->getElementType() == IntType::get()) { @@ -2097,11 +2203,10 @@ int dictUpdate(Stack& stack) { int dictItems(Stack& stack) { auto dict = pop(stack).toGenericDict(); - auto key_type = dict._keyType(); - auto value_type = dict._valueType(); - auto items = (key_type.has_value() && value_type.has_value()) - ? c10::impl::GenericList(TupleType::create({*key_type, *value_type})) - : c10::impl::GenericList(c10::impl::deprecatedUntypedList()); + auto key_type = dict.keyType(); + auto value_type = dict.valueType(); + auto items = + c10::impl::GenericList(TupleType::create({key_type, value_type})); items.reserve(dict.size()); for (const auto& item : iterationOrder(dict)) { items.emplace_back(c10::ivalue::Tuple::create({item.first, item.second})); @@ -2115,6 +2220,26 @@ int dictCopy(Stack& stack) { return 0; } +Operation dictConstructFromList(const Node* node) { + TypePtr output_type = node->outputs()[0]->type(); + TypePtr key_type = + static_cast(output_type.get())->getKeyType(); + TypePtr value_type = + static_cast(output_type.get())->getValueType(); + return [key_type, value_type](Stack& stack) { + auto input_list = pop(stack); + auto list_ref = input_list.toGenericListRef(); + auto dict = c10::impl::GenericDict(key_type, value_type); + dict.reserve(list_ref.size()); + for (const auto& input : list_ref) { + const auto tup = input.toTuple()->elements(); + dict.insert_or_assign(tup[0], tup[1]); + } + push(stack, dict); + return 0; + }; +} + template int hashValue(Stack& stack) { auto value = pop(stack); @@ -2261,6 +2386,14 @@ RegisterOperators reg2({ "aten::__getitem__(" decl_type "[](a) list, int idx) -> " decl_type, \ listSelect, \ aliasAnalysisFromSchema()), \ + Operator( \ + "prim::min(" decl_type "[] l, " decl_type "[] r) -> " decl_type "[]",\ + minList, \ + aliasAnalysisFromSchema()), \ + Operator( \ + "prim::max(" decl_type "[] l, " decl_type "[] r) -> " decl_type "[]",\ + maxList, \ + aliasAnalysisFromSchema()), \ Operator( \ "aten::append(" decl_type "[](a!) self, " decl_type \ " el) -> " decl_type "[](a!)", \ @@ -2270,6 +2403,14 @@ RegisterOperators reg2({ "aten::reverse(" decl_type "[](a!) self) -> ()", \ listReverse, \ aliasAnalysisFromSchema()), \ + Operator( \ + "prim::min(" decl_type "[] self) -> " decl_type, \ + listMin, \ + aliasAnalysisFromSchema()), \ + Operator( \ + "prim::max(" decl_type "[] self) -> " decl_type, \ + listMax, \ + aliasAnalysisFromSchema()), \ Operator( \ "aten::extend(" decl_type "[](a!) self, " decl_type \ " [] other) -> ()", \ @@ -2370,6 +2511,21 @@ RegisterOperators reg2({ CREATE_LIST_OPS("bool", c10::List), CREATE_LIST_OPS("Tensor", c10::List), CREATE_LIST_OPS("t", c10::List), + + // `listContains` is not implemented for non-primitive types + // TODO: Add List[bool] once .to> doesn't throw an error + Operator( + "aten::__contains__(int[] l, int item) -> bool", + listContains, + aliasAnalysisFromSchema()), + Operator( + "aten::__contains__(float[] l, float item) -> bool", + listContains, + aliasAnalysisFromSchema()), + Operator( + "aten::__contains__(str[] l, str item) -> bool", + listContains, + aliasAnalysisFromSchema()), #undef CREATE_LIST_OPS Operator( "aten::sort(int[](a!) self, bool reverse=False) -> ()", @@ -2542,22 +2698,6 @@ RegisterOperators reg2({ DEFINE_BINARY_OP(prim::min, a < b ? a : b), DEFINE_BINARY_OP(prim::max, a > b ? a : b), - Operator( - "prim::min(int[] x) -> int", - [](Stack& stack) { - c10::List int_list = pop(stack).toIntList(); - int64_t min_element = std::numeric_limits::max(); - - for(int64_t ele: int_list) { - if(ele < min_element) { - min_element = ele; - } - } - push(stack, min_element); - return 0; - }, - aliasAnalysisFromSchema()), - // Pass in two ops for handling int and float separately as % in C++ only // works for int The modulus calculation is different between C++ and Python // (on negative), we preserve the python behavior as it's more common and @@ -2867,6 +3007,11 @@ RegisterOperators reg2({ "aten::_set_item(Dict(" key_type ", t)(a!) l, " key_type \ " idx, t(b -> *) v) -> ()", \ dictSetItem, \ + aliasAnalysisFromSchema()), \ + Operator( \ + "aten::dict((" key_type ", tVal)[] inputs) -> Dict(" key_type \ + ", tVal)", \ + dictConstructFromList, \ aliasAnalysisFromSchema()) CREATE_DICT_OPS("str"), diff --git a/torch/csrc/jit/register_special_ops.cpp b/torch/csrc/jit/register_special_ops.cpp index 9603ba57c7164..61bd5c0372f86 100644 --- a/torch/csrc/jit/register_special_ops.cpp +++ b/torch/csrc/jit/register_special_ops.cpp @@ -472,6 +472,13 @@ RegisterOperators reg({ return 0; }, aliasAnalysisFromSchema()), + Operator( + "aten::is_scripting() -> bool", + [](Stack& stack) { + push(stack, true); + return 0; + }, + aliasAnalysisFromSchema()), Operator( "aten::_no_grad_uniform_(Tensor(a!) tensor, float a, float b) -> Tensor(a!)", [](Stack& stack) { diff --git a/torch/csrc/jit/register_string_ops.cpp b/torch/csrc/jit/register_string_ops.cpp index 9a14d5b30346f..ba6d8e1fe340a 100644 --- a/torch/csrc/jit/register_string_ops.cpp +++ b/torch/csrc/jit/register_string_ops.cpp @@ -389,7 +389,7 @@ auto reg_str_ops_2 = return stringFindImpl(string, substr, start, end, true); })) - .op("aten::index(str self, str substr, int start=0, int end=-1) -> int", + .op("aten::index.str(str self, str substr, int start=0, int end=-1) -> int", torch::RegisterOperators::options() .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA) .catchAllKernel([](std::string string, diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index 25edb7f585af1..fb2606e016c4b 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -323,7 +323,7 @@ struct Environment { << value->kind() << " and " << name << " is not a first-class value. Only reassignments to first-class values are allowed"; } - + auto parent_type = unshapedType(simple_parent->type()); as_simple_value = tryConvertToType( loc, @@ -1268,6 +1268,59 @@ struct to_ir { return expr.kind() == TK_IS || expr.kind() == TK_ISNOT; } + Value* emitIsInstance(Expr obj, Expr classinfo) { + // turn (float, (int, tuple)) into a flat list of types and type kind + // category checks: tuple_check = true, types = {float, int} + struct GatheredTypes { + GatheredTypes(ScriptTypeParser parser) : typeParser_(std::move(parser)) {} + void gather(Expr classinfo) { + if (classinfo.kind() == TK_TUPLE_LITERAL) { + for (Expr e : TupleLiteral(classinfo).inputs()) { + gather(e); + } + return; + } + if (classinfo.kind() == TK_VAR) { + // Special casing for list and tuple since isinstance(x, list) and + // isinstance(x, tuple) does not accept List[int] / Tuple[int] like + // subscript type annotation in python + auto name = Var(classinfo).name().name(); + if (name == "tuple") { + tuple_check = true; + return; + } else if (name == "list") { + list_check = true; + return; + } + } + TypePtr type = typeParser_.parseTypeFromExpr(classinfo); + types.emplace_back(type); + } + ScriptTypeParser typeParser_; + bool list_check = false; + bool tuple_check = false; + std::vector types; + }; + GatheredTypes gathered(typeParser_); + gathered.gather(classinfo); + auto val = emitExpr(obj); + if (val->type()->kind() == OptionalType::Kind) { + throw ErrorReport(obj.range()) + << "Optional isinstance check is not supported, " + << "consider use is/is not None instead"; + } + if ((gathered.list_check && val->type()->kind() == ListType::Kind) || + (gathered.tuple_check && val->type()->kind() == TupleType::Kind)) { + return graph->insertConstant(true, obj.range()); + } + for (const TypePtr& typ : gathered.types) { + if (val->type()->isSubtypeOf(typ)) { + return graph->insertConstant(true, obj.range()); + } + } + return graph->insertConstant(false, obj.range()); + } + void emitIf(const If& stmt) { // NOTE: emitIf checks on If stmt condition to see if the cond AST is // a potential none check with is/is not, or an isinstance check. @@ -1799,25 +1852,26 @@ struct to_ir { graph->insert( aten::index_put_, {slicedArg, indices, rhs}, {}, stmtRange); } - - // Otherwise, this is a list. Dispatch to aten::_set_item to both select - // and assign + // Otherwise, this is a list or a classtype. + // Dispatch to aten::_set_item to both select and assign } else { const auto subscript = lhs.subscript_exprs(); if (subscript.size() != 1 || subscript[0].kind() == TK_SLICE_EXPR) { throw ErrorReport(subscript) << "Sliced expression not yet supported for" - << " subscripted list assignment. " + << " subscripted assignment. " << "File a bug if you want this"; } std::vector args; - args.emplace_back(lhs.value().range(), "list", sliceable); + args.emplace_back(lhs.value().range(), "self", sliceable); args.emplace_back( lhs.subscript_exprs().range(), "idx", emitExpr(subscript[0])); args.push_back(rhs); - - graph->insert(aten::_set_item, args, {}, stmtRange); + makeMagic( + "__setitem__", + std::make_shared(aten::_set_item, at::nullopt)) + ->call(stmtRange, method, args, {}, 0); } } @@ -2126,10 +2180,8 @@ struct to_ir { }); } - void checkApplyExpr( - Apply& apply, - SourceRange& loc, - size_t expected_inputs = 2) { + void checkApplyNumInputs(Apply& apply, size_t expected_inputs) { + const SourceRange& loc = apply.range(); if (apply.inputs().size() != expected_inputs) { throw ErrorReport(loc) << Var(apply.callee()).name().name() << " expected exactly " @@ -2156,7 +2208,7 @@ struct to_ir { auto attributes = emitAttributes(apply.attributes()); return emitForkExpr(loc, forked, inputs, attributes); } else if (auto annotate_value = dynamic_cast(sv.get())) { - checkApplyExpr(apply, loc); + checkApplyNumInputs(apply, 2); TypePtr type = typeParser_.parseTypeFromExpr(apply.inputs()[0]); Value* expr = tryConvertToType( apply.range(), @@ -2187,7 +2239,7 @@ struct to_ir { return std::make_shared(expr); } else if (auto getattr = dynamic_cast(sv.get())) { - checkApplyExpr(apply, loc); + checkApplyNumInputs(apply, 2); auto obj = emitSugaredExpr(apply.inputs()[0], 1); auto selector = apply.inputs()[1]; if (selector.kind() != TK_STRINGLITERAL) { @@ -2199,13 +2251,13 @@ struct to_ir { } else if ( auto uninitialized_value = dynamic_cast(sv.get())) { - checkApplyExpr(apply, loc, 1); + checkApplyNumInputs(apply, 1); TypePtr type = typeParser_.parseTypeFromExpr(apply.inputs()[0]); auto out = graph->insertNode(graph->createUninitialized(type)) ->setSourceRange(loc); return std::make_shared(out->output()); } else if (auto tuple_call = dynamic_cast(sv.get())) { - checkApplyExpr(apply, loc, /*expected_inputs*/ 1); + checkApplyNumInputs(apply, 1); auto arg = emitSugaredExpr(apply.inputs()[0], 1); auto inputs = arg->asTuple(apply.range(), method); auto inp_values = fmap(inputs, [&](const SugaredValuePtr& sv) { @@ -2217,48 +2269,9 @@ struct to_ir { // NOTE: for `isinstance` builtin call in JIT, we only check the static // types on the inputs to evaluate, and insert the corresponding constant // node - std::function isInstanceCheck = [&](Expr obj, - Expr classinfo) { - if (classinfo.kind() == TK_TUPLE_LITERAL) { - // handle the case for recursive tuple classinfo - // return true if obj is an instance of any of the types - for (Expr e : TupleLiteral(classinfo).inputs()) { - if (isInstanceCheck(obj, e)) { - return true; - } - } - return false; - } - auto type_name = typeParser_.parseBaseTypeName(classinfo); - if (!type_name) { - throw ErrorReport(classinfo.range()) - << "type must be a type identifier"; - } - auto val = emitExpr(obj); - // Special casing for list and tuple since isinstance(x, list) and - // isinstance(x, tuple) does not accept List[int] / Tuple[int] like - // subscript type annotation in python - if (*type_name == "list" && val->type()->cast()) { - return true; - } else if (*type_name == "tuple" && val->type()->cast()) { - return true; - } else if (val->type()->cast()) { - throw ErrorReport(loc) - << "Optional isinstance check is not supported, " - << "consider use is/isnot None instead"; - } else { - TypePtr type = typeParser_.parseTypeFromExpr(classinfo); - if (val->type()->isSubtypeOf(type)) { - return true; - } - } - return false; - }; - checkApplyExpr(apply, loc); - bool is_instance_val = - isInstanceCheck(apply.inputs()[0], apply.inputs()[1]); - return std::make_shared( - graph->insertConstant(is_instance_val, loc)); + checkApplyNumInputs(apply, 2); + auto result = emitIsInstance(apply.inputs()[0], apply.inputs()[1]); + return std::make_shared(result); } else if (auto classNew = dynamic_cast(sv.get())) { if (apply.inputs().size() != 1) { throw ErrorReport(loc) << "Only one argument to __new__ allowed"; diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index c8d9030e23640..ddc2f3ecc29d7 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -370,6 +370,73 @@ void addFunctionToModule(Module& module, const StrongFunctionPtr& func) { module.type()->addMethod(method); } +// this is used in our test suite to check that we correctly preserved type tags +bool ivalue_tags_match(const Module& lhs, const Module& rhs) { + struct Work { + IValue a; + IValue b; + }; + std::unordered_set visited; + std::vector work = {{lhs.module_object(), rhs.module_object()}}; + while (!work.empty()) { + Work item = work.back(); + work.pop_back(); + if (item.a.isPtrType()) { + // uncomment to debug type matching errors + // std::cout << "MATCHING " << /*item.a <<*/ "(" << *item.a.type() << ") " + // << item.a.internalToPointer() << " " << /*item.b <<*/ " (" + // << *item.b.type() << ") " << item.b.internalToPointer() << + // "\n"; + + if (visited.count(item.a.internalToPointer())) { + continue; + } + visited.emplace(item.a.internalToPointer()); + } + if (*unshapedType(item.a.type()) != *unshapedType(item.b.type())) { + return false; + } + // check tags for objects that contain subobjects + if (item.a.isObject()) { + auto ao = item.a.toObject(); + auto bo = item.b.toObject(); + for (size_t i = 0; i < ao->slots().size(); ++i) { + work.emplace_back(Work{ao->slots().at(i), bo->slots().at(i)}); + } + } else if (item.a.isTuple()) { + auto at = item.a.toTuple(); + auto bt = item.b.toTuple(); + for (size_t i = 0; i < at->elements().size(); ++i) { + work.emplace_back(Work{at->elements().at(i), bt->elements().at(i)}); + } + } else if (item.a.isGenericList()) { + auto al = item.a.toGenericList(); + auto bl = item.b.toGenericList(); + for (size_t i = 0; i < al.size(); ++i) { + work.emplace_back(Work{al.get(i), bl.get(i)}); + } + } else if (item.a.isGenericDict()) { + auto ad = item.a.toGenericDict(); + auto bd = item.b.toGenericDict(); + for (auto& item : ad) { + // Dictionaory keys cannot contain List/Dicts that require tags + // so we do not have to check them. + // Furthermore without ordered dicts it is expensive to find the + // equivalent key + work.emplace_back(Work{item.value(), bd.at(item.key())}); + } + } else if (item.a.isFuture()) { + auto af = item.a.toFuture(); + auto bf = item.b.toFuture(); + af->wait(); + bf->wait(); + work.emplace_back(Work{af->value(), bf->value()}); + } + } + + return true; +} + void initJitScriptBindings(PyObject* module) { auto m = py::handle(module).cast(); @@ -403,9 +470,9 @@ void initJitScriptBindings(PyObject* module) { .def( "_dump", &Module::dump, - py::arg("omit_method_bodies") = true, - py::arg("omit_attr_values") = true, - py::arg("omit_param_values") = true) + py::arg("code") = true, + py::arg("attrs") = true, + py::arg("params") = true) .def( "_define", [](Module& m, @@ -872,6 +939,7 @@ void initJitScriptBindings(PyObject* module) { auto fn = cu->create_function(std::move(name), graph); return StrongFunctionPtr(std::move(cu), fn); }); + m.def("_ivalue_tags_match", ivalue_tags_match); py::class_(m, "FileCheck") .def(py::init<>()) @@ -926,6 +994,12 @@ void initJitScriptBindings(PyObject* module) { m.def("_get_graph_executor_optimize", &torch::jit::getGraphExecutorOptimize); + m.def( + "_resolve_type", + [](const std::string& name, SourceRange range, ResolutionCallback rcb) { + return pythonResolver(rcb)->resolveType(name, range); + }); + py::class_>( m, "LoggerBase"); py::enum_(m, "AggregationType") diff --git a/torch/csrc/jit/script/module.cpp b/torch/csrc/jit/script/module.cpp index cd6cfa8341b98..f51496c92106e 100644 --- a/torch/csrc/jit/script/module.cpp +++ b/torch/csrc/jit/script/module.cpp @@ -405,9 +405,9 @@ void Module::apply(const std::function& fn) { } std::string Module::_dump_to_string( - bool omit_method_bodies, - bool omit_attr_values, - bool omit_param_values, + bool print_method_bodies, + bool print_attr_values, + bool print_param_values, int level) const { std::stringstream ss; std::stringstream parameters_ss; @@ -417,7 +417,7 @@ std::string Module::_dump_to_string( for (Slot param : get_parameters()) { parameters_ss << param.name() << " = "; - if (!omit_param_values) { + if (print_param_values) { parameters_ss << param.value().toTensor() << std::endl; } else { parameters_ss << "..." << std::endl; @@ -426,7 +426,7 @@ std::string Module::_dump_to_string( for (Slot attr : get_attributes()) { attributes_ss << attr.name() << " = "; - if (!attr.value().isTensor() || !omit_attr_values) { + if (!attr.value().isTensor() || print_attr_values) { attributes_ss << attr.value() << std::endl; } else { attributes_ss << "..." << std::endl; @@ -435,7 +435,7 @@ std::string Module::_dump_to_string( for (const Method& method : get_methods()) { methods_ss << " method " << method.name() << " {" << std::endl; - if (!omit_method_bodies) { + if (print_method_bodies) { methods_ss << torch::jit::jit_log_prefix( " ", method.graph()->toString()) << std::endl; @@ -458,7 +458,7 @@ std::string Module::_dump_to_string( // We do level + 2, because one level of indentation comes from 'submodules' // scope and the other one goes from a specific submodule we're printing. ss << submodule._dump_to_string( - omit_method_bodies, omit_attr_values, omit_param_values, level + 2); + print_method_bodies, print_attr_values, print_param_values, level + 2); } ss << " }" << std::endl; ss << "}" << std::endl; @@ -468,11 +468,14 @@ std::string Module::_dump_to_string( } void Module::dump( - bool omit_method_bodies = true, - bool omit_attr_values = true, - bool omit_param_values = true) const { + bool print_method_bodies = true, + bool print_attr_values = true, + bool print_param_values = true) const { std::cout << _dump_to_string( - omit_method_bodies, omit_attr_values, omit_param_values, 0) + print_method_bodies, + print_attr_values, + print_param_values, + 0) << std::endl; } diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index a090fdded8047..fbc6dadd75cf1 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -206,9 +206,9 @@ struct TORCH_API Module { slot_list get_module_slots() const; void dump( - bool omit_method_bodies, - bool omit_attr_values, - bool omit_param_values) const; + bool print_method_bodies, + bool print_attr_values, + bool print_param_values) const; const std::vector get_methods() const { return fmap( diff --git a/torch/csrc/jit/script/python_sugared_value.cpp b/torch/csrc/jit/script/python_sugared_value.cpp index 246497b117917..9ecce419f835b 100644 --- a/torch/csrc/jit/script/python_sugared_value.cpp +++ b/torch/csrc/jit/script/python_sugared_value.cpp @@ -31,7 +31,8 @@ FunctionSchema PythonValue::getSchema( const size_t n_binders, const SourceRange& loc) { auto annotations = py::module::import("torch.jit.annotations"); - auto signature = annotations.attr("get_signature")(self); + auto signature = + annotations.attr("get_signature")(self, rcb ? *rcb : py::none(), loc); std::vector args, rets; // We may mutate this if we can determine the number of args from Python // introspection. @@ -108,18 +109,28 @@ std::shared_ptr PythonValue::call( if (!matched_schema) throw ErrorReport(loc) << failure_messages.str(); + // If if a function is marked as dropped, + // we throw an exception if it is invoked. + if (py::cast(py::module::import("torch._jit_internal") + .attr("should_drop")(self))) { + auto g = m.graph(); + auto err_msg = insertConstant( + *g, + IValue( + "This Python function is annotated to be ignored and cannot be run")); + g->insert(prim::RaiseException, {err_msg}, {}, loc); + return std::make_shared( + g->insertNode( + g->createUninitialized(matched_schema->return_types.at(0))) + ->output()); + } + // Release the function object so we can wrap it in a PythonOp py::object func = self; std::string cconv(inputs.size(), 'd'); Node* new_node = m.graph()->insertNode( m.graph()->createPythonOp(THPObjectPtr(func.release().ptr()), cconv, {})); - // Mark if function is ignored on export - if (py::cast(py::module::import("torch._jit_internal") - .attr("should_drop_on_export")(self))) { - auto python_op = static_cast(new_node); - python_op->ignore_on_export = true; - } new_node->setSourceRange(loc); for (auto& i : matched_schema->inputs) new_node->addInput(i); @@ -226,6 +237,7 @@ std::shared_ptr OverloadedMethodValue::call( for (const std::string& method_name : method_names_) { auto cls = module_->type()->expect(); const auto fn = cls->getMethod(method_name); + TORCH_INTERNAL_ASSERT(fn, "Expected class to have method ", method_name); auto match = tryMatchSchema( fn->getSchema(), loc, @@ -277,6 +289,53 @@ Value* ModuleValue::asValue(const SourceRange& loc, Function& m) { return self_; } +std::vector> ModuleValue::desugarModuleContainer( + bool get_keys, + bool get_values, + const SourceRange& loc, + Function& m) { + // the submodules in the module list may be a mix of python objects + // and script Modules. If we need to load a Module, we need its field + // name so we can emit 'self.field_name'. + std::unordered_map obj_to_field; + for (Slot s : module_.get_module_slots()) { + obj_to_field[s.value().toObject().get()] = s.name(); + } + + std::vector> result; + for (py::handle py_submodule : py_module_) { + py::object obj = py::reinterpret_borrow(py_submodule); + if (auto sub_module = as_module(obj)) { + const auto& name = obj_to_field.at(sub_module->module_object().get()); + auto name_v = + std::make_shared(insertConstant(*m.graph(), name)); + Value* module_v = m.graph()->insertGetAttr(self_, name); + auto mod_v = std::make_shared(module_v, *sub_module, obj); + + if (get_keys && get_values) { + std::vector> tup; + tup.push_back(name_v); + tup.push_back(mod_v); + result.push_back( + std::make_shared(ConstantTupleValue(tup))); + } else if (get_keys) { + result.push_back(name_v); + } else if (get_values) { + result.push_back(mod_v); + } else { + TORCH_INTERNAL_ASSERT(false); + } + } else { + result.push_back(toSugaredValue( + obj, + m, + loc, + /*is_constant =*/false)); + } + } + return result; +} + std::shared_ptr ModuleValue::attr( const SourceRange& loc, Function& m, @@ -317,6 +376,26 @@ std::shared_ptr ModuleValue::attr( if (!py::hasattr(py_module_, field.c_str())) { throw ErrorReport(loc) << "module has no attribute '" << field << "'"; } + + auto is_mod_dict = py::isinstance( + py_module_, py::module::import("torch.jit").attr("_ConstModuleDict")); + if (is_mod_dict) { + if (field == "items" || field == "keys" || field == "values") { + bool get_keys = false; + bool get_values = false; + if (field == "items") { + get_keys = true; + get_values = true; + } else if (field == "values") { + get_values = true; + } else { + get_keys = true; + } + return std::make_shared( + desugarModuleContainer(get_keys, get_values, loc, m), field); + } + } + py::object attr = py::getattr(py_module_, field.c_str()); // HACK: This is used for rnn.py to get all the parameters of a Module as a @@ -387,35 +466,20 @@ std::vector> ModuleValue::asTuple( const SourceRange& loc, Function& m, const c10::optional& size_hint) { - if (!py::isinstance( - py_module_, py::module::import("torch.jit").attr("_ConstModuleList"))) - return SugaredValue::asTuple(loc, m, size_hint); + auto is_mod_dict = py::isinstance( + py_module_, py::module::import("torch.jit").attr("_ConstModuleDict")); + auto is_mod_list = py::isinstance( + py_module_, py::module::import("torch.jit").attr("_ConstModuleList")); - // the submodules in the module list may be a mix of python objects - // and script Modules. If we need to load a Module, we need its field - // name so we can emit 'self.field_name'. - std::unordered_map obj_to_field; - for (Slot s : module_.get_module_slots()) { - obj_to_field[s.value().toObject().get()] = s.name(); + if (!is_mod_list && !is_mod_dict) { + return SugaredValue::asTuple(loc, m, size_hint); } - std::vector> result; - for (py::handle py_submodule : py_module_) { - py::object obj = py::reinterpret_borrow(py_submodule); - if (auto sub_module = as_module(obj)) { - Value* module_v = m.graph()->insertGetAttr( - self_, obj_to_field.at(sub_module->module_object().get())); - result.emplace_back( - std::make_shared(module_v, *sub_module, obj)); - } else { - result.push_back(toSugaredValue( - obj, - m, - loc, - /*is_constant =*/false)); - } - } - return result; + // iterating over a dictionary returns the keys, iterating over a + // list returns the values + bool get_keys = is_mod_dict; + bool get_values = !is_mod_dict; + return desugarModuleContainer(get_keys, get_values, loc, m); } void ModuleValue::setAttr( @@ -596,6 +660,8 @@ std::shared_ptr toSugaredValue( if (auto callee = as_function(compiled_fn)) { return std::make_shared(*callee); } + auto rcb = py::module::import("torch.jit").attr("_gen_rcb")(obj, 0); + return std::make_shared(obj, rcb); } return std::make_shared(obj); diff --git a/torch/csrc/jit/script/python_sugared_value.h b/torch/csrc/jit/script/python_sugared_value.h index 54387c26c8784..d8f1a53298c87 100644 --- a/torch/csrc/jit/script/python_sugared_value.h +++ b/torch/csrc/jit/script/python_sugared_value.h @@ -30,7 +30,8 @@ std::shared_ptr toSugaredValue( c10::optional as_function(const py::object& obj); struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { - PythonValue(py::object the_self) : self(std::move(the_self)) {} + PythonValue(py::object the_self, c10::optional rcb = c10::nullopt) + : self(std::move(the_self)), rcb(std::move(rcb)) {} FunctionSchema getSchema( const size_t n_args, @@ -63,6 +64,7 @@ struct VISIBILITY_HIDDEN PythonValue : public SugaredValue { void checkForAddToConstantsError(std::stringstream& ss); py::object self; + c10::optional rcb; }; struct VISIBILITY_HIDDEN PythonModuleValue : public PythonValue { @@ -104,6 +106,54 @@ struct VISIBILITY_HIDDEN ConstantParameterList : public SugaredValue { Value* the_list_; }; +struct VISIBILITY_HIDDEN ConstantTupleValue : public SugaredValue { + explicit ConstantTupleValue( + std::vector> tup, + bool callable = false) + : tup_(tup){}; + + std::vector> asTuple( + const SourceRange& loc, + Function& m, + const c10::optional& size_hint = {}) override { + return tup_; + }; + + std::string kind() const override { + return "constant tuple"; + } + + std::vector> tup_; + bool callable_; +}; + +struct VISIBILITY_HIDDEN ConstantTupleMethod : public SugaredValue { + explicit ConstantTupleMethod( + std::vector> tup, + const std::string& name) + : tup_(tup), name_(name){}; + + std::string kind() const override { + return name_; + } + + std::shared_ptr call( + const SourceRange& loc, + Function& f, + at::ArrayRef inputs, + at::ArrayRef attributes, + size_t n_binders) override { + if (inputs.size() || attributes.size()) { + throw ErrorReport(loc) + << name_ << " method does not accept any arguments"; + } + return std::make_shared(tup_); + } + + std::vector> tup_; + const std::string name_; +}; + struct VISIBILITY_HIDDEN OverloadedMethodValue : public SugaredValue { OverloadedMethodValue(Value* module, std::vector method_names) : module_(module), method_names_(std::move(method_names)) {} @@ -194,6 +244,12 @@ struct VISIBILITY_HIDDEN ModuleValue : public SugaredValue { Value* self_; Module module_; py::object py_module_; + + std::vector> desugarModuleContainer( + bool get_keys, + bool get_values, + const SourceRange& loc, + Function& m); }; struct VISIBILITY_HIDDEN BooleanDispatchValue : public SugaredValue { diff --git a/torch/csrc/jit/script/schema_matching.cpp b/torch/csrc/jit/script/schema_matching.cpp index 78d1adffe39ea..8dca641cb07db 100644 --- a/torch/csrc/jit/script/schema_matching.cpp +++ b/torch/csrc/jit/script/schema_matching.cpp @@ -130,8 +130,8 @@ static Value* tryMatchArgument( } // Resolve VarType variables - const MatchTypeReturn matched = - matchTypeVariables(arg.type(), value->type(), type_env); + const MatchTypeReturn matched = matchTypeVariables( + arg.type(), value->type(), type_env); if (!matched.success()) { if (failure_messages) { err() << "Could not match type " << value->type()->python_str() << " to " @@ -242,11 +242,16 @@ static bool varargsCanBeUsedAsList( // The formal must be a list bool argument_is_list = arg.type()->kind() == TypeKind::ListType; + // matching varargs of typevar list nyi + bool typevar_list = argument_is_list && + arg.type()->cast()->getElementType()->cast(); + // it must not be a broadcasting list like int[3], // otherwise a single int is a valid input bool arg_is_broadcasting_list = bool(arg.N()); - return is_last_argument && argument_is_list & !arg_is_broadcasting_list; + return is_last_argument && argument_is_list & !arg_is_broadcasting_list && + !typevar_list; } c10::optional tryMatchSchema( diff --git a/torch/csrc/jit/script/script_type_parser.h b/torch/csrc/jit/script/script_type_parser.h index 6ebfd4b870f7f..bf76c26819e1c 100644 --- a/torch/csrc/jit/script/script_type_parser.h +++ b/torch/csrc/jit/script/script_type_parser.h @@ -19,8 +19,6 @@ class TORCH_API ScriptTypeParser { explicit ScriptTypeParser() {} explicit ScriptTypeParser(ResolverPtr resolver) : resolver_(std::move(resolver)) {} - c10::optional parseBaseTypeName(const Expr& expr) const; - c10::TypePtr parseTypeFromExpr(const Expr& expr) const; c10::optional> parseBroadcastList( @@ -31,6 +29,7 @@ class TORCH_API ScriptTypeParser { FunctionSchema parseSchemaFromDef(const Def& def, bool skip_self); private: + c10::optional parseBaseTypeName(const Expr& expr) const; at::TypePtr subscriptToType( const std::string& typeName, const Subscript& subscript) const; diff --git a/torch/csrc/jit/script/sugared_value.cpp b/torch/csrc/jit/script/sugared_value.cpp index 9fd807e2e3384..d2813ce0d5b5d 100644 --- a/torch/csrc/jit/script/sugared_value.cpp +++ b/torch/csrc/jit/script/sugared_value.cpp @@ -289,15 +289,32 @@ Value* SimpleValue::len(const SourceRange& loc, Function& m) { } } +std::shared_ptr callClassMethod( + const ClassTypePtr& class_ptr, + const std::string& desugared_name, + const SourceRange& loc, + Function& m, + at::ArrayRef inputs, + at::ArrayRef attributes, + size_t n_binders) { + if (!class_ptr->getMethod(desugared_name)) { + throw ErrorReport(loc) << class_ptr->python_str() << " does not define a " + << desugared_name << " method"; + } + + Value* self = inputs.at(0).value(*m.graph()); + return MethodValue(self, desugared_name) + .call(loc, m, inputs.slice(1), attributes, n_binders); +} + Value* SimpleValue::getitem(const SourceRange& loc, Function& m, Value* idx) { Value* val = getValue(); TypePtr val_type = val->type(); Graph& g = *m.graph(); - Value* cur_elem = nullptr; // if it's a List/String/Dict, emit a regular __getitem__ op if (val_type->cast() || val_type->cast()) { - cur_elem = g.insert(aten::__getitem__, {val, idx}, {}, loc); + return g.insert(aten::__getitem__, {val, idx}, {}, loc); } else if (auto dict_type = val_type->cast()) { if (!idx->type()->isSubtypeOf(dict_type->getKeyType())) { throw ErrorReport(loc) @@ -306,14 +323,16 @@ Value* SimpleValue::getitem(const SourceRange& loc, Function& m, Value* idx) { << dict_type->getKeyType()->python_str() << "' of the dict '" << dict_type->python_str() << "'"; } - cur_elem = g.insert(aten::__getitem__, {val, idx}, {}, loc); + return g.insert(aten::__getitem__, {val, idx}, {}, loc); } else if (val_type->isSubtypeOf(TensorType::get())) { - cur_elem = g.insert(aten::select, {val, 0, idx}, {}, loc); + return g.insert(aten::select, {val, 0, idx}, {}, loc); + } else if (auto class_type = val_type->cast()) { + return callClassMethod(class_type, "__getitem__", loc, m, {val, idx}, {}, 1) + ->asValue(loc, m); } else { throw ErrorReport(loc) << "'" << val_type->python_str() << "'" << " object is not subscriptable"; } - return cur_elem; } RangeValue::RangeValue( @@ -409,6 +428,21 @@ Value* IterableTree::getitem(const SourceRange& loc, Function& m, Value* idx) { return g.insertNode(g.createTuple(child_items))->output(); } +std::shared_ptr MagicMethod::call( + const SourceRange& loc, + Function& m, + at::ArrayRef inputs, + at::ArrayRef attributes, + size_t n_binders) { + if (inputs.size() > 0) { + Value* self = inputs[0].value(*m.graph()); + if (auto class_ptr = self->type()->cast()) { + return callClassMethod( + class_ptr, desugared_name_, loc, m, inputs, attributes, n_binders); + } + } + return base_value_->call(loc, m, inputs, attributes, n_binders); +} std::shared_ptr ClassValue::call( const SourceRange& loc, Function& m, diff --git a/torch/csrc/jit/script/sugared_value.h b/torch/csrc/jit/script/sugared_value.h index 0d4ebe5e87ac6..95ec284b0aab6 100644 --- a/torch/csrc/jit/script/sugared_value.h +++ b/torch/csrc/jit/script/sugared_value.h @@ -384,23 +384,7 @@ struct TORCH_API MagicMethod : public SugaredValue { Function& m, at::ArrayRef inputs, at::ArrayRef attributes, - size_t n_binders) override { - if (inputs.size() > 0) { - Value* self = inputs[0].value(*m.graph()); - - if (auto class_ptr = self->type()->cast()) { - if (!class_ptr->getMethod(desugared_name_)) { - throw ErrorReport(loc) - << class_ptr->python_str() << " does not define a " - << desugared_name_ << " method"; - } - - return MethodValue(self, desugared_name_) - .call(loc, m, inputs.slice(1), attributes, n_binders); - } - } - return base_value_->call(loc, m, inputs, attributes, n_binders); - } + size_t n_binders) override; private: SugaredValuePtr base_value_; diff --git a/torch/csrc/jit/source_range.cpp b/torch/csrc/jit/source_range.cpp index f15a62ecf6361..91dcf66c24fa3 100644 --- a/torch/csrc/jit/source_range.cpp +++ b/torch/csrc/jit/source_range.cpp @@ -66,8 +66,9 @@ C10_EXPORT void SourceRange::highlight(std::ostream& out) const { size_t len = std::min(size(), end_line - start()); out << std::string(len, '~') << (len < size() ? "... <--- HERE" : " <--- HERE"); - out << str.substr(end_line, end_highlight - end_line); - if (!str.empty() && str.back() != '\n') + auto line_substr = str.substr(end_line, end_highlight - end_line); + out << line_substr; + if (!line_substr.empty() && line_substr.back() != '\n') out << "\n"; // Retrieve original SourceRange, if present. if (auto orig_source_range = findSourceRangeThatGenerated()) { diff --git a/torch/csrc/jit/subgraph_matcher.cpp b/torch/csrc/jit/subgraph_matcher.cpp index d0653f8dac40a..c50a41b3b4344 100644 --- a/torch/csrc/jit/subgraph_matcher.cpp +++ b/torch/csrc/jit/subgraph_matcher.cpp @@ -78,7 +78,20 @@ bool patternGraphIsValid(const Graph& pattern) { bool SubgraphMatcher::matchValues(const Value* v1, Value* v2) { // Check if we've already visited these values. if (values_map_.count(v1)) { - return values_map_.at(v1) == v2; + if (values_map_.at(v1) != v2) { + GRAPH_DEBUG( + "Values %", + v1->debugName(), + " and %", + v2->debugName(), + " did not match because %", + v1->debugName(), + " has already been matched with %", + values_map_.at(v1)->debugName(), + ".\n"); + return false; + } + return true; } // When V2 is ANCHOR, we're comparing exiting values, and when V1->node is diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp index 22b8546d9ac12..7fecd1a940af8 100644 --- a/torch/csrc/jit/tracer.cpp +++ b/torch/csrc/jit/tracer.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include @@ -477,6 +478,14 @@ void addInputs( n->addInput(none); } } +#ifdef BUILD_NAMEDTENSOR +void addInputs( + Node* n, + const char* name, + c10::optional value) { + TORCH_CHECK(false, "NYI: Named tensors are not supported with the tracer"); +} +#endif void addInputs( Node* n, const char* name, diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h index 1a0e615cc338e..a8685cc4a319c 100644 --- a/torch/csrc/jit/tracer.h +++ b/torch/csrc/jit/tracer.h @@ -4,6 +4,8 @@ #include #include #include +#include +#include #include @@ -280,6 +282,9 @@ TORCH_API void addInputs( const char* name, const c10::optional& value); TORCH_API void addInputs(Node* n, const char* name, at::MemoryFormat value); +#ifdef BUILD_NAMEDTENSOR +TORCH_API void addInputs(Node* n, const char* name, c10::optional value); +#endif TORCH_API void addInputs( Node* n, const char* name, diff --git a/torch/csrc/multiprocessing/init.cpp b/torch/csrc/multiprocessing/init.cpp index 75d71db2945ea..8255cf3d6e774 100644 --- a/torch/csrc/multiprocessing/init.cpp +++ b/torch/csrc/multiprocessing/init.cpp @@ -18,7 +18,7 @@ namespace multiprocessing { namespace { -PyObject* multiprocessing_init(PyObject* _unused) { +PyObject* multiprocessing_init(PyObject* _unused, PyObject *noargs) { auto multiprocessing_module = THPObjectPtr(PyImport_ImportModule("torch.multiprocessing")); if (!multiprocessing_module) { diff --git a/torch/csrc/python_dimname.cpp b/torch/csrc/python_dimname.cpp index 38a9da896138d..e02ac41e4eaa7 100644 --- a/torch/csrc/python_dimname.cpp +++ b/torch/csrc/python_dimname.cpp @@ -1,9 +1,10 @@ -#ifdef BUILD_NAMEDTENSOR #include #include #include #include +#include +#ifdef BUILD_NAMEDTENSOR namespace torch { struct InternedStringsTable { diff --git a/torch/csrc/python_dimname.h b/torch/csrc/python_dimname.h index d1c03a9a83853..9d208d06e2c08 100644 --- a/torch/csrc/python_dimname.h +++ b/torch/csrc/python_dimname.h @@ -1,8 +1,9 @@ #pragma once -#ifdef BUILD_NAMEDTENSOR #include #include +#include +#ifdef BUILD_NAMEDTENSOR at::Dimname THPDimname_parse(PyObject* obj); bool THPUtils_checkDimname(PyObject* obj); bool THPUtils_checkDimnameList(PyObject* obj); diff --git a/torch/csrc/tensor/python_tensor.cpp b/torch/csrc/tensor/python_tensor.cpp index 5f0291b46e3da..32b1dfef79a31 100644 --- a/torch/csrc/tensor/python_tensor.cpp +++ b/torch/csrc/tensor/python_tensor.cpp @@ -72,6 +72,10 @@ static PyObject* Tensor_new(PyTypeObject *type, PyObject *args, PyObject *kwargs END_HANDLE_TH_ERRORS } +// TODO: Deprecate this instancecheck entirely. It's here to make +// instanceof(t, torch.FloatTensor) work, but we are not going to keep +// adding torch.QuantizedIntTensor classes for every new tensor type +// we add... static PyObject* Tensor_instancecheck(PyTensorType* self, PyObject* arg) { HANDLE_TH_ERRORS if (THPVariable_Check(arg)) { @@ -82,7 +86,10 @@ static PyObject* Tensor_instancecheck(PyTensorType* self, PyObject* arg) { // be nullptr if you had a tensor of some type, in which case you can // skip initializign aten_type(), but TestAutograd.test_type_conversions // seems to violate this property (for whatever reason.) - if (var.type_id() == self->get_type_id() && + // + // TODO: Stop using legacyExtractTypeId here (probably need to build + // in instanceof checking to Tensor class itself) + if (legacyExtractTypeId(var.type_set()) == self->get_type_id() && var.scalar_type() == static_cast(self->scalar_type)) { Py_RETURN_TRUE; } @@ -91,15 +98,15 @@ static PyObject* Tensor_instancecheck(PyTensorType* self, PyObject* arg) { END_HANDLE_TH_ERRORS } -PyObject *Tensor_dtype(PyTensorType* self) { +PyObject *Tensor_dtype(PyTensorType* self, void *unused) { return torch::autograd::utils::wrap(self->dtype); } -PyObject *Tensor_layout(PyTensorType* self) { +PyObject *Tensor_layout(PyTensorType* self, void *unused) { return torch::autograd::utils::wrap(self->layout); } -PyObject *Tensor_is_cuda(PyTensorType* self) { +PyObject *Tensor_is_cuda(PyTensorType* self, void *unused) { if (self->is_cuda) { Py_RETURN_TRUE; } else { @@ -107,7 +114,7 @@ PyObject *Tensor_is_cuda(PyTensorType* self) { } } -PyObject *Tensor_is_sparse(PyTensorType *self) { +PyObject *Tensor_is_sparse(PyTensorType *self, void *unused) { if (self->layout->layout == at::Layout::Strided) { Py_RETURN_FALSE; } else { diff --git a/torch/csrc/tensor/python_tensor.h b/torch/csrc/tensor/python_tensor.h index ddaeb7d46ff9c..9d46dfc28b8d6 100644 --- a/torch/csrc/tensor/python_tensor.h +++ b/torch/csrc/tensor/python_tensor.h @@ -25,6 +25,10 @@ void py_set_default_tensor_type(PyObject* type_obj); void py_set_default_dtype(PyObject* dtype_obj); // Gets the TensorTypeId for the default tensor type. +// +// TODO: This is nuts! There is no reason to let the default tensor type id +// change. Probably only store ScalarType, as that's the only flex point +// we support. c10::TensorTypeId get_default_tensor_type_id(); // Gets the ScalarType for the default tensor type. diff --git a/torch/csrc/utils/init.cpp b/torch/csrc/utils/init.cpp index 469a80c1b307e..2235231a3d306 100644 --- a/torch/csrc/utils/init.cpp +++ b/torch/csrc/utils/init.cpp @@ -2,8 +2,6 @@ #include #include -#include - #include namespace torch { @@ -48,12 +46,6 @@ void initThroughputBenchmarkBindings(PyObject* module) { }); - m.def("_enable_mkldnn_conv", []() { - at::native::disable_mkldnn_conv.exchange(false); - }); - m.def("_disable_mkldnn_conv", []() { - at::native::disable_mkldnn_conv.exchange(true); - }); } } // namespace throughput_benchmark diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 6ff1ac726bfbc..2d7ce07d82b43 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include @@ -649,7 +650,9 @@ at::Tensor PythonArgs::tensor_slow(int i) { } at::Scalar scalar; - if (THPUtils_checkLong(obj)) { + if (PyBool_Check(obj)) { + scalar = at::Scalar(THPUtils_unpackBool(obj)); + } else if (THPUtils_checkLong(obj)) { scalar = at::Scalar(THPUtils_unpackLong(obj)); }else if (PyComplex_Check(obj)) { scalar = at::Scalar(THPUtils_unpackComplexDouble(obj)); @@ -681,10 +684,15 @@ at::Scalar PythonArgs::scalar_slow(int i) { if (THPVariable_Check(args[i])) { return ((THPVariable*)args[i])->cdata.item(); } + if (THPUtils_checkLong(args[i])) { return at::Scalar(static_cast(THPUtils_unpackLong(args[i]))); } + if (PyBool_Check(args[i])) { + return at::Scalar(THPUtils_unpackBool(args[i])); + } + if (PyComplex_Check(args[i])) { return at::Scalar(THPUtils_unpackComplexDouble(args[i])); } diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index c17c50a043092..b4e2ab4152a1a 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -52,6 +52,7 @@ #include #include #include +#include #ifdef BUILD_NAMEDTENSOR #include #endif diff --git a/torch/csrc/utils/python_numbers.h b/torch/csrc/utils/python_numbers.h index 0c8938da3650c..84b6b9f2c0526 100644 --- a/torch/csrc/utils/python_numbers.h +++ b/torch/csrc/utils/python_numbers.h @@ -89,6 +89,16 @@ inline int64_t THPUtils_unpackIndex(PyObject* obj) { return THPUtils_unpackLong(obj); } +inline bool THPUtils_unpackBool(PyObject* obj) { + if (obj == Py_True) { + return true; + } else if (obj == Py_False) { + return false; + } else { + throw std::runtime_error("couldn't convert python object to boolean"); + } +} + inline bool THPUtils_checkDouble(PyObject* obj) { bool is_numpy_scalar; #ifdef USE_NUMPY diff --git a/torch/csrc/utils/qengines.cpp b/torch/csrc/utils/qengines.cpp new file mode 100644 index 0000000000000..93f9890881970 --- /dev/null +++ b/torch/csrc/utils/qengines.cpp @@ -0,0 +1,37 @@ +#include + +#include +#include +#include +#include + +#include +#include + +namespace torch { +namespace utils { + +void addQEngine( + at::QEngine qengine, + const std::string& name, + PyObject* torch_module) { + PyObject* qengine_obj = THPQEngine_New(qengine, name); + Py_INCREF(qengine_obj); + if (PyModule_AddObject(torch_module, name.c_str(), qengine_obj) != 0) { + throw python_error(); + } +} + +void initializeQEngines() { + auto torch_module = THPObjectPtr(PyImport_ImportModule("torch")); + if (!torch_module) { + throw python_error(); + } + + addQEngine(at::kNoQEngine, "no_qengine", torch_module); + addQEngine(at::kFBGEMM, "fbgemm", torch_module); + addQEngine(at::kQNNPACK, "qnnpack", torch_module); +} + +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/qengines.h b/torch/csrc/utils/qengines.h new file mode 100644 index 0000000000000..b1cd56be7fbad --- /dev/null +++ b/torch/csrc/utils/qengines.h @@ -0,0 +1,10 @@ +#pragma once +#include + +namespace torch { +namespace utils { + +void initializeQEngines(); + +} // namespace utils +} // namespace torch diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 1d8e7480a10ef..08124c1a3411e 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -17,8 +17,10 @@ #include #include +#include #include #include +#include #include #include @@ -115,8 +117,9 @@ Tensor new_with_storage(c10::TensorTypeId type_id, at::ScalarType scalar_type, S } Tensor new_with_tensor(c10::TensorTypeId type_id, at::ScalarType scalar_type, const Tensor& other) { - if (other.type_id() != type_id) { - throw TypeError("expected %s (got %s)", type_id, other.type_id()); + if (legacyExtractTypeId(other.type_set()) != type_id) { + // In temporary expression lifetime we trust + throw TypeError("expected %s (got %s)", type_id, toString(other.type_set()).c_str()); } if (other.scalar_type() != scalar_type) { throw TypeError("expected %s (got %s)", toString(scalar_type), toString(other.scalar_type())); @@ -535,7 +538,7 @@ Tensor sparse_coo_tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_t at::OptionalDeviceGuard device_guard(r.deviceOptional(3)); // if no dtype provided, infer type based on value type. Tensor values = internal_new_from_data(inferred_type_id, inferred_scalar_type, r.deviceOptional(3), r.pyobject(1), false, true, type_inference); - Tensor indices = internal_new_from_data(values.type_id(), kLong, r.deviceOptional(3), r.pyobject(0), false, true, false); + Tensor indices = internal_new_from_data(legacyExtractTypeId(values.type_set()), kLong, r.deviceOptional(3), r.pyobject(0), false, true, false); return at::sparse_coo_tensor(indices, values, values.options().layout(at::kSparse)).set_requires_grad(r.toBool(4)); } else if (r.idx == 1) { bool type_inference = r.isNone(3); @@ -543,7 +546,7 @@ Tensor sparse_coo_tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_t const auto inferred_scalar_type = r.scalartypeWithDefault(3, scalar_type); at::OptionalDeviceGuard device_guard(r.deviceOptional(4)); Tensor values = internal_new_from_data(inferred_type_id, inferred_scalar_type, r.deviceOptional(4), r.pyobject(1), false, true, type_inference); - Tensor indices = internal_new_from_data(values.type_id(), kLong, r.deviceOptional(4), r.pyobject(0), false, true, false); + Tensor indices = internal_new_from_data(legacyExtractTypeId(values.type_set()), kLong, r.deviceOptional(4), r.pyobject(0), false, true, false); return at::sparse_coo_tensor(indices, values, r.intlist(2), values.options().layout(at::kSparse)).set_requires_grad(r.toBool(5)); } else if (r.idx == 2) { const auto inferred_type_id = typeIdWithDefault(r, 2, type_id); @@ -556,10 +559,19 @@ Tensor sparse_coo_tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_t Tensor tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { static PythonArgParser parser({ +#ifdef BUILD_NAMEDTENSOR + "tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, DimnameList? names=None)", +#else "tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)", +#endif }); - ParsedArgs<5> parsed_args; +#ifdef BUILD_NAMEDTENSOR + constexpr int ctor_num_args = 6; +#else + constexpr int ctor_num_args = 5; +#endif + ParsedArgs parsed_args; auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { PyObject* data = r.pyobject(0); @@ -581,6 +593,12 @@ Tensor tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObje true, type_inference, pin_memory); +#ifdef BUILD_NAMEDTENSOR + auto names = r.toDimnameListOptional(5); + if (names) { + at::namedinference::propagate_names(new_tensor, std::move(names), /*validate_names=*/true); + } +#endif new_tensor.detach_(); // ensure new_tensor a leaf node new_tensor.set_requires_grad(args_requires_grad); return new_tensor; diff --git a/torch/csrc/utils/variadic.h b/torch/csrc/utils/variadic.h index 3a924a9db9bef..63f34afbc37bf 100644 --- a/torch/csrc/utils/variadic.h +++ b/torch/csrc/utils/variadic.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -10,67 +11,7 @@ namespace torch { -// This class allows you to write variadic functions which -// call a (possibly overloaded) function on each argument, -// in order. This is most commonly used in autogenerated code, -// where it is convenient to have a function that can uniformly -// take arguments of different types. If your arguments -// are homogenous consider using a std::initializer_list instead. -template -struct IterArgs { - template - inline F& apply() { - return self(); - } - - // NB: Use perfect forwarding here, otherwise we'll make value - // copies of all arguments! - template - inline F& apply(T&& arg, Args&&... args) { - self()(std::forward(arg)); - if (self().short_circuit()) { - return self(); - } else { - return apply(std::forward(args)...); - } - } - - // Here are some handy overloads which provide sensible - // defaults for container-like structures that one might - // be interested in recursing into. You can enable them - // by adding: - // - // using IterArgs::operator() - // - // to your struct. These are not enabled by default because - // you may be able to process these structures more efficiently - // than handling them one-by-one. - - template - void operator()(at::ArrayRef args) { - for (const auto& arg : args) { - self()(arg); - if (short_circuit()) - return; - } - } - - // NB: we need to specify std::vector manually as C++ won't - // do an implicit conversion to make a template deduction go through. - template - void operator()(const std::vector& args) { - self()(at::ArrayRef{args}); - } - - bool short_circuit() { - return false; - } - - private: - inline F& self() { - return *static_cast(this); - } -}; +using at::IterArgs; struct CountTensors : IterArgs { size_t out = 0; @@ -194,4 +135,5 @@ template ) { return ReturnType(function(accessor.template operator()(Is)...)); } + } // namespace torch diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 411cfb7315a3a..8450f278129e7 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -16,12 +16,15 @@ import torch import traceback import warnings +import threading from torch._six import raise_from from subprocess import Popen, PIPE from multiprocessing.util import register_after_fork as _register_after_fork from ._utils import _get_device_index _initialized = False +_tls = threading.local() +_initialization_lock = threading.Lock() _queued_calls = [] # don't invoke these until initialization occurs _in_bad_fork = False # this global is also used in torch.manual_seed _original_pid = False @@ -163,34 +166,50 @@ def init(): def _lazy_init(): global _initialized, _cudart, _original_pid, _queued_calls - if _initialized: + if _initialized or hasattr(_tls, 'is_initializing'): return - if _in_bad_fork: - from sys import version_info - if version_info < (3, 4): - msg = ("To use CUDA with multiprocessing, you must use Python " - "3.4+ and the 'spawn' start method") - else: - msg = ("To use CUDA with multiprocessing, you must use the " - "'spawn' start method") - raise RuntimeError( - "Cannot re-initialize CUDA in forked subprocess. " + msg) - _check_driver() - torch._C._cuda_init() - _cudart = _load_cudart() - _cudart.cudaGetErrorName.restype = ctypes.c_char_p - _cudart.cudaGetErrorString.restype = ctypes.c_char_p - _original_pid = os.getpid() - _initialized = True - # Important to do this after _initialized, since some queued calls - # may themselves call _lazy_init() - for queued_call, orig_traceback in _queued_calls: + with _initialization_lock: + # We be double-checked locking, boys! This is OK because + # the above test was GIL protected anyway. The inner test + # is for when a thread blocked on some other thread which was + # doing the initialization; when they get the lock, they will + # find there is nothing left to do. + if _initialized: + return + # It is important to prevent other threads from entering _lazy_init + # immediately, while we are still guaranteed to have the GIL, because some + # of the C calls we make below will release the GIL + if _in_bad_fork: + from sys import version_info + if version_info < (3, 4): + msg = ("To use CUDA with multiprocessing, you must use Python " + "3.4+ and the 'spawn' start method") + else: + msg = ("To use CUDA with multiprocessing, you must use the " + "'spawn' start method") + raise RuntimeError( + "Cannot re-initialize CUDA in forked subprocess. " + msg) + _check_driver() + torch._C._cuda_init() + _cudart = _load_cudart() + _cudart.cudaGetErrorName.restype = ctypes.c_char_p + _cudart.cudaGetErrorString.restype = ctypes.c_char_p + _original_pid = os.getpid() + # Some of the queued calls may reentrantly call _lazy_init(); + # we need to just return without initializing in that case. + # However, we must not let any *other* threads in! + _tls.is_initializing = True try: - queued_call() - except Exception as e: - msg = ("CUDA call failed lazily at initialization with error: {}\n\n" - "CUDA call was originally invoked at:\n\n{}").format(str(e), orig_traceback) - raise_from(DeferredCudaCallError(msg), e) + for queued_call, orig_traceback in _queued_calls: + try: + queued_call() + except Exception as e: + msg = ("CUDA call failed lazily at initialization with error: {}\n\n" + "CUDA call was originally invoked at:\n\n{}").format(str(e), orig_traceback) + raise_from(DeferredCudaCallError(msg), e) + finally: + delattr(_tls, 'is_initializing') + _initialized = True def _after_fork(arg): diff --git a/torch/cuda/__init__.pyi b/torch/cuda/__init__.pyi index 03da71119feb8..055275bf07e35 100644 --- a/torch/cuda/__init__.pyi +++ b/torch/cuda/__init__.pyi @@ -40,3 +40,5 @@ def max_memory_cached(device: Optional[_device_t]=...) -> int: ... def reset_max_memory_cached(device: Optional[_device_t]=...) -> None: ... def cudart() -> ctypes.CDLL: ... def find_cuda_windows_lib() -> Optional[ctypes.CDLL]: ... +def set_rng_state(new_state): ... +def get_rng_state(): ... diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index 198a73a289954..2fcb14bb1682f 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -23,7 +23,11 @@ def is_available(): from .rpc import _init_rpc from .rpc import * # noqa: F401 - def init_model_parallel(worker_name, rpc_backend=RpcBackend.PROCESS_GROUP): + def init_model_parallel(self_name, + backend=RpcBackend.PROCESS_GROUP, + self_rank=-1, + init_method=None, + num_send_recv_threads=4): r""" Initializes model parallel primitives such as the local rpc agent and distributed autograd. @@ -35,16 +39,19 @@ def init_model_parallel(worker_name, rpc_backend=RpcBackend.PROCESS_GROUP): ``init_process_group`` must be invoked prior to this method. Arguments: - worker_name (str): a globally unique name of this node. (e.g., + backend (Enum): type of RPC backend implementation. + Currently, process group backend is the only + available backend implementation. (default: + ``RpcBackend.PROCESS_GROUP``). + self_name (str): a globally unique name of this node. (e.g., ``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``) Name can only contain number, alphabet, underscore, and/or dash, and must be shorter than 128 characters. - rpc_backend (Enum): type of RPC backend implementation. - Currently, process group backend is the only - available backend implementation. (default: - ``RpcBackend.PROCESS_GROUP``). + self_rank (int): a globally unique id/rank of this node. + init_method(str): backend specific init arguments. + num_send_recv_threads(int): Number of threads for send/recv work. """ - _init_rpc(worker_name, rpc_backend) + _init_rpc(backend, self_name, self_rank, init_method, num_send_recv_threads) from .rpc import _agent autograd._init(_agent.get_worker_id().id) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index c3a095d671aba..6d85f7b4f6a64 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -485,8 +485,7 @@ def _new_process_group_helper(world_size, pg = ProcessGroupNCCL( prefix_store, rank, - world_size, - group_name) + world_size) _pg_map[pg] = (Backend.NCCL, store) _pg_names[pg] = group_name else: diff --git a/torch/distributed/launch.py b/torch/distributed/launch.py index a744a9cf63985..7ccb785de28a4 100644 --- a/torch/distributed/launch.py +++ b/torch/distributed/launch.py @@ -181,6 +181,10 @@ def parse_args(): "'local rank'. For legacy reasons, the default value is False. " "If set to True, the script will not pass " "--local_rank as argument, and will instead set LOCAL_RANK.") + parser.add_argument("-m", "--module", default=False, action="store_true", + help="Changes each process to interpret the launch script " + "as a python module, executing with the same behavior as" + "'python -m'.") # positional parser.add_argument("training_script", type=str, @@ -223,14 +227,17 @@ def main(): current_env["LOCAL_RANK"] = str(local_rank) # spawn the processes - if args.use_env: - cmd = [sys.executable, "-u", - args.training_script] + args.training_script_args - else: - cmd = [sys.executable, - "-u", - args.training_script, - "--local_rank={}".format(local_rank)] + args.training_script_args + cmd = [sys.executable, "-u"] + + if args.module: + cmd.append("-m") + + cmd.append(args.training_script) + + if not args.use_env: + cmd.append("--local_rank={}".format(local_rank)) + + cmd.extend(args.training_script_args) process = subprocess.Popen(cmd, env=current_env) processes.append(process) diff --git a/torch/distributed/rpc.py b/torch/distributed/rpc.py index b002debf68e36..5fa5f22f69d4a 100644 --- a/torch/distributed/rpc.py +++ b/torch/distributed/rpc.py @@ -1,10 +1,11 @@ #!/usr/bin/env python3 -from __future__ import absolute_import, division, print_function, unicode_literals -from . import invoke_rpc_builtin, invoke_rpc_python_udf - +from . import invoke_rpc_builtin, invoke_rpc_python_udf, invoke_remote_builtin +from . import init_rref_context from . import ProcessGroupAgent +from . import WorkerId from .internal_rpc_utils import serialize, PythonUDF +from .rpc_backend_handler import is_backend_registered, registered_init_rpc import sys import torch @@ -52,7 +53,11 @@ class RpcBackend(Enum): # TODO: add a context manager to wrap _init_rpc and join_rpc -def _init_rpc(name, backend=RpcBackend.PROCESS_GROUP): +def _init_rpc(backend=RpcBackend.PROCESS_GROUP, + self_name=None, + self_rank=-1, + init_method=None, + num_send_recv_threads=4): if sys.version_info < (3, 0): raise RuntimeError("RPC package does not support Python2.") @@ -63,11 +68,22 @@ def _init_rpc(name, backend=RpcBackend.PROCESS_GROUP): if backend == RpcBackend.PROCESS_GROUP: from .distributed_c10d import _get_default_group + group = _get_default_group() + if (self_rank != -1) and (self_rank != group.rank()): + raise RuntimeError("self_rank argument {} doesn't match pg rank {}".format( + self_rank, group.rank())) # TODO: add try-except and destroy _agent in all processes if any fails. - _agent = ProcessGroupAgent(name, group) + _agent = ProcessGroupAgent(self_name, group, num_send_recv_threads) + init_rref_context(_agent) + elif is_backend_registered(rpc_backend): + _agent = registered_init_rpc(rpc_backend, + self_rank=self_rank, + self_name=self_name, + init_url=init_method) + init_rref_context(_agent) else: - raise RuntimeError("Unrecognized RPC backend ", backend) + raise RuntimeError("Unrecognized RPC backend ", rpc_backend) @_require_initialized @@ -86,6 +102,62 @@ def get_worker_id(worker_name=None): return _agent.get_worker_id() +def _to_worker_id(name_or_id): + if isinstance(name_or_id, WorkerId): + return name_or_id + elif isinstance(name_or_id, str): + return get_worker_id(name_or_id) + else: + raise ValueError("Unsupported RPC worker ID type {}".format(name_or_id)) + + +@_require_initialized +def remote(to, func, args=None, kwargs=None): + r""" + Make a ``remote`` call to run ``func`` on worker ``to``, and returns an + ``RRef`` to the result value immediately. Worker ``to`` will be the owner + of the return ``RRef``, and this worker is a user. The owner manages the + global reference count of its ``RRef``s, and the owner ``RRef`` is only + destructed when globally there is no living references to it. + + Arguments: + to (int or str): id or name of the destination worker. + func (callable): builtin functions (like ``torch.add``). + args (tuple): the argument tuple for the ``func`` invocation. + kwargs (dict): is a dictionary of keyword arguments for the ``func`` + invocation. + + Returns: + A user ``RRef`` instance to the result value. Use the blocking API + ``RRef.to_here()`` to retrieve the result value locally. + + Example:: + + On worker 0: + >>> import torch.distributed as dist + >>> dist.init_process_group(backend='gloo', rank=0, world_size=2) + >>> dist.init_rpc("worker0") + >>> worker1 = dist.get_worker_id("worker1") + >>> rref1 = dist.remote(worker1, torch.add, args=(torch.ones(2), 3)) + >>> rref2 = dist.remote(worker1, torch.add, args=(torch.ones(2), 1)) + >>> x = rref1.to_here() + rref2.to_here() + >>> dist.join_rpc() + + One worker 1: + >>> import torch.distributed as dist + >>> dist.init_process_group(backend='gloo', rank=1, world_size=2) + >>> dist.init_rpc("worker1") + >>> dist.join_rpc() + """ + qualified_name = torch.jit._find_builtin(func) + + args = args if args else () + kwargs = kwargs if kwargs else {} + + return invoke_remote_builtin( + _agent, _to_worker_id(to), qualified_name, *args, **kwargs) + + @_require_initialized def rpc(to, func, args=None, kwargs=None, async_call=False): r""" @@ -157,13 +229,12 @@ def rpc(to, func, args=None, kwargs=None, async_call=False): args = args if args else () kwargs = kwargs if kwargs else {} - if isinstance(to, str): - to = get_worker_id(to) - if qualified_name is not None: - fut = invoke_rpc_builtin(_agent, to, qualified_name, *args, **kwargs) + fut = invoke_rpc_builtin( + _agent, _to_worker_id(to), qualified_name, *args, **kwargs) else: - fut = invoke_rpc_python_udf(_agent, to, serialize(PythonUDF(func, args, kwargs))) + fut = invoke_rpc_python_udf( + _agent, _to_worker_id(to), serialize(PythonUDF(func, args, kwargs))) if async_call: return fut diff --git a/torch/distributed/rpc_backend_handler.py b/torch/distributed/rpc_backend_handler.py new file mode 100644 index 0000000000000..b279de4837392 --- /dev/null +++ b/torch/distributed/rpc_backend_handler.py @@ -0,0 +1,31 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + + +_rpc_init_handlers = {} + + +def register_rpc_backend(backend_str, handler): + """Registers a new rpc backend. + + Arguments: + backend (str): backend string to identify the handler. + handler (function): Handler that is invoked when the + `_init_rpc()` function is called with a backend. + This returns the agent. + """ + global _rpc_init_handlers + if backend_str in _rpc_init_handlers: + raise RuntimeError( + "Rpc backend {}: already registered".format(backend_str) + ) + _rpc_init_handlers[backend_str] = handler + + +def registered_init_rpc(backend_str, **kwargs): + if backend_str not in _rpc_init_handlers: + raise RuntimeError("No rpc_init handler for {}.".format(backend_str)) + return _rpc_init_handlers[backend_str](**kwargs) + + +def is_backend_registered(backend_str): + return backend_str in _rpc_init_handlers diff --git a/torch/distributions/negative_binomial.py b/torch/distributions/negative_binomial.py index 9e4410a734e8c..7395635971dee 100644 --- a/torch/distributions/negative_binomial.py +++ b/torch/distributions/negative_binomial.py @@ -8,8 +8,8 @@ class NegativeBinomial(Distribution): r""" Creates a Negative Binomial distribution, i.e. distribution - of the number of independent identical Bernoulli trials - needed before :attr:`total_count` failures are achieved. The probability + of the number of successful independent and identical Bernoulli trials + before :attr:`total_count` failures are achieved. The probability of success of each Bernoulli trial is :attr:`probs`. Args: diff --git a/torch/hub.py b/torch/hub.py index 1b3107f0a9263..d44f5e24f81fa 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -12,7 +12,7 @@ if sys.version_info[0] == 2: from urlparse import urlparse - import requests + from urllib2 import urlopen # noqa f811 else: from urllib.request import urlopen from urllib.parse import urlparse # noqa: F401 @@ -95,10 +95,7 @@ def _download_archive_zip(url, filename): sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, filename)) # We use a different API for python2 since urllib(2) doesn't recognize the CA # certificates in older Python - if sys.version_info[0] == 2: - response = requests.get(url, stream=True).raw - else: - response = urlopen(url) + response = urlopen(url) with open(filename, 'wb') as f: while True: data = response.read(READ_DATA_CHUNK) @@ -376,22 +373,14 @@ def _download_url_to_file(url, dst, hash_prefix, progress): file_size = None # We use a different API for python2 since urllib(2) doesn't recognize the CA # certificates in older Python - if sys.version_info[0] == 2: - response = requests.get(url, stream=True) - - content_length = response.headers['Content-Length'] - file_size = content_length - u = response.raw + u = urlopen(url) + meta = u.info() + if hasattr(meta, 'getheaders'): + content_length = meta.getheaders("Content-Length") else: - u = urlopen(url) - - meta = u.info() - if hasattr(meta, 'getheaders'): - content_length = meta.getheaders("Content-Length") - else: - content_length = meta.get_all("Content-Length") - if content_length is not None and len(content_length) > 0: - file_size = int(content_length[0]) + content_length = meta.get_all("Content-Length") + if content_length is not None and len(content_length) > 0: + file_size = int(content_length[0]) # We deliberately save it in a temp file and move it after # download is complete. This prevents a local working checkpoint diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 4edef2705f96c..8836f725dac56 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -9,7 +9,7 @@ from torch._jit_internal import _qualified_name from torch.autograd import Variable, function from torch.jit.frontend import get_jit_class_def, get_jit_def, get_default_args -from torch.nn import Module, ModuleList, Sequential +from torch.nn import Module, ModuleList, Sequential, ModuleDict from torch.serialization import validate_cuda_device from torch._six import PY2, PY37, with_metaclass, string_classes from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \ @@ -32,7 +32,7 @@ # These are imported so users can access them from the `torch.jit` module from torch._jit_internal import Final, _overload, _overload_method # noqa: F401 -from torch._jit_internal import ignore, export # noqa: F401 +from torch._jit_internal import ignore, export, unused # noqa: F401 if sys.version_info[0] > 2: import pathlib @@ -97,6 +97,69 @@ def optimized_execution(should_optimize): DEFAULT_EXTRA_FILES_MAP = torch._C.ExtraFilesMap() +def save(m, f, _extra_files=DEFAULT_EXTRA_FILES_MAP): + """ + Save an offline version of this module for use in a separate process. The saved + module serializes all of the methods, submodules, parameters, and attributes of this + module. It can be loaded into the C++ API using ``torch::jit::load(filename)`` or into the Python + API with :func:`torch.jit.load `. + + To be able to save a module, it must not make any calls to native Python functions. + This means that all submodules must be subclasses of ``torch.jit.ScriptModule`` as well. + + .. DANGER:: + All modules, no matter their device, are always loaded onto the CPU during loading. + This is different from :func:`load `'s semantics and may change in the future. + + Arguments: + m: a ScriptModule to save + f: a file-like object (has to implement write and flush) or a string + containing a file name + _extra_files: Map from filename to contents which will be stored as part of 'f' + + .. warning:: + If you are using Python 2, ``torch.jit.save`` does NOT support ``StringIO.StringIO`` + as a valid file-like object. This is because the write method should return + the number of bytes written; ``StringIO.write()`` does not do this. + + Please use something like ``io.BytesIO`` instead. + + Example: + + .. testcode:: + + import torch + import io + + class MyModule(torch.nn.Module): + def forward(self, x): + return x + 10 + + m = torch.jit.script(MyModule()) + + # Save to file + torch.jit.save(m, 'scriptmodule.pt') + # This line is equivalent to the previous + m.save("scriptmodule.pt") + + # Save to io.BytesIO buffer + buffer = io.BytesIO() + torch.jit.save(m, buffer) + + # Save with extra files + extra_files = torch._C.ExtraFilesMap() + extra_files['foo.txt'] = 'bar' + torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files) + """ + if isinstance(f, str) or \ + (sys.version_info[0] == 2 and isinstance(f, unicode)) or \ + (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)): + m.save(f, _extra_files=_extra_files) + else: + ret = m.save_to_buffer(_extra_files=_extra_files) + f.write(ret) + + def load(f, map_location=None, _extra_files=DEFAULT_EXTRA_FILES_MAP): r""" Load a ``ScriptModule`` previously saved with :func:`torch.jit.save ` @@ -114,11 +177,12 @@ def load(f, map_location=None, _extra_files=DEFAULT_EXTRA_FILES_MAP): filenames given in the map would be loaded and their content would be stored in the provided map. - Returns: A ``ScriptModule`` object. - Example: :: + Example: + + .. testcode:: import torch import io @@ -133,15 +197,28 @@ def load(f, map_location=None, _extra_files=DEFAULT_EXTRA_FILES_MAP): torch.jit.load(buffer) # Load all tensors onto CPU, using a device + buffer.seek(0) torch.jit.load(buffer, map_location=torch.device('cpu')) # Load all tensors onto CPU, using a string + buffer.seek(0) torch.jit.load(buffer, map_location='cpu') # Load with extra files. - files = {'metadata.json' : ''} - torch.jit.load('scriptmodule.pt', _extra_files = files) - print(files['metadata.json']) + extra_files = torch._C.ExtraFilesMap() + extra_files['foo.txt'] = 'bar' + torch.jit.load('scriptmodule.pt', _extra_files=extra_files) + print(extra_files['foo.txt']) + + .. testoutput:: + :hide: + + ... + + .. testcleanup:: + + import os + os.remove("scriptmodule.pt") """ if isinstance(f, string_classes): if not os.path.exists(f): @@ -166,65 +243,6 @@ def load(f, map_location=None, _extra_files=DEFAULT_EXTRA_FILES_MAP): return ScriptModule(_cpp_module=cpp_module) -def save(m, f, _extra_files=DEFAULT_EXTRA_FILES_MAP): - """ - Save an offline version of this module for use in a separate process. The saved - module serializes all of the methods, submodules, parameters, and attributes of this - module. It can be loaded into the C++ API using ``torch::jit::load(filename)`` or into the Python - API with :func:`torch.jit.load `. - - To be able to save a module, it must not make any calls to native Python functions. - This means that all submodules must be subclasses of ``torch.jit.ScriptModule`` as well. - - .. DANGER:: - All modules, no matter their device, are always loaded onto the CPU during loading. - This is different from :func:`load `'s semantics and may change in the future. - - Arguments: - m: a ScriptModule to save - f: a file-like object (has to implement write and flush) or a string - containing a file name - _extra_files: Map from filename to contents which will be stored as part of 'f' - - .. warning:: - If you are using Python 2, ``torch.jit.save`` does NOT support ``StringIO.StringIO`` - as a valid file-like object. This is because the write method should return - the number of bytes written; ``StringIO.write()`` does not do this. - - Please use something like ``io.BytesIO`` instead. - - Example: :: - - import torch - import io - - class MyModule(torch.nn.Module): - def forward(self, x): - return x + 10 - - m = torch.jit.script(MyModule()) - - # Save to file - torch.jit.save(m, 'scriptmodule.pt') - - # Save to io.BytesIO buffer - buffer = io.BytesIO() - torch.jit.save(m, buffer) - - # Save with extra files - extra_files = torch._C.ExtraFilesMap() - extra_files['foo.txt'] = 'bar' - torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files) - """ - if isinstance(f, str) or \ - (sys.version_info[0] == 2 and isinstance(f, unicode)) or \ - (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)): - m.save(f, _extra_files=_extra_files) - else: - ret = m.save_to_buffer(_extra_files=_extra_files) - f.write(ret) - - def get_trace_graph(f, args=(), kwargs=None, _force_outplace=False, return_inputs=False): """ Trace a function or model, returning a tuple consisting of the both the @@ -243,10 +261,11 @@ def get_trace_graph(f, args=(), kwargs=None, _force_outplace=False, return_input kwargs (dict): the keyword arguments to pass to the function/module to be traced. - Example: Trace a cell. + Example (trace a cell): - >>> trace, out = jit.trace(nn.LSTMCell(), (input, hidden)) - >>> print(trace) + .. testcode:: + + trace = torch.jit.trace(nn.LSTMCell(), (input, hidden)) """ if kwargs is None: kwargs = {} @@ -767,7 +786,9 @@ def trace(func, original ``nn.Module``. If ``callable`` is a standalone function, ``trace`` returns ``torch._C.Function`` - Example (tracing a function):: + Example (tracing a function): + + .. testcode:: import torch @@ -1056,7 +1077,9 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None): The ``@torch.jit.script`` decorator will construct a ``torch._C.Function`` by compiling the body of the function. - Example (scripting a function):: + Example (scripting a function): + + .. testcode:: import torch @@ -1075,7 +1098,9 @@ def foo(x, y): will construct ``torch.jit.ScriptModule`` that has copies of the attributes, parameters, and methods of the original module. - Example (scripting a simple module with a Parameter):: + Example (scripting a simple module with a Parameter): + + .. testcode:: import torch @@ -1096,9 +1121,11 @@ def forward(self, input): output = self.linear(output) return output - scripted_module = torch.jit.script(MyModule()) + scripted_module = torch.jit.script(MyModule(2, 3)) - Example (scripting a module with traced submodules):: + Example (scripting a module with traced submodules): + + .. testcode:: import torch import torch.nn as nn @@ -1556,16 +1583,21 @@ def __setattr__(self, attr, value): raise RuntimeError("attempting to re-assign constant '{}' in {}".format(attr, type(self).__name__)) def conv_module_to_const(module_value): - if not isinstance(module_value, (ModuleList, Sequential)): + if not isinstance(module_value, (ModuleList, Sequential, ModuleDict)): return module_value - for i in range(len(module_value)): - module_value[i] = conv_module_to_const(module_value[i]) - if isinstance(module_value, Sequential): - return _ConstSequential(module_value) + if isinstance(module_value, ModuleDict): + for key, val in module_value: + module_value[key] = conv_module_to_const(val) + return _ConstModuleDict(module_value) else: - return _ConstModuleList(module_value) - - if isinstance(value, (ModuleList, Sequential)): + for i in range(len(module_value)): + module_value[i] = conv_module_to_const(module_value[i]) + if isinstance(module_value, Sequential): + return _ConstSequential(module_value) + else: + return _ConstModuleList(module_value) + + if isinstance(value, (ModuleList, Sequential, ModuleDict)): # special case for list of modules. Modules need to be registered with their # parent module. To do this, we create a ConstModuleList, which is itself a module, that # contains each of these modules as submodules. The ConstModuleList then @@ -1601,6 +1633,15 @@ def __getstate__(self): def graph_for(self, *args, **kwargs): return self.forward.graph_for(*args, **kwargs) + def extra_repr(self): + return 'original_name={}'.format(self.original_name) + + @property + def original_name(self): + if type(self) == self._c.name: + return '' + return self._c.name + else: class ScriptModule(torch.nn.Module): def __init__(self): @@ -1730,6 +1771,48 @@ def __dir__(self): keys = [key for key in keys if not key.isdigit()] return keys +class _ConstModuleDict(ScriptModule): + def __init__(self, modules): + super(_ConstModuleDict, self).__init__() + + assert isinstance(modules, OrderedDict) + + for key, module in modules.items(): + if isinstance(module, torch.nn.Module): + module = torch.jit._recursive.recursive_script(module) + self.add_module(key, module) + + + def __getitem__(self, key): + return self._modules[key] + + def __contains__(self, key): + return key in self._modules + + def keys(self): + r"""Return an iterable of the ModuleDict keys. + """ + return self._modules.keys() + + def items(self): + r"""Return an iterable of the ModuleDict key/value pairs. + """ + return self._modules.items() + + def values(self): + r"""Return an iterable of the ModuleDict values. + """ + return self._modules.values() + + def __len__(self): + return len(self._modules) + + def __iter__(self): + return iter(self._modules.values()) + + def forward(self): + raise NotImplementedError() + class _ConstSequential(_ConstModuleList): __constants__ = ['mods'] @@ -1749,6 +1832,23 @@ def forward(self, input): return input """) +def is_scripting(): + r""" + Function that returns True when in compilation and False otherwise. This + is useful especially with the @unused decorator to leave code in your + model that is not yet TorchScript compatible. + + @torch.jit.unused + def unsupported_linear_op(x): + return x + + def linear(x): + if not torch.jit.is_scripting(): + return torch.linear(x) + else: + return unsupported_linear_op(x) + """ + return False def _unwrap_optional(x): assert x is not None, "Unwrapping null optional" @@ -1767,6 +1867,9 @@ def _unwrap_optional(x): (_triple, "aten::_triple"), (_unwrap_optional, "aten::_unwrap_optional"), (_wait, 'aten::wait'), + (is_scripting, "aten::is_scripting"), + (OrderedDict, "aten::dict"), + (dict, "aten::dict"), (cudnn.is_acceptable, "aten::cudnn_is_acceptable"), (math.ceil, "aten::ceil"), (math.copysign, "aten::copysign"), @@ -1893,7 +1996,7 @@ def _compile_function_with_overload(qual_name, impl_fn, overload_decl, overload_ return fn def _check_no_signature(func): - signature = torch.jit.annotations.get_signature(func) + signature = torch.jit.annotations.get_signature(func, None, None) if signature is None: qual_name = _qualified_name(func) raise RuntimeError("Must explicitly add type annotations to overloaded functions: {}".format(qual_name)) diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index 2e71973212520..60b412dac2798 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -3,7 +3,7 @@ import collections import torch._jit_internal as _jit_internal -from torch.nn import Module, ModuleList, Parameter, Sequential +from torch.nn import Module, ModuleList, Parameter, Sequential, ModuleDict from torch._six import get_function_from_type @@ -105,7 +105,7 @@ def recursive_script(mod, exclude_methods=()): if isinstance(mod, torch.jit.ScriptModule): return mod - if isinstance(mod, (torch.nn.ModuleList, torch.nn.Sequential)): + if isinstance(mod, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): # Create constant versions for the iterable modules return create_constant_iterable_module(mod) @@ -220,7 +220,7 @@ def create_constant_iterable_module(module): modules = collections.OrderedDict() for key, submodule in module._modules.items(): - if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + if isinstance(submodule, (ModuleList, Sequential, ModuleDict)): # Make each item in the module a constant modules[key] = create_constant_iterable_module(submodule) else: @@ -230,6 +230,8 @@ def create_constant_iterable_module(module): return torch.jit._ConstSequential(Sequential(modules)) elif isinstance(module, ModuleList): return torch.jit._ConstModuleList(modules) + elif isinstance(module, ModuleDict): + return torch.jit._ConstModuleDict(modules) else: - raise RuntimeError("Only nn.ModuleList and nn.Sequential can be made " + raise RuntimeError("Only nn.ModuleList, nn.Sequential, and nn.ModuleDict can be made " "into constant modules, found {}".format(module)) diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index 8efd299b0f086..5892da20e621d 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -9,6 +9,7 @@ from torch._C import TensorType, TupleType, FloatType, IntType, \ ListType, StringType, DictType, BoolType, OptionalType, ClassType from textwrap import dedent +from torch._utils_internal import get_source_lines_and_file PY35 = sys.version_info >= (3, 5) @@ -35,9 +36,27 @@ def __getattr__(self, name): 'Dict': Dict, 'Optional': Optional, } - - -def get_signature(fn): +class EvalEnv(object): + env = { + 'torch': Module('torch', {'Tensor': torch.Tensor}), + 'Tensor': torch.Tensor, + 'typing': Module('typing', {'Tuple': Tuple}), + 'Tuple': Tuple, + 'List': List, + 'Dict': Dict, + 'Optional': Optional, + } + + def __init__(self, rcb): + self.rcb = rcb + + def __getitem__(self, name): + if name in self.env: + return self.env[name] + if self.rcb is not None: + return self.rcb(name) + +def get_signature(fn, rcb, loc): # Python 3.5 adds support for the nice annotation syntax, so try that first. if PY35: sig = try_real_annotations(fn) @@ -46,7 +65,7 @@ def get_signature(fn): type_line, source = None, None try: - source = dedent(inspect.getsource(fn)) + source = dedent(''.join(get_source_lines_and_file(fn)[0])) type_line = get_type_line(source) except TypeError: pass @@ -55,7 +74,7 @@ def get_signature(fn): if type_line is None: return None - return parse_type_line(type_line) + return parse_type_line(type_line, rcb, loc) # This is essentially a weaker form of get_signature(), where we don't care if @@ -63,7 +82,7 @@ def get_signature(fn): # a function takes. def get_num_params(fn, loc): try: - source = dedent(inspect.getsource(fn)) + source = dedent(''.join(get_source_lines_and_file(fn)[0])) except (TypeError, IOError): return None if source is None: @@ -86,7 +105,7 @@ def get_num_params(fn, loc): return num_params -def parse_type_line(type_line): +def parse_type_line(type_line, rcb, loc): """Parses a type annotation specified as a comment. Example inputs: @@ -96,7 +115,7 @@ def parse_type_line(type_line): arg_ann_str, ret_ann_str = split_type_line(type_line) try: - arg_ann = eval(arg_ann_str, _eval_env) # noqa: P204 + arg_ann = eval(arg_ann_str, {}, EvalEnv(rcb)) # noqa: P204 except (NameError, SyntaxError) as e: raise RuntimeError("Failed to parse the argument list of a type annotation: {}".format(str(e))) @@ -104,12 +123,13 @@ def parse_type_line(type_line): arg_ann = (arg_ann,) try: - ret_ann = eval(ret_ann_str, _eval_env) # noqa: P204 + ret_ann = eval(ret_ann_str, {}, EvalEnv(rcb)) # noqa: P204 except (NameError, SyntaxError) as e: raise RuntimeError("Failed to parse the return type of a type annotation: {}".format(str(e))) - arg_types = [ann_to_type(ann) for ann in arg_ann] - return arg_types, ann_to_type(ret_ann) + resolver = (rcb, loc) + arg_types = [ann_to_type(ann, resolver) for ann in arg_ann] + return arg_types, ann_to_type(ret_ann, resolver) def get_type_line(source): @@ -197,7 +217,9 @@ def as_ann(ann): return arg_types, return_type -def ann_to_type(ann): +def ann_to_type(ann, resolver=None): + # resolver should be a Tuple[Callable, SourceRange] where the Callable + # is a resolutionCallback if ann is None: return TensorType.get() elif ann is torch.Tensor: @@ -225,6 +247,12 @@ def ann_to_type(ann): return BoolType.get() elif hasattr(ann, "__torch_script_class__"): return ClassType(_qualified_name(ann)) + elif resolver is not None: + # Maybe resolve a NamedTuple to a Tuple Type + rcb, loc = resolver + the_type = torch._C._resolve_type(ann.__name__, loc, rcb) + if the_type is not None: + return the_type raise ValueError("Unknown type annotation: '{}'".format(ann)) diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 8459c93faefb0..16696206cbfea 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -7,6 +7,7 @@ from textwrap import dedent from torch._six import PY2 from torch._C._jit_tree_views import * +from torch._utils_internal import get_source_lines_and_file # Borrowed from cPython implementation # https://github.com/python/cpython/blob/561612d8456cfab5672c9b445521113b847bd6b3/Lib/textwrap.py#L411# @@ -146,9 +147,8 @@ def get_jit_class_def(cls, self_name): method_defs = [get_jit_def(method[1], self_name=self_name) for method in methods] - sourcelines, file_lineno = inspect.getsourcelines(cls) + sourcelines, file_lineno, filename = get_source_lines_and_file(cls) source = ''.join(sourcelines) - filename = inspect.getsourcefile(cls) dedent_src = dedent(source) py_ast = ast.parse(dedent_src) leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0]) @@ -157,9 +157,8 @@ def get_jit_class_def(cls, self_name): def get_jit_def(fn, self_name=None): - sourcelines, file_lineno = inspect.getsourcelines(fn) + sourcelines, file_lineno, filename = get_source_lines_and_file(fn) source = ''.join(sourcelines) - filename = inspect.getsourcefile(fn) dedent_src = dedent(source) py_ast = ast.parse(dedent_src) if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): diff --git a/torch/jit/quantized.py b/torch/jit/quantized.py index 7412bc56991bc..44db61fbe3485 100644 --- a/torch/jit/quantized.py +++ b/torch/jit/quantized.py @@ -470,7 +470,7 @@ def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): assert batch_sizes is None result = _VF.quantized_lstm(input, hx, self._get_all_weights(), self.bias, self.num_layers, float(self.dropout), self.training, self.bidirectional, - self.batch_first, dtype=self.dtype) + self.batch_first, dtype=self.dtype, use_dynamic=False) output = result[0] hidden = result[1:] diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp index 337fc0c52bf04..b64cc3f101e71 100644 --- a/torch/lib/c10d/ProcessGroupGloo.cpp +++ b/torch/lib/c10d/ProcessGroupGloo.cpp @@ -1,6 +1,12 @@ #include +#include +#include +#include +#include + #include +#include #include #include #include @@ -19,9 +25,31 @@ #include #endif +#include #include #include + +#if GLOO_HAVE_TRANSPORT_TCP #include +#endif + +#if GLOO_HAVE_TRANSPORT_UV +#include +#endif + +// On Linux, check that the tcp transport is available. +#ifdef __linux__ +#if !GLOO_HAVE_TRANSPORT_TCP +#error "Expected the tcp transport to be available on Linux." +#endif +#endif + +// On macOS, check that the uv transport is available. +#ifdef __APPLE__ +#if !GLOO_HAVE_TRANSPORT_UV +#error "Expected the uv transport to be available on macOS." +#endif +#endif #define GENERATE_ALL_TYPES(type, func, args...) \ switch (type) { \ @@ -123,6 +151,11 @@ void setOutput(O& opts, at::Tensor& tensor) { opts.setOutput(getDataPointer(tensor), tensor.numel()); } +template +void setOutput(O& opts, at::Tensor& tensor, std::vector& counts) { + opts.setOutput(getDataPointer(tensor), counts); +} + #ifdef USE_CUDA at::Tensor pinnedLike(at::Tensor& tensor) { @@ -228,6 +261,8 @@ void initializeStreamsEvents( #endif +const auto kLoopbackAddress = "127.0.0.1"; + } // namespace ProcessGroupGloo::SendWork::SendWork( @@ -276,6 +311,147 @@ void ProcessGroupGloo::RecvWork::wait() { ProcessGroupGloo::Options::Options() : timeout(std::chrono::milliseconds(10 * 1000)), threads(2) {} +namespace { + +// Gloo assumes that this machine's hostname can always be resolved +// to an address. If it doesn't it throws a runtime error saying +// that it can't be resolved. Instead of catching it, we choose +// to proactively check if an address can be resolved, so we can +// gracefully fall back to an alternative if it doesn't. +bool doesHostnameResolveToUsableAddress(const std::string& hostname) { + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + struct addrinfo* result; + auto rv = getaddrinfo(hostname.c_str(), nullptr, &hints, &result); + if (rv < 0) { + return false; + } + struct addrinfo* rp; + for (rp = result; rp != nullptr; rp = rp->ai_next) { + auto fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + if (fd == -1) { + continue; + } + rv = bind(fd, rp->ai_addr, rp->ai_addrlen); + close(fd); + if (rv == -1) { + continue; + } + break; + } + freeaddrinfo(result); + return rp != nullptr; +} + +} // namespace + +#ifdef __linux__ +std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: + createDeviceForInterface(const std::string& interface) { + ::gloo::transport::tcp::attr attr; + attr.iface = interface; + return ::gloo::transport::tcp::CreateDevice(attr); +} +#endif + +#ifdef __APPLE__ +std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: + createDeviceForInterface(const std::string& interface) { + ::gloo::transport::uv::attr attr; + attr.iface = interface; + return ::gloo::transport::uv::CreateDevice(attr); +} +#endif + +#ifdef __linux__ +std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: + createDeviceForHostname(const std::string& hostname) { + ::gloo::transport::tcp::attr attr; + attr.hostname = hostname; + TORCH_CHECK( + doesHostnameResolveToUsableAddress(attr.hostname), + "Cannot resolve ", + hostname, + " to a (local) address"); + return ::gloo::transport::tcp::CreateDevice(attr); +} +#endif + +#ifdef __APPLE__ +std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: + createDeviceForHostname(const std::string& hostname) { + ::gloo::transport::uv::attr attr; + attr.hostname = hostname; + TORCH_CHECK( + doesHostnameResolveToUsableAddress(attr.hostname), + "Cannot resolve ", + hostname, + " to a (local) address"); + return ::gloo::transport::uv::CreateDevice(attr); +} +#endif + +#ifdef __linux__ +std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: + createDefaultDevice() { + ::gloo::transport::tcp::attr attr; + + // Use the hostname to resolve the network address to + // use. Note: if the hostname does not resolve to an address (e.g. + // because of misconfigured /etc/hosts file), this will not work. + std::array buffer{}; + auto rv = gethostname(buffer.data(), buffer.size()); + if (rv != 0) { + throw std::system_error(errno, std::system_category()); + } + attr.hostname = buffer.data(); + + // Use this machine's hostname if it resolves to an address. + if (doesHostnameResolveToUsableAddress(attr.hostname)) { + return ::gloo::transport::tcp::CreateDevice(attr); + } + + // Otherwise, use the loopback address. + TORCH_WARN_ONCE( + "Unable to resolve hostname to a (local) address. ", + "Using the loopback address as fallback. ", + "Manually set the network interface to bind to with GLOO_SOCKET_IFNAME."); + return createDeviceForHostname(kLoopbackAddress); +} +#endif + +#ifdef __APPLE__ +std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: + createDefaultDevice() { + ::gloo::transport::uv::attr attr; + + // Use the hostname to resolve the network address to + // use. Note: if the hostname does not resolve to an address (e.g. + // because of misconfigured /etc/hosts file), this will not work. + const auto hostNameMax = sysconf(_SC_HOST_NAME_MAX); + auto buffer = std::unique_ptr(new char[hostNameMax]); + auto rv = gethostname(buffer.get(), hostNameMax); + if (rv != 0) { + throw std::system_error(errno, std::system_category()); + } + attr.hostname = buffer.get(); + + // Use this machine's hostname if it resolves to an address. + if (doesHostnameResolveToUsableAddress(attr.hostname)) { + return ::gloo::transport::uv::CreateDevice(attr); + } + + // Otherwise, use the loopback address. + TORCH_WARN_ONCE( + "Unable to resolve hostname to a (local) address. ", + "Using the loopback address as fallback. ", + "Manually set the network interface to bind to with GLOO_SOCKET_IFNAME."); + return createDeviceForHostname(kLoopbackAddress); +} +#endif + ProcessGroupGloo::ProcessGroupGloo( const std::shared_ptr& store, int rank, @@ -644,7 +820,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { // Construct from an existing metadata tensor to facilitate structured // access to metadata from peers, after gathering it. explicit SparseTensorMetadata(at::Tensor metadata) - : metadata_(metadata), data_(metadata_.data_ptr()) { + : metadata_(metadata), data_(metadata_.data_ptr()) { AT_ASSERT(metadata.scalar_type() == at::kLong); AT_ASSERT(metadata.dim() == 1); AT_ASSERT(metadata.size(0) == dim); @@ -694,7 +870,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { protected: at::Tensor metadata_; - long* data_; + int64_t* data_; }; // Sparse allreduce is implemented with allgather on indices and values. @@ -703,6 +879,14 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { // we run allgather on the nnz, and then allgather with max(nnz). // We could use an allgatherv for this, if it were available. at::Tensor allreduce(std::vector& tensors) { + // TODO: This is a massive hack! There is some confusion about + // Variable/Tensor inside the body of this function. Turning off + // grad smooths over the confusion for now. This fixes + // test/test_c10d.py ProcessGroupGlooTest.test_sparse_allreduce_basics + // + // The correct fix is to stop allocating tensors that are not variables, + // but to conveniently do this c10d must depend on torch not ATen + at::AutoNonVariableTypeMode _no_grad(true); auto input = tensors[0]; // Perform local reduction if we have multiple inputs. @@ -719,7 +903,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { // Sanity check dimensionality across ranks. { const auto expected = metadata[context->rank].sizes(); - for (size_t i = 0; i < context->size; i++) { + for (auto i = 0; i < context->size; i++) { if (i == context->rank) { continue; } @@ -733,11 +917,11 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { auto values = allgather_values(input, metadata); // Perform global reduction. - AT_ASSERT(indices.size() == context->size); - AT_ASSERT(values.size() == context->size); + AT_ASSERT(static_cast(indices.size()) == context->size); + AT_ASSERT(static_cast(values.size()) == context->size); auto output = at::sparse_coo_tensor( indices[0], values[0], input.sizes(), input.options()); - for (size_t i = 1; i < context->size; i++) { + for (auto i = 1; i < context->size; i++) { output += at::sparse_coo_tensor( indices[i], values[i], input.sizes(), input.options()); } @@ -778,7 +962,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { // Allgather metadata gloo::AllgatherOptions opts(context); - opts.setOutput(buffer.data_ptr(), buffer.numel()); + opts.setOutput(buffer.data_ptr(), buffer.numel()); opts.setTag(tag); gloo::allgather(opts); @@ -788,29 +972,38 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { std::vector allgather_indices( const at::Tensor& tensor, const std::vector& metadata) { - auto max_nnz = metadata[0].nnz(); - for (size_t i = 1; i < metadata.size(); i++) { - max_nnz = std::max(max_nnz, metadata[i].nnz()); + const auto sparseDim = tensor.sparse_dim(); + + std::vector counts(context->size); + int64_t totalSize = 0; + for (size_t i = 0; i < metadata.size(); i++) { + counts[i] = metadata[i].nnz() * sparseDim; + totalSize += counts[i]; } - // There are #sparse_dim() 1-dimensional tensors with nnz elems per rank. - auto buffer = - at::empty({context->size, tensor.sparse_dim(), max_nnz}, at::kLong); - buffer.select(0, context->rank) - .narrow(1, 0, tensor._nnz()) - .copy_(tensor.indices()); + auto output = at::empty({totalSize}, at::kLong); - // Allgather indices. - gloo::AllgatherOptions opts(context); - opts.setOutput(buffer.data_ptr(), buffer.numel()); + // tensors copied from cuda may not be contiguous, get a contiguous + // tensor before use its data_ptr + auto input = tensor.indices().contiguous(); + + // Allgatherv indices. + gloo::AllgathervOptions opts(context); + opts.setInput(input.data_ptr(), input.numel()); + opts.setOutput(output.data_ptr(), counts); opts.setTag(tag); - gloo::allgather(opts); + gloo::allgatherv(opts); // Compile indices tensor per rank. std::vector indices; indices.reserve(metadata.size()); + size_t offset = 0; for (size_t i = 0; i < metadata.size(); i++) { - indices.push_back(buffer.select(0, i).narrow(1, 0, metadata[i].nnz())); + const auto nnz = metadata[i].nnz(); + const auto numel = sparseDim * nnz; + indices.push_back( + output.narrow(0, offset, numel).reshape({sparseDim, nnz})); + offset += numel; } return indices; @@ -819,34 +1012,47 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { std::vector allgather_values( const at::Tensor& tensor, const std::vector& metadata) { - auto max_nnz = metadata[0].nnz(); - for (size_t i = 1; i < metadata.size(); i++) { - max_nnz = std::max(max_nnz, metadata[i].nnz()); + // There are nnz #dense_dim()-dimensional tensors per rank. + const auto valueShape = tensor.sizes().slice(tensor.sparse_dim()); + size_t denseNumel = 1; + for (auto dim : valueShape) { + denseNumel *= dim; } - // There are nnz #dense_dim()-dimensional tensors per rank. - const auto value_shape = tensor.sizes().slice(tensor.sparse_dim()); - auto buffer_shape = std::vector({context->size, max_nnz}); - std::copy( - value_shape.begin(), - value_shape.end(), - std::back_inserter(buffer_shape)); - auto buffer = at::empty(buffer_shape, tensor.scalar_type()); - buffer.select(0, context->rank) - .narrow(0, 0, tensor._nnz()) - .copy_(tensor.values()); - - // Allgather values. - gloo::AllgatherOptions opts(context); - GENERATE_ALL_TYPES(tensor.scalar_type(), setOutput, opts, buffer); + std::vector counts(context->size); + int64_t totalSize = 0; + for (size_t i = 0; i < metadata.size(); i++) { + counts[i] = metadata[i].nnz() * denseNumel; + totalSize += counts[i]; + } + + auto output = at::empty({totalSize}, tensor.scalar_type()); + + // Allgatherv indices. + gloo::AllgathervOptions opts(context); + // tensors copied from cuda may not be contiguous, get a contiguous + // tensor before use its data_ptr + at::Tensor valueTensor = tensor.values().contiguous(); + GENERATE_ALL_TYPES(valueTensor.scalar_type(), setInput, opts, valueTensor); + GENERATE_ALL_TYPES( + valueTensor.scalar_type(), setOutput, opts, output, counts); opts.setTag(tag); - gloo::allgather(opts); + gloo::allgatherv(opts); // Compile values tensor per rank. std::vector values; values.reserve(metadata.size()); + size_t offset = 0; for (size_t i = 0; i < metadata.size(); i++) { - values.push_back(buffer.select(0, i).narrow(0, 0, metadata[i].nnz())); + const auto nnz = metadata[i].nnz(); + const auto numel = denseNumel * nnz; + auto tensorShape = std::vector({(int64_t)nnz}); + std::copy( + valueShape.begin(), + valueShape.end(), + std::back_inserter(tensorShape)); + values.push_back(output.narrow(0, offset, numel).reshape(tensorShape)); + offset += numel; } return values; diff --git a/torch/lib/c10d/ProcessGroupGloo.hpp b/torch/lib/c10d/ProcessGroupGloo.hpp index 8a954c716e3a3..058bf907e5fb1 100644 --- a/torch/lib/c10d/ProcessGroupGloo.hpp +++ b/torch/lib/c10d/ProcessGroupGloo.hpp @@ -127,6 +127,24 @@ class ProcessGroupGloo : public ProcessGroup { int threads; }; + // Helper functions to create a new device object. + // They are static functions on this class to keep them logically + // separate from the rest of the code base (e.g. torch/csrc/distributed). + + // Create new device instance for specific interface. + static std::shared_ptr<::gloo::transport::Device> createDeviceForInterface( + const std::string& interface); + + // Create new device instance for specific hostname or address. + static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname( + const std::string& hostname); + + // Create new device instance. + // It tries to resolve this machine's hostname and bind to that address. + // If that fails (i.e. the hostname doesn't resolve to an address), it + // falls back to binding to the loopback address. + static std::shared_ptr<::gloo::transport::Device> createDefaultDevice(); + explicit ProcessGroupGloo( const std::shared_ptr& store, int rank, diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 40ec1a45a8cf1..85ea14c14e5c2 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -217,21 +217,14 @@ void ProcessGroupNCCL::WorkNCCL::wait() { synchronize(); } -std::unordered_map ProcessGroupNCCL::pgUniqueNCCLIDCnt_; -std::unordered_map - ProcessGroupNCCL::processGroupCounterMap_; - -std::mutex ProcessGroupNCCL::pgTrackingLock_; - ProcessGroupNCCL::ProcessGroupNCCL( const std::shared_ptr& store, int rank, int size, - const std::string& groupName, const std::chrono::milliseconds& opTimeout) : ProcessGroup(rank, size), store_(store), - groupName_(groupName), + ncclCommCounter_(0), terminateWatchdog_(false), opTimeout_(opTimeout) { char* blockingWait = getenv(NCCL_BLOCKING_WAIT); @@ -253,25 +246,11 @@ ProcessGroupNCCL::ProcessGroupNCCL( std::string(NCCL_BLOCKING_WAIT)); } - // Generate the Process Group ID for current PG, this needs to be identical - // for all processes - std::unique_lock lock(pgTrackingLock_); - // Default group is an empty string - const auto groupKey = groupName_ + "_"; - if (processGroupCounterMap_.count(groupKey) == 0) { - processGroupCounterMap_[groupKey] = -1; - } - ++processGroupCounterMap_[groupKey]; - processGroupID_ = std::to_string(processGroupCounterMap_[groupKey]); - groupPgID_ = groupName_ + "_" + processGroupID_; - pgUniqueNCCLIDCnt_[groupPgID_] = -1; ncclCommWatchdogThread_ = std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this); } ProcessGroupNCCL::~ProcessGroupNCCL() { - std::unique_lock lock(pgTrackingLock_); - pgUniqueNCCLIDCnt_.erase(groupPgID_); terminateWatchdog_.store(true); watchdogCV_.notify_one(); ncclCommWatchdogThread_.join(); @@ -343,36 +322,22 @@ std::exception_ptr ProcessGroupNCCL::checkForNCCLErrorsInternal( } void ProcessGroupNCCL::broadcastUniqueNCCLID(ncclUniqueId* ncclID) { - // Every time when we create a new unique NCCL ID, we need to use a new - // global key to access/update the store. - // The key is a combination of processGroupID_ and the current count of - // NCCL unique ID created - std::unique_lock lock(pgTrackingLock_); - auto groupPgId = groupName_ + "_" + processGroupID_; - const auto uniqueNCCLIDCnt = ++pgUniqueNCCLIDCnt_[groupPgID_]; - - lock.unlock(); - - std::string storeKey = - processGroupID_ + "_" + std::to_string(uniqueNCCLIDCnt); - - // Rank 0 writes to the store as bcast + // For every NCCL communicator that we create we need to broadcast + // a unique ID from rank 0 to all other ranks. This broadcast is + // done by rank 0 setting a key in the store and all other ranks + // retrieving the contents of that key. A single process group + // may create multiple NCCL communicators, so we use a sequence + // number to differentiate between them. + std::string storeKey = std::to_string(ncclCommCounter_++); if (rank_ == 0) { - auto ncclIDVal = std::vector( + auto vec = std::vector( reinterpret_cast(ncclID), reinterpret_cast(ncclID) + NCCL_UNIQUE_ID_BYTES); - store_->set(storeKey, ncclIDVal); - // Other ranks get to the store + store_->set(storeKey, vec); } else { - auto ncclIDVal = store_->get(storeKey); - // Just a sanity check - if (ncclIDVal.size() != NCCL_UNIQUE_ID_BYTES) { - throw std::runtime_error( - "Unexpected NCCL unique ID length received " - "from the store"); - } - // Now put the data back to the input pointer - memcpy(ncclID, ncclIDVal.data(), NCCL_UNIQUE_ID_BYTES); + auto vec = store_->get(storeKey); + AT_CHECK(vec.size() == NCCL_UNIQUE_ID_BYTES); + std::memcpy(ncclID, vec.data(), vec.size()); } } diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index 2336a7d825091..64bdea4bbd82c 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -120,28 +120,39 @@ class ProcessGroupNCCL : public ProcessGroup { friend class ProcessGroupNCCL; }; - // Constructor will also check the number of available GPUs in the system + // If you wish to create multiple process groups, each with a potentially + // different rank and size, you can do so by passing a new store instance + // to each one. If you have only a single store object, you can + // use the `c10d::PrefixStore` to derive scoped instances. + // This is also what the Python API in torch.distributed does. // - // Group support: + // The process group instance keeps a reference to the store because + // it may be used long after the constructor runs. In fact, the constructor + // doesn't create any NCCL communicators. A single NCCL communicator can + // only be used on a specific set of devices, and are therefore created + // on-demand when a collective runs. If another collective is executed later, + // against a different set of devices, the process group creates another NCCL + // communicator. These NCCL communicators are cached and reused if possible. // - // In order to support multiple NCCL process groups, each of which has - // different group ranks, we need to use groupName to identify each group - // to ensure the correct behavior. In other words, each process group that - // has different group ranks needs to have a different and unique groupName - // to avoid clashing into undefined behaviors. - // - // In Python frontend API of torch.distributed, it guarantees that each group - // will have a unique name to be passed into the ProcessGroupNCCL constructor. - // If you would like to use ProcessGroupNCCL constructor directly, it is - // your reponsibility to do so as well. ProcessGroupNCCL( const std::shared_ptr& store, int rank, int size, - const std::string& groupName = "", const std::chrono::milliseconds& opTimeout = std::chrono::milliseconds(kProcessGroupNCCLOpTimeoutMillis)); + // This constructor includes the deprecated `groupName` argument. + // If you have existing code that uses the `groupName`, you can replace + // it by specifying a `c10d::PrefixStore(groupName, store)` for store. + C10_DEPRECATED ProcessGroupNCCL( + const std::shared_ptr& store, + int rank, + int size, + const std::string& groupName, + const std::chrono::milliseconds& opTimeout = + std::chrono::milliseconds(kProcessGroupNCCLOpTimeoutMillis)) + : ProcessGroupNCCL(store, rank, size, opTimeout) {} + virtual ~ProcessGroupNCCL(); std::shared_ptr broadcast( @@ -257,11 +268,13 @@ class ProcessGroupNCCL : public ProcessGroup { protected: static const int64_t kWatchdogThreadSleepMillis; - // Store that is used to exchange each Ranks's NCCL unique ID + // The store is used to broadcast the NCCL unique ID of rank 0. std::shared_ptr store_; - // The process group name - std::string groupName_; + // The number of NCCL communicators that have been created during + // the lifetime of this process group. This sequence number is + // used to scope keys used in the store. + uint64_t ncclCommCounter_{0}; // The NCCL communicator that the process group has cached. // The key is a list of GPU devices that an operation is operating on @@ -307,18 +320,9 @@ class ProcessGroupNCCL : public ProcessGroup { // The CUDA events used to sync NCCL streams std::unordered_map> ncclEvents_; - // ID of this process group - std::string processGroupID_; - - // Group Prefix and ID of this process group - std::string groupPgID_; - // Device Indexes used for all collectives in this group std::set usedDeviceIdxs_; - // processGroupID tracking - static std::mutex pgTrackingLock_; - // map from the key: "group name + pg counter (ID)" to the // unique NCCL ID count. This needs to be group and pg specific // diff --git a/torch/lib/c10d/TCPStore.cpp b/torch/lib/c10d/TCPStore.cpp index 242b9a6c93166..f992fedc293c8 100644 --- a/torch/lib/c10d/TCPStore.cpp +++ b/torch/lib/c10d/TCPStore.cpp @@ -104,13 +104,6 @@ void TCPStoreDaemon::run() { continue; } - if (fds[fdIdx].revents ^ POLLIN) { - throw std::system_error( - ECONNABORTED, - std::system_category(), - "Unexpected poll revent: " + std::to_string(fds[fdIdx].revents) + - " on socket: " + std::to_string(fds[fdIdx].fd)); - } // Now query the socket that has the event try { query(fds[fdIdx].fd); diff --git a/torch/lib/c10d/Utils.cpp b/torch/lib/c10d/Utils.cpp index 70b03362419c4..ec0cd47598350 100644 --- a/torch/lib/c10d/Utils.cpp +++ b/torch/lib/c10d/Utils.cpp @@ -21,7 +21,8 @@ namespace tcputil { namespace { -constexpr int LISTEN_QUEUE_SIZE = 64; +constexpr int LISTEN_QUEUE_SIZE = 2048; +const std::string kConnectTimeoutMsg = "connect() timed out."; void setSocketNoDelay(int socket) { int flag = 1; @@ -156,9 +157,13 @@ int connect( struct ::addrinfo* nextAddr = addresses.get(); int socket; - // we'll loop over the addresses only if at least of them gave us ECONNREFUSED - // Maybe the host was up, but the server wasn't running. + + // Loop over the addresses if at least one of them gave us ECONNREFUSED + // or ECONNRESET. This may happen if the server hasn't started listening + // yet, or is listening but has its listen backlog exhausted. bool anyRefused = false; + bool anyReset = false; + const auto start = std::chrono::high_resolution_clock::now(); while (true) { try { SYSCHECK_ERR_RETURN_NEG1( @@ -182,12 +187,22 @@ int connect( pfd.fd = socket; pfd.events = POLLOUT; - int numReady = ::poll(&pfd, 1, timeout.count()); + int64_t pollTimeout = -1; + if (timeout != kNoTimeout) { + // calculate remaining time and use that as timeout for poll() + const auto elapsed = std::chrono::high_resolution_clock::now() - start; + const auto remaining = + std::chrono::duration_cast(timeout) - + std::chrono::duration_cast(elapsed); + pollTimeout = std::max( + static_cast(0), static_cast(remaining.count())); + } + int numReady = ::poll(&pfd, 1, pollTimeout); if (numReady < 0) { throw std::system_error(errno, std::system_category()); } else if (numReady == 0) { errno = 0; - throw std::runtime_error("connect() timed out"); + throw std::runtime_error(kConnectTimeoutMsg); } socklen_t errLen = sizeof(errno); @@ -210,9 +225,14 @@ int connect( break; } catch (std::exception& e) { + // ECONNREFUSED happens if the server is not yet listening. if (errno == ECONNREFUSED) { anyRefused = true; } + // ECONNRESET happens if the server's listen backlog is exhausted. + if (errno == ECONNRESET) { + anyReset = true; + } // We need to move to the next address because this was not available // to connect or to create a socket. @@ -220,11 +240,22 @@ int connect( // We have tried all addresses but could not connect to any of them. if (!nextAddr) { - if (!wait || !anyRefused) { + if (!wait || (!anyRefused && !anyReset)) { throw; } + + // if a timeout is specified, check time elapsed to see if we need to + // timeout. A timeout is specified if timeout != kNoTimeout. + if (timeout != kNoTimeout) { + const auto elapsed = + std::chrono::high_resolution_clock::now() - start; + if (elapsed > timeout) { + throw std::runtime_error(kConnectTimeoutMsg); + } + } std::this_thread::sleep_for(std::chrono::seconds(1)); anyRefused = false; + anyReset = false; nextAddr = addresses.get(); } } diff --git a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp index dc85fa06c526f..28cf9e422c20a 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp @@ -1,5 +1,3 @@ -#include - #include #include @@ -51,8 +49,8 @@ class AsyncTest { // Use tiny timeout to make this test run fast ::c10d::ProcessGroupGloo::Options options; options.timeout = std::chrono::milliseconds(50); - ::gloo::transport::tcp::attr attr; - options.devices.push_back(::gloo::transport::tcp::CreateDevice(attr)); + options.devices.push_back( + ::c10d::ProcessGroupGloo::createDeviceForHostname("127.0.0.1")); pg_ = std::unique_ptr<::c10d::ProcessGroupGloo>( new ::c10d::ProcessGroupGloo(store, rank, size, options)); diff --git a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp index 3aaa70aa17a98..f50ef04c7fca8 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp @@ -9,8 +9,6 @@ #include #include -#include - #include #include #include @@ -42,8 +40,8 @@ class SignalTest { // Use tiny timeout to make this test run fast ::c10d::ProcessGroupGloo::Options options; options.timeout = std::chrono::milliseconds(50); - ::gloo::transport::tcp::attr attr; - options.devices.push_back(::gloo::transport::tcp::CreateDevice(attr)); + options.devices.push_back( + ::c10d::ProcessGroupGloo::createDeviceForHostname("127.0.0.1")); ::c10d::ProcessGroupGloo pg(store, rank, size, options); @@ -127,9 +125,8 @@ class CollectiveTest { // Use tiny timeout to make this test run fast ::c10d::ProcessGroupGloo::Options options; options.timeout = std::chrono::milliseconds(50); - - ::gloo::transport::tcp::attr attr; - options.devices.push_back(::gloo::transport::tcp::CreateDevice(attr)); + options.devices.push_back( + ::c10d::ProcessGroupGloo::createDeviceForHostname("127.0.0.1")); pg_ = std::unique_ptr<::c10d::ProcessGroupGloo>( new ::c10d::ProcessGroupGloo(store, rank, size, options)); diff --git a/torch/lib/c10d/test/TestUtils.hpp b/torch/lib/c10d/test/TestUtils.hpp index 26ff7fe2a6fc2..62402b4eeedcb 100644 --- a/torch/lib/c10d/test/TestUtils.hpp +++ b/torch/lib/c10d/test/TestUtils.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index 6273a0506494f..9a03814e82ae6 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -1,6 +1,6 @@ import torch import torch.utils.hooks -from torch.namedtensor import _check_serializing_named_tensor +from torch._namedtensor_internals import check_serializing_named_tensor import os import threading import errno @@ -138,7 +138,7 @@ def reduce_tensor(tensor): "If you just want to transfer the data, call detach() on the tensor " "before serializing (e.g., putting it on the queue).") - _check_serializing_named_tensor(tensor) + check_serializing_named_tensor(tensor) torch.utils.hooks.warn_if_has_hooks(tensor) # Note [CUDA IPC and the caching allocator] diff --git a/torch/nn/_intrinsic/quantized/modules/conv_relu.py b/torch/nn/_intrinsic/quantized/modules/conv_relu.py index 5bf6b17e9be1b..b1463d2e4b2bf 100644 --- a/torch/nn/_intrinsic/quantized/modules/conv_relu.py +++ b/torch/nn/_intrinsic/quantized/modules/conv_relu.py @@ -24,33 +24,16 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode) - def weight(self): - return torch.ops.quantized.fbgemm_conv_unpack(self._packed_weight).permute([0, 3, 1, 2]) - - def set_weight(self, w): - self._packed_weight = torch.ops.quantized.fbgemm_conv_prepack(w.permute([0, 2, 3, 1]), - self.stride, - self.padding, - self.dilation, - self.groups) - self.weight_scale = w.q_scale() - def forward(self, input): # Temporarily using len(shape) instead of ndim due to JIT issue # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 4: raise ValueError("Input shape must be `(N, C, H, W)`!") - # Temporary work around for bias - # see Issue:https://github.com/pytorch/pytorch/issues/23874 - bias = self.bias - if bias is not None: - bias = torch.quantize_linear(bias.dequantize(), float(self.weight_scale) * input.q_scale(), 0, torch.qint32) - output = torch.ops.quantized.fbgemm_conv2d_relu(input.permute([0, 2, 3, 1]), - self._packed_weight, bias, - self.stride, self.padding, - self.dilation, self.groups, - float(self.scale), int(self.zero_point)) - return output.permute([0, 3, 1, 2]) + return torch.ops.quantized.conv2d_relu(input, + self._packed_params, + self.stride, self.padding, + self.dilation, self.groups, + float(self.scale), int(self.zero_point)) @classmethod def from_float(cls, mod): diff --git a/torch/nn/_intrinsic/quantized/modules/linear_relu.py b/torch/nn/_intrinsic/quantized/modules/linear_relu.py index 0714e01b5eb5d..1ed5a92508923 100644 --- a/torch/nn/_intrinsic/quantized/modules/linear_relu.py +++ b/torch/nn/_intrinsic/quantized/modules/linear_relu.py @@ -26,13 +26,8 @@ def __init__(self, in_features, out_features, bias=True): super(LinearReLU, self).__init__(in_features, out_features, bias) def forward(self, input): - bias = self.bias - if bias is not None: - bias = torch.quantize_linear(bias.dequantize(), float(self.weight_scale) * input.q_scale(), 0, torch.qint32) - - Y_q = torch.ops.quantized.fbgemm_linear_relu( - input, self._packed_weight, - bias, + Y_q = torch.ops.quantized.linear_relu( + input, self._packed_params, float(self.scale), int(self.zero_point)) return Y_q diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index a7e98b50243e5..b48fea7ca1880 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -331,6 +331,9 @@ def update(self, modules): "; 2 is required") self[m[0]] = m[1] + def forward(self): + raise NotImplementedError() + class ParameterList(Module): r"""Holds parameters in a list. diff --git a/torch/nn/parameter.pyi b/torch/nn/parameter.pyi index 54feb513e0df4..b5b65781cf137 100644 --- a/torch/nn/parameter.pyi +++ b/torch/nn/parameter.pyi @@ -1,4 +1,7 @@ from .. import Tensor +import builtins class Parameter(Tensor): + def __init__(self, data: Tensor, requires_grad: builtins.bool): ... + ... diff --git a/torch/nn/quantized/dynamic/modules/linear.py b/torch/nn/quantized/dynamic/modules/linear.py index 3a9139acd7e96..7c0fee7dc60dc 100644 --- a/torch/nn/quantized/dynamic/modules/linear.py +++ b/torch/nn/quantized/dynamic/modules/linear.py @@ -3,8 +3,6 @@ from ....modules.linear import Linear as NNLinear import torch.nn.quantized as nnq -from torch._jit_internal import Optional - class Linear(nnq.Linear): r""" A dynamic quantized linear module with quantized tensor as inputs and outputs. @@ -31,25 +29,17 @@ class Linear(nnq.Linear): torch.Size([128, 30]) """ - __annotations__ = {'bias' : Optional[torch.Tensor]} - def __init__(self, in_features, out_features, bias_=True): super(Linear, self).__init__(in_features, out_features, bias_) # We don't muck around with buffers or attributes or anything here # to keep the module simple. *everything* is simply a Python attribute. # Serialization logic is explicitly handled in the below serialization and # deserialization modules - if bias_: - del self.bias - self.bias = torch.Tensor(out_features).float() - else: - self.bias = None def forward(self, x): # Note that we can handle self.bias == None case. - Y = torch.ops.quantized.fbgemm_linear_dynamic( - x, self._packed_weight, - self.bias) + Y = torch.ops.quantized.linear_dynamic( + x, self._packed_params) return Y.to(x.dtype) @classmethod @@ -75,6 +65,5 @@ def from_float(cls, mod): wt_scale, wt_zp = weight_observer.calculate_qparams() qweight = torch.quantize_linear(mod.weight.float(), float(wt_scale), int(wt_zp), torch.qint8) qlinear = Linear(mod.in_features, mod.out_features) - qlinear.set_weight(qweight) - qlinear.bias = mod.bias + qlinear.set_weight_bias(qweight, mod.bias) return qlinear diff --git a/torch/nn/quantized/dynamic/modules/rnn.py b/torch/nn/quantized/dynamic/modules/rnn.py index ac1e691ff3c61..8af702618bd2b 100644 --- a/torch/nn/quantized/dynamic/modules/rnn.py +++ b/torch/nn/quantized/dynamic/modules/rnn.py @@ -5,7 +5,6 @@ from torch import Tensor # noqa: F401 from torch.nn import _VF from torch._jit_internal import Tuple, Optional, List # noqa: F401 -from torch._jit_internal import _parameter_list from torch.nn.utils.rnn import PackedSequence import numbers @@ -19,13 +18,9 @@ class RNNBase(torch.nn.Module): _FLOAT_MODULE = nn.RNNBase - __constants__ = ['mode', 'input_size', 'hidden_size', 'num_layers', 'bias', - 'batch_first', 'dropout', 'bidirectional', '_packed_weights', - '_quantized_weights'] - def __init__(self, mode, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, - dropout=0., bidirectional=False): + dropout=0., bidirectional=False, dtype=torch.qint8): super(RNNBase, self).__init__() self.mode = mode @@ -36,6 +31,7 @@ def __init__(self, mode, input_size, hidden_size, self.batch_first = batch_first self.dropout = float(dropout) self.bidirectional = bidirectional + self.dtype = dtype num_directions = 2 if bidirectional else 1 if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \ @@ -54,56 +50,68 @@ def __init__(self, mode, input_size, hidden_size, else: raise ValueError("Unrecognized RNN mode: " + mode) - self._all_weights = [] - - packed_weights = [] - quantized_weights = [] - + self._all_weight_names = [] + self._all_weight_values = [] for layer in range(num_layers): for direction in range(num_directions): layer_input_size = input_size if layer == 0 else hidden_size * num_directions - def process_weights(ihhh, layer, suffix, qweight, bias): - weight_name = 'weight_{}_l{}{}'.format(ihhh, layer, suffix) - bias_name = 'bias_{}_l{}{}'.format(ihhh, layer, suffix) - - # for each layer, for each direction we need to quantize and pack - # weights and pack parameters in this order: - # - # w_ih, w_hh, b_ih, b_hh - packed_weight = \ - torch.ops.quantized.fbgemm_linear_prepack(qweight) - params = [packed_weight, bias] - pos_names = ['w', 'b'] - ret_name = ['{}_{}_l{}{}'.format( - name, ihhh, layer, suffix) for name in pos_names] - quantized_weights.append(qweight) - packed_weights.append(ret_name[0]) - return params, ret_name - - w_ih = torch._empty_affine_quantized( - [gate_size, layer_input_size], scale=1, zero_point=0, dtype=torch.qint8) - w_hh = torch._empty_affine_quantized( - [gate_size, hidden_size], scale=1, zero_point=0, dtype=torch.qint8) - b_ih = torch._empty_affine_quantized( - [gate_size], scale=1, zero_point=0, dtype=torch.qint32) - # Second bias vector included for CuDNN compatibility. Only one - # bias vector is needed in standard definition. - b_hh = torch._empty_affine_quantized( - [gate_size], scale=1, zero_point=0, dtype=torch.qint32) + def process_weights(ihhh, layer, suffix, qweight, bias, dtype): + if dtype == torch.qint8: + # for each layer, for each direction we need to quantize and pack + # weights and pack parameters in this order: + # + # w_ih, w_hh + packed_weight = \ + torch.ops.quantized.linear_prepack(qweight, bias) + + params = [packed_weight] + pos_names = ['w'] + ret_name = ['{}_{}_l{}{}'.format( + name, ihhh, layer, suffix) for name in pos_names] + return params, ret_name + else: + # for each layer, for each direction we need to quantize and pack + # weights and pack parameters in this order: + # + # packed_ih, packed_hh, b_ih, b_hh + packed_weight = torch.fbgemm_pack_gemm_matrix_fp16( + qweight) + + params = [packed_weight, bias] + pos_names = ['packed', 'b'] + ret_name = ['{}_{}_l{}{}'.format(name, ihhh, layer, suffix) for name in pos_names] + return params, ret_name + + if dtype == torch.qint8: + w_ih = torch._empty_affine_quantized( + [gate_size, layer_input_size], scale=1, zero_point=0, dtype=torch.qint8) + w_hh = torch._empty_affine_quantized( + [gate_size, hidden_size], scale=1, zero_point=0, dtype=torch.qint8) + b_ih = torch._empty_affine_quantized( + [gate_size], scale=1, zero_point=0, dtype=torch.qint32) + # Second bias vector included for CuDNN compatibility. Only one + # bias vector is needed in standard definition. + b_hh = torch._empty_affine_quantized( + [gate_size], scale=1, zero_point=0, dtype=torch.qint32) + + else: + w_ih = torch.Tensor(gate_size, layer_input_size).float() + w_hh = torch.Tensor(gate_size, hidden_size).float() + b_ih = torch.Tensor(gate_size).float() + # Second bias vector included for CuDNN compatibility. Only one + # bias vector is needed in standard definition. + b_hh = torch.Tensor(gate_size).float() suffix = '_reverse' if direction == 1 else '' ih_params, ih_param_names = process_weights( - 'ih', layer, suffix, w_ih, b_ih) + 'ih', layer, suffix, w_ih, b_ih, dtype) hh_params, hh_param_names = process_weights( - 'hh', layer, suffix, w_hh, b_hh) + 'hh', layer, suffix, w_hh, b_hh, dtype) for (ih, ih_name), (hh, hh_name) in zip(zip(ih_params, ih_param_names), zip(hh_params, hh_param_names)): - self.register_buffer(ih_name, torch.tensor( - ih) if not isinstance(ih, torch.Tensor) else ih) - self.register_buffer(hh_name, torch.tensor( - hh) if not isinstance(hh, torch.Tensor) else hh) - self._all_weights.extend([ih_name, hh_name]) + self._all_weight_names.extend([ih_name, hh_name]) + self._all_weight_values.extend([ih, hh]) def check_input(self, input, batch_sizes): # type: (Tensor, Optional[Tensor]) -> None @@ -148,49 +156,75 @@ def permute_hidden(self, hx, permutation): return hx return apply_permutation(hx, permutation) - @property - def all_weights(self): - return [getattr(self, weight) for weight in self._all_weights] - - def _get_all_weights_names(self): - return [weight for weight in self._all_weights] - - @_parameter_list(_get_all_weights_names) - def _get_all_weights(self): - return self.all_weights - - def _get_packed_weights_names(self): - return self._packed_weights - - @_parameter_list(_get_packed_weights_names) - def _get_packed_weights(self): - return [getattr(self, name) for name in self._packed_weights] - - def _get_quantized_weights_names(self): - return self._quantized_weights - - @_parameter_list(_get_quantized_weights_names) - def _get_quantized_weights(self): - return [getattr(self, name) for name in self._quantized_weights] + @torch.jit.export + def __getstate__(self): + vals = ( + self.mode, + self.input_size, + self.hidden_size, + self.num_layers, + self.bias, + self.batch_first, + self.dropout, + self.bidirectional, + self._all_weight_names, + self.__overloads__, + self.training, + self.dtype, + ) + + dynamic_vals = torch.jit.annotate(List[Tuple[torch.Tensor, Optional[torch.Tensor]]], + []) + + for i in range(len(self._all_weight_names)): + dynamic_vals.append(torch.ops.quantized.linear_unpack(self._all_weight_values[i])) + return vals, dynamic_vals + + @torch.jit.export + def __setstate__(self, state): + vals, dynamic_vals = state + self.mode = vals[0] + self.input_size = vals[1] + self.hidden_size = vals[2] + self.num_layers = vals[3] + self.bias = vals[4] + self.batch_first = vals[5] + self.dropout = vals[6] + self.bidirectional = vals[7] + self._all_weight_names = vals[8] + self.__overloads__ = vals[9] + self.training = vals[10] + self.dtype = vals[11] + + self._all_weight_values = [] + for i in range(len(self._all_weight_names)): + self._all_weight_values.append(torch.ops.quantized.linear_prepack(*dynamic_vals[i])) @classmethod - def from_float(cls, mod): + def from_float(cls, mod, dtype=torch.qint8): assert type(mod) == torch.nn.LSTM, 'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM' assert hasattr( mod, 'qconfig'), 'Input float module must have qconfig defined' - if mod.qconfig is not None and mod.qconfig.weight() is not None: - weight_observer = mod.qconfig.weight() - else: - # We have the circular import issues if we import the qconfig in the beginning of this file: - # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the - # import until we need it. - from torch.quantization.QConfig import default_dynamic_qconfig - weight_observer = default_dynamic_qconfig.weight() - assert weight_observer.dtype == torch.qint8, 'Weight observer must have dtype torch.qint8' + + supported_scalar_types = [torch.qint8, torch.float16] + if dtype not in supported_scalar_types: + raise RuntimeError('Unsupported dtype: {}'.format(dtype)) + + # When dtype = torch.float16, we don't need weight_observer + if dtype == torch.qint8: + if mod.qconfig is not None and mod.qconfig.weight() is not None: + weight_observer = mod.qconfig.weight() + else: + # We have the circular import issues if we import the qconfig in the beginning of this file: + # https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the + # import until we need it. + from torch.quantization.QConfig import default_dynamic_qconfig + weight_observer = default_dynamic_qconfig.weight() + assert weight_observer.dtype == torch.qint8, 'Weight observer must have dtype torch.qint8' if mod.mode == 'LSTM': qRNNBase = LSTM(mod.input_size, mod.hidden_size, mod.num_layers, - mod.bias, mod.batch_first, mod.dropout, mod.bidirectional) + mod.bias, mod.batch_first, mod.dropout, mod.bidirectional, dtype) num_directions = 2 if mod.bidirectional else 1 @@ -200,52 +234,56 @@ def from_float(cls, mod): if qRNNBase.mode != 'LSTM': raise RuntimeError('Only LSTM is supported for QuantizedRNN') - qRNNBase._all_weights = [] - packed_weights = [] - quantized_weights = [] + qRNNBase._all_weight_names = [] + qRNNBase._all_weight_values = [] for layer in range(qRNNBase.num_layers): for direction in range(num_directions): layer_input_size = qRNNBase.input_size if layer == 0 else qRNNBase.hidden_size * num_directions - def process_weights(ihhh, layer, suffix): + def process_weights(ihhh, layer, suffix, dtype): weight_name = 'weight_{}_l{}{}'.format(ihhh, layer, suffix) bias_name = 'bias_{}_l{}{}'.format(ihhh, layer, suffix) weight = getattr(mod, weight_name) bias = getattr(mod, bias_name) - # for each layer, for each direction we need to quantize and pack - # weights and pack parameters in this order: - # - # w_ih, w_hh, b_ih, b_hh - weight_observer(weight) - wt_scale, wt_zp = weight_observer.calculate_qparams() - qweight = torch.quantize_linear( - weight.float(), float(wt_scale), int(wt_zp), torch.qint8) - packed_weight = \ - torch.ops.quantized.fbgemm_linear_prepack(qweight) - - params = [packed_weight, bias] - pos_names = ['w', 'b'] - ret_name = ['{}_{}_l{}{}'.format( - name, ihhh, layer, suffix) for name in pos_names] - quantized_weights.append(qweight) - packed_weights.append(ret_name[0]) - return params, ret_name + + if dtype == torch.qint8: + # for each layer, for each direction we need to quantize and pack + # weights and pack parameters in this order: + # + # w_ih, w_hh + weight_observer(weight) + wt_scale, wt_zp = weight_observer.calculate_qparams() + qweight = torch.quantize_linear( + weight.float(), float(wt_scale), int(wt_zp), torch.qint8) + packed_weight = \ + torch.ops.quantized.linear_prepack(qweight, bias) + + params = [packed_weight] + pos_names = ['w'] + ret_name = ['{}_{}_l{}{}'.format( + name, ihhh, layer, suffix) for name in pos_names] + return params, ret_name + else: + # for each layer, for each direction we need to quantize and pack + # weights and pack parameters in this order: + # + # packed_ih, packed_hh, b_ih, b_hh + packed_weight = torch.fbgemm_pack_gemm_matrix_fp16( + weight.float()) + + params = [packed_weight, bias] + pos_names = ['packed', 'b'] + ret_name = ['{}_{}_l{}{}'.format(name, ihhh, layer, suffix) for name in pos_names] + return params, ret_name suffix = '_reverse' if direction == 1 else '' - ih_params, ih_param_names = process_weights('ih', layer, suffix) - hh_params, hh_param_names = process_weights('hh', layer, suffix) + ih_params, ih_param_names = process_weights('ih', layer, suffix, dtype) + hh_params, hh_param_names = process_weights('hh', layer, suffix, dtype) for (ih, ih_name), (hh, hh_name) in zip(zip(ih_params, ih_param_names), zip(hh_params, hh_param_names)): - qRNNBase.register_buffer(ih_name, torch.tensor( - ih) if not isinstance(ih, torch.Tensor) else ih) - qRNNBase.register_buffer(hh_name, torch.tensor( - hh) if not isinstance(hh, torch.Tensor) else hh) - qRNNBase._all_weights.extend([ih_name, hh_name]) - - qRNNBase._packed_weights = packed_weights - # DO WE NEED _quantized_weights? @jianyuh: will remove _quantized_weight as now we support the fbgemm_linear_unpack function - qRNNBase._quantized_weights = quantized_weights + qRNNBase._all_weight_names.extend([ih_name, hh_name]) + qRNNBase._all_weight_values.extend([ih, hh]) return qRNNBase @@ -275,14 +313,15 @@ def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): self.check_forward_args(input, hx, batch_sizes) assert batch_sizes is None - result = _VF.quantized_lstm(input, hx, self._get_all_weights(), self.bias, self.num_layers, + result = _VF.quantized_lstm(input, hx, self._all_weight_values, self.bias, self.num_layers, float(self.dropout), self.training, self.bidirectional, - self.batch_first, dtype=torch.int8, use_dynamic=True) + self.batch_first, dtype=self.dtype, use_dynamic=True) output = result[0] hidden = result[1:] return output, hidden + @torch.jit.export def forward_tensor(self, input, hx=None): # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] batch_sizes = None @@ -295,6 +334,7 @@ def forward_tensor(self, input, hx=None): return output, self.permute_hidden(hidden, unsorted_indices) + @torch.jit.export def forward_packed(self, input, hx=None): # type: (PackedSequence, Optional[Tuple[Tensor, Tensor]]) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa input, batch_sizes, sorted_indices, unsorted_indices = input @@ -315,7 +355,7 @@ def permute_hidden(self, hx, permutation): return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation) def check_forward_args(self, input, hidden, batch_sizes): - # type : (Tensor, Tuple[Tensor, Tensor], Optional[Tensor])->None + # type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor])->None self.check_input(input, batch_sizes) expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) @@ -324,6 +364,7 @@ def check_forward_args(self, input, hidden, batch_sizes): self.check_hidden_size(hidden[1], expected_hidden_size, 'Expected hidden[1] size {}, got {}') + @torch.jit.ignore def forward(self, input, hx=None): if isinstance(input, PackedSequence): return self.forward_packed(input, hx) @@ -331,5 +372,5 @@ def forward(self, input, hx=None): return self.forward_tensor(input, hx) @classmethod - def from_float(cls, mod): - return super(LSTM, cls).from_float(mod) + def from_float(cls, mod, dtype=torch.qint8): + return super(LSTM, cls).from_float(mod, dtype) diff --git a/torch/nn/quantized/functional.py b/torch/nn/quantized/functional.py index 6efbbc4167567..ceb7dd2273a97 100644 --- a/torch/nn/quantized/functional.py +++ b/torch/nn/quantized/functional.py @@ -38,7 +38,7 @@ def linear(input, weight, bias=None, scale=None, zero_point=None): Args: input (Tensor): Quantized input of type `torch.quint8` weight (Tensor): Quantized weight of type `torch.qint8` - bias (Tensor): None or Quantized bias of type `torch.qint32` + bias (Tensor): None or fp32 bias of type `torch.float` scale (double): output scale. If None, derived from the input scale zero_point (long): output zero point. If None, derived from the input zero_point @@ -53,11 +53,9 @@ def linear(input, weight, bias=None, scale=None, zero_point=None): scale = input.q_scale() if zero_point is None: zero_point = input.q_zero_point() - _packed_weight = torch.ops.quantized.fbgemm_linear_prepack(weight) - if bias is not None: - bias = torch.quantize_linear(bias.dequantize(), weight.q_scale() * input.q_scale(), 0, torch.qint32) - return torch.ops.quantized.fbgemm_linear(input, _packed_weight, bias, scale, - zero_point) + _packed_params = torch.ops.quantized.linear_prepack(weight, bias) + return torch.ops.quantized.linear(input, _packed_params, scale, + zero_point) def conv2d(input, weight, bias, stride=1, padding=0, dilation=1, groups=1, @@ -79,7 +77,7 @@ def conv2d(input, weight, bias, Args: input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)` weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)` - bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.int32`. + bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`. stride: the stride of the convolving kernel. Can be a single number or a tuple `(sH, sW)`. Default: 1 padding: implicit paddings on both sides of the input. Can be a @@ -105,8 +103,7 @@ def conv2d(input, weight, bias, >>> >>> q_filters = torch.quantize_linear(filters, scale, zero_point, dtype) >>> q_inputs = torch.quantize_linear(inputs, scale, zero_point, dtype) - >>> q_bias = torch.quantize_linear(bias, scale, zero_point, torch.quint8) - >>> qF.conv2d(q_inputs, q_filters, q_bias, scale, zero_point, padding=1) + >>> qF.conv2d(q_inputs, q_filters, bias, scale, zero_point, padding=1) """ # noqa: E501 if padding_mode != 'zeros': raise NotImplementedError("Only zero-padding is supported!") @@ -116,14 +113,12 @@ def conv2d(input, weight, bias, padding = _pair(padding) dilation = _pair(dilation) - prepacked_weight = torch.ops.quantized.fbgemm_conv_prepack( - weight.permute([0, 2, 3, 1]), stride, padding, dilation, groups) - if bias is not None: - bias = torch.quantize_linear(bias.dequantize(), scale=weight.q_scale() * input.q_scale(), zero_point=0, dtype=torch.qint32) - return torch.ops.quantized.fbgemm_conv2d(input.permute([0, 2, 3, 1]), - prepacked_weight, bias, - stride, padding, dilation, - groups, scale, zero_point).permute([0, 3, 1, 2]) + prepacked_weight = torch.ops.quantized.conv_prepack( + weight, bias, stride, padding, dilation, groups) + return torch.ops.quantized.conv2d(input, + prepacked_weight, + stride, padding, dilation, + groups, scale, zero_point) def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False): diff --git a/torch/nn/quantized/modules/conv.py b/torch/nn/quantized/modules/conv.py index e0e4010b3eb26..ce27fe37c5e56 100644 --- a/torch/nn/quantized/modules/conv.py +++ b/torch/nn/quantized/modules/conv.py @@ -14,8 +14,6 @@ from torch._ops import ops from torch.nn.modules.utils import _pair -from torch._jit_internal import Optional - class Conv2d(torch.nn.Module): r"""Applies a 2D convolution over a quantized input signal composed of several quantized input planes. @@ -54,7 +52,6 @@ class Conv2d(torch.nn.Module): """ _FLOAT_MODULE = nn.Conv2d - __annotations__ = {'bias' : Optional[torch.Tensor]} def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, @@ -84,14 +81,11 @@ def __init__(self, in_channels, out_channels, kernel_size, stride=1, self.kernel_size[1]], scale=1, zero_point=0, dtype=torch.qint8) self.weight_scale = 1.0 - self.set_weight(qweight) + bias_float = None if bias: - self.bias = torch._empty_affine_quantized([out_channels], - scale=1.0, zero_point=0, - dtype=torch.qint32) - else: - self.bias = None + bias_float = torch.zeros(out_channels, dtype=torch.float) + self.set_weight_bias(qweight, bias_float) self.scale = 1.0 self.zero_point = 0 @@ -104,35 +98,37 @@ def extra_repr(self): s += ', dilation={dilation}' if self.groups != 1: s += ', groups={groups}' - if self.bias is None: + if self.bias() is None: s += ', bias=False' return s.format(**self.__dict__) - def set_weight(self, w): - self._packed_weight = torch.ops.quantized.fbgemm_conv_prepack( - w.permute([0, 2, 3, 1]), self.stride, self.padding, self.dilation, self.groups) + def set_weight_bias(self, w, b): + # type: (torch.Tensor, Optional[torch.Tensor]) -> None + self._packed_params = torch.ops.quantized.conv_prepack( + w, b, self.stride, self.padding, self.dilation, self.groups) self.weight_scale = w.q_scale() + def _weight_bias(self): + return torch.ops.quantized.conv_unpack(self._packed_params) + def weight(self): - return torch.ops.quantized.fbgemm_conv_unpack( - self._packed_weight).permute([0, 3, 1, 2]) + (w, b) = torch.ops.quantized.conv_unpack(self._packed_params) + return w + + def bias(self): + (w, b) = torch.ops.quantized.conv_unpack(self._packed_params) + return b def forward(self, input): # Temporarily using len(shape) instead of ndim due to JIT issue # https://github.com/pytorch/pytorch/issues/23890 if len(input.shape) != 4: raise ValueError("Input shape must be `(N, C, H, W)`!") - # Temporary work around for bias - # see Issue:https://github.com/pytorch/pytorch/issues/23874 - bias = self.bias - if bias is not None: - bias = torch.quantize_linear(bias.dequantize(), self.weight_scale * input.q_scale(), 0, torch.qint32) - output = ops.quantized.fbgemm_conv2d(input.permute([0, 2, 3, 1]), - self._packed_weight, bias, - self.stride, self.padding, - self.dilation, self.groups, - self.scale, self.zero_point) - return output.permute([0, 3, 1, 2]) + return ops.quantized.conv2d(input, + self._packed_params, + self.stride, self.padding, + self.dilation, self.groups, + self.scale, self.zero_point) # ===== Serialization methods ===== # The special consideration here is that we have to unpack the weights into their @@ -141,13 +137,15 @@ def forward(self, input): # from the QTensor weight. def _save_to_state_dict(self, destination, prefix, keep_vars): super(Conv2d, self)._save_to_state_dict(destination, prefix, keep_vars) - destination[prefix + 'weight'] = self.weight() + (w, b) = self._weight_bias() + destination[prefix + 'weight'] = w destination[prefix + 'scale'] = torch.tensor(self.scale) destination[prefix + 'zero_point'] = torch.tensor(self.zero_point) - destination[prefix + 'bias'] = self.bias + destination[prefix + 'bias'] = b @torch.jit.export def __getstate__(self): + (w, b) = self._weight_bias() return ( self.in_channels, self.out_channels, @@ -159,8 +157,8 @@ def __getstate__(self): self.output_padding, self.groups, self.padding_mode, - self.weight(), - self.bias, + w, + b, self.scale, self.zero_point ) @@ -170,10 +168,8 @@ def __getstate__(self): # weight into its packed format for use by the FBGEMM ops. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - self.set_weight(state_dict[prefix + 'weight']) + self.set_weight_bias(state_dict[prefix + 'weight'], state_dict[prefix + 'bias']) state_dict.pop(prefix + 'weight') - - self.bias = state_dict[prefix + 'bias'] state_dict.pop(prefix + 'bias') @@ -199,8 +195,7 @@ def __setstate__(self, state): self.output_padding = state[7] self.groups = state[8] self.padding_mode = state[9] - self.set_weight(state[10]) - self.bias = state[11] + self.set_weight_bias(state[10], state[11]) self.scale = state[12] self.zero_point = state[13] @@ -238,9 +233,6 @@ def from_float(cls, mod): act_scale, act_zp = activation_observer.calculate_qparams() assert weight_observer.dtype == torch.qint8, 'Weight observer must have a dtype of qint8' wt_scale, wt_zp = weight_observer.calculate_qparams() - # Scale bias to activation_scale/2^16, this quantizes bias - # to about 24 bits of precision - bias_scale = float(act_scale / (2**16)) qweight = torch.quantize_linear( mod.weight.float(), @@ -248,12 +240,7 @@ def from_float(cls, mod): qconv = cls(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride, mod.padding, mod.dilation, mod.groups, mod.bias is not None, mod.padding_mode) - qconv.set_weight(qweight) - if mod.bias is not None: - qbias = torch.quantize_linear(mod.bias.float(), bias_scale, 0, torch.qint32) - else: - qbias = None - qconv.bias = qbias + qconv.set_weight_bias(qweight, mod.bias) qconv.scale = float(act_scale) qconv.zero_point = int(act_zp) diff --git a/torch/nn/quantized/modules/linear.py b/torch/nn/quantized/modules/linear.py index 6ca8eee6f99e9..ebcbae567ea58 100644 --- a/torch/nn/quantized/modules/linear.py +++ b/torch/nn/quantized/modules/linear.py @@ -100,8 +100,6 @@ class Linear(torch.nn.Module): """ _FLOAT_MODULE = nn.Linear - __annotations__ = {'bias' : Optional[torch.Tensor]} - def __init__(self, in_features, out_features, bias_=True): super(Linear, self).__init__() # We don't muck around with buffers or attributes or anything here @@ -110,34 +108,27 @@ def __init__(self, in_features, out_features, bias_=True): # deserialization modules self.in_features = in_features self.out_features = out_features + bias = None if bias_: - self.bias = torch._empty_affine_quantized( - [out_features], scale=1, zero_point=0, dtype=torch.qint32) - else: - self.bias = None + bias = torch.zeros(out_features, dtype=torch.float) + qweight = torch._empty_affine_quantized( [out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8) - self.set_weight(qweight) + self.set_weight_bias(qweight, bias) self.weight_scale = 1.0 self.scale = 1.0 self.zero_point = 0 def extra_repr(self): - return 'in_features={}, out_features={}, bias={}, scale={}, zero_point={}'.format( - self.in_features, self.out_features, self.bias is not None, self.scale, self.zero_point + return 'in_features={}, out_features={}, scale={}, zero_point={}'.format( + self.in_features, self.out_features, self.scale, self.zero_point ) def forward(self, x): - # Temporary work around for bias - # see Issue:https://github.com/pytorch/pytorch/issues/23874 - bias = self.bias - if bias is not None: - bias = torch.quantize_linear(bias.dequantize(), float(self.weight_scale) * x.q_scale(), 0, torch.qint32) - - return torch.ops.quantized.fbgemm_linear( - x, self._packed_weight, bias, self.scale, self.zero_point) + return torch.ops.quantized.linear( + x, self._packed_params, self.scale, self.zero_point) # ===== Serialization methods ===== # The special consideration here is that we have to unpack the weights into their @@ -146,18 +137,20 @@ def forward(self, x): # from the QTensor weight. def _save_to_state_dict(self, destination, prefix, keep_vars): super(Linear, self)._save_to_state_dict(destination, prefix, keep_vars) - destination[prefix + 'weight'] = self.weight() + (w, b) = self._weight_bias() + destination[prefix + 'weight'] = w destination[prefix + 'scale'] = torch.tensor(self.scale) destination[prefix + 'zero_point'] = torch.tensor(self.zero_point) - destination[prefix + 'bias'] = self.bias + destination[prefix + 'bias'] = b @torch.jit.export def __getstate__(self): + (w, b) = self._weight_bias() return ( self.in_features, self.out_features, - self.bias, - self.weight(), + b, + w, self.scale, self.zero_point ) @@ -167,10 +160,8 @@ def __getstate__(self): # weight into its packed format for use by the FBGEMM ops. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - self.set_weight(state_dict[prefix + 'weight']) + self.set_weight_bias(state_dict[prefix + 'weight'], state_dict[prefix + 'bias']) state_dict.pop(prefix + 'weight') - - self.bias = state_dict[prefix + 'bias'] state_dict.pop(prefix + 'bias') self.scale = float(state_dict[prefix + 'scale']) @@ -187,18 +178,26 @@ def __setstate__(self, state): # type: (Tuple[int, int, Optional[torch.Tensor], torch.Tensor, float, int]) -> None self.in_features = state[0] self.out_features = state[1] - self.bias = state[2] - self.set_weight(state[3]) + self.set_weight_bias(state[3], state[2]) self.scale = state[4] self.zero_point = state[5] # Function rather than property to make sure that JIT serialization doesn't # register this as an attribute + def _weight_bias(self): + return torch.ops.quantized.linear_unpack(self._packed_params) + def weight(self): - return torch.ops.quantized.fbgemm_linear_unpack(self._packed_weight) + (w, b) = torch.ops.quantized.linear_unpack(self._packed_params) + return w - def set_weight(self, w): - self._packed_weight = torch.ops.quantized.fbgemm_linear_prepack(w) + def bias(self): + (w, b) = torch.ops.quantized.linear_unpack(self._packed_params) + return b + + def set_weight_bias(self, w, b): + # type: (torch.Tensor, Optional[torch.Tensor]) -> None + self._packed_params = torch.ops.quantized.linear_prepack(w, b) self.weight_scale = w.q_scale() @classmethod @@ -230,17 +229,9 @@ def from_float(cls, mod): act_scale, act_zp = activation_observer.calculate_qparams() assert weight_observer.dtype == torch.qint8, 'Weight observer must have dtype torch.qint8' wt_scale, wt_zp = weight_observer.calculate_qparams() - # Scale bias to activation_scale/2^16, this quantizes bias - # to about 24 bits of precision - bias_scale = float(act_scale / (2**16)) qweight = torch.quantize_linear(mod.weight.float(), float(wt_scale), int(wt_zp), torch.qint8) - if mod.bias is not None: - qbias = torch.quantize_linear(mod.bias.float(), bias_scale, 0, torch.qint32) - else: - qbias = None qlinear = cls(mod.in_features, mod.out_features) - qlinear.set_weight(qweight) - qlinear.bias = qbias + qlinear.set_weight_bias(qweight, mod.bias) qlinear.scale = float(act_scale) qlinear.zero_point = int(act_zp) return qlinear diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py index 567f7f08edbce..bfb9768a65dff 100644 --- a/torch/nn/utils/rnn.py +++ b/torch/nn/utils/rnn.py @@ -76,6 +76,13 @@ def __new__(cls, data, batch_sizes=None, sorted_indices=None, unsorted_indices=N # support being called as `PackedSequence(data, batch_sizes, sorted_indices)` if batch_sizes is not None: + if batch_sizes.device.type != 'cpu': + raise ValueError( + "batch_sizes should always be on CPU. " + "Instances of PackedSequence should never be created manually. " + "They should be instantiated by functions like pack_sequence " + "and pack_padded_sequences in nn.utils.rnn. " + "https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_sequence") return super(PackedSequence, cls).__new__( cls, data, batch_sizes, sorted_indices, unsorted_indices) diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index a164d833b1ada..2c76ddc9f743d 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -11,13 +11,28 @@ # This file exports ONNX ops for opset 11 black_listed_operators = [ - "eq", "ne", "scatter", "clamp", "clamp_min", "clamp_max", "sort", "topk", "hardtanh" + "eq", "ne", "scatter", "sort", "topk", "hardtanh" ] for black_listed_op in black_listed_operators: vars()[black_listed_op] = _black_list_in_opset(black_listed_op) +def clamp(g, self, min, max): + dtype = self.type().scalarType() + + def _cast_if_not_none(tensor, dtype): + if tensor is not None and not tensor.node().mustBeNone(): + return g.op("Cast", tensor, to_i=sym_help.cast_pytorch_to_onnx[dtype]) + else: + return tensor + + if dtype is not None: + min = _cast_if_not_none(min, dtype) + max = _cast_if_not_none(max, dtype) + return g.op("Clip", self, min, max) + + @parse_args('v', 'i') def pixel_shuffle(g, self, upscale_factor): dims = self.type().sizes() @@ -46,3 +61,7 @@ def _unique2(g, self, sorted, return_inverse, return_counts): def unique_dim(g, self, dim, sorted, return_inverse, return_counts): u, indices, inverse_indices, counts = g.op("Unique", self, axis_i=dim, sorted_i=sorted, outputs=4) return u, inverse_indices, counts + + +def round(g, self): + return g.op("Round", self) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 9d4a6b965e84e..7a7115184ad9f 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -86,40 +86,30 @@ def add(g, self, other, alpha=None): # default alpha arg is to allow no-alpha add (aten add st overload no alpha) if alpha and sym_help._scalar(sym_help._maybe_get_scalar(alpha)) != 1: return _unimplemented("add", "alpha != 1") - # See Note [Pointwise by scalar] - other = sym_help._maybe_get_scalar(other) - return g.op("Add", self, sym_help._if_scalar_type_as(g, other, self)) + return g.op("Add", self, other) def sub(g, self, other, alpha=None): # default alpha arg is to allow no-alpha sub (aten sub st overload no alpha) if alpha and sym_help._scalar(sym_help._maybe_get_scalar(alpha)) != 1: return _unimplemented("sub", "alpha != 1") - # See Note [Pointwise by scalar]. Note that self or other may be scalars. - other = sym_help._maybe_get_scalar(other) - return g.op("Sub", self, sym_help._if_scalar_type_as(g, other, self)) + return g.op("Sub", self, other) def rsub(g, self, other, alpha=None): - other = sym_help._maybe_get_scalar(other) - other = sym_help._if_scalar_type_as(g, other, self) return sub(g, other, self, alpha=alpha) def mul(g, self, other): - # See Note [Pointwise by scalar] - other = sym_help._maybe_get_scalar(other) - return g.op("Mul", self, sym_help._if_scalar_type_as(g, other, self)) + return g.op("Mul", self, other) def div(g, self, other): - # See Note [Pointwise by scalar] - other = sym_help._maybe_get_scalar(other) - return g.op("Div", self, sym_help._if_scalar_type_as(g, other, self)) + return g.op("Div", self, other) def reciprocal(g, self): - return g.op("Div", sym_help._if_scalar_type_as(g, torch.ones(1), self), self) + return g.op("Div", torch.ones(1), self) @parse_args('v', 'i') @@ -137,8 +127,7 @@ def stack(g, tensor_list, dim): def mm(g, self, other): # Create a dummy C tensor. Only needed for API purposes, the value is # since beta = 0 - ty = sym_help._try_get_scalar_type(self, other).lower() - C = g.constant(0, [1], ty) + C = g.op("Constant", value_t=torch.tensor([1])) return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0) @@ -163,6 +152,10 @@ def sqrt(g, self): return g.op("Sqrt", self) +def rsqrt(g, self): + return div(g, sym_help._if_scalar_type_as(g, torch.ones(1), self), sqrt(g, self)) + + def tanh(g, self): return g.op("Tanh", self) @@ -784,8 +777,7 @@ def gt(g, input, other): def gt_impl(g, input, other): - other = sym_help._maybe_get_scalar(other) - return g.op("Greater", input, sym_help._if_scalar_type_as(g, other, input)) + return g.op("Greater", input, other) def lt(g, input, other): @@ -793,20 +785,17 @@ def lt(g, input, other): def lt_impl(g, input, other): - other = sym_help._maybe_get_scalar(other) - return g.op("Less", input, sym_help._if_scalar_type_as(g, other, input)) + return g.op("Less", input, other) @wrap_logical_op_with_negation def ge(g, input, other): - other = sym_help._maybe_get_scalar(other) - return lt_impl(g, input, sym_help._if_scalar_type_as(g, other, input)) + return lt_impl(g, input, other) @wrap_logical_op_with_negation def le(g, input, other): - other = sym_help._maybe_get_scalar(other) - return gt_impl(g, input, sym_help._if_scalar_type_as(g, other, input)) + return gt_impl(g, input, other) @wrap_logical_op_with_cast_to_and_from('Bool') @@ -1028,9 +1017,12 @@ def log(g, self): return g.op("Log", self) +def log1p(g, self): + return log(g, add(g, sym_help._if_scalar_type_as(g, torch.ones(1), self), self)) + + def pow(g, self, exponent): - exponent = sym_help._maybe_get_scalar(exponent) - return g.op("Pow", self, sym_help._if_scalar_type_as(g, exponent, self)) + return g.op("Pow", self, exponent) def clamp(g, self, min, max): @@ -1895,9 +1887,9 @@ def try_mask_to_index(index): @parse_args('v', 'is', 'i') def frobenius_norm(g, self, dim=None, keepdim=False): - sqrt = g.op('Mul', self, self) - sumsqrt = g.op('ReduceSum', sqrt, axes_i=dim, keepdims_i=keepdim) - return g.op('Sqrt', sumsqrt) + sqr = g.op('Mul', self, self) + sumsqr = g.op('ReduceSum', sqr, axes_i=dim, keepdims_i=keepdim) + return g.op('Sqrt', sumsqr) @parse_args('v', 'i', 'b', 'v') @@ -1911,3 +1903,10 @@ def multinomial(g, input, num_samples, replacement=False, generator=None): return g.op("Multinomial", log_input, dtype_i=sym_help.cast_pytorch_to_onnx['Long'], sample_size_i=num_samples) + + +def gelu(g, self): + _sqrt2 = 1.4142135623730951 + erf = g.op('Erf', div(g, self, torch.tensor(_sqrt2))) + erf_plusone = add(g, erf, g.op('Constant', value_t=torch.tensor(1, dtype=torch.float))) + return mul(g, mul(g, self, erf_plusone), g.op('Constant', value_t=torch.tensor(0.5, dtype=torch.float))) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 5cd0d1454df07..ca81bc00706ba 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -131,6 +131,10 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa graph = torch._C._jit_pass_onnx(graph, operator_export_type) torch._C._jit_pass_lint(graph) + + torch._C._jit_pass_onnx_scalar_type_analysis(graph) + torch._C._jit_pass_lint(graph) + from torch.onnx.symbolic_helper import _export_onnx_opset_version torch._C._jit_pass_onnx_peephole(graph, _export_onnx_opset_version) torch._C._jit_pass_lint(graph) diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index ea18caa276f9d..8e75a4b7667bb 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -1,8 +1,9 @@ import types import math from torch._six import inf -from functools import partial, wraps +from functools import wraps import warnings +import weakref from bisect import bisect_right from .optimizer import Optimizer @@ -29,15 +30,32 @@ def __init__(self, optimizer, last_epoch=-1): # Following https://github.com/pytorch/pytorch/issues/20124 # We would like to ensure that `lr_scheduler.step()` is called after # `optimizer.step()` - def with_counter(func, opt): + def with_counter(method): + if getattr(method, '_with_counter', False): + # `optimizer.step()` has already been replaced, return. + return method + + # Keep a weak reference to the optimizer instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + @wraps(func) def wrapper(*args, **kwargs): - opt._step_count += 1 - return func(*args, **kwargs) + instance = instance_ref() + instance._step_count += 1 + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. wrapper._with_counter = True return wrapper - self.optimizer.step = with_counter(self.optimizer.step, self.optimizer) + self.optimizer.step = with_counter(self.optimizer.step) self.optimizer._step_count = 0 self._step_count = 0 self.step(last_epoch) @@ -366,7 +384,6 @@ def __init__(self, optimizer, mode='min', factor=0.1, patience=10, self.best = None self.num_bad_epochs = None self.mode_worse = None # the worse value for the chosen mode - self.is_better = None self.eps = eps self.last_epoch = -1 self._init_is_better(mode=mode, threshold=threshold, @@ -415,20 +432,20 @@ def _reduce_lr(self, epoch): def in_cooldown(self): return self.cooldown_counter > 0 - def _cmp(self, mode, threshold_mode, threshold, a, best): - if mode == 'min' and threshold_mode == 'rel': - rel_epsilon = 1. - threshold + def is_better(self, a, best): + if self.mode == 'min' and self.threshold_mode == 'rel': + rel_epsilon = 1. - self.threshold return a < best * rel_epsilon - elif mode == 'min' and threshold_mode == 'abs': - return a < best - threshold + elif self.mode == 'min' and self.threshold_mode == 'abs': + return a < best - self.threshold - elif mode == 'max' and threshold_mode == 'rel': - rel_epsilon = threshold + 1. + elif self.mode == 'max' and self.threshold_mode == 'rel': + rel_epsilon = self.threshold + 1. return a > best * rel_epsilon else: # mode == 'max' and epsilon_mode == 'abs': - return a > best + threshold + return a > best + self.threshold def _init_is_better(self, mode, threshold, threshold_mode): if mode not in {'min', 'max'}: @@ -441,10 +458,12 @@ def _init_is_better(self, mode, threshold, threshold_mode): else: # mode == 'max': self.mode_worse = -inf - self.is_better = partial(self._cmp, mode, threshold_mode, threshold) + self.mode = mode + self.threshold = threshold + self.threshold_mode = threshold_mode def state_dict(self): - return {key: value for key, value in self.__dict__.items() if key not in {'optimizer', 'is_better'}} + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} def load_state_dict(self, state_dict): self.__dict__.update(state_dict) @@ -611,6 +630,7 @@ def __init__(self, self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum) super(CyclicLR, self).__init__(optimizer, last_epoch) + self.base_lrs = base_lrs def _format_param(self, name, optimizer, param): """Return correctly formatted lr/momentum for each param group.""" diff --git a/torch/quantization/QConfig.py b/torch/quantization/QConfig.py index 362276dab297b..614148efe8d80 100644 --- a/torch/quantization/QConfig.py +++ b/torch/quantization/QConfig.py @@ -4,14 +4,17 @@ from .fake_quantize import * QConfig = namedtuple('QConfig', - ['weight', 'activation']) + ['activation', 'weight']) -default_qconfig = QConfig(default_weight_observer(), - default_observer()) +default_qconfig = QConfig(activation=default_observer(), + weight=default_weight_observer()) + +default_debug_qconfig = QConfig(weight=default_weight_observer(), + activation=default_debug_observer()) QConfig_dynamic = namedtuple('QConfig_dynamic', ['weight']) -default_dynamic_qconfig = QConfig_dynamic(default_weight_observer()) +default_dynamic_qconfig = QConfig_dynamic(weight=default_weight_observer()) -default_qat_qconfig = QConfig(default_weight_fake_quant(), - default_fake_quant()) +default_qat_qconfig = QConfig(activation=default_fake_quant(), + weight=default_weight_fake_quant()) diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py index 8b6a8d6d1f94e..4c60294fc630c 100644 --- a/torch/quantization/__init__.py +++ b/torch/quantization/__init__.py @@ -26,7 +26,7 @@ def default_eval_fn(model, calib_data): 'Observer', 'WeightObserver', 'observer', 'default_observer', 'default_weight_observer', # QConfig - 'QConfig', 'default_qconfig', + 'QConfig', 'default_qconfig', 'default_dynamic_qconfig', # QAT utilities 'default_qat_qconfig', 'prepare_qat', 'quantize_qat', # module transformations diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py index 8cf1e15df3b8f..6faf479e3a6d4 100644 --- a/torch/quantization/observer.py +++ b/torch/quantization/observer.py @@ -1,14 +1,17 @@ from __future__ import absolute_import, division, print_function, unicode_literals -import torch -import torch.nn as nn +import math +import warnings from abc import ABCMeta, abstractmethod from functools import partial -import warnings -from torch._jit_internal import Optional +import torch +import torch.nn as nn +from torch._jit_internal import List, Optional + + +ABC = ABCMeta(str("ABC"), (object,), {}) # compatible with Python 2 *and* 3: -ABC = ABCMeta(str('ABC'), (object,), {}) # compatible with Python 2 *and* 3: class ObserverBase(ABC, nn.Module): r"""Observer base Module @@ -20,16 +23,23 @@ class ObserverBase(ABC, nn.Module): the collected statistics. """ - def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine): + def __init__( + self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False + ): super(ObserverBase, self).__init__() self.dtype = dtype self.qscheme = qscheme + self.reduce_range = reduce_range + self.eps = torch.finfo(torch.float32).eps assert self.qscheme in ( torch.per_tensor_affine, torch.per_tensor_symmetric, - ), "Default Observer only works for per_tensor_affine and \ - per_tensor_symmetric quantization scheme" + torch.per_channel_affine, + torch.per_channel_symmetric, + ), "Default Observer only works for per_tensor_affine, \ + per_tensor_symmetric, per_channel_affine and \ + per_channel_symmetric quantization scheme" assert self.dtype in ( torch.qint8, torch.quint8, @@ -43,6 +53,35 @@ def forward(self, x): def calculate_qparams(self, **kwargs): pass + def _calculate_per_channel_qparams(self, min_vals, max_vals): + # type: (Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor] + """ + Given min and max value tensors, this function calculates per channel + quantization parameters + """ + if min_vals is None or max_vals is None: + warnings.warn( + "must run observer before calling calculate_qparams.\ + Returning default scale and zero point " + ) + return torch.tensor([1.0]), torch.tensor([0]) + + for i in range(len(min_vals)): + assert ( + min_vals[i] <= max_vals[i] + ), "min {} should be less than max {}".format(min_vals[i], max_vals[i]) + + scales = torch.ones(min_vals.size()) + zero_points = torch.ones(min_vals.size()) + for i in range(len(scales)): + qparam = self._calculate_qparams( + min_vals[i], max_vals[i] + ) + scales[i] = float(qparam[0]) + zero_points[i] = int(qparam[1]) + + return scales, zero_points + def _calculate_qparams(self, min_val, max_val): # type: (Optional[Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor] """ @@ -50,8 +89,10 @@ def _calculate_qparams(self, min_val, max_val): """ if max_val is None or min_val is None: - warnings.warn("must run observer before calling calculate_qparams.\ - Returning default scale and zero point ") + warnings.warn( + "must run observer before calling calculate_qparams.\ + Returning default scale and zero point " + ) return torch.tensor([1.0]), torch.tensor([0]) assert min_val <= max_val, "min {} should be less than max {}".format( @@ -59,9 +100,15 @@ def _calculate_qparams(self, min_val, max_val): ) if self.dtype == torch.qint8: - qmin, qmax = -128, 127 + if self.reduce_range: + qmin, qmax = -64, 63 + else: + qmin, qmax = -128, 127 else: - qmin, qmax = 0, 255 + if self.reduce_range: + qmin, qmax = 0, 127 + else: + qmin, qmax = 0, 255 max_val, min_val = float(max_val), float(min_val) min_val = min(0.0, min_val) @@ -70,7 +117,7 @@ def _calculate_qparams(self, min_val, max_val): scale = 1.0 zero_point = 0 else: - if self.qscheme == torch.per_tensor_symmetric: + if self.qscheme == torch.per_tensor_symmetric or self.qscheme == torch.per_channel_symmetric: max_val = max(-min_val, max_val) scale = max_val / ((qmax - qmin) / 2) scale = max(scale, self.eps) @@ -94,12 +141,30 @@ class MinMaxObserver(ObserverBase): calculate_qparams will calculate scale and zero_point """ - __annotations__ = {'min_val' : Optional[torch.Tensor], 'max_val' : Optional[torch.Tensor]} + __annotations__ = { + "min_val": Optional[torch.Tensor], + "max_val": Optional[torch.Tensor], + } def __init__(self, **kwargs): + # For x86 quantized kernels, we need to ensure that the vpmaddubsw instruction + # does not overflow. We allow for a reduce_range argument to observers that + # reduces the quantized range to (0,127) or (-64, 63). For more details see + # aten/src/ATen/native/quantized/cpu/qconv.cpp + # This is not the optimal choice for non x86 backends as + # lose a bit of precision for activations. + # super(MinMaxObserver, self).__init__(**kwargs) self.min_val = None self.max_val = None + if ( + self.qscheme == torch.per_tensor_symmetric + and self.reduce_range + and self.dtype == torch.quint8 + ): + raise NotImplementedError( + "Cannot reduce range for symmetric quantization for quint8" + ) def forward(self, x): min_val = self.min_val @@ -120,14 +185,351 @@ def calculate_qparams(self): @torch.jit.export def extra_repr(self): - return 'min_val={}, max_val={}'.format(self.min_val, self.max_val) + return "min_val={}, max_val={}".format(self.min_val, self.max_val) + + +class PerChannelMinMaxObserver(ObserverBase): + r"""Per Channel Observer Module + The module will record the running average of max and min value for each + channel of the observed Tensor and calculate_qparams will calculate + scales and zero_points for each channel + """ + + def __init__(self, ch_axis=0, **kwargs): + super(PerChannelMinMaxObserver, self).__init__(**kwargs) + self.ch_axis = ch_axis + self.min_vals = None + self.max_vals = None + if ( + self.qscheme == torch.per_channel_symmetric + and self.reduce_range + and self.dtype == torch.quint8 + ): + raise NotImplementedError( + "Cannot reduce range for symmetric quantization for quint8" + ) + + def forward(self, x): + with torch.no_grad(): + min_vals = self.min_vals + max_vals = self.max_vals + x_dim = x.size() + + new_axis_list = list(range(len(x_dim))) + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(tuple(new_axis_list)) + y = torch.flatten(y, start_dim=1) + if min_vals is None or max_vals is None: + min_vals = torch.min(y, 1)[0] + max_vals = torch.max(y, 1)[0] + else: + min_vals = torch.min(torch.min(y, 1)[0], min_vals) + max_vals = torch.max(torch.max(y, 1)[0], max_vals) + self.min_vals = min_vals + self.max_vals = max_vals + return x + + def calculate_qparams(self): + return self._calculate_per_channel_qparams(self.min_vals, self.max_vals) + + def extra_repr(self): + return "min_val={}, max_val={}".format(self.min_vals, self.max_vals) + + + +class HistogramObserver(ObserverBase): + r""" + The module records the running histogram of tensor values along with + min/max values. calculate_qparams will calculate scale and zero_point + """ + + __annotations__ = { + "min_val": Optional[torch.Tensor], + "max_val": Optional[torch.Tensor], + "histogram": Optional[torch.Tensor], + } + + def __init__(self, bins=2048, **kwargs): + # bins: The number of bins used for histogram calculation. + super(HistogramObserver, self).__init__(**kwargs) + self.bins = bins + self.histogram = None + self.min_val = None + self.max_val = None + + @staticmethod + def _get_norm(delta_begin, delta_end, density, norm_type): + """ + Compute the norm of the values uniformaly distributed between + delta_begin and delta_end. + + norm = density * (integral_{begin, end} x^2) + = density * (end^3 - begin^3) / 3 + """ + assert norm_type == "L2", "Only L2 norms are currently supported" + norm = 0.0 + if norm_type == "L2": + norm = ( + delta_end * delta_end * delta_end + - delta_begin * delta_begin * delta_begin + ) / 3 + return density * norm + + def _compute_quantization_error(self, next_start_bin, next_end_bin, norm_type): + """ + Compute the quantization error if we use start_bin to end_bin as the + min and max to do the quantization. + """ + dst_nbins = 2 ** torch.iinfo(self.dtype).bits + bin_width = (self.max_val.item() - self.min_val.item()) / self.bins + + norm = 0.0 + dst_bin_width = bin_width * (next_end_bin - next_start_bin + 1) / dst_nbins + for src_bin in range(self.bins): + # distances from the beginning of first dst_bin to the beginning and + # end of src_bin + src_bin_begin = (src_bin - next_start_bin) * bin_width + src_bin_end = src_bin_begin + bin_width + + # which dst_bins the beginning and end of src_bin belong to? + dst_bin_of_begin = min( + dst_nbins - 1, max(0.0, math.floor(src_bin_begin / dst_bin_width)) + ) + dst_bin_of_end = min( + dst_nbins - 1, max(0.0, math.floor(src_bin_end / dst_bin_width)) + ) + dst_bin_of_begin_center = ( + dst_bin_of_begin * dst_bin_width + dst_bin_width / 2 + ) + + density = self.histogram[src_bin] / bin_width + if dst_bin_of_begin == dst_bin_of_end: + # if src_bin is entirely within 1 dst_bin + delta_begin = src_bin_begin - dst_bin_of_begin_center + delta_end = src_bin_end - dst_bin_of_begin_center + norm = norm + self._get_norm(delta_begin, delta_end, density, norm_type) + else: + delta_begin = src_bin_begin - dst_bin_of_begin_center + delta_end = dst_bin_width / 2 + norm = norm + self._get_norm(delta_begin, delta_end, density, norm_type) + + norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * self._get_norm( + -dst_bin_width / 2, dst_bin_width / 2, density, norm_type + ) + + dst_bin_of_end_center = ( + dst_bin_of_end * dst_bin_width + dst_bin_width / 2 + ) + + delta_begin = -dst_bin_width / 2 + delta_end = src_bin_end - dst_bin_of_end_center + norm = norm + self._get_norm(delta_begin, delta_end, density, norm_type) + return norm + + def _non_linear_param_search(self): + """ + An approximation for L2 error minimization for selecting min/max. + By selecting new min/max, we filter out outliers in input distribution. + This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in + caffe2/quantization/server/norm_minimization.cc + """ + assert self.histogram.size()[0] == self.bins, "bins mistmatch" + bin_width = (self.max_val - self.min_val) / self.bins + + # cumulative sum + total = sum(self.histogram) + cSum = torch.cumsum(self.histogram, dim=0) + + stepsize = 1e-5 + alpha = 0.0 + beta = 1.0 + start_bin = 0 + end_bin = self.bins - 1 + norm_min = float("inf") + + while alpha < beta: + next_alpha = alpha + stepsize + next_beta = beta - stepsize + + # find the left and right bins between the quantile bounds + l = start_bin + r = end_bin + while l < end_bin and cSum[l] < next_alpha * total: + l = l + 1 + while r > start_bin and cSum[r] > next_beta * total: + r = r - 1 + + next_start_bin = start_bin + next_end_bin = end_bin + if (l - start_bin) > (end_bin - r): + next_start_bin = l + alpha = next_alpha + else: + next_end_bin = r + beta = next_beta + + if next_start_bin == start_bin and next_end_bin == end_bin: + continue + + # calculate the quantization error using next_start_bin and next_end_bin + norm = self._compute_quantization_error(next_start_bin, next_end_bin, "L2") + + if norm > norm_min: + break + norm_min = norm + start_bin = next_start_bin + end_bin = next_end_bin + + new_min = self.min_val + bin_width * start_bin + new_max = self.min_val + bin_width * (end_bin + 1) + return new_min, new_max + + def _combine_histograms( + self, dst_histogram, dst_min, dst_max, src_histogram, src_min, src_max + ): + bins_dst = dst_histogram.size()[0] + bins_src = src_histogram.size()[0] + + dst_bin_width = (dst_max - dst_min) / bins_dst + src_bin_width = (src_max - src_min) / bins_src + + for i in range(bins_src): + src_bin_count = src_histogram[i].item() + if src_bin_count == 0: + continue + + src_bin_begin = src_min + src_bin_width * i + src_bin_end = src_bin_begin + src_bin_width + + dst_bin = 0 + if dst_bin_width: + dst_bin = int((src_bin_begin - dst_min) / dst_bin_width) + + dst_bin_begin = dst_min + dst_bin_width * dst_bin + dst_bin_end = dst_bin_begin + dst_bin_width + + dst_bin2 = 0 + if dst_bin_width: + dst_bin2 = min( + int((src_bin_end - dst_min) / dst_bin_width), bins_dst - 1 + ) + + assert dst_bin2 <= dst_bin + 2, "1 src_bin is mapped to at most 2 dst_bins" + # dst_bin_cnt is the count from src_bin that should go to dst_bin + # the remainder should go to dst_bin2 + dst_bin_cnt = 0 + if src_bin_width == 0 or dst_bin_width == 0: + dst_bin_cnt = src_bin_count + else: + # We divide counts in src_bin in proportion to range overlap with dst_bin + dst_bin_cnt = min( + round( + (dst_bin_end - src_bin_begin) / src_bin_width * src_bin_count + ), + src_bin_count, + ) + + dst_histogram[dst_bin] += dst_bin_cnt + + # remaining should go to dst_bin2 + if dst_bin_cnt < src_bin_count: + dst_histogram[dst_bin2] += src_bin_count - dst_bin_cnt + + def forward(self, x): + with torch.no_grad(): + min_val = self.min_val + max_val = self.max_val + histogram = self.histogram + if min_val is None or max_val is None or histogram is None: + min_val = torch.min(x) + max_val = torch.max(x) + self.min_val = min_val + self.max_val = max_val + self.histogram = torch.histc(x, self.bins, min=min_val, max=max_val) + else: + new_min = torch.min(x) + new_max = torch.max(x) + new_histogram = torch.histc(x, self.bins, min=new_min, max=new_max) + # combine the existing histogram and new histogram into 1 histogram + combined_histogram = torch.zeros_like(self.histogram) + combined_min = torch.min(new_min, self.min_val) + combined_max = torch.max(new_max, self.max_val) + self._combine_histograms( + combined_histogram, + combined_min.item(), + combined_max.item(), + self.histogram, + self.min_val.item(), + self.max_val.item(), + ) + self._combine_histograms( + combined_histogram, + combined_min.item(), + combined_max.item(), + new_histogram, + new_min.item(), + new_max.item(), + ) + self.histogram = combined_histogram + self.min_val = combined_min + self.max_val = combined_max + return x + + def calculate_qparams(self): + if self.histogram is None: + warnings.warn( + "must run observer before calling calculate_qparams.\ + Returning default scale and zero point " + ) + return torch.tensor([1.0]), torch.tensor([0]) + assert self.bins == len(self.histogram), ( + "The number of bins in histogram should be equal to the number of bins " + "supplied while making this observer" + ) + + new_min, new_max = self._non_linear_param_search() + + return self._calculate_qparams(new_min.item(), new_max.item()) + + +class TensorObserver(ObserverBase): + r""" + The module is mainly for debug and records the tensor values during runtime + """ + __annotations__ = {"tensor_val": List[Optional[torch.Tensor]]} + + def __init__(self, **kwargs): + super(TensorObserver, self).__init__(**kwargs) + self.tensor_val = [] + + def forward(self, x): + self.tensor_val.append(x.clone()) + return x + + @torch.jit.export + def calculate_qparams(self): + raise Exception("calculate_qparams should not be called for TensorObserver") + + @torch.jit.export + def get_tensor_value(self): + return self.tensor_val + def observer(observer_cls, **kwargs): return partial(observer_cls, **kwargs) + def default_observer(**kwargs): + # Restrict activations to be in the range (0,127) + kwargs.setdefault("reduce_range", True) return observer(MinMaxObserver, **kwargs) + +def default_debug_observer(**kwargs): + return observer(TensorObserver, **kwargs) + + def default_weight_observer(**kwargs): kwargs.setdefault("dtype", torch.qint8) kwargs.setdefault("qscheme", torch.per_tensor_symmetric) diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index 982f2be0aea84..5c0fee881330f 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function, unicode_literals import copy +import torch import torch.nn as nn import torch.nn._intrinsic as nni import torch.nn._intrinsic.quantized as nniq @@ -11,7 +12,7 @@ import torch.nn.qat as nnqat -DEFAULT_SKIP_LIST = [nn.Identity, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d] +DEFAULT_SKIP_LIST = [nn.Dropout, nn.Identity, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d] def propagate_qconfig_helper(module, qconfig_dict, skip_list=DEFAULT_SKIP_LIST, qconfig_parent=None, prefix=''): r"""This is a helper function for `propagate_qconfig` @@ -227,6 +228,7 @@ def forward(self, x): DEFAULT_DYNAMIC_MODULE_MAPPING = { nn.Linear: nnqd.Linear, + nn.LSTM: nnqd.LSTM, } def quantize(model, run_fn, run_args, mapping=DEFAULT_MODULE_MAPPING): @@ -256,9 +258,10 @@ def quantize(model, run_fn, run_args, mapping=DEFAULT_MODULE_MAPPING): DEFAULT_QCONFIG_DICT = { nn.Linear : default_dynamic_qconfig, + nn.LSTM : default_dynamic_qconfig, } -def quantize_dynamic(model, qconfig_dict=DEFAULT_QCONFIG_DICT, mapping=DEFAULT_DYNAMIC_MODULE_MAPPING): +def quantize_dynamic(model, qconfig_dict=DEFAULT_QCONFIG_DICT, mapping=DEFAULT_DYNAMIC_MODULE_MAPPING, dtype=torch.qint8): r"""Converts a float model to dynamic quantized model. Perform dynamic training and output a quantized model. @@ -266,7 +269,7 @@ def quantize_dynamic(model, qconfig_dict=DEFAULT_QCONFIG_DICT, mapping=DEFAULT_D model = copy.deepcopy(model) model.eval() propagate_qconfig(model, qconfig_dict) - convert(model, mapping) + convert(model, mapping, dtype) return model def prepare_qat(model): @@ -292,7 +295,7 @@ def quantize_qat(model, run_fn, run_args): convert(model) return model -def convert(module, mapping=DEFAULT_MODULE_MAPPING): +def convert(module, mapping=DEFAULT_MODULE_MAPPING, dtype=torch.qint8): r"""Converts the float module with observers(where we can get quantization parameters) to a quantized module. Args: @@ -309,13 +312,13 @@ def convert(module, mapping=DEFAULT_MODULE_MAPPING): for name, mod in module.named_children(): if type(mod) not in SWAPPABLE_MODULES: - convert(mod, mapping) - reassign[name] = swap_module(mod, mapping) + convert(mod, mapping, dtype) + reassign[name] = swap_module(mod, mapping, dtype) for key, value in reassign.items(): module._modules[key] = value -def swap_module(mod, mapping): +def swap_module(mod, mapping, dtype=torch.qint8): r"""Swaps the module if it has a quantized counterpart and it has an `observer` attached. @@ -329,5 +332,35 @@ def swap_module(mod, mapping): new_mod = mod if hasattr(mod, 'qconfig') and mod.qconfig is not None: if type(mod) in mapping: - new_mod = mapping[type(mod)].from_float(mod) + supported_scalar_types = [torch.qint8, torch.float16] + if dtype not in supported_scalar_types: + raise RuntimeError('Unsupported dtype: {}'.format(dtype)) + if dtype == torch.qint8: + new_mod = mapping[type(mod)].from_float(mod) + elif dtype == torch.float16: + # We want to support float16 dynamic quantization + new_mod = mapping[type(mod)].from_float(mod, dtype) return new_mod + +def dump_tensor(mod, target_dict, prefix=""): + r"""Traverse the modules and save the weight and stored activation to given dict. + This is mainly used for quantization accuracy debug + Args: + mod: the top module we want to save all tensors + prefix: the prefix for the current module + target_dict: the dictionary used to save the tensors + """ + def get_prefix(prefix): + return prefix if prefix == "" else prefix + '.' + + weight_unpack = getattr(mod, "weight", None) + if weight_unpack is not None and callable(weight_unpack): + target_dict[get_prefix(prefix) + 'weight'] = mod.weight() + elif hasattr(mod, 'weight'): + target_dict[get_prefix(prefix) + 'weight'] = mod.weight + + if hasattr(mod, 'observer'): + target_dict[get_prefix(prefix) + 'activation'] = mod.observer.get_tensor_value() + for name, child in mod.named_children(): + module_prefix = get_prefix(prefix) + name if prefix else name + dump_tensor(child, target_dict, module_prefix) diff --git a/torch/script.h b/torch/script.h index 274609df31263..b4bd7c2618d35 100644 --- a/torch/script.h +++ b/torch/script.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include diff --git a/torch/serialization.py b/torch/serialization.py index 8b168b694cb96..a3e0bc9b01ef6 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1,5 +1,4 @@ import difflib -import inspect import os import io import shutil @@ -12,6 +11,7 @@ from contextlib import closing, contextmanager from ._utils import _import_dotted_name from ._six import string_classes as _string_classes +from torch._utils_internal import get_source_lines_and_file if sys.version_info[0] == 2: import cPickle as pickle else: @@ -285,8 +285,8 @@ def persistent_id(obj): serialized_container_types[obj] = True source_file = source = None try: - source_file = inspect.getsourcefile(obj) - source = inspect.getsource(obj) + source_lines, _, source_file = get_source_lines_and_file(obj) + source = ''.join(obj) except Exception: # saving the source is optional, so we can ignore any errors warnings.warn("Couldn't retrieve source code for container of " "type " + obj.__name__ + ". It won't be checked " @@ -449,7 +449,7 @@ def restore_location(storage, location): def _check_container_source(container_type, source_file, original_source): try: - current_source = inspect.getsource(container_type) + current_source = ''.join(get_source_lines_and_file(container_type)[0]) except Exception: # saving the source is optional, so we can ignore any errors warnings.warn("Couldn't retrieve source code for container of " "type " + container_type.__name__ + ". It won't be checked " diff --git a/torch/tensor.py b/torch/tensor.py index a7724a4a33ea6..2da187d311136 100644 --- a/torch/tensor.py +++ b/torch/tensor.py @@ -1,7 +1,8 @@ import sys import torch import torch._C as _C -from torch.namedtensor import _update_names, _check_serializing_named_tensor +from torch._namedtensor_internals import update_names, check_serializing_named_tensor, resolve_ellipsis +from torch._namedtensor_internals import unzip_namedshape from collections import OrderedDict import torch.utils.hooks as hooks import warnings @@ -37,7 +38,7 @@ def __deepcopy__(self, memo): return new_tensor def __reduce_ex__(self, proto): - _check_serializing_named_tensor(self) + check_serializing_named_tensor(self) # See Note [Don't serialize hooks] torch.utils.hooks.warn_if_has_hooks(self) if self.is_quantized: @@ -427,7 +428,11 @@ def __contains__(self, element): """ if isinstance(element, (torch.Tensor, Number)): return (element == self).any().item() - return NotImplemented + + raise RuntimeError( + "Tensor.__contains__ only supports Tensor or scalar, but you passed in a %s." % + type(element) + ) @property def __cuda_array_interface__(self): @@ -481,23 +486,34 @@ def __cuda_array_interface__(self): return dict(typestr=typestr, shape=shape, strides=strides, data=data, version=1) + def refine_names(self, *names): + names = resolve_ellipsis(names, self.names, 'refine_names') + return super(Tensor, self).refine_names(names) + + def align_to(self, *names): + return super(Tensor, self).align_to(resolve_ellipsis(names, self.names, 'align_to')) + + def unflatten(self, dim, namedshape): + names, sizes = unzip_namedshape(namedshape) + return super(Tensor, self).unflatten(dim, sizes, names) + def names_(self, *names, **rename_map): - # Note [names_ / view_names API] + # Note [names_ / renamed API] # The Python API for these is different from the C++ API. In Python: - # 1) tensor.view_names(*names) takes a vararglist of names - # 2) tensor.view_names(**rename_map) takes a map of names to rename. + # 1) tensor.renamed(*names) takes a vararglist of names + # 2) tensor.renamed(**rename_map) takes a map of names to rename. # C++ is static, making it difficult to implement similar behavior. - return _update_names(self, names, rename_map, inplace=True) + return update_names(self, names, rename_map, inplace=True) - def view_names(self, *names, **rename_map): - # See Note [names_ / view_names API] - return _update_names(self, names, rename_map, inplace=False) + def renamed(self, *names, **rename_map): + # See Note [names_ / renamed API] + return update_names(self, names, rename_map, inplace=False) def _update_names(self, names, inplace): - # See Note [names_ / view_names API] + # See Note [names_ / renamed API] if inplace: return super(Tensor, self).names_(names) else: - return super(Tensor, self).view_names(names) + return super(Tensor, self).renamed(names) __module__ = 'torch' diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py index e80ac6853e945..c09db87930d1e 100644 --- a/torch/utils/data/_utils/collate.py +++ b/torch/utils/data/_utils/collate.py @@ -14,7 +14,6 @@ def default_convert(data): r"""Converts each NumPy array data field into a tensor""" - elem_type = type(data) if isinstance(data, torch.Tensor): return data @@ -28,7 +27,7 @@ def default_convert(data): elif isinstance(data, container_abcs.Mapping): return {key: default_convert(data[key]) for key in data} elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple - return elem_type(default_convert(d) for d in data) + return elem_type(*(default_convert(d) for d in data)) elif isinstance(data, container_abcs.Sequence) and not isinstance(data, string_classes): return [default_convert(d) for d in data] else: diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index b544c4b9ec66b..53955511a2b4b 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -199,10 +199,10 @@ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, drop_last = False elif batch_size is None: # no auto_collation - if shuffle or sampler is not None or drop_last: + if shuffle or drop_last: raise ValueError('batch_size=None option disables auto-batching ' 'and is mutually exclusive with ' - 'shuffle, sampler, and drop_last') + 'shuffle, and drop_last') if sampler is None: # give default samplers if self._dataset_kind == _DatasetKind.Iterable: @@ -539,7 +539,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): # `pin_memory_thread_done_event`: # A `threading.Event` for a similar purpose to that of # `workers_done_event`, but is for the `pin_memory_thread`. The reason - # that separate events are neede is that `pin_memory_thread` reads from + # that separate events are needed is that `pin_memory_thread` reads from # the output queue of the workers. But the workers, upon seeing that # `workers_done_event` is set, only wants to see the final `None`, and is # not required to flush all data in the output queue (e.g., it may call diff --git a/torch/utils/mkldnn.py b/torch/utils/mkldnn.py index b7a0d8373e365..73524f8a52998 100644 --- a/torch/utils/mkldnn.py +++ b/torch/utils/mkldnn.py @@ -150,19 +150,3 @@ def m_fn_rec(m): return new_m return m_fn_rec(module) - - -# **** WARNING: This is used to temporarily disable MKL-DNN convolution due -# to a bug: https://github.com/pytorch/pytorch/issues/23825 -# Once this bug is fixed, this context manager as well as its callsites -# should be removed! - -from contextlib import contextmanager - -@contextmanager -def disable_mkldnn_conv(): - torch._C._disable_mkldnn_conv() - try: - yield - finally: - torch._C._enable_mkldnn_conv() diff --git a/torch/utils/tensorboard/_pytorch_graph.py b/torch/utils/tensorboard/_pytorch_graph.py index b71893b5f7d61..02522912ee2c7 100644 --- a/torch/utils/tensorboard/_pytorch_graph.py +++ b/torch/utils/tensorboard/_pytorch_graph.py @@ -151,6 +151,11 @@ def populate_namespace_from_OP_to_IO(self): self.unique_name_to_scoped_name[input_node_id] = node.scopeName + '/' + input_node_id for key, node in self.nodes_io.items(): + if type(node) == NodeBase: + self.unique_name_to_scoped_name[key] = node.scope + '/' + node.debugName + if hasattr(node, 'input_or_output'): + self.unique_name_to_scoped_name[key] = node.input_or_output + '/' + node.debugName + if hasattr(node, 'scope') and node.scope is not None: self.unique_name_to_scoped_name[key] = node.scope + '/' + node.debugName if node.scope == '' and self.shallowest_scope_name: