diff --git a/.bc-linter.yml b/.bc-linter.yml index cafa3a51c3ac1..7671090ff0351 100644 --- a/.bc-linter.yml +++ b/.bc-linter.yml @@ -13,3 +13,4 @@ exclude: - "**/benchmarks/**" - "**/test_*.py" - "**/*_test.py" + - "tools/**" diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index a23c85bc60a50..d0500b89780ce 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -195,13 +195,16 @@ case "$tag" in NINJA_VERSION=1.9.0 TRITON=yes ;; - pytorch-linux-jammy-xpu-n-py3) + pytorch-linux-jammy-xpu-n-py3 | pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks) ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 VISION=yes XPU_VERSION=2025.2 NINJA_VERSION=1.9.0 TRITON=yes + if [[ $tag =~ "benchmarks" ]]; then + INDUCTOR_BENCHMARKS=yes + fi ;; pytorch-linux-jammy-py3-gcc11-inductor-benchmarks) ANACONDA_PYTHON_VERSION=3.10 diff --git a/.ci/docker/common/install_acl.sh b/.ci/docker/common/install_acl.sh index 0b865e5bc6f8d..4a98a0eaf146c 100755 --- a/.ci/docker/common/install_acl.sh +++ b/.ci/docker/common/install_acl.sh @@ -3,7 +3,7 @@ set -eux -ACL_VERSION=${ACL_VERSION:-"v25.02"} +ACL_VERSION=${ACL_VERSION:-"v52.6.0"} ACL_INSTALL_DIR="/acl" # Clone ACL diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index 481de54a50f2c..41335a0dc370f 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -49,12 +49,20 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then export SYSROOT_DEP="sysroot_linux-64=2.17" fi +# Install correct Python version +# Also ensure sysroot is using a modern GLIBC to match system compilers +if [ "$ANACONDA_PYTHON_VERSION" = "3.14" ]; then + as_jenkins conda create -n py_$ANACONDA_PYTHON_VERSION -y\ + python="3.14.0" \ + ${SYSROOT_DEP} \ + -c conda-forge +else # Install correct Python version # Also ensure sysroot is using a modern GLIBC to match system compilers as_jenkins conda create -n py_$ANACONDA_PYTHON_VERSION -y\ python="$ANACONDA_PYTHON_VERSION" \ ${SYSROOT_DEP} - +fi # libstdcxx from conda default channels are too old, we need GLIBCXX_3.4.30 # which is provided in libstdcxx 12 and up. conda_install libstdcxx-ng=12.3.0 --update-deps -c conda-forge diff --git a/.ci/docker/common/install_rocm.sh b/.ci/docker/common/install_rocm.sh index 7878311c15b08..9376d259d9cca 100644 --- a/.ci/docker/common/install_rocm.sh +++ b/.ci/docker/common/install_rocm.sh @@ -40,11 +40,7 @@ EOF # Default url values rocm_baseurl="http://repo.radeon.com/rocm/apt/${ROCM_VERSION}" - amdgpu_baseurl="https://repo.radeon.com/amdgpu/${ROCM_VERSION}/ubuntu" - - # Add amdgpu repository UBUNTU_VERSION_NAME=`cat /etc/os-release | grep UBUNTU_CODENAME | awk -F= '{print $2}'` - echo "deb [arch=amd64] ${amdgpu_baseurl} ${UBUNTU_VERSION_NAME} main" > /etc/apt/sources.list.d/amdgpu.list # Add rocm repository wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - diff --git a/.ci/docker/common/install_rocm_magma.sh b/.ci/docker/common/install_rocm_magma.sh index 9bf45e6f1b0a9..2d03c6186b8e5 100644 --- a/.ci/docker/common/install_rocm_magma.sh +++ b/.ci/docker/common/install_rocm_magma.sh @@ -12,8 +12,8 @@ function do_install() { rocm_version_nodot=${rocm_version//./} - # https://github.com/icl-utk-edu/magma/pull/65 - MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec + # post merge of https://github.com/icl-utk-edu/magma/pull/65 + MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f magma_archive="magma-rocm${rocm_version_nodot}-${MAGMA_VERSION}-1.tar.bz2" rocm_dir="/opt/rocm" diff --git a/.ci/docker/manywheel/Dockerfile_2_28 b/.ci/docker/manywheel/Dockerfile_2_28 index 4803cb778c905..bcc249633faa5 100644 --- a/.ci/docker/manywheel/Dockerfile_2_28 +++ b/.ci/docker/manywheel/Dockerfile_2_28 @@ -149,7 +149,7 @@ FROM cpu_final as rocm_final ARG ROCM_VERSION=6.0 ARG PYTORCH_ROCM_ARCH ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH} -ARG DEVTOOLSET_VERSION=11 +ARG DEVTOOLSET_VERSION=13 ENV LDFLAGS="-Wl,-rpath=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64 -Wl,-rpath=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib" # Somewhere in ROCm stack, we still use non-existing /opt/rocm/hip path, # below workaround helps avoid error diff --git a/.ci/docker/manywheel/build.sh b/.ci/docker/manywheel/build.sh index b4b5059973037..ac385ce4b29fd 100755 --- a/.ci/docker/manywheel/build.sh +++ b/.ci/docker/manywheel/build.sh @@ -97,7 +97,7 @@ case ${image} in manylinux2_28-builder:xpu) TARGET=xpu_final GPU_IMAGE=amd64/almalinux:8 - DOCKER_GPU_BUILD_ARG=" --build-arg DEVTOOLSET_VERSION=11" + DOCKER_GPU_BUILD_ARG=" --build-arg DEVTOOLSET_VERSION=13" MANY_LINUX_VERSION="2_28" ;; *) diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 93d32b803b199..bdc34b4864cd7 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -136,10 +136,11 @@ numba==0.61.2 ; python_version > "3.9" #test_nn.py, test_namedtensor.py, test_linalg.py, test_jit_cuda_fuser.py, #test_jit.py, test_indexing.py, test_datapipe.py, test_dataloader.py, #test_binary_ufuncs.py -numpy==2.0.2 ; python_version == "3.9" -numpy==2.1.2 ; python_version > "3.9" +numpy==2.1.2; python_version > "3.9" and python_version < "3.14" +numpy==2.3.4; python_version >= "3.14" -pandas==2.2.3 +pandas==2.2.3; python_version >= "3.9" and python_version < "3.14" +pandas==2.3.3; python_version >= "3.14" #onnxruntime #Description: scoring engine for Open Neural Network Exchange (ONNX) models @@ -151,7 +152,8 @@ opt-einsum==3.3 #Pinned versions: 3.3 #test that import: test_linalg.py -optree==0.13.0 +optree==0.13.0 ; python_version < "3.14" +optree==0.17.0 ; python_version >= "3.14" #Description: A library for tree manipulation #Pinned versions: 0.13.0 #test that import: test_vmap.py, test_aotdispatch.py, test_dynamic_shapes.py, @@ -249,8 +251,8 @@ scikit-image==0.22.0 #Pinned versions: 0.20.3 #test that import: -scipy==1.13.1 ; python_version == "3.9" -scipy==1.14.1 ; python_version > "3.9" +scipy==1.14.1 ; python_version > "3.9" and python_version < "3.14" +scipy==1.16.2 ; python_version >= "3.14" # Pin SciPy because of failing distribution tests (see #60347) #Description: scientific python #Pinned versions: 1.10.1 @@ -321,7 +323,8 @@ pywavelets==1.7.0 ; python_version >= "3.12" #Pinned versions: 1.4.1 #test that import: -lxml==5.3.0 +lxml==5.3.0 ; python_version < "3.14" +lxml==6.0.2 ; python_version >= "3.14" #Description: This is a requirement of unittest-xml-reporting PyGithub==2.3.0 @@ -331,7 +334,9 @@ sympy==1.13.3 #Pinned versions: #test that import: -onnx==1.19.1 +onnx==1.19.1 ; python_version < "3.14" +# Unpin once Python 3.14 is supported. See onnxruntime issue 26309. +onnx==1.18.0 ; python_version == "3.14" #Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal #Pinned versions: #test that import: @@ -356,7 +361,7 @@ pwlf==2.2.1 #test that import: test_sac_estimator.py # To build PyTorch itself -pyyaml==6.0.2 +pyyaml==6.0.3 pyzstd setuptools==78.1.1 packaging==23.1 diff --git a/.ci/docker/ubuntu-xpu/Dockerfile b/.ci/docker/ubuntu-xpu/Dockerfile index 8765249688ce5..af11992a91646 100644 --- a/.ci/docker/ubuntu-xpu/Dockerfile +++ b/.ci/docker/ubuntu-xpu/Dockerfile @@ -54,12 +54,15 @@ ENV OPENSSL_DIR /opt/openssl RUN rm install_openssl.sh ARG INDUCTOR_BENCHMARKS +ARG ANACONDA_PYTHON_VERSION +ENV ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps.sh COPY ./common/common_utils.sh common_utils.sh COPY ci_commit_pins/huggingface-requirements.txt huggingface-requirements.txt COPY ci_commit_pins/timm.txt timm.txt +COPY ci_commit_pins/torchbench.txt torchbench.txt RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi -RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt +RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt # Install XPU Dependencies ARG XPU_VERSION diff --git a/.ci/lumen_cli/pyproject.toml b/.ci/lumen_cli/pyproject.toml index 59948edabb53a..9f5b87634a3f3 100644 --- a/.ci/lumen_cli/pyproject.toml +++ b/.ci/lumen_cli/pyproject.toml @@ -6,7 +6,7 @@ dependencies = [ "GitPython==3.1.45", "docker==7.1.0", "pytest==7.3.2", - "uv==0.9.5" + "uv==0.9.6" ] [tool.setuptools] diff --git a/.ci/magma-rocm/Makefile b/.ci/magma-rocm/Makefile index 9fca7ad544617..0e71b0467a9ef 100644 --- a/.ci/magma-rocm/Makefile +++ b/.ci/magma-rocm/Makefile @@ -1,7 +1,7 @@ SHELL=/usr/bin/env bash DOCKER_CMD ?= docker -DESIRED_ROCM ?= 7.0 +DESIRED_ROCM ?= 7.1 DESIRED_ROCM_SHORT = $(subst .,,$(DESIRED_ROCM)) PACKAGE_NAME = magma-rocm # inherit this from underlying docker image, do not pass this env var to docker @@ -16,6 +16,7 @@ DOCKER_RUN = set -eou pipefail; ${DOCKER_CMD} run --rm -i \ magma-rocm/build_magma.sh .PHONY: all +all: magma-rocm71 all: magma-rocm70 all: magma-rocm64 @@ -24,6 +25,11 @@ clean: $(RM) -r magma-* $(RM) -r output +.PHONY: magma-rocm71 +magma-rocm71: DESIRED_ROCM := 7.1 +magma-rocm71: + $(DOCKER_RUN) + .PHONY: magma-rocm70 magma-rocm70: DESIRED_ROCM := 7.0 magma-rocm70: diff --git a/.ci/magma-rocm/build_magma.sh b/.ci/magma-rocm/build_magma.sh index c7c7780227ea5..7d95fed873dc0 100755 --- a/.ci/magma-rocm/build_magma.sh +++ b/.ci/magma-rocm/build_magma.sh @@ -6,8 +6,8 @@ set -eou pipefail # The script expects DESIRED_CUDA and PACKAGE_NAME to be set ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" -# https://github.com/icl-utk-edu/magma/pull/65 -MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec +# post merge of https://github.com/icl-utk-edu/magma/pull/65 +MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f # Folders for the build PACKAGE_FILES=${ROOT_DIR}/magma-rocm/package_files # metadata @@ -20,7 +20,7 @@ mkdir -p ${PACKAGE_DIR} ${PACKAGE_OUTPUT}/linux-64 ${PACKAGE_BUILD} ${PACKAGE_RE # Fetch magma sources and verify checksum pushd ${PACKAGE_DIR} -git clone https://github.com/jeffdaily/magma +git clone https://github.com/icl-utk-edu/magma pushd magma git checkout ${MAGMA_VERSION} popd diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index cae81a2568d5c..d66aa1120fb30 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -426,7 +426,7 @@ fi if [[ "$BUILD_ENVIRONMENT" != *libtorch* && "$BUILD_ENVIRONMENT" != *bazel* ]]; then # export test times so that potential sharded tests that'll branch off this build will use consistent data # don't do this for libtorch as libtorch is C++ only and thus won't have python tests run on its build - python tools/stats/export_test_times.py + PYTHONPATH=. python tools/stats/export_test_times.py fi # don't do this for bazel or s390x or riscv64 as they don't use sccache if [[ "$BUILD_ENVIRONMENT" != *s390x* && "$BUILD_ENVIRONMENT" != *riscv64* && "$BUILD_ENVIRONMENT" != *-bazel-* ]]; then diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 40dc90f2eb24f..26996b5a32d56 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -572,6 +572,8 @@ fi if [[ "${TEST_CONFIG}" == *cpu* ]]; then DYNAMO_BENCHMARK_FLAGS+=(--device cpu) +elif [[ "${TEST_CONFIG}" == *xpu* ]]; then + DYNAMO_BENCHMARK_FLAGS+=(--device xpu) else DYNAMO_BENCHMARK_FLAGS+=(--device cuda) fi @@ -665,6 +667,8 @@ test_perf_for_dashboard() { device=cuda_b200 elif [[ "${TEST_CONFIG}" == *rocm* ]]; then device=rocm + elif [[ "${TEST_CONFIG}" == *xpu* ]]; then + device=xpu fi for mode in "${modes[@]}"; do @@ -1649,7 +1653,7 @@ test_operator_microbenchmark() { cd "${TEST_DIR}"/benchmarks/operator_benchmark - for OP_BENCHMARK_TESTS in matmul mm addmm bmm; do + for OP_BENCHMARK_TESTS in matmul mm addmm bmm conv; do $TASKSET python -m pt.${OP_BENCHMARK_TESTS}_test --tag-filter long \ --output-json-for-dashboard "${TEST_REPORTS_DIR}/operator_microbenchmark_${OP_BENCHMARK_TESTS}_compile.json" \ --benchmark-name "PyTorch operator microbenchmark" --use-compile @@ -1757,7 +1761,7 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then else # Do this after checkout_install_torchbench to ensure we clobber any # nightlies that torchbench may pull in - if [[ "${TEST_CONFIG}" != *cpu* ]]; then + if [[ "${TEST_CONFIG}" != *cpu* && "${TEST_CONFIG}" != *xpu* ]]; then install_torchrec_and_fbgemm fi PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id" diff --git a/.clang-tidy b/.clang-tidy index 71ffdf8cb224c..2b8eb00bb03f5 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -60,9 +60,11 @@ performance-*, readability-container-size-empty, readability-delete-null-pointer, readability-duplicate-include, +readability-named-parameter, readability-misplaced-array-index, readability-redundant*, readability-simplify-subscript-expr, +readability-static-definition-in-anonymous-namespace readability-string-compare, -readability-redundant-access-specifiers, -readability-redundant-control-flow, diff --git a/.claude/skills/add-uint-support/SKILL.md b/.claude/skills/add-uint-support/SKILL.md new file mode 100644 index 0000000000000..a4859fdeae55c --- /dev/null +++ b/.claude/skills/add-uint-support/SKILL.md @@ -0,0 +1,319 @@ +--- +name: add-uint-support +description: Add unsigned integer (uint) type support to PyTorch operators by updating AT_DISPATCH macros. Use when adding support for uint16, uint32, uint64 types to operators, kernels, or when user mentions enabling unsigned types, barebones unsigned types, or uint support. +--- + +# Add Unsigned Integer (uint) Support to Operators + +This skill helps add support for unsigned integer types (uint16, uint32, uint64) to PyTorch operators by updating their AT_DISPATCH macros. + +## When to use this skill + +Use this skill when: +- Adding uint16, uint32, or uint64 support to an operator +- User mentions "unsigned types", "uint support", "barebones unsigned types" +- Enabling support for kUInt16, kUInt32, kUInt64 in kernels +- Working with operator implementations that need expanded type coverage + +## Quick reference + +**Add unsigned types to existing dispatch:** +```cpp +// Before +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_ALL_TYPES)); + +// After (method 1: add unsigned types explicitly) +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + +// After (method 2: use V2 integral types if AT_INTEGRAL_TYPES present) +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES)); +``` + +## Type group reference + +**Unsigned type groups:** +- `AT_BAREBONES_UNSIGNED_TYPES`: kUInt16, kUInt32, kUInt64 +- `AT_INTEGRAL_TYPES_V2`: AT_INTEGRAL_TYPES + AT_BAREBONES_UNSIGNED_TYPES + +**Relationship:** +```cpp +AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort +AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64 +AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + BAREBONES_UNSIGNED_TYPES +``` + +## Instructions + +### Step 1: Determine if conversion to V2 is needed + +Check if the file uses AT_DISPATCH_V2: + +**If using old AT_DISPATCH:** +- First convert to AT_DISPATCH_V2 using the at-dispatch-v2 skill +- Then proceed with adding uint support + +**If already using AT_DISPATCH_V2:** +- Proceed directly to Step 2 + +### Step 2: Analyze the current dispatch macro + +Identify what type groups are currently in use: + +```cpp +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + // body +}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16); + ^^^^^^^^^^^^^^^^^^^^^^^^^ + Current type coverage +``` + +Common patterns: +- `AT_EXPAND(AT_ALL_TYPES)` → includes AT_INTEGRAL_TYPES + AT_FLOATING_TYPES +- `AT_EXPAND(AT_INTEGRAL_TYPES)` → signed integers only +- `AT_EXPAND(AT_FLOATING_TYPES)` → floating point types + +### Step 3: Choose the uint addition method + +Two approaches: + +**Method 1: Add AT_BAREBONES_UNSIGNED_TYPES explicitly** +- Use when: You want to be explicit about adding uint support +- Add `AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)` to the type list + +**Method 2: Substitute AT_INTEGRAL_TYPES with AT_INTEGRAL_TYPES_V2** +- Use when: The dispatch already uses `AT_EXPAND(AT_INTEGRAL_TYPES)` +- More concise: replaces one type group with its superset +- Only applicable if AT_INTEGRAL_TYPES is present + +### Step 4: Apply the transformation + +**Method 1 example:** +```cpp +// Before +AT_DISPATCH_V2( + dtype, + "min_values_cuda", + AT_WRAP([&]() { + kernel_impl(iter); + }), + AT_EXPAND(AT_ALL_TYPES), + kBFloat16, kHalf, kBool +); + +// After (add unsigned types) +AT_DISPATCH_V2( + dtype, + "min_values_cuda", + AT_WRAP([&]() { + kernel_impl(iter); + }), + AT_EXPAND(AT_ALL_TYPES), + AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), + kBFloat16, kHalf, kBool +); +``` + +**Method 2 example:** +```cpp +// Before +AT_DISPATCH_V2( + dtype, + "integral_op", + AT_WRAP([&]() { + kernel(); + }), + AT_EXPAND(AT_INTEGRAL_TYPES) +); + +// After (substitute with V2) +AT_DISPATCH_V2( + dtype, + "integral_op", + AT_WRAP([&]() { + kernel(); + }), + AT_EXPAND(AT_INTEGRAL_TYPES_V2) +); +``` + +### Step 5: Handle AT_ALL_TYPES vs individual type groups + +If the dispatch uses `AT_EXPAND(AT_ALL_TYPES)`: +- `AT_ALL_TYPES` = `AT_INTEGRAL_TYPES` + `AT_FLOATING_TYPES` +- To add uint: add `AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)` to the list + +If the dispatch separately lists INTEGRAL and FLOATING: +```cpp +// Before +AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES) + +// After (Method 2 preferred) +AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES) +``` + +### Step 6: Verify all dispatch sites + +Check the file for ALL dispatch macros that need uint support: +- Some operators have multiple dispatch sites (CPU, CUDA, different functions) +- Apply the transformation consistently across all sites +- Ensure each gets the same type coverage updates + +### Step 7: Validate the changes + +Check that: +- [ ] AT_DISPATCH_V2 format is used (not old AT_DISPATCH) +- [ ] Unsigned types are added via one of the two methods +- [ ] All relevant dispatch sites in the file are updated +- [ ] Type groups use `AT_EXPAND()` +- [ ] Arguments are properly formatted and comma-separated + +## Common patterns + +### Pattern 1: AT_ALL_TYPES + extras + +```cpp +// Before +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16); + +// After +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16); +``` + +### Pattern 2: Separate INTEGRAL + FLOATING + +```cpp +// Before +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)); + +// After +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_INTEGRAL_TYPES_V2), AT_EXPAND(AT_FLOATING_TYPES)); +``` + +### Pattern 3: Old dispatch needs conversion first + +```cpp +// Before (needs v2 conversion first) +AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() { + kernel(); +}); + +// After v2 conversion +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16); + +// After adding uint support +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16); +``` + +## Multiple dispatch sites example + +For a file with multiple functions: + +```cpp +void min_values_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() { + impl(iter); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf); + // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + // Added uint support +} + +void min_launch_kernel(TensorIterator &iter) { + AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() { + gpu_reduce_kernel(iter); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf); + // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + // Added uint support here too +} +``` + +## Decision tree + +Use this decision tree to determine the approach: + +``` +Is the file using AT_DISPATCH_V2? +├─ No → Use at-dispatch-v2 skill first, then continue +└─ Yes + └─ Does it use AT_EXPAND(AT_INTEGRAL_TYPES)? + ├─ Yes → Replace with AT_EXPAND(AT_INTEGRAL_TYPES_V2) + └─ No → Add AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES) to type list +``` + +## Edge cases + +### Case 1: Dispatch with only floating types + +If the operator only supports floating point types, don't add uint support: + +```cpp +// Leave as-is - floating point only operator +AT_DISPATCH_V2(dtype, "float_op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_FLOATING_TYPES), kHalf); +``` + +### Case 2: Complex types present + +Unsigned types work alongside complex types: + +```cpp +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_ALL_TYPES), + AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), + AT_EXPAND(AT_COMPLEX_TYPES), + kHalf, kBFloat16); +``` + +### Case 3: Already has uint support + +Check if uint types are already present: +- If `AT_INTEGRAL_TYPES_V2` is used → already has uint support +- If `AT_BAREBONES_UNSIGNED_TYPES` is already in list → already has uint support +- Skip the file if uint support is already present + +## Workflow + +When asked to add uint support: + +1. Read the target file +2. Check if using AT_DISPATCH_V2: + - If not → use at-dispatch-v2 skill first +3. Identify all dispatch macro sites +4. For each dispatch: + - Analyze current type groups + - Choose method (add BAREBONES_UNSIGNED or upgrade to V2) + - Apply transformation with Edit tool +5. Show the user the changes +6. Explain what was modified + +## Important notes + +- Always check if v2 conversion is needed first +- Apply changes consistently across all dispatch sites in the file +- Method 2 (AT_INTEGRAL_TYPES_V2) is cleaner when applicable +- Method 1 (explicit AT_BAREBONES_UNSIGNED_TYPES) is more explicit +- Unsigned types are: kUInt16, kUInt32, kUInt64 (not kByte which is uint8) +- Some operators may not semantically support unsigned types - use judgment + +## Testing + +After adding uint support, the operator should accept uint16, uint32, and uint64 tensors. The user is responsible for functional testing. diff --git a/.claude/skills/at-dispatch-v2/SKILL.md b/.claude/skills/at-dispatch-v2/SKILL.md new file mode 100644 index 0000000000000..eb9946c1d03b2 --- /dev/null +++ b/.claude/skills/at-dispatch-v2/SKILL.md @@ -0,0 +1,305 @@ +--- +name: at-dispatch-v2 +description: Convert PyTorch AT_DISPATCH macros to AT_DISPATCH_V2 format in ATen C++ code. Use when porting AT_DISPATCH_ALL_TYPES_AND*, AT_DISPATCH_FLOATING_TYPES*, or other dispatch macros to the new v2 API. For ATen kernel files, CUDA kernels, and native operator implementations. +--- + +# AT_DISPATCH to AT_DISPATCH_V2 Converter + +This skill helps convert PyTorch's legacy AT_DISPATCH macros to the new AT_DISPATCH_V2 format, as defined in `aten/src/ATen/Dispatch_v2.h`. + +## When to use this skill + +Use this skill when: +- Converting AT_DISPATCH_* macros to AT_DISPATCH_V2 +- Porting ATen kernels to use the new dispatch API +- Working with files in `aten/src/ATen/native/` that use dispatch macros +- User mentions "AT_DISPATCH", "dispatch v2", "Dispatch_v2.h", or macro conversion + +## Quick reference + +**Old format:** +```cpp +AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, dtype, "kernel_name", [&]() { + // lambda body +}); +``` + +**New format:** +```cpp +AT_DISPATCH_V2(dtype, "kernel_name", AT_WRAP([&]() { + // lambda body +}), AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool); +``` + +## Key transformations + +1. **Reorder arguments**: `scalar_type` and `name` come first, then lambda, then types +2. **Wrap the lambda**: Use `AT_WRAP(lambda)` to handle internal commas +3. **Expand type groups**: Use `AT_EXPAND(AT_ALL_TYPES)` instead of implicit expansion +4. **List individual types**: Add extra types (kHalf, kBFloat16, etc.) after expanded groups +5. **Add include**: `#include ` near other Dispatch includes + +## Instructions + +### Step 1: Add the Dispatch_v2.h include + +Add the v2 header near the existing `#include `: + +```cpp +#include +#include +``` + +Keep the old Dispatch.h include for now (other code may still need it). + +### Step 2: Identify the old dispatch pattern + +Common patterns to convert: + +- `AT_DISPATCH_ALL_TYPES_AND{2,3,4}(type1, type2, ..., scalar_type, name, lambda)` +- `AT_DISPATCH_FLOATING_TYPES_AND{2,3}(type1, type2, ..., scalar_type, name, lambda)` +- `AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND{2,3}(type1, ..., scalar_type, name, lambda)` +- `AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3}(type1, ..., scalar_type, name, lambda)` + +### Step 3: Map the old macro to type groups + +Identify which type group macro corresponds to the base types: + +| Old macro base | AT_DISPATCH_V2 type group | +|----------------|---------------------------| +| `ALL_TYPES` | `AT_EXPAND(AT_ALL_TYPES)` | +| `FLOATING_TYPES` | `AT_EXPAND(AT_FLOATING_TYPES)` | +| `INTEGRAL_TYPES` | `AT_EXPAND(AT_INTEGRAL_TYPES)` | +| `COMPLEX_TYPES` | `AT_EXPAND(AT_COMPLEX_TYPES)` | +| `ALL_TYPES_AND_COMPLEX` | `AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX)` | + +For combined patterns, use multiple `AT_EXPAND()` entries: +```cpp +// Old: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(...) +// New: AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), type1, type2 +``` + +### Step 4: Extract the individual types + +From `AT_DISPATCH_*_AND2(type1, type2, ...)` or `AT_DISPATCH_*_AND3(type1, type2, type3, ...)`, extract the individual types (type1, type2, etc.). + +These become the trailing arguments after the type group: +```cpp +AT_DISPATCH_V2(..., AT_EXPAND(AT_ALL_TYPES), kBFloat16, kHalf, kBool) + ^^^^^^^^^^^^^^^^^^^^^^^^ + Individual types from AND3 +``` + +### Step 5: Transform to AT_DISPATCH_V2 + +Apply the transformation: + +**Pattern:** +```cpp +AT_DISPATCH_V2( + scalar_type, // 1st: The dtype expression + "name", // 2nd: The debug string + AT_WRAP(lambda), // 3rd: The lambda wrapped in AT_WRAP + type_groups, // 4th+: Type groups with AT_EXPAND() + individual_types // Last: Individual types +) +``` + +**Example transformation:** +```cpp +// BEFORE +AT_DISPATCH_ALL_TYPES_AND3( + kBFloat16, kHalf, kBool, + iter.dtype(), + "min_values_cuda", + [&]() { + min_values_kernel_cuda_impl(iter); + } +); + +// AFTER +AT_DISPATCH_V2( + iter.dtype(), + "min_values_cuda", + AT_WRAP([&]() { + min_values_kernel_cuda_impl(iter); + }), + AT_EXPAND(AT_ALL_TYPES), + kBFloat16, kHalf, kBool +); +``` + +### Step 6: Handle multi-line lambdas + +For lambdas with internal commas or complex expressions, AT_WRAP is essential: + +```cpp +AT_DISPATCH_V2( + dtype, + "complex_kernel", + AT_WRAP([&]() { + gpu_reduce_kernel( + iter, + MinOps{}, + thrust::pair(upper_bound(), 0) // Commas inside! + ); + }), + AT_EXPAND(AT_ALL_TYPES) +); +``` + +### Step 7: Verify the conversion + +Check that: +- [ ] `AT_WRAP()` wraps the entire lambda +- [ ] Type groups use `AT_EXPAND()` +- [ ] Individual types don't have `AT_EXPAND()` (just `kBFloat16`, not `AT_EXPAND(kBFloat16)`) +- [ ] Argument order is: scalar_type, name, lambda, types +- [ ] Include added: `#include ` + +## Type group reference + +Available type group macros (use with `AT_EXPAND()`): + +```cpp +AT_INTEGRAL_TYPES // kByte, kChar, kInt, kLong, kShort +AT_FLOATING_TYPES // kDouble, kFloat +AT_COMPLEX_TYPES // kComplexDouble, kComplexFloat +AT_QINT_TYPES // kQInt8, kQUInt8, kQInt32 +AT_ALL_TYPES // INTEGRAL_TYPES + FLOATING_TYPES +AT_ALL_TYPES_AND_COMPLEX // ALL_TYPES + COMPLEX_TYPES +AT_INTEGRAL_TYPES_V2 // INTEGRAL_TYPES + unsigned types +AT_BAREBONES_UNSIGNED_TYPES // kUInt16, kUInt32, kUInt64 +AT_FLOAT8_TYPES // Float8 variants +``` + +## Common patterns + +### Pattern: AT_DISPATCH_ALL_TYPES_AND2 + +```cpp +// Before +AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, dtype, "op", [&]() { + kernel(data); +}); + +// After +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(data); +}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBFloat16); +``` + +### Pattern: AT_DISPATCH_FLOATING_TYPES_AND3 + +```cpp +// Before +AT_DISPATCH_FLOATING_TYPES_AND3(kHalf, kBFloat16, kFloat8_e4m3fn, + tensor.scalar_type(), "float_op", [&] { + process(tensor); +}); + +// After +AT_DISPATCH_V2(tensor.scalar_type(), "float_op", AT_WRAP([&] { + process(tensor); +}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn); +``` + +### Pattern: AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2 + +```cpp +// Before +AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + kComplexHalf, kHalf, + self.scalar_type(), + "complex_op", + [&] { + result = compute(self); + } +); + +// After +AT_DISPATCH_V2( + self.scalar_type(), + "complex_op", + AT_WRAP([&] { + result = compute(self); + }), + AT_EXPAND(AT_ALL_TYPES), + AT_EXPAND(AT_COMPLEX_TYPES), + kComplexHalf, + kHalf +); +``` + +## Edge cases + +### Case 1: No extra types (rare) + +```cpp +// Before +AT_DISPATCH_ALL_TYPES(dtype, "op", [&]() { kernel(); }); + +// After +AT_DISPATCH_V2(dtype, "op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_ALL_TYPES)); +``` + +### Case 2: Many individual types (AND4, AND5, etc.) + +```cpp +// Before +AT_DISPATCH_FLOATING_TYPES_AND4(kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2, + dtype, "float8_op", [&]() { kernel(); }); + +// After +AT_DISPATCH_V2(dtype, "float8_op", AT_WRAP([&]() { + kernel(); +}), AT_EXPAND(AT_FLOATING_TYPES), kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2); +``` + +### Case 3: Lambda with no captures + +```cpp +// Before +AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, dtype, "op", []() { + static_kernel(); +}); + +// After +AT_DISPATCH_V2(dtype, "op", AT_WRAP([]() { + static_kernel(); +}), AT_EXPAND(AT_ALL_TYPES), kHalf, kBool); +``` + +## Benefits of AT_DISPATCH_V2 + +1. **No arity in macro name**: Don't need different macros for AND2, AND3, AND4 +2. **Composable type sets**: Mix and match type groups with `AT_EXPAND()` +3. **Extensible**: Easy to add more types without hitting macro limits +4. **Clearer**: Type groups are explicit, not implicit in macro name + +## Important notes + +- Keep `#include ` - other code may need it +- The `AT_WRAP()` is mandatory - prevents comma parsing issues in the lambda +- Type groups need `AT_EXPAND()`, individual types don't +- The v2 API is in `aten/src/ATen/Dispatch_v2.h` - refer to it for full docs +- See the header file for the Python script to regenerate the macro implementation + +## Workflow + +When asked to convert AT_DISPATCH macros: + +1. Read the file to identify all AT_DISPATCH uses +2. Add `#include ` if not present +3. For each dispatch macro: + - Identify the pattern and extract components + - Map the base type group + - Extract individual types + - Construct the AT_DISPATCH_V2 call + - Apply with Edit tool +4. Show the user the complete converted file +5. Explain what was changed + +Do NOT compile or test the code - focus on accurate conversion only. diff --git a/.github/actions/diskspace-cleanup/action.yml b/.github/actions/diskspace-cleanup/action.yml index 7291adb59a18d..602b6946f5ec1 100644 --- a/.github/actions/diskspace-cleanup/action.yml +++ b/.github/actions/diskspace-cleanup/action.yml @@ -27,7 +27,9 @@ runs: docker system prune -af diskspace_new=$(df -H --output=pcent ${docker_root_dir} | sed -n 2p | sed 's/%//' | sed 's/ //') if [[ "$diskspace_new" -gt "$diskspace_cutoff" ]] ; then - echo "Error: Available diskspace is less than $diskspace_cutoff percent. Not enough diskspace." + diskspace_cutoff_int=$((diskspace_cutoff + 0)) + difference=$((100 - diskspace_cutoff_int)) + echo "Error: Available diskspace is less than $difference percent. Not enough diskspace." echo "$msg" exit 1 else diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 8af554d56ee57..966f6bcfc0d94 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -69bbe7363897764f9e758d851cd0340147d27f94 +3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2 diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index 5d9b8d5d171ef..183e9fb4b06e1 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -218d2ab791d437309f91e0486eb9fa7f00badc17 +cfbc5c2f1c798991715a6b06bb3ce46478c4487c diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 280d5ab77009f..01f0673fcf802 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -df6798dfb931ce7c7fe5bed2447cd1092a5981af +c8b09f5f77d6bf6fb7ed7a9aa83e5d8156b3a5e9 diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index 74b0d243859a2..c15ba606398f6 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -19,6 +19,7 @@ ciflow_push_tags: - ciflow/inductor-perf-test-nightly-rocm-mi300 - ciflow/inductor-perf-test-nightly-rocm-mi355 - ciflow/inductor-perf-test-nightly-x86-zen +- ciflow/inductor-perf-test-nightly-xpu - ciflow/inductor-periodic - ciflow/inductor-rocm - ciflow/linux-aarch64 @@ -26,6 +27,7 @@ ciflow_push_tags: - ciflow/nightly - ciflow/op-benchmark - ciflow/periodic +- ciflow/periodic-rocm-mi200 - ciflow/periodic-rocm-mi300 - ciflow/pull - ciflow/quantization-periodic diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index fd04922f39999..0db11452873fe 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -11,11 +11,17 @@ * Latest XPU """ +import json import os +import re +from pathlib import Path from typing import Optional -# NOTE: Please also update the CUDA sources in `PIP_SOURCES` in tools/nightly.py when changing this +SCRIPT_DIR = Path(__file__).absolute().parent +REPO_ROOT = SCRIPT_DIR.parent.parent + + CUDA_ARCHES = ["12.6", "12.8", "12.9", "13.0"] CUDA_STABLE = "12.8" CUDA_ARCHES_FULL_VERSION = { @@ -31,8 +37,7 @@ "13.0": "9", } -# NOTE: Please also update the ROCm sources in `PIP_SOURCES` in tools/nightly.py when changing this -ROCM_ARCHES = ["6.4", "7.0"] +ROCM_ARCHES = ["7.0", "7.1"] XPU_ARCHES = ["xpu"] @@ -137,9 +142,48 @@ } -def get_nccl_wheel_version(arch_version: str) -> str: - import re +# Used by tools/nightly.py +PYTORCH_NIGHTLY_PIP_INDEX_URL = "https://download.pytorch.org/whl/nightly" +NIGHTLY_SOURCE_MATRIX = { + "cpu": dict( + name="cpu", + index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cpu", + supported_platforms=["Linux", "macOS", "Windows"], + accelerator="cpu", + ) +} +CUDA_NIGHTLY_SOURCE_MATRIX = { + f"cuda-{major}.{minor}": dict( + name=f"cuda-{major}.{minor}", + index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu{major}{minor}", + supported_platforms=["Linux", "Windows"], + accelerator="cuda", + ) + for major, minor in (map(int, version.split(".")) for version in CUDA_ARCHES) +} +ROCM_NIGHTLY_SOURCE_MATRIX = { + f"rocm-{major}.{minor}": dict( + name=f"rocm-{major}.{minor}", + index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/rocm{major}.{minor}", + supported_platforms=["Linux"], + accelerator="rocm", + ) + for major, minor in (map(int, version.split(".")) for version in ROCM_ARCHES) +} +XPU_NIGHTLY_SOURCE_MATRIX = { + "xpu": dict( + name="xpu", + index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/xpu", + supported_platforms=["Linux"], + accelerator="xpu", + ) +} +NIGHTLY_SOURCE_MATRIX.update(CUDA_NIGHTLY_SOURCE_MATRIX) +NIGHTLY_SOURCE_MATRIX.update(ROCM_NIGHTLY_SOURCE_MATRIX) +NIGHTLY_SOURCE_MATRIX.update(XPU_NIGHTLY_SOURCE_MATRIX) + +def get_nccl_wheel_version(arch_version: str) -> str: requirements = map( str.strip, re.split("[;|]", PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version]) ) @@ -147,17 +191,14 @@ def get_nccl_wheel_version(arch_version: str) -> str: def read_nccl_pin(arch_version: str) -> str: - from pathlib import Path - - nccl_pin_path = os.path.join( - Path(__file__).absolute().parents[2], - ".ci", - "docker", - "ci_commit_pins", - f"nccl-cu{arch_version[:2]}.txt", + nccl_pin_path = ( + REPO_ROOT + / ".ci" + / "docker" + / "ci_commit_pins" + / f"nccl-cu{arch_version[:2]}.txt" ) - with open(nccl_pin_path) as f: - return f.read().strip() + return nccl_pin_path.read_text().strip() def validate_nccl_dep_consistency(arch_version: str) -> None: @@ -165,7 +206,8 @@ def validate_nccl_dep_consistency(arch_version: str) -> None: wheel_ver = get_nccl_wheel_version(arch_version) if not nccl_release_tag.startswith(f"v{wheel_ver}"): raise RuntimeError( - f"{arch_version} NCCL release tag version {nccl_release_tag} does not correspond to wheel version {wheel_ver}" + f"{arch_version} NCCL release tag version {nccl_release_tag} " + f"does not correspond to wheel version {wheel_ver}" ) @@ -412,7 +454,14 @@ def generate_wheels_matrix( return ret -validate_nccl_dep_consistency("13.0") -validate_nccl_dep_consistency("12.9") -validate_nccl_dep_consistency("12.8") -validate_nccl_dep_consistency("12.6") +arch_version = "" +for arch_version in CUDA_ARCHES: + validate_nccl_dep_consistency(arch_version) +del arch_version + + +if __name__ == "__main__": + # Used by tools/nightly.py + (SCRIPT_DIR / "nightly_source_matrix.json").write_text( + json.dumps(NIGHTLY_SOURCE_MATRIX, indent=4) + "\n" + ) diff --git a/.github/workflows/_xpu-test.yml b/.github/workflows/_xpu-test.yml index 7aa7608924487..e68bc6ead3a26 100644 --- a/.github/workflows/_xpu-test.yml +++ b/.github/workflows/_xpu-test.yml @@ -38,6 +38,10 @@ on: default: "" description: | List of tests to include (empty string implies default list) + dashboard-tag: + required: false + type: string + default: "" disable-monitor: description: | [Experimental] Disable utilization monitoring for tests. @@ -58,6 +62,11 @@ on: required: false type: number default: 1 + secrets: + HUGGING_FACE_HUB_TOKEN: + required: false + description: | + HF Auth token to avoid rate limits when downloading models or datasets from hub permissions: id-token: write contents: read @@ -196,6 +205,8 @@ jobs: PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }} PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }} TESTS_TO_INCLUDE: ${{ inputs.tests-to-include }} + DASHBOARD_TAG: ${{ inputs.dashboard-tag }} + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} timeout-minutes: ${{ fromJson(steps.test-timeout.outputs.timeout) }} run: | # Fetch aws credential from IMDs @@ -246,6 +257,8 @@ jobs: -e PYTORCH_TEST_RERUN_DISABLED_TESTS \ -e TESTS_TO_INCLUDE \ -e ZE_AFFINITY_MASK \ + -e HUGGING_FACE_HUB_TOKEN \ + -e DASHBOARD_TAG \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --ulimit stack=10485760:83886080 \ --ulimit core=0 \ diff --git a/.github/workflows/build-almalinux-images.yml b/.github/workflows/build-almalinux-images.yml index 8318286cccbee..d1262ace0cde8 100644 --- a/.github/workflows/build-almalinux-images.yml +++ b/.github/workflows/build-almalinux-images.yml @@ -36,7 +36,7 @@ jobs: runs-on: linux.9xlarge.ephemeral strategy: matrix: - tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm6.4", "rocm7.0", "cpu"] + tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "rocm7.0", "rocm7.1", "cpu"] steps: - name: Build docker image uses: pytorch/pytorch/.github/actions/binary-docker-build@main diff --git a/.github/workflows/build-libtorch-images.yml b/.github/workflows/build-libtorch-images.yml index c67281e0a112b..09b17c8744f7d 100644 --- a/.github/workflows/build-libtorch-images.yml +++ b/.github/workflows/build-libtorch-images.yml @@ -52,8 +52,8 @@ jobs: { tag: "cuda12.9" }, { tag: "cuda12.8" }, { tag: "cuda12.6" }, - { tag: "rocm6.4" }, { tag: "rocm7.0" }, + { tag: "rocm7.1" }, { tag: "cpu" }, ] steps: diff --git a/.github/workflows/build-magma-rocm-linux.yml b/.github/workflows/build-magma-rocm-linux.yml index eaeb741e56394..1913229a66805 100644 --- a/.github/workflows/build-magma-rocm-linux.yml +++ b/.github/workflows/build-magma-rocm-linux.yml @@ -34,7 +34,7 @@ jobs: id-token: write strategy: matrix: - rocm_version: ["70", "64"] + rocm_version: ["71", "70"] steps: - name: Checkout PyTorch uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/build-manywheel-images.yml b/.github/workflows/build-manywheel-images.yml index a5c5c387adb82..c4952c3df0f19 100644 --- a/.github/workflows/build-manywheel-images.yml +++ b/.github/workflows/build-manywheel-images.yml @@ -54,8 +54,8 @@ jobs: { name: "manylinuxaarch64-builder", tag: "cuda12.9", runner: "linux.arm64.2xlarge.ephemeral" }, { name: "manylinuxaarch64-builder", tag: "cuda12.8", runner: "linux.arm64.2xlarge.ephemeral" }, { name: "manylinuxaarch64-builder", tag: "cuda12.6", runner: "linux.arm64.2xlarge.ephemeral" }, - { name: "manylinux2_28-builder", tag: "rocm6.4", runner: "linux.9xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "rocm7.0", runner: "linux.9xlarge.ephemeral" }, + { name: "manylinux2_28-builder", tag: "rocm7.1", runner: "linux.9xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "cpu", runner: "linux.9xlarge.ephemeral" }, { name: "manylinux2_28_aarch64-builder", tag: "cpu-aarch64", runner: "linux.arm64.2xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "xpu", runner: "linux.9xlarge.ephemeral" }, diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml index 9e4144ae56c2d..f8fe484b042ff 100644 --- a/.github/workflows/build-triton-wheel.yml +++ b/.github/workflows/build-triton-wheel.yml @@ -55,7 +55,7 @@ jobs: docker-image: ["pytorch/manylinux2_28-builder:cpu"] include: - device: "rocm" - rocm_version: "7.0" + rocm_version: "7.1" runs_on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" - device: "cuda" rocm_version: "" @@ -159,12 +159,7 @@ jobs: WITH_CLANG_LDD="--with-clang-ldd" fi - if [[ "${BUILD_DEVICE}" == xpu ]]; then - docker exec -t "${container_name}" bash -c "dnf install -y gcc-toolset-13-gcc-c++" - docker exec -t "${container_name}" bash -c "source /opt/rh/gcc-toolset-13/enable && ${PYTHON_EXECUTABLE} /pytorch/.github/scripts/build_triton_wheel.py --device=$BUILD_DEVICE $RELEASE" - else - docker exec -t "${container_name}" bash -c "${PYTHON_EXECUTABLE} /pytorch/.github/scripts/build_triton_wheel.py --device=$BUILD_DEVICE $RELEASE $WITH_CLANG_LDD" - fi + docker exec -t "${container_name}" bash -c "${PYTHON_EXECUTABLE} /pytorch/.github/scripts/build_triton_wheel.py --device=$BUILD_DEVICE $RELEASE $WITH_CLANG_LDD" if [[ ("${{ matrix.device }}" == "cuda" || "${{ matrix.device }}" == "xpu") ]]; then docker exec -t "${container_name}" bash -c "auditwheel repair --plat ${PLATFORM} //artifacts/*.whl" diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index ca257ee8225ad..6fbe2e846d40b 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -57,6 +57,7 @@ jobs: pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11, pytorch-linux-jammy-py3.10-clang12, pytorch-linux-jammy-py3.13-clang12, + pytorch-linux-jammy-py3.14-clang12, pytorch-linux-jammy-rocm-n-py3, pytorch-linux-noble-rocm-n-py3, pytorch-linux-jammy-rocm-n-py3-benchmarks, @@ -66,6 +67,7 @@ jobs: pytorch-linux-jammy-py3.12-halide, pytorch-linux-jammy-xpu-n-1-py3, pytorch-linux-jammy-xpu-n-py3, + pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks, pytorch-linux-jammy-py3-clang18-asan, pytorch-linux-jammy-py3-clang12-onnx, pytorch-linux-jammy-linter, diff --git a/.github/workflows/generated-linux-binary-libtorch-nightly.yml b/.github/workflows/generated-linux-binary-libtorch-nightly.yml index 7f3277ef64a12..446415807f204 100644 --- a/.github/workflows/generated-linux-binary-libtorch-nightly.yml +++ b/.github/workflows/generated-linux-binary-libtorch-nightly.yml @@ -384,7 +384,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-rocm6_4-shared-with-deps-release-build: + libtorch-rocm7_0-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -393,23 +393,23 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: libtorch-cxx11-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 - build_name: libtorch-rocm6_4-shared-with-deps-release + build_name: libtorch-rocm7_0-shared-with-deps-release build_environment: linux-binary-libtorch secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - libtorch-rocm6_4-shared-with-deps-release-test: # Testing + libtorch-rocm7_0-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - libtorch-rocm6_4-shared-with-deps-release-build + - libtorch-rocm7_0-shared-with-deps-release-build - get-label-type runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 @@ -418,12 +418,12 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: libtorch-cxx11-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps permissions: @@ -435,7 +435,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: libtorch-rocm6_4-shared-with-deps-release + name: libtorch-rocm7_0-shared-with-deps-release path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: actions/checkout@v4 @@ -466,7 +466,7 @@ jobs: with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: libtorch-cxx11-builder - custom-tag-prefix: rocm6.4 + custom-tag-prefix: rocm7.0 docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image @@ -479,30 +479,30 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Teardown ROCm uses: ./.github/actions/teardown-rocm - libtorch-rocm6_4-shared-with-deps-release-upload: # Uploading + libtorch-rocm7_0-shared-with-deps-release-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: libtorch-rocm6_4-shared-with-deps-release-test + needs: libtorch-rocm7_0-shared-with-deps-release-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: libtorch-cxx11-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps - build_name: libtorch-rocm6_4-shared-with-deps-release + build_name: libtorch-rocm7_0-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - libtorch-rocm7_0-shared-with-deps-release-build: + libtorch-rocm7_1-shared-with-deps-release-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -511,23 +511,23 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: libtorch-cxx11-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 - build_name: libtorch-rocm7_0-shared-with-deps-release + build_name: libtorch-rocm7_1-shared-with-deps-release build_environment: linux-binary-libtorch secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - libtorch-rocm7_0-shared-with-deps-release-test: # Testing + libtorch-rocm7_1-shared-with-deps-release-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - libtorch-rocm7_0-shared-with-deps-release-build + - libtorch-rocm7_1-shared-with-deps-release-build - get-label-type runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 @@ -536,12 +536,12 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: libtorch-cxx11-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps permissions: @@ -553,7 +553,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: libtorch-rocm7_0-shared-with-deps-release + name: libtorch-rocm7_1-shared-with-deps-release path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: actions/checkout@v4 @@ -584,7 +584,7 @@ jobs: with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: libtorch-cxx11-builder - custom-tag-prefix: rocm7.0 + custom-tag-prefix: rocm7.1 docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image @@ -597,25 +597,25 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Teardown ROCm uses: ./.github/actions/teardown-rocm - libtorch-rocm7_0-shared-with-deps-release-upload: # Uploading + libtorch-rocm7_1-shared-with-deps-release-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: libtorch-rocm7_0-shared-with-deps-release-test + needs: libtorch-rocm7_1-shared-with-deps-release-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: libtorch-cxx11-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps - build_name: libtorch-rocm7_0-shared-with-deps-release + build_name: libtorch-rocm7_1-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index 5fcf4e0bd176f..21c1d5caa3829 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -373,7 +373,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-rocm6_4-build: + manywheel-py3_10-rocm7_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -382,22 +382,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 - build_name: manywheel-py3_10-rocm6_4 + build_name: manywheel-py3_10-rocm7_0 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-rocm6_4-test: # Testing + manywheel-py3_10-rocm7_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_10-rocm6_4-build + - manywheel-py3_10-rocm7_0-build - get-label-type runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 @@ -406,12 +406,12 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.10" permissions: id-token: write @@ -422,7 +422,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: manywheel-py3_10-rocm6_4 + name: manywheel-py3_10-rocm7_0 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: actions/checkout@v4 @@ -453,7 +453,7 @@ jobs: with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder - custom-tag-prefix: rocm6.4 + custom-tag-prefix: rocm7.0 docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image @@ -466,29 +466,29 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Teardown ROCm uses: ./.github/actions/teardown-rocm - manywheel-py3_10-rocm6_4-upload: # Uploading + manywheel-py3_10-rocm7_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_10-rocm6_4-test + needs: manywheel-py3_10-rocm7_0-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-rocm6_4 + build_name: manywheel-py3_10-rocm7_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-rocm7_0-build: + manywheel-py3_10-rocm7_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -497,22 +497,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 - build_name: manywheel-py3_10-rocm7_0 + build_name: manywheel-py3_10-rocm7_1 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-rocm7_0-test: # Testing + manywheel-py3_10-rocm7_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_10-rocm7_0-build + - manywheel-py3_10-rocm7_1-build - get-label-type runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 @@ -521,12 +521,12 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.10" permissions: id-token: write @@ -537,7 +537,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: manywheel-py3_10-rocm7_0 + name: manywheel-py3_10-rocm7_1 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: actions/checkout@v4 @@ -568,7 +568,7 @@ jobs: with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder - custom-tag-prefix: rocm7.0 + custom-tag-prefix: rocm7.1 docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image @@ -581,24 +581,24 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Teardown ROCm uses: ./.github/actions/teardown-rocm - manywheel-py3_10-rocm7_0-upload: # Uploading + manywheel-py3_10-rocm7_1-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_10-rocm7_0-test + needs: manywheel-py3_10-rocm7_1-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-rocm7_0 + build_name: manywheel-py3_10-rocm7_1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -1039,7 +1039,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_11-rocm6_4-build: + manywheel-py3_11-rocm7_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1048,22 +1048,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 - build_name: manywheel-py3_11-rocm6_4 + build_name: manywheel-py3_11-rocm7_0 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-rocm6_4-test: # Testing + manywheel-py3_11-rocm7_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_11-rocm6_4-build + - manywheel-py3_11-rocm7_0-build - get-label-type runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 @@ -1072,12 +1072,12 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.11" permissions: id-token: write @@ -1088,7 +1088,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: manywheel-py3_11-rocm6_4 + name: manywheel-py3_11-rocm7_0 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: actions/checkout@v4 @@ -1119,7 +1119,7 @@ jobs: with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder - custom-tag-prefix: rocm6.4 + custom-tag-prefix: rocm7.0 docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image @@ -1132,29 +1132,29 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Teardown ROCm uses: ./.github/actions/teardown-rocm - manywheel-py3_11-rocm6_4-upload: # Uploading + manywheel-py3_11-rocm7_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_11-rocm6_4-test + needs: manywheel-py3_11-rocm7_0-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-rocm6_4 + build_name: manywheel-py3_11-rocm7_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_11-rocm7_0-build: + manywheel-py3_11-rocm7_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1163,22 +1163,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 - build_name: manywheel-py3_11-rocm7_0 + build_name: manywheel-py3_11-rocm7_1 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-rocm7_0-test: # Testing + manywheel-py3_11-rocm7_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_11-rocm7_0-build + - manywheel-py3_11-rocm7_1-build - get-label-type runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 @@ -1187,12 +1187,12 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.11" permissions: id-token: write @@ -1203,7 +1203,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: manywheel-py3_11-rocm7_0 + name: manywheel-py3_11-rocm7_1 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: actions/checkout@v4 @@ -1234,7 +1234,7 @@ jobs: with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder - custom-tag-prefix: rocm7.0 + custom-tag-prefix: rocm7.1 docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image @@ -1247,24 +1247,24 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Teardown ROCm uses: ./.github/actions/teardown-rocm - manywheel-py3_11-rocm7_0-upload: # Uploading + manywheel-py3_11-rocm7_1-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_11-rocm7_0-test + needs: manywheel-py3_11-rocm7_1-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-rocm7_0 + build_name: manywheel-py3_11-rocm7_1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -1705,7 +1705,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_12-rocm6_4-build: + manywheel-py3_12-rocm7_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1714,22 +1714,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 - build_name: manywheel-py3_12-rocm6_4 + build_name: manywheel-py3_12-rocm7_0 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-rocm6_4-test: # Testing + manywheel-py3_12-rocm7_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_12-rocm6_4-build + - manywheel-py3_12-rocm7_0-build - get-label-type runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 @@ -1738,12 +1738,12 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.12" permissions: id-token: write @@ -1754,7 +1754,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: manywheel-py3_12-rocm6_4 + name: manywheel-py3_12-rocm7_0 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: actions/checkout@v4 @@ -1785,7 +1785,7 @@ jobs: with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder - custom-tag-prefix: rocm6.4 + custom-tag-prefix: rocm7.0 docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image @@ -1798,29 +1798,29 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Teardown ROCm uses: ./.github/actions/teardown-rocm - manywheel-py3_12-rocm6_4-upload: # Uploading + manywheel-py3_12-rocm7_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_12-rocm6_4-test + needs: manywheel-py3_12-rocm7_0-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-rocm6_4 + build_name: manywheel-py3_12-rocm7_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_12-rocm7_0-build: + manywheel-py3_12-rocm7_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1829,22 +1829,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 - build_name: manywheel-py3_12-rocm7_0 + build_name: manywheel-py3_12-rocm7_1 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-rocm7_0-test: # Testing + manywheel-py3_12-rocm7_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_12-rocm7_0-build + - manywheel-py3_12-rocm7_1-build - get-label-type runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 @@ -1853,12 +1853,12 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.12" permissions: id-token: write @@ -1869,7 +1869,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: manywheel-py3_12-rocm7_0 + name: manywheel-py3_12-rocm7_1 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: actions/checkout@v4 @@ -1900,7 +1900,7 @@ jobs: with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder - custom-tag-prefix: rocm7.0 + custom-tag-prefix: rocm7.1 docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image @@ -1913,24 +1913,24 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Teardown ROCm uses: ./.github/actions/teardown-rocm - manywheel-py3_12-rocm7_0-upload: # Uploading + manywheel-py3_12-rocm7_1-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_12-rocm7_0-test + needs: manywheel-py3_12-rocm7_1-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-rocm7_0 + build_name: manywheel-py3_12-rocm7_1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -2371,7 +2371,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-rocm6_4-build: + manywheel-py3_13-rocm7_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2380,22 +2380,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 - build_name: manywheel-py3_13-rocm6_4 + build_name: manywheel-py3_13-rocm7_0 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-rocm6_4-test: # Testing + manywheel-py3_13-rocm7_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13-rocm6_4-build + - manywheel-py3_13-rocm7_0-build - get-label-type runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 @@ -2404,12 +2404,12 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.13" permissions: id-token: write @@ -2420,7 +2420,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: manywheel-py3_13-rocm6_4 + name: manywheel-py3_13-rocm7_0 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: actions/checkout@v4 @@ -2451,7 +2451,7 @@ jobs: with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder - custom-tag-prefix: rocm6.4 + custom-tag-prefix: rocm7.0 docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image @@ -2464,29 +2464,29 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Teardown ROCm uses: ./.github/actions/teardown-rocm - manywheel-py3_13-rocm6_4-upload: # Uploading + manywheel-py3_13-rocm7_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13-rocm6_4-test + needs: manywheel-py3_13-rocm7_0-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-rocm6_4 + build_name: manywheel-py3_13-rocm7_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-rocm7_0-build: + manywheel-py3_13-rocm7_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2495,22 +2495,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 - build_name: manywheel-py3_13-rocm7_0 + build_name: manywheel-py3_13-rocm7_1 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-rocm7_0-test: # Testing + manywheel-py3_13-rocm7_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13-rocm7_0-build + - manywheel-py3_13-rocm7_1-build - get-label-type runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 @@ -2519,12 +2519,12 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.13" permissions: id-token: write @@ -2535,7 +2535,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: manywheel-py3_13-rocm7_0 + name: manywheel-py3_13-rocm7_1 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: actions/checkout@v4 @@ -2566,7 +2566,7 @@ jobs: with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder - custom-tag-prefix: rocm7.0 + custom-tag-prefix: rocm7.1 docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image @@ -2579,24 +2579,24 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Teardown ROCm uses: ./.github/actions/teardown-rocm - manywheel-py3_13-rocm7_0-upload: # Uploading + manywheel-py3_13-rocm7_1-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13-rocm7_0-test + needs: manywheel-py3_13-rocm7_1-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-rocm7_0 + build_name: manywheel-py3_13-rocm7_1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -3037,7 +3037,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13t-rocm6_4-build: + manywheel-py3_13t-rocm7_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -3046,22 +3046,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 - build_name: manywheel-py3_13t-rocm6_4 + build_name: manywheel-py3_13t-rocm7_0 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-rocm6_4-test: # Testing + manywheel-py3_13t-rocm7_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13t-rocm6_4-build + - manywheel-py3_13t-rocm7_0-build - get-label-type runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 @@ -3070,12 +3070,12 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.13t" permissions: id-token: write @@ -3086,7 +3086,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: manywheel-py3_13t-rocm6_4 + name: manywheel-py3_13t-rocm7_0 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: actions/checkout@v4 @@ -3117,7 +3117,7 @@ jobs: with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder - custom-tag-prefix: rocm6.4 + custom-tag-prefix: rocm7.0 docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image @@ -3130,29 +3130,29 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Teardown ROCm uses: ./.github/actions/teardown-rocm - manywheel-py3_13t-rocm6_4-upload: # Uploading + manywheel-py3_13t-rocm7_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13t-rocm6_4-test + needs: manywheel-py3_13t-rocm7_0-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-rocm6_4 + build_name: manywheel-py3_13t-rocm7_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13t-rocm7_0-build: + manywheel-py3_13t-rocm7_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -3161,22 +3161,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 - build_name: manywheel-py3_13t-rocm7_0 + build_name: manywheel-py3_13t-rocm7_1 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-rocm7_0-test: # Testing + manywheel-py3_13t-rocm7_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13t-rocm7_0-build + - manywheel-py3_13t-rocm7_1-build - get-label-type runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 @@ -3185,12 +3185,12 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.13t" permissions: id-token: write @@ -3201,7 +3201,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: manywheel-py3_13t-rocm7_0 + name: manywheel-py3_13t-rocm7_1 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: actions/checkout@v4 @@ -3232,7 +3232,7 @@ jobs: with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder - custom-tag-prefix: rocm7.0 + custom-tag-prefix: rocm7.1 docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image @@ -3245,24 +3245,24 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Teardown ROCm uses: ./.github/actions/teardown-rocm - manywheel-py3_13t-rocm7_0-upload: # Uploading + manywheel-py3_13t-rocm7_1-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13t-rocm7_0-test + needs: manywheel-py3_13t-rocm7_1-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-rocm7_0 + build_name: manywheel-py3_13t-rocm7_1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -3703,7 +3703,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_14-rocm6_4-build: + manywheel-py3_14-rocm7_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -3712,22 +3712,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.14" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 - build_name: manywheel-py3_14-rocm6_4 + build_name: manywheel-py3_14-rocm7_0 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14-rocm6_4-test: # Testing + manywheel-py3_14-rocm7_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_14-rocm6_4-build + - manywheel-py3_14-rocm7_0-build - get-label-type runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 @@ -3736,12 +3736,12 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.14" permissions: id-token: write @@ -3752,7 +3752,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: manywheel-py3_14-rocm6_4 + name: manywheel-py3_14-rocm7_0 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: actions/checkout@v4 @@ -3783,7 +3783,7 @@ jobs: with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder - custom-tag-prefix: rocm6.4 + custom-tag-prefix: rocm7.0 docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image @@ -3796,29 +3796,29 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Teardown ROCm uses: ./.github/actions/teardown-rocm - manywheel-py3_14-rocm6_4-upload: # Uploading + manywheel-py3_14-rocm7_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_14-rocm6_4-test + needs: manywheel-py3_14-rocm7_0-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.14" - build_name: manywheel-py3_14-rocm6_4 + build_name: manywheel-py3_14-rocm7_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_14-rocm7_0-build: + manywheel-py3_14-rocm7_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -3827,22 +3827,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.14" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 - build_name: manywheel-py3_14-rocm7_0 + build_name: manywheel-py3_14-rocm7_1 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14-rocm7_0-test: # Testing + manywheel-py3_14-rocm7_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_14-rocm7_0-build + - manywheel-py3_14-rocm7_1-build - get-label-type runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 @@ -3851,12 +3851,12 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.14" permissions: id-token: write @@ -3867,7 +3867,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: manywheel-py3_14-rocm7_0 + name: manywheel-py3_14-rocm7_1 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: actions/checkout@v4 @@ -3898,7 +3898,7 @@ jobs: with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder - custom-tag-prefix: rocm7.0 + custom-tag-prefix: rocm7.1 docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image @@ -3911,24 +3911,24 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Teardown ROCm uses: ./.github/actions/teardown-rocm - manywheel-py3_14-rocm7_0-upload: # Uploading + manywheel-py3_14-rocm7_1-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_14-rocm7_0-test + needs: manywheel-py3_14-rocm7_1-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.14" - build_name: manywheel-py3_14-rocm7_0 + build_name: manywheel-py3_14-rocm7_1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml @@ -4369,7 +4369,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_14t-rocm6_4-build: + manywheel-py3_14t-rocm7_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -4378,22 +4378,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.14t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 - build_name: manywheel-py3_14t-rocm6_4 + build_name: manywheel-py3_14t-rocm7_0 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14t-rocm6_4-test: # Testing + manywheel-py3_14t-rocm7_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_14t-rocm6_4-build + - manywheel-py3_14t-rocm7_0-build - get-label-type runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 @@ -4402,12 +4402,12 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.14t" permissions: id-token: write @@ -4418,7 +4418,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: manywheel-py3_14t-rocm6_4 + name: manywheel-py3_14t-rocm7_0 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: actions/checkout@v4 @@ -4449,7 +4449,7 @@ jobs: with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder - custom-tag-prefix: rocm6.4 + custom-tag-prefix: rocm7.0 docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image @@ -4462,29 +4462,29 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Teardown ROCm uses: ./.github/actions/teardown-rocm - manywheel-py3_14t-rocm6_4-upload: # Uploading + manywheel-py3_14t-rocm7_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_14t-rocm6_4-test + needs: manywheel-py3_14t-rocm7_0-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm6.4 - GPU_ARCH_VERSION: "6.4" + DESIRED_CUDA: rocm7.0 + GPU_ARCH_VERSION: "7.0" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm6.4 + DOCKER_IMAGE_TAG_PREFIX: rocm7.0 DESIRED_PYTHON: "3.14t" - build_name: manywheel-py3_14t-rocm6_4 + build_name: manywheel-py3_14t-rocm7_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_14t-rocm7_0-build: + manywheel-py3_14t-rocm7_1-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -4493,22 +4493,22 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.14t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" timeout-minutes: 300 - build_name: manywheel-py3_14t-rocm7_0 + build_name: manywheel-py3_14t-rocm7_1 build_environment: linux-binary-manywheel secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14t-rocm7_0-test: # Testing + manywheel-py3_14t-rocm7_1-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_14t-rocm7_0-build + - manywheel-py3_14t-rocm7_1-build - get-label-type runs-on: linux.rocm.gpu.mi250 timeout-minutes: 240 @@ -4517,12 +4517,12 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm SKIP_ALL_TESTS: 1 DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.14t" permissions: id-token: write @@ -4533,7 +4533,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: manywheel-py3_14t-rocm7_0 + name: manywheel-py3_14t-rocm7_1 path: "${{ runner.temp }}/artifacts/" - name: Checkout PyTorch uses: actions/checkout@v4 @@ -4564,7 +4564,7 @@ jobs: with: docker-registry: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') && '308535385114.dkr.ecr.us-east-1.amazonaws.com' || 'docker.io' }} docker-image-name: manylinux2_28-builder - custom-tag-prefix: rocm7.0 + custom-tag-prefix: rocm7.1 docker-build-dir: .ci/docker working-directory: pytorch - name: Pull Docker image @@ -4577,24 +4577,24 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - name: Teardown ROCm uses: ./.github/actions/teardown-rocm - manywheel-py3_14t-rocm7_0-upload: # Uploading + manywheel-py3_14t-rocm7_1-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_14t-rocm7_0-test + needs: manywheel-py3_14t-rocm7_1-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: rocm7.0 - GPU_ARCH_VERSION: "7.0" + DESIRED_CUDA: rocm7.1 + GPU_ARCH_VERSION: "7.1" GPU_ARCH_TYPE: rocm DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: rocm7.0 + DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.14t" - build_name: manywheel-py3_14t-rocm7_0 + build_name: manywheel-py3_14t-rocm7_1 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} uses: ./.github/workflows/_binary-upload.yml diff --git a/.github/workflows/inductor-perf-test-nightly-xpu.yml b/.github/workflows/inductor-perf-test-nightly-xpu.yml new file mode 100644 index 0000000000000..c2db8c310e368 --- /dev/null +++ b/.github/workflows/inductor-perf-test-nightly-xpu.yml @@ -0,0 +1,148 @@ +name: inductor-perf-nightly-xpu + +on: + push: + tags: + - ciflow/inductor-perf-test-nightly-xpu/* + schedule: + - cron: 30 17 * * * + workflow_dispatch: + inputs: + training: + description: Run training (on by default)? + required: false + type: boolean + default: true + inference: + description: Run inference (on by default)? + required: false + type: boolean + default: true + default: + description: Run inductor_default? + required: false + type: boolean + default: false + dynamic: + description: Run inductor_dynamic_shapes? + required: false + type: boolean + default: false + cppwrapper: + description: Run inductor_cpp_wrapper? + required: false + type: boolean + default: false + cudagraphs: + description: Run inductor_cudagraphs? + required: false + type: boolean + default: false + freezing_cudagraphs: + description: Run inductor_cudagraphs with freezing for inference? + required: false + type: boolean + default: false + aotinductor: + description: Run aot_inductor for inference? + required: false + type: boolean + default: false + maxautotune: + description: Run inductor_max_autotune? + required: false + type: boolean + default: false + benchmark_configs: + description: The list of configs used the benchmark + required: false + type: string + default: inductor_huggingface_perf,inductor_timm_perf,inductor_torchbench_perf,cachebench + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +permissions: read-all + +jobs: + get-label-type: + name: get-label-type + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + opt_out_experiments: lf + + xpu-n-py3_10-inductor-benchmark-build: + name: xpu-n-py3.10-inductor-benchmark + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-xpu-n-py3.10 + docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks + runner: linux.c7i.12xlarge + test-matrix: | + { include: [ + { config: "inductor_huggingface_perf_xpu", shard: 1, num_shards: 5, runner: "linux.idc.xpu" }, + { config: "inductor_huggingface_perf_xpu", shard: 2, num_shards: 5, runner: "linux.idc.xpu" }, + { config: "inductor_huggingface_perf_xpu", shard: 3, num_shards: 5, runner: "linux.idc.xpu" }, + { config: "inductor_huggingface_perf_xpu", shard: 4, num_shards: 5, runner: "linux.idc.xpu" }, + { config: "inductor_huggingface_perf_xpu", shard: 5, num_shards: 5, runner: "linux.idc.xpu" }, + { config: "inductor_timm_perf_xpu", shard: 1, num_shards: 6, runner: "linux.idc.xpu" }, + { config: "inductor_timm_perf_xpu", shard: 2, num_shards: 6, runner: "linux.idc.xpu" }, + { config: "inductor_timm_perf_xpu", shard: 3, num_shards: 6, runner: "linux.idc.xpu" }, + { config: "inductor_timm_perf_xpu", shard: 4, num_shards: 6, runner: "linux.idc.xpu" }, + { config: "inductor_timm_perf_xpu", shard: 5, num_shards: 6, runner: "linux.idc.xpu" }, + { config: "inductor_timm_perf_xpu", shard: 6, num_shards: 6, runner: "linux.idc.xpu" }, + { config: "inductor_torchbench_perf_xpu", shard: 1, num_shards: 6, runner: "linux.idc.xpu" }, + { config: "inductor_torchbench_perf_xpu", shard: 2, num_shards: 6, runner: "linux.idc.xpu" }, + { config: "inductor_torchbench_perf_xpu", shard: 3, num_shards: 6, runner: "linux.idc.xpu" }, + { config: "inductor_torchbench_perf_xpu", shard: 4, num_shards: 6, runner: "linux.idc.xpu" }, + { config: "inductor_torchbench_perf_xpu", shard: 5, num_shards: 6, runner: "linux.idc.xpu" }, + { config: "inductor_torchbench_perf_xpu", shard: 6, num_shards: 6, runner: "linux.idc.xpu" }, + ]} + secrets: inherit + + xpu-n-py3_10-inductor-benchmark-test-nightly: + permissions: + id-token: write + contents: read + if: github.event_name != 'workflow_dispatch' + name: xpu-n-py3.10-inductor-benchmark + uses: ./.github/workflows/_xpu-test.yml + needs: xpu-n-py3_10-inductor-benchmark-build + with: + build-environment: linux-jammy-xpu-n-py3.10 + dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-false-cppwrapper-true-aotinductor-true-freezing_cudagraphs-false-cudagraphs_low_precision-false + docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }} + test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }} + timeout-minutes: 720 + # Disable monitor in perf tests for more investigation + disable-monitor: true + monitor-log-interval: 10 + monitor-data-collect-interval: 2 + secrets: inherit + + xpu-n-py3_10-inductor-benchmark-test: + permissions: + id-token: write + contents: read + if: github.event_name == 'workflow_dispatch' + name: xpu-n-py3.10-inductor-test + uses: ./.github/workflows/_xpu-test.yml + needs: xpu-n-py3_10-inductor-benchmark-build + with: + build-environment: linux-jammy-xpu-n-py3.10 + dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }} + docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }} + test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }} + timeout-minutes: 720 + disable-monitor: false + monitor-log-interval: 15 + monitor-data-collect-interval: 4 + secrets: inherit diff --git a/.github/workflows/inductor-rocm.yml b/.github/workflows/inductor-rocm.yml index b1bb7972d67de..b2ff53a645481 100644 --- a/.github/workflows/inductor-rocm.yml +++ b/.github/workflows/inductor-rocm.yml @@ -1,9 +1,10 @@ name: inductor-rocm on: + schedule: + - cron: 0 * * * * push: branches: - - main - release/* tags: - ciflow/inductor-rocm/* diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 729b111574851..01f0434aa5023 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -76,11 +76,12 @@ jobs: # NOTE: mypy needs its own job because it depends on --all-files, without assessing all files it sometimes # fails to find types when it should - lintrunner-mypy: + # NOTE: We should be able to disable this and consolidate with Pyrefly + lintrunner-pyrefly: uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main - name: lintrunner-mypy-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }} + name: lintrunner-pyrefly-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }} needs: [get-label-type, get-changed-files] - # Only run if there are changed files relevant to mypy + # Only run if there are changed files relevant to pyrefly if: | github.repository_owner == 'pytorch' && ( needs.get-changed-files.outputs.changed-files == '*' || @@ -98,8 +99,8 @@ jobs: ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} script: | CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" - echo "Running mypy" - ADDITIONAL_LINTRUNNER_ARGS="--take MYPY,MYPYSTRICT --all-files" .github/scripts/lintrunner.sh + echo "Running pyrefly" + ADDITIONAL_LINTRUNNER_ARGS="--take PYREFLY --all-files" .github/scripts/lintrunner.sh lintrunner-noclang: uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main @@ -118,9 +119,9 @@ jobs: CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" echo "Running all other linters" if [ "$CHANGED_FILES" = '*' ]; then - ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY --all-files" .github/scripts/lintrunner.sh + ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,PYREFLY --all-files" .github/scripts/lintrunner.sh else - ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,MYPY,MYPYSTRICT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh + ADDITIONAL_LINTRUNNER_ARGS="--skip CLANGTIDY,CLANGFORMAT,PYREFLY ${CHANGED_FILES}" .github/scripts/lintrunner.sh fi quick-checks: diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 696c5b68b475b..0682dd2144afd 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -41,7 +41,7 @@ jobs: uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.10-gcc11 docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11 secrets: inherit diff --git a/.github/workflows/periodic-rocm-mi200.yml b/.github/workflows/periodic-rocm-mi200.yml new file mode 100644 index 0000000000000..6b65bf05cbde0 --- /dev/null +++ b/.github/workflows/periodic-rocm-mi200.yml @@ -0,0 +1,84 @@ +name: periodic-rocm-mi200 + +on: + schedule: + # We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs. + # Also run less frequently on weekends. + - cron: 45 0,8,16 * * 1-5 + - cron: 45 4 * * 0,6 + - cron: 45 4,12,20 * * 1-5 + - cron: 45 12 * * 0,6 + - cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests + push: + tags: + - ciflow/periodic/* + - ciflow/periodic-rocm-mi200/* + branches: + - release/* + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }} + cancel-in-progress: true + +permissions: + id-token: write + contents: read + +jobs: + llm-td: + if: github.repository_owner == 'pytorch' + name: before-test + uses: ./.github/workflows/llm_td_retrieval.yml + permissions: + id-token: write + contents: read + + target-determination: + name: before-test + uses: ./.github/workflows/target_determination.yml + needs: llm-td + permissions: + id-token: write + contents: read + + get-label-type: + name: get-label-type + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + + linux-jammy-rocm-py3_10-build: + name: linux-jammy-rocm-py3.10 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-rocm-py3.10 + docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 + test-matrix: | + { include: [ + { config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.4", owners: ["module:rocm", "oncall:distributed"] }, + { config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.4", owners: ["module:rocm", "oncall:distributed"] }, + { config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.4", owners: ["module:rocm", "oncall:distributed"] }, + ]} + secrets: inherit + + linux-jammy-rocm-py3_10-test: + permissions: + id-token: write + contents: read + name: linux-jammy-rocm-py3.10 + uses: ./.github/workflows/_rocm-test.yml + needs: + - linux-jammy-rocm-py3_10-build + - target-determination + with: + build-environment: linux-jammy-rocm-py3.10 + docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} + secrets: inherit diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 0c4668aa89c6b..5a90db9ab5737 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -204,37 +204,6 @@ jobs: test-matrix: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-build.outputs.test-matrix }} secrets: inherit - linux-jammy-rocm-py3_10-build: - name: linux-jammy-rocm-py3.10 - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-rocm-py3.10 - docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 - test-matrix: | - { include: [ - { config: "distributed", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.mi250.4", owners: ["module:rocm", "oncall:distributed"] }, - { config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.mi250.4", owners: ["module:rocm", "oncall:distributed"] }, - { config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.mi250.4", owners: ["module:rocm", "oncall:distributed"] }, - ]} - secrets: inherit - - linux-jammy-rocm-py3_10-test: - permissions: - id-token: write - contents: read - name: linux-jammy-rocm-py3.10 - uses: ./.github/workflows/_rocm-test.yml - needs: - - linux-jammy-rocm-py3_10-build - - target-determination - with: - build-environment: linux-jammy-rocm-py3.10 - docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }} - secrets: inherit - linux-jammy-cuda12_8-py3-gcc11-slow-gradcheck-build: name: linux-jammy-cuda12.8-py3-gcc11-slow-gradcheck uses: ./.github/workflows/_linux-build.yml diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 7fdfab476705b..e3af55e736503 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -66,10 +66,10 @@ jobs: { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "docs_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, { config: "distributed", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "distributed", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, ]} secrets: inherit @@ -167,8 +167,8 @@ jobs: docker-image-name: ci-image:pytorch-linux-jammy-py3-clang12-onnx test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, + { config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, ]} secrets: inherit diff --git a/.github/workflows/rocm.yml b/.github/workflows/rocm.yml index 50a791432dc97..ffe6efbe0433c 100644 --- a/.github/workflows/rocm.yml +++ b/.github/workflows/rocm.yml @@ -3,13 +3,13 @@ name: rocm on: push: branches: - - main - release/* tags: - ciflow/rocm/* workflow_dispatch: schedule: - cron: 29 8 * * * # about 1:29am PDT + - cron: 0 * * * * concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} diff --git a/.github/workflows/upload-test-stats.yml b/.github/workflows/upload-test-stats.yml index f77b6081b776a..24c3ab3db84f3 100644 --- a/.github/workflows/upload-test-stats.yml +++ b/.github/workflows/upload-test-stats.yml @@ -6,6 +6,7 @@ on: - pull - trunk - periodic + - periodic-rocm-mi200 - periodic-rocm-mi300 - inductor - unstable diff --git a/.gitignore b/.gitignore index 447ef777e9291..d1b3b17445dac 100644 --- a/.gitignore +++ b/.gitignore @@ -143,6 +143,7 @@ scripts/release_notes/*.json sccache-stats*.json lint.json merge_record.json +.github/scripts/nightly_source_matrix.json # These files get copied over on invoking setup.py torchgen/packaged/* @@ -397,3 +398,4 @@ CLAUDE.local.md /test_*.py /debug_*.py CLAUDE_CONTEXT/ +/.claude/settings.local.json diff --git a/.lintrunner.toml b/.lintrunner.toml index 26ade791a1bde..92ff683a7a0c8 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -121,94 +121,6 @@ command = [ ] is_formatter = true -[[linter]] -code = 'MYPY' -include_patterns = [ - 'setup.py', - 'functorch/dim/**/*.py', - 'torch/**/*.py', - 'torch/**/*.pyi', - 'caffe2/**/*.py', - 'caffe2/**/*.pyi', - 'test/test_bundled_images.py', - 'test/test_bundled_inputs.py', - 'test/test_complex.py', - 'test/test_datapipe.py', - 'test/test_futures.py', - 'test/test_numpy_interop.py', - 'test/test_torch.py', - 'test/test_type_hints.py', - 'test/test_type_info.py', - 'test/test_utils.py', -] -exclude_patterns = [ - '**/fb/**', -] -command = [ - 'python3', - 'tools/linter/adapters/mypy_linter.py', - '--config=mypy.ini', - '--', - '@{{PATHSFILE}}' -] -init_command = [ - 'python3', - 'tools/linter/adapters/pip_init.py', - '--dry-run={{DRYRUN}}', - 'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"', - 'numpy==2.1.0 ; python_version >= "3.12"', - 'expecttest==0.3.0', - 'mypy==1.16.0', - 'sympy==1.13.3', - 'types-requests==2.27.25', - 'types-pyyaml==6.0.2', - 'types-tabulate==0.8.8', - 'types-protobuf==5.29.1.20250403', - 'types-setuptools==79.0.0.20250422', - 'types-jinja2==2.11.9', - 'types-colorama==0.4.6', - 'filelock==3.18.0', - 'junitparser==2.1.1', - 'rich==14.1.0', - 'pyyaml==6.0.2', - 'optree==0.13.0', - 'dataclasses-json==0.6.7', - 'pandas==2.2.3', -] - -[[linter]] -code = 'MYPYSTRICT' -include_patterns = [ - '.github/**/*.py', - 'benchmarks/instruction_counts/**/*.py', - 'tools/**/*.py', - 'torchgen/**/*.py', - 'torch/utils/_pytree.py', - 'torch/utils/_cxx_pytree.py', - 'torch/utils/benchmark/utils/common.py', - 'torch/utils/benchmark/utils/timer.py', - 'torch/utils/benchmark/utils/valgrind_wrapper/**/*.py', -] -exclude_patterns = [ - # (linbinyu) copied from internal repo - '**/fb/**', - 'tools/code_analyzer/gen_operators_yaml.py', - 'tools/dynamo/verify_dynamo.py', - 'tools/gen_vulkan_spv.py', - 'tools/test/gen_operators_yaml_test.py', - 'tools/test/gen_oplist_test.py', - 'tools/test/test_selective_build.py', - 'tools/experimental/torchfuzz/**', -] -command = [ - 'python3', - 'tools/linter/adapters/mypy_linter.py', - '--config=mypy-strict.ini', - '--code=MYPYSTRICT', - '--', - '@{{PATHSFILE}}' -] - [[linter]] code = 'PYREFLY' @@ -230,6 +142,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', + 'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"', 'numpy==2.1.0 ; python_version >= "3.12"', 'expecttest==0.3.0', 'pyrefly==0.36.2', diff --git a/CMakeLists.txt b/CMakeLists.txt index 991ea336a175b..ca1e4164be9b8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -374,7 +374,7 @@ cmake_dependent_option( "Build the lazy Torchscript backend, not compatible with mobile builds" ON "NOT INTERN_BUILD_MOBILE" OFF) cmake_dependent_option(BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF) -cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler" +cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin folder" OFF "USE_CUDA" OFF) cmake_dependent_option(USE_KLEIDIAI "Use KleidiAI for the ARM CPU & AARCH64 architecture." ON "CPU_AARCH64" OFF) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4c46077f9db71..9df55ca6acd5c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -11,7 +11,6 @@ aspects of contributing to PyTorch. - [Developing PyTorch](#developing-pytorch) - - [Setup the development environment](#setup-the-development-environment) - [Tips and Debugging](#tips-and-debugging) - [Nightly Checkout & Pull](#nightly-checkout--pull) - [Codebase structure](#codebase-structure) @@ -67,23 +66,6 @@ aspects of contributing to PyTorch. Follow the instructions for [installing PyTorch from source](https://github.com/pytorch/pytorch#from-source). If you get stuck when developing PyTorch on your machine, check out the [tips and debugging](#tips-and-debugging) section below for common solutions. -### Setup the development environment - -First, you need to [fork the PyTorch project on GitHub](https://github.com/pytorch/pytorch/fork) and follow the instructions at [Connecting to GitHub with SSH](https://docs.github.com/en/authentication/connecting-to-github-with-ssh) to setup your SSH authentication credentials. - -Then clone the PyTorch project and setup the development environment: - -```bash -git clone git@github.com:/pytorch.git -cd pytorch -git remote add upstream git@github.com:pytorch/pytorch.git - -make setup-env -# Or run `make setup-env-cuda` for pre-built CUDA binaries -# Or run `make setup-env-rocm` for pre-built ROCm binaries -source venv/bin/activate # or `. .\venv\Scripts\activate` on Windows -``` - ### Tips and Debugging * If you want to have no-op incremental rebuilds (which are fast), see [Make no-op build fast](#make-no-op-build-fast) below. diff --git a/aten/src/ATen/CPUGeneratorImpl.cpp b/aten/src/ATen/CPUGeneratorImpl.cpp index 44ad24b81755f..4d3dafc65663e 100644 --- a/aten/src/ATen/CPUGeneratorImpl.cpp +++ b/aten/src/ATen/CPUGeneratorImpl.cpp @@ -181,7 +181,7 @@ c10::intrusive_ptr CPUGeneratorImpl::get_state() const { static const size_t size = sizeof(CPUGeneratorImplState); static_assert(std::is_standard_layout_v, "CPUGeneratorImplState is not a PODType"); - auto state_tensor = at::detail::empty_cpu({(int64_t)size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt); + auto state_tensor = at::detail::empty_cpu({static_cast(size)}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt); auto rng_state = state_tensor.data_ptr(); // accumulate generator data to be copied into byte tensor diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 3310abfb41d54..a354b41912406 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -223,7 +223,7 @@ void Context::setSDPPriorityOrder(const std::vector& order) { "setSDPPriority order expected ", sdp_priority_order.size() - 1, " but got ", at::num_sdp_backends, " unique backends specified in priority order."); for (uint32_t i = 0; i < order.size(); i++) { - sdp_priority_order[i] = (at::SDPBackend) order[i]; + sdp_priority_order[i] = static_cast(order[i]); } } @@ -825,6 +825,14 @@ void Context::setDisplayVmapFallbackWarnings(bool enabled) { display_vmap_fallback_warnings_ = enabled; } +bool Context::warnOnAccumulateGradStreamMismatch() const { + return warn_on_accumulate_grad_stream_mismatch_; +} + +void Context::setWarnOnAccumulateGradStreamMismatch(bool enabled) { + warn_on_accumulate_grad_stream_mismatch_ = enabled; +} + bool Context::isDefaultMobileCPUAllocatorSet() { return prev_allocator_ptr_ != nullptr; } diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index a4a26b5671e59..6807e527eb75f 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -404,6 +404,9 @@ class TORCH_API Context { void setDisplayVmapFallbackWarnings(bool enabled); bool areVmapFallbackWarningsEnabled() const; + void setWarnOnAccumulateGradStreamMismatch(bool enabled); + bool warnOnAccumulateGradStreamMismatch() const; + bool isDefaultMobileCPUAllocatorSet(); void setDefaultMobileCPUAllocator(); void unsetDefaultMobileCPUAllocator(); @@ -494,6 +497,7 @@ class TORCH_API Context { bool release_original_weights = false; #endif bool display_vmap_fallback_warnings_ = false; + bool warn_on_accumulate_grad_stream_mismatch_ = true; std::atomic quantized_engine = at::QEngine::NoQEngine; bool enable_sparse_tensor_invariant_checks = false; bool allow_fp16_reduction_cpu = false; diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index 15a862274f003..40ad61cbd6455 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -197,6 +197,7 @@ inline at::ScalarType scalar_type(at::ScalarType s) { /* don't use TYPE again in case it is an expensive or side-effect op */ \ at::ScalarType _st = ::detail::scalar_type(the_type); \ RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \ + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \ switch (_st) { \ __VA_ARGS__ \ default: \ @@ -208,6 +209,7 @@ inline at::ScalarType scalar_type(at::ScalarType s) { toString(_st), \ "'"); \ } \ + C10_DIAGNOSTIC_POP() \ }() #define AT_DISPATCH_CASE_FLOATING_TYPES(...) \ diff --git a/aten/src/ATen/MapAllocator.cpp b/aten/src/ATen/MapAllocator.cpp index ed697c32b58a8..d8ad62c8c62a4 100644 --- a/aten/src/ATen/MapAllocator.cpp +++ b/aten/src/ATen/MapAllocator.cpp @@ -252,13 +252,13 @@ MapAllocator::MapAllocator(WithFd /*unused*/, std::string_view filename, int fd, if (!(flags_ & ALLOCATOR_MAPPED_FROMFD)) { if (flags_ & ALLOCATOR_MAPPED_SHARED) { // NOLINTNEXTLINE(bugprone-assignment-in-if-condition) - if ((fd = open(filename_.c_str(), flags, (mode_t)0600)) == -1) { + if ((fd = open(filename_.c_str(), flags, static_cast(0600))) == -1) { TORCH_CHECK(false, "unable to open file <", filename_, "> in read-write mode: ", c10::utils::str_error(errno), " (", errno, ")"); } } else if (flags_ & ALLOCATOR_MAPPED_SHAREDMEM) { #ifdef HAVE_SHM_OPEN // NOLINTNEXTLINE(bugprone-assignment-in-if-condition) - if((fd = shm_open(filename_.c_str(), flags, (mode_t)0600)) == -1) { + if((fd = shm_open(filename_.c_str(), flags, static_cast(0600))) == -1) { TORCH_CHECK(false, "unable to open shared memory object <", filename_, "> in read-write mode: ", c10::utils::str_error(errno), " (", errno, ")"); } #else @@ -503,7 +503,7 @@ RefcountedMapAllocator::RefcountedMapAllocator(WithFd /*unused*/, const char *fi void RefcountedMapAllocator::initializeAlloc() { TORCH_CHECK(base_ptr_, "base_ptr_ is null"); - MapInfo *map_info = (MapInfo*)base_ptr_; + MapInfo *map_info = static_cast(base_ptr_); #ifdef _WIN32 ReleaseContext* r_ctx = new ReleaseContext; @@ -539,7 +539,7 @@ void RefcountedMapAllocator::close() { } #else /* _WIN32 */ - MapInfo *info = (MapInfo*)(data); + MapInfo *info = static_cast(data); if (--info->refcount == 0) { #ifdef HAVE_SHM_UNLINK if (shm_unlink(filename_.c_str()) == -1) { diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp index b10d5c7d1fc3f..8614405bcdcf5 100644 --- a/aten/src/ATen/TensorIterator.cpp +++ b/aten/src/ATen/TensorIterator.cpp @@ -862,7 +862,7 @@ void TensorIteratorBase::narrow(int dim, int64_t start, int64_t size) { shape_[dim] = size; view_offsets_[dim] += start; for (auto& op : operands_) { - op.data = ((char*)op.data) + op.stride_bytes[dim] * start; + op.data = (static_cast(op.data)) + op.stride_bytes[dim] * start; } if (size == 1 && !is_reduction_) { coalesce_dimensions(); @@ -873,7 +873,7 @@ void TensorIteratorBase::select_all_keeping_dim(int start_dim, IntArrayRef indic TORCH_INTERNAL_ASSERT(start_dim <= ndim()); for (const auto i : c10::irange(start_dim, ndim())) { for (auto& op : operands_) { - op.data = ((char*)op.data) + op.stride_bytes[i] * indices[i - start_dim]; + op.data = (static_cast(op.data)) + op.stride_bytes[i] * indices[i - start_dim]; } shape_[i] = 1; } diff --git a/aten/src/ATen/TensorIteratorInternal.h b/aten/src/ATen/TensorIteratorInternal.h index ec0cb6c8fdfcb..9792494c8b5ab 100644 --- a/aten/src/ATen/TensorIteratorInternal.h +++ b/aten/src/ATen/TensorIteratorInternal.h @@ -41,7 +41,7 @@ inline void serial_for_each( IntArrayRef strides, char** base_ptrs, size_t ntensors, - typename TensorIteratorBase::loop2d_t loop, + TensorIteratorBase::loop2d_t loop, Range range) { const auto ndim = shape.size(); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( diff --git a/aten/src/ATen/VmapModeRegistrations.cpp b/aten/src/ATen/VmapModeRegistrations.cpp index 8d85032e4d07a..ca5a87bf2d253 100644 --- a/aten/src/ATen/VmapModeRegistrations.cpp +++ b/aten/src/ATen/VmapModeRegistrations.cpp @@ -72,10 +72,16 @@ TORCH_LIBRARY_IMPL(aten, VmapMode, m) { m.impl("random_", unsupportedRandomOp_>); m.impl("rand_like", unsupportedRandomOp>); + m.impl("rand_like.generator", unsupportedRandomOp, TENSOROPTIONS, std::optional>); m.impl("randn_like", unsupportedRandomOp>); + m.impl("randn_like.generator", unsupportedRandomOp, TENSOROPTIONS, std::optional>); m.impl("randint_like", unsupportedRandomOp>); + m.impl("randint_like.Tensor", unsupportedRandomOp>); m.impl("randint_like.low_dtype", unsupportedRandomOp>); + m.impl("randint_like.generator", unsupportedRandomOp, TENSOROPTIONS, std::optional>); + m.impl("randint_like.Tensor_generator", unsupportedRandomOp, TENSOROPTIONS, std::optional>); + m.impl("randint_like.low_generator_dtype", unsupportedRandomOp, TENSOROPTIONS, std::optional>); m.impl("rand", unsupportedRandomOp); m.impl("rand.generator", unsupportedRandomOp, TENSOROPTIONS>); diff --git a/aten/src/ATen/core/IListRef.h b/aten/src/ATen/core/IListRef.h index aa90faf838786..a11a78c03a3bb 100644 --- a/aten/src/ATen/core/IListRef.h +++ b/aten/src/ATen/core/IListRef.h @@ -190,12 +190,14 @@ class IListRef; * it to a function (e.g. `ImplT::(this_)`). */ #define TORCH_ILISTREF_UNWRAP(TAG, BODY) \ + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \ switch (TAG) { \ TORCH_ILISTREF_FORALL_TAGS(TORCH_ILISTREF_UNWRAP_CASE, BODY) \ break; \ default: \ TORCH_INTERNAL_ASSERT(false, "invalid IListRef tag."); \ - } + } \ + C10_DIAGNOSTIC_POP() enum class IListRefTag { #define DEFINE_TAG(tag, ...) tag, diff --git a/aten/src/ATen/core/TransformationHelper.h b/aten/src/ATen/core/TransformationHelper.h index dad18bd019bbe..dabba95a2f815 100644 --- a/aten/src/ATen/core/TransformationHelper.h +++ b/aten/src/ATen/core/TransformationHelper.h @@ -56,7 +56,7 @@ C10_HOST_DEVICE inline T uniform_int_full_range(V val) { * in this overloaded version */ template -C10_HOST_DEVICE inline std::enable_if_t), T>uniform_int(V val) { +C10_HOST_DEVICE inline std::enable_if_t, T>uniform_int(V val) { if constexpr (std::is_same_v) { return static_cast(val & 1); } else if constexpr (std::is_same_v) { diff --git a/aten/src/ATen/core/boxing/KernelFunction_impl.h b/aten/src/ATen/core/boxing/KernelFunction_impl.h index bb981c1d4efd2..0ce1b77dece0e 100644 --- a/aten/src/ATen/core/boxing/KernelFunction_impl.h +++ b/aten/src/ATen/core/boxing/KernelFunction_impl.h @@ -114,25 +114,25 @@ inline typename remove_symint::type unpackSymInt(T x) { } template <> -inline typename remove_symint::type unpackSymInt(c10::SymInt x) { +inline remove_symint::type unpackSymInt(c10::SymInt x) { return x.guard_int(__FILE__, __LINE__); } template <> -inline typename remove_symint::type unpackSymInt( +inline remove_symint::type unpackSymInt( c10::SymIntArrayRef x) { return C10_AS_INTARRAYREF_SLOW(x); } template <> -inline typename remove_symint>::type unpackSymInt( +inline remove_symint>::type unpackSymInt( std::optional x) { return x.has_value() ? std::make_optional(x->guard_int(__FILE__, __LINE__)) : std::nullopt; } template <> -inline typename remove_symint::type unpackSymInt( +inline remove_symint::type unpackSymInt( at::OptionalSymIntArrayRef x) { return x.has_value() ? std::make_optional(C10_AS_INTARRAYREF_SLOW(*x)) : std::nullopt; diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h index 34b1514f32cdb..5f11d1715f34a 100644 --- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h +++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h @@ -631,8 +631,8 @@ call_functor_with_args_from_stack_( Stack* stack, std::index_sequence /*unused*/, guts::typelist::typelist* /*unused*/) { - (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would - // be unused and we have to silence the compiler warning. + (void)stack; // when sizeof...(ivalue_arg_indices) == 0, this argument would + // be unused and we have to silence the compiler warning. // We're explicitly filtering out DispatchKeySet from the argument list. // Some kernels take a DispatchKeySet as their first argument in order to diff --git a/aten/src/ATen/core/enum_type.h b/aten/src/ATen/core/enum_type.h index e292f58487fbd..583c4dbecbe7a 100644 --- a/aten/src/ATen/core/enum_type.h +++ b/aten/src/ATen/core/enum_type.h @@ -18,6 +18,7 @@ struct TORCH_API EnumType : public NamedType { TypePtr value, std::vector enum_names_values, std::weak_ptr<::torch::jit::CompilationUnit> cu) { + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") switch (value->kind()) { case TypeKind::IntType: case TypeKind::FloatType: @@ -34,6 +35,7 @@ struct TORCH_API EnumType : public NamedType { value->str(), "', only int, float and string are supported"); } + C10_DIAGNOSTIC_POP() } std::string str() const override { diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index bb01c47e055a8..1ff8dd0410949 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -601,8 +601,8 @@ std::ostream& IValue::repr( double d = v.toDouble(); int c = std::fpclassify(d); if ((c == FP_NORMAL || c == FP_ZERO ) && std::abs(d) < 1e10) { - int64_t i = int64_t(d); - if (double(i) == d) { + int64_t i = static_cast(d); + if (static_cast(i) == d) { // -0.0 (signed zero) needs to be parsed as -0. if (i == 0 && std::signbit(d)) { return out << "-" << i << "."; @@ -799,8 +799,8 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) { double d = v.toDouble(); int c = std::fpclassify(d); if (c == FP_NORMAL || c == FP_ZERO) { - int64_t i = int64_t(d); - if (double(i) == d) { + int64_t i = static_cast(d); + if (static_cast(i) == d) { return out << i << "."; } } diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 35c0d3530adcc..666d1ade5789c 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -41,7 +41,7 @@ void standardizeVectorForUnion(std::vector* to_flatten); inline bool is_contiguous_strides( const IntArrayRef sizes, const IntArrayRef strides) { - int n_dim = static_cast(sizes.size()); + size_t n_dim = sizes.size(); if (n_dim == 0) { return true; } @@ -50,7 +50,7 @@ inline bool is_contiguous_strides( return false; } - for (int i = n_dim - 2; i >= 0; i--) { + for (int i = static_cast(n_dim) - 2; i >= 0; i--) { if (strides[i] != strides[i + 1] * sizes[i + 1]) { return false; } @@ -922,6 +922,7 @@ struct TORCH_API DictType : public SharedType { if (auto dyn = key->castRaw()) { kind = dyn->dynamicKind(); } + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") switch (kind) { case TypeKind::AnyType: case TypeKind::IntType: @@ -938,6 +939,7 @@ struct TORCH_API DictType : public SharedType { key->str(), "', only int, float, complex, Tensor, device and string keys are supported"); } + C10_DIAGNOSTIC_POP() } // aligned with the format in FunctionSchema @@ -2371,7 +2373,7 @@ struct TORCH_API AnyClassType : public Type { }; template<> -inline typename detail::CastReturnType::type Type::cast() { +inline detail::CastReturnType::type Type::cast() { if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType || kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) { return std::static_pointer_cast(static_cast(this)->shared_from_this()); @@ -2380,7 +2382,7 @@ inline typename detail::CastReturnType::type Type::cast() } template<> -inline typename detail::CastConstReturnType::type Type::cast() const { +inline detail::CastConstReturnType::type Type::cast() const { if (kind() == TypeKind::TupleType || kind() == TypeKind::FunctionType || kind() == TypeKind::ClassType || kind() == TypeKind::InterfaceType) { return std::static_pointer_cast(static_cast(this)->shared_from_this()); diff --git a/aten/src/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h b/aten/src/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h index aae7f2a79c2ea..5022a76f3dbcc 100644 --- a/aten/src/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h +++ b/aten/src/ATen/cpu/vec/vec128/vec128_bfloat16_neon.h @@ -19,6 +19,13 @@ inline namespace CPU_CAPABILITY { #error "Big endian is not supported." #endif +// GCC does not properly optimize bf16 operators +#if defined(__ARM_FEATURE_BF16) && (__clang_major__ >= 19) +#define BF16_ARITHMETIC_SUPPORTED() 1 +#else +#define BF16_ARITHMETIC_SUPPORTED() 0 +#endif + // Unlike the float16_t family of types, bfloat16_t is not available // when we're not targeting bfloat16 hardware support on some // platforms (but not Mac, so we have to be careful not to shadow the @@ -352,18 +359,35 @@ class Vectorized : public Vectorized16< other, &Vectorized::name); \ } - DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs) Vectorized frac() const; DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(trunc) DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(sqrt) #ifdef __ARM_FEATURE_BF16 + // Flip sign bit Vectorized neg() const { - return -values; + return vreinterpretq_bf16_s16(vreinterpretq_s16_bf16(values) ^ (-32768)); } + // Fast reciprocal is fine because we are truncating results Vectorized reciprocal() const { - return 1.0f / values; + auto x = vcvtq_low_f32_bf16(values); + auto y = vcvtq_high_f32_bf16(values); + x = vrecpeq_f32(x); + y = vrecpeq_f32(y); + return vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(x), y); + } + // Clearing the sign bit + Vectorized abs() const { + return vreinterpretq_bf16_u16(vreinterpretq_u16_bf16(values) & 0x7FFF); } +#else + DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(abs) + DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg) + DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal) +#endif + +// These functions are optimized on clang-21+ +#if BF16_ARITHMETIC_SUPPORTED() && (__clang_major__ >= 21) Vectorized operator==( const Vectorized& other) const { return values == other.values; @@ -394,8 +418,6 @@ class Vectorized : public Vectorized16< return values >= other.values; } #else - DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(neg) - DEFINE_UNARY_ELEMENTWISE_FUNC_VIA_FLOAT_METHOD(reciprocal) DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator==) DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator!=) DEFINE_BINARY_COMPARISON_OPERATOR_VIA_FLOAT_METHOD(operator<) @@ -451,7 +473,7 @@ template <> Vectorized inline operator+( const Vectorized& a, const Vectorized& b) { -#ifdef __ARM_FEATURE_BF16 +#if BF16_ARITHMETIC_SUPPORTED() bfloat16x8_t x = a; bfloat16x8_t y = b; return x + y; @@ -464,7 +486,7 @@ template <> Vectorized inline operator-( const Vectorized& a, const Vectorized& b) { -#ifdef __ARM_FEATURE_BF16 +#if BF16_ARITHMETIC_SUPPORTED() bfloat16x8_t x = a; bfloat16x8_t y = b; return x - y; @@ -477,7 +499,7 @@ template <> Vectorized inline operator*( const Vectorized& a, const Vectorized& b) { -#ifdef __ARM_FEATURE_BF16 +#if BF16_ARITHMETIC_SUPPORTED() bfloat16x8_t x = a; bfloat16x8_t y = b; return x * y; @@ -490,7 +512,7 @@ template <> Vectorized inline operator/( const Vectorized& a, const Vectorized& b) { -#ifdef __ARM_FEATURE_BF16 +#if BF16_ARITHMETIC_SUPPORTED() bfloat16x8_t x = a; bfloat16x8_t y = b; return x / y; @@ -607,7 +629,7 @@ Vectorized inline fmadd( const Vectorized& a, const Vectorized& b, const Vectorized& c) { -#ifdef __ARM_FEATURE_BF16 +#if BF16_ARITHMETIC_SUPPORTED() bfloat16x8_t x = a; bfloat16x8_t y = b; bfloat16x8_t z = c; @@ -627,7 +649,7 @@ Vectorized inline fnmadd( const Vectorized& a, const Vectorized& b, const Vectorized& c) { -#ifdef __ARM_FEATURE_BF16 +#if BF16_ARITHMETIC_SUPPORTED() bfloat16x8_t x = a; bfloat16x8_t y = b; bfloat16x8_t z = c; @@ -643,7 +665,7 @@ Vectorized inline fmsub( const Vectorized& a, const Vectorized& b, const Vectorized& c) { -#ifdef __ARM_FEATURE_BF16 +#if BF16_ARITHMETIC_SUPPORTED() bfloat16x8_t x = a; bfloat16x8_t y = b; bfloat16x8_t z = c; @@ -659,7 +681,7 @@ Vectorized inline fnmsub( const Vectorized& a, const Vectorized& b, const Vectorized& c) { -#ifdef __ARM_FEATURE_BF16 +#if BF16_ARITHMETIC_SUPPORTED() bfloat16x8_t x = a; bfloat16x8_t y = b; bfloat16x8_t z = c; diff --git a/aten/src/ATen/cpu/vec/vec128/vec128_convert.h b/aten/src/ATen/cpu/vec/vec128/vec128_convert.h index b2e6016bcc12e..64e9588c32881 100644 --- a/aten/src/ATen/cpu/vec/vec128/vec128_convert.h +++ b/aten/src/ATen/cpu/vec/vec128/vec128_convert.h @@ -6,9 +6,9 @@ namespace at::vec { inline namespace CPU_CAPABILITY { #if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) -// Enable auto-vectorization for GCC-13+ and clang-17+ +// Enable auto-vectorization for clang-17+ // GCC-12 has a bug: gcc.gnu.org/bugzilla/show_bug.cgi?id=117001 -#if __GNUC__ > 12 || (defined(__clang__) && (__clang_major__ >= 17)) +#if defined(__clang__) && (__clang_major__ >= 17) template inline void convertImpl( diff --git a/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h b/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h index 67760ec967aa1..c479fc2e4aeb2 100644 --- a/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h +++ b/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h @@ -309,7 +309,7 @@ class Vectorized { DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1) // Implementation copied from Arm Optimized Routine // https://github.com/ARM-software/optimized-routines/blob/master/math/aarch64/advsimd/expf.c - Vectorized exp_u20() const { + inline Vectorized vexpq_f32_u20() const { // bail out to sleef if it's a special case: // i.e. there's an input s.t. |input| > 87.3.... const float32x4_t special_bound = vdupq_n_f32(0x1.5d5e2ap+6f); @@ -348,6 +348,9 @@ class Vectorized { return vfmaq_f32(scale, poly, scale); } + Vectorized exp_u20() const { + return vexpq_f32_u20(); + } Vectorized fexp_u20() const { return exp_u20(); } @@ -634,7 +637,7 @@ inline Vectorized Vectorized::erf() const { // - exp(- x * x) auto pow_2 = (*this) * (*this); auto neg_pow_2 = pow_2 ^ neg_zero_vec; - auto tmp4 = neg_pow_2.exp(); + auto tmp4 = neg_pow_2.vexpq_f32_u20(); auto tmp5 = tmp4 ^ neg_zero_vec; // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) auto tmp6 = t * tmp5; diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h index 2b70564b9ca81..31a881ff28665 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h @@ -514,7 +514,7 @@ struct Vectorized : public Vectorizedqi { using float_vec_return_type = std::array, kFloatNumVecs>; using int_vec_return_type = std::array, kIntNumVecs>; - using value_type = typename c10::qint8::underlying; + using value_type = c10::qint8::underlying; public: using Vectorizedqi::Vectorizedqi; @@ -727,7 +727,7 @@ struct Vectorized : public Vectorizedqi { using float_vec_return_type = std::array, kFloatNumVecs>; using int_vec_return_type = std::array, kIntNumVecs>; - using value_type = typename c10::quint8::underlying; + using value_type = c10::quint8::underlying; public: using Vectorizedqi::Vectorizedqi; diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h index 64ba47e0f0646..e723cba61bf2d 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h @@ -567,7 +567,7 @@ struct Vectorized : public Vectorizedqi { using float_vec_return_type = std::array, 4>; using int_vec_return_type = std::array, 4>; - using value_type = typename c10::qint8::underlying; + using value_type = c10::qint8::underlying; public: using Vectorizedqi::Vectorizedqi; @@ -804,7 +804,7 @@ struct Vectorized : public Vectorizedqi { using float_vec_return_type = std::array, 4>; using int_vec_return_type = std::array, 4>; - using value_type = typename c10::quint8::underlying; + using value_type = c10::quint8::underlying; public: using Vectorizedqi::Vectorizedqi; diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index b4441981b3d87..8741d099ea914 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -672,7 +672,7 @@ struct Vectorized { return map(std::sqrt); } Vectorized reciprocal() const { - return map([](T x) { return (T)(1) / x; }); + return map([](T x) { return (T)1 / x; }); } Vectorized rsqrt() const { return map([](T x) { return (T)1 / std::sqrt(x); }); diff --git a/aten/src/ATen/cpu/vml.h b/aten/src/ATen/cpu/vml.h index 26547e99a1b57..143f45b2e97c1 100644 --- a/aten/src/ATen/cpu/vml.h +++ b/aten/src/ATen/cpu/vml.h @@ -46,7 +46,7 @@ inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) { parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) { map( [](const Vectorized& x) { - return Vectorized((scalar_t)(1)) / x.sqrt(); + return Vectorized((scalar_t)1) / x.sqrt(); }, out + begin, in + begin, diff --git a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp index 2e387fbc264d7..adef16d2deda2 100644 --- a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp +++ b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp @@ -194,8 +194,8 @@ void CUDAGeneratorState::unregister_graph(cuda::CUDAGraph* graph) { void CUDAGeneratorState::capture_prologue() { capturing_ = true; offset_intragraph_ = 0; - seed_extragraph_.fill_(int64_t(seed_)); - offset_extragraph_.fill_(int64_t(0)); + seed_extragraph_.fill_(static_cast(seed_)); + offset_extragraph_.fill_(0); } /** @@ -216,8 +216,8 @@ void CUDAGeneratorState::replay_prologue(uint64_t wholegraph_increment) { at::cuda::assertNotCapturing( "Cannot prepare for replay during capturing stage."); if (wholegraph_increment) { - seed_extragraph_.fill_(int64_t(seed_)); - offset_extragraph_.fill_(int64_t(philox_offset_per_thread_)); + seed_extragraph_.fill_(static_cast(seed_)); + offset_extragraph_.fill_(static_cast(philox_offset_per_thread_)); // Applies the total increment achieved during previous captures to update the // offset. increase(wholegraph_increment); @@ -329,7 +329,7 @@ c10::intrusive_ptr CUDAGeneratorImpl::get_state() const { constexpr size_t offset_size = sizeof(int64_t); constexpr size_t total_size = seed_size + offset_size; - auto state_tensor = at::detail::empty_cpu({(int64_t)total_size}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt); + auto state_tensor = at::detail::empty_cpu({static_cast(total_size)}, ScalarType::Byte, std::nullopt, std::nullopt, std::nullopt, std::nullopt); auto rng_state = state_tensor.data_ptr(); auto current_seed = this->current_seed(); auto offset = static_cast(this->philox_offset_per_thread()); // Note that old THCGeneratorState had offset as std::atomic diff --git a/aten/src/ATen/cuda/CUDAGreenContext.cpp b/aten/src/ATen/cuda/CUDAGreenContext.cpp index 6108f6e96a818..8aa05b80f82f9 100644 --- a/aten/src/ATen/cuda/CUDAGreenContext.cpp +++ b/aten/src/ATen/cuda/CUDAGreenContext.cpp @@ -1,78 +1,90 @@ #include -namespace at::cuda { - GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) { -#if CUDA_HAS_GREEN_CONTEXT - int driver_version; - C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version)); - TORCH_CHECK( - driver_version >= 12080, "cuda driver too old to use green context!"); - CUcontext pctx = nullptr; - C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx)); - if (C10_UNLIKELY(!pctx)) { - TORCH_WARN( - "Attempted to create a green context but" - " there was no primary context! Creating a primary context..."); - - cudaFree(0); - } - - CUdevice device; - device_id_ = device_id; - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id)); - - // Get device resources - CUdevResource device_resource; - C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_( - device, &device_resource, CU_DEV_RESOURCE_TYPE_SM)); - - // Split resources - std::vector result(1); - auto result_data = result.data(); - unsigned int nb_groups = 1; - CUdevResource remaining; +#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12030) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) +#include +#include +#include +#define HAS_CUDA_GREEN_CONTEXT() 1 +#else +#define HAS_CUDA_GREEN_CONTEXT() 0 +// Suppress unsued private field warnings as this class is not supposed to be called +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-private-field") +#endif - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_( - result_data, - &nb_groups, - &device_resource, - &remaining, - 0, // default flags - num_sms)); - - TORCH_CHECK(nb_groups == 1, "Failed to create single resource group"); - - // Generate resource descriptor - CUdevResourceDesc desc; - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_( - &desc, result_data, 1)); +namespace at::cuda { - // Create green context - // CU_GREEN_CTX_DEFAULT_STREAM is required per docs: - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html - C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_( - &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM)); +GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) { +#if HAS_CUDA_GREEN_CONTEXT() + int driver_version; + C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version)); + TORCH_CHECK( + driver_version >= 12080, "cuda driver too old to use green context!"); + CUcontext pctx = nullptr; + C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuCtxGetCurrent_(&pctx)); + if (C10_UNLIKELY(!pctx)) { + TORCH_WARN( + "Attempted to create a green context but" + " there was no primary context! Creating a primary context..."); + + cudaFree(0); + } - // Convert to regular context - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_)); - TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!"); + CUdevice device; + device_id_ = device_id; + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id)); + + // Get device resources + CUdevResource device_resource; + C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_( + device, &device_resource, CU_DEV_RESOURCE_TYPE_SM)); + + // Split resources + std::vector result(1); + auto result_data = result.data(); + unsigned int nb_groups = 1; + CUdevResource remaining; + + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_( + result_data, + &nb_groups, + &device_resource, + &remaining, + 0, // default flags + num_sms)); + + TORCH_CHECK(nb_groups == 1, "Failed to create single resource group"); + + // Generate resource descriptor + CUdevResourceDesc desc; + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_( + &desc, result_data, 1)); + + // Create green context + // CU_GREEN_CTX_DEFAULT_STREAM is required per docs: + // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html + C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuGreenCtxCreate_( + &green_ctx_, desc, device, CU_GREEN_CTX_DEFAULT_STREAM)); + + // Convert to regular context + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuCtxFromGreenCtx_(&context_, green_ctx_)); + TORCH_CHECK(context_, "Green ctx conversion to regular ctx failed!"); #else - TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); #endif } std::unique_ptr GreenContext::create( uint32_t num_sms, std::optional device_id) { -#if CUDA_HAS_GREEN_CONTEXT +#if HAS_CUDA_GREEN_CONTEXT() if (!device_id.has_value()) { device_id = at::cuda::current_device(); } - return std::make_unique(device_id.value(), num_sms); + return std::unique_ptr(new GreenContext(device_id.value(), num_sms)); #else TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); #endif @@ -80,7 +92,7 @@ namespace at::cuda { // Implement move operations GreenContext::GreenContext(GreenContext&& other) noexcept{ -#if CUDA_HAS_GREEN_CONTEXT +#if HAS_CUDA_GREEN_CONTEXT() device_id_ = std::exchange(other.device_id_, -1); green_ctx_ = std::exchange(other.green_ctx_, nullptr); context_ = std::exchange(other.context_, nullptr); @@ -91,7 +103,7 @@ namespace at::cuda { } GreenContext& GreenContext::operator=(GreenContext&& other) noexcept{ -#if CUDA_HAS_GREEN_CONTEXT +#if HAS_CUDA_GREEN_CONTEXT() if (this != &other) { // Clean up current resources if (green_ctx_) { @@ -120,7 +132,7 @@ namespace at::cuda { } GreenContext::~GreenContext() noexcept{ -#if CUDA_HAS_GREEN_CONTEXT +#if HAS_CUDA_GREEN_CONTEXT() C10_CUDA_DRIVER_CHECK( c10::cuda::DriverAPI::get()->cuGreenCtxDestroy_(green_ctx_)); #else @@ -128,25 +140,9 @@ namespace at::cuda { #endif } - // Get the underlying CUDA context - CUcontext GreenContext::getContext() const { -#if CUDA_HAS_GREEN_CONTEXT - return context_; -#else - TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); -#endif - } - - // Get the underlying green context -#if CUDA_HAS_GREEN_CONTEXT - CUgreenCtx GreenContext::getGreenContext() const { - return green_ctx_; - } -#endif - // Make this context current void GreenContext::setContext() { -#if CUDA_HAS_GREEN_CONTEXT +#if HAS_CUDA_GREEN_CONTEXT() auto current_stream = c10::cuda::getCurrentCUDAStream(); parent_stream_ = current_stream.stream(); @@ -175,7 +171,7 @@ namespace at::cuda { } void GreenContext::popContext() { -#if CUDA_HAS_GREEN_CONTEXT +#if HAS_CUDA_GREEN_CONTEXT() // see above note about stream being hardcoded to the default stream at::cuda::CUDAEvent ev; ev.record(c10::cuda::getCurrentCUDAStream()); diff --git a/aten/src/ATen/cuda/CUDAGreenContext.h b/aten/src/ATen/cuda/CUDAGreenContext.h index 4f198e2e1c06e..f9fa2cd112e3f 100644 --- a/aten/src/ATen/cuda/CUDAGreenContext.h +++ b/aten/src/ATen/cuda/CUDAGreenContext.h @@ -1,53 +1,38 @@ #pragma once #include - -#if defined(CUDA_VERSION) && !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) -#include #include -#include -#include -#include -#define CUDA_HAS_GREEN_CONTEXT 1 -#else -#define CUDA_HAS_GREEN_CONTEXT 0 -#endif + +// Forward declare green context as opaque ptr +typedef struct CUgreenCtx_st* CUgreenCtx; namespace at::cuda { class TORCH_CUDA_CPP_API GreenContext { public: - GreenContext(uint32_t device_id, uint32_t num_sms); - - static std::unique_ptr create(uint32_t num_sms, std::optional device_id); + // Green context creation + static std::unique_ptr create( + uint32_t num_sms, + std::optional device_id); + ~GreenContext() noexcept; // Delete copy constructor and assignment GreenContext(const GreenContext&) = delete; GreenContext& operator=(const GreenContext&) = delete; - // Implement move operations - GreenContext(GreenContext&& other) noexcept; - GreenContext& operator=(GreenContext&& other) noexcept; - ~GreenContext() noexcept; - - // Get the underlying CUDA context - CUcontext getContext() const; - - // Get the underlying green context -#if CUDA_HAS_GREEN_CONTEXT - CUgreenCtx getGreenContext() const; -#endif - // Make this context current void setContext(); void popContext(); private: -#if CUDA_HAS_GREEN_CONTEXT + GreenContext(uint32_t device_id, uint32_t num_sms); + // Implement move operations + GreenContext(GreenContext&& other) noexcept; + GreenContext& operator=(GreenContext&& other) noexcept; + int32_t device_id_ = -1; CUgreenCtx green_ctx_ = nullptr; CUcontext context_ = nullptr; cudaStream_t parent_stream_ = nullptr; -#endif }; } // namespace at::cuda diff --git a/aten/src/ATen/cuda/CUDASparse.h b/aten/src/ATen/cuda/CUDASparse.h index e00e50b38d2de..ceffa8e86eedd 100644 --- a/aten/src/ATen/cuda/CUDASparse.h +++ b/aten/src/ATen/cuda/CUDASparse.h @@ -7,17 +7,6 @@ #endif -#if defined(USE_ROCM) -// hipSparse const API added in v2.4.0 -#if HIPSPARSE_VERSION >= 200400 -#define AT_USE_HIPSPARSE_GENERIC_API() 1 -#else -#define AT_USE_HIPSPARSE_GENERIC_API() 1 -#endif -#else // USE_ROCM -#define AT_USE_HIPSPARSE_GENERIC_API() 0 -#endif // USE_ROCM - // cuSparse Generic API spsv function was added in CUDA 11.3.0 #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500) #define AT_USE_CUSPARSE_GENERIC_SPSV() 1 diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp index d7832c761ae55..6175e69827e2f 100644 --- a/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp @@ -155,8 +155,8 @@ size_t parseChosenWorkspaceSize() { while (next != end) { std::smatch match = *next; TORCH_CHECK(match.size() == 3, "Expected CUBLAS_WORKSPACE_SPACE_CONFIG match of size 3 (Format :SIZE:COUNT)"); - size_t curr_size = (size_t) std::stoi(match.str(1)); - size_t count = (size_t) std::stoi(match.str(2)); + size_t curr_size = std::stoull(match.str(1)); + size_t count = std::stoull(match.str(2)); total_size += curr_size * 1024 * count; next++; } diff --git a/aten/src/ATen/cuda/Sleep.cu b/aten/src/ATen/cuda/Sleep.cu index 586520e25327d..a1feb045732a2 100644 --- a/aten/src/ATen/cuda/Sleep.cu +++ b/aten/src/ATen/cuda/Sleep.cu @@ -1,6 +1,7 @@ #include #include +#include #include #include @@ -24,8 +25,22 @@ __global__ void spin_kernel(int64_t cycles) { #endif } } + +thread_local int *flag = nullptr; + +__global__ void busy_wait_for_flag_kernel(int *flag) { + atomicExch(flag, 1); + while (atomicAdd(flag, 0) == 1) { + // do nothing + } +} + +__global__ void clear_flag_kernel(int *flag) { + atomicExch(flag, 0); } +} // anonymous namespace + void sleep(int64_t cycles) { dim3 grid(1); dim3 block(1); @@ -33,6 +48,26 @@ void sleep(int64_t cycles) { C10_CUDA_KERNEL_LAUNCH_CHECK(); } +void busy_wait_for_flag() { + if (!flag) { + flag = (int*)c10::cuda::CUDACachingAllocator::raw_alloc(sizeof(int)); + } + dim3 grid(1); + dim3 block(1); + busy_wait_for_flag_kernel<<>>(flag); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void clear_flag() { + if (!flag) { + flag = (int*)c10::cuda::CUDACachingAllocator::raw_alloc(sizeof(int)); + } + dim3 grid(1); + dim3 block(1); + clear_flag_kernel<<>>(flag); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + #ifdef USE_ROCM __global__ void flush_icache_kernel() { diff --git a/aten/src/ATen/cuda/Sleep.h b/aten/src/ATen/cuda/Sleep.h index ef5e83a832f73..28719593af009 100644 --- a/aten/src/ATen/cuda/Sleep.h +++ b/aten/src/ATen/cuda/Sleep.h @@ -7,6 +7,11 @@ namespace at::cuda { // enqueues a kernel that spins for the specified number of cycles TORCH_CUDA_CU_API void sleep(int64_t cycles); +// enqueues a kernel that spins until a flag is cleared by a +// corresponding call to clear_flag() +TORCH_CUDA_CU_API void busy_wait_for_flag(); +TORCH_CUDA_CU_API void clear_flag(); + // flushes instruction cache for ROCm; no-op for CUDA TORCH_CUDA_CU_API void flush_icache(); diff --git a/aten/src/ATen/cuda/detail/BLASConstants.cu b/aten/src/ATen/cuda/detail/BLASConstants.cu index 9673880447054..2131c09965fee 100644 --- a/aten/src/ATen/cuda/detail/BLASConstants.cu +++ b/aten/src/ATen/cuda/detail/BLASConstants.cu @@ -2,8 +2,6 @@ #include #include -#include - namespace at { namespace cuda { namespace detail { @@ -12,39 +10,36 @@ __device__ __constant__ float cublas_one_device; __device__ __constant__ float cublas_zero_device; float *get_cublas_device_one() { - static c10::once_flag init_flag; - - c10::call_once(init_flag, []() { + static float *ptr = nullptr; + static auto init_flag = [&]() { const float one = 1.f; AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_one_device, &one, sizeof(float))); - }); + AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast(&ptr), cublas_one_device)); + return true; + }(); - float *ptr; - AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast(&ptr), cublas_one_device)); return ptr; } float *get_cublas_device_zero() { - static c10::once_flag init_flag; - - c10::call_once(init_flag, []() { + static float *ptr = nullptr; + static auto init_flag = [&]() { const float zero = 0.f; AT_CUDA_CHECK(cudaMemcpyToSymbol(cublas_zero_device, &zero, sizeof(float))); - }); + AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast(&ptr), cublas_zero_device)); + return true; + }(); - float *ptr; - AT_CUDA_CHECK(cudaGetSymbolAddress(reinterpret_cast(&ptr), cublas_zero_device)); return ptr; } float *get_user_alpha_ptr() { static float *alpha_ptr; - static c10::once_flag init_flag; - - c10::call_once(init_flag, []() { + static bool init_flag [[maybe_unused]] = []() { AT_CUDA_CHECK(cudaMalloc(&alpha_ptr, sizeof(float))); - }); + return true; + }(); return alpha_ptr; } diff --git a/aten/src/ATen/cudnn/Descriptors.cpp b/aten/src/ATen/cudnn/Descriptors.cpp index dbd178e0f8eee..8636d267209e9 100644 --- a/aten/src/ATen/cudnn/Descriptors.cpp +++ b/aten/src/ATen/cudnn/Descriptors.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -136,9 +137,9 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo "Weight strides: ", t.strides(), "\n", "cuDNN suggested memory_format: ", memory_format); - int size[CUDNN_DIM_MAX]; + std::array size; for (const auto i : c10::irange(dim)) { - size[i] = (int) t.size(i); + size[i] = static_cast(t.size(i)); } for (const auto i : c10::irange(dim, pad)) { size[i] = 1; @@ -156,7 +157,7 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo default: TORCH_INTERNAL_ASSERT(false, "unsupported memory_format for cuDNN filters"); } - set(getDataType(t), static_cast(dim), size, filter_format); + set(getDataType(t), static_cast(dim), size.data(), filter_format); } std::string cudnnMemoryFormatToString(cudnnTensorFormat_t tformat) { diff --git a/aten/src/ATen/detail/MTIAHooksInterface.h b/aten/src/ATen/detail/MTIAHooksInterface.h index b415862f29e7c..58c7a0304181c 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.h +++ b/aten/src/ATen/detail/MTIAHooksInterface.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -8,8 +9,8 @@ #include -#include #include +#include #include namespace at { @@ -25,8 +26,7 @@ constexpr const char* MTIA_HELP = struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { // this fails the implementation if MTIAHooks functions are called, but // MTIA backend is not present. -#define FAIL_MTIAHOOKS_FUNC(func) \ - TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend."); +#define FAIL_MTIAHOOKS_FUNC(func) TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend."); ~MTIAHooksInterface() override = default; @@ -91,7 +91,7 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA); } - virtual void setCurrentStream(const c10::Stream& /*stream*/ ) const { + virtual void setCurrentStream(const c10::Stream& /*stream*/) const { FAIL_MTIAHOOKS_FUNC(__func__); } @@ -123,11 +123,9 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { FAIL_MTIAHOOKS_FUNC(__func__); } - - virtual void recordMemoryHistory( - const std::optional& /*enabled*/, - const std::string& /*stacks*/, - size_t /*max_entries*/) const { + virtual void recordMemoryHistory(const std::optional& /*enabled*/, + const std::string& /*stacks*/, + size_t /*max_entries*/) const { FAIL_MTIAHOOKS_FUNC(__func__); } @@ -151,13 +149,46 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { } virtual bool isAvailable() const override; + + /* MTIAGraph related APIs */ + virtual int64_t mtiagraphCreate(bool keep_graph = false) const { + FAIL_MTIAHOOKS_FUNC(__func__); + return -1; + } + + virtual void mtiagraphDestroy(int64_t handle) const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + virtual void mtiagraphCaptureBegin(int64_t handle, MempoolId_t pool) const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + virtual void mtiagraphCaptureEnd(int64_t handle) const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + virtual void mtiagraphInstantiate(int64_t handle) const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + virtual void mtiagraphReplay(int64_t handle) const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + virtual void mtiagraphReset(int64_t handle) const { + FAIL_MTIAHOOKS_FUNC(__func__); + } + + virtual MempoolId_t mtiagraphPool(int64_t handle) const { + FAIL_MTIAHOOKS_FUNC(__func__); + } }; struct TORCH_API MTIAHooksArgs {}; TORCH_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs); -#define REGISTER_MTIA_HOOKS(clsname) \ - C10_REGISTER_CLASS(MTIAHooksRegistry, clsname, clsname) +#define REGISTER_MTIA_HOOKS(clsname) C10_REGISTER_CLASS(MTIAHooksRegistry, clsname, clsname) namespace detail { TORCH_API const MTIAHooksInterface& getMTIAHooks(); diff --git a/aten/src/ATen/functorch/ADInterpreters.cpp b/aten/src/ATen/functorch/ADInterpreters.cpp index 9bf7de9d3baa7..666fdf143039a 100644 --- a/aten/src/ATen/functorch/ADInterpreters.cpp +++ b/aten/src/ATen/functorch/ADInterpreters.cpp @@ -198,7 +198,7 @@ static void autogradBasedTransformSendToNext( } // Step 6 - stack->erase(stack->end() - std::ptrdiff_t(args_size + ret_size), stack->end() - std::ptrdiff_t(ret_size)); + stack->erase(stack->end() - static_cast(args_size + ret_size), stack->end() - static_cast(ret_size)); } void GradInterpreterPtr::processImpl( diff --git a/aten/src/ATen/functorch/BatchRulesNorm.cpp b/aten/src/ATen/functorch/BatchRulesNorm.cpp index 6da55762e1595..4546c56e2f586 100644 --- a/aten/src/ATen/functorch/BatchRulesNorm.cpp +++ b/aten/src/ATen/functorch/BatchRulesNorm.cpp @@ -443,14 +443,14 @@ static bool has_same_shape( if (!tensor.defined()) { return true; } - if (rankWithoutBatchDim(tensor, tensor_bdim) != (int64_t) normalized_shape.size()) { + if (rankWithoutBatchDim(tensor, tensor_bdim) != static_cast(normalized_shape.size())) { return false; } const auto tensor_shape = tensor.sizes(); for (const auto i : c10::irange(normalized_shape.size())) { auto j = i; // (0, 1, 2), 1 -> (0, 2, 3) - if (tensor_bdim.has_value() && (int64_t)i >= tensor_bdim.value()) { + if (tensor_bdim.has_value() && static_cast(i) >= tensor_bdim.value()) { j = j + 1; } if (normalized_shape[i] != tensor_shape[j]) { diff --git a/aten/src/ATen/functorch/BatchRulesReduceOps.cpp b/aten/src/ATen/functorch/BatchRulesReduceOps.cpp index de1a37a9b4320..ecee801965e71 100644 --- a/aten/src/ATen/functorch/BatchRulesReduceOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesReduceOps.cpp @@ -135,7 +135,7 @@ static void boxed_reduction_batch_rule(const c10::OperatorHandle& op, torch::jit reduction_case = ReductionCase::DimArray; dims = arguments[dim_arg_pos].toIntList().vec(); if (dims.empty()) { - auto all_dims = range(0, std::max((int64_t)1, logical_dim)); + auto all_dims = range(0, std::max(static_cast(1), logical_dim)); dims = std::vector(all_dims.begin(), all_dims.end()); } } else if (arguments[dim_arg_pos].isInt()) { diff --git a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp index f5c770371de8e..ae4b5b25988e4 100644 --- a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp @@ -432,7 +432,7 @@ namespace { // Eg. Given `indexed_shape.size()` is 5 and // shape of `values` is (N, 2, 3), then following block // will reshape `values` to (N, 1, 1, 2, 3). - if ( (int64_t) indexed_shape.size() > values_.dim()) { + if ( static_cast(indexed_shape.size()) > values_.dim()) { auto values_sizes = values_.sym_sizes(); // number of unit dims (for broadcasting value to indexed_shape) diff --git a/aten/src/ATen/functorch/BatchRulesViews.cpp b/aten/src/ATen/functorch/BatchRulesViews.cpp index 08db1d202b4eb..a78d8b0eec7e1 100644 --- a/aten/src/ATen/functorch/BatchRulesViews.cpp +++ b/aten/src/ATen/functorch/BatchRulesViews.cpp @@ -109,7 +109,7 @@ std::tuple> repeat_batch_rule( SymDimVector sizes_with_bdim = { sizes.begin(), sizes.end() }; sizes_with_bdim.insert(sizes_with_bdim.begin(), 1); auto self_ = moveBatchDimToFront(self, self_bdim); - while (self_.dim() < (int64_t)sizes_with_bdim.size()) { + while (self_.dim() < static_cast(sizes_with_bdim.size())) { self_ = self_.unsqueeze(1); } return std::make_tuple(self_.repeat_symint(sizes_with_bdim), 0); @@ -534,20 +534,20 @@ Tensor trace_decomp(const Tensor& tensor) { std::tuple> tril_batch_rule( const Tensor& self, std::optional self_bdim, - int64_t diagonal = 0) { + c10::SymInt diagonal = 0) { TORCH_CHECK(self.dim() >= 2, "tril: The input tensor must have at least 2 dimensions."); auto self_ = moveBatchDimToFront(self, self_bdim); - auto result = at::tril(self_, diagonal); + auto result = at::tril_symint(self_, std::move(diagonal)); return std::make_tuple(std::move(result), 0); } std::tuple> triu_batch_rule( const Tensor& self, std::optional self_bdim, - int64_t diagonal = 0) { + c10::SymInt diagonal = 0) { TORCH_CHECK(self.dim() >= 2, "triu: The input tensor must have at least 2 dimensions."); auto self_ = moveBatchDimToFront(self, self_bdim); - auto result = at::triu(self_, diagonal); + auto result = at::triu_symint(self_, std::move(diagonal)); return std::make_tuple(std::move(result), 0); } diff --git a/aten/src/ATen/functorch/BatchedFallback.cpp b/aten/src/ATen/functorch/BatchedFallback.cpp index 92123c1cd0e22..aab1da68053b7 100644 --- a/aten/src/ATen/functorch/BatchedFallback.cpp +++ b/aten/src/ATen/functorch/BatchedFallback.cpp @@ -191,7 +191,7 @@ static void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, t // simplicity. When that is not the case, this code should be updated. const auto& argument = (*stack)[arguments_begin + arg_idx]; if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end() - || (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) { + || static_cast(arg_idx) != *batched_tensor_inputs_pos_iter) { // argument isn't a BatchedTensor torch::jit::push(stack, argument); continue; @@ -345,7 +345,7 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta // simplicity. When that is not the case, this code should be updated. const auto& argument = (*stack)[arguments_begin + arg_idx]; if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end() - || (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) { + || static_cast(arg_idx) != *batched_tensor_inputs_pos_iter) { // argument isn't a BatchedTensor torch::jit::push(stack, argument); continue; @@ -473,7 +473,7 @@ void batchedNestedTensorForLoopFallback(const c10::OperatorHandle& op, torch::ji // simplicity. When that is not the case, this code should be updated. const auto& argument = (*stack)[arguments_begin + arg_idx]; if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end() - || (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) { + || static_cast(arg_idx) != *batched_tensor_inputs_pos_iter) { // argument isn't a BatchedTensor torch::jit::push(stack, argument); continue; diff --git a/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp b/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp index 0c2ed37d23765..e51f4901f36bc 100644 --- a/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp +++ b/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp @@ -157,7 +157,7 @@ Tensor& squeeze__batching_rule(Tensor& self) { const auto physical_shape = batched->value().sizes(); auto how_many_dims_of_size_1_before_bdim = 0; for (const auto i : c10::irange(0, physical_shape.size())) { - if ((int64_t)i == bdim) { + if (static_cast(i) == bdim) { break; } if (physical_shape[i] == 1) { @@ -573,7 +573,7 @@ Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) { } auto new_dim = bdim_size.has_value() ? dim + 1 : dim; - std::optional new_bdim = bdim_size.has_value() ? std::make_optional((int64_t)0) : std::nullopt; + std::optional new_bdim = bdim_size.has_value() ? std::make_optional(static_cast(0)) : std::nullopt; auto result = at::cat(tensors_to_cat, new_dim); return makeBatched(result, new_bdim, get_current_level()); } diff --git a/aten/src/ATen/mps/MPSDevice.mm b/aten/src/ATen/mps/MPSDevice.mm index 5a37490c02402..839ec2f1ac624 100644 --- a/aten/src/ATen/mps/MPSDevice.mm +++ b/aten/src/ATen/mps/MPSDevice.mm @@ -1,7 +1,5 @@ // Copyright © 2022 Apple Inc. -#include - #include #include #include @@ -10,9 +8,6 @@ namespace at::mps { -static std::unique_ptr mps_device; -static c10::once_flag mpsdev_init; - static inline MTLLanguageVersion getMetalLanguageVersion(const id& device) { // MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants) // host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+ @@ -21,8 +16,8 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de } MPSDevice* MPSDevice::getInstance() { - c10::call_once(mpsdev_init, [] { mps_device = std::unique_ptr(new MPSDevice()); }); - return mps_device.get(); + static MPSDevice mps_device; + return &mps_device; } MPSDevice::~MPSDevice() { diff --git a/aten/src/ATen/native/AveragePool2d.cpp b/aten/src/ATen/native/AveragePool2d.cpp index 368dc02c2832f..035228285bc28 100644 --- a/aten/src/ATen/native/AveragePool2d.cpp +++ b/aten/src/ATen/native/AveragePool2d.cpp @@ -25,18 +25,19 @@ TORCH_PRECOMPUTE_META_FUNC(avg_pool2d) // #20866, #22032: Guarantee this for the official C++ API? TORCH_CHECK(kernel_size.size() == 1 || kernel_size.size() == 2, "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints"); - const int64_t kH = kernel_size[0]; - const int64_t kW = kernel_size.size() == 1 ? kH : kernel_size[1]; + const int kH = safe_downcast(kernel_size[0]); + const int kW = kernel_size.size() == 1 ? kH : safe_downcast(kernel_size[1]); TORCH_CHECK(stride.empty() || stride.size() == 1 || stride.size() == 2, "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints"); - const int64_t dH = stride.empty() ? kH : stride[0]; - const int64_t dW = stride.empty() ? kW : stride.size() == 1 ? dH : stride[1]; + const int dH = stride.empty() ? kH : safe_downcast(stride[0]); + const int dW = stride.empty() ? kW : + stride.size() == 1 ? dH : safe_downcast(stride[1]); TORCH_CHECK(padding.size() == 1 || padding.size() == 2, "avg_pool2d: padding must either be a single int, or a tuple of two ints"); - const int64_t padH = padding[0]; - const int64_t padW = padding.size() == 1 ? padH : padding[1]; + const int padH = safe_downcast(padding[0]); + const int padW = padding.size() == 1 ? padH : safe_downcast(padding[1]); TORCH_CHECK(!divisor_override.has_value() || divisor_override.value() != 0, "divisor must be not zero"); diff --git a/aten/src/ATen/native/AveragePool3d.cpp b/aten/src/ATen/native/AveragePool3d.cpp index 365cfa311512a..158b6a3a200eb 100644 --- a/aten/src/ATen/native/AveragePool3d.cpp +++ b/aten/src/ATen/native/AveragePool3d.cpp @@ -198,9 +198,9 @@ void avg_pool3d_out_frame( int64_t hend = std::min(hstart + kH, iheight + padH); int64_t wend = std::min(wstart + kW, iwidth + padW); int64_t pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart); - tstart = std::max(tstart, (int64_t) 0); - hstart = std::max(hstart, (int64_t) 0); - wstart = std::max(wstart, (int64_t) 0); + tstart = std::max(tstart, static_cast(0)); + hstart = std::max(hstart, static_cast(0)); + wstart = std::max(wstart, static_cast(0)); tend = std::min(tend, itime); hend = std::min(hend, iheight); wend = std::min(wend, iwidth); @@ -377,9 +377,9 @@ void avg_pool3d_backward_out_frame( int64_t hend = std::min(hstart + kH, iheight + padH); int64_t wend = std::min(wstart + kW, iwidth + padW); int64_t pool_size = (tend -tstart) * (hend - hstart) * (wend - wstart); - tstart = std::max(tstart, (int64_t) 0); - hstart = std::max(hstart, (int64_t) 0); - wstart = std::max(wstart, (int64_t) 0); + tstart = std::max(tstart, static_cast(0)); + hstart = std::max(hstart, static_cast(0)); + wstart = std::max(wstart, static_cast(0)); tend = std::min(tend, itime); hend = std::min(hend, iheight); wend = std::min(wend, iwidth); diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 6669357cda456..8ebf50e913a75 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -2917,9 +2917,7 @@ static Tensor& linalg_eig_make_complex_eigenvectors(Tensor& complex_vectors, con DEFINE_DISPATCH(linalg_eig_stub); static std::tuple linalg_eig_out_info(const Tensor& input, Tensor& values, Tensor& vectors, Tensor& infos, bool compute_eigenvectors) { - // MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU - // therefore we create all intermediate tensors on CPU - auto options = input.options().device(at::kCPU); + auto options = input.options(); // These internal asserts make explicit the assumptions in the implementation // Error check with the actual error messages are done on the higher level of the hierarchy of calls @@ -2928,16 +2926,13 @@ static std::tuple linalg_eig_out_info(const Tensor& input, Ten // for real-valued 'input', eigenvalues can be real-valued or complex-valued TORCH_INTERNAL_ASSERT_DEBUG_ONLY((toComplexType(input.scalar_type()) == values.scalar_type()) || (input.scalar_type() == values.scalar_type())); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.device() == at::kCPU); // for real-valued 'input', eigenvectors can be real-valued or complex-valued if (compute_eigenvectors) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY((toComplexType(input.scalar_type()) == vectors.scalar_type()) || (input.scalar_type() == vectors.scalar_type())); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(vectors.device() == at::kCPU); } TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.scalar_type() == at::kInt); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.device() == at::kCPU); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.numel() == std::max(1, batchCount(input))); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.is_contiguous()); @@ -2986,15 +2981,7 @@ static std::tuple linalg_eig_out_info(const Tensor& input, Ten } } - // MAGMA uses a hybrid CPU-GPU algorithm that performs well only for large matrices - // See: https://github.com/pytorch/pytorch/pull/52491#issuecomment-795685687 - // Here we call CPU path for matrices smaller than 2048x2048 - // that should be in general significantly faster than calling MAGMA - if (input.size(-1) <= 2048) { - linalg_eig_stub(at::kCPU, real_imag_values, maybe_complex_vectors, infos, input.to(kCPU), compute_eigenvectors); - } else { - linalg_eig_stub(input.device().type(), real_imag_values, maybe_complex_vectors, infos, input, compute_eigenvectors); - } + linalg_eig_stub(input.device().type(), real_imag_values, maybe_complex_vectors, infos, input, compute_eigenvectors); // if input is not complex we need to do some post-processing if (!input.is_complex()) { @@ -3019,7 +3006,14 @@ static std::tuple linalg_eig_out_info(const Tensor& input, Ten } if (compute_eigenvectors) { if (vectors.is_complex()) { - vectors = linalg_eig_make_complex_eigenvectors(vectors, values, maybe_complex_vectors); + // We move to the CPU because linalg_eig_make_complex_eigenvectors requires it. + // Performance note: this function could be implemented via a TensorIterator, + // which would avoid an explicit host-device synchronization. + auto vectors_cpu = vectors.cpu(); + auto values_cpu = values.cpu(); + auto maybe_complex_vectors_cpu = maybe_complex_vectors.cpu(); + vectors_cpu = linalg_eig_make_complex_eigenvectors(vectors_cpu, values_cpu, maybe_complex_vectors_cpu); + vectors.copy_(vectors_cpu); } else { TORCH_CHECK(false, "torch.linalg.eig: imaginary part of eigenvectors is non-zero, can't safely cast eigenvectors to non-complex dtype.") } @@ -3039,8 +3033,7 @@ std::tuple linalg_eig_out(const Tensor& input, Tensor& values, checkSameDevice("torch.linalg.eig", values, input, "eigenvalues"); checkSameDevice("torch.linalg.eig", vectors, input, "eigenvectors"); - // MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU - auto options = input.options().device(at::kCPU); + auto options = input.options(); auto infos = at::zeros({std::max(1, batchCount(input))}, options.dtype(kInt)); // if result is not empty and not in batched column major format we have to allocate a temporary tensor @@ -3129,8 +3122,7 @@ Tensor& linalg_eigvals_out(const Tensor& input, Tensor& values) { checkLinalgCompatibleDtype("torch.linalg.eigvals", values.scalar_type(), toComplexType(input.scalar_type()), "eigenvalues"); checkSameDevice("torch.linalg.eigvals", values, input, "eigenvalues"); - // MAGMA doesn't have GPU interface for GEEV routine, it requires inputs to be on CPU - auto options = input.options().device(at::kCPU); + auto options = input.options(); auto infos = at::zeros({std::max(1, batchCount(input))}, options.dtype(kInt)); bool values_expected_type = (values.scalar_type() == toComplexType(input.scalar_type())); @@ -3159,6 +3151,7 @@ Tensor& linalg_eigvals_out(const Tensor& input, Tensor& values) { } Tensor vectors; + vectors = at::empty({0}, input.options()); if (values_tmp_needed) { Tensor values_tmp = at::empty({0}, options.dtype(values_type)); std::tie(values_tmp, std::ignore) = linalg_eig_out_info(input, values_tmp, vectors, infos, /*compute_eigenvectors=*/false); diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index df64aa42e602f..fdc0c09124978 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -946,10 +946,10 @@ void apply_lu_factor(const Tensor& input, const Tensor& pivots, const Tensor& in } }; // avoid overflow - float matrix_rank = float(std::min(m, n)); + auto matrix_rank = std::min(m, n); // A heuristic tested on a 32 core/socket ICX system // https://github.com/pytorch/pytorch/pull/93037#discussion_r1090112948 - int64_t chunk_size_per_thread = int64_t( + int64_t chunk_size_per_thread = static_cast( std::min(1.0, 3200.0 / (matrix_rank * matrix_rank * matrix_rank))); int64_t grain_size = chunk_size_per_thread * at::get_num_threads(); at::parallel_for(0, batch_size, grain_size, loop); diff --git a/aten/src/ATen/native/Blas.cpp b/aten/src/ATen/native/Blas.cpp index 6b7496f49732e..79f7f6112c924 100644 --- a/aten/src/ATen/native/Blas.cpp +++ b/aten/src/ATen/native/Blas.cpp @@ -267,7 +267,7 @@ _scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2, float input_scale = scale_a.item(); float weight_scale = scale_b.item(); - float output_scale = float(1.0); + float output_scale = 1.0f; if (scale_result.has_value() && (*out_dtype == ScalarType::Float8_e4m3fn || *out_dtype == ScalarType::Float8_e5m2)) { diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp index b476ca3cff8f1..76727c8db21ad 100644 --- a/aten/src/ATen/native/BlasKernel.cpp +++ b/aten/src/ATen/native/BlasKernel.cpp @@ -331,7 +331,7 @@ bool gemv_use_fast_path( [[maybe_unused]] double beta, int64_t incy) { return gemv_use_fast_path( - trans, m, n, (float)alpha, lda, incx, (float)beta, incy); + trans, m, n, static_cast(alpha), lda, incx, static_cast(beta), incy); } template <> @@ -523,8 +523,8 @@ static inline void scal(int64_t n, scalar_t a, scalar_t *x, int64_t incx) if (n == 1) incx = 1; #if AT_BUILD_WITH_BLAS() if (blas_impl::scal_use_fast_path(n, incx)) { - int i_n = (int)n; - int i_incx = (int)incx; + int i_n = static_cast(n); + int i_incx = static_cast(incx); blas_impl::scal_fast_path(&i_n, &a, x, &i_incx); return; } @@ -545,11 +545,11 @@ void gemv(char trans, int64_t m, int64_t n, scalar_t alpha, const scalar_t *a, i #if AT_BUILD_WITH_BLAS() if (blas_impl::gemv_use_fast_path(trans, m, n, alpha, lda, incx, beta, incy)) { TORCH_CHECK(lda >= std::max(1L, m), "lda should be at least max(1,", m, "), but have ", lda); - int i_m = (int)m; - int i_n = (int)n; - int i_lda = (int)lda; - int i_incx = (int)incx; - int i_incy = (int)incy; + int i_m = static_cast(m); + int i_n = static_cast(n); + int i_lda = static_cast(lda); + int i_incx = static_cast(incx); + int i_incy = static_cast(incy); blas_impl::gemv_fast_path(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy); return; } diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index c17a70ea308ab..3c8f3922f234e 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -680,9 +680,9 @@ void axpy(int64_t n, double a, const double *x, int64_t incx, double *y, int64_t #if AT_BUILD_WITH_BLAS() if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { - int i_n = (int)n; - int i_incx = (int)incx; - int i_incy = (int)incy; + int i_n = static_cast(n); + int i_incx = static_cast(incx); + int i_incy = static_cast(incy); #if C10_IOS cblas_daxpy(i_n, a, x, i_incx, y, i_incy); #else @@ -705,9 +705,9 @@ void axpy(int64_t n, float a, const float *x, int64_t incx, float *y, int64_t in #if AT_BUILD_WITH_BLAS() if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { - int i_n = (int)n; - int i_incx = (int)incx; - int i_incy = (int)incy; + int i_n = static_cast(n); + int i_incx = static_cast(incx); + int i_incy = static_cast(incy); #if C10_IOS cblas_saxpy(i_n, a, x, i_incx, y, i_incy); #else @@ -730,9 +730,9 @@ void axpy(int64_t n, c10::complex a, const c10::complex *x, int6 #if AT_BUILD_WITH_BLAS() if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { - int i_n = (int)n; - int i_incx = (int)incx; - int i_incy = (int)incy; + int i_n = static_cast(n); + int i_incx = static_cast(incx); + int i_incy = static_cast(incy); #if C10_IOS cblas_zaxpy(i_n, &a, x, i_incx, y, i_incy); #else @@ -755,9 +755,9 @@ void axpy(int64_t n, c10::complex a, const c10::complex *x, int64_ #if AT_BUILD_WITH_BLAS() if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { - int i_n = (int)n; - int i_incx = (int)incx; - int i_incy = (int)incy; + int i_n = static_cast(n); + int i_incx = static_cast(incx); + int i_incy = static_cast(incy); #if C10_IOS cblas_caxpy(i_n, &a, x, i_incx, y, i_incy); #else @@ -781,9 +781,9 @@ void copy(int64_t n, const double *x, int64_t incx, double *y, int64_t incy) { } #if AT_BUILD_WITH_BLAS() if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { - int i_n = (int)n; - int i_incx = (int)incx; - int i_incy = (int)incy; + int i_n = static_cast(n); + int i_incx = static_cast(incx); + int i_incy = static_cast(incy); #if C10_IOS cblas_dcopy(i_n, x, i_incx, y, i_incy); #else @@ -805,9 +805,9 @@ void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy) { } #if AT_BUILD_WITH_BLAS() if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) ) { - int i_n = (int)n; - int i_incx = (int)incx; - int i_incy = (int)incy; + int i_n = static_cast(n); + int i_incx = static_cast(incx); + int i_incy = static_cast(incy); #if C10_IOS cblas_scopy(i_n, x, i_incx, y, i_incy); #else @@ -829,9 +829,9 @@ void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex(n); + int i_incx = static_cast(incx); + int i_incy = static_cast(incy); #if C10_IOS cblas_zcopy(i_n, x, i_incx, y, i_incy); #else @@ -853,9 +853,9 @@ void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex(n); + int i_incx = static_cast(incx); + int i_incy = static_cast(incy); #if C10_IOS cblas_ccopy(i_n, &x, i_incx, y, i_incy); #else @@ -1082,7 +1082,7 @@ struct Brgemm : public KernelCache { M, N, K, - int64_t(1), + 1, ld_a, ld_b, ld_c, @@ -1096,7 +1096,7 @@ struct Brgemm : public KernelCache { M, N, K, - int64_t(1), + 1, ld_a, ld_b, ld_c, diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 610f454be21fa..2c3f14aab911c 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -410,8 +410,8 @@ struct ConvParams { return false; } static long cudnn_version = detail::getCUDAHooks().versionCuDNN(); - // broken on cuDNN 9.8 - if (cudnn_version >= 90800) { + // broken on cuDNN 9.8 - 9.14 + if (cudnn_version >= 90800 && cudnn_version < 91500) { if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous && (input.scalar_type() == at::kBFloat16 || input.scalar_type() == at::kHalf) && weight.dim() == 5) { diff --git a/aten/src/ATen/native/GridSampler.cpp b/aten/src/ATen/native/GridSampler.cpp index 0ca8ec2a3a887..9ad138b3a663b 100644 --- a/aten/src/ATen/native/GridSampler.cpp +++ b/aten/src/ATen/native/GridSampler.cpp @@ -487,17 +487,17 @@ static Tensor _grid_sampler_2d_cpu_quantized( int64_t out_sC = output.stride(1); int64_t out_sH = output.stride(2); int64_t out_sW = output.stride(3); - uint8_t* inp_ptr = (uint8_t*)input.data_ptr(); - uint8_t* out_ptr = (uint8_t*)output.data_ptr(); - float* grid_ptr = grid.data_ptr(); + const uint8_t* inp_ptr = input.const_data_ptr(); + uint8_t* out_ptr = output.data_ptr(); + const float* grid_ptr = grid.const_data_ptr(); at::parallel_for(0, N, 0, [&](int64_t start, int64_t end) { for (const auto n : c10::irange(start, end)) { - float* grid_ptr_N = grid_ptr + n * grid_sN; - uint8_t* inp_ptr_N = inp_ptr + n * inp_sN; + const float* grid_ptr_N = grid_ptr + n * grid_sN; + const uint8_t* inp_ptr_N = inp_ptr + n * inp_sN; for (const auto h : c10::irange(out_H)) { for (const auto w : c10::irange(out_W)) { // get the corresponding input x, y, z coordinates from grid - float* grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; + const float* grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; float x = *grid_ptr_NHW; float y = grid_ptr_NHW[grid_sCoor]; @@ -527,7 +527,7 @@ static Tensor _grid_sampler_2d_cpu_quantized( float se = (ix - ix_nw) * (iy - iy_nw); // calculate bilinear weighted pixel value and set output pixel - uint8_t* inp_ptr_NC = inp_ptr_N; + const uint8_t* inp_ptr_NC = inp_ptr_N; uint8_t* out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW; for (int64_t c = 0; c < C; diff --git a/aten/src/ATen/native/Histogram.cpp b/aten/src/ATen/native/Histogram.cpp index 5919997cf5fe5..26d12ce8b5b8a 100644 --- a/aten/src/ATen/native/Histogram.cpp +++ b/aten/src/ATen/native/Histogram.cpp @@ -318,7 +318,7 @@ static std::vector& histogramdd_bin_edges_out(const Tensor& self, IntArr const int64_t N = self.size(-1); const int64_t M = std::accumulate(self.sizes().begin(), self.sizes().end() - 1, - (int64_t)1, std::multiplies()); + static_cast(1), std::multiplies()); Tensor reshaped_self = self.reshape({ M, N }); auto outer_bin_edges = select_outer_bin_edges(reshaped_self, range); diff --git a/aten/src/ATen/native/Integration.cpp b/aten/src/ATen/native/Integration.cpp index f8f45db66a200..6d592c6c27a45 100644 --- a/aten/src/ATen/native/Integration.cpp +++ b/aten/src/ATen/native/Integration.cpp @@ -40,7 +40,7 @@ Tensor do_trapezoid(const Tensor& y, const Tensor& dx, int64_t dim) { // When dx is constant, the above formula simplifies // to dx * [(\sum_{i=1}^n y_i) - (y_1 + y_n)/2] Tensor do_trapezoid(const Tensor& y, double dx, int64_t dim) { - return (y.sum(dim) - (y.select(dim, 0) + y.select(dim, -1)) * (0.5)) * dx; + return (y.sum(dim) - (y.select(dim, 0) + y.select(dim, -1)) * 0.5) * dx; } Tensor zeros_like_except(const Tensor& y, int64_t dim) { diff --git a/aten/src/ATen/native/Linear.cpp b/aten/src/ATen/native/Linear.cpp index a744da3bcad2e..1da245972f0cb 100644 --- a/aten/src/ATen/native/Linear.cpp +++ b/aten/src/ATen/native/Linear.cpp @@ -201,7 +201,7 @@ static Tensor sumproduct_pair(const Tensor& left_, const Tensor& right_, IntArra out_size.reserve(out_num_dim); for (auto& d : lro) out_size.push_back(left.sym_size(d)); for (auto& d : lo) out_size.push_back(left.sym_size(d)); - for (auto& d : sum_dims_) { out_size.emplace_back(1); (void)(d); }; // avoid warning about not using d + for (auto& d : sum_dims_) { out_size.emplace_back(1); (void)d; }; // avoid warning about not using d for (auto& d : ro) out_size.push_back(right.sym_size(d)); std::vector lpermutation(lro); @@ -640,7 +640,7 @@ Tensor einsum(std::string_view equation, TensorList operands, at::OptionalIntArr } } - return ops[0]; + return std::move(ops[0]); } // _trilinear computes a trilinear einstein sum with an unrolled dimension @@ -805,7 +805,7 @@ Tensor tensordot(const Tensor& input1, const Tensor& input2, IntArrayRef dims1, std::vector rsizes; // rsizes: sizes of the result p1.reserve(input1.dim()); p2.reserve(input2.dim()); - rsizes.reserve(input1.dim() + input2.dim() - (int64_t) dims1.size()); + rsizes.reserve(input1.dim() + input2.dim() - static_cast(dims1.size())); SymInt size1 = 1; // number of non-contracted elements in input1 SymInt size2 = 1; // number of non-contracted elements in input2 diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index c07c7a5ac6e07..07bdc19ec8ff7 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1655,7 +1655,7 @@ static inline void baddbmm_cpu_kernel(const Tensor& result, const Tensor& self, auto s0 = self.accessor(); auto m0 = mat2.accessor(); - int64_t grain_size = std::max(internal::GRAIN_SIZE / (is * js * ks), (int64_t)1); + int64_t grain_size = std::max(internal::GRAIN_SIZE / (is * js * ks), static_cast(1)); using opmath_t = at::opmath_type; parallel_for(0, bs, grain_size, [&](int64_t b_begin, int64_t b_end) { for (const auto b : c10::irange(b_begin, b_end)) { diff --git a/aten/src/ATen/native/LossNLL.cpp b/aten/src/ATen/native/LossNLL.cpp index 576f56986988b..94bf75be2ee8f 100644 --- a/aten/src/ATen/native/LossNLL.cpp +++ b/aten/src/ATen/native/LossNLL.cpp @@ -235,7 +235,7 @@ void nll_loss_out_frame( constexpr int64_t cascade_sum_num_levels = 8; const int64_t level_power = - std::max(int64_t(4), utils::CeilLog2(batch_size) / cascade_sum_num_levels); + std::max(static_cast(4), utils::CeilLog2(batch_size) / cascade_sum_num_levels); const int64_t level_step = (1 << level_power); const int64_t level_mask = level_step - 1; diff --git a/aten/src/ATen/native/LossNLL2d.cpp b/aten/src/ATen/native/LossNLL2d.cpp index 7bea90cbd5274..6f7fd8ad1e136 100644 --- a/aten/src/ATen/native/LossNLL2d.cpp +++ b/aten/src/ATen/native/LossNLL2d.cpp @@ -129,7 +129,7 @@ void nll_loss2d_forward_out_frame( for (const auto b : c10::irange(start, end)) { for (const auto h : c10::irange(H)) { for (const auto w : c10::irange(W)) { - const int64_t cur_target = (int64_t)target_acc[b][h][w]; + const int64_t cur_target = target_acc[b][h][w]; if (cur_target == ignore_index) { output_acc[b][h][w] = static_cast(0); @@ -188,7 +188,7 @@ void nll_loss2d_forward_out_frame( // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) scalar_t loss_partial_sums[cascade_sum_num_levels] = {0}; const int64_t level_power = - std::max(int64_t(4), utils::CeilLog2(numiter) / cascade_sum_num_levels); + std::max(static_cast(4), utils::CeilLog2(numiter) / cascade_sum_num_levels); const int64_t level_step = (1 << level_power); const int64_t level_mask = level_step - 1; diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index 4677542706f6b..6c305ec6c0e50 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -192,7 +192,7 @@ Date: February 1996 x = x - (std::erf(x) - y) / ((static_cast(2.0)/static_cast(std::sqrt(c10::pi)))*std::exp(-x*x)); x = x - (std::erf(x) - y) / ((static_cast(2.0)/static_cast(std::sqrt(c10::pi)))*std::exp(-x*x)); - return(x); + return x; } #undef CENTRAL_RANGE @@ -3819,7 +3819,7 @@ inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) if ((n > 6) && (std::abs(x + x - T(1.0)) < T(1.0))) { if (std::sin(std::acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) { - return std::cos(((n) + T(0.5)) * std::acos(x + x - T(1.0))) / std::cos(std::acos(x + x - T(1.0)) / T(2.0)); + return std::cos((n + T(0.5)) * std::acos(x + x - T(1.0))) / std::cos(std::acos(x + x - T(1.0)) / T(2.0)); } if (n % 2 == 0) { diff --git a/aten/src/ATen/native/NNPACK.cpp b/aten/src/ATen/native/NNPACK.cpp index be6266b17fced..6869f331994e1 100644 --- a/aten/src/ATen/native/NNPACK.cpp +++ b/aten/src/ATen/native/NNPACK.cpp @@ -193,22 +193,22 @@ Tensor _nnpack_spatial_convolution( const size_t input_channels = input.size(1); const size_t output_channels = weight.size(0); const struct nnp_size input_size = { - .width = (size_t)input.size(3), - .height = (size_t)input.size(2), + .width = static_cast(input.size(3)), + .height = static_cast(input.size(2)), }; const struct nnp_padding input_padding = { - .top = (size_t)padding[0], - .right = (size_t)padding[1], - .bottom = (size_t)padding[0], - .left = (size_t)padding[1], + .top = static_cast(padding[0]), + .right = static_cast(padding[1]), + .bottom = static_cast(padding[0]), + .left = static_cast(padding[1]), }; const struct nnp_size kernel_size = { - .width = (size_t)weight.size(3), - .height = (size_t)weight.size(2), + .width = static_cast(weight.size(3)), + .height = static_cast(weight.size(2)), }; const struct nnp_size output_size = { - .width = (size_t)output.size(3), - .height = (size_t)output.size(2), + .width = static_cast(output.size(3)), + .height = static_cast(output.size(2)), }; const nnp_size output_subsample = { .width = static_cast(stride[1]), diff --git a/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp b/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp index 469269ab07dfb..3bb500de9c7c7 100644 --- a/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp +++ b/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp @@ -248,8 +248,8 @@ void slow_conv_transpose3d_out_cpu_template( Tensor weight = weight_.contiguous(); Tensor bias = bias_.defined() ? bias_.contiguous() : bias_; - const int n_input_plane = (int)weight.size(0); - const int n_output_plane = (int)weight.size(1); + const auto n_input_plane = weight.size(0); + const auto n_output_plane = weight.size(1); bool is_batch = false; if (input.dim() == 4) { diff --git a/aten/src/ATen/native/NamedTensor.cpp b/aten/src/ATen/native/NamedTensor.cpp index 47e698d81e49e..d1118a1268a9e 100644 --- a/aten/src/ATen/native/NamedTensor.cpp +++ b/aten/src/ATen/native/NamedTensor.cpp @@ -84,8 +84,8 @@ static std::vector aligned_size( DimnameList aligned_names, bool is_aligning_two_tensors) { std::vector expanded_sizes(aligned_names.size(), 1); - ptrdiff_t dim = (ptrdiff_t)tensor_sizes.size() - 1; - ptrdiff_t idx = (ptrdiff_t)aligned_names.size() - 1; + ptrdiff_t dim = static_cast(tensor_sizes.size()) - 1; + ptrdiff_t idx = static_cast(aligned_names.size()) - 1; for (; idx >= 0 && dim >= 0; --idx) { if (tensor_names[dim] != aligned_names[idx]) { continue; diff --git a/aten/src/ATen/native/RowwisePrune.cpp b/aten/src/ATen/native/RowwisePrune.cpp index bae698638b2e2..26062a071ff70 100644 --- a/aten/src/ATen/native/RowwisePrune.cpp +++ b/aten/src/ATen/native/RowwisePrune.cpp @@ -25,7 +25,7 @@ std::tuple _rowwise_prune_helper( auto mask_contig = mask.contiguous(); auto mask_data = mask_contig.data_ptr(); for (const auto i : c10::irange(mask.numel())) { - num_non_masked_rows += (((mask_data[i] == true)) ? 1 : 0); + num_non_masked_rows += ((mask_data[i] == true) ? 1 : 0); } int num_cols = weights.size(1); auto pruned_2d_tensor = at::empty({num_non_masked_rows, num_cols}, diff --git a/aten/src/ATen/native/SoftMax.cpp b/aten/src/ATen/native/SoftMax.cpp index c3cc459184d93..fa83f3b6122fc 100644 --- a/aten/src/ATen/native/SoftMax.cpp +++ b/aten/src/ATen/native/SoftMax.cpp @@ -176,7 +176,7 @@ void host_softmax( scalar_t* input_data_base = input.data_ptr(); scalar_t* output_data_base = output.data_ptr(); bool* mask_data_base = mask; - int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1); + int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, static_cast(1)); parallel_for( 0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) { @@ -265,7 +265,7 @@ void host_softmax_backward( scalar_t* output_data_base = output.data_ptr(); scalar_t* gradOutput_data_base = grad.data_ptr(); bool* mask_data_base = mask; - int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1); + int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, static_cast(1)); parallel_for( 0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) { for (const auto i : c10::irange(begin, end)) { diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 451869f521df2..bfb5803eee07b 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -1701,13 +1701,13 @@ Tensor& index_select_out_cpu_( TORCH_CHECK_INDEX( (self_i >= 0) && (self_i < self_dim_size), "index out of range in self"); - auto self_data = static_cast(selfSlice_data) + + auto self_data = const_cast(static_cast( + selfSlice_data)) + self_i * self_stride_bytes; auto result_data = static_cast(resultSlice_data) + i * result_stride_bytes; sub_iter.unsafe_replace_operand(0, result_data); - sub_iter.unsafe_replace_operand( - 1, const_cast(self_data)); + sub_iter.unsafe_replace_operand(1, self_data); copy_stub(sub_iter.device_type(), sub_iter, false); }; }); diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 1886e65fc1edc..0feccbf24b484 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -1089,6 +1090,7 @@ Tensor& rand_out( Tensor rand_like( const Tensor& self, + std::optional generator, std::optional dtype, std::optional layout, std::optional device, @@ -1100,7 +1102,24 @@ Tensor rand_like( pin_memory); auto result = at::empty_like(self, options, optional_memory_format); - return result.uniform_(0, 1, std::nullopt); + return result.uniform_(0, 1, std::move(generator)); +} + +Tensor rand_like( + const Tensor& self, + std::optional dtype, + std::optional layout, + std::optional device, + std::optional pin_memory, + std::optional optional_memory_format) { + return native::rand_like( + self, + static_cast>(std::nullopt), + dtype, + layout, + device, + pin_memory, + optional_memory_format); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randint ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1197,7 +1216,9 @@ Tensor& randint_out( Tensor randint_like( const Tensor& self, + int64_t low, int64_t high, + std::optional generator, std::optional dtype, std::optional layout, std::optional device, @@ -1209,7 +1230,71 @@ Tensor randint_like( pin_memory); auto result = at::empty_like(self, options, optional_memory_format); - return result.random_(0, high, std::nullopt); + return result.random_(low, high, std::move(generator)); +} + +Tensor randint_like( + const Tensor& self, + int64_t low, + int64_t high, + std::optional dtype, + std::optional layout, + std::optional device, + std::optional pin_memory, + std::optional optional_memory_format) { + return native::randint_like( + self, + low, + high, + static_cast>(std::nullopt), + dtype, + layout, + device, + pin_memory, + optional_memory_format); +} + +Tensor randint_like( + const Tensor& self, + int64_t high, + std::optional dtype, + std::optional layout, + std::optional device, + std::optional pin_memory, + std::optional optional_memory_format) { + // See [Note: hacky wrapper removal for TensorOptions] + return native::randint_like( + self, + 0, + high, + static_cast>(std::nullopt), + dtype, + layout, + device, + pin_memory, + optional_memory_format); +} + +Tensor randint_like( + const Tensor& self, + int64_t high, + std::optional generator, + std::optional dtype, + std::optional layout, + std::optional device, + std::optional pin_memory, + std::optional optional_memory_format) { + // See [Note: hacky wrapper removal for TensorOptions] + return native::randint_like( + self, + 0, + high, + generator, + dtype, + layout, + device, + pin_memory, + optional_memory_format); } Tensor randint_like( @@ -1226,7 +1311,9 @@ Tensor randint_like( int64_t high_scalar = high.item(); return at::native::randint_like( self, + 0, high_scalar, + static_cast>(std::nullopt), dtype, layout, device, @@ -1236,20 +1323,27 @@ Tensor randint_like( Tensor randint_like( const Tensor& self, - int64_t low, - int64_t high, + const Tensor& high, + std::optional generator, std::optional dtype, std::optional layout, std::optional device, std::optional pin_memory, std::optional optional_memory_format) { - // See [Note: hacky wrapper removal for TensorOptions] - TensorOptions options = - TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory( - pin_memory); - - auto result = at::empty_like(self, options, optional_memory_format); - return result.random_(low, high, std::nullopt); + TORCH_CHECK( + high.numel() == 1 && high.ndimension() == 0 && high.device().is_cpu(), + "high must be a scalar tensor and on CPU"); + int64_t high_scalar = high.item(); + return at::native::randint_like( + self, + 0, + high_scalar, + generator, + dtype, + layout, + device, + pin_memory, + optional_memory_format); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randn ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1327,6 +1421,7 @@ Tensor& normal_out( Tensor randn_like( const Tensor& self, + std::optional generator, std::optional dtype, std::optional layout, std::optional device, @@ -1338,7 +1433,24 @@ Tensor randn_like( pin_memory); auto result = at::empty_like(self, options, optional_memory_format); - return result.normal_(0, 1, std::nullopt); + return result.normal_(0, 1, std::move(generator)); +} + +Tensor randn_like( + const Tensor& self, + std::optional dtype, + std::optional layout, + std::optional device, + std::optional pin_memory, + std::optional optional_memory_format) { + return native::randn_like( + self, + static_cast>(std::nullopt), + dtype, + layout, + device, + pin_memory, + optional_memory_format); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ randperm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1382,7 +1494,7 @@ void randperm_cpu(Tensor& result, int64_t n, CPUGeneratorImpl* generator) { // use no-initialization Fischer-Yates variant // https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_.22inside-out.22_algorithm for (int64_t i = 0; i < n; i++) { - int64_t z = (int64_t)(generator->random64() % (i + 1)); + int64_t z = static_cast(generator->random64() % (i + 1)); r__data[i * r__stride_0] = i; r__data[i * r__stride_0] = r__data[z * r__stride_0]; r__data[z * r__stride_0] = i; diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp index 717ad1608c0f8..12841ad8e7391 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp @@ -40,7 +40,7 @@ at::Tensor PackedLinearWeightQnnp::apply_dynamic_impl( "quantized_sparse_linear(): Input tensor rank should be >= 2"); const auto rows_input = c10::multiply_integers(input.sizes().begin(), input.sizes().end() - 1); - const auto cols_input = static_cast(input.size(input.dim() - 1)); + const auto cols_input = input.size(input.dim() - 1); TORCH_CHECK( cols_input == input_channels_, "quantized_sparse_linear: Input tensor's last and weight tensor's" diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_unpack.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_unpack.cpp index b9cffe5b0bcbf..c3737e69390f4 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_unpack.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_unpack.cpp @@ -65,8 +65,8 @@ LinearPackedSerializationType PackedLinearWeight::unpack() { #ifdef USE_PYTORCH_QNNPACK LinearPackedSerializationType PackedLinearWeightQnnp::unpack() { - const int64_t N = static_cast(output_channels_); - const int64_t K = static_cast(input_channels_); + const int64_t N = output_channels_; + const int64_t K = input_channels_; float* w_scales_ptr = w_scales_.data_ptr(); diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp index bc9b452bc6876..4ea90e7d6ddcf 100644 --- a/aten/src/ATen/native/cpu/Activation.cpp +++ b/aten/src/ATen/native/cpu/Activation.cpp @@ -998,7 +998,7 @@ void softplus_backward_kernel(TensorIteratorBase& iter, const Scalar& beta_, con auto threshold = threshold_.to(); const Vec beta_vec(beta); const Vec threshold_vec(threshold); - const Vec one_vec(static_cast(1.0)); + const Vec one_vec(1.0f); cpu_kernel_vec( iter, [beta, threshold](scalar_t a, scalar_t b) -> scalar_t { diff --git a/aten/src/ATen/native/cpu/AtomicAddFloat.h b/aten/src/ATen/native/cpu/AtomicAddFloat.h index 526f86d705b77..1ecfbe0357fa8 100644 --- a/aten/src/ATen/native/cpu/AtomicAddFloat.h +++ b/aten/src/ATen/native/cpu/AtomicAddFloat.h @@ -17,7 +17,7 @@ static inline void cpu_atomic_add_float(float* dst, float fvalue) } uf32_t; uf32_t new_value, old_value; - std::atomic* dst_intV = (std::atomic*)(dst); + std::atomic* dst_intV = (std::atomic*)dst; old_value.floatV = *dst; new_value.floatV = old_value.floatV + fvalue; diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 10e0daacab33c..221f621ea1e06 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -851,7 +851,7 @@ void sigmoid_backward_kernel(TensorIteratorBase& iter) { }); }); } else if (iter.dtype() == kBFloat16) { - auto one_vec = Vectorized((float)(1)); + auto one_vec = Vectorized((float)1); cpu_kernel_vec( iter, [=](BFloat16 a, BFloat16 b) -> BFloat16 { diff --git a/aten/src/ATen/native/cpu/CopyKernel.cpp b/aten/src/ATen/native/cpu/CopyKernel.cpp index 365a79ba52ca9..68c5a867f24ee 100644 --- a/aten/src/ATen/native/cpu/CopyKernel.cpp +++ b/aten/src/ATen/native/cpu/CopyKernel.cpp @@ -77,9 +77,7 @@ static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_ne int64_t grain_size = at::internal::GRAIN_SIZE; - auto loop = [strides_in, requires_neg](char** base, const int64_t* strides, int64_t size0, int64_t size1) { - std::array data; - std::copy_n(base, 2, data.data()); + auto loop = [strides_in, requires_neg](char** data, const int64_t* strides, int64_t size0, int64_t size1) { const int64_t *outer_strides = &strides[2]; for ([[maybe_unused]] const auto it : c10::irange(size1)) { @@ -146,9 +144,7 @@ static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_ne int64_t grain_size = at::internal::GRAIN_SIZE; - auto loop = [strides_in, requires_neg](char** base, const int64_t* strides, int64_t size0, int64_t size1) { - std::array data; - std::copy_n(base, 2, data.data()); + auto loop = [strides_in, requires_neg](char** data, const int64_t* strides, int64_t size0, int64_t size1) { const int64_t *outer_strides = &strides[2]; for ([[maybe_unused]] const auto it : c10::irange(size1)) { diff --git a/aten/src/ATen/native/cpu/IndexKernel.cpp b/aten/src/ATen/native/cpu/IndexKernel.cpp index 57d3ab89c6174..cce717692e3da 100644 --- a/aten/src/ATen/native/cpu/IndexKernel.cpp +++ b/aten/src/ATen/native/cpu/IndexKernel.cpp @@ -493,40 +493,33 @@ void cpu_hflip_vec(at::TensorIterator& iter) { for ([[maybe_unused]] const auto j : c10::irange(size1)) { // vectorized loop with negative stride for output - char** C10_RESTRICT data_ = data_arr.data(); int64_t n = size0; - - char* C10_RESTRICT data[ntensors]; - for (const auto arg : c10::irange(ntensors)) { - data[arg] = data_[arg]; - } - int64_t i = 0; - // data[0] unaligned pre-pass + // data_arr[0] unaligned pre-pass int64_t offset = (j * n + (n - i - Vec::size())) % 32; offset = (offset >= n) ? n : offset; for (; i < offset; i++) { - scalar_t* out_ptr = (scalar_t*)(data[0] - i * stride); - *out_ptr = c10::load((scalar_t *)(data[1] + i * stride)); + scalar_t* out_ptr = (scalar_t*)(data_arr[0] - i * stride); + *out_ptr = c10::load((scalar_t *)(data_arr[1] + i * stride)); } // Empirically found that it is faster to process 3 data items together vs 2 or 4 for (; i <= n - 3 * Vec::size(); i += 3 * Vec::size()) { - auto out1 = Vec::loadu(data[1] + i * stride); - auto out2 = Vec::loadu(data[1] + (i + Vec::size()) * stride); - auto out3 = Vec::loadu(data[1] + (i + 2 * Vec::size()) * stride); + auto out1 = Vec::loadu(data_arr[1] + i * stride); + auto out2 = Vec::loadu(data_arr[1] + (i + Vec::size()) * stride); + auto out3 = Vec::loadu(data_arr[1] + (i + 2 * Vec::size()) * stride); // flip the vector: 1234 -> 4321 out1 = flip(out1); out2 = flip(out2); out3 = flip(out3); - out1.store(data[0] - (i + Vec::size() - 1) * stride); - out2.store(data[0] - (i + 2 * Vec::size() - 1) * stride); - out3.store(data[0] - (i + 3 * Vec::size() - 1) * stride); + out1.store(data_arr[0] - (i + Vec::size() - 1) * stride); + out2.store(data_arr[0] - (i + 2 * Vec::size() - 1) * stride); + out3.store(data_arr[0] - (i + 3 * Vec::size() - 1) * stride); } if (i < n) { for (; i < n; i++) { - scalar_t* out_ptr = (scalar_t*)(data[0] - i * stride); - *out_ptr = c10::load((scalar_t *)(data[1] + i * stride)); + scalar_t* out_ptr = (scalar_t*)(data_arr[0] - i * stride); + *out_ptr = c10::load((scalar_t *)(data_arr[1] + i * stride)); } } @@ -560,15 +553,8 @@ void cpu_vflip_memcpy(at::TensorIterator& iter) { const int64_t stride = strides[0]; for ([[maybe_unused]] const auto j : c10::irange(size1)) { - char** C10_RESTRICT data_ = data_arr.data(); int64_t n = size0; - - char* C10_RESTRICT data[ntensors]; - for (const auto arg : c10::irange(ntensors)) { - data[arg] = data_[arg]; - } - - memcpy(data[0], data[1], n * stride); + memcpy(data_arr[0], data_arr[1], n * stride); // advance: for (const auto arg : c10::irange(data_arr.size())) { diff --git a/aten/src/ATen/native/cpu/Unfold2d.cpp b/aten/src/ATen/native/cpu/Unfold2d.cpp index 444ec10861da8..ed69998e99f79 100644 --- a/aten/src/ATen/native/cpu/Unfold2d.cpp +++ b/aten/src/ATen/native/cpu/Unfold2d.cpp @@ -298,7 +298,7 @@ void unfolded2d_copy( memcpy( dst + (size_t)y * output_width + x, src + (size_t)iy * input_width + ix, - sizeof(scalar_t) * (1)); + sizeof(scalar_t) * 1); } } } @@ -317,7 +317,7 @@ void unfolded2d_copy( memcpy( dst + (size_t)y * output_width + x, src + (size_t)iy * input_width + ix + x * dW, - sizeof(scalar_t) * (1)); + sizeof(scalar_t) * 1); } } } diff --git a/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h b/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h index 073cc4fd7e8bb..146c60e5cd0fa 100644 --- a/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h +++ b/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h @@ -342,7 +342,7 @@ void upsample_avx_bilinear_bicubic_uint8( if (need_horizontal) { int interp_dim = 3; - auto stride = (skip_unpacking) ? num_channels : 4; + auto stride = skip_unpacking ? num_channels : 4; std::tie(horiz_indices_weights, ksize_horiz, horiz_weights_precision) = F::compute_index_ranges_int16_weights( /*input_size=*/xin, @@ -358,7 +358,7 @@ void upsample_avx_bilinear_bicubic_uint8( if (need_vertical) { int interp_dim = 2; - auto stride = (skip_unpacking) ? num_channels * xout : 4 * xout; + auto stride = skip_unpacking ? num_channels * xout : 4 * xout; std::tie(vert_indices_weights, ksize_vert, vert_weights_precision) = F::compute_index_ranges_int16_weights( /*input_size=*/yin, @@ -377,17 +377,17 @@ void upsample_avx_bilinear_bicubic_uint8( // horizontal-only or vertical-only interpolation, and if the tensor doesn't // need repacking if (need_horizontal && (need_vertical || !skip_packing)) { - auto c = (skip_unpacking) ? num_channels : 4; + auto c = skip_unpacking ? num_channels : 4; buffer_horiz = at::empty({c, yin, xout}, input.options()); } if (need_vertical && !skip_packing) { - auto c = (skip_unpacking) ? num_channels : 4; + auto c = skip_unpacking ? num_channels : 4; buffer_vert = at::empty({c, yout, xout}, input.options()); } for (const auto i : c10::irange(batch_size)) { - at::Tensor unpacked_input = (skip_unpacking) ? input[i] : unpack_rgb(input[i]); + at::Tensor unpacked_input = skip_unpacking ? input[i] : unpack_rgb(input[i]); at::Tensor unpacked_output; if (need_horizontal) { @@ -411,7 +411,7 @@ void upsample_avx_bilinear_bicubic_uint8( unpacked_output = unpacked_input = unpacked_output_temp; } if (need_vertical) { - unpacked_output = (skip_packing) ? output[i] : buffer_vert; + unpacked_output = skip_packing ? output[i] : buffer_vert; ImagingResampleVertical( unpacked_output, @@ -502,7 +502,7 @@ void ImagingResampleHorizontalConvolution8u4x( // RGBA: b4_delta = b4_delta_soft = 3 // RGB : b4_delta = 5 // RGB : b4_delta_soft = 4 - const auto b4_delta = (stride == 4) ? 3 : ((is_last_line) ? 5 : 4); + const auto b4_delta = (stride == 4) ? 3 : (is_last_line ? 5 : 4); // In block 2 (2 means we process 2 weights values together), we read input data // with _mm_loadl_epi64, i.e. 8 bytes, per one line: @@ -515,7 +515,7 @@ void ImagingResampleHorizontalConvolution8u4x( // RGBA: b2_delta = b2_delta_soft = 1 // RGB : b2_delta = 2 // RGB : b2_delta_soft = 1 - const auto b2_delta = (stride == 4) ? 1 : ((is_last_line) ? 2 : 1); + const auto b2_delta = (stride == 4) ? 1 : (is_last_line ? 2 : 1); const auto max_out_x_strided = out_xsize * stride; const auto max_in_x_strided = in_xsize * stride; @@ -819,7 +819,7 @@ void ImagingResampleHorizontalConvolution8u( // RGBA: b8_delta = b8_delta_soft = 7 // RGB : b8_delta = 10 // RGB : b8_delta_soft = 9 - const auto b8_delta = (stride == 4) ? 7 : ((is_last_line) ? 10 : 9); + const auto b8_delta = (stride == 4) ? 7 : (is_last_line ? 10 : 9); // In block 4 (4 means we process 4 weight values together), we read // 16 bytes of input data. @@ -832,7 +832,7 @@ void ImagingResampleHorizontalConvolution8u( // RGBA: b4_delta = b4_delta_soft = 3 // RGB : b4_delta = 5 // RGB : b4_delta_soft = 4 - const auto b4_delta = (stride == 4) ? 3 : ((is_last_line) ? 5 : 4); + const auto b4_delta = (stride == 4) ? 3 : (is_last_line ? 5 : 4); // In block 2 (2 means we process 2 weight values together), we read // 8 bytes of input data. @@ -845,7 +845,7 @@ void ImagingResampleHorizontalConvolution8u( // RGBA: b2_delta = b2_delta_soft = 1 // RGB : b2_delta = 2 // RGB : b2_delta_soft = 1 - const auto b2_delta = (stride == 4) ? 1 : ((is_last_line) ? 2 : 1); + const auto b2_delta = (stride == 4) ? 1 : (is_last_line ? 2 : 1); const auto max_out_x_strided = out_xsize * stride; const auto max_in_x_strided = in_xsize * stride; diff --git a/aten/src/ATen/native/cpu/int4mm_kernel.cpp b/aten/src/ATen/native/cpu/int4mm_kernel.cpp index a9683ba4bef3f..33aae4fbf27a5 100644 --- a/aten/src/ATen/native/cpu/int4mm_kernel.cpp +++ b/aten/src/ATen/native/cpu/int4mm_kernel.cpp @@ -644,8 +644,8 @@ void weight_to_int4pack_kernel( int32_t val2 = src[(d + 32) * K + k]; int32_t val3 = src[(d + 48) * K + k]; - uint8_t packed02 = (((uint8_t)(val2) << 4)) | ((uint8_t)(val0)); - uint8_t packed13 = (((uint8_t)(val3) << 4)) | ((uint8_t)(val1)); + uint8_t packed02 = ((uint8_t)val2 << 4) | ((uint8_t)val0); + uint8_t packed13 = ((uint8_t)val3 << 4) | ((uint8_t)val1); dst[k * 32 + d] = packed02; dst[k * 32 + 16 + d] = packed13; @@ -656,7 +656,7 @@ void weight_to_int4pack_kernel( int32_t val0 = src[n * K + k]; int32_t val1 = src[n * K + K + k]; - uint8_t packed = (((uint8_t)(val1) << 4)) | ((uint8_t)(val0)); + uint8_t packed = ((uint8_t)val1 << 4) | ((uint8_t)val0); dst[k * nb_size / 2 + n / 2] = packed; } } @@ -667,7 +667,7 @@ void weight_to_int4pack_kernel( int32_t val0 = src[(d + 0) * K + k]; int32_t val1 = src[(d + 16) * K + k]; - uint8_t packed01 = (((uint8_t)(val1) << 4)) | ((uint8_t)(val0)); + uint8_t packed01 = ((uint8_t)val1 << 4) | ((uint8_t)val0); dst[k * 16 + d] = packed01; } } else { @@ -676,7 +676,7 @@ void weight_to_int4pack_kernel( int32_t val0 = src[n * K + k]; int32_t val1 = src[n * K + K + k]; - uint8_t packed = (((uint8_t)(val1) << 4)) | ((uint8_t)(val0)); + uint8_t packed = ((uint8_t)val1 << 4) | ((uint8_t)val0); dst[k * nb_size / 2 + n / 2] = packed; } } @@ -685,7 +685,7 @@ void weight_to_int4pack_kernel( int32_t val0 = src[n * K + k]; int32_t val1 = src[n * K + K + k]; - uint8_t packed = (((uint8_t)(val1) << 4)) | ((uint8_t)(val0)); + uint8_t packed = ((uint8_t)val1 << 4) | ((uint8_t)val0); dst[k * nb_size / 2 + n / 2] = packed; } #endif @@ -872,16 +872,16 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel( for (size_t k_idx = 0; k_idx < k; ++k_idx) { const float src0_0 = src_ptr[k_idx]; - max0 = (std::max)(src0_0, max0); - min0 = (std::min)(src0_0, min0); + max0 = std::max(src0_0, max0); + min0 = std::min(src0_0, min0); } // Maximum/minimum int8 values const float qmin = (float)INT8_MIN; const float qmax = (float)INT8_MAX; - const float rmin0 = (std::min)(0.0f, min0); - const float rmax0 = (std::max)(0.0f, max0); + const float rmin0 = std::min(0.0f, min0); + const float rmax0 = std::max(0.0f, max0); const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); @@ -900,8 +900,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel( ? qmin - descaled_min0 : qmax - descaled_max0; - zero_point0 = (std::max)(zero_point0, qmin); - zero_point0 = (std::min)(zero_point0, qmax); + zero_point0 = std::max(zero_point0, qmin); + zero_point0 = std::min(zero_point0, qmax); // Round to nearest integer const int32_t nudged_zero_point0 = lrintf(zero_point0); @@ -909,9 +909,9 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel( int8_t* dst_ptr = lhs_qa8dx + m_idx * dst_stride; // LHS offset at the beginning of the row - *((float*)(dst_ptr)) = recip_scale0; + *((float*)dst_ptr) = recip_scale0; dst_ptr += sizeof(float); - *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + *((int32_t*)dst_ptr) = -nudged_zero_point0; dst_ptr += sizeof(int32_t); // Quantize the channels @@ -922,8 +922,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel( int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0)); v0_s32 = v0_s32 + nudged_zero_point0; - v0_s32 = (std::max)(v0_s32, static_cast(INT8_MIN)); - v0_s32 = (std::min)(v0_s32, static_cast(INT8_MAX)); + v0_s32 = std::max(v0_s32, static_cast(INT8_MIN)); + v0_s32 = std::min(v0_s32, static_cast(INT8_MAX)); dst_ptr[0] = (int8_t)v0_s32; dst_ptr += sizeof(int8_t); } @@ -988,8 +988,8 @@ void ref_dyn_quant_matmul_4bit_channelwise_kernel( main_acc = main_acc * lhs_scale; // Clamp (min-max) operation - main_acc = (std::max)(main_acc, scalar_min); - main_acc = (std::min)(main_acc, scalar_max); + main_acc = std::max(main_acc, scalar_min); + main_acc = std::min(main_acc, scalar_max); dst_f32[0] = main_acc; dst_f32 += 1; @@ -1024,15 +1024,15 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel( for (size_t k_idx = 0; k_idx < k; ++k_idx) { const float src0_0 = src_ptr[k_idx]; - max0 = (std::max)(src0_0, max0); - min0 = (std::min)(src0_0, min0); + max0 = std::max(src0_0, max0); + min0 = std::min(src0_0, min0); } const float qmin = (float)INT8_MIN; const float qmax = (float)INT8_MAX; - const float rmin0 = (std::min)(0.0f, min0); - const float rmax0 = (std::max)(0.0f, max0); + const float rmin0 = std::min(0.0f, min0); + const float rmax0 = std::max(0.0f, max0); const float scale0 = (rmin0 == rmax0) ? 1.f : (qmax - qmin) / (rmax0 - rmin0); const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; @@ -1044,22 +1044,22 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel( ? qmin - descaled_min0 : qmax - descaled_max0; - zero_point0 = (std::max)(zero_point0, qmin); - zero_point0 = (std::min)(zero_point0, qmax); + zero_point0 = std::max(zero_point0, qmin); + zero_point0 = std::min(zero_point0, qmax); const int32_t nudged_zero_point0 = lrintf(zero_point0); int8_t* dst_ptr = lhs_qa8dx + row_idx * dst_stride; - *((float*)(dst_ptr)) = recip_scale0; + *((float*)dst_ptr) = recip_scale0; dst_ptr += sizeof(float); - *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + *((int32_t*)dst_ptr) = -nudged_zero_point0; dst_ptr += sizeof(int32_t); for (size_t k_idx = 0; k_idx < k; ++k_idx) { const float src0_0 = src_ptr[k_idx]; int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0)); - v0_s32 = (std::max)( - (std::min)( + v0_s32 = std::max( + std::min( v0_s32 + nudged_zero_point0, static_cast(INT8_MAX)), static_cast(INT8_MIN)); dst_ptr[0] = (int8_t)v0_s32; @@ -1118,8 +1118,8 @@ void ref_dyn_quant_matmul_4bit_groupwise_kernel( } main_acc = main_acc * lhs_scale; - main_acc = (std::max)(main_acc, scalar_min); - main_acc = (std::min)(main_acc, scalar_max); + main_acc = std::max(main_acc, scalar_min); + main_acc = std::min(main_acc, scalar_max); dst_f32[0] = main_acc; dst_f32 += 1; diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index f29be23acd559..186f7d8a6a78a 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -170,10 +169,14 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const #if defined(CUDA_VERSION) || defined(USE_ROCM) const auto scalar_type = mat1.scalar_type(); return (beta.toComplexDouble() == 1.0 - // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] - // is to use lt interface only when self is bias. - && self.dim() == 1 && self.sizes()[0] == mat2_sizes[1] && self.is_contiguous() && result.dim() == 2 && result.is_contiguous() + // Conditions for bias to be fusable + && ( + self.is_contiguous() && + // NOTE: fine to have 1-len dims to the left from the right-most one + (self.dim() == 1 || self.squeeze().dim() == 1) && + self.sizes().back() == mat2_sizes[1] + ) && ( // some dtype restrictions #ifndef USE_ROCM scalar_type == at::ScalarType::Double || @@ -202,8 +205,8 @@ static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const // and the leading stride is at least max(1, other dim length), so we might // end up with contiguous cols but not rows (i.e. holes between different rows) // and vice versa. - && mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 && - mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 && + && mat2_sizes[0] < 65535 * 32 && mat2_sizes[1] < 65535 * 32 + && mat1_sizes[0] < 65535 * 32 && mat1_sizes[1] < 65535 * 32 && ( // filter by dtype (scalar_type != at::ScalarType::Half && scalar_type != at::ScalarType::BFloat16) || diff --git a/aten/src/ATen/native/cuda/GroupedBlas.cpp b/aten/src/ATen/native/cuda/GroupedBlas.cpp index 407f101023178..f64eb317d0cca 100644 --- a/aten/src/ATen/native/cuda/GroupedBlas.cpp +++ b/aten/src/ATen/native/cuda/GroupedBlas.cpp @@ -213,9 +213,9 @@ _f4_f4_bf16_grouped_mm_fbgemm( const Tensor& mat_a, const Tensor& mat_b, const Tensor& scale_a, - const Tensor& global_scale_a, + const std::optional& global_scale_a, const Tensor& scale_b, - const Tensor& global_scale_b, + const std::optional& global_scale_b, const std::optional& offs, const std::optional& bias, Tensor& out) { @@ -225,14 +225,28 @@ _f4_f4_bf16_grouped_mm_fbgemm( "mat_a must be Float4_e2n1fn_2, got: ", mat_a.scalar_type()); TORCH_CHECK_VALUE(mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_b must be Float4_e2n1fn_2, got: ", mat_b.scalar_type()); - TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e4m3fn, - "scale_a must be Float8_e4m3fn, got: ", scale_a.scalar_type()); - TORCH_CHECK_VALUE(scale_b.scalar_type() == at::kFloat8_e4m3fn, - "scale_b must be Float8_e4m3fn, got: ", scale_b.scalar_type()); - TORCH_CHECK_VALUE(global_scale_a.scalar_type() == at::kFloat, - "global_scale_a must be Float, got: ", global_scale_a.scalar_type()); - TORCH_CHECK_VALUE(global_scale_b.scalar_type() == at::kFloat, - "global_scale_b must be Float, got: ", global_scale_b.scalar_type()); + + std::optional combined_global_scale = std::nullopt; + if (global_scale_a.has_value() || global_scale_b.has_value()) { + // NVFP4 + TORCH_CHECK_VALUE(global_scale_a.has_value() && global_scale_b.has_value(), + "For NVFP4 grouped gemm both of global_scale_{a,b} must have values") + TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e4m3fn, + "scale_a must be Float8_e4m3fn, got: ", scale_a.scalar_type()); + TORCH_CHECK_VALUE(scale_b.scalar_type() == at::kFloat8_e4m3fn, + "scale_b must be Float8_e4m3fn, got: ", scale_b.scalar_type()); + TORCH_CHECK_VALUE(global_scale_a.value().scalar_type() == at::kFloat, + "global_scale_a must be Float, got: ", global_scale_a.value().scalar_type()); + TORCH_CHECK_VALUE(global_scale_b.value().scalar_type() == at::kFloat, + "global_scale_b must be Float, got: ", global_scale_b.value().scalar_type()); + combined_global_scale = global_scale_a.value().mul(global_scale_b.value()); + } else { + // MXFP4 + TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu, + "scale_a must be Float8_e8m0fnu, got: ", scale_a.scalar_type()); + TORCH_CHECK_VALUE(scale_b.scalar_type() == at::kFloat8_e8m0fnu, + "scale_b must be Float8_e8m0fnu, got: ", scale_b.scalar_type()); + } auto o = fbgemm_gpu::f4f4bf16_grouped_mm( mat_a, @@ -241,7 +255,7 @@ _f4_f4_bf16_grouped_mm_fbgemm( scale_b, offs.value(), out, - global_scale_a.mul(global_scale_b) + combined_global_scale ); #else TORCH_CHECK_NOT_IMPLEMENTED(false, "nvfp4 grouped gemm is not supported without USE_FBGEMM_GENAI, and only for CUDA") @@ -471,9 +485,10 @@ namespace { using acceptance_fn = std::function&, ArrayRef&, c10::ScalarType, std::vector&, ArrayRef&)>; -std::array, 3> scale_grouped_kernel_dispatch = {{ +std::array, 4> scale_grouped_kernel_dispatch = {{ { "rowwise_rowwise", scaled_blas::check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE}, { "mxfp8_mxfp8", scaled_blas::check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}, + { "mxfp4_mxfp4", scaled_blas::check_mxfp4_recipe, ScaledGemmImplementation::MXFP4_MXFP4}, { "nvfp4_nvfp4", scaled_blas::check_nvfp4_recipe, ScaledGemmImplementation::NVFP4_NVFP4}}}; } // anonymous namespace @@ -599,6 +614,21 @@ _scaled_grouped_mm_cuda_v2( offs.value(), out); } + case ScaledGemmImplementation::MXFP4_MXFP4: { + // scale shape checks + _check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */); + _check_scales_blocked(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */); + return _f4_f4_bf16_grouped_mm_fbgemm( + mat_a, + mat_b, + scale_a[0], /* block-scale A */ + std::nullopt, /* global-scale A */ + scale_b[0], /* block-scale B */ + std::nullopt, /* global-scale B */ + offs.value(), + std::nullopt, /* bias */ + out); + } case ScaledGemmImplementation::NVFP4_NVFP4: { // scale shape checks _check_scales_blocked(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */); diff --git a/aten/src/ATen/native/cuda/IndexKernelUtils.cu b/aten/src/ATen/native/cuda/IndexKernelUtils.cu index 3e13f934e21e3..8343c60418952 100644 --- a/aten/src/ATen/native/cuda/IndexKernelUtils.cu +++ b/aten/src/ATen/native/cuda/IndexKernelUtils.cu @@ -13,7 +13,7 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx, if (allow_neg_indices) { ind = (ind < 0) ? ind + ind_dim_size : ind; } - CUDA_KERNEL_ASSERT(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"); + CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds", "Expected 0 <= index < ind_dim_size(%ld), but got index = %ld", ind_dim_size, ind); int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; // off is guaranteed to be within int32 limits if (off >= slice_size) return; auto vec = at::native::memory::ld_vec(inp + ind * inp_stride + off); diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index 3eeca901a18d5..382a5a065b300 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -54,7 +54,6 @@ namespace { using DtypeScale = float; using DtypeAccum = float; using DtypeEpilogue = float; -using DtypeOutput = cutlass::bfloat16_t; using Multiply = cutlass::epilogue::fusion::Sm90Compute< cutlass::multiplies, @@ -68,12 +67,6 @@ using Add = cutlass::epilogue::fusion::Sm90Compute< DtypeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>; -using Cast = cutlass::epilogue::fusion::Sm90Compute< - cutlass::epilogue::thread::Identity, - DtypeOutput, - DtypeEpilogue, - cutlass::FloatRoundStyle::round_to_nearest>; - template struct Schedule; @@ -120,7 +113,8 @@ template < typename FastAccum, typename DtypeA, typename DtypeB, - typename DtypeBias> + typename DtypeBias, + typename DtypeOutput> void f8f8bf16_rowwise_impl( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 @@ -181,6 +175,11 @@ void f8f8bf16_rowwise_impl( WScale, cutlass::epilogue::fusion::Sm90EVT>; + using Cast = cutlass::epilogue::fusion::Sm90Compute< + cutlass::epilogue::thread::Identity, + DtypeOutput, + DtypeEpilogue, + cutlass::FloatRoundStyle::round_to_nearest>; using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT< Cast, cutlass::epilogue::fusion::Sm90EVT< @@ -313,7 +312,8 @@ template < typename FastAccum, typename DtypeA, typename DtypeB, - typename DtypeBias> + typename DtypeBias, + typename DtypeOutput> void f8f8bf16_rowwise_impl_sm100_sm120( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 @@ -372,6 +372,11 @@ void f8f8bf16_rowwise_impl_sm100_sm120( WScale, cutlass::epilogue::fusion::Sm90EVT>; + using Cast = cutlass::epilogue::fusion::Sm90Compute< + cutlass::epilogue::thread::Identity, + DtypeOutput, + DtypeEpilogue, + cutlass::FloatRoundStyle::round_to_nearest>; using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT< Cast, cutlass::epilogue::fusion::Sm90EVT< @@ -498,7 +503,8 @@ template < typename FastAccum, typename DtypeA, typename DtypeB, - typename DtypeBias> + typename DtypeBias, + typename DtypeOutput> void f8f8bf16_rowwise_impl_sm89( at::Tensor XQ, // FP8 at::Tensor WQ, // FP8 @@ -765,7 +771,8 @@ template < typename FastAccum, typename DtypeA, typename DtypeB, - typename DtypeBias> + typename DtypeBias, + typename DtypeOutput> void handle_transposition( at::Tensor XQ, at::Tensor WQ, @@ -782,7 +789,8 @@ void handle_transposition( FastAccum, DtypeA, DtypeB, - DtypeBias>(XQ, WQ, x_scale, w_scale, bias, out, swizzle); + DtypeBias, + DtypeOutput>(XQ, WQ, x_scale, w_scale, bias, out, swizzle); } else { dispatch_fp8_rowwise_kernel_on_tile_size< ClusterShape, @@ -791,7 +799,8 @@ void handle_transposition( FastAccum, DtypeB, DtypeA, - DtypeBias>(WQ.t(), XQ.t(), w_scale.t(), x_scale.t(), bias, out.t(), swizzle); + DtypeBias, + DtypeOutput>(WQ.t(), XQ.t(), w_scale.t(), x_scale.t(), bias, out.t(), swizzle); } } @@ -1027,11 +1036,19 @@ void dispatch_fp8_rowwise_kernel_on_bias_dtype( at::Tensor out) { if (bias.has_value() && bias->dtype() == at::kBFloat16) { dispatch_fp8_rowwise_kernel_on_input_dtypes< + cutlass::bfloat16_t, cutlass::bfloat16_t> (XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out); + } else if (bias.has_value() && bias->dtype() == at::kHalf){ + TORCH_CHECK(out.dtype() == at::kHalf, "Output should be Float16 when bias is Float16"); + dispatch_fp8_rowwise_kernel_on_input_dtypes< + cutlass::half_t, + cutlass::half_t> + (XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out); } else { dispatch_fp8_rowwise_kernel_on_input_dtypes< - float> + float, + cutlass::bfloat16_t> //Types...> (XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out); } @@ -1073,14 +1090,14 @@ void check_inputs( if (bias.has_value()) { TORCH_CHECK(bias->device() == b.device()); - TORCH_CHECK(bias->dtype() == at::kFloat || bias->dtype() == at::kBFloat16); + TORCH_CHECK(bias->dtype() == at::kFloat || bias->dtype() == at::kBFloat16 || bias->dtype() == at::kHalf); TORCH_CHECK(bias->dim() == 1); TORCH_CHECK(bias->size(0) == b.size(1)); TORCH_CHECK(bias->stride(0) == 1); } TORCH_CHECK(out.device() == a.device()); - TORCH_CHECK(out.dtype() == at::kBFloat16); + TORCH_CHECK(out.dtype() == at::kBFloat16 || out.dtype() == at::kHalf); TORCH_CHECK(out.dim() == 2); TORCH_CHECK(out.size(0) == a.size(0)); TORCH_CHECK(out.size(1) == b.size(1)); diff --git a/aten/src/ATen/native/cuda/ScaledBlas.cpp b/aten/src/ATen/native/cuda/ScaledBlas.cpp index eb7e66649d19a..ba52dc6bc042a 100644 --- a/aten/src/ATen/native/cuda/ScaledBlas.cpp +++ b/aten/src/ATen/native/cuda/ScaledBlas.cpp @@ -591,7 +591,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, if ((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900) // cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales || (dprops->major >= 10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty()))) { - TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling."); + TORCH_CHECK_VALUE(out.dtype() == kBFloat16 || out.dtype() == kHalf, "Only bf16 and fp16 high precision output types are supported for row-wise scaling."); return _scaled_rowwise_rowwise( mat1, mat2, @@ -736,7 +736,7 @@ _scaled_rowwise_rowwise( if (((dprops->major < 9 || CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900) // cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales || (dprops->major == 10 && (scale_a.sizes().size() || scale_b.sizes().size())))) { - TORCH_CHECK_VALUE(out.dtype() == kBFloat16, "Only bf16 high precision output types are supported for row-wise scaling."); + TORCH_CHECK_VALUE(out.dtype() == kBFloat16 || out.dtype() == kHalf, "Only bf16 and fp16 high precision output types are supported for row-wise scaling."); at::cuda::detail::f8f8bf16_rowwise( mat_a, mat_b, @@ -794,6 +794,24 @@ void _check_deepseek_scale_stride(const Tensor& scale, const Tensor& t, const Sc } } +void +_check_deepseek_support() { +#ifndef USE_ROCM + auto dprops = at::cuda::getCurrentDeviceProperties(); + if (dprops->major != 9) { + // Only on Hopper GPUs + TORCH_CHECK_NOT_IMPLEMENTED( + dprops->major == 9, + "DeepSeek style (1x128, 128x128) scaling only supported in CUDA for SM90") + } + // Only in cublasLt >= 12.9 + TORCH_CHECK_NOT_IMPLEMENTED( + CUBLAS_VERSION < 120900 || cublasLtGetVersion() < 120900, + "DeepSeek style (1x128, 128x128) scaling requires cublasLt >= 12.9" + ); +#endif +} + Tensor& _scaled_block1x128_block1x128( const Tensor& mat_a, const Tensor& mat_b, @@ -802,8 +820,12 @@ _scaled_block1x128_block1x128( const c10::ScalarType out_dtype, const bool use_fast_accum, Tensor& out) { +#ifndef USE_ROCM // Restrictions: // A, B are FP8, scales are fp32, shape K//128 + // CUDA: Only Hopper GPUs + _check_deepseek_support(); + TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()); TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat, @@ -821,6 +843,12 @@ _scaled_block1x128_block1x128( _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); return out; +#else + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "1x128 and 128x128 scaling not available with ROCm" + ); +#endif } Tensor& @@ -831,10 +859,12 @@ _scaled_block128x128_block1x128( const c10::ScalarType out_dtype, const bool use_fast_accum, Tensor& out) { +#ifndef USE_ROCM // Restrictions: // A, B are FP8, scales are fp32, shape K//128 - std::cout << "mat_b: " << mat_b.dim() << ", " << mat_b.sizes() << ", " << mat_b.strides() << std::endl; - std::cout << "scale_b: " << scale_b.dim() << ", " << scale_b.sizes() << ", " << scale_b.strides() << std::endl; + // CUDA: Only Hopper GPUs + _check_deepseek_support(); + TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()); TORCH_CHECK_VALUE(scale_a.sizes()[0] == ceil_div(mat_a.sizes()[0], 128) && scale_a.sizes()[1] == ceil_div(mat_a.sizes()[1], 128) && scale_a.scalar_type() == kFloat, @@ -852,6 +882,12 @@ _scaled_block128x128_block1x128( _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); return out; +#else + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "1x128 and 128x128 scaling not available with ROCm" + ); +#endif } Tensor& @@ -862,8 +898,12 @@ _scaled_block1x128_block128x128( const c10::ScalarType out_dtype, const bool use_fast_accum, Tensor& out) { +#ifndef USE_ROCM // Restrictions: // A, B are FP8, scales are fp32, A: shape K//128, B: K//128, N//128 + // CUDA: Only Hopper GPUs + _check_deepseek_support(); + TORCH_CHECK_VALUE(isFloat8Type(mat_a.scalar_type()) && isFloat8Type(mat_b.scalar_type()), "mat_a and mat_b must be fp8 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()); TORCH_CHECK_VALUE(scale_a.sizes()[0] == mat_a.sizes()[0] && scale_a.sizes()[1] == mat_a.sizes()[1] / 128 && scale_a.scalar_type() == kFloat, @@ -881,6 +921,12 @@ _scaled_block1x128_block128x128( _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, use_fast_accum, out); return out; +#else + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "1x128 and 128x128 scaling not available with ROCm" + ); +#endif } Tensor& diff --git a/aten/src/ATen/native/cuda/ScatterGatherKernel.cu b/aten/src/ATen/native/cuda/ScatterGatherKernel.cu index 26064dd98377c..2def9b196974e 100644 --- a/aten/src/ATen/native/cuda/ScatterGatherKernel.cu +++ b/aten/src/ATen/native/cuda/ScatterGatherKernel.cu @@ -160,8 +160,8 @@ struct _cuda_scatter_gather_internal_kernel { auto offsets = offset_calc.get(i); int64_t idx_dim = *(index_t*)(index_ptr + offsets[2]); - CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size - && "scatter gather kernel index out of bounds"); + CUDA_KERNEL_ASSERT_VERBOSE(idx_dim >= 0 && idx_dim < index_size + && "scatter gather kernel index out of bounds", "Expected 0 <= idx_dim < index_size (%ld), but got idx_dim = %ld", index_size, idx_dim); f( (scalar_t*)(self_ptr + offsets[0]), @@ -406,9 +406,8 @@ struct _cuda_scatter_fill_internal_kernel { auto offsets = offset_calc.get(i); int64_t idx_dim = *(index_t*)(index_ptr + offsets[1]); - CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size - && "index out of bounds" - ); + CUDA_KERNEL_ASSERT_VERBOSE(idx_dim >= 0 && idx_dim < index_size + && "index out of bounds", "Expected 0 <= idx_dim < index_size (%ld), but got idx_dim = %ld", index_size, idx_dim); f( (scalar_t*)(self_ptr + offsets[0]), diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index c457bd3dba753..730a7ea910961 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -141,7 +141,8 @@ WelfordDataLN cuWelfordOnlineSum( if constexpr (!rms_norm){ U delta = val - curr_sum.mean; U new_count = curr_sum.count + 1.f; -#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL) +//Due to low CU count, we run into accuracy issues on gfx90a with `__builtin_amdgcn_rcpf` +#if defined(USE_ROCM) && !defined(__gfx90a__) && defined(USE_LAYERNORM_FAST_RECIPROCAL) U new_mean = curr_sum.mean + delta * __builtin_amdgcn_rcpf(new_count); #else U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster @@ -163,7 +164,8 @@ WelfordDataLN cuWelfordCombine( U count = dataA.count + dataB.count; U mean, sigma2; if (count > decltype(dataB.count){0}) { -#if defined(USE_ROCM) && defined(USE_LAYERNORM_FAST_RECIPROCAL) +//Due to low CU count, we run into accuracy issues on gfx90a with `__builtin_amdgcn_rcpf` +#if defined(USE_ROCM) && !defined(__gfx90a__) && defined(USE_LAYERNORM_FAST_RECIPROCAL) auto coef = __builtin_amdgcn_rcpf(count); #else auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp index dea59d5913b91..525714e2817ed 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp @@ -1881,6 +1881,8 @@ void geqrf_kernel(const Tensor& input, const Tensor& tau) { REGISTER_CUDA_DISPATCH(geqrf_stub, &geqrf_kernel) +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eigh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + template static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { #if !AT_MAGMA_ENABLED() @@ -1955,8 +1957,6 @@ static void apply_magma_eigh(const Tensor& values, const Tensor& vectors, const #endif } -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eigh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // This is a type dispatch function for 'apply_magma_eigh' // For small inputs result is computed on CPU void linalg_eigh_magma(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors) { @@ -2019,10 +2019,10 @@ This is an in-place routine, content of 'input', 'values', 'vectors' is overwrit For more information see MAGMA's documentation for GEEV routine. */ template -void apply_linalg_eig(Tensor& values, Tensor& vectors, Tensor& input, Tensor& infos, bool compute_eigenvectors) { +void apply_magma_eig(Tensor& values, Tensor& vectors, Tensor& input, Tensor& infos, bool compute_eigenvectors) { #if !AT_MAGMA_ENABLED() -TORCH_CHECK(false, "Calling torch.linalg.eig on a CUDA tensor requires compiling PyTorch with MAGMA. " - "Either transfer the tensor to the CPU before calling torch.linalg.eig or recompile with MAGMA."); +TORCH_CHECK(false, "Calling torch.linalg.eig with MAGMA requires compiling PyTorch with MAGMA. " + "Either transfer the tensor to the CPU before calling torch.linalg.eig or use cuSolver."); #else TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.device() == at::kCPU); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.device() == at::kCPU); @@ -2076,22 +2076,44 @@ TORCH_CHECK(false, "Calling torch.linalg.eig on a CUDA tensor requires compiling #endif } -// This is a type dispatching helper function for 'apply_linalg_eig' +// MAGMA wrapper: transfers tensors to CPU, calls apply_magma_eig, then copies results back. +void linalg_eig_magma(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, const Tensor& input, bool compute_eigenvectors){ + // MAGMA doesn't have GPU interface for the eigendecomposition, and it forces us to transfer to CPU + auto eigenvalues_cpu = eigenvalues.cpu(); + auto eigenvectors_cpu = eigenvectors.cpu(); + auto infos_cpu = infos.cpu(); + + Tensor input_cpu = at::empty(input.sizes(), input.options().device(kCPU)); + input_cpu.transpose_(-2, -1); + input_cpu.copy_(input); + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "linalg_eig_out_cuda", [&]{ + apply_magma_eig(eigenvalues_cpu, eigenvectors_cpu, input_cpu, infos_cpu, compute_eigenvectors); + }); + + eigenvalues.copy_(eigenvalues_cpu); + eigenvectors.copy_(eigenvectors_cpu); + infos.copy_(infos_cpu); +} void linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, const Tensor& input, bool compute_eigenvectors) { // This function calculates the non-symmetric eigendecomposition in-place // tensors should be in batched column major memory format - // the content of eigenvalues, eigenvectors and infos is overwritten by 'apply_linalg_eig' + // the content of eigenvalues, eigenvectors and infos is overwritten by 'linalg_eig_magma' or + // 'linalg_eig_cusolver_xgeev' both geev routines modify the provided input matrix in-place, therefore we need a copy - // apply_linalg_eig modifies the provided input matrix in-place, therefore we need a copy - // MAGMA doesn't have GPU interface for the eigendecomposition and it forces us to transfer 'input' to CPU TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_cuda()); - Tensor input_working_copy = at::empty(input.sizes(), input.options().device(kCPU)); - input_working_copy.transpose_(-2, -1); // make input_working_copy to have Fortran contiguous memory layout - input_working_copy.copy_(input); - - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(input.scalar_type(), "linalg_eig_out_cuda", [&]{ - apply_linalg_eig(eigenvalues, eigenvectors, input_working_copy, infos, compute_eigenvectors); - }); +#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702) + auto preferred_backend = at::globalContext().linalgPreferredBackend(); + switch (preferred_backend) { + case at::LinalgBackend::Cusolver: + default: + linalg_eig_cusolver_xgeev(eigenvalues, eigenvectors, input, infos, compute_eigenvectors); + return; + case at::LinalgBackend::Magma: + break; // MAGMA path handled below + } +#endif + linalg_eig_magma(eigenvalues, eigenvectors, infos, input, compute_eigenvectors); } REGISTER_CUDA_DISPATCH(linalg_eig_stub, &linalg_eig_kernel) diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp index 5b28cc6eccf01..a0a581b858d09 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp @@ -753,8 +753,8 @@ static void apply_cholesky_cusolver_potrf_looped(const Tensor& self_working_copy handle, params, uplo, n, datatype, self_working_copy_ptr + i * matrix_stride, lda, datatype, - (char*)workdata_device_ptr + i * worksize_device, worksize_device, - (char*)workdata_host_ptr + i * worksize_host, worksize_host, + static_cast(workdata_device_ptr) + i * worksize_device, worksize_device, + static_cast(workdata_host_ptr) + i * worksize_host, worksize_host, infos_ptr + i ); } @@ -1625,6 +1625,126 @@ void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, #endif } +// cuSOLVER Xgeev (requires cuSOLVER >= 11.7.2, i.e. CUDA 12.8+) +#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702) + +template +void apply_xgeev(const Tensor& values, const Tensor& vectors, const Tensor& input, const Tensor& infos, bool compute_eigenvectors) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(values.is_cuda()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(vectors.is_cuda()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(input.is_cuda()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(infos.is_cuda()); + + int n = cuda_int_cast(input.size(-1), "n"); + int lda = std::max(1, n); + auto batch_size = batchCount(input); + + if (n == 0 || batch_size == 0) { + // XGeev crashes on empty input, explicitly handle empty input + auto values_shape = IntArrayRef(input.sizes().data(), input.dim() - 1); + values.resize_(values_shape, MemoryFormat::Contiguous); + values.zero_(); + + if (compute_eigenvectors) { + vectors.resize_(input.sizes(), MemoryFormat::Contiguous); + vectors.zero_(); + } else { + vectors.resize_({0}); + } + + infos.resize_({std::max(1, batch_size)}, MemoryFormat::Contiguous); + infos.zero_(); + return; + } + + int64_t vectors_stride = 0; + if (compute_eigenvectors){ + vectors_stride = matrixStride(vectors); + } + + auto values_stride = values.size(-1); + auto vectors_data = vectors.data_ptr(); + auto values_data = values.data_ptr(); + auto infos_data = infos.data_ptr(); + + cusolverDnParams_t params = nullptr; + TORCH_CUSOLVER_CHECK(cusolverDnCreateParams(¶ms)); + + Tensor A_fortran = input.mT().contiguous(); + auto* A_data = A_fortran.data_ptr(); + const auto A_stride = matrixStride(A_fortran); + auto handle = at::cuda::getCurrentCUDASolverDnHandle(); + + const int ldvl = 1; // ldvl >= 1 if jobvl = CUSOLVER_EIG_MODE_NOVECTOR + cusolverEigMode_t jobvl = CUSOLVER_EIG_MODE_NOVECTOR; + + cusolverEigMode_t jobvr; + int ldvr; + if (compute_eigenvectors) { + ldvr = n; // ldvr >= n if jobvr = CUSOLVER_EIG_MODE_VECTOR + jobvr = CUSOLVER_EIG_MODE_VECTOR; + } + else { + ldvr = 1; // ldvr >= 1 if jobvr = CUSOLVER_EIG_MODE_NOVECTOR + jobvr = CUSOLVER_EIG_MODE_NOVECTOR; + } + + scalar_t* W = values.data_ptr(); + scalar_t* VL = nullptr; + scalar_t* VR = vectors.data_ptr(); + + const scalar_t* A_const = A_data; + const scalar_t* W_const = W; + const scalar_t* VL_const = VL; + const scalar_t* VR_const = VR; + + size_t ws_dev = 0, ws_host = 0; + at::cuda::solver::xgeev_bufferSize( + handle, params, + jobvl, jobvr, + n, + A_const, lda, + W_const, + VL_const, ldvl, + VR_const, ldvr, + &ws_dev, &ws_host); + + auto& device_allocator = *at::cuda::getCUDADeviceAllocator(); + auto work_device_data = device_allocator.allocate(ws_dev); + // use pinned memory for best performance. + auto& host_allocator = *at::cuda::getPinnedMemoryAllocator(); + auto work_host_data = host_allocator.allocate(ws_host); + + for (decltype(batch_size) i = 0; i < batch_size; ++i) { + scalar_t* Ai = A_data + i * A_stride; + scalar_t* Wi = values_data + i * values_stride; + scalar_t* VLi = nullptr; // xgeev does not support computing left evs + scalar_t* VRi = compute_eigenvectors ? (vectors_data + i * vectors_stride) : nullptr; + int* info = infos_data + i; + + at::cuda::solver::xgeev( + handle, params, + jobvl, jobvr, + n, + Ai, lda, + Wi, + VLi, ldvl, + VRi, ldvr, + static_cast(work_device_data.get()), ws_dev, + static_cast(work_host_data.get()), ws_host, + info); + } + TORCH_CUSOLVER_CHECK(cusolverDnDestroyParams(params)); +} + +void linalg_eig_cusolver_xgeev(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& input, const Tensor& infos, bool compute_eigenvectors) { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(eigenvectors.scalar_type(), "linalg_eig_cuda", [&] { + apply_xgeev(eigenvalues, eigenvectors, input, infos, compute_eigenvectors); + }); +} + +#endif // defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702) + // The 'apply_' word is used for templated by dtype functions that call an API routine // underneath. Since the cusolver API has a slightly different structure we do not prepend // apply_ to this function. diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h index cd03319f96d05..6bd10454d3054 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h @@ -73,6 +73,11 @@ void ormqr_cusolver(const Tensor& input, const Tensor& tau, const Tensor& other, Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau); void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors); + +void linalg_eig_cusolver_xgeev(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& input, const Tensor& infos, bool compute_eigenvectors); + + + void lu_solve_looped_cusolver(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose); void lu_factor_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots); diff --git a/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp b/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp index af183038bb8e4..0928efb7708b5 100644 --- a/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp +++ b/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp @@ -1954,6 +1954,336 @@ void xsyevd, double>( workspaceInBytesOnHost, info)); } + +// cuSOLVER Xgeev bindings (requires cuSOLVER >= 11.7.2, i.e. CUDA 12.8+) +#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702) + +template <> +void xgeev_bufferSize( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobvl, + cusolverEigMode_t jobvr, + int64_t n, + const float* A, + int64_t lda, + const float* W, + const float* VL, + int64_t ldvl, + const float* VR, + int64_t ldvr, + size_t* workspaceInBytesOnDevice, + size_t* workspaceInBytesOnHost) { + TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize( + handle, params, jobvl, jobvr, n, + CUDA_R_32F, + reinterpret_cast(A), + lda, + CUDA_R_32F, + reinterpret_cast(W), + CUDA_R_32F, + reinterpret_cast(VL), + ldvl, + CUDA_R_32F, + reinterpret_cast(VR), + ldvr, + CUDA_R_32F, + workspaceInBytesOnDevice, + workspaceInBytesOnHost)); +} + +template <> +void xgeev_bufferSize( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobvl, + cusolverEigMode_t jobvr, + int64_t n, + const double* A, + int64_t lda, + const double* W, + const double* VL, + int64_t ldvl, + const double* VR, + int64_t ldvr, + size_t* workspaceInBytesOnDevice, + size_t* workspaceInBytesOnHost) { + TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize( + handle, params, jobvl, jobvr, n, + CUDA_R_64F, + reinterpret_cast(A), + lda, + CUDA_R_64F, + reinterpret_cast(W), + CUDA_R_64F, + reinterpret_cast(VL), + ldvl, + CUDA_R_64F, + reinterpret_cast(VR), + ldvr, + CUDA_R_64F, + workspaceInBytesOnDevice, + workspaceInBytesOnHost)); +} + + +template <> +void xgeev_bufferSize>( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobvl, + cusolverEigMode_t jobvr, + int64_t n, + const c10::complex* A, + int64_t lda, + const c10::complex* W, + const c10::complex* VL, + int64_t ldvl, + const c10::complex* VR, + int64_t ldvr, + size_t* workspaceInBytesOnDevice, + size_t* workspaceInBytesOnHost) { + TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize( + handle, params, jobvl, jobvr, n, + CUDA_C_32F, + reinterpret_cast(A), + lda, + CUDA_C_32F, + reinterpret_cast(W), + CUDA_C_32F, + reinterpret_cast(VL), + ldvl, + CUDA_C_32F, + reinterpret_cast(VR), + ldvr, + CUDA_C_32F, + workspaceInBytesOnDevice, + workspaceInBytesOnHost)); +} + +template <> +void xgeev_bufferSize>( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobvl, + cusolverEigMode_t jobvr, + int64_t n, + const c10::complex* A, + int64_t lda, + const c10::complex* W, + const c10::complex* VL, + int64_t ldvl, + const c10::complex* VR, + int64_t ldvr, + size_t* workspaceInBytesOnDevice, + size_t* workspaceInBytesOnHost) { + TORCH_CUSOLVER_CHECK(cusolverDnXgeev_bufferSize( + handle, params, jobvl, jobvr, n, + CUDA_C_64F, + reinterpret_cast(A), + lda, + CUDA_C_64F, + reinterpret_cast(W), + CUDA_C_64F, + reinterpret_cast(VL), + ldvl, + CUDA_C_64F, + reinterpret_cast(VR), + ldvr, + CUDA_C_64F, + workspaceInBytesOnDevice, + workspaceInBytesOnHost)); +} + +template <> +void xgeev( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobvl, + cusolverEigMode_t jobvr, + int64_t n, + float* A, + int64_t lda, + float* W, + float* VL, + int64_t ldvl, + float* VR, + int64_t ldvr, + float* bufferOnDevice, + size_t workspaceInBytesOnDevice, + float* bufferOnHost, + size_t workspaceInBytesOnHost, + int* info) { + + TORCH_CUSOLVER_CHECK(cusolverDnXgeev( + handle, + params, + jobvl, + jobvr, + n, + CUDA_R_32F, + reinterpret_cast(A), + lda, + CUDA_R_32F, + reinterpret_cast(W), + CUDA_R_32F, + reinterpret_cast(VL), + ldvl, + CUDA_R_32F, + reinterpret_cast(VR), + ldvr, + CUDA_R_32F, + reinterpret_cast(bufferOnDevice), + workspaceInBytesOnDevice, + reinterpret_cast(bufferOnHost), + workspaceInBytesOnHost, + info)); +} + + + + +template <> +void xgeev( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobvl, + cusolverEigMode_t jobvr, + int64_t n, + double* A, + int64_t lda, + double* W, + double* VL, + int64_t ldvl, + double* VR, + int64_t ldvr, + double* bufferOnDevice, + size_t workspaceInBytesOnDevice, + double* bufferOnHost, + size_t workspaceInBytesOnHost, + int* info) { + + TORCH_CUSOLVER_CHECK(cusolverDnXgeev( + handle, + params, + jobvl, + jobvr, + n, + CUDA_R_64F, + reinterpret_cast(A), + lda, + CUDA_R_64F, + reinterpret_cast(W), + CUDA_R_64F, + reinterpret_cast(VL), + ldvl, + CUDA_R_64F, + reinterpret_cast(VR), + ldvr, + CUDA_R_64F, + reinterpret_cast(bufferOnDevice), + workspaceInBytesOnDevice, + reinterpret_cast(bufferOnHost), + workspaceInBytesOnHost, + info)); + +} + +template <> +void xgeev>( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobvl, + cusolverEigMode_t jobvr, + int64_t n, + c10::complex* A, + int64_t lda, + c10::complex* W, + c10::complex* VL, + int64_t ldvl, + c10::complex* VR, + int64_t ldvr, + c10::complex* bufferOnDevice, + size_t workspaceInBytesOnDevice, + c10::complex* bufferOnHost, + size_t workspaceInBytesOnHost, + int* info) { + + TORCH_CUSOLVER_CHECK(cusolverDnXgeev( + handle, + params, + jobvl, + jobvr, + n, + CUDA_C_32F, + reinterpret_cast(A), + lda, + CUDA_C_32F, + reinterpret_cast(W), + CUDA_C_32F, + reinterpret_cast(VL), + ldvl, + CUDA_C_32F, + reinterpret_cast(VR), + ldvr, + CUDA_C_32F, + reinterpret_cast(bufferOnDevice), + workspaceInBytesOnDevice, + reinterpret_cast(bufferOnHost), + workspaceInBytesOnHost, + info)); +} + +template <> +void xgeev>( + cusolverDnHandle_t handle, + cusolverDnParams_t params, + cusolverEigMode_t jobvl, + cusolverEigMode_t jobvr, + int64_t n, + c10::complex* A, + int64_t lda, + c10::complex* W, + c10::complex* VL, + int64_t ldvl, + c10::complex* VR, + int64_t ldvr, + c10::complex* bufferOnDevice, + size_t workspaceInBytesOnDevice, + c10::complex* bufferOnHost, + size_t workspaceInBytesOnHost, + int* info) { + + TORCH_CUSOLVER_CHECK(cusolverDnXgeev( + handle, + params, + jobvl, + jobvr, + n, + CUDA_C_64F, + reinterpret_cast(A), + lda, + CUDA_C_64F, + reinterpret_cast(W), + CUDA_C_64F, + reinterpret_cast(VL), + ldvl, + CUDA_C_64F, + reinterpret_cast(VR), + ldvr, + CUDA_C_64F, + reinterpret_cast(bufferOnDevice), + workspaceInBytesOnDevice, + reinterpret_cast(bufferOnHost), + workspaceInBytesOnHost, + info)); +} + + + + +#endif // defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702) + #endif // USE_CUSOLVER_64_BIT #ifdef USE_CUSOLVER_64_BIT_XSYEV_BATCHED diff --git a/aten/src/ATen/native/cuda/linalg/CUDASolver.h b/aten/src/ATen/native/cuda/linalg/CUDASolver.h index cb46608c50b54..26f44c788b354 100644 --- a/aten/src/ATen/native/cuda/linalg/CUDASolver.h +++ b/aten/src/ATen/native/cuda/linalg/CUDASolver.h @@ -674,6 +674,66 @@ template <> void xsyevd, double>( CUDASOLVER_XSYEVD_ARGTYPES(c10::complex, double)); + + +// cuSOLVER Xgeev (non-Hermitian eigen decomposition, CUDA >= 12.8) +#if defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702) + +#define CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(scalar_t) \ +cusolverDnHandle_t handle, cusolverDnParams_t params, \ +cusolverEigMode_t jobvl, cusolverEigMode_t jobvr, int64_t n, \ +const scalar_t* A, int64_t lda, const scalar_t* W, \ +const scalar_t* VL, int64_t ldvl, const scalar_t* VR, int64_t ldvr, \ +size_t* workspaceInBytesOnDevice, size_t* workspaceInBytesOnHost + +template +void xgeev_bufferSize( + CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(scalar_t)) { + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::xgeev_bufferSize: not implemented"); +} + +template <> +void xgeev_bufferSize(CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(float)); + +template <> +void xgeev_bufferSize(CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(double)); + +template <> +void xgeev_bufferSize>( + CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(c10::complex)); + +template <> +void xgeev_bufferSize>( + CUDASOLVER_XGEEV_BUFFERSIZE_ARGTYPES(c10::complex)); + +#define CUDASOLVER_XGEEV_ARGTYPES(scalar_t) \ +cusolverDnHandle_t handle, cusolverDnParams_t params, \ +cusolverEigMode_t jobvl, cusolverEigMode_t jobvr, int64_t n, scalar_t *A, \ +int64_t lda, scalar_t *W, scalar_t *VL, int64_t ldvl, scalar_t *VR, int64_t ldvr,\ +scalar_t *bufferOnDevice, size_t workspaceInBytesOnDevice, scalar_t *bufferOnHost,\ +size_t workspaceInBytesOnHost, int *info + +template +void xgeev(CUDASOLVER_XGEEV_ARGTYPES(scalar_t)) { + static_assert(false&&sizeof(scalar_t), + "at::cuda::solver::xgeev: not implemented"); +} + +template <> +void xgeev(CUDASOLVER_XGEEV_ARGTYPES(float)); + +template <> +void xgeev(CUDASOLVER_XGEEV_ARGTYPES(double)); + +template <> +void xgeev>(CUDASOLVER_XGEEV_ARGTYPES(c10::complex)); + +template <> +void xgeev>(CUDASOLVER_XGEEV_ARGTYPES(c10::complex)); + +#endif // defined(CUSOLVER_VERSION) && (CUSOLVER_VERSION >= 11702) + #endif // USE_CUSOLVER_64_BIT #ifdef USE_CUSOLVER_64_BIT_XSYEV_BATCHED diff --git a/aten/src/ATen/native/cudnn/ConvShared.cpp b/aten/src/ATen/native/cudnn/ConvShared.cpp index b86b7436138f2..325b082f314d9 100644 --- a/aten/src/ATen/native/cudnn/ConvShared.cpp +++ b/aten/src/ATen/native/cudnn/ConvShared.cpp @@ -119,8 +119,8 @@ void setConvolutionParams( params->input_dim = input.dim(); params->memory_format = memory_format; for (int i = 0; i != params->input_dim; ++i) { - params->input_size[i] = (int)input.sizes()[i]; - params->weight_size[i] = (int)weight.sizes()[i]; + params->input_size[i] = static_cast(input.sizes()[i]); + params->weight_size[i] = static_cast(weight.sizes()[i]); } // ASSERT(padding.size() == stride.size()) // ASSERT(padding.size() == dilation.size()) diff --git a/aten/src/ATen/native/cudnn/Conv_v7.cpp b/aten/src/ATen/native/cudnn/Conv_v7.cpp index 081b4afa15ac5..bc064e3ad3167 100644 --- a/aten/src/ATen/native/cudnn/Conv_v7.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v7.cpp @@ -64,7 +64,7 @@ // fastest algorithm combination with a sub optimal mathType. constexpr size_t operator"" _TiB(unsigned long long n) { - return size_t(n) * 1024 * 1024 * 1024 * 1024; + return static_cast(n) * 1024 * 1024 * 1024 * 1024; } namespace at { diff --git a/aten/src/ATen/native/cudnn/Conv_v8.cpp b/aten/src/ATen/native/cudnn/Conv_v8.cpp index 8a19fac27bfd4..75ab950e19bbb 100644 --- a/aten/src/ATen/native/cudnn/Conv_v8.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v8.cpp @@ -46,7 +46,7 @@ namespace { // TODO: remove duplicate code in Conv_v7.cpp constexpr int64_t operator"" _TiB(unsigned long long n) { - return size_t(n) << 40; + return static_cast(n) << 40; } uint8_t getAlignment(const Tensor& t) { @@ -93,7 +93,10 @@ cudnn_frontend::Tensor getTensorDescriptorWithTypeVirtual( std::vector strides_copy(std::begin(strides), std::end(strides)); fixSizeOneDimStride( - sizes.size(), &sizes[0], (int64_t*)&strides_copy[0], channels_last); + sizes.size(), + &sizes[0], + static_cast(&strides_copy[0]), + channels_last); auto r = cudnn_frontend::TensorBuilder() .setDim(sizes.size(), sizes.data()) .setStrides(strides_copy.size(), strides_copy.data()) diff --git a/aten/src/ATen/native/cudnn/GridSampler.cpp b/aten/src/ATen/native/cudnn/GridSampler.cpp index 3b5f5bd218bb5..fc41957bb7c7d 100644 --- a/aten/src/ATen/native/cudnn/GridSampler.cpp +++ b/aten/src/ATen/native/cudnn/GridSampler.cpp @@ -44,6 +44,7 @@ std::tuple cudnn_grid_sampler_backward( #include #include #include +#include #include #include @@ -59,11 +60,11 @@ void setSamplerDescriptor( SpatialTransformerDescriptor& desc, cudnnDataType_t dataType, const at::Tensor& tensor) { - int inputSize[4] = {0}; + std::array inputSize{0}; for (const auto i : c10::irange(tensor.dim())) { - inputSize[i] = (int)tensor.size(i); + inputSize[i] = static_cast(tensor.size(i)); } - desc.set(dataType, 4, inputSize); + desc.set(dataType, 4, inputSize.data()); } void checkGridSize(CheckedFrom c, TensorArg grid, TensorArg input) { diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index 704a333b1f84b..e7030a00c71d5 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -656,7 +656,8 @@ void add_projection_weights( TORCH_INTERNAL_ASSERT( nb_dims <= min_dim, "nb_dims = ", nb_dims, "; min_dim = ", min_dim); auto elem_size = dataSize(getCudnnDataType(weight_buf)); - auto offset_bytes = (char*)matrix_pointer - (char*)weight_buf.data_ptr(); + auto offset_bytes = static_cast(matrix_pointer) - + static_cast(weight_buf.data_ptr()); TORCH_INTERNAL_ASSERT( offset_bytes % elem_size == 0, "offset_bytes = ", @@ -794,8 +795,8 @@ get_parameters( "; min_dim = ", min_dim); auto elem_size = dataSize(getCudnnDataType(weight_buf)); - auto offset_bytes = - (char*)matrix_pointer - (char*)weight_buf.data_ptr(); + auto offset_bytes = static_cast(matrix_pointer) - + static_cast(weight_buf.data_ptr()); TORCH_INTERNAL_ASSERT( offset_bytes % elem_size == 0, "offset_bytes = ", diff --git a/aten/src/ATen/native/mkl/SpectralOps.cpp b/aten/src/ATen/native/mkl/SpectralOps.cpp index 4aa53c5e794b8..da4fc3fb6d079 100644 --- a/aten/src/ATen/native/mkl/SpectralOps.cpp +++ b/aten/src/ATen/native/mkl/SpectralOps.cpp @@ -330,7 +330,6 @@ Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, } #elif AT_MKL_ENABLED() -#include #include #include diff --git a/aten/src/ATen/native/mkldnn/Linear.cpp b/aten/src/ATen/native/mkldnn/Linear.cpp index 2f8448cf57d1f..719865063e20d 100644 --- a/aten/src/ATen/native/mkldnn/Linear.cpp +++ b/aten/src/ATen/native/mkldnn/Linear.cpp @@ -535,7 +535,7 @@ mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2, float input_scale = scale_a.item(); float weight_scale = scale_b.item(); - float output_scale = float(1.0); + float output_scale = 1.0f; if (scale_result.has_value() && (*out_dtype == ScalarType::Float8_e4m3fn || *out_dtype == ScalarType::Float8_e5m2)) { diff --git a/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp b/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp index 1c1841148740a..aef6da7937bf5 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp +++ b/aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp @@ -530,7 +530,7 @@ static Tensor get_mkldnn_serialized_md(const Tensor& self) { #else TORCH_CHECK(false, "Unexpected IDeep version to do weight serialization."); #endif - Tensor serialized_md = at::from_blob((void*)serialized_wei_desc.data(), {(int64_t)serialized_wei_desc.size()}, at::TensorOptions(at::kByte)); + Tensor serialized_md = at::from_blob((void*)serialized_wei_desc.data(), {static_cast(serialized_wei_desc.size())}, at::TensorOptions(at::kByte)); auto res = at::empty_like(serialized_md); // serialized_md shares the buffer with serialized_wei_desc, // which will be released outside of this function thus invalidating the buffer of serialized_md. diff --git a/aten/src/ATen/native/mkldnn/Matmul.cpp b/aten/src/ATen/native/mkldnn/Matmul.cpp index b969e6a19c299..4494bc1a7c827 100644 --- a/aten/src/ATen/native/mkldnn/Matmul.cpp +++ b/aten/src/ATen/native/mkldnn/Matmul.cpp @@ -576,14 +576,14 @@ static void _mkldnn_gemm_i8i8i32_with_blas( n, k, alpha, - (int8_t*)self.data_ptr(), + static_cast(self.data_ptr()), lda, ao, - (int8_t*)mat2.data_ptr(), + static_cast(mat2.data_ptr()), ldb, bo, beta, - (int32_t*)result.data_ptr(), + static_cast(result.data_ptr()), ldc, &co); } diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/WoQMatmul.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/WoQMatmul.cpp index 6ef371424eed8..e7ed6b1a68bbf 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/WoQMatmul.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/WoQMatmul.cpp @@ -41,7 +41,7 @@ void woq_matmul_int4_impl( dst_usr_dims; dnnl::memory::dims m1_usr_strides, m2_usr_strides, scale_usr_strides, zp_usr_strides, dst_usr_strides; - int compressed_k = (int)(k / 8); + int compressed_k = k / 8; int num_groups = (int)(k / group_size); m1_usr_dims = {m, k}; m1_usr_strides = {m1.stride(0), m1.stride(1)}; diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index 0764b9d5e12d9..5cb6dd38822a6 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -86,6 +86,28 @@ struct zeta_functor { } }; +struct logaddexp_functor { + template , bool> = true> + inline T operator()(const T a, const T b) { + return c10::metal::logaddexp(a, b); + } + template , bool> = true> + inline float operator()(const T a, const T b) { + return c10::metal::logaddexp(float(a), float(b)); + } +}; + +struct logaddexp2_functor { + template , bool> = true> + inline T operator()(const T a, const T b) { + return c10::metal::logaddexp2(a, b); + } + template , bool> = true> + inline float operator()(const T a, const T b) { + return c10::metal::logaddexp2(float(a), float(b)); + } +}; + struct xlog1py_functor { template , bool> = true> inline T operator()(const T a, const T b) { @@ -377,6 +399,10 @@ REGISTER_FLOAT_BINARY_OP(fmin); REGISTER_FLOAT_BINARY_OP(nextafter); REGISTER_FLOAT_BINARY_OP(zeta); REGISTER_INT2FLOAT_BINARY_OP(zeta); +REGISTER_FLOAT_BINARY_OP(logaddexp); +REGISTER_INT2FLOAT_BINARY_OP(logaddexp); +REGISTER_FLOAT_BINARY_OP(logaddexp2); +REGISTER_INT2FLOAT_BINARY_OP(logaddexp2); REGISTER_FLOAT_BINARY_OP(xlog1py); REGISTER_INT2FLOAT_BINARY_OP(xlog1py); REGISTER_FLOAT_BINARY_OP(chebyshev_polynomial_t); @@ -463,6 +489,8 @@ REGISTER_BINARY_OP(add, float2, float2); REGISTER_BINARY_OP(add, half2, half2); REGISTER_BINARY_OP(sub, float2, float2); REGISTER_BINARY_OP(sub, half2, half2); +REGISTER_BINARY_OP(logaddexp, float2, float2); +REGISTER_BINARY_OP(logaddexp, half2, half2); REGISTER_BINARY_ALPHA_OP(add_alpha, float2, float2, float2); REGISTER_BINARY_ALPHA_OP(add_alpha, half2, half2, half2); REGISTER_BINARY_ALPHA_OP(sub_alpha, float2, float2, float2); diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index 70211ceef07ad..f8baf2e7f1171 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -89,6 +89,14 @@ static void zeta_mps_kernel(TensorIteratorBase& iter) { lib.exec_binary_kernel(iter, "zeta"); } +static void logaddexp_mps_kernel(TensorIteratorBase& iter) { + lib.exec_binary_kernel(iter, "logaddexp"); +} + +static void logaddexp2_mps_kernel(TensorIteratorBase& iter) { + lib.exec_binary_kernel(iter, "logaddexp2"); +} + static void xlog1py_mps_kernel(TensorIteratorBase& iter) { TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "xlog1py_mps not implemented for non-floating types"); lib.exec_binary_kernel(iter, "xlog1py"); @@ -211,6 +219,8 @@ static void hypot_mps_kernel(TensorIteratorBase& iter) { REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel) REGISTER_DISPATCH(nextafter_stub, &nextafter_mps_kernel) REGISTER_DISPATCH(zeta_stub, &zeta_mps_kernel) +REGISTER_DISPATCH(logaddexp_stub, &logaddexp_mps_kernel); +REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_mps_kernel); REGISTER_DISPATCH(xlog1py_stub, &xlog1py_mps_kernel) REGISTER_DISPATCH(chebyshev_polynomial_t_stub, &chebyshev_polynomial_t_mps_kernel) REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_mps_kernel) diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index bffd792432666..d450a3ed8fe44 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -17,8 +17,6 @@ #include #include #include -#include -#include #include #include #include @@ -277,30 +275,6 @@ static void add_sub_lerp_template(const Tensor& self, } } -TORCH_IMPL_FUNC(logaddexp_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { - mps::BinaryOpBlock logaddexp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { - MPSGraph* mpsGraph = cachedGraph->graph(); - MPSGraphTensor* sumTensor = - [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentWithTensor:primaryCastTensor name:nil] - secondaryTensor:[mpsGraph exponentWithTensor:secondaryCastTensor name:nil] - name:nil]; - return [mpsGraph logarithmWithTensor:sumTensor name:nil]; - }; - mps::binaryOpTensor(self, other, output, "logaddexp_out_mps", logaddexp_op_block); -} - -TORCH_IMPL_FUNC(logaddexp2_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { - mps::BinaryOpBlock logaddexp2_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { - MPSGraph* mpsGraph = cachedGraph->graph(); - MPSGraphTensor* sumTensor = - [mpsGraph additionWithPrimaryTensor:[mpsGraph exponentBase2WithTensor:primaryCastTensor name:nil] - secondaryTensor:[mpsGraph exponentBase2WithTensor:secondaryCastTensor name:nil] - name:nil]; - return [mpsGraph logarithmBase2WithTensor:sumTensor name:nil]; - }; - mps::binaryOpTensor(self, other, output, "logaddexp2_out_mps", logaddexp2_op_block); -} - TORCH_IMPL_FUNC(xlogy_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { mps::BinaryOpBlock xlogy_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { MPSGraph* mpsGraph = cachedGraph->graph(); diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm index f5264cf32d9f2..0c95fec667e80 100644 --- a/aten/src/ATen/native/mps/operations/Normalization.mm +++ b/aten/src/ATen/native/mps/operations/Normalization.mm @@ -84,6 +84,9 @@ static void get_shapes(MPSShape* input_shape_readonly, Tensor& output, Tensor& save_mean, Tensor& save_var) { + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "Long batch norm is not supported with MPS"); + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), + "Batch norm for complex is not supported for MPS"); using namespace at::native::mps; struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} @@ -918,6 +921,7 @@ Check if running mean exists (maybe do this check before making graph) // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) const int axis = input_ndim - normalized_ndim; MPSStream* stream = getCurrentMPSStream(); + TORCH_CHECK_NOT_IMPLEMENTED(input.scalar_type() != kLong, "Not implemented for long on MPS"); @autoreleasepool { mps::dispatch_sync_with_rethrow(stream->queue(), ^() { // which kernel variant to use based on the normalized axis N size diff --git a/aten/src/ATen/native/mps/operations/Pooling.mm b/aten/src/ATen/native/mps/operations/Pooling.mm index fdfabecef06b9..2d466f7c79436 100644 --- a/aten/src/ATen/native/mps/operations/Pooling.mm +++ b/aten/src/ATen/native/mps/operations/Pooling.mm @@ -10,6 +10,7 @@ #include #include #else +#include #include #include #include @@ -544,8 +545,9 @@ static void max_unpool_out_mps_template(const Tensor& input, if (indices.defined() && indices.numel() > 0) { auto output_image_size = c10::multiply_integers(output_size_); - int64_t min_idx = indices.min().item(); - int64_t max_idx = indices.max().item(); + auto [min_idx_tensor, max_idx_tensor] = indices.aminmax(); + int64_t min_idx = min_idx_tensor.item(); + int64_t max_idx = max_idx_tensor.item(); if (min_idx < 0 || max_idx >= output_image_size) { int64_t error_idx = (min_idx < 0) ? min_idx : max_idx; diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index f4e469b79cb48..3747f314adfa1 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -1028,15 +1028,18 @@ Tensor trace_mps(const Tensor& self) { } TORCH_IMPL_FUNC(amax_out_mps)(const Tensor& input_t, IntArrayRef dim, bool keepdim, const Tensor& output_t) { + TORCH_CHECK(!c10::isComplexType(input_t.scalar_type()), "amax is not defined for complex types"); reduction_out_mps(input_t, dim, keepdim, std::nullopt, output_t, MPSReductionType::AMAX, "amax_out_mps"); } TORCH_IMPL_FUNC(amin_out_mps)(const Tensor& input_t, IntArrayRef dim, bool keepdim, const Tensor& output_t) { + TORCH_CHECK(!c10::isComplexType(input_t.scalar_type()), "amin is not defined for complex types"); reduction_out_mps(input_t, dim, keepdim, std::nullopt, output_t, MPSReductionType::AMIN, "amin_out_mps"); } TORCH_IMPL_FUNC(aminmax_out_mps) (const Tensor& input_t, std::optional dim_opt, bool keepdim, const Tensor& min_t, const Tensor& max_t) { + TORCH_CHECK(!c10::isComplexType(input_t.scalar_type()), "aminmax is not defined for complex types"); reduction_out_mps(input_t, dim_opt.has_value() ? OptionalIntArrayRef({*dim_opt}) : std::nullopt, keepdim, diff --git a/aten/src/ATen/native/mps/operations/Shape.mm b/aten/src/ATen/native/mps/operations/Shape.mm index 973bef036d564..af5416b21b3bd 100644 --- a/aten/src/ATen/native/mps/operations/Shape.mm +++ b/aten/src/ATen/native/mps/operations/Shape.mm @@ -83,6 +83,31 @@ static void check_shape_except_dim(const Tensor& first, const Tensor& second, in return "int32_t"; } +// If all tensors are contiguous with the same dtype and the cat dimension is 0, +// then we can simply copy each tensor's underlying buffer contiguously into the +// output. +static void cat_out_mps_contiguous_impl(const ITensorListRef& inputs, const Tensor& output) { + MPSStream* stream = getCurrentMPSStream(); + id output_buffer = getMTLBufferStorage(output); + size_t output_offset = output.storage_offset() * output.itemsize(); + + for (const Tensor& input : inputs) { + if (cat_should_skip_tensor(input)) { + continue; + } + + id input_buffer = getMTLBufferStorage(input); + size_t input_offset = input.storage_offset() * input.itemsize(); + auto nbytes = input.nbytes(); + auto profile_id = + getMPSProfiler().beginProfileCopy(input_buffer, output_buffer, input, output, nbytes, /*non_blocking=*/true); + + stream->copy(input_buffer, output_buffer, nbytes, input_offset, output_offset, profile_id, SyncType::NONE); + + output_offset += nbytes; + } +} + // NOTE: `output` is expected to already have the correct size. template static void cat_out_mps_impl(const ITensorListRef& inputs, int64_t dimension, const Tensor& output) { @@ -105,7 +130,7 @@ static void cat_out_mps_impl(const ITensorListRef& inputs, int64_t dimension, co // copy all the input tensor data into a packed buffer, which would not be // ideal. for (const Tensor& input : inputs) { - if (input.numel() == 0) { + if (cat_should_skip_tensor(input)) { continue; } @@ -240,104 +265,19 @@ static void cat_out_mps_impl(const ITensorListRef& inputs, int64_t dimension, co const Tensor& out) { using namespace mps; - if (out.numel() == 0) { - return; - } - auto materialized_inputs = inputs.materialize(); - auto out_dtype = at::native::result_type(inputs); - - int idx = 0; - for (const Tensor& t : materialized_inputs) { - TORCH_CHECK(t.dim() > 0, "zero-dimensional tensor (at position ", idx, ") cannot be concatenated"); - auto lap = at::get_overlap_status(out, t); - TORCH_CHECK(lap != at::MemOverlapStatus::Partial && lap != at::MemOverlapStatus::Full, - "torch.cat(): unsupported operation: the input tensors cannot refer to any " - "of the output memory locations. Found overlap in input tensor ", - idx); - idx++; - } - // Check for type promotion - TORCH_CHECK(canCast(out_dtype, out.scalar_type()), - "torch.cat(): input types can't be cast to the desired output type ", - out.scalar_type()); - TORCH_CHECK(!inputs.empty(), "torch.cat(): invalid number of inputs ", inputs.size()); - - dimension = legacy_cat_wrap_dim(dimension, materialized_inputs); - TORCH_CHECK(dimension >= 0, "torch.cat(): invalid dimension ", dimension); - - // previously, size [0] tensors were the only possible empty tensors; thus, it - // wasn't possible to cat empty tensors unless all the other tensors were - // 1-dimensional, so we allowed these tensors to be "skipped". We maintain - // this behavior for backwards compatibility, but only for this specific size - // (i.e. other empty sizes are not skipped). - // FIXME: warn if this is the case - auto should_skip = [](const Tensor& t) { return t.dim() == 1 && t.size(0) == 0; }; - at::assert_no_internal_overlap(out); - - Tensor notSkippedTensor; - // Indices of tensors to be skipped because they're empty - std::vector skipped_tensor_indices; - // Tensors to be read - std::vector input_tensors; - int tensor_idx = 0; - for (const Tensor& t : materialized_inputs) { - if (t.numel() == 0 || should_skip(t)) { - skipped_tensor_indices.push_back(tensor_idx); - tensor_idx++; - continue; - } - input_tensors.push_back(t); - // TODO: Is this OK? - notSkippedTensor = t; - tensor_idx++; - } - // If all inputs are empty tensors, return an empty tensor - if (!notSkippedTensor.defined()) { - return; - } - for (const Tensor& t : inputs) { - TORCH_CHECK(t.device() == notSkippedTensor.device(), - "torch.cat(): all input tensors must be on the same device. Received ", - t.device(), - " and ", - notSkippedTensor.device()); - } - TORCH_CHECK(out.device() == notSkippedTensor.device(), - "torch.cat(): all input tensors and out must be on the same device, but inputs are on ", - notSkippedTensor.device(), - " and out is on ", - out.device()); - - std::vector size(notSkippedTensor.sizes().vec()); - - // Compute size of the result in the cat dimension - int64_t cat_dim_size = 0; - idx = 0; - bool has_large_tensor = false; - for (const Tensor& tensor : materialized_inputs) { - if (isTooLargeForMPSGraph(tensor)) { - has_large_tensor |= true; - } - if (!should_skip(tensor)) { - // TODO: Factor out `check_shape_except_dim` - check_shape_except_dim(notSkippedTensor, tensor, dimension, idx); - cat_dim_size += tensor.size(dimension); - idx++; - } - } - // Compute the size of the result - size[dimension] = cat_dim_size; - // skip resizing if size of result is same as expected - if (out.sizes() != size) { - out.resize_(size, MemoryFormat::Contiguous); - } if (out.numel() == 0) { return; } - has_large_tensor |= isTooLargeForMPSGraph(out); + auto materialized_inputs = inputs.materialize(); + bool has_large_tensor = + isTooLargeForMPSGraph(out) || std::any_of(materialized_inputs.begin(), materialized_inputs.end(), [](auto& t) { + return !cat_should_skip_tensor(t) && isTooLargeForMPSGraph(t); + }); - if (has_large_tensor) { + if (all_contiguous && all_same_dtype && (memory_format == MemoryFormat::Contiguous) && (dimension == 0)) { + return mps::cat_out_mps_contiguous_impl(materialized_inputs, out); + } else if (has_large_tensor) { return mps::cat_out_mps_impl(materialized_inputs, dimension, out); } else { return mps::cat_out_mps_impl(materialized_inputs, dimension, out); diff --git a/aten/src/ATen/native/mps/operations/Sort.mm b/aten/src/ATen/native/mps/operations/Sort.mm index b6a07f14704cc..898acacdb763f 100644 --- a/aten/src/ATen/native/mps/operations/Sort.mm +++ b/aten/src/ATen/native/mps/operations/Sort.mm @@ -31,6 +31,7 @@ void kthvalue_out_mps_impl(const Tensor& self, int64_t k, int64_t dim, Tensor& v indices.copy_(values.toType(at::ScalarType::Long)); return; } + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "kthvalue is not implemented for complex types"); // issue #154890, raising error to prevent crash within MPSGraph until // workaround is implemented. TORCH_CHECK(self.dim() - dim <= 4, "On-going issue on MPSGraph topk when ndims() - axis > 4, see issue #154890"); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 0bc89ef493dc9..4424f51827d45 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2602,12 +2602,16 @@ device_check: NoCheck # TensorIterator structured_delegate: exp.out variants: function, method + dispatch: + SparseCPU, SparseCUDA, SparseMPS: exp_sparse tags: [core, pointwise] - func: exp_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator structured_delegate: exp.out variants: function, method + dispatch: + SparseCPU, SparseCUDA, SparseMPS: exp_sparse_ tags: pointwise - func: exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) @@ -2616,6 +2620,7 @@ structured_inherits: TensorIteratorBase dispatch: CPU, CUDA, MPS, MTIA: exp_out + SparseCPU, SparseCUDA, SparseMPS: exp_sparse_out tags: pointwise - func: exp2(Tensor self) -> Tensor @@ -3622,8 +3627,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: logaddexp_out - MPS: logaddexp_out_mps + CPU, CUDA, MPS: logaddexp_out tags: pointwise - func: logaddexp(Tensor self, Tensor other) -> Tensor @@ -3635,8 +3639,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: logaddexp2_out - MPS: logaddexp2_out_mps + CPU, CUDA, MPS: logaddexp2_out tags: pointwise - func: logaddexp2(Tensor self, Tensor other) -> Tensor @@ -4797,6 +4800,12 @@ CompositeExplicitAutograd: rand_like autogen: rand_like.out +- func: rand_like.generator(Tensor self, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + CompositeExplicitAutograd: rand_like + autogen: rand_like.generator_out + - func: randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor tags: nondeterministic_seeded dispatch: @@ -4845,6 +4854,14 @@ CompositeExplicitAutograd: randint_like autogen: randint_like.out +- func: randint_like.generator(Tensor self, SymInt high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: randint_like + autogen: randint_like.generator_out + - func: randint_like.Tensor(Tensor self, Tensor high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor tags: nondeterministic_seeded dispatch: @@ -4853,6 +4870,14 @@ CompositeExplicitAutograd: randint_like autogen: randint_like.Tensor_out +- func: randint_like.Tensor_generator(Tensor self, Tensor high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: randint_like + autogen: randint_like.Tensor_generator_out + - func: randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor tags: nondeterministic_seeded dispatch: @@ -4861,6 +4886,14 @@ CompositeExplicitAutograd: randint_like autogen: randint_like.low_dtype_out +- func: randint_like.low_generator_dtype(Tensor self, SymInt low, SymInt high, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: randint_like + autogen: randint_like.low_generator_dtype_out + - func: randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor tags: [core, nondeterministic_seeded] dispatch: @@ -4901,6 +4934,14 @@ CompositeExplicitAutograd, CompositeImplicitAutogradNestedTensor: randn_like autogen: randn_like.out +- func: randn_like.generator(Tensor self, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd, CompositeImplicitAutogradNestedTensor: randn_like + autogen: randn_like.generator_out + - func: randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor tags: [core, nondeterministic_seeded] dispatch: @@ -8867,11 +8908,11 @@ autogen: bitwise_right_shift.Scalar_Tensor_out tags: pointwise -- func: tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) +- func: tril_(Tensor(a!) self, SymInt diagonal=0) -> Tensor(a!) structured_delegate: tril.out variants: method -- func: triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) +- func: triu_(Tensor(a!) self, SymInt diagonal=0) -> Tensor(a!) structured_delegate: triu.out variants: method @@ -8995,25 +9036,25 @@ - func: cross(Tensor self, Tensor other, int? dim=None) -> Tensor variants: method, function -- func: triu.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) +- func: triu.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: CPU: triu_cpu CUDA: triu_cuda MPS: triu_mps_out -- func: triu(Tensor self, int diagonal=0) -> Tensor +- func: triu(Tensor self, SymInt diagonal=0) -> Tensor structured_delegate: triu.out variants: method, function -- func: tril.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) +- func: tril.out(Tensor self, SymInt diagonal=0, *, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: CPU: tril_cpu CUDA: tril_cuda MPS: tril_mps_out -- func: tril(Tensor self, int diagonal=0) -> Tensor +- func: tril(Tensor self, SymInt diagonal=0) -> Tensor structured_delegate: tril.out variants: method, function diff --git a/aten/src/ATen/native/quantized/AffineQuantizerBase.cpp b/aten/src/ATen/native/quantized/AffineQuantizerBase.cpp index 1997ea1648352..1086b4d0d8c58 100644 --- a/aten/src/ATen/native/quantized/AffineQuantizerBase.cpp +++ b/aten/src/ATen/native/quantized/AffineQuantizerBase.cpp @@ -65,7 +65,7 @@ void quantize_vec( (typename T::underlying*)dst, count, fbgemm::TensorQuantizationParams{ - (float)scale, (int32_t)zero_point, precision}); + static_cast(scale), static_cast(zero_point), precision}); } #if defined(__ARM_NEON__) || defined(__aarch64__) diff --git a/aten/src/ATen/native/quantized/cpu/AdaptiveAveragePooling.cpp b/aten/src/ATen/native/quantized/cpu/AdaptiveAveragePooling.cpp index de7c380b6b67f..c9cc433b11e08 100644 --- a/aten/src/ATen/native/quantized/cpu/AdaptiveAveragePooling.cpp +++ b/aten/src/ATen/native/quantized/cpu/AdaptiveAveragePooling.cpp @@ -40,7 +40,7 @@ inline int start_index(int out_idx, int out_len, int in_len) { * This function computes the start index on input matrix. */ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - return (int)std::floor((float)(out_idx * in_len) / out_len); + return static_cast(std::floor(static_cast(out_idx * in_len) / out_len)); } inline int end_index(int out_idx, int out_len, int in_len) { @@ -49,7 +49,7 @@ inline int end_index(int out_idx, int out_len, int in_len) { * This function computes the end index on input matrix. */ // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - return (int)std::ceil((float)((out_idx + 1) * in_len) / out_len); + return static_cast(std::ceil(static_cast((out_idx + 1) * in_len) / out_len)); } // adaptive avg pool for 2D and 3D inputs diff --git a/aten/src/ATen/native/quantized/cpu/AveragePool2d.cpp b/aten/src/ATen/native/quantized/cpu/AveragePool2d.cpp index 640ce50b76e85..bab0957c4c918 100644 --- a/aten/src/ATen/native/quantized/cpu/AveragePool2d.cpp +++ b/aten/src/ATen/native/quantized/cpu/AveragePool2d.cpp @@ -71,8 +71,8 @@ void avg_pool2d_out_frame( int64_t hend = std::min(hstart + kH, inputHeight + padH); int64_t wend = std::min(wstart + kW, inputWidth + padW); int64_t pool_size = (hend - hstart) * (wend - wstart); - hstart = std::max(hstart, (int64_t)0); - wstart = std::max(wstart, (int64_t)0); + hstart = std::max(hstart, static_cast(0)); + wstart = std::max(wstart, static_cast(0)); hend = std::min(hend, inputHeight); wend = std::min(wend, inputWidth); diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp index e9043f06b3018..5ddf1f60e2317 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp @@ -646,7 +646,7 @@ class QConvPackWeightInt8 final { torch::List output_padding; output_padding.reserve(kSpatialDim); for ([[maybe_unused]] const auto idx : c10::irange(kSpatialDim)) { - output_padding.push_back((int64_t)0); + output_padding.push_back(0); } return _run(weight, bias, stride, padding, output_padding, dilation, groups, /*transpose=*/false); diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl b/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl index ae31d34e80c46..180442b4b09a4 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl @@ -301,6 +301,10 @@ def define_qnnpack(third_party, labels = []): "-DQNNP_PRIVATE=", "-DQNNP_INTERNAL=", ], + fbobjc_compiler_flags = [ + "-Wno-switch-enum", + "-Wno-switch-default", + ], labels = [ "supermodule:android/default/pytorch", "supermodule:ios/default/public.pytorch", diff --git a/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp b/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp index b94ab0fd0975f..1e1811a0b2c45 100644 --- a/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp +++ b/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp @@ -134,7 +134,7 @@ class QConvPackWeightInt8Cudnn final { torch::List output_padding; output_padding.reserve(kSpatialDim); for ([[maybe_unused]] const auto idx : c10::irange(kSpatialDim)) { - output_padding.push_back((int64_t)0); + output_padding.push_back(0); } return _run(weight, bias, stride, padding, output_padding, dilation, groups, /*transpose=*/false); diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index 80f79c6520378..a7d3cb8d671e2 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -467,6 +467,28 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, IntArrayRe !options.has_layout() || options.layout() == kSparse, "expected sparse layout, but got layout ", options.layout()); + + if (indices.numel() > 0) { + Tensor min_indices = + std::get(indices.min(/* dim */ 1, /* keepdim */ false)); + Tensor cpu_min_indices; + if (!indices.is_cpu()) { + cpu_min_indices = min_indices.to(at::DeviceType::CPU); + } else { + cpu_min_indices = min_indices; + } + auto cpu_min_indices_accessor = cpu_min_indices.accessor(); + for (const auto d : c10::irange(indices.size(0))) { + int64_t min_index_in_dim = cpu_min_indices_accessor[d]; + TORCH_CHECK( + min_index_in_dim >= 0, + "found negative index ", + min_index_in_dim, + " for dim ", + d); + } + } + return at::native::_sparse_coo_tensor_unsafe( indices, values, diff --git a/aten/src/ATen/native/sparse/SparseUnaryOps.cpp b/aten/src/ATen/native/sparse/SparseUnaryOps.cpp index fc288f4203db0..79bed48926656 100644 --- a/aten/src/ATen/native/sparse/SparseUnaryOps.cpp +++ b/aten/src/ATen/native/sparse/SparseUnaryOps.cpp @@ -26,6 +26,8 @@ #include #include #include +#include +#include #include #include #include @@ -175,6 +177,7 @@ COALESCED_UNARY_UFUNC(atanh) COALESCED_UNARY_UFUNC(ceil) COALESCED_UNARY_UFUNC(deg2rad) COALESCED_UNARY_UFUNC(erf) +COALESCED_UNARY_UFUNC(exp) COALESCED_UNARY_UFUNC(erfinv) COALESCED_UNARY_UFUNC(expm1) COALESCED_UNARY_UFUNC(floor) diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp index c656dc71a660d..77f31e6cc4ac2 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp @@ -16,8 +16,8 @@ void Xcoo2csr(const int *coorowind, int64_t nnz, int64_t m, int *csrrowptr) { "cusparseXcoo2csr only supports m, nnz with the bound [val] <= ", INT_MAX); - int i_nnz = (int)nnz; - int i_m = (int)m; + int i_nnz = static_cast(nnz); + int i_m = static_cast(m); auto handle = at::cuda::getCurrentCUDASparseHandle(); TORCH_CUDASPARSE_CHECK(cusparseXcoo2csr(handle, coorowind, i_nnz, i_m, csrrowptr, CUSPARSE_INDEX_BASE_ZERO)); @@ -202,7 +202,7 @@ void CreateIdentityPermutation(int64_t nnz, int *P) { TORCH_CHECK((nnz <= INT_MAX), "Xcsrsort_bufferSizeExt only supports m, n, nnz with the bound [val] <= ", INT_MAX); - int i_nnz = (int)nnz; + int i_nnz = static_cast(nnz); auto handle = at::cuda::getCurrentCUDASparseHandle(); cusparseCreateIdentityPermutation(handle, i_nnz, P); @@ -213,9 +213,9 @@ void Xcsrsort_bufferSizeExt(int64_t m, int64_t n, int64_t nnz, const int *csrRow TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (nnz <= INT_MAX), "Xcsrsort_bufferSizeExt only supports m, n, nnz with the bound [val] <=", INT_MAX); - int i_m = (int)m; - int i_n = (int)n; - int i_nnz = (int)nnz; + int i_m = static_cast(m); + int i_n = static_cast(n); + int i_nnz = static_cast(nnz); auto handle = at::cuda::getCurrentCUDASparseHandle(); TORCH_CUDASPARSE_CHECK(cusparseXcsrsort_bufferSizeExt(handle, i_m, i_n, i_nnz, csrRowPtr, csrColInd, pBufferSizeInBytes)); @@ -226,9 +226,9 @@ void Xcsrsort(int64_t m, int64_t n, int64_t nnz, const int *csrRowPtr, int *csrC TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (nnz <= INT_MAX), "Xcsrsort only supports m, n, nnz with the bound [val] <= ", INT_MAX); - int i_m = (int)m; - int i_n = (int)n; - int i_nnz = (int)nnz; + int i_m = static_cast(m); + int i_n = static_cast(n); + int i_nnz = static_cast(nnz); auto handle = at::cuda::getCurrentCUDASparseHandle(); cusparseMatDescr_t desc; @@ -242,9 +242,9 @@ void Xcoosort_bufferSizeExt(int64_t m, int64_t n, int64_t nnz, const int *cooRow TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (nnz <= INT_MAX), "Xcoosort_bufferSizeExt only supports m, n, nnz with the bound [val] <= ", INT_MAX); - int i_m = (int)m; - int i_n = (int)n; - int i_nnz = (int)nnz; + int i_m = static_cast(m); + int i_n = static_cast(n); + int i_nnz = static_cast(nnz); auto handle = at::cuda::getCurrentCUDASparseHandle(); TORCH_CUDASPARSE_CHECK(cusparseXcoosort_bufferSizeExt(handle, i_m, i_n, i_nnz, cooRows, cooCols, pBufferSizeInBytes)); @@ -255,9 +255,9 @@ void XcoosortByRow(int64_t m, int64_t n, int64_t nnz, int *cooRows, int *cooCols TORCH_CHECK((m <= INT_MAX) && (n <= INT_MAX) && (nnz <= INT_MAX), "XcoosortByRow only supports m, n, nnz with the bound [val] <= ", INT_MAX); - int i_m = (int)m; - int i_n = (int)n; - int i_nnz = (int)nnz; + int i_m = static_cast(m); + int i_n = static_cast(n); + int i_nnz = static_cast(nnz); auto handle = at::cuda::getCurrentCUDASparseHandle(); TORCH_CUDASPARSE_CHECK(cusparseXcoosortByRow(handle, i_m, i_n, i_nnz, cooRows, cooCols, P, pBuffer)); diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp index 4141044116bbe..c4aca18ef9e4a 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp @@ -155,7 +155,7 @@ void set_params_fprop(Flash_fwd_params ¶ms, // [Minor] We want to round down since when we do the comparison we use <= instead of < // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); - params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.p_dropout_in_uint8_t = static_cast(std::floor(params.p_dropout * 255.0)); params.rp_dropout = 1.f / params.p_dropout; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; TORCH_CHECK(p_dropout < 1.f); @@ -307,7 +307,7 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n if (!is_split_eligible(num_splits)) { efficiency.push_back(0.f); } else { - float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float n_waves = static_cast(batch_nheads_mblocks * num_splits) / num_SMs; float eff = n_waves / ceil(n_waves); // printf("num_splits = %d, eff = %f\n", num_splits, eff); if (eff > max_efficiency) { max_efficiency = eff; } diff --git a/aten/src/ATen/native/transformers/sdp_utils_cpp.h b/aten/src/ATen/native/transformers/sdp_utils_cpp.h index c63ca928613e6..eb6a92dd5411a 100644 --- a/aten/src/ATen/native/transformers/sdp_utils_cpp.h +++ b/aten/src/ATen/native/transformers/sdp_utils_cpp.h @@ -341,7 +341,7 @@ inline bool check_grouped_query_attention(sdp_params const& params, bool debug) const auto v_num_heads = params.value.sym_size(-3); const bool same_kv_heads = k_num_heads == v_num_heads; - if (requires_same_num_heads && !(same_kv_heads)){ + if (requires_same_num_heads && !same_kv_heads){ if (debug) { TORCH_WARN( "Both fused kernels require key and value to have the same num_heads and batch_size but got: ", diff --git a/aten/src/ATen/native/ts_native_functions.yaml b/aten/src/ATen/native/ts_native_functions.yaml index 17c9bd4234f3e..4ef380704de8e 100644 --- a/aten/src/ATen/native/ts_native_functions.yaml +++ b/aten/src/ATen/native/ts_native_functions.yaml @@ -202,6 +202,7 @@ supported: - select_backward - _trilinear - linalg_pinv.atol_rtol_tensor + - svd - logsumexp.out symint: - empty.memory_format diff --git a/aten/src/ATen/native/xnnpack/Convolution.cpp b/aten/src/ATen/native/xnnpack/Convolution.cpp index 05c964cc0f59c..50d9a632008ac 100644 --- a/aten/src/ATen/native/xnnpack/Convolution.cpp +++ b/aten/src/ATen/native/xnnpack/Convolution.cpp @@ -43,9 +43,9 @@ bool available( (kFloat == weight.scalar_type()) && // Bias (bias_sizes_opt.has_value() ? ((1 == bias_sizes_opt->size()) && - ((transposed ? (weight.size(Layout::Filter::input) == + (transposed ? (weight.size(Layout::Filter::input) == ((*bias_sizes_opt)[0] / groups)) - : (weight.size(Layout::Filter::output) == ((*bias_sizes_opt)[0]))))) + : (weight.size(Layout::Filter::output) == ((*bias_sizes_opt)[0])))) : true) && // Padding (padding[Layout::Parameter::height] >= 0) && @@ -133,10 +133,10 @@ const Tensor reorder_weights_for_transpose_conv(const Tensor& weight_nhwc, int kernel_height = weight_nhwc.size(2); int o_offset = 1; - int h_offset = (output_channels_per_group); - int w_offset = (output_channels_per_group)*(kernel_height); - int i_offset = (output_channels_per_group)*(kernel_height)*(kernel_width); - int g_offset = (output_channels_per_group)*(kernel_height)*(kernel_width)*(input_channels_per_group); + int h_offset = output_channels_per_group; + int w_offset = output_channels_per_group*kernel_height; + int i_offset = output_channels_per_group*kernel_height*kernel_width; + int g_offset = output_channels_per_group*kernel_height*kernel_width*input_channels_per_group; Tensor reordered = mobile::empty_with_tail_padding( weight_nhwc.sizes(), diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 9d80c13f5ed82..ba3490bb1b071 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -129,6 +130,7 @@ class TORCH_API Tensor: public TensorBase { return *this; } + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") switch (this->layout()) { case at::kSparse: case at::kSparseCsr: @@ -139,6 +141,7 @@ class TORCH_API Tensor: public TensorBase { default: return this->_conj(); } + C10_DIAGNOSTIC_POP() } // Aliased by Dimname overloads, so need explicit using diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 54900de1ed915..e0681f52586e7 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -1837,6 +1837,10 @@ def skip_models(self): def skip_models_for_cuda(self): return set() + @property + def skip_models_for_xpu(self): + return set() + @property def skip_models_for_cpu(self): return set() @@ -3927,6 +3931,8 @@ def run(runner, args, original_dir=None): runner.skip_models.update(runner.skip_models_for_cpu_aarch64) elif args.devices == ["cuda"]: runner.skip_models.update(runner.skip_models_for_cuda) + elif args.devices == ["xpu"]: + runner.skip_models.update(runner.skip_models_for_xpu) if not args.multiprocess: runner.skip_models.update(runner.skip_multiprocess_models) diff --git a/benchmarks/dynamo/genai_layers/benchmark.py b/benchmarks/dynamo/genai_layers/benchmark.py index a5ad0b35e50b3..f903f805383ef 100644 --- a/benchmarks/dynamo/genai_layers/benchmark.py +++ b/benchmarks/dynamo/genai_layers/benchmark.py @@ -56,6 +56,20 @@ def list_benchmarks(): print(f"Available benchmarks: {list(BENCHMARK_REGISTRY.keys())}") +def _run_benchmark( + benchmark_cls, + script_args, +): + benchmark = benchmark_cls(script_args) + benchmark.benchmark() + benchmark.report_geomean_speedup() + if script_args.print_benchmark_result: + print(f"Benchmarking results {benchmark.name}:") + print(benchmark.profiling_results) + if script_args.visualize: + benchmark.visualize() + + def run_benchmark( benchmark_name: str, script_args, @@ -71,10 +85,7 @@ def run_benchmark( print("=" * 60) benchmark_class = BENCHMARK_REGISTRY[benchmark_name] - benchmark = benchmark_class(script_args) - benchmark.benchmark() - if script_args.visualize: - benchmark.visualize() + _run_benchmark(benchmark_class, script_args) return True @@ -87,10 +98,7 @@ def run_all_benchmarks(script_args): for name, cls in BENCHMARK_REGISTRY.items(): print(f"\n{'=' * 20} {name.upper()} {'=' * 20}") - benchmark = cls(script_args) - benchmark.benchmark() - if script_args.visualize: - benchmark.visualize() + _run_benchmark(cls, script_args) print() @@ -149,8 +157,43 @@ def main(): help="Whether to exit with an error message for accuracy failure", ) + parser.add_argument( + "--print-benchmark-result", + action="store_true", + help="Whether to print the raw benchmarking result. Easier to quickly check the benchmark results on a server without GUI", + ) + + parser.add_argument( + "--custom-compile-name", + type=str, + default=None, + help="Name for the curve with customized compilation options", + ) + + parser.add_argument( + "--custom-compile-options", + type=str, + default=None, + help="Json string for the custom compile options.", + ) + args = parser.parse_args() + if args.custom_compile_options: + import json + + try: + args.custom_compile_options = json.loads(args.custom_compile_options) + except json.decoder.JSONDecodeError as e: + raise RuntimeError( + f"Invalid json string for --custom-compile-options: {args.custom_compile_options}" + ) from e + + if not args.custom_compile_options: + raise RuntimeError("Found no options for --custom-compile-options") + if not args.custom_compile_name: + raise RuntimeError("Missing label name for the custom compilation") + # Handle list option if args.list: list_benchmarks() diff --git a/benchmarks/dynamo/genai_layers/kernels.py b/benchmarks/dynamo/genai_layers/kernels.py index 81c12495b5523..5c417cfcad31e 100644 --- a/benchmarks/dynamo/genai_layers/kernels.py +++ b/benchmarks/dynamo/genai_layers/kernels.py @@ -8,6 +8,15 @@ import torch.nn.functional as F +# more important shapes used by internal models +extra_shapes_for_norm = ( + (1152 * 500, 384), + (1152 * 500, 512), + (1152 * 1000, 384), + (1152 * 1000, 512), +) + + class CrossEntropyForward(BenchmarkKernel): def __init__(self, script_args): super().__init__(script_args) @@ -346,7 +355,7 @@ def get_shapes(self) -> tuple[tuple[int, ...], ...]: (32768, 65536), (16384, 131072), (8192, 262144), - ) + ) + extra_shapes_for_norm def get_memory_bytes(self, args, kwargs) -> int: x, w = args @@ -438,8 +447,7 @@ def get_shapes(self) -> tuple[tuple[int, ...], ...]: (32768, 4096), (32768, 8192), (32768, 16384), - (32768, 32768), - ) + ) + extra_shapes_for_norm def get_memory_bytes(self, args, kwargs) -> int: x, w, dy = args @@ -553,7 +561,7 @@ def get_shapes(self) -> tuple[tuple[int, ...], ...]: (32768, 16384), (32768, 32768), (32768, 65536), - ) + ) + extra_shapes_for_norm def get_memory_bytes(self, args, kwargs) -> int: x, w = args @@ -627,7 +635,7 @@ def get_shapes(self) -> tuple[tuple[int, ...], ...]: (32768, 16384), (32768, 32768), (32768, 65536), - ) + ) + extra_shapes_for_norm def get_memory_bytes(self, args, kwargs) -> int: x, w, dy = args diff --git a/benchmarks/dynamo/genai_layers/utils.py b/benchmarks/dynamo/genai_layers/utils.py index 30f1c4509d5e6..090e58676f4ed 100644 --- a/benchmarks/dynamo/genai_layers/utils.py +++ b/benchmarks/dynamo/genai_layers/utils.py @@ -6,6 +6,7 @@ from typing import Any, Optional import matplotlib.pyplot as plt +from scipy.stats import gmean import torch from torch._inductor.runtime.benchmarking import benchmarker @@ -107,6 +108,18 @@ def check_accuracy(self, args, kwargs) -> None: for backend in self.available_backends: args_ref, kwargs_ref = self.clone_inputs(args, kwargs) res[backend] = getattr(self, backend)(args_ref, kwargs_ref)() + + if ( + "compiled" in self.available_backends + and self.script_args.custom_compile_options + ): + torch._dynamo.reset() # cause recompile + with torch._inductor.config.patch(self.script_args.custom_compile_options): + args_ref, kwargs_ref = self.clone_inputs(args, kwargs) + res[self.script_args.custom_compile_name] = self.compiled( + args_ref, kwargs_ref + )() + gold = res["eager"] tol = {} @@ -115,7 +128,7 @@ def check_accuracy(self, args, kwargs) -> None: "atol": self.script_args.tolerance, "rtol": self.script_args.tolerance, } - for backend in self.available_backends: + for backend in res: if backend == "eager": continue try: @@ -134,37 +147,83 @@ def check_accuracy(self, args, kwargs) -> None: print("Exit right away since --exit-on-accuracy-failure is set") sys.exit(1) + def benchmark_single_shape_for_backend( + self, backend, args, kwargs, setting, fn=None + ) -> bool: + if fn is None: + fn = getattr(self, backend) + args_ref, kwargs_ref = self.clone_inputs(args, kwargs) + try: + avg_time = benchmark_kernel_in_milliseconds(fn(args_ref, kwargs_ref)) + except Exception as e: + print( + f"Failed to run {backend} backend on {self.name} kernel for {setting} due to {e}" + ) + self.available_backends.remove(backend) # noqa: B909 + return False + mem_bytes = self.get_memory_bytes(args_ref, kwargs_ref) + perf = Performance(setting, avg_time, mem_bytes) + print(f"{self.name} kernel on {backend} backend. {perf}") + self.profiling_results[backend].append(perf) + return True + def benchmark_single_shape( self, args, kwargs=None, should_check_accuracy=True, setting: str = "" ): for backend in self.available_backends: - args_ref, kwargs_ref = self.clone_inputs(args, kwargs) - try: - avg_time = benchmark_kernel_in_milliseconds( - getattr(self, backend)(args_ref, kwargs_ref) + self.benchmark_single_shape_for_backend(backend, args, kwargs, setting) + if ( + "compiled" in self.available_backends + and self.script_args.custom_compile_options + ): + torch._dynamo.reset() # cause recompile + with torch._inductor.config.patch(self.script_args.custom_compile_options): + status = self.benchmark_single_shape_for_backend( + self.script_args.custom_compile_name, + args, + kwargs, + setting, + fn=self.compiled, ) - except Exception as e: - print( - f"Failed to run {backend} backend on {self.name} kernel for {setting} due to {e}" + if not status: + self.script_args.custom_compile_options = ( + None # once fail, don't run again ) - self.available_backends.remove(backend) # noqa: B909 - continue - mem_bytes = self.get_memory_bytes(args_ref, kwargs_ref) - perf = Performance(setting, avg_time, mem_bytes) - print(f"{self.name} kernel on {backend} backend. {perf}") - self.profiling_results[backend].append(perf) if should_check_accuracy: self.check_accuracy(args, kwargs) def visualize(self) -> None: + device_name = torch.cuda.get_device_name(0) visualize_comparison( self.profiling_results, - title=f"{self.name}", + title=f"{self.name} ({device_name})", output_path=f"{self.name}_bench", ) return + def report_geomean_speedup(self) -> None: + print(f"Geomean speedup for benchmark {self.name}") + eager_result = { + result.setting: result for result in self.profiling_results["eager"] + } + print(f" eager {len(eager_result)} data points") + for backend, backend_result in self.profiling_results.items(): + if backend == "eager": + continue + speeduplist = [] + for result in backend_result: + eager_latency = eager_result[result.setting].latency + backend_latency = result.latency + speeduplist.append( + eager_latency / backend_latency if backend_latency != 0 else 0.0 + ) + + if len(speeduplist) > 0: + print( + f" {backend} {len(speeduplist)} data points, {gmean(speeduplist):.2f}x speedup" + ) + def get_backend_colors() -> dict[str, str]: """Get consistent color scheme for different backends.""" @@ -252,5 +311,6 @@ def visualize_comparison( os.makedirs("pics", exist_ok=True) full_path = os.path.join("pics", output_path + ".png") plt.savefig(full_path, dpi=300, bbox_inches="tight", facecolor="white") + print(f"Chart saved to {full_path}") plt.close() diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py index 59534e8341cbc..f31dbb3f95796 100755 --- a/benchmarks/dynamo/timm_models.py +++ b/benchmarks/dynamo/timm_models.py @@ -74,7 +74,8 @@ def pip_install(package): REQUIRE_HIGHER_TOLERANCE_AMP = {} REQUIRE_EVEN_HIGHER_TOLERANCE = { - "beit_base_patch16_224", + "deit_base_distilled_patch16_224", + "vit_base_patch16_siglip_256", } # These models need higher tolerance in MaxAutotune mode @@ -354,7 +355,9 @@ def get_tolerance_and_cosine_flag(self, is_training, current_device, name): if is_training: from torch._inductor import config as inductor_config - if name in REQUIRE_EVEN_HIGHER_TOLERANCE or ( + if name == "beit_base_patch16_224": + tolerance = 16 * 1e-2 + elif name in REQUIRE_EVEN_HIGHER_TOLERANCE or ( inductor_config.max_autotune and name in REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE ): diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index da6a3e1336aa3..ac4ddb4088416 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -124,6 +124,10 @@ def skip_models_for_cpu_aarch64(self): def skip_models_for_cuda(self): return self._skip["device"]["cuda"] + @property + def skip_models_for_xpu(self): + return self._skip["device"]["xpu"] + @property def skip_models_for_freezing_cuda(self): return self._skip["freezing"]["cuda"] diff --git a/benchmarks/dynamo/torchbench.yaml b/benchmarks/dynamo/torchbench.yaml index c2324eddc3887..b31a85ae26763 100644 --- a/benchmarks/dynamo/torchbench.yaml +++ b/benchmarks/dynamo/torchbench.yaml @@ -217,6 +217,9 @@ skip: cuda: [] + xpu: + - *DETECTRON2_MODELS + test: training: - *DETECTRON2_MODELS diff --git a/benchmarks/operator_benchmark/pt/configs.py b/benchmarks/operator_benchmark/pt/configs.py index 807182dd592b9..15895a6e8f52f 100644 --- a/benchmarks/operator_benchmark/pt/configs.py +++ b/benchmarks/operator_benchmark/pt/configs.py @@ -11,6 +11,11 @@ def remove_cuda(config_list): return [config for config in config_list if cuda_config not in config] +def remove_cpu(config_list): + cpu_config = {"device": "cpu"} + return [config for config in config_list if cpu_config not in config] + + # Configs for conv-1d ops conv_1d_configs_short = op_bench.config_list( attr_names=["IC", "OC", "kernel", "stride", "N", "L"], @@ -127,6 +132,18 @@ def remove_cuda(config_list): }, tags=["short"], ) +conv_3d_configs_long = op_bench.cross_product_configs( + IC=[16, 32], + OC=[32, 64], + kernel=[3, 5], + stride=[1, 2], + N=[1], + D=[128], + H=[128], + W=[128], + device=["cpu", "cuda"], + tags=["long"], +) linear_configs_short = op_bench.config_list( attr_names=["N", "IN", "OUT"], diff --git a/benchmarks/operator_benchmark/pt/conv_test.py b/benchmarks/operator_benchmark/pt/conv_test.py index 65baf47e0d673..eb94921989ccf 100644 --- a/benchmarks/operator_benchmark/pt/conv_test.py +++ b/benchmarks/operator_benchmark/pt/conv_test.py @@ -38,6 +38,10 @@ def forward(self, input): op_bench.generate_pt_test( configs.conv_1d_configs_short + configs.conv_1d_configs_long, Conv1dBenchmark ) +op_bench.generate_pt_gradient_test( + configs.remove_cpu(configs.conv_1d_configs_short + configs.conv_1d_configs_long), + Conv1dBenchmark, +) if not torch.backends.mkldnn.is_acl_available(): @@ -103,6 +107,20 @@ def forward(self, input): configs.conv_2d_pw_configs_short + configs.conv_2d_pw_configs_long, Conv2dPointwiseBenchmark, ) +op_bench.generate_pt_gradient_test( + configs.remove_cpu(configs.conv_2d_configs_short + configs.conv_2d_configs_long), + Conv2dBenchmark, +) +op_bench.generate_pt_gradient_test( + configs.remove_cpu(configs.conv_2d_configs_short + configs.conv_2d_configs_long), + ConvTranspose2dBenchmark, +) +op_bench.generate_pt_gradient_test( + configs.remove_cpu( + configs.conv_2d_pw_configs_short + configs.conv_2d_pw_configs_long + ), + Conv2dPointwiseBenchmark, +) """ @@ -134,6 +152,12 @@ def forward(self, input): op_bench.generate_pt_test(configs.conv_3d_configs_short, Conv3dBenchmark) op_bench.generate_pt_test(configs.conv_3d_configs_short, ConvTranspose3dBenchmark) +op_bench.generate_pt_gradient_test( + configs.remove_cpu(configs.conv_3d_configs_long), Conv3dBenchmark +) +op_bench.generate_pt_gradient_test( + configs.remove_cpu(configs.conv_3d_configs_long), ConvTranspose3dBenchmark +) if __name__ == "__main__": diff --git a/c10/core/Backend.h b/c10/core/Backend.h index ff3621352632f..ba536e15fe4a8 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -3,10 +3,13 @@ #include #include #include +#include #include #include +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") + namespace c10 { /** @@ -402,3 +405,5 @@ inline bool isSparseCsr(Backend b) { } } // namespace c10 + +C10_DIAGNOSTIC_POP() diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index 96ef6b3522ba7..72e72f49a5e40 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -52,9 +52,7 @@ constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | // where we would like to support composite implicit kernels but not // explicit kernels therefore we manually add the key to the // math_dispatch_keyset - DispatchKeySet{DispatchKey::NestedTensor} | - // Functionalize should always reuse CompositeImplicit decomps. - DispatchKeySet{DispatchKey::Functionalize}; + DispatchKeySet{DispatchKey::NestedTensor}; constexpr DispatchKeySet nested_dispatch_keyset = DispatchKeySet( diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index d46bf7efeed6a..934fa1be8965c 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -15,6 +15,8 @@ #include #include +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") + namespace c10 { struct FunctionalityOffsetAndMask { @@ -966,3 +968,5 @@ using remove_DispatchKeySet_arg_from_func = guts::make_function_traits_t< 1>, typename guts::infer_function_traits_t::parameter_types>>; } // namespace c10 + +C10_DIAGNOSTIC_POP() diff --git a/c10/core/DynamicCast.h b/c10/core/DynamicCast.h index 0a845776a263b..c844986c04f2b 100644 --- a/c10/core/DynamicCast.h +++ b/c10/core/DynamicCast.h @@ -69,6 +69,7 @@ template C10_HOST_DEVICE inline dest_t fetch_and_cast( const ScalarType src_type, const void* ptr) { + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") switch (src_type) { AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(FETCH_AND_CAST_CASE) FETCH_AND_CAST_CASE(uint16_t, UInt16) @@ -77,6 +78,7 @@ C10_HOST_DEVICE inline dest_t fetch_and_cast( default: ERROR_UNSUPPORTED_CAST } + C10_DIAGNOSTIC_POP() return dest_t(0); // just to avoid compiler warning } @@ -91,6 +93,7 @@ C10_HOST_DEVICE inline void cast_and_store( const ScalarType dest_type, void* ptr, src_t value) { + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") switch (dest_type) { AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CAST_AND_STORE_CASE) CAST_AND_STORE_CASE(uint16_t, UInt16) @@ -98,6 +101,7 @@ C10_HOST_DEVICE inline void cast_and_store( CAST_AND_STORE_CASE(uint64_t, UInt64) default:; } + C10_DIAGNOSTIC_POP() ERROR_UNSUPPORTED_CAST } diff --git a/c10/core/Layout.h b/c10/core/Layout.h index 0d09e0ed46f4e..a85f2ee6911ce 100644 --- a/c10/core/Layout.h +++ b/c10/core/Layout.h @@ -29,6 +29,7 @@ constexpr auto kSparseBsc = Layout::SparseBsc; constexpr auto kJagged = Layout::Jagged; inline Layout layout_from_backend(Backend backend) { + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") switch (backend) { case Backend::SparseCPU: case Backend::SparseCUDA: @@ -52,6 +53,7 @@ inline Layout layout_from_backend(Backend backend) { default: return Layout::Strided; } + C10_DIAGNOSTIC_POP() } inline std::ostream& operator<<(std::ostream& stream, at::Layout layout) { @@ -72,6 +74,7 @@ inline std::ostream& operator<<(std::ostream& stream, at::Layout layout) { return stream << "Mkldnn"; case at::kJagged: return stream << "Jagged"; + case Layout::NumOptions: default: TORCH_CHECK(false, "Unknown layout"); } diff --git a/c10/core/MemoryFormat.h b/c10/core/MemoryFormat.h index edc08bb1016c9..8c8531d014713 100644 --- a/c10/core/MemoryFormat.h +++ b/c10/core/MemoryFormat.h @@ -55,6 +55,7 @@ inline std::ostream& operator<<( return stream << "ChannelsLast"; case MemoryFormat::ChannelsLast3d: return stream << "ChannelsLast3d"; + case MemoryFormat::NumOptions: default: TORCH_CHECK(false, "Unknown memory format ", memory_format); } diff --git a/c10/core/QScheme.h b/c10/core/QScheme.h index 559e68508c76e..359bba8ac469f 100644 --- a/c10/core/QScheme.h +++ b/c10/core/QScheme.h @@ -1,9 +1,12 @@ #pragma once +#include #include #include #include +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") + namespace c10 { /** @@ -48,3 +51,5 @@ inline std::string toString(QScheme qscheme) { } } // namespace c10 + +C10_DIAGNOSTIC_POP() diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index e0c84370e878c..ba1068e72695c 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -26,6 +26,8 @@ #include +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") + namespace c10 { // See [dtype Macros note] in torch/headeronly/core/ScalarType.h @@ -288,3 +290,5 @@ C10_API std::pair getDtypeNames( C10_API const std::unordered_map& getStringToDtypeMap(); } // namespace c10 + +C10_DIAGNOSTIC_POP() diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index 6e71f5e21c2bb..99e0d7599f5d9 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -21,6 +21,8 @@ #include #include +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") + namespace c10 { inline ScalarType dtype_or_default(std::optional dtype) { @@ -780,3 +782,5 @@ inline bool backend_supports_empty_operator(const TensorOptions& options) { } // namespace detail } // namespace c10 + +C10_DIAGNOSTIC_POP() diff --git a/c10/core/thread_pool.cpp b/c10/core/thread_pool.cpp index 64858bed47d4c..7effe5846a3ef 100644 --- a/c10/core/thread_pool.cpp +++ b/c10/core/thread_pool.cpp @@ -87,7 +87,7 @@ bool ThreadPool::inThreadPool() const { } void ThreadPool::run(std::function func) { - TORCH_CHECK(threads_.size() > 0, "No threads to run a task"); + TORCH_CHECK(!threads_.empty(), "No threads to run a task"); std::unique_lock lock(mutex_); // Set task and signal condition variable so that a worker thread will diff --git a/c10/cuda/CUDAStream.cpp b/c10/cuda/CUDAStream.cpp index 975468de9f439..d2dd6da7a58ad 100644 --- a/c10/cuda/CUDAStream.cpp +++ b/c10/cuda/CUDAStream.cpp @@ -15,7 +15,6 @@ namespace c10::cuda { namespace { // Global stream state and constants -c10::once_flag init_flag; DeviceIndex num_gpus = -1; constexpr int kStreamsPerPoolBits = 5; constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits; @@ -226,7 +225,10 @@ void initDeviceStreamState(DeviceIndex device_index) { // Init front-end to ensure initialization only occurs once void initCUDAStreamsOnce() { // Inits default streams (once, globally) - c10::call_once(init_flag, initGlobalStreamState); + auto static init_flag [[maybe_unused]] = [] { + initGlobalStreamState(); + return true; + }(); if (current_streams) { return; diff --git a/c10/metal/special_math.h b/c10/metal/special_math.h index 29a45ff4c30b6..defce910d7dc6 100644 --- a/c10/metal/special_math.h +++ b/c10/metal/special_math.h @@ -1,4 +1,4 @@ -// Implementation of specal math functions for Metal +// Implementation of special math functions for Metal #pragma once #include #include @@ -624,6 +624,64 @@ inline T spherical_bessel_j0(T x) { return static_cast(::metal::sin(x) / x); } +template +inline ::metal::enable_if_t, T> logaddexp( + T a, + T b) { + float a0 = static_cast(a); + float b0 = static_cast(b); + if (::metal::isinf(a0) && a0 == b0) { + return static_cast(a0); + } else { + float m0 = ::metal::max(a0, b0); + return static_cast( + m0 + ::c10::metal::log1p(::metal::exp(-::metal::abs(a0 - b0)))); + } +} + +// The function is ported from mlx +template +inline ::metal::enable_if_t, T> logaddexp(T a, T b) { + if (::metal::isnan(a.x) || ::metal::isnan(a.y) || ::metal::isnan(b.x) || + ::metal::isnan(b.y)) { + return T(NAN, NAN); + } + + T maxval = a.x > b.x ? a : b; + T minval = a.x < b.x ? a : b; + constexpr auto inf = ::metal::numeric_limits::infinity().x; + + if (minval.x == -inf || maxval.x == inf) { + return maxval; + } + + float2 maxval_ = static_cast(maxval); + float2 minval_ = static_cast(minval); + float m = ::metal::exp(minval_.x - maxval_.x); + float2 dexp{ + m * ::metal::cos(minval_.y - maxval_.y), + m * ::metal::sin(minval_.y - maxval_.y), + }; + return static_cast(maxval_ + ::c10::metal::log1p(dexp)); +} + +template +inline T logaddexp2(T a, T b) { + constexpr auto log_2 = float(0.693147180559945309417232121458176); + constexpr auto inv_log_2 = float(1) / log_2; + float a0 = static_cast(a); + float b0 = static_cast(b); + if (::metal::isinf(a0) && a0 == b0) { + return static_cast(a0); + } else { + float m0 = ::metal::max(a0, b0); + return static_cast( + m0 + + ::c10::metal::log1p(::metal::pow(float(2), -::metal::abs(a0 - b0))) * + inv_log_2); + } +} + template inline float xlog1py(T x, T y) { if (::metal::isnan(y)) { diff --git a/c10/metal/utils.h b/c10/metal/utils.h index 43d0eff27b8e8..51e04174e32d2 100644 --- a/c10/metal/utils.h +++ b/c10/metal/utils.h @@ -322,6 +322,24 @@ inline float log1p(float x) { return rc; } +// The function is ported from mlx +inline float2 log1p(float2 in) { + float x = in.x; + float y = in.y; + float zabs = ::metal::precise::sqrt(x * x + y * y); + float theta = ::metal::atan2(y, x + 1); + if (zabs < 0.5f) { + float r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {0.5f * log1p(r), theta}; + } else { + auto z0 = ::metal::sqrt((x + 1) * (x + 1) + y * y); + return {::metal::log(z0), theta}; + } +} + template struct pair { T1 first; diff --git a/c10/mobile/CPUProfilingAllocator.cpp b/c10/mobile/CPUProfilingAllocator.cpp index d01cdd2b1d24b..c484811db91ac 100644 --- a/c10/mobile/CPUProfilingAllocator.cpp +++ b/c10/mobile/CPUProfilingAllocator.cpp @@ -34,7 +34,7 @@ struct MemEvent { bool overlaps(const MemBlock& a, const MemBlock& b) { // two blocks dont overlap if // |---a--------|--------------b--------| - // strat_a end_a <= start_b end_b + // start_a end_a <= start_b end_b return !( (a.end_offset <= b.start_offset) || (b.end_offset <= a.start_offset)); } diff --git a/c10/test/util/TypeList_test.cpp b/c10/test/util/TypeList_test.cpp index d2fe4432e393a..274cbfa442186 100644 --- a/c10/test/util/TypeList_test.cpp +++ b/c10/test/util/TypeList_test.cpp @@ -239,7 +239,7 @@ struct Class2 { struct mapper_call_func { template - decltype(auto) operator()(T) { + auto operator()(T) { return T::type::func(); } }; @@ -254,7 +254,7 @@ TEST(TypeListTest, MapTypesToValues_members) { struct mapper_call_nonexistent_function { template - decltype(auto) operator()(T) { + auto operator()(T) { return T::type::this_doesnt_exist(); } }; diff --git a/c10/util/Bitset.h b/c10/util/Bitset.h index 782cefbd922e0..f1d521bd7e513 100644 --- a/c10/util/Bitset.h +++ b/c10/util/Bitset.h @@ -33,7 +33,7 @@ struct bitset final { constexpr bitset() noexcept = default; constexpr bitset(const bitset&) noexcept = default; constexpr bitset(bitset&&) noexcept = default; - // there is an issure for gcc 5.3.0 when define default function as constexpr + // there is an issue for gcc 5.3.0 when define default function as constexpr // see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=68754. bitset& operator=(const bitset&) noexcept = default; bitset& operator=(bitset&&) noexcept = default; diff --git a/c10/util/C++17.h b/c10/util/C++17.h index fcdaaae3cb450..5dafb245f92e8 100644 --- a/c10/util/C++17.h +++ b/c10/util/C++17.h @@ -53,7 +53,7 @@ namespace guts { // member functions. namespace detail { template -C10_HOST_DEVICE constexpr decltype(auto) apply_impl( +C10_HOST_DEVICE constexpr auto apply_impl( F&& f, Tuple&& t, std::index_sequence) { @@ -62,7 +62,7 @@ C10_HOST_DEVICE constexpr decltype(auto) apply_impl( } // namespace detail template -C10_HOST_DEVICE constexpr decltype(auto) apply(F&& f, Tuple&& t) { +C10_HOST_DEVICE constexpr auto apply(F&& f, Tuple&& t) { return detail::apply_impl( std::forward(f), std::forward(t), diff --git a/c10/util/Exception.h b/c10/util/Exception.h index f0c85a8b13d8c..6b2fd626bfb5e 100644 --- a/c10/util/Exception.h +++ b/c10/util/Exception.h @@ -469,7 +469,7 @@ C10_API std::string GetExceptionString(const std::exception& e); namespace c10::detail { template -decltype(auto) torchCheckMsgImpl(const char* /*msg*/, const Args&... args) { +auto torchCheckMsgImpl(const char* /*msg*/, const Args&... args) { return ::c10::str(args...); } inline C10_API const char* torchCheckMsgImpl(const char* msg) { diff --git a/c10/util/SmallVector.h b/c10/util/SmallVector.h index eaf3cbfc601e8..d02c9380a563d 100644 --- a/c10/util/SmallVector.h +++ b/c10/util/SmallVector.h @@ -795,7 +795,7 @@ class SmallVectorImpl : public SmallVectorTemplateBase { std::move(I + 1, this->end(), I); // Drop the last elt. this->pop_back(); - return (N); + return N; } iterator erase(iterator S, iterator E) { @@ -807,7 +807,7 @@ class SmallVectorImpl : public SmallVectorTemplateBase { // Drop the last elts. this->destroy_range(I, this->end()); this->set_size(I - this->begin()); - return (N); + return N; } private: diff --git a/c10/util/StringUtil.h b/c10/util/StringUtil.h index b2c41bb98ee1d..cbc6f4ec336bb 100644 --- a/c10/util/StringUtil.h +++ b/c10/util/StringUtil.h @@ -135,7 +135,7 @@ struct _str_wrapper<> final { // Convert a list of string-like arguments into a single string. template -inline decltype(auto) str(const Args&... args) { +inline auto str(const Args&... args) { return detail::_str_wrapper< typename detail::CanonicalizeStrTypes::type...>::call(args...); } diff --git a/c10/util/TypeList.h b/c10/util/TypeList.h index a540a0c5c6744..244e5bb141cd7 100644 --- a/c10/util/TypeList.h +++ b/c10/util/TypeList.h @@ -507,7 +507,7 @@ struct map_types_to_values> final { } // namespace detail template -decltype(auto) map_types_to_values(Func&& func) { +auto map_types_to_values(Func&& func) { return detail::map_types_to_values::call(std::forward(func)); } diff --git a/c10/util/llvmMathExtras.h b/c10/util/llvmMathExtras.h index 6321297a61c75..da29724144951 100644 --- a/c10/util/llvmMathExtras.h +++ b/c10/util/llvmMathExtras.h @@ -370,7 +370,7 @@ constexpr inline bool isShiftedInt(int64_t x) { template constexpr inline std::enable_if_t<(N < 64), bool> isUInt(uint64_t X) { static_assert(N > 0, "isUInt<0> doesn't make sense"); - return X < (UINT64_C(1) << (N)); + return X < (UINT64_C(1) << N); } template constexpr inline std::enable_if_t= 64, bool> isUInt(uint64_t /*X*/) { diff --git a/c10/util/string_view.h b/c10/util/string_view.h index 0cc5da4309f6b..dabbd4f419f6b 100644 --- a/c10/util/string_view.h +++ b/c10/util/string_view.h @@ -324,7 +324,7 @@ class basic_string_view final { constexpr size_type find(basic_string_view v, size_type pos = 0) const noexcept { - if (v.size() == 0) { + if (v.empty()) { return pos <= size() ? pos : npos; } @@ -355,7 +355,7 @@ class basic_string_view final { constexpr size_type rfind(basic_string_view v, size_type pos = npos) const noexcept { // Write it iteratively. This is faster. - if (v.size() == 0) { + if (v.empty()) { return pos <= size() ? pos : size(); } @@ -509,7 +509,7 @@ class basic_string_view final { constexpr size_type find_last_if_(size_type pos, Condition&& condition) const noexcept { // Write it iteratively. This is faster. - if (size() > 0) { + if (!empty()) { pos = std::min(size() - 1, pos); do { if (condition(at_(pos))) { diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index 244db3c91e0fb..76be395ef4771 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -554,6 +554,17 @@ class DeviceCachingAllocator { } } + double getMemoryFraction() { + if (!set_fraction) { + return 1.0; + } + + c10::xpu::DeviceProp device_prop; + c10::xpu::get_device_properties(&device_prop, device_index); + return static_cast(allowed_memory_maximum) / + static_cast(device_prop.global_mem_size); + } + void setMemoryFraction(double fraction) { c10::xpu::DeviceProp device_prop; c10::xpu::get_device_properties(&device_prop, device_index); @@ -724,6 +735,11 @@ class XPUAllocator : public DeviceAllocator { device_allocators[device]->resetAccumulatedStats(); } + double getMemoryFraction(DeviceIndex device) { + assertValidDevice(device); + return device_allocators[device]->getMemoryFraction(); + } + void setMemoryFraction(double fraction, DeviceIndex device) { assertValidDevice(device); TORCH_CHECK_VALUE( @@ -777,6 +793,10 @@ void recordStream(const DataPtr& dataPtr, XPUStream stream) { return allocator.recordStream(dataPtr, stream); } +double getMemoryFraction(DeviceIndex device) { + return allocator.getMemoryFraction(device); +} + void setMemoryFraction(double fraction, DeviceIndex device) { return allocator.setMemoryFraction(fraction, device); } diff --git a/c10/xpu/XPUCachingAllocator.h b/c10/xpu/XPUCachingAllocator.h index 44ac34fe9a9b0..b0b0f2ca969e1 100644 --- a/c10/xpu/XPUCachingAllocator.h +++ b/c10/xpu/XPUCachingAllocator.h @@ -25,6 +25,8 @@ C10_XPU_API void raw_delete(void* ptr); C10_XPU_API void recordStream(const DataPtr& dataPtr, XPUStream stream); +C10_XPU_API double getMemoryFraction(DeviceIndex device); + C10_XPU_API void setMemoryFraction(double fraction, DeviceIndex device); } // namespace c10::xpu::XPUCachingAllocator diff --git a/c10/xpu/XPUFunctions.cpp b/c10/xpu/XPUFunctions.cpp index 26edf295d1fca..7afd3da1556d7 100644 --- a/c10/xpu/XPUFunctions.cpp +++ b/c10/xpu/XPUFunctions.cpp @@ -1,4 +1,3 @@ -#include #include #include @@ -33,7 +32,6 @@ namespace { * one iGPU and enumerate all iGPUs on that platform. * 3. If neither dGPUs nor iGPUs are found, conclude that no GPUs are available. */ -c10::once_flag init_flag; thread_local DeviceIndex curDeviceIndex = 0; struct DevicePool { @@ -149,7 +147,10 @@ inline void initGlobalDevicePoolState() { } inline void initDevicePoolCallOnce() { - c10::call_once(init_flag, initGlobalDevicePoolState); + auto static init_flag [[maybe_unused]] = [] { + initGlobalDevicePoolState(); + return true; + }(); } void initDeviceProperties(DeviceProp* device_prop, DeviceIndex device) { diff --git a/c10/xpu/XPUStream.cpp b/c10/xpu/XPUStream.cpp index 1daca30885da4..baf44ff11cbac 100644 --- a/c10/xpu/XPUStream.cpp +++ b/c10/xpu/XPUStream.cpp @@ -12,7 +12,6 @@ namespace c10::xpu { namespace { // Global stream state and constants -c10::once_flag init_flag; DeviceIndex num_gpus = -1; constexpr int kStreamsPerPoolBits = 5; constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits; @@ -163,7 +162,10 @@ void initDeviceStreamState(DeviceIndex device) { } void initXPUStreamsOnce() { - c10::call_once(init_flag, initGlobalStreamState); + auto static init_flag [[maybe_unused]] = [] { + initGlobalStreamState(); + return true; + }(); if (current_streams) { return; diff --git a/caffe2/serialize/crc_alt.h b/caffe2/serialize/crc_alt.h index 9d1c4f1dc7ddc..8c3e85df46ae8 100644 --- a/caffe2/serialize/crc_alt.h +++ b/caffe2/serialize/crc_alt.h @@ -38,7 +38,7 @@ uint32_t crc32_combine (uint32_t crcA, uint32_t crcB, size_t lengthB); /// compute CRC32 (bitwise algorithm) uint32_t crc32_bitwise (const void* data, size_t length, uint32_t previousCrc32 = 0); -/// compute CRC32 (half-byte algoritm) +/// compute CRC32 (half-byte algorithm) uint32_t crc32_halfbyte(const void* data, size_t length, uint32_t previousCrc32 = 0); #ifdef CRC32_USE_LOOKUP_TABLE_BYTE @@ -96,7 +96,7 @@ uint32_t crc32_16bytes_prefetch(const void* data, size_t length, uint32_t previo #define __BIG_ENDIAN 4321 #endif -// define endianess and some integer data types +// define endianness and some integer data types #if defined(_MSC_VER) || defined(__MINGW32__) // Windows always little endian #define __BYTE_ORDER __LITTLE_ENDIAN @@ -168,7 +168,7 @@ namespace /// zlib's CRC32 polynomial const uint32_t Polynomial = 0xEDB88320; - /// swap endianess + /// swap endianness static inline uint32_t swap(uint32_t x) { #if defined(__GNUC__) || defined(__clang__) @@ -229,7 +229,7 @@ uint32_t crc32_bitwise(const void* data, size_t length, uint32_t previousCrc32) } -/// compute CRC32 (half-byte algoritm) +/// compute CRC32 (half-byte algorithm) uint32_t crc32_halfbyte(const void* data, size_t length, uint32_t previousCrc32) { uint32_t crc = ~previousCrc32; // same as previousCrc32 ^ 0xFFFFFFFF @@ -662,7 +662,7 @@ uint32_t crc32_combine(uint32_t crcA, uint32_t crcB, size_t lengthB) // - if you append length(B) zeros to A and call it A' (think of it as AAAA000) // and prepend length(A) zeros to B and call it B' (think of it as 0000BBB) // then exists a C' = A' ^ B' - // - remember: if you XOR someting with zero, it remains unchanged: X ^ 0 = X + // - remember: if you XOR something with zero, it remains unchanged: X ^ 0 = X // - that means C' = A concat B so that crc(A concat B) = crc(C') = crc(A') ^ crc(B') // - the trick is to compute crc(A') based on crc(A) // and crc(B') based on crc(B) diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index 47bd7886dc93e..7c13b2d6ec54a 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -76,7 +76,7 @@ typedef struct mz_zip_archive mz_zip_archive; // 2) Writing with 1-pass sequential access // -> We must take care not to require updating values that have already // been written. We place the variable-length index at the end and do -// not put any indicies into the header to fulfill this constraint. +// not put any index into the header to fulfill this constraint. // The model.json, which contains all the metadata information, // should be written as the last file. One reason is that the size of tensor diff --git a/caffe2/serialize/inline_container_test.cc b/caffe2/serialize/inline_container_test.cc index 489751522fb61..785e93a0a22e2 100644 --- a/caffe2/serialize/inline_container_test.cc +++ b/caffe2/serialize/inline_container_test.cc @@ -519,7 +519,7 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoadWithAllocator) { std::tie(data_ptr, size) = reader.getRecord("key1", &overrideAllocator); EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1); EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes); - // allcoate with base allocator + // allocate with base allocator std::tie(data_ptr, size) = reader.getRecord("key1"); EXPECT_EQ(overrideAllocator.getAllocatedBytes(), kBytes1); EXPECT_EQ(baseAllocator.getAllocatedBytes(), allocBytes + kBytes1); diff --git a/caffe2/utils/threadpool/WorkersPool.h b/caffe2/utils/threadpool/WorkersPool.h index 274456ffc5322..6504d70c25c2f 100644 --- a/caffe2/utils/threadpool/WorkersPool.h +++ b/caffe2/utils/threadpool/WorkersPool.h @@ -258,6 +258,7 @@ class alignas(kGEMMLOWPCacheLineSize) Worker { case State::HasWork: DCHECK(new_state == State::Ready || new_state == State::ExitAsSoonAsPossible); break; + case State::ExitAsSoonAsPossible: default: abort(); } @@ -292,6 +293,8 @@ class alignas(kGEMMLOWPCacheLineSize) Worker { break; case State::ExitAsSoonAsPossible: return; + case State::Ready: + case State::ThreadStartup: default: abort(); } diff --git a/docs/source/autograd.md b/docs/source/autograd.md index 4218eac05d79d..e78b77e4eb45c 100644 --- a/docs/source/autograd.md +++ b/docs/source/autograd.md @@ -423,8 +423,10 @@ Also see {ref}`saved-tensors-hooks-doc`. ```{eval-rst} .. autofunction:: torch.autograd.graph.get_gradient_edge +``` - +```{eval-rst} +.. autofunction:: torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch ``` % This module needs to be documented. Adding here in the meantime diff --git a/docs/source/notes/libtorch_stable_abi.md b/docs/source/notes/libtorch_stable_abi.md index 20bc0c16c198a..fff32d00cb449 100644 --- a/docs/source/notes/libtorch_stable_abi.md +++ b/docs/source/notes/libtorch_stable_abi.md @@ -2,9 +2,9 @@ ## Overview -The LibTorch Stable ABI (Application Binary Interface) provides an interface for extending PyTorch functionality without being tightly coupled to specific PyTorch versions. This enables the development of custom operators and extensions that remain compatible across PyTorch releases. +The LibTorch Stable ABI (Application Binary Interface) provides a limited interface for extending PyTorch functionality without being tightly coupled to specific PyTorch versions. This enables the development of custom operators and extensions that remain compatible across PyTorch releases. This limited set of APIs is not intended to replace existing LibTorch, but rather to provide a stable foundation for a majority of custom extension use cases. If there is any API you would like to see added to the stable ABI, please file a request through a [new issue on the PyTorch repo](https://github.com/pytorch/pytorch/issues). -The stable ABI consists of three main components: +The limited stable ABI consists of three main components: 1. **Stable C headers** - Low-level C API implemented by libtorch (primarily `torch/csrc/inductor/aoti_torch/c/shim.h`) 2. **Header-only C++ library** - Standalone utilities implemented in only headers such that there is no dependence on libtorch (`torch/headeronly/*`) @@ -14,8 +14,8 @@ We discuss each of these in detail ### `torch/headeronly` -This is a set of inlined C++ headers are completely decoupled from libtorch. The headers consist of certain utilities that might be familiar to custom extension writers. For example, the -`c10::ScalarType` enum lives here as `torch::headeronly::ScalarType`. +The inlined C++ headers living in [`torch/headeronly`](https://github.com/pytorch/pytorch/tree/main/torch/headeronly) are completely decoupled from LibTorch. The headers consist of certain utilities that might be familiar to custom extension writers. For example, the +`c10::ScalarType` enum lives here as `torch::headeronly::ScalarType`, as well as a libtorch-independent version of `TORCH_CHECK` that is `STD_TORCH_CHECK`. You can trust all APIs in the `torch::headeronly` namespace to not depend on `libtorch.so`. These APIs are also globally listed in [torch/header_only_apis.txt](https://github.com/pytorch/pytorch/blob/main/torch/header_only_apis.txt). ### `torch/csrc/stable` @@ -34,8 +34,14 @@ We are continuing to improve coverage in our `torch/csrc/stable` APIs. Please fi ### Stable C headers -The stable C headers used by AOTInductor form the foundation of the stable ABI. However, this is **use at your own risk**. For example, users must handle the memory lifecycle of objects returned by certain APIs. - Further, the stack-based APIs discussed below which allow the user to call the PyTorch dispatcher don't provide strong guarantees on forward and backward compatibility. +The stable C headers started by AOTInductor form the foundation of the stable ABI. Presently, the available C headers include: + +- [torch/csrc/inductor/aoti_torch/c/shim.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/c/shim.h): Includes C-style shim APIs for commonly used regarding Tensors, dtypes, CUDA, and the like. +- [torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h): Includes C-style shim APIs for ATen ops from `native_functions.yaml` (e.g. `aoti_torch_aten_new_empty`). +- [torch/csrc/inductor/aoti_torch/generated/c_shim_*.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/inductor/aoti_torch/generated): Includes C-style shim APIs for specific backend kernels dispatched from `native_functions.yaml` (e.g. `aoti_torch_cuda_pad`). These APIs should only be used for the specific backend they are named after (e.g. `aoti_torch_cuda_pad` should only be used within CUDA kernels), as they opt out of the dispatcher. +- [torch/csrc/stable/c/shim.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/stable/c/shim.h): We are building out more ABIs to logically live in `torch/csrc/stable/c` instead of continuing the AOTI naming that no longer makes sense for our general use case. + +These headers are promised to be ABI stable across releases and adhere to a stronger backwards compatibility policy than LibTorch. Specifically, we promise not to modify them for at least 2 years after they are released. However, this is **use at your own risk**. For example, users must handle the memory lifecycle of objects returned by certain APIs. Further, the stack-based APIs discussed below which allow the user to call into the PyTorch dispatcher do not provide strong guarantees on forward and backward compatibility of the underlying op that is called. Unless absolutely necessary, we recommend the high-level C++ API in `torch/csrc/stable` which will handle all the rough edges of the C API for the user. @@ -122,12 +128,38 @@ The above is relevant in two places: } ``` -2. `aoti_torch_call_dispatcher` +2. `torch_call_dispatcher` This API allows you to call the PyTorch dispatcher from C/C++ code. It has the following signature: + ```cpp - aoti_torch_call_dispatcher(const char* opName, const char* overloadName, StableIValue* stack); + torch_call_dispatcher(const char* opName, const char* overloadName, StableIValue* stack, uint64_t extension_build_version); ``` - `aoti_torch_call_dispatcher` will call the op overload defined by a given `opName`, `overloadName`, and a stack of - StableIValues. This call will populate any return values of the op into the stack in their StableIValue form, - with `ret0` at index 0, `ret1` at index 1, and so on. + `torch_call_dispatcher` will call the op overload defined by a given `opName`, `overloadName`, a stack of + StableIValues and the `TORCH_ABI_VERSION` of the user extension. This call will populate any return values of the + op into the stack in their StableIValue form, with `ret0` at index 0, `ret1` at index 1, and so on. + + We caution against using this API to call functions that have been registered to the dispatcher by other extensions + unless the caller can guarantee that the signature they expect matches that which the custom extension has + registered. + +### Versioning and Forward/Backward compatibility guarantees + +We provide a `TORCH_ABI_VERSION` macro in `torch/headeronly/version.h` of the form + +``` +[ byte ][ byte ][ byte ][ byte ][ byte ][ byte ][ byte ][ byte ] +[MAJ ][ MIN ][PATCH ][ ABI TAG ] +``` + +In the present phase of development, APIs in the C-shim will be versioned based on major.minor.patch release that they are first introduced in, with 2.10 being the first release where this will be enforced. The ABI tag is reserved for future use. + +Extensions can select the minimum abi version to be compatible with using: + +``` +#define TORCH_TARGET_VERSION (((0ULL + major) << 56) | ((0ULL + minor) << 48)) +``` + +before including any stable headers or by passing the equivalent `-D` option to the compiler. Otherwise, the default will be the current `TORCH_ABI_VERSION`. + +The above ensures that if a user defines `TORCH_TARGET_VERSION` to be 0x0209000000000000 (2.9) and attempts to use a C shim API `foo` that was introduced in version 2.10, a compilation error will be raised. Similarly, the C++ wrapper APIs in `torch/csrc/stable` are compatible with older libtorch binaries up to the TORCH_ABI_VERSION they are exposed in and forward compatible with newer libtorch binaries. diff --git a/docs/source/xpu.md b/docs/source/xpu.md index 7a10e29b6af67..6cd82aa984159 100644 --- a/docs/source/xpu.md +++ b/docs/source/xpu.md @@ -76,6 +76,7 @@ :nosignatures: empty_cache + get_per_process_memory_fraction max_memory_allocated max_memory_reserved mem_get_info diff --git a/pyrefly.toml b/pyrefly.toml index cca6f5eb78cc1..249b02227cec9 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -29,6 +29,7 @@ project-excludes = [ "torch/_inductor/runtime/triton_heuristics.py", "torch/_inductor/runtime/triton_helpers.py", "torch/_inductor/runtime/halide_helpers.py", + "torch/utils/tensorboard/summary.py", # formatting issues, will turn on after adjusting where suppressions can be # in import statements "tools/flight_recorder/components/types.py", @@ -46,6 +47,7 @@ project-excludes = [ "torch/distributed/elastic/metrics/__init__.py", "torch/_inductor/fx_passes/bucketing.py", # ==== + "torch/onnx/_internal/exporter/_torchlib/ops/nn.py", "torch/include/**", "torch/csrc/**", "torch/distributed/elastic/agent/server/api.py", diff --git a/setup.py b/setup.py index a980a5f35216a..31e78d0245d93 100644 --- a/setup.py +++ b/setup.py @@ -1106,7 +1106,7 @@ def _embed_libomp(self) -> None: continue self.copy_file(source_lib, target_lib) # Delete old rpath and add @loader_lib to the rpath - # This should prevent delocate from attempting to package another instance + # This should prevent deallocate from attempting to package another instance # of OpenMP library in torch wheel as well as loading two libomp.dylib into # the address space, as libraries are cached by their unresolved names install_name_tool_args = [ diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_ops.py b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_ops.py index a307f65cbc6fa..627df5c7f782f 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_ops.py +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/tests/test_ops.py @@ -246,7 +246,7 @@ def test_scaled_dot_product_fused_attention_overrideable_backward(self): max_k, philox_seed, philox_offset, - debug_attn_mask, + _debug_attn_mask, ) = torch.ops.aten._scaled_dot_product_fused_attention_overrideable( q_privateuse1, k_privateuse1, v_privateuse1, attn_bias=attn_mask_privateuse1 ) @@ -256,25 +256,23 @@ def test_scaled_dot_product_fused_attention_overrideable_backward(self): ) rand_upward_privateuse1 = rand_upward.to("openreg") grad_input_mask = [True, True, True, True] - grad_q, grad_k, grad_v, grad_attn_mask = ( - torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward( - rand_upward_privateuse1, - q_privateuse1, - k_privateuse1, - v_privateuse1, - attn_mask_privateuse1, - grad_input_mask, - output, - logsumexp, - cum_seq_q, - cum_seq_k, - max_q, - max_k, - dropout_p=0.0, - is_causal=False, - philox_seed=philox_seed, - philox_offset=philox_offset, - ) + torch.ops.aten._scaled_dot_product_fused_attention_overrideable_backward( + rand_upward_privateuse1, + q_privateuse1, + k_privateuse1, + v_privateuse1, + attn_mask_privateuse1, + grad_input_mask, + output, + logsumexp, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + dropout_p=0.0, + is_causal=False, + philox_seed=philox_seed, + philox_offset=philox_offset, ) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_autograd.py b/test/distributed/_composable/fsdp/test_fully_shard_autograd.py index 1ee930a717012..cbb11ae3e774a 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_autograd.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_autograd.py @@ -266,7 +266,7 @@ def hook(param_name: str, param: torch.Tensor) -> None: model(inp).sum().backward() param_names = {param_name for param_name, _ in model.named_parameters()} self.assertEqual(param_names, set(param_name_to_hook_count.keys())) - for param_name, count in param_name_to_hook_count.items(): + for count in param_name_to_hook_count.values(): self.assertEqual(count, 1) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py b/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py index d193d65b179a5..a23faed921137 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py @@ -438,8 +438,8 @@ def test_rank0_offload_full_state_dict(self): if self.rank == 0: self.assertEqual(len(full_sd), len(ref_full_sd)) self.assertEqual(list(full_sd.keys()), list(ref_full_sd.keys())) - for (param_name, param), ref_param in zip( - full_sd.items(), ref_full_sd.values() + for param, ref_param in zip( + full_sd.values(), ref_full_sd.values(), strict=True ): self.assertEqual(param.device, torch.device("cpu")) self.assertEqual(param.device, ref_param.device) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index 8331cd90ce9bc..da847fc32f761 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -827,7 +827,7 @@ def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]: torch.manual_seed(42 + self.rank) inp = torch.randint(0, model_args.vocab_size, (2, 16), device=device_type.type) - for iter_idx in range(5): + for _ in range(5): ref_loss = ref_model(inp).sum() loss = model(inp).sum() self.assertEqual(ref_loss, loss) diff --git a/test/distributed/_composable/test_composability/test_pp_composability.py b/test/distributed/_composable/test_composability/test_pp_composability.py index a184901f6ef05..80b86dcf50764 100644 --- a/test/distributed/_composable/test_composability/test_pp_composability.py +++ b/test/distributed/_composable/test_composability/test_pp_composability.py @@ -392,11 +392,11 @@ def test_replicate_pp(self, ScheduleClass, MixedPrecisionParam): replicate_size = self.world_size // (pp_size) device_mesh = init_device_mesh( device_type, - mesh_shape=(replicate_size, 1, pp_size), - mesh_dim_names=("replicate", "shard", "pp"), + mesh_shape=(replicate_size, pp_size), + mesh_dim_names=("replicate", "pp"), ) torch.manual_seed(42) - dp_mesh = device_mesh["replicate", "shard"] + dp_mesh = device_mesh["replicate"] pp_mesh = device_mesh["pp"] pp_group = device_mesh["pp"].get_group() @@ -582,11 +582,11 @@ def test_replicate_pp_grads(self, ScheduleClass): replicate_size = self.world_size // (pp_size) device_mesh = init_device_mesh( device_type, - mesh_shape=(replicate_size, 1, pp_size), - mesh_dim_names=("replicate", "shard", "pp"), + mesh_shape=(replicate_size, pp_size), + mesh_dim_names=("replicate", "pp"), ) torch.manual_seed(42) - dp_mesh = device_mesh["replicate", "shard"] + dp_mesh = device_mesh["replicate"] pp_mesh = device_mesh["pp"] pp_group = device_mesh["pp"].get_group() dp_group = device_mesh["replicate"].get_group() diff --git a/test/distributed/_composable/test_replicate_training.py b/test/distributed/_composable/test_replicate_training.py index 0cb7fad1b9b14..1268dad3ff8aa 100644 --- a/test/distributed/_composable/test_replicate_training.py +++ b/test/distributed/_composable/test_replicate_training.py @@ -108,7 +108,7 @@ def test_param_registration_after_forward(self): """Tests the parameter registration after forward.""" device = torch.device(device_type.type, 0) # Single Replicate group - for reshard_after_forward in (True, False, None): + for reshard_after_forward in (False,): torch.manual_seed(42) model = MLP(3, device) # Since seed is per process, not per thread, we broadcast to ensure @@ -131,7 +131,7 @@ def test_param_registration_after_forward(self): self._assert_same_params(model.parameters(), ref_model.parameters()) # Multiple Replicate groups - for reshard_after_forward in (True, False, None): + for reshard_after_forward in (False,): torch.manual_seed(42) model = nn.Sequential(MLP(3, device), MLP(3, device)) for param in model.parameters(): @@ -405,8 +405,8 @@ def _test_train_parity_multi_group( ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) mesh = init_device_mesh( test_device_type, - (self.world_size, 1), - mesh_dim_names=("replicate", "shard"), + (self.world_size,), + mesh_dim_names=("replicate",), ) fully_shard_fn = functools.partial( replicate, @@ -740,8 +740,8 @@ def _test_train_parity_with_activation_checkpointing( # Apply Replicate device_mesh = init_device_mesh( test_device_type, - (self.world_size, 1), - mesh_dim_names=("replicate", "shard"), + (self.world_size,), + mesh_dim_names=("replicate",), ) fsdp_kwargs = { "reshard_after_forward": reshard_after_forward, @@ -868,11 +868,11 @@ def test_gradient_accumulation(self): with/without resharding after backward. """ - shard_size, replicate_size = 1, self.world_size + replicate_size = self.world_size meshes = init_device_mesh( device_type.type, - (replicate_size, shard_size), - mesh_dim_names=("replicate", "shard"), + (replicate_size,), + mesh_dim_names=("replicate",), ) self.run_subtests( { @@ -1145,8 +1145,8 @@ def world_size(self) -> int: def init_global_mesh(self) -> DeviceMesh: return init_device_mesh( device_type.type, - (2, 1, 2), - mesh_dim_names=("dp_replicate", "dp_shard", "tp"), + (2, 2), + mesh_dim_names=("dp_replicate", "tp"), ) @skip_if_lt_x_gpu(8) @@ -1170,7 +1170,7 @@ def _test_replicate_tp( mlp_dim: int, foreach: bool, ): - dp_mesh, tp_mesh = global_mesh["dp_replicate", "dp_shard"], global_mesh["tp"] + dp_mesh, tp_mesh = global_mesh["dp_replicate"], global_mesh["tp"] dp_pg = dp_mesh._flatten().get_group() # used for `replicate()` torch.manual_seed(42) @@ -1229,11 +1229,9 @@ def _test_replicate_tp( for _, p in model.named_parameters(): self.assertIsInstance(p, DTensor) - self.assertEqual(p.device_mesh.ndim, 3) - self.assertEqual(len(p.placements), 3) - self.assertEqual( - p.device_mesh.mesh_dim_names, ("dp_replicate", "dp_shard", "tp") - ) + self.assertEqual(p.device_mesh.ndim, 2) + self.assertEqual(len(p.placements), 2) + self.assertEqual(p.device_mesh.mesh_dim_names, ("dp_replicate", "tp")) if __name__ == "__main__": diff --git a/test/distributed/_composable/test_replicate_with_fsdp.py b/test/distributed/_composable/test_replicate_with_fsdp.py index 099f84b9e848f..b4c55e10b4b59 100644 --- a/test/distributed/_composable/test_replicate_with_fsdp.py +++ b/test/distributed/_composable/test_replicate_with_fsdp.py @@ -120,7 +120,7 @@ def _test_replicate_transformer(self, sharding_strategy): if i % 2 == 0: self.assertTrue("replicate" in _get_registry(layer)) for parameter in layer.parameters(): - self.assertEqual(parameter.placements, (Replicate(), Shard(dim=0))) + self.assertEqual(parameter.placements, (Replicate(),)) elif i % 2 == 1: self.assertTrue("fully_shard" in _get_registry(layer)) for parameter in layer.parameters(): @@ -197,14 +197,14 @@ def test_replicate_tp_device_mesh(self): ] global_mesh = self.init_replicate_tp_mesh() - replicate_mesh = global_mesh["replicate", "shard"] + replicate_mesh = global_mesh["replicate"] for layer in layers: replicate(layer, device_mesh=replicate_mesh) for parameter in layer.parameters(): - self.assertEqual(parameter.device_mesh.shape, (2, 1)) - self.assertEqual(parameter.placements, (Replicate(), Shard(dim=0))) + self.assertEqual(parameter.device_mesh.shape, (2,)) + self.assertEqual(parameter.placements, (Replicate(),)) @skip_if_lt_x_gpu(2) def test_train_replicate_fsdp(self): @@ -263,7 +263,6 @@ def test_train_parity_2d_mlp(self): run_subtests( self, { - "reshard_after_forward": [False, True], "use_activation_checkpointing": [False, True], "mlp_dim": [3, 16, 17], }, @@ -273,7 +272,6 @@ def test_train_parity_2d_mlp(self): def _test_train_parity_2d_mlp( self, global_mesh: DeviceMesh, - reshard_after_forward: bool, use_activation_checkpointing: bool, mlp_dim: int, ): @@ -287,13 +285,12 @@ def _test_train_parity_2d_mlp( torch.manual_seed(42) model = MLPStack(mlp_dim) ref_model = copy.deepcopy(model).cuda() - replicate(ref_model, device_mesh=replicate_shard_mesh) + replicate(ref_model, device_mesh=replicate_mesh) ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False) model.parallelize( tp_mesh, replicate_shard_mesh, use_activation_checkpointing, - reshard_after_forward=reshard_after_forward, ) optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False) diff --git a/test/distributed/checkpoint/_experimental/test_checkpointer.py b/test/distributed/checkpoint/_experimental/test_checkpointer.py index f96ecb6e1d7a1..62fde0b3166df 100644 --- a/test/distributed/checkpoint/_experimental/test_checkpointer.py +++ b/test/distributed/checkpoint/_experimental/test_checkpointer.py @@ -560,6 +560,7 @@ def test_async_multiple_saves_ordering(self): # Wait for all to complete for stage_future, write_future in futures: + stage_future.result() write_future.result() # Verify all checkpoints exist and have correct content diff --git a/test/distributed/checkpoint/test_async_process_executor.py b/test/distributed/checkpoint/test_async_process_executor.py index 9dc7095b0d6c3..1eabc75527da6 100644 --- a/test/distributed/checkpoint/test_async_process_executor.py +++ b/test/distributed/checkpoint/test_async_process_executor.py @@ -1,16 +1,26 @@ # Owner(s): ["oncall: distributed checkpointing"] +import os import sys from unittest.mock import patch import torch +import torch.testing._internal.common_utils as common from torch import distributed as dist from torch.distributed.checkpoint._async_process_executor import ( _ProcessBasedAsyncCheckpointExecutor, + _ProcessGroupInitInfo, ) +from torch.distributed.checkpoint.api import CheckpointException from torch.distributed.checkpoint.storage import StorageWriter from torch.distributed.elastic.utils.distributed import get_free_port -from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN +from torch.testing._internal.common_distributed import skip_if_win32 +from torch.testing._internal.common_utils import ( + retry_on_connect_failures, + run_tests, + TEST_WITH_DEV_DBG_ASAN, + TestCase, +) from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, @@ -110,47 +120,184 @@ def test_checkpoint_save_failure_continues_serving(self) -> None: "epoch": 5, } - # 1. Simulate a failure in creating PG in background process. - with patch( - "torch.distributed.checkpoint._async_process_executor.get_free_port", - return_value=-1, + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("DCP_USE_PREFIX_STORE", None) + + # 1. Simulate a failure in creating PG in background process. + with patch( + "torch.distributed.checkpoint._async_process_executor.get_free_port", + return_value=-1, + ): + with self.assertRaises(ValueError) as _: + proc_executor = _ProcessBasedAsyncCheckpointExecutor() + fut = proc_executor.execute_save( + staging_future_or_state_dict=test_state_dict, + ) + fut.result() + + # 2. Attempt save with failing storage writer + with patch( + "torch.distributed.checkpoint._async_process_executor.get_free_port", + return_value=get_free_port(), + ) as mock_get_free_port: + proc_executor = _ProcessBasedAsyncCheckpointExecutor() + fut = proc_executor.execute_save( + staging_future_or_state_dict=test_state_dict, + storage_writer=TestStorageWriter(behavior="fail_once"), + ) + self.assertIn( + "fail_once policy triggered failure", str(fut.exception()) + ) + # Verify new process was created for this attempt + if dist.get_rank() == 0: + mock_get_free_port.assert_called_once() + + # 3. Second save attempt with successful storage writer - process should still be alive + with patch( + "torch.distributed.checkpoint._async_process_executor.get_free_port", + ) as mock_get_free_port: + proc_executor = _ProcessBasedAsyncCheckpointExecutor() + fut = proc_executor.execute_save( + staging_future_or_state_dict=test_state_dict, + storage_writer=TestStorageWriter(behavior="success"), + ) + result = fut.result() + # Verify process is still alive + mock_get_free_port.assert_not_called() + # Verify successful save + self.assertIsNotNone(result) + + +class TestAsyncProcessExecutorPrefixStore(TestCase): + @skip_if_win32() + @retry_on_connect_failures + def test_checkpoint_save_with_prefix_store_enabled(self) -> None: + """Test that checkpoint save works when DCP_USE_PREFIX_STORE is enabled.""" + + test_state_dict = { + "model": {"weight": torch.randn(4, 4), "bias": torch.randn(4)}, + "optimizer": {"param_groups": [{"lr": 0.01}]}, + "epoch": 5, + } + + master_addr = "localhost" + master_port = str(common.find_free_port()) + + with patch.dict( + os.environ, + { + "DCP_USE_PREFIX_STORE": "1", + "MASTER_ADDR": master_addr, + "MASTER_PORT": master_port, + }, ): - with self.assertRaises(ValueError) as _: + with patch( + "torch.distributed.checkpoint._async_process_executor.get_free_port" + ) as mock_get_free_port: + dist.init_process_group( + backend=dist.Backend.GLOO, + rank=0, + world_size=1, + ) + proc_executor = _ProcessBasedAsyncCheckpointExecutor() fut = proc_executor.execute_save( staging_future_or_state_dict=test_state_dict, + storage_writer=TestStorageWriter(behavior="success"), ) - fut.result() - - # 2. Attempt save with failing storage writer - with patch( - "torch.distributed.checkpoint._async_process_executor.get_free_port", - return_value=get_free_port(), - ) as mock_get_free_port: - proc_executor = _ProcessBasedAsyncCheckpointExecutor() - fut = proc_executor.execute_save( - staging_future_or_state_dict=test_state_dict, - storage_writer=TestStorageWriter(behavior="fail_once"), - ) - self.assertIn("fail_once policy triggered failure", str(fut.exception())) - # Verify new process was created for this attempt - if dist.get_rank() == 0: - mock_get_free_port.assert_called_once() - - # 3. Second save attempt with successful storage writer - process should still be alive - with patch( - "torch.distributed.checkpoint._async_process_executor.get_free_port", - ) as mock_get_free_port: - proc_executor = _ProcessBasedAsyncCheckpointExecutor() - fut = proc_executor.execute_save( - staging_future_or_state_dict=test_state_dict, - storage_writer=TestStorageWriter(behavior="success"), - ) - result = fut.result() - # Verify process is still alive - mock_get_free_port.assert_not_called() - # Verify successful save - self.assertIsNotNone(result) + result = fut.result() + self.assertIsNotNone(result) + mock_get_free_port.assert_not_called() + + +class TestProcessGroupInitInfo(DTensorTestBase): + """Test suite for _ProcessGroupInitInfo.""" + + @with_comms + def test_process_group_init_info_with_default_pg(self) -> None: + """Test that ProcessGroupInitInfo correctly initializes.""" + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("DCP_USE_PREFIX_STORE", None) + + pg_init_info = _ProcessGroupInitInfo() + + self.assertEqual(pg_init_info.global_rank, dist.get_rank()) + self.assertEqual(pg_init_info.world_size, dist.get_world_size()) + self.assertIsNotNone(pg_init_info.tcp_store_master_addr) + self.assertGreater(pg_init_info.tcp_store_master_port, 0) + self.assertEqual(pg_init_info.use_prefix_store, False) + + @with_comms + def test_process_group_init_info_with_prefix_store_env_var(self) -> None: + """Test that ProcessGroupInitInfo handles DCP_USE_PREFIX_STORE environment variable.""" + + # Flag enabled, addr/port correctly defined + with patch.dict( + os.environ, + { + "DCP_USE_PREFIX_STORE": "1", + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + }, + ): + pg_init_info = _ProcessGroupInitInfo() + self.assertTrue(pg_init_info.use_prefix_store) + + # Missing port + with patch.dict( + os.environ, {"DCP_USE_PREFIX_STORE": "1", "MASTER_ADDR": "localhost"} + ): + with self.assertRaises(CheckpointException): + pg_init_info = _ProcessGroupInitInfo() + # Missing addr + with patch.dict( + os.environ, {"DCP_USE_PREFIX_STORE": "1", "MASTER_PORT": "12345"} + ): + with self.assertRaises(CheckpointException): + pg_init_info = _ProcessGroupInitInfo() + # Invalid port + with patch.dict( + os.environ, + { + "DCP_USE_PREFIX_STORE": "1", + "MASTER_ADDR": "localhost", + "MASTER_PORT": "a", + }, + ): + with self.assertRaises(CheckpointException): + pg_init_info = _ProcessGroupInitInfo() + + @with_comms + def test_process_group_init_info_without_prefix_store_env_var(self) -> None: + """Test that ProcessGroupInitInfo defaults to not using prefix store.""" + + # Env var set to 0 + with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "0"}): + pg_init_info = _ProcessGroupInitInfo() + self.assertFalse(pg_init_info.use_prefix_store) + + # Missing env var + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("DCP_USE_PREFIX_STORE", None) + pg_init_info = _ProcessGroupInitInfo() + self.assertFalse(pg_init_info.use_prefix_store) + + # Invalid env var + with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "2"}): + pg_init_info = _ProcessGroupInitInfo() + self.assertFalse(pg_init_info.use_prefix_store) + + with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "true"}): + pg_init_info = _ProcessGroupInitInfo() + self.assertFalse(pg_init_info.use_prefix_store) + + with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": "false"}): + pg_init_info = _ProcessGroupInitInfo() + self.assertFalse(pg_init_info.use_prefix_store) + + with patch.dict(os.environ, {"DCP_USE_PREFIX_STORE": ""}): + pg_init_info = _ProcessGroupInitInfo() + self.assertFalse(pg_init_info.use_prefix_store) if __name__ == "__main__": diff --git a/test/distributed/elastic/multiprocessing/api_test.py b/test/distributed/elastic/multiprocessing/api_test.py index 5c31645e6b7f4..c05c0884c78ce 100644 --- a/test/distributed/elastic/multiprocessing/api_test.py +++ b/test/distributed/elastic/multiprocessing/api_test.py @@ -800,6 +800,7 @@ def test_wrap_bad(self): stderr_redirects={0: stderr_redir}, ret_vals={0: queue}, queue_finished_reading_event=worker_finished_event_mock, + numa_options=None, ) self.assertEqual("hello_0", queue.get()) if stdout_redir: diff --git a/test/distributed/fsdp/test_checkpoint_wrapper.py b/test/distributed/fsdp/test_checkpoint_wrapper.py index db44c5a41ba12..0acb530f441fc 100644 --- a/test/distributed/fsdp/test_checkpoint_wrapper.py +++ b/test/distributed/fsdp/test_checkpoint_wrapper.py @@ -303,7 +303,7 @@ def check_fn(l): ) inp = torch.randn(4, 10, requires_grad=True) - for i in range(6): + for _ in range(6): # Kwarg input loss = model(x=inp).sum() self.assertTrue(loss.requires_grad) diff --git a/test/distributed/fsdp/test_distributed_checkpoint.py b/test/distributed/fsdp/test_distributed_checkpoint.py index c80602c5d50f3..00479cf0935b9 100644 --- a/test/distributed/fsdp/test_distributed_checkpoint.py +++ b/test/distributed/fsdp/test_distributed_checkpoint.py @@ -31,17 +31,17 @@ sys.exit(0) -_DISTRIBUTED_STATE_DICT_IMPLS = ( +_DISTRIBUTED_STATE_DICT_IMPLS = { StateDictType.LOCAL_STATE_DICT, StateDictType.SHARDED_STATE_DICT, -) +} class TestDistributedCheckpoint(FSDPTest): @property def world_size(self): - if torch.cuda.is_available(): - gpu_cnt = torch.cuda.device_count() + if torch.accelerator.is_available(): + gpu_cnt = torch.accelerator.device_count() if gpu_cnt < 2: return gpu_cnt return 2 @@ -93,7 +93,9 @@ def test_distributed_checkpoint(self, state_dict_type) -> None: # TODO: add resharding test case. -devices = ("cuda", "hpu") -instantiate_device_type_tests(TestDistributedCheckpoint, globals(), only_for=devices) +devices = ("cuda", "hpu", "xpu") +instantiate_device_type_tests( + TestDistributedCheckpoint, globals(), only_for=devices, allow_xpu=True +) if __name__ == "__main__": run_tests() diff --git a/test/distributed/fsdp/test_fsdp_apply.py b/test/distributed/fsdp/test_fsdp_apply.py index c0f1a791c5346..3c88f89851e1c 100644 --- a/test/distributed/fsdp/test_fsdp_apply.py +++ b/test/distributed/fsdp/test_fsdp_apply.py @@ -36,8 +36,8 @@ class TestApply(FSDPTest): @property def world_size(self): - if torch.cuda.is_available(): - gpu_cnt = torch.cuda.device_count() + if torch.accelerator.is_available(): + gpu_cnt = torch.accelerator.device_count() if gpu_cnt < 2: return gpu_cnt return 2 diff --git a/test/distributed/fsdp/test_fsdp_misc.py b/test/distributed/fsdp/test_fsdp_misc.py index 2ae986af785b8..99a1c3ad1707c 100644 --- a/test/distributed/fsdp/test_fsdp_misc.py +++ b/test/distributed/fsdp/test_fsdp_misc.py @@ -514,18 +514,17 @@ def test_fsdp_optimizer_overlap(self): def test_fsdp_cpu_training(self): """Tests FSDP training on CPU.""" gloo_pg = dist.new_group(backend="gloo") - for ss in [ # noqa: F841 + for ss in [ ShardingStrategy.NO_SHARD, ShardingStrategy.FULL_SHARD, ShardingStrategy.SHARD_GRAD_OP, - ShardingStrategy.HYBRID_SHARD, - ShardingStrategy._HYBRID_SHARD_ZERO2, ]: torch.manual_seed(42) model = MyModel() ref_model = DDP(deepcopy(model), process_group=gloo_pg) model = FSDP( model, + sharding_strategy=ss, auto_wrap_policy=always_wrap_policy, process_group=gloo_pg, device_id=torch.device("cpu"), diff --git a/test/distributed/nn/jit/test_instantiator.py b/test/distributed/nn/jit/test_instantiator.py index 9d2931ba9b60b..37cd99be10d7b 100644 --- a/test/distributed/nn/jit/test_instantiator.py +++ b/test/distributed/nn/jit/test_instantiator.py @@ -2,7 +2,6 @@ # Owner(s): ["oncall: distributed"] import sys -from pathlib import Path import torch import torch.distributed as dist @@ -45,53 +44,19 @@ def test_get_arg_return_types_from_interface(self): self.assertEqual(return_type_str, "Tuple[Tensor, int, str]") def test_instantiate_scripted_remote_module_template(self): - dir_path = Path(instantiator.INSTANTIATED_TEMPLATE_DIR_PATH) - - # Cleanup. - file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py") - for file_path in file_paths: - file_path.unlink() - - # Check before run. - file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py") - num_files_before = len(list(file_paths)) - self.assertEqual(num_files_before, 0) - generated_module = instantiator.instantiate_scriptable_remote_module_template( MyModuleInterface ) self.assertTrue(hasattr(generated_module, "_remote_forward")) self.assertTrue(hasattr(generated_module, "_generated_methods")) - # Check after run. - file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py") - num_files_after = len(list(file_paths)) - self.assertEqual(num_files_after, 1) - def test_instantiate_non_scripted_remote_module_template(self): - dir_path = Path(instantiator.INSTANTIATED_TEMPLATE_DIR_PATH) - - # Cleanup. - file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py") - for file_path in file_paths: - file_path.unlink() - - # Check before run. - file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py") - num_files_before = len(list(file_paths)) - self.assertEqual(num_files_before, 0) - generated_module = ( instantiator.instantiate_non_scriptable_remote_module_template() ) self.assertTrue(hasattr(generated_module, "_remote_forward")) self.assertTrue(hasattr(generated_module, "_generated_methods")) - # Check after run. - file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py") - num_files_after = len(list(file_paths)) - self.assertEqual(num_files_after, 1) - if __name__ == "__main__": run_tests() diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py index 5c90ad8be144e..35eefdad512e6 100644 --- a/test/distributed/optim/test_zero_redundancy_optimizer.py +++ b/test/distributed/optim/test_zero_redundancy_optimizer.py @@ -1048,7 +1048,7 @@ def join_process_group(self): grads=grads, ): for _ in range(NUM_EPOCHS): - for input in inputs: + for _input in inputs: # Notify join context that this process has not joined Join.notify_join_context(gradient_setter) # Set gradients manually diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index 6d70de134e8e7..8b9d71b78b97c 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -64,6 +64,38 @@ def test_debug_mode_mm(self): self.assertTrue(isinstance(debug_mode.operators[2], _RedistributeCall)) self.assertEqual(next(iter(debug_mode.operators[1])), torch.ops.aten.mm.default) + # check stringification + self.assertTrue(hasattr(debug_mode.operators[0], "args_str")) + self.assertFalse(hasattr(debug_mode.operators[0], "args")) + + # check recording hook + def mm(x, y): + return (x @ y).sum() + + eager_out = mm(x_dtensor, y_dtensor) + + # check recording hook for compiled variant + with ( + DebugMode() as debug_mode, + DebugMode.record_outputs(), + DebugMode.log_tensor_hashes(), + ): + compiled_out = torch.compile(mm, backend="aot_eager")(x_dtensor, y_dtensor) + + # check numerical equivalence + self.assertTrue(torch.equal(eager_out, compiled_out)) + sum_op = next( + iter( + op + for op in debug_mode.operators + if isinstance(op, _OpCall) and str(op.op) == "aten.sum.default" + ) + ) + self.assertTrue(torch.equal(sum_op.record["output"], eager_out.to_local())) + self.assertTrue( + "aten::sum(t: f32[1, 32]) # {'hash': " in debug_mode.debug_string() + ) + def test_debug_string_inside_context(self): mesh = DeviceMesh(self.device_type, list(range(self.world_size))) @@ -86,7 +118,9 @@ def test_debug_mode_backward(self): x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) y_dtensor = DTensor.from_local(y, mesh, [Shard(1)], run_check=False) - with DebugMode(record_torchfunction=True) as debug_mode: + with DebugMode( + record_torchfunction=True, record_stack_trace=True + ) as debug_mode: z = x_dtensor + y_dtensor z.sum().backward() @@ -119,6 +153,9 @@ def test_debug_mode_backward(self): aten::detach(t: f32[1, 8])""", ) + # check stack trace + self.assertTrue("z.sum().backward()" in debug_mode.operators[-1].stack_trace) + def test_debug_mode_densor_redistribution_trace(self): mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).view(4, 2)) @@ -267,6 +304,7 @@ def test_tensor_attributes(self): record_torchfunction=True, record_faketensor=True, record_tensor_attributes=["a1", "a2"], + store_original_args=True, ) as debug_mode: torch.matmul(y, x) @@ -279,6 +317,9 @@ def test_tensor_attributes(self): aten::_unsafe_view(t: f32[64, 8], [8, 8, 8])""", ) + self.assertTrue(hasattr(debug_mode.operators[0], "args")) + self.assertEqual(id(debug_mode.operators[0].args[0]), id(y)) + @parametrize("has_inner_mode", [True, False]) @parametrize("has_outer_mode", [True, False]) def test_nested_debug_mode(self, has_inner_mode, has_outer_mode): diff --git a/test/distributed/tensor/test_attention.py b/test/distributed/tensor/test_attention.py index 64d86ba3c129f..eaf3a4042060d 100644 --- a/test/distributed/tensor/test_attention.py +++ b/test/distributed/tensor/test_attention.py @@ -20,18 +20,18 @@ _cp_options, _disable_context_parallel_dispatcher, _enable_context_parallel_dispatcher, + _HeadTailLoadBalancer, _is_causal_behavior, + _LoadBalancer, + _PerDocumentHeadTailLoadBalancer, + _PTRRLoadBalancer, _RotateMethod, context_parallel, context_parallel_unshard, set_rotate_method, ) -from torch.distributed.tensor.experimental._cp_custom_ops import flex_cp_allgather -from torch.distributed.tensor.experimental._load_balancer import ( - _HeadTailLoadBalancer, - _LoadBalancer, - _PerDocumentHeadTailLoadBalancer, - _PTRRLoadBalancer, +from torch.distributed.tensor.experimental._context_parallel._cp_custom_ops import ( + flex_cp_allgather, ) from torch.distributed.tensor.parallel import parallelize_module from torch.nn.attention import sdpa_kernel, SDPBackend @@ -52,7 +52,9 @@ from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_utils import run_tests, skipIfRocm from torch.testing._internal.distributed._tensor.common_dtensor import ( + create_local_tensor_test_class, DTensorTestBase, + map_local_tensor_for_rank, with_comms, ) @@ -800,11 +802,47 @@ def test_context_parallel_shard(self) -> None: chunks = freqs_cis.chunk(self.world_size * 2) self.assertEqual( freqs_cis_shard, - torch.cat( - [chunks[self.rank], chunks[self.world_size * 2 - self.rank - 1]], dim=0 + map_local_tensor_for_rank( + chunks, + self.rank, + lambda chunks, rank: torch.cat( + [chunks[rank], chunks[self.world_size * 2 - rank - 1]], + dim=0, + ), ), ) +RingAttentionTestWithLocalTensor = create_local_tensor_test_class( + RingAttentionTest, + skipped_tests=[ + # Need to make attention implementation local tensor friendly, e.g. + # rewrite "rank local" logic + "test_ring_attention_sdpa", + ], +) + +CPFlexAttentionTestWithLocalTensor = create_local_tensor_test_class( + CPFlexAttentionTest, + skipped_tests=[ + # Missing support for batched tensors + "test_cp_flex_attention_causal_mask", + "test_cp_flex_attention_document_mask", + ], +) + +TestCPCustomOpsWithLocalTensor = create_local_tensor_test_class( + TestCPCustomOps, + skipped_tests=[ + # Missing support for fake tensors + "test_flex_cp_custom_op", + ], +) + +TestShardingWithLocalTensor = create_local_tensor_test_class( + TestSharding, +) + + if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/test_convolution_ops.py b/test/distributed/tensor/test_convolution_ops.py index 68c52353b21ae..de4343eef6a4e 100644 --- a/test/distributed/tensor/test_convolution_ops.py +++ b/test/distributed/tensor/test_convolution_ops.py @@ -16,6 +16,7 @@ from torch.nn import functional as F from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( + create_local_tensor_test_class, DTensorTestBase, skip_if_lt_x_gpu, with_comms, @@ -203,34 +204,42 @@ def test_conv_backward_none_grad_inp(self): self.assertTrue(b_dt.grad is not None) self.assertTrue(x_dt.grad is None) + def _run_single_arg_fwd(self, model, arg) -> tuple[torch.Tensor, torch.Tensor]: + """Given model and arg, runs fwd model local and distbuted given device_mesh""" + device_mesh = self.build_device_mesh() + model_copy = copy.deepcopy(model).to(device=self.device_type) + dist_model = distribute_module(model, device_mesh, _conv_fn) + arg_dt = DTensor.from_local(arg, device_mesh, [Replicate()]) + out_dt = dist_model(arg_dt.to(device=self.device_type)) + out = model_copy(arg) + return (out_dt.full_tensor(), out) + @with_comms def test_conv1d(self): - device_mesh = self.build_device_mesh() model = nn.Conv1d(64, 64, 3, padding=1) - model_gt = copy.deepcopy(model) - x = torch.randn(1, 64, 8) - x_dt = DTensor.from_local(x, device_mesh, [Replicate()]) - model_dt = distribute_module( - model, device_mesh, _conv_fn, input_fn=None, output_fn=None - ) - out_dt = model_dt(x_dt) - out = model_gt(x) + x = torch.randn(1, 64, 8, device=self.device_type) + out_dt, out = self._run_single_arg_fwd(model, x) self.assertEqual(out_dt.shape, out.shape) @with_comms def test_conv3d(self): - device_mesh = self.build_device_mesh() model = nn.Conv3d(64, 64, 3, padding=1) - model_gt = copy.deepcopy(model).to(device=self.device_type) x = torch.randn(1, 64, 8, 8, 8, device=self.device_type) - x_dt = DTensor.from_local(x, device_mesh, [Replicate()]) - model_dt = distribute_module( - model, device_mesh, _conv_fn, input_fn=None, output_fn=None - ) - out_dt = model_dt(x_dt) - out = model_gt(x) + out_dt, out = self._run_single_arg_fwd(model, x) self.assertEqual(out_dt.shape, out.shape) +DistConvolutionOpsTestWithLocalTensor = create_local_tensor_test_class( + DistConvolutionOpsTest, + # Send / recv ops are not supported + skipped_tests=[ + "test_conv1d", + "test_conv3d", + "test_conv_backward_none_grad_inp", + "test_depthwise_convolution", + "test_downsampling_convolution", + ], +) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/tensor/test_dtensor_export.py b/test/distributed/tensor/test_dtensor_export.py index a16e2f2fdd795..d2104066811be 100644 --- a/test/distributed/tensor/test_dtensor_export.py +++ b/test/distributed/tensor/test_dtensor_export.py @@ -520,6 +520,21 @@ def causal_mask(b, h, q_idx, kv_idx): 2, ) + def test_union_typed_annotation(self): + def fn(leaf: torch.Tensor | DTensor): + def nest_fn(leaf: torch.Tensor | DTensor): + # def nest_fn(leaf: Union[torch.Tensor, DTensor]): # this works + if isinstance(leaf, DTensor): + leaf = leaf.to_local() + return leaf + + return nest_fn(leaf) + 1 + + z = torch.randn(16, 16) + gm = graph_capture_and_aot_export_joint_with_descriptors(fn, (z,)) + + self.assertEqual(fn(z), gm(z)[0]) + instantiate_parametrized_tests(DTensorExportTest) diff --git a/test/distributed/tensor/test_math_ops.py b/test/distributed/tensor/test_math_ops.py index 51a8186bac509..f031085b23bd2 100644 --- a/test/distributed/tensor/test_math_ops.py +++ b/test/distributed/tensor/test_math_ops.py @@ -60,9 +60,9 @@ def linear_op_reductions(self, op_str): shard_spec = [Shard(0)] tensor = torch.randn(12, 8, 8) - # TODO: check `all` correctness and test `all` on a bool tensor - if op_str in ("any"): - # test out a bool tensor for any + if op_str in ("any", "all"): + # Test bool tensor for any() and all() reduction ops + # Previously all() had a bug using sum reduction instead of product tensor = tensor < 0 dtensor = distribute_tensor(tensor, device_mesh, shard_spec) diff --git a/test/distributed/test_aten_comm_compute_reordering.py b/test/distributed/test_aten_comm_compute_reordering.py index 4894c6853cdae..5b1db2d8dfe14 100644 --- a/test/distributed/test_aten_comm_compute_reordering.py +++ b/test/distributed/test_aten_comm_compute_reordering.py @@ -887,6 +887,135 @@ def func(a, b, c, d, *, ranks): correct = func(a, b, c, d, ranks=ranks) self.assertTrue(same(test_out, correct)) + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @torch._inductor.config.patch(get_bucket_patches()) + def test_custom_estimation_with_fake_tensor_mode(self): + """Test that custom estimation can use FakeTensorMode for analysis.""" + from torch._subclasses.fake_tensor import FakeTensorMode + + estimation_calls = 0 + + def estimate_with_fake_mode(fx_node, compute_multiplier=1.0): + with FakeTensorMode(): + nonlocal estimation_calls + estimation_calls += 1 + assert isinstance(torch.rand([20]), torch._subclasses.FakeTensor) + + return 1.0 + + patches = get_bucket_patches() + patches["aten_distributed_optimizations.custom_runtime_estimation"] = ( + estimate_with_fake_mode + ) + + def func(a, b, *, ranks): + # Two independent all_gathers that should be bucketed + ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks) + ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks) + + # Matmul that can hide the collectives + mm1 = torch.matmul(a, a) + + return ag1.sum() + ag2.sum() + mm1.sum() + + with _dynamo_dist_per_rank_init( + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), + ): + inputs_a = torch.ones(4, 4, dtype=torch.float, device=device_type) + inputs_b = torch.ones(4, 4, dtype=torch.float, device=device_type) * 2 + ranks = list(range(self.world_size)) + + func_c = functools.partial(func, ranks=ranks) + with torch._inductor.config.patch(patches): + compiled = torch.compile(func_c) + out, aten_graph_str = run_and_get_aten_graph( + compiled, inputs_a, inputs_b + ) + + # Verify the custom estimation was called + self.assertTrue( + estimation_calls > 0, "Custom estimation should have been called" + ) + + correct = func(inputs_a, inputs_b, ranks=ranks) + self.assertTrue(same(out, correct)) + + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @torch._inductor.config.patch(get_bucket_patches()) + def test_multidtype_bucketing(self): + """Test that all_gathers with different dtypes get bucketed together.""" + + def func(a, b, c, *, ranks): + # Three all_gathers with different dtypes + ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks) # float32 + ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks) # float16 + ag3 = _functional_collectives.all_gather_tensor(c, 0, ranks) # float16 + + # Use all results + return ag1.sum() + ag2.sum() + ag3.sum() + + with _dynamo_dist_per_rank_init( + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), + ): + a = torch.ones(4, 4, dtype=torch.float32, device=device_type) + b = torch.ones(4, 4, dtype=torch.float16, device=device_type) * 2 + c = torch.ones(4, 4, dtype=torch.float16, device=device_type) * 3 + ranks = list(range(self.world_size)) + + func_c = functools.partial(func, ranks=ranks) + compiled = torch.compile(func_c) + out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c) + + # Should have 1 bucketed all_gather despite different dtypes + FileCheck().check_count( + "torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True + ).run(aten_graph_str) + + # Verify correctness + correct = func(a, b, c, ranks=ranks) + self.assertTrue(same(out, correct)) + + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @torch._inductor.config.patch(get_bucket_patches()) + def test_basic_all_reduce_bucketing(self): + """Test that independent all_reduce operations get bucketed together.""" + + def func(a, b, c): + # Three independent all_reduces that should be bucketed + ar1 = _functional_collectives.all_reduce(a, "sum", "0") + ar2 = _functional_collectives.all_reduce(b, "sum", "0") + ar3 = _functional_collectives.all_reduce(c, "sum", "0") + + return ar1.sum() + ar2.sum() + ar3.sum() + + with _dynamo_dist_per_rank_init( + self.rank, + self.world_size, + self.backend(device_type), + fake_pg=not at_least_x_gpu(2), + ): + a = torch.ones(4, 4, dtype=torch.float, device=device_type) + self.rank + b = torch.ones(4, 4, dtype=torch.float, device=device_type) * 2 + c = torch.ones(4, 4, dtype=torch.float, device=device_type) * 3 + + compiled = torch.compile(func) + out, aten_graph_str = run_and_get_aten_graph(compiled, a, b, c) + + # Should see a single bucketed all_reduce + FileCheck().check_count( + "torch.ops._c10d_functional.wait_tensor.default", 1, exactly=True + ).run(aten_graph_str) + + # Verify correctness + correct = func(a, b, c) + self.assertTrue(same(out, correct)) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/distributed/test_c10d_ops_nccl.py b/test/distributed/test_c10d_ops_nccl.py index 4dd4fc72361cf..b559665643fd5 100644 --- a/test/distributed/test_c10d_ops_nccl.py +++ b/test/distributed/test_c10d_ops_nccl.py @@ -278,7 +278,7 @@ def test_alltoall_ops_with_cudafree_race(self): tmp.append(torch.rand(10 ** (3 + i), device=local_device)) race_tensors.append(tmp) - for i in range(10): + for _ in range(10): race_tensors.pop() work = pg.alltoall_base(output, input, [], [], opts) # this triggers cudaFree diff --git a/test/distributed/test_c10d_spawn_gloo.py b/test/distributed/test_c10d_spawn_gloo.py index 9a9dacf22cc23..c4667bb5dd486 100644 --- a/test/distributed/test_c10d_spawn_gloo.py +++ b/test/distributed/test_c10d_spawn_gloo.py @@ -195,7 +195,7 @@ def test_gather(self): for i, t in enumerate(tensors): self.assertEqual(t, torch.ones(5, 5, device=device) + i) elif self.rank == 0: - for i, t in enumerate(tensors): + for t in tensors: zeros = torch.zeros(5, 5, device=device) self.assertEqual(t, zeros) y = torch.sum(torch.stack(tensors), axis=0) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 62e5143d06226..ac3103e09341d 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -23,6 +23,7 @@ sink_waits_iterative, ) from torch._inductor.compile_fx import compile_fx as inductor_compile_fx +from torch._inductor.fx_passes.bucketing import is_all_gather_into_tensor from torch._inductor.scheduler import ( _get_mm_like_fn, BaseSchedulerNode, @@ -2197,6 +2198,48 @@ def test_sync_decision_cross_ranks(self): saved_values = _sync_decision_cross_ranks(test_graph, saved_values) self.assertEqual(saved_values, [wt1]) + @skip_if_lt_x_gpu(2) + def test_comm_analysis(self): + store = c10d.FileStore(self.file_name, self.world_size) + torch.cuda.set_device(self.rank) + c10d.init_process_group( + backend="nccl", store=store, rank=self.rank, world_size=self.world_size + ) + group = c10d.distributed_c10d._get_default_group() + group_name = "default" + torch._C._distributed_c10d._register_process_group( + group_name, torch.distributed.group.WORLD + ) + group_size = group.size() + + def func(inp, group_size, group_name): + ag_0_out = torch.ops._c10d_functional.all_gather_into_tensor( + inp, group_size, group_name + ) + ag_0_wait = torch.ops.c10d_functional.wait_tensor(ag_0_out) + ag_1_out = torch.ops._c10d_functional.all_gather_into_tensor( + ag_0_wait, group_size, group_name + ) + ag_1_wait = torch.ops.c10d_functional.wait_tensor(ag_1_out) + return ag_1_wait + + gm = make_fx(func)(torch.ones(4, 4, device=self.device), group_size, group_name) + g = gm.graph + for n in g.nodes: + if is_all_gather_into_tensor(n): + from torch._inductor.comm_analysis import ( + estimate_nccl_collective_runtime_from_fx_node, + ) + + est_ms = estimate_nccl_collective_runtime_from_fx_node( + n, use_nccl_estimator=False + ) + assert est_ms > 0 + est_ms_nccl = estimate_nccl_collective_runtime_from_fx_node( + n, use_nccl_estimator=True + ) + assert est_ms_nccl > 0 + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/distributed/test_overlap_bucketing_unit.py b/test/distributed/test_overlap_bucketing_unit.py new file mode 100644 index 0000000000000..c26bf0e93bab4 --- /dev/null +++ b/test/distributed/test_overlap_bucketing_unit.py @@ -0,0 +1,572 @@ +# Owner(s): ["module: inductor"] +import unittest + +import torch +import torch._dynamo +import torch._dynamo.logging +import torch._dynamo.test_case +import torch.distributed as dist +import torch.fx as fx + +# for some reason importing functional collectives after dynamo breaks collectives handling! +from torch._C import FileCheck +from torch._inductor.test_case import TestCase as InductorTestCase +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import make_fx +from torch.testing._internal.common_distributed import requires_accelerator_dist_backend +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, +) +from torch.testing._internal.inductor_utils import HAS_GPU +from torch.utils._ordered_set import OrderedSet + + +# flake8: noqa: B950 +# Owner(s): ["module: inductor"] + + +aten = torch.ops.aten + +from torch.testing._internal.common_fsdp import get_devtype + + +device_type = str(get_devtype()) + + +import torch +import torch._dynamo +import torch._dynamo.logging +import torch._dynamo.test_case + + +# for some reason importing functional collectives after dynamo breaks collectives handling! + + +@requires_accelerator_dist_backend(["nccl", "xccl"]) +def build_collective_info(graph, hiding_annotations): + """ + Build CollectiveInfo dict from manual hiding annotations. + + hiding_annotations: dict mapping collective_start -> hiding_compute_node + """ + from torch._inductor.fx_passes.overlap_scheduling import CollectiveInfo + + collective_info = {} + + # Find all collective starts and their corresponding waits + start_to_wait = {} + for node in graph.nodes: + if node.op == "call_function" and "wait_tensor" in str(node.target): + wait_input = node.args[0] + if isinstance(wait_input, fx.Node): + start_to_wait[wait_input] = node + + # Build CollectiveInfo for each collective + for start_node, wait_node in start_to_wait.items(): + hiding_node = hiding_annotations.get(start_node) + + # Estimate size and time + size_bytes = 16 * 4 # 4x4 tensor of floats + estimated_time_ms = 1.0 # Dummy time + exposed_time_ms = 0.0 if hiding_node else 1.0 # Hidden if has hiding_node + + collective_info[start_node] = CollectiveInfo( + start_node=start_node, + wait_node=wait_node, + size_bytes=size_bytes, + estimated_time_ms=estimated_time_ms, + exposed_time_ms=exposed_time_ms, + hiding_node=hiding_node, + ) + + return collective_info + + +def compute_ancestors(graph): + """Compute ancestor sets for all nodes in the graph.""" + node_ancestors = {} + + for node in graph.nodes: + ancestors = OrderedSet() + stack = list(node.all_input_nodes) + visited = set() + + while stack: + current = stack.pop() + if current in visited: + continue + visited.add(current) + ancestors.add(current) + stack.extend(current.all_input_nodes) + + node_ancestors[node] = ancestors + + return node_ancestors + + +@requires_accelerator_dist_backend() +@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") +@instantiate_parametrized_tests +class TestOverlapPreservingBucketing(InductorTestCase): + """ + Unit tests for overlap-preserving bucketing pass. + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + from torch.testing._internal.distributed.fake_pg import FakeStore + + store = FakeStore() + dist.init_process_group(backend="fake", rank=0, world_size=2, store=store) + cls.device = "cuda" + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + dist.destroy_process_group() + + def test_can_bucket_independent_collectives(self): + """ + Test that independent collectives with separate hiding nodes CAN bucket. + + Graph structure: + ag1_start -> ag2_start -> mm1 (hides ag1) -> mm2 (hides ag2) -> ag1_wait -> ag2_wait + """ + + def func(a, b): + group_name = "0" + group_size = 1 + + # Start both collectives + ag1 = torch.ops._c10d_functional.all_gather_into_tensor( + a, group_size, group_name + ) + ag2 = torch.ops._c10d_functional.all_gather_into_tensor( + b, group_size, group_name + ) + + # Independent compute that can hide both + mm1 = torch.mm(a, a) + mm2 = torch.mm(b, b) + + # Wait for both + ag1_out = torch.ops._c10d_functional.wait_tensor(ag1) + ag2_out = torch.ops._c10d_functional.wait_tensor(ag2) + + return ag1_out.sum() + ag2_out.sum() + mm1.sum() + mm2.sum() + + # Use fake mode to trace without executing + with FakeTensorMode(): + a = torch.ones(4, 4, device=self.device) + b = torch.ones(4, 4, device=self.device) * 2 + + # Trace with make_fx + traced = make_fx(func)(a, b) + + # Find nodes using find_nodes + ag1, ag2 = traced.graph.find_nodes( + op="call_function", + target=torch.ops._c10d_functional.all_gather_into_tensor.default, + ) + mm1, mm2 = traced.graph.find_nodes( + op="call_function", target=torch.ops.aten.mm.default + ) + + # Manually annotate hiding relationships + hiding_annotations = { + ag1: mm1, # mm1 hides ag1 + ag2: mm2, # mm2 hides ag2 + } + + # Build collective info and ancestors + collective_info = build_collective_info(traced.graph, hiding_annotations) + node_ancestors = compute_ancestors(traced.graph) + scheduled = OrderedSet(traced.graph.nodes) + + # Run bucketing + from torch._inductor.fx_passes.overlap_preserving_bucketer import ( + OverlapPreservingBucketer, + ) + + bucketer = OverlapPreservingBucketer( + traced.graph, + collective_info, + node_ancestors, + scheduled, + ) + bucketer.bucket_collectives() + + # Verify: should have 1 bucketed collective (all_gather_into_tensor_out) + graph_str = str(traced.graph) + FileCheck().check_count("all_gather_into_tensor_out", 1, exactly=False).run( + graph_str + ) + + def test_cant_bucket_nested_hiding_intervals(self): + """ + Test that nested hiding intervals prevent bucketing. + + Graph structure: + ag1_start -> ag2_start -> mm2 (hides ag2) -> ag2_wait -> mm1 (hides ag1) -> ag1_wait + + ag2's hiding interval is nested inside ag1's hiding interval. + """ + + def func(a, b): + group_name = "0" + group_size = 1 + + # ag1 starts first + ag1 = torch.ops._c10d_functional.all_gather_into_tensor( + a, group_size, group_name + ) + + # ag2 starts (inside ag1's interval) + ag2 = torch.ops._c10d_functional.all_gather_into_tensor( + b, group_size, group_name + ) + + # mm2 hides ag2 + mm2 = torch.mm(b[:2, :2], b[:2, :2]) + + # ag2 waits (still inside ag1's interval) + ag2_out = torch.ops._c10d_functional.wait_tensor(ag2) + + # mm1 uses ag2's result and hides ag1 + mm1 = torch.mm(a + ag2_out[:4, :4], a) + + # ag1 waits last + ag1_out = torch.ops._c10d_functional.wait_tensor(ag1) + + return ag1_out.sum() + ag2_out.sum() + mm1.sum() + mm2.sum() + + # Use fake mode to trace without executing + with FakeTensorMode(): + a = torch.ones(4, 4, device=self.device) + b = torch.ones(4, 4, device=self.device) * 2 + + # Trace with make_fx + traced = make_fx(func)(a, b) + + # Find nodes using find_nodes + ag1, ag2 = traced.graph.find_nodes( + op="call_function", + target=torch.ops._c10d_functional.all_gather_into_tensor.default, + ) + mm_nodes = traced.graph.find_nodes( + op="call_function", target=torch.ops.aten.mm.default + ) + # mm2 is the first mm, mm1 is the second (based on graph order) + mm2 = mm_nodes[0] + mm1 = mm_nodes[1] + + # Manually annotate hiding relationships + hiding_annotations = { + ag1: mm1, # mm1 hides ag1 + ag2: mm2, # mm2 hides ag2 + } + + # Build collective info and ancestors + collective_info = build_collective_info(traced.graph, hiding_annotations) + node_ancestors = compute_ancestors(traced.graph) + scheduled = OrderedSet(traced.graph.nodes) + + # Run bucketing + from torch._inductor.fx_passes.overlap_preserving_bucketer import ( + OverlapPreservingBucketer, + ) + + bucketer = OverlapPreservingBucketer( + traced.graph, + collective_info, + node_ancestors, + scheduled, + ) + bucketer.bucket_collectives() + + # Verify: nested hiding intervals should prevent bucketing + # Should have 2 separate all_gathers, not 1 bucketed one + graph_str = str(traced.graph) + FileCheck().check_count("all_gather_into_tensor", 2, exactly=False).run( + graph_str + ) + + @parametrize("final_mm_hidden", (True, False)) + def test_cant_bucket_ag_with_rs_hiding_interval_between(self, final_mm_hidden): + """ + Test that all_gathers can't bucket when a reduce_scatter's hiding interval is between them. + + Graph structure: + ag1_start -> mm1 (hides ag1) -> ag1_wait -> + rs_start -> mm2 (hides rs) -> rs_wait -> + + if final_mm_hidden: + ag2_start -> mm3 (hides ag2) -> ag2_wait + + if final_mm_hidden: + Bucketing ag1 and ag2 would require moving one of them, which would break hiding relationships: + - Moving ag2 earlier would break ag2's hiding by mm3 + - Moving ag1 later would break ag1's hiding by mm1 + - The rs hiding interval creates an obstacle between them + + otherwise, we can bucket + """ + + def func(a, b, c): + group_name = dist.distributed_c10d._get_default_group().group_name + group_size = 1 + + # First all_gather + ag1 = torch.ops._c10d_functional.all_gather_into_tensor( + a, group_size, group_name + ) + mm1 = torch.mm(a, a) # hides ag1 + ag1_out = torch.ops._c10d_functional.wait_tensor(ag1) + + # Reduce scatter in between + rs = torch.ops._c10d_functional.reduce_scatter_tensor( + b, "sum", group_size, group_name + ) + mm2 = torch.mm(b[:4, :4], b[:4, :4]) # hides rs + rs_out = torch.ops._c10d_functional.wait_tensor(rs) + + # Second all_gather + ag2 = torch.ops._c10d_functional.all_gather_into_tensor( + c, group_size, group_name + ) + mm3 = torch.mm(c, c) # hides ag2 + ag2_out = torch.ops._c10d_functional.wait_tensor(ag2) + + return ag1_out.sum() + rs_out.sum() + ag2_out.sum(), mm1, mm2, mm3 + + # Use fake mode to trace without executing + with FakeTensorMode(): + a = torch.ones(4, 4, device=self.device) + b = torch.ones(8, 4, device=self.device) + c = torch.ones(4, 4, device=self.device) + + # Trace with make_fx + traced = make_fx(func)(a, b, c) + + ag1, ag2 = traced.graph.find_nodes( + op="call_function", + target=torch.ops._c10d_functional.all_gather_into_tensor.default, + ) + (rs,) = traced.graph.find_nodes( + op="call_function", + target=torch.ops._c10d_functional.reduce_scatter_tensor.default, + ) + mm1, mm2, mm3 = traced.graph.find_nodes( + op="call_function", target=torch.ops.aten.mm.default + ) + + # Manually annotate hiding relationships + hiding_annotations = { + ag1: mm1, # mm1 hides ag1 + # rs: mm2, # mm2 hides rs + ag2: mm3, + } + if final_mm_hidden: + hiding_annotations[rs] = mm2 + + # Build collective info and ancestors + collective_info = build_collective_info(traced.graph, hiding_annotations) + node_ancestors = compute_ancestors(traced.graph) + scheduled = OrderedSet(traced.graph.nodes) + + # Run bucketing logic to find buckets (without applying them, which would require process groups) + from torch._inductor.fx_passes.overlap_preserving_bucketer import ( + OverlapPreservingBucketer, + ) + + bucketer = OverlapPreservingBucketer( + traced.graph, + collective_info, + node_ancestors, + scheduled, + ) + + bucketer.bucket_collectives() + + graph_str = str(traced.graph) + + # check order of mms preserved + FileCheck().check("%mm").check("%mm_1").check("%mm_2").run(graph_str) + + if final_mm_hidden: + # Should NOT bucket - 2 separate all_gathers + # Count all_gather node names (works even when wrapped in control_deps) + FileCheck().check_count("%all_gather_into_tensor", 2, exactly=False).run( + graph_str + ) + else: + # Should bucket - 1 bucketed all_gather (all_gather_into_tensor_out) + FileCheck().check_count( + "%all_gather_into_tensor_out", 1, exactly=False + ).run(graph_str) + + def test_can_bucket_all_reduce(self): + """ + Test that all_reduce operations CAN bucket together. + + Graph structure: + ar1_start -> ar2_start -> mm1 (hides ar1) -> mm2 (hides ar2) -> ar1_wait -> ar2_wait + """ + + def func(a, b): + group_name = "0" + + # Start both all_reduce operations + ar1 = torch.ops._c10d_functional.all_reduce(a, "sum", group_name) + ar2 = torch.ops._c10d_functional.all_reduce(b, "sum", group_name) + + # Independent compute that can hide both + mm1 = torch.mm(a, a) + mm2 = torch.mm(b, b) + + # Wait for both + ar1_out = torch.ops._c10d_functional.wait_tensor(ar1) + ar2_out = torch.ops._c10d_functional.wait_tensor(ar2) + + return ar1_out.sum() + ar2_out.sum() + mm1.sum() + mm2.sum() + + # Use fake mode to trace without executing + with FakeTensorMode(): + a = torch.ones(4, 4, device=self.device) + b = torch.ones(4, 4, device=self.device) * 2 + + # Trace with make_fx + traced = make_fx(func)(a, b) + + # Find nodes + ar1, ar2 = traced.graph.find_nodes( + op="call_function", + target=torch.ops._c10d_functional.all_reduce.default, + ) + mm1, mm2 = traced.graph.find_nodes( + op="call_function", target=torch.ops.aten.mm.default + ) + + # For all_reduce, start_node == wait_node (no separate wait) + hiding_annotations = { + ar1: mm1, + ar2: mm2, + } + + # Build collective info + collective_info = build_collective_info(traced.graph, hiding_annotations) + node_ancestors = compute_ancestors(traced.graph) + scheduled = OrderedSet(traced.graph.nodes) + + # Run bucketing + from torch._inductor.fx_passes.overlap_preserving_bucketer import ( + OverlapPreservingBucketer, + ) + + bucketer = OverlapPreservingBucketer( + traced.graph, + collective_info, + node_ancestors, + scheduled, + ) + bucketer.bucket_collectives() + + # Verify: should have 1 bucketed all_reduce + # After bucketing, there should be only one all_reduce node (the bucketed one) + graph_str = str(traced.graph) + FileCheck().check_count("%all_reduce", 1, exactly=True).check_count( + "%mm", 2 + ).run(graph_str) + + def test_can_bucket_multidtype_collectives(self): + """ + Test that all_gathers with different dtypes CAN bucket together. + + Graph structure: + ag1_float32 -> mm1 (hides ag1) -> ag1_wait + ag2_bfloat16 -> mm2 (hides ag2) -> ag2_wait + """ + + def func(a, b): + group_name = "0" + group_size = 1 + + # Start both collectives with different dtypes + ag1 = torch.ops._c10d_functional.all_gather_into_tensor( + a, + group_size, + group_name, # float32 + ) + ag2 = torch.ops._c10d_functional.all_gather_into_tensor( + b, + group_size, + group_name, # bfloat16 + ) + + # Independent compute that can hide both + mm1 = torch.mm(a, a) + mm2 = torch.mm(b.float(), b.float()) + + # Wait for both + ag1_out = torch.ops._c10d_functional.wait_tensor(ag1) + ag2_out = torch.ops._c10d_functional.wait_tensor(ag2) + + return ag1_out.sum() + ag2_out.sum() + mm1.sum() + mm2.sum() + + # Use fake mode to trace without executing + with FakeTensorMode(): + a = torch.ones(4, 4, device=self.device, dtype=torch.float32) + b = torch.ones(4, 4, device=self.device, dtype=torch.bfloat16) + + # Trace with make_fx + traced = make_fx(func)(a, b) + + # Find nodes using find_nodes + ag1, ag2 = traced.graph.find_nodes( + op="call_function", + target=torch.ops._c10d_functional.all_gather_into_tensor.default, + ) + mm_nodes = traced.graph.find_nodes( + op="call_function", target=torch.ops.aten.mm.default + ) + mm1 = mm_nodes[0] + mm2 = mm_nodes[1] + + # Manually annotate hiding relationships + hiding_annotations = { + ag1: mm1, # mm1 hides ag1 + ag2: mm2, # mm2 hides ag2 + } + + # Build collective info and ancestors + collective_info = build_collective_info(traced.graph, hiding_annotations) + node_ancestors = compute_ancestors(traced.graph) + scheduled = OrderedSet(traced.graph.nodes) + + # Run bucketing with multidtype mode + from torch._inductor.fx_passes.overlap_preserving_bucketer import ( + OverlapPreservingBucketer, + ) + + bucketer = OverlapPreservingBucketer( + traced.graph, + collective_info, + node_ancestors, + scheduled, + bucket_mode="custom_ops_multidtype", + ) + bucketer.bucket_collectives() + + # Verify: should have 1 bucketed collective (all_gather_into_tensor_out) + # even though dtypes are different + graph_str = str(traced.graph) + FileCheck().check_count("all_gather_into_tensor_out", 1, exactly=False).run( + graph_str + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/test_aot_autograd.py b/test/dynamo/test_aot_autograd.py index a51e28e37a098..47e2be2d17c19 100644 --- a/test/dynamo/test_aot_autograd.py +++ b/test/dynamo/test_aot_autograd.py @@ -1718,6 +1718,39 @@ def repro(sentinel: torch.Tensor, skip_squeeze: bool = False) -> torch.Tensor: self.assertEqual(eager_no_sq, comp_ind_no_sq) self.assertEqual(eager_no_sq.stride(), comp_ind_no_sq.stride()) + @torch._dynamo.config.patch(capture_scalar_outputs=True) + @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) + def test_unbacked_activation_specialized_in_inductor(self): + """Test compilation with unbacked operations like nonzero.""" + torch._dynamo.reset() + + def fuzzed_program(arg_0, sentinel): + var_node_1 = arg_0 + var_node_5 = torch.full((1, 2), -66, dtype=torch.int32) + var_node_6 = torch.full((1, 2), 77, dtype=torch.int64) + var_node_4 = torch.ops.aten.add(var_node_5, var_node_6) + var_node_7 = torch.full((1, 2), -64, dtype=torch.int32) + var_node_3 = torch.ops.aten.mul(var_node_4, var_node_7) + var_node_9 = torch.full((3, 4), False, dtype=torch.bool) + var_node_8 = torch.nonzero(var_node_9) + var_node_2 = torch.ops.aten.add(var_node_3, var_node_8) + var_node_0 = torch.ops.aten.div(var_node_1, var_node_2) + result = var_node_0 * sentinel + if result.is_complex(): + result = result.real + return result + + sentinel = torch.tensor(1.0, requires_grad=True) + arg_0 = torch.randint(0, 3, (1, 2), dtype=torch.int64) + args = (arg_0,) + (sentinel,) + + result_original = fuzzed_program(*args) + + compiled_program = torch.compile(fuzzed_program, fullgraph=True, dynamic=True) + result_compiled = compiled_program(*args) + + self.assertTrue(torch.allclose(result_original, result_compiled)) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_aot_autograd_cache.py b/test/dynamo/test_aot_autograd_cache.py index 04af76c90c529..40b8b1a5b6c71 100644 --- a/test/dynamo/test_aot_autograd_cache.py +++ b/test/dynamo/test_aot_autograd_cache.py @@ -41,6 +41,20 @@ from torch.testing._internal.two_tensor import TwoTensor +def aot_eager_regional_inductor(): + """ + Regional inductor backend for AOT autograd. + Uses regional_inductor as both forward and backward compiler. + """ + from torch._dynamo.backends.common import aot_autograd + from torch.fx.passes.regional_inductor import regional_inductor + + return aot_autograd( + fw_compiler=regional_inductor, + bw_compiler=regional_inductor, + ) + + def saved_tensors_hooks_to_gm( pack_fn, unpack_fn, @@ -1898,6 +1912,171 @@ def fn(x, y): # no recompiles self.assertFalse(counters) + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + @functorch_config.patch({"bundled_autograd_cache": True}) + def test_regional_inductor_basic(self): + """ + Basic test for regional inductor with bundled autograd cache. + Tests that regional inductor compilation results can be cached and hit. + """ + import torch.fx.traceback as fx_traceback + + def fn(x, y): + sin = torch.sin(x) + # Mark this region to be compiled with inductor + with fx_traceback.annotate({"compile_with_inductor": 0}): + mul = sin * y + add = mul + 1 + return torch.sin(add) + + x = torch.randn(10, device="cpu") + y = torch.randn(10, device="cpu") + + # Compile with regional inductor backend + compiled_fn = torch.compile( + fn, backend=aot_eager_regional_inductor(), fullgraph=True + ) + + # First call should miss in cache + result1 = compiled_fn(x, y) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + # Second call should hit (after clearing dynamo) + self._clear_dynamo_and_codecache() + result2 = compiled_fn(x, y) + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) + + # Results should be the same + self.assertEqual(result1, result2) + + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + @functorch_config.patch({"bundled_autograd_cache": True}) + def test_regional_inductor_with_backward(self): + """ + Test regional inductor with backward pass and bundled autograd cache. + Note: Regional inductor triggers multiple AOT autograd compilations: + - One for the outer graph (with regional inductor backend) + - One for each marked region (via standalone_compile) + """ + import torch.fx.traceback as fx_traceback + + def fn(x, y): + sin = torch.sin(x) + # Mark this region to be compiled with inductor + with fx_traceback.annotate({"compile_with_inductor": 0}): + mul = sin * y + add = mul + 1 + return torch.sin(add) + + x = torch.randn(10, requires_grad=True) + y = torch.randn(10, requires_grad=True) + x2 = x.detach().clone().requires_grad_(True) + y2 = y.detach().clone().requires_grad_(True) + + # Compile with regional inductor backend + compiled_fn = torch.compile( + fn, backend=aot_eager_regional_inductor(), fullgraph=True + ) + + # First call: AOT autograd compiles the outer graph (1 miss) + # Regional inductor then compiles the marked region (1 more miss) + result1 = compiled_fn(x, y) + result1.sum().backward() + + # We expect 2 cache misses: outer graph + marked region + initial_misses = counters["aot_autograd"]["autograd_cache_miss"] + initial_saves = counters["aot_autograd"]["autograd_cache_saved"] + self.assertGreater(initial_misses, 0) + self.assertGreater(initial_saves, 0) + + # Second call should hit (after clearing dynamo) + self._clear_dynamo_and_codecache() + result2 = compiled_fn(x2, y2) + result2.sum().backward() + + # Should have cache hits now + final_hits = counters["aot_autograd"]["autograd_cache_hit"] + self.assertGreater(final_hits, 0) + + # Cache misses and saves should not increase + self.assertEqual( + counters["aot_autograd"]["autograd_cache_miss"], initial_misses + ) + self.assertEqual( + counters["aot_autograd"]["autograd_cache_saved"], initial_saves + ) + + # Results and gradients should be the same + self.assertEqual(result1, result2) + self.assertEqual(x.grad, x2.grad) + self.assertEqual(y.grad, y2.grad) + + @inductor_config.patch("fx_graph_remote_cache", False) + @inductor_config.patch("fx_graph_cache", True) + @functorch_config.patch({"enable_autograd_cache": True}) + @functorch_config.patch({"bundled_autograd_cache": True}) + def test_regional_inductor_cache_miss_on_change(self): + """ + Test that changing the function causes a cache miss with regional inductor. + Regional inductor creates multiple AOT compilations, so we track + the change in cache misses rather than absolute counts. + """ + import torch.fx.traceback as fx_traceback + + def fn1(x, y): + sin = torch.sin(x) + with fx_traceback.annotate({"compile_with_inductor": 0}): + mul = sin * y + add = mul + 1 + return torch.sin(add) + + def fn2(x, y): + sin = torch.sin(x) + with fx_traceback.annotate({"compile_with_inductor": 0}): + mul = sin * y + add = mul + 2 # Changed from +1 to +2 + return torch.sin(add) + + x = torch.randn(10) + y = torch.randn(10) + + # Compile first function + compiled_fn1 = torch.compile( + fn1, backend=aot_eager_regional_inductor(), fullgraph=True + ) + result1 = compiled_fn1(x, y) + first_misses = counters["aot_autograd"]["autograd_cache_miss"] + first_saves = counters["aot_autograd"]["autograd_cache_saved"] + self.assertGreater(first_misses, 0) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertGreater(first_saves, 0) + + # Compile second function (different graph) + self._clear_dynamo_and_codecache() + compiled_fn2 = torch.compile( + fn2, backend=aot_eager_regional_inductor(), fullgraph=True + ) + result2 = compiled_fn2(x, y) + # Should miss because graph is different (more misses than before) + self.assertGreater( + counters["aot_autograd"]["autograd_cache_miss"], first_misses + ) + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) + self.assertGreater( + counters["aot_autograd"]["autograd_cache_saved"], first_saves + ) + + # Results should be different + self.assertNotEqual(result1, result2) + @functorch_config.patch({"bundled_autograd_cache": True}) class AOTAutogradCacheBundledTests(AOTAutogradCacheTests): diff --git a/test/dynamo/test_aot_compile.py b/test/dynamo/test_aot_compile.py index c822569f62484..8f39435b922ae 100644 --- a/test/dynamo/test_aot_compile.py +++ b/test/dynamo/test_aot_compile.py @@ -582,6 +582,23 @@ def test_aot_compile_with_super_call(self): actual = compiled_fn(fn, *inputs) self.assertEqual(expected, actual) + def test_aot_compile_with_default_args(self): + def fn(x, y=1): + return x + x + + compiled_fn = torch.compile(fn, fullgraph=True).aot_compile( + ((torch.randn(3, 4),), {}) + ) + inputs = (torch.randn(3, 4),) + expected = fn(*inputs) + actual = compiled_fn(*inputs) + self.assertEqual(expected, actual) + compiled_fn.save_compiled_function(self.path()) + with open(self.path(), "rb") as f: + compiled_fn = torch.compiler.load_compiled_function(f) + actual = compiled_fn(*inputs) + self.assertEqual(expected, actual) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_compile.py b/test/dynamo/test_compile.py index 7df0ba2f1d3e4..1f7290c51dd8d 100644 --- a/test/dynamo/test_compile.py +++ b/test/dynamo/test_compile.py @@ -234,27 +234,6 @@ def fn(x, y): with self.assertRaises(IndexError): fn(torch.randn(10), 99) - def test_list_bad_weakref(self): - import weakref - - a = torch.Event() - with self.assertRaises(TypeError): - weakref.ref(a) - - @torch.compile(backend="eager") - class Mod(torch.nn.Module): - def __init__(self, event): - super().__init__() - self.event = event - - def forward(self, x): - return x * int(self.event.query()) - - e = torch.Event() - m = Mod(e) - a = torch.randn(10) - self.assertEqual(m(a), a) - # The private variants of the below functions are extensively tested # So as long as the signatures match we're good diff --git a/test/dynamo/test_compiler_bisector.py b/test/dynamo/test_compiler_bisector.py index 9ce4d714fbd97..8810a30aaf3b7 100644 --- a/test/dynamo/test_compiler_bisector.py +++ b/test/dynamo/test_compiler_bisector.py @@ -275,6 +275,59 @@ def test_fn(): self.assertEqual(out.backend, "eager") self.assertEqual(out.subsystem, None) + @config.patch( + { + "test_configs.bisect_pre_grad_graph": True, + "test_configs.bisect_keep_custom_backend_for_inductor": True, + } + ) + def test_bisect_pre_grad_graph(self): + def f(x): + for i in range(5): + x = x + 1 + return x.relu() + + class MyBackend: + def __call__(self, gm, example_inputs): + node_idx = 0 + + def node_to_graph_id(node): + nonlocal node_idx + out = 0 if node_idx < 3 else 1 + node_idx += 1 + return out + + split_gm = torch.fx.passes.split_module.split_module( + gm, None, node_to_graph_id, keep_original_order=True + ) + + for name, submod in split_gm.named_modules(): + if "submod_" in name: + # the test case is simple enough that using + # the original example_inputs works for sub + # moule + submod.forward = torch._inductor.standalone_compile( + submod, + example_inputs, + dynamic_shapes="from_example_inputs", + options={}, + ) + + return split_gm + + def test_fn(): + torch._dynamo.reset() + + x = torch.randn(1024, device="cuda") + with config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy"): + opt_f = torch.compile(f, backend=MyBackend()) + return torch.allclose(opt_f(x), f(x)) + + out = CompilerBisector.do_bisect(test_fn) + self.assertEqual(out.backend, "inductor") + self.assertEqual(out.subsystem, "pre_grad_graph") + self.assertEqual(out.bisect_number, 1) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index 2311eac402c71..0433354b953b9 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -230,7 +230,7 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 20) + self.assertExpectedInline(str(cnts.op_count), """9""") @unittest.expectedFailure # https://github.com/pytorch/pytorch/issues/118204 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") @@ -335,7 +335,7 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 37) + self.assertExpectedInline(str(cnts.op_count), """15""") @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") def test_cuda_stream_compared_with_constant(self): @@ -517,7 +517,7 @@ def fn(x, cur_stream, new_stream): res = opt_fn(x, cur_stream, new_stream) self.assertEqual(ref, res) self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 27) + self.assertExpectedInline(str(cnts.op_count), """16""") @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") def test_cuda_event_method(self): @@ -537,7 +537,7 @@ def fn(x): with torch.cuda.stream(new_stream): x = torch.add(x, 4) - new_event = torch.cuda.Event() + new_event = torch.Event() new_event.record(new_stream) new_event.wait(cur_stream) @@ -557,7 +557,7 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) self.assertEqual(cnts.frame_count, 1) - self.assertEqual(cnts.op_count, 27) + self.assertExpectedInline(str(cnts.op_count), """16""") @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") def test_cuda_device(self): diff --git a/test/dynamo/test_debug_utils.py b/test/dynamo/test_debug_utils.py index eae4d06d98904..692ec0884399f 100644 --- a/test/dynamo/test_debug_utils.py +++ b/test/dynamo/test_debug_utils.py @@ -4,10 +4,10 @@ from unittest.mock import patch import torch -from functorch import make_fx from torch._dynamo import debug_utils from torch._dynamo.debug_utils import aot_graph_input_parser, generate_env_vars_string from torch._dynamo.test_case import TestCase +from torch.fx.experimental.proxy_tensor import make_fx from torch.testing._internal.common_device_type import instantiate_device_type_tests diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 2b626132103a1..0eb21c9cef068 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -2064,6 +2064,23 @@ def f(): self.assertEqual(f(), 1) + def test_error_on_graph_break_nonempty_checkpoint(self): + cnts = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnts) + def fn(x): + x = x + 1 + x = x + 1 + x = x + 1 + with torch._dynamo.error_on_graph_break(True): + torch._dynamo.graph_break() + return x + 1 + + with self.assertRaises(Unsupported): + fn(torch.ones(3)) + + self.assertEqual(cnts.frame_count, 0) + def test_nested_compile_fullgraph(self): # Test that fullgraph=True cannot be toggled back by fullgraph=False inp = torch.ones(3) diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index be605ccdd1e18..966acd1d81394 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -341,7 +341,7 @@ def test_ordered_dict_reordered_keys(self): def fn(x, d): y = 0 - for idx, (key, value) in enumerate(d.items()): + for idx, value in enumerate(d.values()): if idx == 0: y += torch.sin(x * value) else: @@ -366,7 +366,7 @@ def keys(self): def fn(x, d): y = 0 - for idx, (key, value) in enumerate(d.items()): + for idx, value in enumerate(d.values()): if idx == 0: y += torch.sin(x * value) else: @@ -847,7 +847,7 @@ def fn(x): d = {"a": 2, "b": 3, "c": 5 * x} mp = types.MappingProxyType(d) y = torch.sin(x * mp["a"]) - for k, v in mp.items(): # noqa: PERF102 + for v in mp.values(): y += torch.cos(x * v) return mp @@ -864,7 +864,7 @@ def test_mapping_proxy_for_nonlocal(self): def fn(x): mp = types.MappingProxyType(d) y = torch.sin(x * mp["a"]) - for k, v in mp.items(): # noqa: PERF102 + for v in mp.values(): y += torch.cos(x * v) d["d"] = 4 return mp @@ -885,7 +885,7 @@ def test_mapping_proxy_existing(self): def fn(x, mp): y = torch.sin(x * mp["a"]) - for k, v in mp.items(): # noqa: PERF102 + for v in mp.values(): y += torch.cos(x * v) if isinstance(mp, types.MappingProxyType): y *= 2 @@ -1100,6 +1100,20 @@ def f(x): self.assertEqual(ref, res) + def test_iter_default_dict(self): + def f(x): + d = defaultdict(list) + d[0] = 42 + for k in d: + d[k] += 1 + return x + 1, d + + x = torch.ones(2) + ref = f(x) + res = torch.compile(f, backend="eager", fullgraph=True)(x) + + self.assertEqual(ref, res) + @parametrize("op", ["or_", "and_", "xor", "sub"]) def test_dict_keys_binop(self, op): op = getattr(operator, op) @@ -1623,6 +1637,12 @@ def test_dict_type_comparison(self): self.assertNotEqual(self.thetype, other) self.assertTrue(self.thetype is not other, f"{self.thetype=}, {other=}") + @make_dynamo_test + def test_dict___iter__(self): + d = self.thetype({1: 2}) + it = d.__iter__() + self.assertEqual(next(it), 1) + class DictSubclassMethodsTests(DictMethodsTests): thetype = SimpleDict diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index 17e28d38001c7..2e4702e668e93 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -1159,6 +1159,7 @@ def fn(x): torch._dynamo.graph_break() NOTE: the most recent `torch.compile` tracing attempt might not be where you applied `torch.compile`! This is due to how graph breaks are implemented - the optimized code object returned by Dynamo will call another Dynamo-generated resume function and tracing is re-enabled by calling the resume function as a normal Python function, which Dynamo intercepts as a top-level frame. + Most recent bytecode instructions traced (max 20): TRACE RESUME 0 [] TRACE LOAD_FAST 'x' [] @@ -1172,7 +1173,8 @@ def fn(x): TRACE LOAD_GLOBAL 'torch' [] TRACE LOAD_ATTR '_dynamo' [LazyVariableTracker(unrealized: )] TRACE LOAD_ATTR 'graph_break' [LazyVariableTracker(unrealized: )] -TRACE CALL 0 [NullVariable, LazyVariableTracker(unrealized: )]""", +TRACE CALL 0 [NullVariable, LazyVariableTracker(unrealized: )] +""", ) @torch._dynamo.config.patch(verbose=True) @@ -1234,17 +1236,28 @@ def f3(x): self.assertIn("Foo().attr = x # 1", records[-1].getMessage()) def post_munge(s): - return re.sub( + s = re.sub( r"torch_dynamo_resume_in_f(\d)_at_(\d+)", r"torch_dynamo_resume_in_f\1_at_N", s, ) + # remove most recent bytecode instructions + # DOTALL is needed to entirely remove TRACE ... lines (including the newline) + return re.sub(r"TRACE.*$", "", s, flags=re.DOTALL) self.assertExpectedInline( post_munge(munge_exc(records[-1].getMessage(), skip=0)), """\ Graph break in user code at test_error_messages.py:N -Graph Break Reason: STORE_ATTR-caused graph break +Graph Break Reason: Encountered graph break when attempting to store an object's attribute (STORE_ATTR): + +Call to `torch._dynamo.graph_break()` + Explanation: User-inserted graph break. Message: None + Hint: Remove the `torch._dynamo.graph_break()` call. + + Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` + + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html User code traceback: File "test_error_messages.py", line N, in test_graph_break_traceback_above_dynamo_shows_user_code f3(torch.randn(3)) @@ -1257,8 +1270,12 @@ def post_munge(s): File "test_error_messages.py", line N, in torch_dynamo_resume_in_f3_at_N Foo().attr = x + File "test_error_messages.py", line N, in __setattr__ + torch._dynamo.graph_break() NOTE: the most recent `torch.compile` tracing attempt might not be where you applied `torch.compile`! This is due to how graph breaks are implemented - the optimized code object returned by Dynamo will call another Dynamo-generated resume function and tracing is re-enabled by calling the resume function as a normal Python function, which Dynamo intercepts as a top-level frame. + +Most recent bytecode instructions traced (max 20): """, ) @@ -1483,6 +1500,110 @@ def bad_clean_and_assemble_instructions(instructions, *args): ): fn(torch.randn(3)) + @make_logging_test(graph_breaks=True) + def test_step_graph_break(self, records): + @torch.compile(backend="eager") + def fn(x): + x = x + 1 + x = x + 2 + torch._dynamo.step_unsupported() + return x + 4 + + fn(torch.ones(3)) + + self.assertExpectedInline( + munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + """\ +Graph break in user code at test_error_messages.py:N +Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered: + +User code traceback: + File "test_error_messages.py", line N, in test_step_graph_break + fn(torch.ones(3)) + File "test_error_messages.py", line N, in fn + torch._dynamo.step_unsupported() +""", + ) + + torch._dynamo.reset() + + with torch._dynamo.error_on_graph_break(True): + self.assertExpectedInlineMunged( + Unsupported, + lambda: fn(torch.ones(3)), + """\ +cannot resume from torch._dynamo.step_unsupported() + Explanation: traced torch._dynamo.step_unsupported(), but Dynamo is instructed to error on graph break. This graph break is used for debugging only. + Hint: Remove the torch._dynamo.step_unsupported() call. + Hint: Make sure fullgraph=False and error_on_graph_break=False. + Hint: This is likely to be a Dynamo bug. Please report an issue to PyTorch. + + Developer debug context: + + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0284.html + +from user code: + File "test_error_messages.py", line N, in fn + torch._dynamo.step_unsupported()""", + ) + + @make_logging_test(graph_breaks=True) + def test_store_attr_graph_break(self, records): + class Foo: + def __setattr__(self, name, value): + torch._dynamo.graph_break() + + @torch.compile(backend="eager") + def fn(x): + Foo().attr = x + + fn(torch.ones(3)) + + self.assertExpectedInline( + munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + """\ +Graph break in user code at test_error_messages.py:N +Graph Break Reason: Encountered graph break when attempting to store an object's attribute (STORE_ATTR): + +Call to `torch._dynamo.graph_break()` + Explanation: User-inserted graph break. Message: None + Hint: Remove the `torch._dynamo.graph_break()` call. + + Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` + + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html +User code traceback: + File "test_error_messages.py", line N, in test_store_attr_graph_break + fn(torch.ones(3)) + File "test_error_messages.py", line N, in fn + Foo().attr = x + File "test_error_messages.py", line N, in __setattr__ + torch._dynamo.graph_break() +""", + ) + + torch._dynamo.reset() + + with torch._dynamo.error_on_graph_break(True): + self.assertExpectedInlineMunged( + Unsupported, + lambda: fn(torch.ones(3)), + """\ +Call to `torch._dynamo.graph_break()` + Explanation: User-inserted graph break. Message: None + Hint: Remove the `torch._dynamo.graph_break()` call. + + Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` + + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html + +from user code: + File "test_error_messages.py", line N, in fn + Foo().attr = x + File "test_error_messages.py", line N, in __setattr__ + torch._dynamo.graph_break()""", + ) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index c06331cea7dbf..419e57a1cc280 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -82,6 +82,18 @@ def update_global(x): return x * _variable +def pos_only_fn(*args, **kwargs): + return _pos_only_fn(*args, **kwargs) + + +def _pos_only_fn(a, b=3, /, **kwargs): + return ( + a * b + kwargs.get("a", -13) * kwargs.get("b", 42), + "a" in kwargs, + "b" in kwargs, + ) + + @contextlib.contextmanager def update_global_ctx(x): try: @@ -941,6 +953,19 @@ def fn(x, arg): self.assertEqual(fn(input, [1, 2, 3]), input + 1) self.assertEqual(fn(input, (1, 2, 3)), input + 1) + def test_pos_only_args_with_same_name_in_star_kwargs(self): + opt_fn = torch.compile(pos_only_fn, backend="eager", fullgraph=True) + a = torch.randn(4) + b = torch.randn(4) + x = torch.randn(4) + y = torch.randn(4) + self.assertEqual(pos_only_fn(a), opt_fn(a)) + self.assertEqual(pos_only_fn(a, a=x), opt_fn(a, a=x)) + self.assertEqual(pos_only_fn(a, b=y), opt_fn(a, b=y)) + self.assertEqual(pos_only_fn(a, b=b, a=x), opt_fn(a, b=b, a=x)) + self.assertEqual(pos_only_fn(a, a=x, b=y), opt_fn(a, a=x, b=y)) + self.assertEqual(pos_only_fn(a, b, a=x, b=y), opt_fn(a, b, a=x, b=y)) + @make_test def test_len_constant_misc_iterables(x): a = len((1, 2, 3)) @@ -5216,6 +5241,63 @@ def forward(self, x): x = torch.randn(1) self.assertEqual(opt_mod(x), x + 1) + def test_full_with_tensor_fill_value(self): + """Test that torch.full works correctly with dynamic tensor fill_value""" + + # Test with tensor fill_value (the bug case) + def func_tensor(x): + return torch.full((2,), x, dtype=torch.float64) + + func_compiled = torch.compile(func_tensor) + + # Test with different values + x1 = torch.tensor(5.0, dtype=torch.float64) + x2 = torch.tensor(10.0, dtype=torch.float64) + + result1 = func_compiled(x1) + expected1 = torch.full((2,), x1, dtype=torch.float64) + self.assertEqual(result1, expected1) + + # This is where the bug occurred - second call reused first value + result2 = func_compiled(x2) + expected2 = torch.full((2,), x2, dtype=torch.float64) + self.assertEqual(result2, expected2) + + # Test with different dtypes + for dtype in [torch.float32, torch.float64, torch.int32, torch.int64]: + + def func_typed(x): + return torch.full((3,), x, dtype=dtype) + + func_typed_compiled = torch.compile(func_typed) + x_typed = torch.tensor(7, dtype=dtype) + result = func_typed_compiled(x_typed) + expected = torch.full((3,), x_typed, dtype=dtype) + self.assertEqual(result, expected) + + # Test with non-tensor fill_value (scalar) to ensure we didn't break existing behavior + def func_scalar(size): + return torch.full((size,), 42.0, dtype=torch.float32) + + func_scalar_compiled = torch.compile(func_scalar) + + result_scalar = func_scalar_compiled(5) + expected_scalar = torch.full((5,), 42.0, dtype=torch.float32) + self.assertEqual(result_scalar, expected_scalar) + + # Test with different scalar values + def func_scalar_param(): + # Test multiple calls with different hardcoded scalar values + a = torch.full((2,), 3.14, dtype=torch.float32) + b = torch.full((2,), 2.71, dtype=torch.float32) + return a, b + + func_scalar_param_compiled = torch.compile(func_scalar_param) + result_a, result_b = func_scalar_param_compiled() + + self.assertEqual(result_a, torch.full((2,), 3.14, dtype=torch.float32)) + self.assertEqual(result_b, torch.full((2,), 2.71, dtype=torch.float32)) + instantiate_parametrized_tests(FunctionTests) instantiate_parametrized_tests(DefaultsTests) diff --git a/test/dynamo/test_fx_graph_runnable.py b/test/dynamo/test_fx_graph_runnable.py index 47e9ee3cb888e..4994dffdddb43 100644 --- a/test/dynamo/test_fx_graph_runnable.py +++ b/test/dynamo/test_fx_graph_runnable.py @@ -363,6 +363,40 @@ def f(x, y): self._exec_and_verify_payload() + def test_metrics_context(self): + """ + When TORCH_COMPILE_DEBUG is set, provenance_tracking_level is set to 1, and + the generated fx_graph_runnable crashed with, + RuntimeError: Cannot add inductor_provenance outside of a MetricsContext + """ + import torch._inductor.config as inductor_config + + def f(x): + return x * 2 + 1 + + # Enable provenance tracking to trigger the code path that adds metrics + with inductor_config.patch( + {"trace.enabled": True, "trace.provenance_tracking_level": 1} + ): + x = torch.randn(4, 4) + torch.compile(f)(x) + self._exec_and_verify_payload() + + @torch._dynamo.config.patch(assume_static_by_default=False) + def test_dynamic_expression(self): + """ + Test not emitting something like "s27*s53**2 = 36" + """ + + def f(x): + return torch.ops.aten._adaptive_avg_pool2d( + x, (6, 6) + ), torch.ops.aten._adaptive_avg_pool2d(x + 1, (2, 5)) + + x = torch.randn(2, 4, 16, 16) + torch.compile(f)(x) + self._exec_and_verify_payload() + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 693c90a10b3a4..204e5114320f6 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -2858,7 +2858,7 @@ def save_activations(mod, inp, out): def fn(x): return wrap(lambda x: model(x), x) - for i in range(2): + for _ in range(2): # second iteration is key, hooks would have fired during aot trace # on first iter activations.clear() diff --git a/test/dynamo/test_hooks.py b/test/dynamo/test_hooks.py index 3f3a3bd7f6537..125958596eb5e 100644 --- a/test/dynamo/test_hooks.py +++ b/test/dynamo/test_hooks.py @@ -807,7 +807,7 @@ class Mod(torch.nn.Module): def __init__(self) -> None: super().__init__() self.layers = torch.nn.ModuleList() - for i in range(10): + for _ in range(10): layer = torch.nn.Linear(16, 16) layer.register_forward_pre_hook(lambda _, inp: fw_hook(inp)) layer = torch.compile(layer, backend=cnts) diff --git a/test/dynamo/test_list.py b/test/dynamo/test_list.py index 60c799d0b6a44..41e5da15b5378 100644 --- a/test/dynamo/test_list.py +++ b/test/dynamo/test_list.py @@ -168,6 +168,14 @@ def test___contains__(self): self.assertRaises(TypeError, p.__contains__) self.assertRaises(TypeError, p.__contains__, 1, 2) + @make_dynamo_test + def test___iter__(self): + p = self.thetype([1]) + it = p.__iter__() + self.assertEqual(next(it), 1) + it = p.__iter__().__iter__() + self.assertEqual(next(it), 1) + class ListTests(TupleTests): # List methods diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 3856b5078375c..c47a26a7f6f7b 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1272,6 +1272,20 @@ def fn(d): r2 = opt_fn(d) self.assertEqual(r1, r2) + def test_tensor__iter__(self): + def fn(x): + it = x.__iter__() + for y in it: + y.add_(1.0) + return y + + torch._dynamo.testing.standard_test( + self, + fn, + 1, + expected_ops=20, + ) + def test_tensor_iter(self): def fn(x): for y in x: @@ -1961,6 +1975,15 @@ def run(n): self.assertTrue(same(res2, torch.ones(2))) self.assertTrue(same(res3, torch.ones(3))) + def test_range___iter__(self): + def func(x): + it = range(3).__iter__() + return x + next(it) + + opt_func = torch.compile(func, backend="eager", fullgraph=True) + x = torch.randn(3) + self.assertTrue(same(func(x), opt_func(x))) + def test_range_iter_side_effects(self): @torch.compile(backend="eager", fullgraph=True) def run(x, it): @@ -4324,6 +4347,33 @@ def fn(x): res = opt_fn(x) self.assertTrue(same(ref, res)) + def test_tying_union_new_syntax(self): + def fn(x): + def inner1(y: torch.Tensor | None): + return y + + def inner2(y: None | torch.Tensor): + return y + + def inner3(y: torch.Tensor | list[int]): + return y + + return x + 1 + + torch.compile(fn, backend="eager", fullgraph=True)(torch.ones(3)) + + @unittest.expectedFailure + def test_typing_union_new_syntax_reconstruct(self): + def fn(x): + return ( + x + 1, + torch.Tensor | None, + None | torch.Tensor, + torch.Tensor | list[int], + ) + + torch.compile(fn, backend="eager", fullgraph=True)(torch.ones(3)) + def test_optimize_on_module(self): class MockModule(torch.nn.Module): def __init__(self) -> None: @@ -6970,7 +7020,7 @@ def guard_failures(failure): # guard is expected for both static and dynamic shapes self.assertTrue(guard_failure is not None) self.assertIn( - """len(x) == 10""", + """size mismatch at index 0. expected 10, actual 9""", guard_failure[0], ) @@ -9608,6 +9658,18 @@ def fn(img): self.assertEqual(msg, "shape torch.Size([8, 8]) batch size 1") self.assertEqual(res, img1 + torch.sin(img1)) + def test_str___iter__(self): + def fn(x): + s = "a" + if next(s.__iter__()) == "a": + return x + 1 + else: + return x + + x = torch.randn(3) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + def test_str_format_return2(self): @torch.compile(backend="eager", fullgraph=True) def fn(img): diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index c251ce28bac49..4718ef0795897 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -262,7 +262,7 @@ def __init__(self) -> None: self.count = 3 def forward(self, x): - for i in range(self.count): + for _ in range(self.count): x = torch.sigmoid(self.linear1(x)) return x @@ -509,7 +509,7 @@ def __init__(self) -> None: self.layer = torch.nn.Linear(10, 10) def forward(self, x): - for i in range(self.cfg.count): + for _ in range(self.cfg.count): x = self.layer(x + self.cfg.val) return x @@ -781,7 +781,7 @@ def __init__(self) -> None: def forward(self, x): counter = 0 - for param in self.parameters(): + for _param in self.parameters(): counter += 1 return x * self.scale * counter @@ -841,7 +841,7 @@ def __init__( def forward(self, init_features): features = [init_features] - for idx, layer in enumerate(self.values()): + for layer in self.values(): new_features = layer(features) features.append(new_features) return torch.cat(features, 1) @@ -2161,7 +2161,7 @@ def fn(x, mod): cnts = torch._dynamo.testing.CompileCounterWithBackend("eager") opt_mod = torch.compile(fn, backend=cnts) - for i in range(8): + for _ in range(8): mod = Mod() opt_mod(torch.randn(5, 5), mod) @@ -2516,7 +2516,7 @@ def save_activations(name, mod, inp, out): compiled_model = torch.compile(model, backend="aot_eager") activations = compiled_activations - for i in range(2): + for _ in range(2): # second iteration is key, hooks would have fired during aot trace # on first iter compiled_activations.clear() @@ -2526,7 +2526,7 @@ def save_activations(name, mod, inp, out): loss.backward() activations = eager_activations - for i in range(2): + for _ in range(2): # second iteration is key, hooks would have fired during aot trace # on first iter eager_activations.clear() @@ -2575,12 +2575,12 @@ def forward(self, x): def save_activations(mod, inp, out): activations.append(inp) - for name, module in model.named_modules(): + for module in model.modules(): module.register_forward_hook(save_activations) cnt = torch._dynamo.testing.CompileCounter() model = torch.compile(model, backend=cnt, fullgraph=True) - for i in range(2): + for _ in range(2): # second iteration is key, hooks would have fired during aot trace # on first iter activations.clear() @@ -2703,7 +2703,7 @@ def backward_pre_hook(name, mod, grad_out): model = torch.compile(model, backend="aot_eager") - for i in range(2): + for _ in range(2): # second iteration is key, hooks would have fired during aot trace # on first iter x = torch.randn((20, 10)) @@ -2763,6 +2763,22 @@ def forward(self, x): self.assertEqual(eager_res, optim_res) self.assertEqual(cnt.frame_count, 1) + def test_specialized_module___iter__(self): + ml = torch.nn.ModuleList( + [ + torch.nn.Linear(10, 10), + ] + ) + ml.torchdynamo_force_dynamic = False + + def f(x): + it = ml.__iter__() + return next(it)(x) + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + x = torch.randn(10) + self.assertEqual(f(x), opt_f(x)) + def test_module_dict_iter_keys(self): class MyModule(torch.nn.Module): def __init__(self) -> None: diff --git a/test/dynamo/test_nested_graph_breaks.py b/test/dynamo/test_nested_graph_breaks.py index fcd55eaa0dc1b..bc41e19c9ef01 100644 --- a/test/dynamo/test_nested_graph_breaks.py +++ b/test/dynamo/test_nested_graph_breaks.py @@ -380,6 +380,41 @@ def outer(x): self.assertEqual(cnts.frame_count, 2) self.assertEqual(cnts.op_count, 13) + def test_dead_nested_cells(self): + global f1, f2, f3 + + def f3(x, cell1): + cell1 += 2 + x = x + cell1 + torch._dynamo.graph_break() + return x + cell1 + + def f1(cell1=0): + def inner(x): + x += 4 + x = f3(x, cell1) + return x + 8 + + return inner + + def f2(x): + return f1()(x + 16) + 32 + + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(backend=cnts)(f2) + x = torch.zeros(3) + res = f2(x) + ref = opt_fn(x) + self.assertEqual(ref, res) + # If we don't handle dead cells in nested functions correctly, + # frame_count will increase since we also + # graph break when we attempt to codegen inner. + # The exact issue was that side_effects was failing to codegen inner's cell's creation. + # So when we try to codegen cells for resume functions, we end up trying to codegen + # a CellVariable without a source, which leads to a graph break we can't resume from. + self.assertEqual(cnts.frame_count, 2) + self.assertEqual(cnts.op_count, 6) + def test_cells_double_graph_break(self): def f1(x1): cell1 = x1 + 1 @@ -806,6 +841,39 @@ def f8(x): ) ) + def test_disable_nested_graph_breaks(self): + global f1, f2, f3, f4, f5 + + def f1(x): + x = x + 1 + torch._dynamo.graph_break() + return x + 2 + + def f2(x): + return f1(x + 4) + 8 + + # NOTE since the disable_nested_graph_breaks decorator is implemented as a + # context manager, we don't need to separately test context manager usage. + @torch._dynamo.disable_nested_graph_breaks + def f3(x): + return f2(x + 16) + 32 + + def f4(x): + return f3(x + 64) + 128 + + def f5(x): + return f4(x + 256) + 512 + + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(backend=cnts)(f5) + x = torch.zeros(3) + res = f5(x) + ref = opt_fn(x) + self.assertEqual(ref, res) + # 2 frames from each of f5+f4, f3, f2, f1 + self.assertEqual(cnts.frame_count, 8) + self.assertEqual(cnts.op_count, 10) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_regional_inductor.py b/test/dynamo/test_regional_inductor.py index 0dc208ddadc56..524d7fa499c39 100644 --- a/test/dynamo/test_regional_inductor.py +++ b/test/dynamo/test_regional_inductor.py @@ -1,13 +1,16 @@ # Owner(s): ["module: dynamo"] import functools +from typing import TYPE_CHECKING import torch import torch._inductor.test_case import torch.fx.traceback as fx_traceback import torch.utils.checkpoint from torch._dynamo.backends.common import aot_autograd +from torch._functorch._aot_autograd.autograd_cache import BundledCompiledForward from torch._guards import detect_fake_mode +from torch._inductor.output_code import RegionalOutputCode from torch._inductor.test_case import run_tests from torch._inductor.utils import run_fw_bw_and_get_code from torch.fx._graph_pickler import GraphPickler @@ -21,6 +24,10 @@ from torch.testing._internal.triton_utils import requires_cuda_and_triton +if TYPE_CHECKING: + from torch._inductor.compile_fx import _CompileFxKwargs + + # Open questions / follow-ups # 1) CSE behavior with meta custom nodes # Common subexpression elimination may not differentiate between distinct meta @@ -462,5 +469,154 @@ def flex_attn_fn(x): self.assertEqual(len(codes), 2) +@skipIfTorchDynamo("Not a suitable dynamo wrapped test") +class TestRegionalOutputCode(torch._inductor.test_case.TestCase): + """Tests for RegionalOutputCode and BundledAOTAutogradResult.""" + + def test_regional_output_code_serialization(self): + """Test that RegionalOutputCode can be serialized and deserialized.""" + + def fn(x, y): + sin = torch.sin(x) + with fx_traceback.annotate({"compile_with_inductor": 0}): + mul = sin * y + add = mul + 1 + return torch.sin(add) + + x = torch.randn(10, requires_grad=True) + y = torch.randn(10, requires_grad=True) + + # Compile with regional inductor + with torch.fx.traceback.preserve_node_meta(enable=False): + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.fx.experimental.proxy_tensor import make_fx + + fake_mode = FakeTensorMode() + with fake_mode: + fake_x = fake_mode.from_tensor(x) + fake_y = fake_mode.from_tensor(y) + gm = make_fx(fn)(fake_x, fake_y) + + # Run regional_inductor on the graph + result_gm = regional_inductor(gm, fake_x, fake_y) + + # Create RegionalOutputCode + output_code = RegionalOutputCode(result_gm) + + # Test that we can call it + self.assertIsNotNone(output_code._graph_module) + + # Serialize + output_code.prepare_for_serialization() + self.assertIsNone(output_code._graph_module) + self.assertIsNotNone(output_code._serialized_graph_module) + + # Deserialize via post_compile + from torch._inductor.output_code import CompiledFxGraphConstants + + fx_config: _CompileFxKwargs = {"is_backward": False} + output_code.post_compile( + [fake_x, fake_y], CompiledFxGraphConstants(), fx_config + ) + self.assertIsNotNone(output_code._graph_module) + self.assertIsInstance(output_code._graph_module, torch.fx.GraphModule) + + # Test that deserialized graph works + with fake_mode: + result = output_code([fake_x, fake_y]) + self.assertIsNotNone(result) + + def test_regional_output_code_with_backward(self): + """Test RegionalOutputCode with both forward and backward compilation.""" + + def fn(x, y): + sin = torch.sin(x) + with fx_traceback.annotate({"compile_with_inductor": 0}): + mul = sin * y + add = mul + 1 + return torch.sin(add) + + x = torch.randn(10, requires_grad=True) + y = torch.randn(10, requires_grad=True) + + # Compile with regional inductor backend + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.fx.experimental.proxy_tensor import make_fx + + fake_mode = FakeTensorMode() + with fake_mode: + fake_x = fake_mode.from_tensor(x) + fake_y = fake_mode.from_tensor(y) + + # Create forward graph + with torch.fx.traceback.preserve_node_meta(enable=False): + gm = make_fx(fn)(fake_x, fake_y) + forward_gm = regional_inductor(gm, fake_x, fake_y) + + # Create forward output code + fw_code = RegionalOutputCode(forward_gm) + + # Verify it can be called + with fake_mode: + result = fw_code([fake_x, fake_y]) + self.assertIsNotNone(result) + + # Test serialization round-trip + fw_code.prepare_for_serialization() + + # Deserialize via post_compile + + from torch._inductor.output_code import CompiledFxGraphConstants + + fx_config: _CompileFxKwargs = {"is_backward": False} + fw_code.post_compile([fake_x, fake_y], CompiledFxGraphConstants(), fx_config) + + with fake_mode: + result2 = fw_code([fake_x, fake_y]) + self.assertIsNotNone(result2) + + def test_regional_compiled_forward_backward(self): + """Test BundledCompiledForward and BundledCompiledBackward with RegionalOutputCode.""" + + def fn(x): + with fx_traceback.annotate({"compile_with_inductor": 0}): + return torch.sin(x) * 2 + + x = torch.randn(5, requires_grad=True) + + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.fx.experimental.proxy_tensor import make_fx + + fake_mode = FakeTensorMode() + with fake_mode: + fake_x = fake_mode.from_tensor(x) + + with torch.fx.traceback.preserve_node_meta(enable=False): + gm = make_fx(fn)(fake_x) + compiled_gm = regional_inductor(gm, fake_x) + + # Create forward using the generic BundledCompiledForward + fw_code = RegionalOutputCode(compiled_gm) + fw_compiled = BundledCompiledForward[RegionalOutputCode](result=fw_code) + + # Test pre_save + fw_compiled.pre_save() + # After pre_save, fw_compiled.result is a copy with serialized graph + self.assertIsNotNone(fw_compiled.result._serialized_graph_module) + self.assertIsNone( + fw_compiled.result._graph_module + ) # Should be cleared after serialization + + # Test load (doesn't deserialize yet) + loaded_code = fw_compiled.load([fake_x]) + self.assertIsNone(loaded_code._graph_module) # Not yet deserialized + self.assertIsNotNone(loaded_code._serialized_graph_module) + + fx_config: _CompileFxKwargs = {"is_backward": False} + post_compiled = fw_compiled.post_compile(loaded_code, fx_config) + self.assertIsNotNone(post_compiled) + self.assertIsNotNone(post_compiled._graph_module) # Now deserialized + + if __name__ == "__main__": run_tests() diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index c6138f7574fd4..7cd10ae356f99 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1000,18 +1000,6 @@ def tearDown(self) -> None: self.exit_stack.close() super().tearDown() - def test_compiled_module_truthiness(self): - # Test with empty ModuleList - original_empty = nn.ModuleList() - compiled_empty = torch.compile(original_empty) - self.assertEqual(bool(original_empty), bool(compiled_empty)) - self.assertFalse(bool(compiled_empty)) - # Test with non-empty ModuleList - original_filled = nn.ModuleList([nn.Linear(10, 5)]) - compiled_filled = torch.compile(original_filled) - self.assertEqual(bool(original_filled), bool(compiled_filled)) - self.assertTrue(bool(compiled_filled)) - def guard_manager_clone_hook_fn(self, guard_manager_wrapper, f_locals, builder): root = guard_manager_wrapper.root cloned_root = root.clone_manager(lambda x: True) diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index a73edc930df31..e05e1304d2860 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -1,11 +1,20 @@ # Owner(s): ["module: dynamo"] +import functools +import unittest +import weakref import torch import torch._dynamo.test_case import torch._dynamo.testing +from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import requires_cuda +requires_multigpu = functools.partial( + unittest.skipIf, not TEST_MULTIGPU, "requires multiple cuda devices" +) + + class TestStreams(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): @@ -15,6 +24,154 @@ def setUpClass(cls): def tearDownClass(cls): super().tearDownClass() + @requires_cuda + def test_stream_weakref(self): + s = torch.Stream() + weakref.ref(s) + + @requires_cuda + def test_event_weakref(self): + e = torch.Event() + weakref.ref(e) + + @requires_cuda + def test_stream_enter_exit(self): + def fn(x, y): + s2 = torch.Stream() + s1 = torch.Stream() + with s1: + z1 = torch.add(x, y) + with s2: + z = torch.add(x, y) + y = z + 2 + z1 + + return y + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) + expected = fn(*inp) + fn_opt = torch.compile(fn, fullgraph=True) + actual = fn_opt(*inp) + self.assertEqual(expected, actual) + + @requires_cuda + def test_stream_context_graph_break(self): + def fn(x, y): + s2 = torch.Stream() + s1 = torch.Stream() + with s1: + z1 = torch.add(x, y) + with s2: + z = torch.add(x, y) + y = z + 2 + z1 + torch._dynamo.graph_break() + y = y + 1 + + return y + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) + expected = fn(*inp) + fn_opt = torch.compile(fn) + actual = fn_opt(*inp) + self.assertEqual(expected, actual) + + @requires_cuda + def test_stream_input(self): + def fn(x, y, s): + z = torch.add(x, y) + y = z + 2 + return y, s + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(device="cuda")) + expected = fn(*inp) + fn_opt = torch.compile(fn, fullgraph=True) + actual = fn_opt(*inp) + self.assertEqual(expected, actual) + + @requires_cuda + def test_local_stream_return(self): + def fn(x, y): + s = torch.Stream() + z = torch.add(x, y) + y = z + 2 + return y, s + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) + fn_opt = torch.compile(fn, fullgraph=True) + _, s0 = fn_opt(*inp) + _, s1 = fn_opt(*inp) + # Streams will be different values for each invocation + # so don't check for equality + self.assertIsInstance(s0, torch.Stream) + # Stream should be newly allocated on each call + self.assertNotEqual(s0, s1) + + @requires_cuda + def test_get_current_stream_return(self): + def fn(x, s): + with s: + s0 = torch.accelerator.current_stream() + return x, s0 + + s_inp = torch.Stream(device="cuda") + inp = (torch.ones(2, 2) + 1, s_inp) + fn_opt = torch.compile(fn, fullgraph=True) + _, s0 = fn_opt(*inp) + _, s1 = fn_opt(*inp) + self.assertEqual(s_inp, s0) + self.assertEqual(s0, s1) + + @requires_cuda + @requires_multigpu() + def test_get_current_stream_return_different_device(self): + def fn(x, s0, s1): + with s1: + with s0: + s = torch.accelerator.current_stream(torch.device("cuda:1")) + return s + + s0 = torch.Stream(device="cuda:0") + s1 = torch.Stream(device="cuda:1") + inp = (torch.ones(2, 2) + 1, s0, s1) + fn_opt = torch.compile(fn, fullgraph=True) + s_act = fn_opt(*inp) + s_exp = fn(*inp) + self.assertEqual(s_act, s_exp) + + @requires_cuda + @requires_multigpu() + def test_get_current_stream_return_no_index(self): + def fn(x, s0, s1): + with s1: + with s0: + s = torch.accelerator.current_stream(torch.device("cuda")) + return s + + s0 = torch.Stream(device="cuda:0") + s1 = torch.Stream(device="cuda:1") + inp = (torch.ones(2, 2) + 1, s0, s1) + fn_opt = torch.compile(fn, fullgraph=True) + s_act = fn_opt(*inp) + s_exp = fn(*inp) + self.assertEqual(s_act, s_exp) + + def test_nested_stream_enter_exit(self): + pass + + def test_stream_enter_exit_graph_break(self): + pass + + def test_nested_stream_enter_exit_graph_break(self): + pass + + def test_local_stream_enter_exit(self): + pass + + def test_local_stream_nested_enter_exit(self): + pass + + def test_stream_with_mutation(self): + pass + @requires_cuda def test_run_opcheck(self): from torch._dynamo.variables.streams import fork_stream, join_stream diff --git a/test/dynamo/test_subgraphs.py b/test/dynamo/test_subgraphs.py index 35036fd1de3fa..16c765bfb1409 100644 --- a/test/dynamo/test_subgraphs.py +++ b/test/dynamo/test_subgraphs.py @@ -580,7 +580,7 @@ def fn(x, it): def test_enumerate_not_break_graph(self): def fn(a, b): - for i, x in enumerate(a.shape): + for _, x in enumerate(a.shape): b = b + x for i, x in enumerate(b.shape, 8): b = b + x * i diff --git a/test/dynamo/test_trace_rules.py b/test/dynamo/test_trace_rules.py index 9bfccd94b1f7e..e9c6df7e959f8 100644 --- a/test/dynamo/test_trace_rules.py +++ b/test/dynamo/test_trace_rules.py @@ -466,6 +466,7 @@ def test_no_special_handlers_for_torch_non_c_bindings(self): "handle_cudnn_is_acceptable", # No global state "handle_assert", # No global state (constant) "handle_nested_tensor", # No global state + "handle_current_stream", # Safely implemented ) for fn in handlers: if isinstance(fn, staticmethod) or inspect.ismethod(fn): diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index 91862e6d3eb00..2085e46c5000e 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -697,7 +697,7 @@ def f(x, y): @torch._dynamo.config.patch(specialize_float=False, capture_scalar_outputs=True) def test_unspecialized_float_multiply_precision(self): dtypes = [torch.bfloat16, torch.float16, torch.float32, torch.float64] - for i, dtype in enumerate(dtypes): + for dtype in dtypes: def fn(x, y): return x * y @@ -722,7 +722,7 @@ def f(x): return x + y.item() dtypes = [torch.bfloat16, torch.float16, torch.float32, torch.float64] - for i, dtype in enumerate(dtypes): + for dtype in dtypes: x = torch.ones(3, 3, dtype=dtype) self.assertEqual(f(x), x + x.sum().item()) diff --git a/test/dynamo_expected_failures/CPython313-test_heapq-TestHeapPython.test_empty_merges b/test/dynamo_expected_failures/CPython313-test_heapq-TestHeapPython.test_empty_merges deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_exhausted_iterator b/test/dynamo_expected_failures/CPython313-test_list-ListTest.test_exhausted_iterator deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_init b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_init deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 42c63ad8706f2..cf5416f45a33a 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -1089,6 +1089,8 @@ aten::rand.names aten::rand.names_out aten::rand.out aten::rand_like +aten::rand_like.generator +aten::rand_like.generator_out aten::rand_like.out aten::randint aten::randint.generator @@ -1100,9 +1102,15 @@ aten::randint.low_out aten::randint.out aten::randint_like aten::randint_like.Tensor +aten::randint_like.Tensor_generator +aten::randint_like.Tensor_generator_out aten::randint_like.Tensor_out +aten::randint_like.generator +aten::randint_like.generator_out aten::randint_like.low_dtype aten::randint_like.low_dtype_out +aten::randint_like.low_generator_dtype +aten::randint_like.low_generator_dtype_out aten::randint_like.out aten::randn.generator aten::randn.generator_with_names @@ -1110,6 +1118,8 @@ aten::randn.generator_with_names_out aten::randn.names aten::randn.names_out aten::randn_like +aten::randn_like.generator +aten::randn_like.generator_out aten::randn_like.out aten::random aten::random.from diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index a1cc88568107f..d67175f8aa3da 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -522,6 +522,83 @@ def forward(self, args_0): ) self.assertEqual(ep(*inps), MyModel()(*inps)) + def test_dynamo_graph_capture_full_tracing_context(self) -> None: + class Foo(torch.nn.Module): + def forward(self, x): + return x + x.shape[0] + + foo = Foo() + + def make_inputs(b: int): + ret = (torch.randn(b, 3),) + torch._dynamo.mark_dynamic(ret[0], 0) + return ret + + trace_inputs = make_inputs(2) + gm = dynamo_graph_capture_for_export(foo)(*trace_inputs) + test_inputs = make_inputs(3) + self.assertEqual(gm(*test_inputs), foo(*test_inputs)) + self.assertIsNotNone(gm.meta["tracing_context"].fake_mode) + self.assertEqual(len(gm.meta["tracing_context"].tensor_to_context), 1) + + def test_dynamo_graph_capture_dict_keys_getitem(self): + class Module(torch.nn.Module): + def forward(self, x): + return x * 2 + + foo = Module() + + class BlockMask: + def __init__(self, d): + self.d = d + + block_mask = BlockMask(torch.randn(4)) + + def pre_hook_function(m, input): + block_mask.d = input[0] + 1 + return input # Return a tuple of modified inputs + + foo.register_forward_pre_hook(pre_hook_function) + + def make_inputs(): + return (torch.randn(4),) + + trace_inputs = make_inputs() + gm = dynamo_graph_capture_for_export(foo)(*trace_inputs) + test_inputs = make_inputs() + self.assertExpectedInline( + gm.code.strip("\r\n "), + """\ +def forward(self, args_0): + _tree_leaf_0, _tree_leaf_1, = pytree.tree_leaves((self, args_0,)) + L_args_0_ , = self._in_shuffle_graph(_tree_leaf_0, _tree_leaf_1) + l_args_0_ = L_args_0_ + add = l_args_0_ + 1 + mul = l_args_0_ * 2; l_args_0_ = None + return pytree.tree_unflatten(self._out_shuffle_graph(_tree_leaf_0, _tree_leaf_1, mul, add), self._out_spec)""", + ) + self.assertEqual(gm(*test_inputs), foo(*test_inputs)) + + def test_dynamo_graph_capture_with_tensor_constant(self): + outer = torch.randn(2, 3) + + class MyModel(torch.nn.Module): + def forward(self, x): + z = x + outer + return z + + foo = MyModel() + + def make_inputs(): + return (torch.randn(2, 3),) + + trace_inputs = make_inputs() + gm = dynamo_graph_capture_for_export(foo)(*trace_inputs) + test_inputs = make_inputs() + self.assertEqual(gm(*test_inputs), foo(*test_inputs)) + self.assertEqual(len(list(gm.buffers())), len(list(foo.buffers()))) + self.assertEqual(len(list(gm.parameters())), len(list(foo.parameters()))) + @unittest.skipIf(not TEST_CUDA, "CUDA not available") def test_dynamo_graph_capture_fx_graph_annotate_overlap_pass(self): class DummyOp(torch.autograd.Function): @@ -585,6 +662,16 @@ def forward(self, args_0, args_1): test_inputs = input_fn() self.assertEqual(gm(*test_inputs), model(*test_inputs)) + def test_dynamo_graph_capture_default_args(self): + class Module(torch.nn.Module): + def forward(self, x, y=1): + return x + y + + m = Module() + ep = dynamo_graph_capture_for_export(m)(torch.randn(2, 3)) + test_inputs = (torch.randn(2, 3),) + self.assertEqual(ep(*test_inputs), m(*test_inputs)) + if __name__ == "__main__": run_tests() diff --git a/test/export/test_export.py b/test/export/test_export.py index 3250d82c3eae8..3908f03b11e55 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -16747,6 +16747,74 @@ def forward(self, x, y): self.assertEqual(result_non_strict, result_strict) + def test_tril_dynamic_diagonal(self): + class Module(torch.nn.Module): + def forward(self, x, y): + x_len = x.shape[0] + y_len = y.shape[0] + mask = torch.ones(x_len, y_len, dtype=torch.bool, device=x.device) + mask = mask.tril(diagonal=y_len - x_len) + return mask + + x = torch.randn(3, 4) + y = torch.randn(5, 4) + x_len = Dim("x_len", min=1, max=64) + y_len = Dim("y_len", min=1, max=64) + ep = export( + Module(), + (x, y), + dynamic_shapes={ + "x": {0: x_len}, + "y": {0: y_len}, + }, + ) + eager_out = Module()(x, y) + exported_out = ep.module()(x, y) + self.assertEqual(eager_out, exported_out) + self.assertEqual(exported_out.shape, (3, 5)) + x2 = torch.randn(4, 4) + y2 = torch.randn(7, 4) + eager_out2 = Module()(x2, y2) + exported_out2 = ep.module()(x2, y2) + self.assertEqual(eager_out2, exported_out2) + self.assertEqual(exported_out2.shape, (4, 7)) + expected_mask = torch.ones(3, 5, dtype=torch.bool).tril(diagonal=2) + self.assertEqual(eager_out, expected_mask) + + def test_triu_dynamic_diagonal(self): + class Module(torch.nn.Module): + def forward(self, x, y): + x_len = x.shape[0] + y_len = y.shape[0] + mask = torch.ones(x_len, y_len, dtype=torch.bool, device=x.device) + mask = mask.triu(diagonal=y_len - x_len) + return mask + + x = torch.randn(3, 4) + y = torch.randn(5, 4) + x_len = Dim("x_len", min=1, max=64) + y_len = Dim("y_len", min=1, max=64) + ep = export( + Module(), + (x, y), + dynamic_shapes={ + "x": {0: x_len}, + "y": {0: y_len}, + }, + ) + eager_out = Module()(x, y) + exported_out = ep.module()(x, y) + self.assertEqual(eager_out, exported_out) + self.assertEqual(exported_out.shape, (3, 5)) + x2 = torch.randn(4, 4) + y2 = torch.randn(7, 4) + eager_out2 = Module()(x2, y2) + exported_out2 = ep.module()(x2, y2) + self.assertEqual(eager_out2, exported_out2) + self.assertEqual(exported_out2.shape, (4, 7)) + expected_mask = torch.ones(3, 5, dtype=torch.bool).triu(diagonal=2) + self.assertEqual(eager_out, expected_mask) + @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestOneOffModelExportResult(TestCase): diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 2f68cdf479439..472ddcf556f83 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -600,6 +600,8 @@ def add_kernel( in_ptr1, out_ptr, n_elements, + fval, + ival, BLOCK_SIZE: "tl.constexpr", ): pid = tl.program_id(axis=0) @@ -608,7 +610,7 @@ def add_kernel( mask = offsets < n_elements x = tl.load(in_ptr0 + offsets, mask=mask) y = tl.load(in_ptr1 + offsets, mask=mask) - output = x + y + output = x + y + fval + ival tl.store(out_ptr + offsets, output, mask=mask) def custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: @@ -618,7 +620,9 @@ def custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def grid(meta): return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16) + wrap_triton(add_kernel)[grid]( + x, y, output, n_elements, 3.14, 42, BLOCK_SIZE=16 + ) return output @@ -633,7 +637,9 @@ def custom_add_autotune(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def grid(meta): return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) - wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16, num_warps=8) + wrap_triton(add_kernel)[grid]( + x, y, output, n_elements, 3.14, 42, BLOCK_SIZE=16, num_warps=8 + ) return output @@ -661,34 +667,44 @@ def forward(self, x, y): self.assertIsNotNone(triton_node) args = [] - kwargs = [] + kwargs = {} for arg in triton_node.inputs: if arg.kind == ArgumentKind.POSITIONAL: args.append(arg.arg) elif arg.kind == ArgumentKind.KEYWORD: - kwargs.append(arg.arg) + kwargs[arg.name] = arg.arg - self.assertEqual(len(args), 4) - self.assertEqual(len(kwargs), 5) + self.assertEqual(len(args), 6) + # Always: name, grid, output_indices and num_warps are + # Triton version dependent: num_cpu_threads, shared_memory_bytes + self.assertTrue(len(kwargs) >= 4) for i in range(3): self.assertIsNotNone(args[i].as_tensor) self.assertEqual(args[3].as_int, 3) - - self.assertEqual(kwargs[0].as_string, "add_kernel") # name - self.assertEqual(kwargs[1].as_ints, [1, 1, 1]) # grid - self.assertEqual(kwargs[2].as_ints, [2]) # output indices + self.assertAlmostEqual(args[4].as_float, 3.14, places=2) + self.assertEqual(args[5].as_int, 42) + kernel_name = kwargs["name"].as_string + symbol_name = kernel_name.rpartition("_")[0] + self.assertEqual(symbol_name, "add_kernel") + self.assertEqual(kwargs["grid"].as_ints, [1, 1, 1]) + self.assertEqual(kwargs["output_indices"].as_ints, [2]) self.assertEqual( - kwargs[3].as_int, 8 if isinstance(m, MyModelAutotune) else 4 - ) # num warps - self.assertEqual(kwargs[4].as_int, 0) # shared mem bytes + kwargs["num_warps"].as_int, 8 if isinstance(m, MyModelAutotune) else 4 + ) + + if "num_cpu_threads" in kwargs: + self.assertEqual(kwargs["num_cpu_threads"].as_int, 0) + if "shared_memory_bytes" in kwargs: + self.assertEqual(kwargs["shared_memory_bytes"].as_int, 0) self.assertEqual(len(triton_node.outputs), 1) self.assertIsNotNone(triton_node.outputs[0].as_tensors) self.assertEqual( - len(triton_node.outputs[0].as_tensors), len(kwargs[2].as_ints) + len(triton_node.outputs[0].as_tensors), + len(kwargs["output_indices"].as_ints), ) self.assertEqual(triton_node.outputs[0].as_tensors[0].name, "getitem") diff --git a/test/functorch/test_aot_joint_with_descriptors.py b/test/functorch/test_aot_joint_with_descriptors.py index 25535dc5334ee..7949d2bb46cbf 100644 --- a/test/functorch/test_aot_joint_with_descriptors.py +++ b/test/functorch/test_aot_joint_with_descriptors.py @@ -675,7 +675,7 @@ def forward(self, x): # Verify buffer handling buffer_count = 0 - for desc, (node, grad_node) in input_grad_nodes.items(): + for desc, (node, _grad_node) in input_grad_nodes.items(): if isinstance(desc, BufferAOTInput): buffer_count += 1 self.assertIsNotNone(node) @@ -764,13 +764,13 @@ def forward(self, x): self.assertIn(node, named_params.values()) # Check that param_grads contains the same parameter nodes - for desc, (param_node, grad_node) in param_grads.items(): + for desc, (param_node, _grad_node) in param_grads.items(): self.assertIn(param_node, param_nodes) self.assertEqual(param_node, named_params[desc.target]) # Check that all_input_grads contains the parameter nodes param_count = 0 - for desc, (input_node, grad_node) in all_input_grads.items(): + for desc, (input_node, _grad_node) in all_input_grads.items(): if isinstance(desc, ParamAOTInput): param_count += 1 self.assertIn(input_node, param_nodes) @@ -1069,6 +1069,31 @@ def forward(self, x): ('call_function', 'index', {'pp_stage': 0})""", ) + def test_static_input_indices(self): + """Test basic linear module with aot_export_joint_with_descriptors""" + + class SimpleLinear(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(3, 2) + + def forward(self, x): + return self.linear(x) + + model = SimpleLinear() + inputs = (torch.randn(4, 3),) + gm = _dynamo_graph_capture_for_export(model)(*inputs) + fake_mode = gm.meta.get("fake_mode", None) + + with tracing(TracingContext(fake_mode)): + with ExitStack() as stack: + joint = aot_export_joint_with_descriptors( + stack, + gm, + inputs, + ) + self.assertEqual(joint._aot_state.fw_metadata.static_input_indices, [0, 1]) + if __name__ == "__main__": run_tests() diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 29b69322d2fc8..fba7a96288caf 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -7356,7 +7356,6 @@ def fn(x): aot_eager = torch.compile(backend="aot_eager")(fn)(x) self.assertEqual(eager, aot_eager, atol=0, rtol=0) - @unittest.expectedFailure @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") def test_rms_norm(self): # Only CUDA rms norm fails to be decomposed @@ -7555,7 +7554,7 @@ def _tg3(y): (_inp, _tg3), ] - for i, (inp_fn, tg_fn) in enumerate(TEST_CASES): + for inp_fn, tg_fn in TEST_CASES: ref_x = inp_fn() x = ref_x.detach().clone().requires_grad_() diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 5bfd1f200dd02..5034661fa3e05 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -742,7 +742,7 @@ def forward(self, pred_1, x_1): def test_cond_in_forloop(self): def for_loop_fake(x): - for i in range(3): + for _ in range(3): x = x * x + 1 return x @@ -3088,9 +3088,7 @@ def run_test_and_get_grads_loss(model, initial_hs, inputs): ) # Compare gradients for each layer - for i, (uncompiled_grad, compiled_grad) in enumerate( - zip(uncompiled_grads, compiled_grads) - ): + for uncompiled_grad, compiled_grad in zip(uncompiled_grads, compiled_grads): self.assertEqual( uncompiled_grad, compiled_grad, diff --git a/test/functorch/test_dims.py b/test/functorch/test_dims.py index eb5202d4bb2ef..a0cd59c02665a 100644 --- a/test/functorch/test_dims.py +++ b/test/functorch/test_dims.py @@ -282,7 +282,7 @@ def f(): # python 3.11 adapts bytecode after a number of iterations # check that we still match names correctly - for i in range(10): + for _ in range(10): f() @skipIf(not TEST_CUDA, "no CUDA") diff --git a/test/fx/test_fx_const_fold.py b/test/fx/test_fx_const_fold.py index 8ff9fb438619a..ed1a049a81bb9 100644 --- a/test/fx/test_fx_const_fold.py +++ b/test/fx/test_fx_const_fold.py @@ -707,6 +707,50 @@ def forward(self, x, y): fold_result = mod_folded(in_x, in_y) self.assertTrue(torch.equal(fold_result, base_result)) + def test_fold_pure_subgraph(self): + class SubModule(torch.nn.Module): + def forward(self): + return torch.full((5, 10), 2.0) + 1 + + # Create a parent graph with this module as a subgraph and output + ep = torch.export.export(SubModule(), ()) + parent_graph = torch.fx.Graph() + call_mod = parent_graph.call_module("sub", args=()) + get_item = parent_graph.call_function( + operator.getitem, args=(call_mod, slice(None)) + ) + parent_graph.output((get_item,)) + parent = torch.fx.GraphModule({"sub": ep.module()}, parent_graph) + + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs( + parent, device_for_folded_attrs="cpu" + ) + self._verify_const_fold_mod(mod_folded) + + def test_do_not_fold_impure_subgraph(self): + """ + Skip folding any subgraph containing impure ops. + """ + + class SubModule(torch.nn.Module): + def forward(self): + return torch.randn(5, 10) + 1 + + # Create a parent graph with this module as a subgraph and output + ep = torch.export.export(SubModule(), ()) + parent_graph = torch.fx.Graph() + call_mod = parent_graph.call_module("sub", args=()) + get_item = parent_graph.call_function( + operator.getitem, args=(call_mod, slice(None)) + ) + parent_graph.output((get_item,)) + parent = torch.fx.GraphModule({"sub": ep.module()}, parent_graph) + + mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs( + parent, device_for_folded_attrs="cpu" + ) + self.assertIsNone(mod_folded.const_subgraph_module) + if __name__ == "__main__": raise_on_run_directly("test/test_fx.py") diff --git a/test/higher_order_ops/test_with_effects.py b/test/higher_order_ops/test_with_effects.py index 67facfb127d8e..d3e5d36dbed54 100644 --- a/test/higher_order_ops/test_with_effects.py +++ b/test/higher_order_ops/test_with_effects.py @@ -491,9 +491,7 @@ def ins_sc(): def ins_dense(): return torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0]) - for i, (ins_fn, expected_fw_count) in enumerate( - zip([ins_sc, ins_dense], [2, 1]) - ): + for ins_fn, expected_fw_count in zip([ins_sc, ins_dense], [2, 1]): reset_counter() ref_out = fn(*ins_fn()) assert_counter(expected_fw_count, 0) @@ -524,16 +522,14 @@ def ins_sc_req_grad(): ), ) - for i, ( + for ( ins_fn_req_grad, ( expected_fw_count, expected_fw_count_after_bw, expected_bw_count_after_bw, ), - ) in enumerate( - zip([ins_dense_req_grad, ins_sc_req_grad], [(1, 1, 1), (2, 2, 2)]) - ): + ) in zip([ins_dense_req_grad, ins_sc_req_grad], [(1, 1, 1), (2, 2, 2)]): ref_ins = ins_fn_req_grad() reset_counter() ref_out = fn(*ref_ins) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index cd719dd17fd1f..8f009f30a0a60 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -4869,7 +4869,7 @@ def forward(self, *inputs): return result inputs = [] - for i in range(1000): + for _ in range(1000): inputs.append(torch.ones(8, 8, 8, dtype=torch.float16, device=self.device)) inputs = tuple(inputs) model = Model() diff --git a/test/inductor/test_augmented_graph_helper.py b/test/inductor/test_augmented_graph_helper.py index ef1f92e23268c..92dcfa1b37b85 100644 --- a/test/inductor/test_augmented_graph_helper.py +++ b/test/inductor/test_augmented_graph_helper.py @@ -360,6 +360,191 @@ def test_multiple_merge_unmerge(self): self.assertEqual(self.tracker.merge_sets[nodes[0]], {nodes[0]}) self.assertEqual(len(self.tracker.merge_sets[nodes[1]]), 1) + # ========== Dependency Transfer Tests ========== + + def test_transfer_with_cross_deps(self): + """Test transfer when erased nodes depend on each other.""" + # old_start -> old_wait, both get replaced + # Should become: new_start -> new_wait + graph = fx.Graph() + x = graph.placeholder("x") + old_start = graph.call_function(torch.relu, args=(x,), name="old_start") + old_wait = graph.call_function(torch.abs, args=(x,), name="old_wait") + compute = graph.call_function(torch.neg, args=(old_wait,), name="compute") + graph.output(compute) + + tracker = AugmentedGraphHelper(graph) + + # Add cross-dependency: old_start -> old_wait + tracker.add_extra_dep(n=old_wait, dep=old_start) + # Add extra dep: compute -> old_wait + tracker.add_extra_dep(n=compute, dep=old_wait) + + # Create replacements + new_start = graph.call_function(torch.sigmoid, args=(x,), name="new_start") + new_wait = graph.call_function(torch.tanh, args=(x,), name="new_wait") + + # Transfer both at once + tracker.transfer_erased_node_deps({old_start: new_start, old_wait: new_wait}) + + # new_wait should depend on new_start (cross-dep redirected correctly) + self.assertIn(new_start, tracker.extra_deps[new_wait]) + + # compute should depend on new_wait + self.assertIn(new_wait, tracker.extra_deps[compute]) + + # Old nodes should be cleaned up + self.assertEqual(len(tracker.extra_deps[old_start]), 0) + self.assertEqual(len(tracker.extra_deps[old_wait]), 0) + self.assertEqual(len(tracker.extra_uses[old_start]), 0) + self.assertEqual(len(tracker.extra_uses[old_wait]), 0) + + def test_transfer_preserves_external_deps(self): + """Test that external dependencies are preserved correctly.""" + # external1 -> old1, old2 -> external2 + # Should become: external1 -> new1, new2 -> external2 + graph = fx.Graph() + x = graph.placeholder("x") + external1 = graph.call_function(torch.relu, args=(x,), name="external1") + old1 = graph.call_function(torch.abs, args=(x,), name="old1") + old2 = graph.call_function(torch.neg, args=(x,), name="old2") + external2 = graph.call_function(torch.sigmoid, args=(x,), name="external2") + graph.output(external2) + + tracker = AugmentedGraphHelper(graph) + + # Add deps: old1 -> external1, external2 -> old2 + tracker.add_extra_dep(n=old1, dep=external1) + tracker.add_extra_dep(n=external2, dep=old2) + + # Create new nodes + new1 = graph.call_function(torch.tanh, args=(x,), name="new1") + new2 = graph.call_function(torch.exp, args=(x,), name="new2") + + # Transfer + tracker.transfer_erased_node_deps({old1: new1, old2: new2}) + + self.assertIn(external1, tracker.extra_deps[new1]) + + self.assertIn(new2, tracker.extra_deps[external2]) + self.assertNotIn(old2, tracker.extra_deps[external2]) + + def test_transfer_with_merge_sets(self): + """Test transfer when nodes have merge sets.""" + graph = fx.Graph() + x = graph.placeholder("x") + old_a = graph.call_function(torch.relu, args=(x,), name="old_a") + old_b = graph.call_function(torch.abs, args=(x,), name="old_b") + dep = graph.call_function(torch.neg, args=(x,), name="dep") + user = graph.call_function(torch.sigmoid, args=(x,), name="user") + graph.output(user) + + tracker = AugmentedGraphHelper(graph) + + # Merge old_a and old_b + tracker.merge_to_set(old_a, old_b) + + # Add deps: old_a -> dep, user -> old_a + tracker.add_extra_dep(n=old_a, dep=dep) + tracker.add_extra_dep(n=user, dep=old_a) + + # Create new node + new = graph.call_function(torch.tanh, args=(x,), name="new") + + # Transfer (only need to specify one from merge set) + tracker.transfer_erased_node_deps({old_a: new}) + + # new should have dep on dep + self.assertIn(dep, tracker.extra_deps[new]) + + # user should depend on new + self.assertIn(new, tracker.extra_deps[user]) + + # Both old nodes should be cleaned up + self.assertEqual(len(tracker.extra_deps[old_a]), 0) + self.assertEqual(len(tracker.extra_deps[old_b]), 0) + + def test_transfer_multiple_merge_sets_with_chain(self): + """Test transferring multiple merge sets that depend on each other. + + Setup: + node1 (singleton) + node2, node3 (merged) + other_node (singleton) + node4, node5 (merged) + + Dependencies: + node2 -> node1 + other_node -> node3 + node4 -> other_node + + Transfer: + (node2, node3) -> new_2_3 + (node4, node5) -> new_4_5 + + Expected: + new_2_3 -> node1 + other_node -> new_2_3 + new_4_5 -> other_node + """ + graph = fx.Graph() + x = graph.placeholder("x") + + # Create nodes + node1 = graph.call_function(torch.relu, args=(x,), name="node1") + node2 = graph.call_function(torch.abs, args=(x,), name="node2") + node3 = graph.call_function(torch.neg, args=(x,), name="node3") + other_node = graph.call_function(torch.sigmoid, args=(x,), name="other_node") + node4 = graph.call_function(torch.tanh, args=(x,), name="node4") + node5 = graph.call_function(torch.exp, args=(x,), name="node5") + graph.output(other_node) + + tracker = AugmentedGraphHelper(graph) + + # Merge node2 and node3 + tracker.merge_to_set(node2, node3) + + # Merge node4 and node5 + tracker.merge_to_set(node4, node5) + + # Add dependencies + tracker.add_extra_dep(n=node2, dep=node1) # node2 -> node1 + tracker.add_extra_dep(n=other_node, dep=node3) # other_node -> node3 + tracker.add_extra_dep(n=node4, dep=other_node) # node4 -> other_node + + # Create replacement nodes + new_2_3 = graph.call_function(torch.sin, args=(x,), name="new_2_3") + new_4_5 = graph.call_function(torch.cos, args=(x,), name="new_4_5") + + # Transfer both merge sets atomically + tracker.transfer_erased_node_deps( + { + node2: new_2_3, # This will transfer both node2 and node3 + node4: new_4_5, # This will transfer both node4 and node5 + } + ) + + # Verify: new_2_3 should depend on node1 + self.assertIn(node1, tracker.extra_deps[new_2_3]) + + # Verify: other_node should depend on new_2_3 (not node3) + self.assertIn(new_2_3, tracker.extra_deps[other_node]) + self.assertNotIn(node3, tracker.extra_deps[other_node]) + + # Verify: new_4_5 should depend on other_node + self.assertIn(other_node, tracker.extra_deps[new_4_5]) + + # Verify: old nodes are cleaned up + self.assertEqual(len(tracker.extra_deps[node2]), 0) + self.assertEqual(len(tracker.extra_deps[node3]), 0) + self.assertEqual(len(tracker.extra_deps[node4]), 0) + self.assertEqual(len(tracker.extra_deps[node5]), 0) + + # Verify: bidirectional consistency + self.assertIn(new_2_3, tracker.extra_uses[node1]) + self.assertIn(other_node, tracker.extra_uses[new_2_3]) + self.assertIn(new_4_5, tracker.extra_uses[other_node]) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_compile_subprocess.py b/test/inductor/test_compile_subprocess.py index bf474bfbf1776..dc730e408b706 100644 --- a/test/inductor/test_compile_subprocess.py +++ b/test/inductor/test_compile_subprocess.py @@ -182,7 +182,7 @@ def test_async(self): @torch.compile(fullgraph=True, backend="inductor") def model_add(x, y): out = x - for i in range(500): + for _ in range(500): out = torch.add(out, y) return out diff --git a/test/inductor/test_compile_worker.py b/test/inductor/test_compile_worker.py index 8fde26c6acf67..50a389e8663f9 100644 --- a/test/inductor/test_compile_worker.py +++ b/test/inductor/test_compile_worker.py @@ -2,12 +2,14 @@ import operator import os import tempfile +from threading import Event from torch._inductor.compile_worker.subproc_pool import ( raise_testexc, SubprocException, SubprocPool, ) +from torch._inductor.compile_worker.timer import Timer from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import skipIfWindows from torch.testing._internal.inductor_utils import HAS_CPU @@ -81,6 +83,59 @@ def test_logging(self): pool.shutdown() +class TestTimer(TestCase): + def test_basics(self): + done = Event() + + def doit(): + done.set() + + t = Timer(0.1, doit) + t.sleep_time = 0.1 + t.record_call() + self.assertTrue(done.wait(4)) + t.quit() + + def test_repeated_calls(self): + done = Event() + + def doit(): + done.set() + + t = Timer(0.1, doit) + t.sleep_time = 0.1 + for _ in range(10): + t.record_call() + self.assertTrue(done.wait(4)) + done.clear() + t.quit() + + def test_never_fires(self): + done = Event() + + def doit(): + done.set() + + t = Timer(999, doit) + t.sleep_time = 0.1 + t.record_call() + self.assertFalse(done.wait(4)) + t.quit() + + def test_spammy_calls(self): + done = Event() + + def doit(): + done.set() + + t = Timer(1, doit) + t.sleep_time = 0.1 + for _ in range(400): + t.record_call() + self.assertTrue(done.wait(4)) + t.quit() + + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 6781e16aa6d7b..3001f86f4cfce 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -405,7 +405,7 @@ def __init__(self, ioc): self.grad_acc_hooks = [] self.grad_acc = [] self.params = [self.fc1.weight, self.fc2.weight] - for i, param in enumerate(self.params): + for param in self.params: def wrapper(param): param_tmp = param.expand_as(param) @@ -1558,7 +1558,7 @@ def _forward(self, x): dtype=input_tensor.dtype, device=DEVICE ) - for iteration in range(10): + for _ in range(10): for param in model_parameters: param.grad = None output_tensor = model( @@ -1599,7 +1599,7 @@ def eager_check(): eager_check() - for i in range(5): + for _ in range(5): with compiled_autograd._enable(compiler_fn): eager_check() diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 1ce5d88a20f8d..937208d9fd531 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -1543,22 +1543,26 @@ def fn( with config.patch({"cpp.simdlen": None}): torch._dynamo.reset() metrics.reset() - self.common( - fn, - ( - x, - scale, - zero_point, - use_dequant, - use_quant, - quant_min, - quant_max, - dtype, - dequant_out_dtype, - ), + inputs = ( + x, + scale, + zero_point, + use_dequant, + use_quant, + quant_min, + quant_max, + dtype, + dequant_out_dtype, ) + self.common(fn, inputs) check_metrics_vec_kernel_count(1) + # Check that both main and tail loops are vectorized + if dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + compiled_fn = torch.compile(fn) + _, code = run_and_get_cpp_code(compiled_fn, *inputs) + FileCheck().check_count("loadu", 2, exactly=True).run(code) + @requires_vectorization def test_dequant_quant_lowering_uint8(self): self._test_dequant_quant_lowering_helper(torch.uint8) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index ffdb7b112f894..1804f4692124f 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -38,8 +38,10 @@ from torch.testing._internal.common_utils import ( DeterministicGuard, freeze_rng_state, + instantiate_parametrized_tests, IS_FBCODE, MI350_ARCH, + parametrize, skipIfRocmArch, TEST_WITH_ASAN, TEST_WITH_ROCM, @@ -85,6 +87,7 @@ aten = torch.ops.aten +@instantiate_parametrized_tests class CudaReproTests(TestCase): device = "cuda" common = check_model_cuda @@ -541,7 +544,7 @@ def forward(self, x): input = torch.randn(10, 10, device="cuda", requires_grad=True) - for i in range(2): + for _ in range(2): output_ref = model_ref(input) output_res = model_opt(input) output_ref.sum().backward() @@ -2441,6 +2444,60 @@ def forward(self, x): f"Max diff: {torch.max(torch.abs(eager_output - compiled_output)):.6f}", ) + @parametrize( + "quantiles_shape,quantiles_strides,batch_size", + [ + ((100, 10), (10, 1), 16), # Contiguous C-order + ((100, 10), (1, 100), 16), # Transposed/F-order + ((80, 12), (1, 80), 16), # Transposed different size + ((50, 20), (1, 50), 16), # Transposed medium + ((200, 8), (1, 200), 16), # Transposed large x small + ((25, 40), (1, 25), 16), # Transposed small x large + ((20, 5, 8), (40, 1, 5), 16), # 3D case with mixed strides + ((20, 5, 8), (1, 20, 100), 16), # 3D case different stride order + ], + ) + def test_searchsorted_stride_permutations( + self, quantiles_shape, quantiles_strides, batch_size + ): + class Foo(torch.nn.Module): + def __init__(self, quantiles: torch.Tensor) -> None: + super().__init__() + assert quantiles.shape[0] > 0 + quantiles = quantiles.T + self.q = torch.nn.Parameter(quantiles, requires_grad=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.searchsorted(self.q, x.T).T + + torch.manual_seed(42) + + # Create contiguous tensor first + numel = 1 + for dim in quantiles_shape: + numel *= dim + data = torch.randn(numel, dtype=torch.float32, device="cuda") + + # Create tensor with specified shape and strides + quantiles = torch.as_strided( + data, size=quantiles_shape, stride=quantiles_strides + ) + + quantiles = torch.sort(quantiles, dim=0)[0] + + x_shape = (batch_size,) + quantiles_shape[1:] + x = torch.randn(*x_shape, dtype=torch.float32, device="cuda") + + foo = Foo(quantiles) + foo_compiled = torch.compile(Foo(quantiles), fullgraph=True) + + # Test eager vs compiled + with torch.no_grad(): + eager = foo(x) + compiled = foo_compiled(x) + + self.assertEqual(eager, compiled) + def test_identity_load(self): device = "cuda" diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index c46c3b86055cb..db15ff03e0c11 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -383,7 +383,7 @@ def non_mut(x): foo = get_compile_fn(backend)(foo) with capture_stderr() as captured_output: - for i in range(3): + for _ in range(3): torch.compiler.cudagraph_mark_step_begin() inp = torch.rand([4], device="cuda") @@ -415,7 +415,7 @@ def non_mut(x): foo = get_compile_fn(backend)(foo) with capture_stderr() as captured_output: - for i in range(3): + for _ in range(3): torch.compiler.cudagraph_mark_step_begin() inp = torch.rand([4], device="cuda") @@ -493,7 +493,7 @@ def inp(): # Should warn for current_node=None mut(inp()) - for i in range(3): + for _ in range(3): torch.compiler.cudagraph_mark_step_begin() tmp = foo(inp()) mut(tmp) # should not warn @@ -945,35 +945,46 @@ def f(x, flag): self.assertEqual(num_partitions, 1) @torch.library.custom_op("mylib::baz", mutates_args=()) - def baz(x: torch.Tensor, flag: int) -> torch.Tensor: + def baz(x: torch.Tensor) -> torch.Tensor: return x.clone() @baz.register_fake - def _(x, flag): + def _(x): return x.clone() - def should_partition(x, flag): - return flag + # custom_should_partition_ops takes effect which lead to 2 partitions + torch._inductor.config.custom_should_partition_ops = ["mylib::baz"] - torch._inductor.scheduler.register_should_partition_rule( - torch.ops.mylib.baz.default, should_partition - ) - - def f(x, flag): + def f(x): x = x + 1 - x = baz(x, flag) + x = baz(x) x = x + 1 return x f_compiled = torch.compile(f, mode="reduce-overhead", fullgraph=True) - _, code = run_and_get_code(f_compiled, x, True) + _, code = run_and_get_code(f_compiled, x) num_partitions = get_num_partitions(code) self.assertEqual(num_partitions, 2) - _, code = run_and_get_code(f_compiled, x, False) + # update the config should NOT force recompile + torch._inductor.config.custom_should_partition_ops = [] + with torch.compiler.set_stance("fail_on_recompile"): + f_compiled(x) + + # run_and_get_code forces recompile. Now we should cache miss, recompile, and + # only have 1 partition. + _, code = run_and_get_code(f_compiled, x) num_partitions = get_num_partitions(code) self.assertEqual(num_partitions, 1) + # test that op_overload name takes effect which lead to 2 partitions + torch._inductor.config.custom_should_partition_ops = ["mylib::baz.default"] + + f_compiled = torch.compile(f, mode="reduce-overhead", fullgraph=True) + _, code = run_and_get_code(f_compiled, x) + num_partitions = get_num_partitions(code) + self.assertEqual(num_partitions, 2) + @torch._inductor.config.patch("graph_partition", True) @torch._inductor.config.patch("implicit_fallbacks", True) def test_graph_partition_with_memory_plan_reuse(self): @@ -2169,7 +2180,7 @@ def bwd(loss): model = torch.nn.Linear(10, 10, bias=False, device="cuda") x = torch.randn(10, 10, device="cuda") - for i in range(5): + for _ in range(5): out = model(x) bwd(out.sum()) model.weight.grad = None @@ -4494,7 +4505,7 @@ def fn(x, y): ] for i, compile_fn in enumerate(compile_fns): torch.manual_seed(0) - for index in range(3): + for _ in range(3): x = torch.randn(4, 4, device=device, requires_grad=True) y = torch.randn(4, 4, device=device, requires_grad=True) diff --git a/test/inductor/test_custom_op_autotune.py b/test/inductor/test_custom_op_autotune.py new file mode 100644 index 0000000000000..adc46a0f390a4 --- /dev/null +++ b/test/inductor/test_custom_op_autotune.py @@ -0,0 +1,507 @@ +# Owner(s): ["module: inductor"] +""" +Tests for custom operation autotuning with PyTorch Inductor. + +Validates that custom ops can be registered with multiple CustomOpConfigs, where each +config specifies an optional decomposition function and its associated parameters. +Inductor benchmarks all variants and automatically selects the best performing one. +""" + +import torch +from torch._inductor import config +from torch._inductor.kernel.custom_op import ( + CustomOpConfig, + register_custom_op_autotuning, +) +from torch._inductor.test_case import run_tests, TestCase +from torch.testing._internal.common_utils import skipIfXpu +from torch.testing._internal.inductor_utils import HAS_GPU + + +torch.set_float32_matmul_precision("high") + + +class TestCustomOpAutoTune(TestCase): + """Test custom operation autotuning functionality.""" + + def setUp(self) -> None: + """Set up test environment with appropriate device and dtype.""" + super().setUp() + self.device = "cuda" if HAS_GPU else "cpu" + self.dtype = torch.float16 if self.device == "cuda" else torch.float32 + + def _run_autotune_test(self, op_object, inputs, expected, test_name): + """Shared test infrastructure for autotuning tests.""" + + @torch.compile + def test_model(*args): + return op_object(*args) + + torch._dynamo.reset() + autotune_backends = "TRITON" if self.device == "cuda" else "ATEN" + + with config.patch( + max_autotune=True, + max_autotune_gemm_backends=autotune_backends, + fx_graph_cache=False, + benchmark_kernel=True, + ): + compiled_result = test_model(*inputs) + + self.assertEqual( + compiled_result.shape, expected.shape, f"{test_name} shape mismatch" + ) + torch.testing.assert_close( + compiled_result, + expected, + rtol=2e-1, + atol=5e-1, + msg=f"{test_name} numerical mismatch", + ) + + def _assert_implementations_equivalent(self, decompositions, inputs, op_name): + """Utility to assert that all implementations produce equivalent results.""" + implementations = [(func.__name__, func) for func in decompositions] + results = {} + for name, impl in implementations: + result = impl(*inputs) + results[name] = result + + # Basic sanity checks + self.assertTrue( + torch.isfinite(result).all(), + f"{op_name} {name} produced non-finite values", + ) + + # Verify numerical equivalence + reference_name, reference_result = next(iter(results.items())) + for name, result in results.items(): + if name != reference_name: + rtol = 1e-1 if "Approximated" in name else 1e-2 + atol = 1e-1 if "Approximated" in name else 1e-2 + torch.testing.assert_close( + result, + reference_result, + rtol=rtol, + atol=atol, + msg=f"{op_name} {name} differs from {reference_name}", + ) + + def _create_rmsnorm_inputs(self, batch_size=32, seq_len=2048, hidden_dim=512): + """Create test inputs for RMSNorm operations.""" + input_tensor = torch.randn( + batch_size, + seq_len, + hidden_dim, + device=self.device, + dtype=self.dtype, + requires_grad=False, + ) + weight = torch.randn( + hidden_dim, device=self.device, dtype=self.dtype, requires_grad=False + ) + return input_tensor, weight + + def _create_mlp_inputs( + self, + batch_size=2, + seq_len=32, + hidden_dim=512, + intermediate_dim=1024, + output_dim=256, + ): + """Create test inputs for MLP operations.""" + input_tensor = torch.randn( + batch_size, + seq_len, + hidden_dim, + device=self.device, + dtype=self.dtype, + requires_grad=False, + ) + gate_weight = torch.randn( + hidden_dim, + intermediate_dim, + device=self.device, + dtype=self.dtype, + requires_grad=False, + ) + up_weight = torch.randn( + hidden_dim, + intermediate_dim, + device=self.device, + dtype=self.dtype, + requires_grad=False, + ) + down_weight = torch.randn( + intermediate_dim, + output_dim, + device=self.device, + dtype=self.dtype, + requires_grad=False, + ) + return input_tensor, gate_weight, up_weight, down_weight + + @skipIfXpu + def test_rmsnorm_custom_op_autotune_with_dynamic_shape(self): + """Test RMSNorm autotuning with multiple decomposition variants and dynamic shapes. + + Validates: + - Multiple decomposition implementations with different computational approaches + - Dynamic shape handling across multiple compilations + """ + test_op_name = f"test_lib::rmsnorm_{id(self)}" + + def rmsnorm_decomposition1( + x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8 + ) -> torch.Tensor: + """Variance-based approach: compute variance then rsqrt.""" + variance = x.pow(2).mean(dim=-1, keepdim=True) + rstd = torch.rsqrt(variance + eps) + return x * rstd * weight + + def rmsnorm_decomposition2( + x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8 + ) -> torch.Tensor: + """Separate normalization and scaling: compute normalized value then scale.""" + x_var = x + variance = x_var.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = x * weight + return x + + @torch.library.custom_op(test_op_name, mutates_args=()) + def test_rmsnorm_op( + input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8 + ) -> torch.Tensor: + return torch.nn.functional.rms_norm( + input_tensor, input_tensor.shape[-1:], weight, eps=eps + ) + + @test_rmsnorm_op.register_fake + def _(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8): + return torch.empty_like(input_tensor) + + decompositions = [ + rmsnorm_decomposition1, + rmsnorm_decomposition2, + ] + + register_custom_op_autotuning( + test_rmsnorm_op, + configs=[CustomOpConfig(decomp) for decomp in decompositions], + name="test_rmsnorm_autotuned", + input_gen_fns={ + "x": lambda x: torch.randn_like(x, device=self.device) * 0.02, + "weight": lambda weight: torch.ones_like(weight, device=self.device), + }, + ) + + # Test multiple shapes to verify dynamic shape handling + test_shapes = [(2, 16, 128), (8, 32, 256)] + + for i, (batch_size, seq_len, hidden_dim) in enumerate(test_shapes): + input_tensor, weight = self._create_rmsnorm_inputs( + batch_size, seq_len, hidden_dim + ) + + # Test numerical equivalence for all decompositions + self._assert_implementations_equivalent( + decompositions, (input_tensor, weight), f"RMSNorm_{i}" + ) + + # Test autotuning + expected = rmsnorm_decomposition1(input_tensor, weight) + self._run_autotune_test( + test_rmsnorm_op, (input_tensor, weight), expected, f"RMSNorm_{i}" + ) + + @skipIfXpu + def test_mlp_custom_op_autotune(self): + """Test MLP autotuning with method parameter controlling different decomposition variants. + + Validates parametric tuning where the same decomposition function uses different + algorithmic approaches based on a method parameter (standard matmul, batched mm, fused weights). + """ + test_op_name = f"test_lib::mlp_{id(self)}" + + def mlp_variants( + input_tensor: torch.Tensor, + gate_weight: torch.Tensor, + up_weight: torch.Tensor, + down_weight: torch.Tensor, + method: int = 0, + ) -> torch.Tensor: + """MLP implementation with different computational approaches controlled by method parameter.""" + + if method == 0: + gate_proj = torch.matmul(input_tensor, gate_weight) + up_proj = torch.matmul(input_tensor, up_weight) + gated = torch.relu(gate_proj) * up_proj + return torch.matmul(gated, down_weight) + + elif method == 1: + batch_shape = input_tensor.shape[:-1] + hidden_dim = input_tensor.shape[-1] + output_dim = down_weight.shape[-1] + + input_2d = input_tensor.view(-1, hidden_dim) + + gate_proj = torch.mm(input_2d, gate_weight) + up_proj = torch.mm(input_2d, up_weight) + + gated = torch.relu(gate_proj) * up_proj + output_2d = torch.mm(gated, down_weight) + + return output_2d.view(*batch_shape, output_dim) + + @torch.library.custom_op(test_op_name, mutates_args=()) + def test_mlp_op( + input_tensor: torch.Tensor, + gate_weight: torch.Tensor, + up_weight: torch.Tensor, + down_weight: torch.Tensor, + method: int = 0, + ) -> torch.Tensor: + return mlp_variants( + input_tensor, gate_weight, up_weight, down_weight, method=method + ) + + @test_mlp_op.register_fake + def _( + input_tensor: torch.Tensor, + gate_weight: torch.Tensor, + up_weight: torch.Tensor, + down_weight: torch.Tensor, + method: int = 0, + ): + return torch.empty( + input_tensor.shape[:-1] + (down_weight.shape[-1],), + device=input_tensor.device, + dtype=input_tensor.dtype, + ) + + # Use explicit config with method parameter as tuning knob + register_custom_op_autotuning( + test_mlp_op, + configs=[ + CustomOpConfig(method=0), + CustomOpConfig(method=1), + ], + name="test_mlp_autotuned", + input_gen_fns={ + "input_tensor": lambda fake_tensor: torch.randn_like( + fake_tensor, device=self.device + ) + * 0.1, + "gate_weight": lambda fake_tensor: torch.randn_like( + fake_tensor, device=self.device + ) + * 0.05, + "up_weight": lambda fake_tensor: torch.randn_like( + fake_tensor, device=self.device + ) + * 0.05, + "down_weight": lambda fake_tensor: torch.randn_like( + fake_tensor, device=self.device + ) + * 0.05, + }, + ) + + # Create test inputs + input_tensor, gate_weight, up_weight, down_weight = self._create_mlp_inputs() + + # Test that all method variants produce numerically equivalent results + expected = mlp_variants( + input_tensor, gate_weight, up_weight, down_weight, method=0 + ) + + # Test autotuning + self._run_autotune_test( + test_mlp_op, + (input_tensor, gate_weight, up_weight, down_weight), + expected, + "MLP", + ) + + def _create_decompose_k_inputs(self, m=256, k=65536, n=1024): + """Create test inputs for decompose_k matrix multiplication - divisible by all k_splits values.""" + # Ensure k is divisible by all k_splits values: [2, 32, 64, 128, 256] + k = ((k + 255) // 256) * 256 # Round up to nearest multiple of 256 + a = torch.randn(m, k, device=self.device, dtype=self.dtype, requires_grad=False) + b = torch.randn(k, n, device=self.device, dtype=self.dtype, requires_grad=False) + return a, b + + @skipIfXpu + def test_decompose_k_custom_op_autotune(self): + """Test decompose_k autotuning with parametric tuning for k_splits values. + + Validates numerical parameter sweep where k_splits controls how the K dimension + is decomposed for matrix multiplication (k_splits in [32, 64, 128, 256]). + """ + test_op_name = f"test_lib::decompose_k_{id(self)}" + + def decompose_k_implementation( + a: torch.Tensor, b: torch.Tensor, k_splits: int = 4 + ) -> torch.Tensor: + """Matrix multiply with k-way decomposition - Python implementation.""" + m = a.shape[0] + n = b.shape[1] + k = a.shape[1] + + k_parts = k // k_splits + B = k_splits + + a_reshaped = torch.permute( + a.reshape(m, B, k_parts), (1, 0, 2) + ) # [B, m, k_parts] + b_reshaped = b.reshape(B, k_parts, n) # [B, k_parts, n] + + result = torch.bmm(a_reshaped, b_reshaped) # [B, m, n] + + return torch.sum(result, dim=0) # [m, n] + + @torch.library.custom_op(test_op_name, mutates_args=()) + def test_decompose_k_op( + a: torch.Tensor, b: torch.Tensor, k_splits: int = 4 + ) -> torch.Tensor: + """Matrix multiply with k-way decomposition - custom op using the decomposition.""" + return decompose_k_implementation(a, b, k_splits) + + @test_decompose_k_op.register_fake + def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4): + return torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=a.dtype) + + # Register autotuning with different k_splits values using decomposition function + register_custom_op_autotuning( + test_decompose_k_op, + configs=[ + CustomOpConfig(k_splits=2), + CustomOpConfig(k_splits=4), + CustomOpConfig(k_splits=8), + CustomOpConfig(k_splits=16), + CustomOpConfig(k_splits=32), + CustomOpConfig(k_splits=64), + CustomOpConfig(k_splits=128), + ], + name="test_decompose_k_autotuned", + input_gen_fns={ + "a": lambda fake_tensor: torch.randn_like( + fake_tensor, device=self.device + ) + * 0.1, + "b": lambda fake_tensor: torch.randn_like( + fake_tensor, device=self.device + ) + * 0.1, + }, + ) + + a, b = self._create_decompose_k_inputs() + expected = a @ b + self._run_autotune_test(test_decompose_k_op, (a, b), expected, "DecomposeK") + + @skipIfXpu + def test_multi_parameter_tuning(self): + """Test autotuning with multiple parameters for combinatorial parameter exploration. + + Validates parametric tuning with multiple parameters (scale_mode and chunk_size) + to test combinatorial exploration of the parameter space. + """ + test_op_name = f"test_lib::multi_param_{id(self)}" + + def multi_param_scaling( + x: torch.Tensor, + factor: torch.Tensor, + scale_mode: int = 1, + chunk_size: int = 16, + ) -> torch.Tensor: + """Different scaling approaches controlled by scale_mode parameter.""" + if scale_mode == 1: + # Simple broadcasting + return x * factor + elif scale_mode == 2: + # Process in chunks + batch_size, seq_len = x.shape[:2] + chunks = [] + for start in range(0, seq_len, chunk_size): + end = min(start + chunk_size, seq_len) + chunk = x[:, start:end] + chunks.append(chunk * factor) + return torch.cat(chunks, dim=1) + elif scale_mode == 3: + # Using einsum for scaling + return torch.einsum("...i,i->...i", x, factor) + + @torch.library.custom_op(test_op_name, mutates_args=()) + def multi_param_op( + x: torch.Tensor, + factor: torch.Tensor, + scale_mode: int = 1, + chunk_size: int = 16, + ) -> torch.Tensor: + return multi_param_scaling(x, factor, scale_mode, chunk_size) + + @multi_param_op.register_fake + def _( + x: torch.Tensor, + factor: torch.Tensor, + scale_mode: int = 1, + chunk_size: int = 16, + ): + return torch.empty_like(x) + + # Use explicit configs with scale_mode and chunk_size parameters as tuning knobs + register_custom_op_autotuning( + multi_param_op, + configs=[ + CustomOpConfig(scale_mode=1), # Broadcast + CustomOpConfig(scale_mode=2, chunk_size=16), # Chunked 16 + CustomOpConfig(scale_mode=2, chunk_size=32), # Chunked 32 + CustomOpConfig(scale_mode=3), # Einsum + ], + name="multi_param_autotuned", + input_gen_fns={ + "x": lambda t: torch.randn_like(t, device=self.device) * 0.1, + "factor": lambda t: torch.ones( + t.shape[-1], device=self.device, dtype=t.dtype + ), + }, + ) + + # Create test inputs + test_x = torch.randn(4, 64, 128, device=self.device, dtype=self.dtype) + test_factor = torch.ones(128, device=self.device, dtype=self.dtype) * 2.0 + + # Verify numerical equivalence across all approaches + expected_result = test_x * test_factor + + # Test each scale_mode variant + configs = [ + (1, 16), # broadcast, chunk_size ignored + (2, 16), # chunked with size 16 + (2, 32), # chunked with size 32 + (3, 16), # einsum, chunk_size ignored + ] + + for scale_mode, chunk_size in configs: + result = multi_param_scaling( + test_x, test_factor, scale_mode=scale_mode, chunk_size=chunk_size + ) + torch.testing.assert_close( + result, + expected_result, + rtol=1e-5, + atol=1e-5, + msg=f"scale_mode {scale_mode} with chunk_size {chunk_size} not equivalent to expected", + ) + + # Test autotuning + self._run_autotune_test( + multi_param_op, (test_x, test_factor), expected_result, "MultiParam" + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/inductor/test_flex_flash.py b/test/inductor/test_flex_flash.py index f75eff65382df..5f3735ac87e0d 100644 --- a/test/inductor/test_flex_flash.py +++ b/test/inductor/test_flex_flash.py @@ -122,16 +122,52 @@ def cuda_kernel_profiler(kernel_pattern="flash_attncute"): result["found"] = any(kernel_pattern in name for name in kernel_names) -def flash_vs_triton(q, k, v, score_mod=None, rtol=5e-3, atol=5e-3): +def flash_vs_triton(q, k, v, score_mod=None, block_mask=None, rtol=2): compiled_fn = torch.compile(flex_attention) + + out_ref_fp32 = flex_attention( + q.to(torch.float32), + k.to(torch.float32), + v.to(torch.float32), + score_mod=score_mod, + block_mask=block_mask, + ).to(q.dtype) + out_flash = compiled_fn( - q, k, v, score_mod=score_mod, kernel_options={"force_flash": True} + q, + k, + v, + score_mod=score_mod, + block_mask=block_mask, + kernel_options={"force_flash": True}, + ) + out_triton = compiled_fn( + q, + k, + v, + score_mod=score_mod, + block_mask=block_mask, + kernel_options={"force_flash": False}, ) - out_no_flash = compiled_fn( - q, k, v, score_mod=score_mod, kernel_options={"force_flash": False} + + assert out_flash.shape == out_ref_fp32.shape == out_triton.shape + assert not torch.isnan(out_flash).any() + assert not torch.isnan(out_triton).any() + assert not torch.isnan(out_ref_fp32).any() + assert torch.isfinite(out_flash).all() + assert torch.isfinite(out_triton).all() + assert torch.isfinite(out_ref_fp32).all() + + fwd_atol = 2 * (out_ref_fp32 + 0.3 - 0.3 - out_ref_fp32).abs().max().item() + + triton_error = (out_triton - out_ref_fp32).abs().max().item() + flash_error = (out_flash - out_ref_fp32).abs().max().item() + + assert flash_error <= rtol * triton_error + fwd_atol, ( + f"Flash error {flash_error:.2e} exceeds {rtol}x Triton error {triton_error:.2e} + {fwd_atol:.2e}" ) - torch.testing.assert_close(out_flash, out_no_flash, rtol=rtol, atol=atol) - return out_flash, out_no_flash + + return out_flash, out_triton, out_ref_fp32 def name_fn(score_mod): @@ -162,26 +198,6 @@ def test_flash_attention_unfriendly_seqlen_with_causal( q, k, v = create_test_tensors(seq_len=seq_len, dtype=dtype, device=device) flash_vs_triton(q, k, v, score_mod=_causal) - @dtypes(torch.float16, torch.bfloat16) - def test_force_flash_error_with_block_mask(self, device, dtype): - """Test that force_flash=True raises error when BlockMask is provided.""" - q, k, v = create_test_tensors(dtype=dtype, device=device) - - # Create a causal block mask - def causal_mask(b, h, q_idx, kv_idx): - return q_idx >= kv_idx - - block_mask = create_block_mask(causal_mask, 2, 4, 512, 512, device=device) - - compiled_fn = torch.compile(flex_attention) - with self.assertRaisesRegex( - RuntimeError, - r"force_flash=True but flash attention cannot be used.*BlockMask.*not supported", - ): - compiled_fn( - q, k, v, block_mask=block_mask, kernel_options={"force_flash": True} - ) - @dtypes(torch.float16, torch.bfloat16) def test_flash_attention_kernel_called(self, device, dtype): """Test that flash attention kernel is actually called when force_flash=True.""" @@ -252,12 +268,26 @@ def test_flash_attention_with_dual_buffer_bias(self, device, dtype): score_mod = create_dual_buffer_bias(num_heads=4, seq_len=512, dtype=dtype) flash_vs_triton(q, k, v, score_mod=score_mod) + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_with_score_view_buffer(self, device, dtype): + """Score modifier should load from a non-contiguous view.""" + num_heads = 4 + q, k, v = create_test_tensors(num_heads=num_heads, dtype=dtype, device=device) + + base_scales = torch.rand(num_heads, 2, device=device, dtype=dtype) + 0.5 + scales_view = base_scales[:, 0] + assert not scales_view.is_contiguous() + + def score_view_mod(score, b, h, q_idx, kv_idx): + return score + scales_view[h] + + flash_vs_triton(q, k, v, score_mod=score_view_mod) + @dtypes(torch.float16, torch.bfloat16) def test_force_flash_error_with_requires_grad(self, device, dtype): """Test that force_flash=True raises error when tensor requires gradients.""" q, k, v = create_test_tensors(dtype=dtype, device=device) - # Create a score mod with requires_grad tensor bias = torch.randn(4, device=device, dtype=dtype, requires_grad=True) def score_mod_with_grad(score, b, h, q_idx, kv_idx): @@ -276,6 +306,166 @@ def score_mod_with_grad(score, b, h, q_idx, kv_idx): kernel_options={"force_flash": True}, ) + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_with_block_mask(self, device, dtype): + """Test flash attention with block mask and mask_mod.""" + q, k, v = create_test_tensors(dtype=dtype, device=device) + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + block_mask = create_block_mask(causal_mask, 2, 4, 512, 512, device=device) + flash_vs_triton(q, k, v, block_mask=block_mask) + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_block_mask_with_score_mod(self, device, dtype): + """Test flash attention with both block mask and score_mod.""" + q, k, v = create_test_tensors(dtype=dtype, device=device) + + def causal_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + block_mask = create_block_mask(causal_mask, 2, 4, 512, 512, device=device) + flash_vs_triton(q, k, v, score_mod=_times_two, block_mask=block_mask) + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_with_mask_mod_buffer(self, device, dtype): + """Test flash attention with mask_mod that loads from buffer.""" + q, k, v = create_test_tensors( + batch_size=2, num_heads=4, dtype=dtype, device=device + ) + + mask_bias = torch.randn(4, device=device, dtype=dtype) * 0.1 + + def custom_mask(b, h, q_idx, kv_idx): + bias_value = mask_bias[h] + return (q_idx >= kv_idx) | (bias_value > 0) + + block_mask = create_block_mask(custom_mask, 2, 4, 512, 512, device=device) + flash_vs_triton(q, k, v, block_mask=block_mask) + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_with_doc_mask(self, device, dtype): + """Test flash attention with a document-aware mask_mod.""" + # Use shorter sequences to make the document layout explicit. + seq_len = 128 + q, k, v = create_test_tensors( + batch_size=2, num_heads=4, seq_len=seq_len, dtype=dtype, device=device + ) + lengths_per_batch = ( + (16, 31, 25, 56), # batch 0 + (40, 9, 23, 56), # batch 1 uses a different document arrangement + ) + document_ids = [] + for lengths in lengths_per_batch: + assert sum(lengths) == seq_len + doc_tokens = [] + for doc_id, length in enumerate(lengths): + doc_tokens.extend([doc_id] * length) + document_ids.append(doc_tokens) + document_ids = torch.tensor(document_ids, device=device, dtype=torch.long) + + def document_mask(b, _h, q_idx, kv_idx): + doc_id_q = document_ids[b, q_idx // 2] + doc_id_kv = document_ids[b, kv_idx] + return doc_id_q == doc_id_kv + + block_mask = create_block_mask( + document_mask, 2, 1, seq_len, seq_len, device=device + ) + flash_vs_triton(q, k, v, block_mask=block_mask) + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_mask_mod_with_view_buffer(self, device, dtype): + """Mask modifier should support buffers that are non-contiguous views.""" + batch_size, num_heads, seq_len = 2, 4, 512 + q, k, v = create_test_tensors( + batch_size=batch_size, num_heads=num_heads, dtype=dtype, device=device + ) + + base_bias = torch.randn(num_heads, 3, device=device, dtype=dtype) + mask_bias_view = base_bias[:, 1] + assert not mask_bias_view.is_contiguous() + + def mask_with_view_buffer(b, h, q_idx, kv_idx): + bias_value = mask_bias_view[h] + double_bias = bias_value * 2 + return (q_idx >= kv_idx) | (double_bias > 0) + + block_mask = create_block_mask( + mask_with_view_buffer, + batch_size, + num_heads, + seq_len, + seq_len, + device=device, + ) + flash_vs_triton(q, k, v, block_mask=block_mask) + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_mask_mod_with_dual_buffers(self, device, dtype): + """Mask modifier should support multiple captured buffers.""" + batch_size, num_heads, seq_len = 2, 4, 512 + q, k, v = create_test_tensors( + batch_size=batch_size, num_heads=num_heads, dtype=dtype, device=device + ) + + head_bias = torch.randn(num_heads, device=device, dtype=dtype) * 0.2 + batch_bias = torch.randn(batch_size, device=device, dtype=dtype) * 0.2 + + def dual_buffer_mask(b, h, q_idx, kv_idx): + head_term = head_bias[h] + batch_term = batch_bias[b] + causal = q_idx >= kv_idx + bias_cond = (head_term + batch_term).to(torch.float32) > 0 + return causal | bias_cond + + block_mask = create_block_mask( + dual_buffer_mask, batch_size, num_heads, seq_len, seq_len, device=device + ) + flash_vs_triton(q, k, v, block_mask=block_mask) + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_score_mod_with_many_buffer_indexing(self, device, dtype): + batch_size, num_heads, seq_len = 2, 4, 512 + q, k, v = create_test_tensors( + batch_size=batch_size, num_heads=num_heads, dtype=dtype, device=device + ) + + head_bias = torch.randn(num_heads, device=device, dtype=dtype) * 0.15 + query_scale = torch.randn(seq_len, device=device, dtype=dtype) * 0.05 + kv_scale = torch.randn(seq_len, device=device, dtype=dtype) * 0.05 + batch_bias = torch.randn(batch_size, device=device, dtype=dtype) * 0.1 + + def complex_score(score, b, h, q_idx, kv_idx): + head_term = head_bias[h] + query_term = query_scale[q_idx] + kv_term = kv_scale[kv_idx] + batch_term = batch_bias[b] + return score + head_term + query_term - kv_term + batch_term + + flash_vs_triton(q, k, v, score_mod=complex_score) + + @dtypes(torch.float16, torch.bfloat16) + def test_flash_attention_with_score_and_mask_buffers(self, device, dtype): + """Test flash attention with both score_mod and mask_mod using buffers.""" + q, k, v = create_test_tensors( + batch_size=2, num_heads=4, dtype=dtype, device=device + ) + + score_bias = torch.randn(4, device=device, dtype=dtype) * 0.2 + mask_bias = torch.randn(4, device=device, dtype=dtype) * 0.1 + + def score_with_buffer(score, b, h, q_idx, kv_idx): + return score + score_bias[h] + + def mask_with_buffer(b, h, q_idx, kv_idx): + bias_value = mask_bias[h] + return (q_idx >= kv_idx) | (bias_value > 0) + + block_mask = create_block_mask(mask_with_buffer, 2, 4, 512, 512, device=device) + flash_vs_triton(q, k, v, score_mod=score_with_buffer, block_mask=block_mask) + instantiate_device_type_tests(TestFlexFlash, globals(), only_for="cuda") diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index 1d7551ba1e5aa..f26a2347e4e86 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -794,14 +794,16 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): _get_torch_cuda_version() < (12, 9), "cuBLAS blockwise scaling added in CUDA 12.9", ) - @parametrize( - "shape", ((16, 256, 256), (1024, 512, 1024)) - ) # TODO (jananisriram): add scaling recipe overrides for shapes like (16, 256, 64) and (256, 16, 64) + @parametrize("shape", ((16, 256, 256), (1024, 512, 1024))) @parametrize("use_fast_accum", (False, True)) - def test_blockwise1x128_blockwise128x128_scaling( + @parametrize( + "scaling_block_sizes", ((1, 128, 128, 128), (1, 128, 1, 128)) + ) # (BlockWise1x128, BlockWise128x128), (BlockWise1x128, BlockWise1x128) + def test_main_loop_scaling( self, shape: tuple[int, int, int], use_fast_accum: bool, + scaling_block_sizes: tuple[int, int, int, int], ): # Only bf16 output type is supported for non-tensorwise scaling, not fp32 dtype: torch.dtype = torch.bfloat16 @@ -814,20 +816,28 @@ def test_blockwise1x128_blockwise128x128_scaling( w = torch.randn(N, K, dtype=dtype, device=device) bias = None + am, ak, bn, bk = scaling_block_sizes + # quantize weight (prior to inference) w_fp8, w_inverse_scale = _quantize_blockwise( - w, dtype_float8, block_outer=128, block_inner=128 + w, dtype_float8, block_outer=bn, block_inner=bk ) w_t_fp8 = w_fp8.t() - w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N) + if (bn, bk) == (1, 128): + w_inverse_scale = ( + w_inverse_scale.t().contiguous().t().t() + ) # 1x128 blocks need scales to be outer-dim-major + else: + w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N) # quantize input x x_fp8, x_inverse_scale = _quantize_blockwise( - x, dtype_float8, block_outer=1, block_inner=128 + x, dtype_float8, block_outer=am, block_inner=ak ) - x_inverse_scale = ( - x_inverse_scale.t().contiguous().t() - ) # 1x128 blocks need scales to be outer-dim-major + if (am, ak) == (1, 128): + x_inverse_scale = ( + x_inverse_scale.t().contiguous().t() + ) # 1x128 blocks need scales to be outer-dim-major def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): y = torch._scaled_mm( @@ -872,9 +882,15 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): FileCheck().check( f"SCALE_RECIPE_A : tl.constexpr = {ScalingType.BlockWise1x128.value}" ).run(code[0]) + + if (bn, bk) == (1, 128): + check_scale_recipe_b = ScalingType.BlockWise1x128.value + else: + check_scale_recipe_b = ScalingType.BlockWise128x128.value FileCheck().check( - f"SCALE_RECIPE_B : tl.constexpr = {ScalingType.BlockWise128x128.value}" + f"SCALE_RECIPE_B : tl.constexpr = {check_scale_recipe_b}" ).run(code[0]) + self.assertEqual(y_eager.dtype, dtype) self.assertEqual(y_compiled.dtype, dtype) torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) diff --git a/test/inductor/test_fxir_backend.py b/test/inductor/test_fxir_backend.py index 72eb37c1e1b96..cb70eb7b22f5c 100644 --- a/test/inductor/test_fxir_backend.py +++ b/test/inductor/test_fxir_backend.py @@ -1114,6 +1114,58 @@ def test_mismatched_branch_dynamic(self, pred: bool): dynamic_shapes=dynamic_shapes, ) + def test_const_folded_subgraph(self): + """ + If a graph only contains a call_module node to a subgraph, + where the subgraph can be const-folded away, + validate the fake mode used in FXConverter generation is not None. + """ + device = self.device + shape = (5, 10) + + class Submodule(torch.nn.Module): + def forward(self): + return torch.randn(*shape, device=device) + 1 + + # Create a parent graph with this module as a subgraph and output + ep = torch.export.export(Submodule(), ()) + parent_graph = torch.fx.Graph() + call_mod = parent_graph.call_module("sub", args=()) + get_item = parent_graph.call_function( + operator.getitem, args=(call_mod, slice(None)) + ) + parent_graph.output((get_item,)) + parent = torch.fx.GraphModule({"sub": ep.module()}, parent_graph) + + # Verify FXConverter.generate uses non-null fake mode + # Intercept _set_node_metadata_hook to ensure fake_mode is not None + orig_set_hook = torch._inductor.codegen.wrapper_fxir._set_node_metadata_hook + called = False + + def mock_set_hook(gm: torch.fx.GraphModule, fn): + nonlocal called + called = True + # Please update this check if `fake_mode` is + # no longer used in FXConverter call to _node_metadata_hook + self.assertTrue("fake_mode" in fn.keywords) + self.assertIsNotNone(fn.keywords["fake_mode"]) + return orig_set_hook(gm, fn) + + self.assertFalse(called) + with unittest.mock.patch.object( + torch._inductor.codegen.wrapper_fxir, + "_set_node_metadata_hook", + mock_set_hook, + ): + args = () + compiled = torch._inductor.aot_compile( + parent, args, options={"fx_wrapper": True} + ) + self.assertTrue(called) + + compiled_out = compiled(*args) + self.assertEqual(compiled_out.shape, shape) + class TestReplaceFloorDiv(InductorTestCase): """ diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index 611b1dd966e1b..f875a7c7f5bac 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -443,10 +443,16 @@ def test_print_floor_div(self): s2 = sympy.S(-1) expr = FloorDiv(s1, s2) self.assertEqual(pexpr(expr), "(-1)*s1") - self.assertEqual(cexpr(expr), "(-1LL)*s1") if sys.platform in [ - "darwin", - "win32", - ] else "(-1L)*s1" + self.assertEqual( + cexpr(expr), + "(-1LL)*s1" + if sys.platform + in [ + "darwin", + "win32", + ] + else "(-1L)*s1", + ) s0 = sympy.Symbol("s0", integer=True) s2 = sympy.S(2) diff --git a/test/inductor/test_inductor_freezing.py b/test/inductor/test_inductor_freezing.py index 45045a3c41893..f213fba0d4c3e 100644 --- a/test/inductor/test_inductor_freezing.py +++ b/test/inductor/test_inductor_freezing.py @@ -970,9 +970,7 @@ def debug_inductor_force_stride_order(orig_fn, input_tensor, stride): self.assertEqual(len(actual_outputs), len(expected_outputs)) self.assertEqual(2, len(actual_outputs)) - for i, actual, expected in zip( - itertools.count(), actual_outputs, expected_outputs - ): + for actual, expected in zip(actual_outputs, expected_outputs): self.assertEqual(expected, actual) if self.device == "cpu": diff --git a/test/inductor/test_inductor_scheduler.py b/test/inductor/test_inductor_scheduler.py index ef383bb8fee98..3a3583c144ebd 100644 --- a/test/inductor/test_inductor_scheduler.py +++ b/test/inductor/test_inductor_scheduler.py @@ -210,7 +210,7 @@ def _create_mock_node(self, name: str, reads: list[str], writes: list[str]) -> M return node -instantiate_device_type_tests(TestScheduler, globals()) +instantiate_device_type_tests(TestScheduler, globals(), allow_xpu=True) if __name__ == "__main__": run_tests() diff --git a/test/inductor/test_lookup_table.py b/test/inductor/test_lookup_table.py new file mode 100644 index 0000000000000..250a822267833 --- /dev/null +++ b/test/inductor/test_lookup_table.py @@ -0,0 +1,1063 @@ +# Owner(s): ["module: inductor"] +import re +import unittest +from functools import partial +from typing import Any, Optional, Union +from unittest.mock import patch + +import torch +import torch.nn as nn +from torch._inductor import config as inductor_config +from torch._inductor.choices import InductorChoices +from torch._inductor.kernel_inputs import MMKernelInputs +from torch._inductor.lookup_table.choices import LookupTableChoices +from torch._inductor.select_algorithm import ( + add_preprocessing_fn, + clear_preprocessing_fns, + ExternKernelCaller, + TritonTemplateCaller, +) +from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.utils import fresh_cache, get_num_sms, TMA_DESCRIPTOR_SIZE +from torch._inductor.virtualized import V +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + TEST_WITH_ROCM, +) +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA_AND_TRITON, HAS_GPU +from torch.utils._triton import has_triton_stable_tma_api, has_triton_tma_device + + +class MockTensorNode: + """Mock input node that wraps a real tensor for testing""" + + def __init__(self, tensor: torch.Tensor): + self.tensor = tensor + + def get_device(self) -> torch.device: + return self.tensor.device + + def get_dtype(self) -> torch.dtype: + return self.tensor.dtype + + def get_size(self) -> tuple[int, ...]: + return tuple(self.tensor.shape) + + def get_stride(self) -> tuple[int, ...]: + return tuple(self.tensor.stride()) + + +class MockMMKernelInputs(MMKernelInputs): + """Mock MMKernelInputs that subclasses the real class and uses real tensors""" + + def __init__( + self, + tensors: list[torch.Tensor], + scalars: Optional[dict[str, Union[float, int]]] = None, + mat1_idx: int = -2, + mat2_idx: int = -1, + ): + """Initialize with real tensors, creating mock nodes for the base class""" + mock_nodes = [MockTensorNode(t) for t in tensors] + super().__init__(mock_nodes, scalars, mat1_idx=mat1_idx, mat2_idx=mat2_idx) + self.tensors = tensors # Keep reference to original tensors + + def shapes_hinted(self) -> tuple[tuple[int, ...], ...]: + """Delegate to symbolic since real tensors already have int shapes""" + return self.shapes_symbolic() + + def strides_hinted(self) -> tuple[tuple[int, ...], ...]: + """Delegate to symbolic since real tensors already have int strides""" + return self.strides_symbolic() # pyre-ignore + + def mnk_hinted(self) -> tuple[int, int, int]: + """Delegate to symbolic since real tensors already have int dimensions""" + return self.mnk_symbolic() # pyre-ignore + + @property + def device_type(self) -> Optional[str]: + return self.tensors[0].device.type + + +class BaseLookupTableTest(TestCase): + """Base class for lookup table tests with common setup and utilities""" + + def setUp(self): + super().setUp() + self.original_table = inductor_config.lookup_table.table + self.original_max_autotune = getattr(inductor_config, "max_autotune", False) + inductor_config.max_autotune = True + # Set the lookup table choices handler + V.set_choices_handler(LookupTableChoices()) + + def tearDown(self): + inductor_config.lookup_table.table = self.original_table + inductor_config.max_autotune = self.original_max_autotune + # Restore original choices handler + V.set_choices_handler(InductorChoices()) + super().tearDown() + + def create_mock_mm_kernel_inputs( + self, + shapes: Optional[list[tuple[int, ...]]] = None, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.float32, + scalars: Optional[dict[str, Union[float, int]]] = None, + ) -> MockMMKernelInputs: + """Create MockMMKernelInputs with real tensors""" + if shapes is None: + shapes = [(128, 128), (128, 128)] # Default MM shapes + + tensors = [] + for shape in shapes: + # Create a real tensor with the specified shape, device, and dtype + tensor = torch.randn(shape, device=device, dtype=dtype) + tensors.append(tensor) + + return MockMMKernelInputs(tensors, scalars) + + def create_lookup_key(self, method, kernel_inputs): + """Create a lookup key using LookupTableChoices""" + choices = LookupTableChoices() + return choices.make_lookup_key(kernel_inputs, method) + + def create_config(self, template_id, **kwargs): + """Create a backend configuration with template_id field""" + config = {"template_id": template_id} + + # Add minimal defaults based on template type + if template_id == "triton": + config.update( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "BLOCK_K": 64, + "num_stages": 2, + "num_warps": 2, + "EVEN_K": True, + "USE_FAST_ACCUM": False, + "ACC_TYPE": "tl.float32", + "GROUP_M": 8, + } + ) + elif template_id == "tma": + config.update( + { + "BLOCK_M": 256, + "BLOCK_N": 128, + "BLOCK_K": 64, + "num_stages": 4, + "num_warps": 8, + "EVEN_K": True, + "USE_FAST_ACCUM": False, + "ACC_TYPE": "tl.float32", + "GROUP_M": 8, + } + ) + elif template_id == "decompose_k": + config.update({"k": 4}) + + config.update(kwargs) + return config + + +@unittest.skipIf(not HAS_CUDA_AND_TRITON, "CUDA not available") +@instantiate_parametrized_tests +class TestLookupTable(BaseLookupTableTest): + """Consolidated tests for lookup table functionality""" + + def test_lookup_mismatch(self): + """Test mismatch scenario in lookup table""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + lookup_table_data = { + self.create_lookup_key("mm", kernel_inputs): [self.create_config("triton")] + } + + with patch.object(inductor_config.lookup_table, "table", lookup_table_data): + test_choices = LookupTableChoices() + # looking for addmm but created the entry with mm - should mismatch the key and return + # an empty result + result = test_choices.lookup_template_configs( + kernel_inputs, "addmm", ["triton"] + ) + self.assertEqual(result, {}) + + def test_successful_lookup_with_template_filtering(self): + """Test successful lookup that filters configs by template_id""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + config_list = [ + self.create_config("triton", BLOCK_M=128, BLOCK_N=128), + self.create_config("triton", BLOCK_M=64, BLOCK_N=64), + self.create_config("tma", BLOCK_M=256, BLOCK_N=128), + self.create_config("decompose_k", k_split=4), + ] + + lookup_table_data = {self.create_lookup_key("mm", kernel_inputs): config_list} + + with patch.object(inductor_config.lookup_table, "table", lookup_table_data): + test_choices = LookupTableChoices() + + # Test triton template filtering + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"] + ) + assert result is not None, "Result should not be None" + self.assertEqual(len(result["triton"]), 2) + for config in result["triton"]: + self.assertNotIn("template_id", config) + self.assertIn("BLOCK_M", config) + + # Test tma template filtering + result = test_choices.lookup_template_configs(kernel_inputs, "mm", ["tma"]) + assert result is not None, "Result should not be None" + self.assertEqual(len(result["tma"]), 1) + self.assertNotIn("template_id", result["tma"][0]) + self.assertEqual(result["tma"][0]["BLOCK_M"], 256) + + # Test decompose_k template filtering + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["decompose_k"] + ) + assert result is not None, "Result should not be None" + self.assertEqual(len(result["decompose_k"]), 1) + self.assertNotIn("template_id", result["decompose_k"][0]) + self.assertEqual(result["decompose_k"][0]["k_split"], 4) + + def test_empty_table(self): + """Test when template lookup table is empty""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + with patch.object(inductor_config.lookup_table, "table", {}): + test_choices = LookupTableChoices() + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"] + ) + self.assertEqual(result, {}) + + def test_validation_error(self): + """Test validation error for invalid config""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + invalid_config = {"BLOCK_M": 128} # missing template_id + + lookup_table_data = { + self.create_lookup_key("mm", kernel_inputs): [invalid_config] + } + + with patch.object(inductor_config.lookup_table, "table", lookup_table_data): + test_choices = LookupTableChoices() + with self.assertRaises(ValueError) as cm: + test_choices.lookup_template_configs(kernel_inputs, "mm", ["triton"]) + self.assertIn("missing required 'template_id' field", str(cm.exception)) + + def test_cpu_input_returns_empty(self): + """Test that CPU tensor input returns empty dict""" + # Create kernel inputs with CPU tensors + kernel_inputs = self.create_mock_mm_kernel_inputs(device=torch.device("cpu")) + + lookup_table_data = { + self.create_lookup_key("mm", kernel_inputs): [self.create_config("triton")] + } + + with patch.object(inductor_config.lookup_table, "table", lookup_table_data): + test_choices = LookupTableChoices() + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"] + ) + self.assertEqual(result, {}) # Should return empty dict for CPU + + def test_multiple_calls_work(self): + """Test that calling lookup functions multiple times works correctly""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + config_list = [ + self.create_config("triton", BLOCK_M=128), + self.create_config("tma", BLOCK_M=256), + ] + + lookup_table_data = {self.create_lookup_key("mm", kernel_inputs): config_list} + + with patch.object(inductor_config.lookup_table, "table", lookup_table_data): + test_choices = LookupTableChoices() + + # First calls + result1 = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"] + ) + result2 = test_choices.lookup_template_configs(kernel_inputs, "mm", ["tma"]) + assert result1 is not None, "Result1 should not be None" + assert result2 is not None, "Result2 should not be None" + self.assertEqual(len(result1["triton"]), 1) + self.assertEqual(len(result2["tma"]), 1) + + # Second calls should work the same + result3 = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"] + ) + result4 = test_choices.lookup_template_configs(kernel_inputs, "mm", ["tma"]) + assert result3 is not None, "Result3 should not be None" + assert result4 is not None, "Result4 should not be None" + self.assertEqual(len(result3["triton"]), 1) + self.assertEqual(len(result4["tma"]), 1) + + def test_batch_lookup_mixed_entries(self): + """Test batch lookup where some templates have entries and others don't""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + config_list = [ + self.create_config("triton", BLOCK_M=128), + self.create_config("tma", BLOCK_M=256), + # No decompose_k config in lookup table + ] + + lookup_table_data = {self.create_lookup_key("mm", kernel_inputs): config_list} + + with patch.object(inductor_config.lookup_table, "table", lookup_table_data): + test_choices = LookupTableChoices() + + # Test batch lookup with mixed results + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton", "tma", "decompose_k"] + ) + assert result is not None, "Result should not be None" + + # Should have entries for triton and tma, but not decompose_k + self.assertIn("triton", result) + self.assertIn("tma", result) + self.assertNotIn("decompose_k", result) + + self.assertEqual(len(result["triton"]), 1) + self.assertEqual(len(result["tma"]), 1) + self.assertEqual(result["triton"][0]["BLOCK_M"], 128) + self.assertEqual(result["tma"][0]["BLOCK_M"], 256) + + @parametrize( + "config_hash,template_hash,expected_kept", + [ + # Hash matching (config kept) + ("hash123", "hash123", True), + # Hash mismatch (config filtered) + ("hash123", "hash456", False), + # Config without hash (config kept) + (None, "hash123", True), + # Template without hash (config kept) + ("hash123", None, True), + # Both None (config kept) + (None, None, True), + ], + ) + def test_template_hash_checking(self, config_hash, template_hash, expected_kept): + """Test template hash validation behavior""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + config = self.create_config("triton", BLOCK_M=128, BLOCK_N=64) + if config_hash is not None: + config["template_hash"] = config_hash + + template_hash_map = ( + {"triton": template_hash} if template_hash is not None else {} + ) + + lookup_table_data = {self.create_lookup_key("mm", kernel_inputs): [config]} + + with ( + patch.object(inductor_config.lookup_table, "table", lookup_table_data), + patch.object(inductor_config.lookup_table, "check_src_hash", True), + ): + test_choices = LookupTableChoices() + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"], template_hash_map + ) + + if expected_kept: + assert result is not None, "Result should not be None" + self.assertIn("triton", result) + self.assertEqual(len(result["triton"]), 1) + # template_hash should be removed from returned config + self.assertNotIn("template_hash", result["triton"][0]) + else: + # Config was filtered out due to hash mismatch + self.assertEqual(result, {}) + + def test_template_hash_checking_disabled(self): + """Test that hash checking is skipped when config flag is disabled""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + # Create config with mismatching hash + config = self.create_config("triton", BLOCK_M=128, template_hash="hash123") + + # Provide different template hash that would normally cause filtering + template_hash_map = {"triton": "hash456"} + + lookup_table_data = {self.create_lookup_key("mm", kernel_inputs): [config]} + + with ( + patch.object(inductor_config.lookup_table, "table", lookup_table_data), + patch.object( + inductor_config.lookup_table, + "check_src_hash", + False, + ), + ): + test_choices = LookupTableChoices() + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"], template_hash_map + ) + + # Should keep config even with mismatching hash since checking is disabled + assert result is not None, "Result should not be None" + self.assertIn("triton", result) + self.assertEqual(len(result["triton"]), 1) + # template_hash should still be removed from returned config + self.assertNotIn("template_hash", result["triton"][0]) + + def test_template_hash_mixed_scenarios(self): + """Test mixed hash scenarios with multiple configs""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + config_list = [ + self.create_config( + "triton", BLOCK_M=128, template_hash="correct_hash" + ), # Should be kept + self.create_config( + "triton", BLOCK_M=64, template_hash="wrong_hash" + ), # Should be filtered + self.create_config("triton", BLOCK_M=32), # No hash, should be kept + ] + + template_hash_map = {"triton": "correct_hash"} + + lookup_table_data = {self.create_lookup_key("mm", kernel_inputs): config_list} + + with ( + patch.object(inductor_config.lookup_table, "table", lookup_table_data), + patch.object(inductor_config.lookup_table, "check_src_hash", True), + ): + test_choices = LookupTableChoices() + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"], template_hash_map + ) + + assert result is not None, "Result should not be None" + self.assertIn("triton", result) + # Should keep 2 configs: the one with correct hash and the one without hash + self.assertEqual(len(result["triton"]), 2) + + # Check that kept configs have expected BLOCK_M values + kept_block_ms = [config["BLOCK_M"] for config in result["triton"]] + self.assertIn(128, kept_block_ms) # Config with correct hash + self.assertIn(32, kept_block_ms) # Config without hash + self.assertNotIn( + 64, kept_block_ms + ) # Config with wrong hash should be filtered + + # template_hash should be removed from returned configs + for config in result["triton"]: + self.assertNotIn("template_hash", config) + + @parametrize( + "config_hash,description", + [ + ("definitely_malformed_hash_!@#$%", "malformed hash"), + (12345, "non-string hash"), + ("", "empty string hash"), + (None, "missing hash field"), + ], + ) + def test_hash_checking_disabled_edge_cases(self, config_hash, description): + """Test that configs are kept when hash checking is disabled, regardless of hash validity""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + # Create config with potentially problematic hash + config = self.create_config("triton", BLOCK_M=128) + if config_hash is not None: + config["template_hash"] = config_hash + # If config_hash is None, don't add template_hash field at all + + # Provide a valid template hash that would normally be used for comparison + template_hash_map = {"triton": "valid_template_hash_abc123"} + + lookup_table_data = {self.create_lookup_key("mm", kernel_inputs): [config]} + + with ( + patch.object(inductor_config.lookup_table, "table", lookup_table_data), + patch.object(inductor_config.lookup_table, "check_src_hash", False), + ): + test_choices = LookupTableChoices() + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"], template_hash_map + ) + + # Should keep config regardless of hash validity since checking is disabled + assert result is not None, f"Result should not be None for {description}" + self.assertIn( + "triton", result, f"Should have triton result for {description}" + ) + self.assertEqual( + len(result["triton"]), 1, f"Should have 1 config for {description}" + ) + # template_hash should be removed from returned config + self.assertNotIn( + "template_hash", + result["triton"][0], + f"template_hash should be removed from result for {description}", + ) + # Other config fields should be preserved + self.assertEqual( + result["triton"][0]["BLOCK_M"], + 128, + f"BLOCK_M should be preserved for {description}", + ) + + @parametrize( + "table_has_device_key,lookup_device_matches,expected_found", + [ + # Device-specific key in table, same device -> found + (True, True, True), + # Device-specific key in table, different device -> not found + (True, False, False), + # Device-agnostic key in table, same device -> found + (False, True, True), + # Device-agnostic key in table, different device -> found (device-agnostic) + (False, False, True), + ], + ) + def test_device_key_lookup_scenarios( + self, table_has_device_key, lookup_device_matches, expected_found + ): + """Test lookup behavior with device-specific vs device-agnostic keys""" + # Create kernel inputs for "device_1" (our reference device) + kernel_inputs_device1 = self.create_mock_mm_kernel_inputs() + + # Create config + config = self.create_config("triton", BLOCK_M=128) + + # Create a test choices class for generating the table key + class TableKeyChoices(LookupTableChoices): + @staticmethod + def _get_device_key(device): + if device.type != "cuda": + return None + return "device_1" # Always device_1 for table key generation + + table_key_choices = TableKeyChoices() + + # Generate table key based on whether it should include device + if table_has_device_key: + table_key = table_key_choices.make_lookup_key( + kernel_inputs_device1, "mm", include_device=True + ) + else: + table_key = table_key_choices.make_lookup_key( + kernel_inputs_device1, "mm", include_device=False + ) + + lookup_table_data = {table_key: [config]} + + # Create test choices class for the actual lookup with different device behavior + if lookup_device_matches: + + class TestChoices(LookupTableChoices): + @staticmethod + def _get_device_key(device): + if device.type != "cuda": + return None + return "device_1" + + else: + + class TestChoices(LookupTableChoices): + @staticmethod + def _get_device_key(device): + if device.type != "cuda": + return None + return "device_2" + + with patch.object(inductor_config.lookup_table, "table", lookup_table_data): + test_choices = TestChoices() + result = test_choices.lookup_template_configs( + kernel_inputs_device1, "mm", ["triton"] + ) + + if expected_found: + assert result is not None, ( + f"Result should not be None when expected_found={expected_found}" + ) + self.assertIn("triton", result, "Should have triton result when found") + self.assertEqual(len(result["triton"]), 1, "Should have exactly 1 config") + self.assertEqual( + result["triton"][0]["BLOCK_M"], 128, "Config should be preserved" + ) + else: + self.assertEqual( + result, + {}, + f"Should return empty dict when expected_found={expected_found}", + ) + + def test_device_key_priority(self): + """Test that device-specific keys take priority over device-agnostic keys""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + # Create two different configs + device_specific_config = self.create_config( + "triton", BLOCK_M=256 + ) # Different BLOCK_M + device_agnostic_config = self.create_config("triton", BLOCK_M=128) + + # Create a test choices instance to generate keys + key_choices = LookupTableChoices() + + # Create both key types for the same inputs + device_key = key_choices.make_lookup_key( + kernel_inputs, "mm", include_device=True + ) + device_agnostic_key = key_choices.make_lookup_key( + kernel_inputs, "mm", include_device=False + ) + + # Put both in the table + lookup_table_data = { + device_key: [device_specific_config], + device_agnostic_key: [device_agnostic_config], + } + + with patch.object(inductor_config.lookup_table, "table", lookup_table_data): + test_choices = LookupTableChoices() + result = test_choices.lookup_template_configs( + kernel_inputs, "mm", ["triton"] + ) + + # Should get device-specific config (BLOCK_M=256), not device-agnostic (BLOCK_M=128) + assert result is not None, "Result should not be None" + self.assertIn("triton", result) + self.assertEqual(len(result["triton"]), 1) + self.assertEqual( + result["triton"][0]["BLOCK_M"], + 256, + "Should use device-specific config when both exist", + ) + + def test_make_lookup_key_variants(self): + """Test the make_lookup_key_variants helper function""" + kernel_inputs = self.create_mock_mm_kernel_inputs() + + test_choices = LookupTableChoices() + device_key, device_agnostic_key = test_choices.make_lookup_key_variants( + kernel_inputs, "mm" + ) + + # Both should be strings + self.assertIsInstance(device_key, str) + self.assertIsInstance(device_agnostic_key, str) + + # Device key should be longer (contains device info) + self.assertGreater(len(device_key), len(device_agnostic_key)) + + # Device-agnostic key should be contained in device key (as a substring after device part) + self.assertIn(device_agnostic_key.split("+mm")[0], device_key) + + +class UnifiedModel(nn.Module): + """Unified model for different matrix operations""" + + def __init__(self, operation="mm"): + super().__init__() + self.operation = operation + + def forward(self, *args): + if self.operation == "mm": + return torch.mm(args[0], args[1]) + elif self.operation == "addmm": + return torch.addmm(args[0], args[1], args[2]) + elif self.operation == "bmm": + return torch.bmm(args[0], args[1]) + elif self.operation == "mm_plus_mm": + return torch.mm(args[0], args[1]) + torch.mm(args[2], args[3]) + else: + raise ValueError(f"Unsupported operation: {self.operation}") + + +def verify_choice_names(choices: list[Any], pattern: str, expected_count: int = 1): + """Verify choices match expected pattern and count""" + if len(choices) != expected_count: + raise ValueError(f"Expected {expected_count} choices, got {len(choices)}") + for choice in choices: + if not re.search(pattern, choice.name): + raise ValueError( + f"Choice name '{choice.name}' doesn't match pattern '{pattern}'" + ) + return choices + + +class BaseE2ELookupTableTest(BaseLookupTableTest): + """Base class for E2E lookup table tests""" + + def setUp(self): + torch._dynamo.reset() + clear_preprocessing_fns() + self.device = torch.device("cuda") + self.dev_key = LookupTableChoices._get_device_key(self.device) + self.original_lookup_table = inductor_config.lookup_table.table + # Set the lookup table choices handler + V.set_choices_handler(LookupTableChoices()) + + def tearDown(self): + inductor_config.lookup_table.table = self.original_lookup_table + # Restore original choices handler + V.set_choices_handler(InductorChoices()) + clear_preprocessing_fns() + + def create_tensors(self, operation, b=8, m=64, n=64, k=32): + """Create test tensors for operations with configurable dimensions""" + if operation in ["mm", "addmm", "mm_plus_mm"]: + A = torch.randn(m, k, device=self.device, dtype=torch.float16) + B = torch.randn(k, n, device=self.device, dtype=torch.float16) + if operation == "mm": + return [A, B] + if operation == "addmm": + return [ + torch.randn((m, n), device=self.device, dtype=torch.float16), + A, + B, + ] + elif operation == "mm_plus_mm": + return [ + A, + B, + torch.randn(m, k, device=self.device, dtype=torch.float16), + torch.randn(k, n, device=self.device, dtype=torch.float16), + ] + elif operation == "bmm": + return [ + torch.randn(b, m, k, device=self.device, dtype=torch.float16), + torch.randn(b, k, n, device=self.device, dtype=torch.float16), + ] + else: + raise ValueError(f"Unsupported operation: {operation}") + + def setup_lookup_table(self, operation, tensors, configs): + """Setup lookup table with configuration""" + scalars = {} + if operation in ["addmm", "baddbmm"]: + scalars["beta"] = 1 + scalars["alpha"] = 1 + mock_kernel_inputs = MockMMKernelInputs(tensors, scalars) + flat_key = self.create_lookup_key(operation, mock_kernel_inputs) + inductor_config.lookup_table.table = {flat_key: configs} + + def run_model(self, operation, tensors, config_patches=None): + """Run compiled model with configuration""" + config = {"max_autotune_gemm": True, "test_configs.max_mm_configs": 4} + if config_patches: + config.update(config_patches) + + model = UnifiedModel(operation) + with inductor_config.patch(config): + compiled_model = torch.compile(model.to(self.device)) + return compiled_model(*tensors) + + def create_basic_config(self, template_id): + """Create basic configuration for template""" + configs = { + torch._inductor.kernel.mm.mm_template.uid: { + "BLOCK_M": 64, + "BLOCK_N": 64, + "BLOCK_K": 32, + "num_stages": 2, + "num_warps": 2, + "EVEN_K": True, + "USE_FAST_ACCUM": False, + "ACC_TYPE": "tl.float32", + "GROUP_M": 8, + }, + torch._inductor.kernel.mm_plus_mm.mm_plus_mm_template.uid: { + "BLOCK_M": 64, + "BLOCK_N": 64, + "BLOCK_K": 32, + "num_stages": 2, + "num_warps": 2, + "EVEN_K": True, + "USE_FAST_ACCUM": False, + "ACC_TYPE": "tl.float32", + "GROUP_M": 8, + }, + torch._inductor.kernel.bmm.bmm_template.uid: { + "BLOCK_M": 64, + "BLOCK_N": 64, + "BLOCK_K": 64, + "num_stages": 2, + "num_warps": 2, + "EVEN_K": True, + "USE_FAST_ACCUM": False, + "ACC_TYPE": "tl.float32", + "GROUP_M": 8, + }, + torch._inductor.kernel.mm.persistent_tma_mm_template.uid: { + "BLOCK_M": 64, + "BLOCK_N": 64, + "BLOCK_K": 32, + "num_stages": 2, + "num_warps": 2, + "EVEN_K": True, + "USE_FAST_ACCUM": False, + "ACC_TYPE": "tl.float32", + "GROUP_M": 8, + "A_ROW_MAJOR": True, + "B_ROW_MAJOR": True, + "NUM_SMS": get_num_sms(), + "TMA_SIZE": TMA_DESCRIPTOR_SIZE, + "TMA_EXPERIMENTAL_API": not has_triton_stable_tma_api(), + }, + torch._inductor.kernel.mm.aten_bias_addmm.uid: {}, + torch._inductor.kernel.mm.decompose_k_subgraph_template.uid: {"k_split": 4}, + } + return {"template_id": template_id, **configs.get(template_id, {})} + + def _create_simple_matmul_model(self): + """Create a simple matmul model for recording tests""" + + class SimpleMatmul(nn.Module): + def forward(self, a, b): + return torch.mm(a, b) + + return SimpleMatmul() + + def _create_test_inputs(self, device="cuda"): + """Create test inputs for matmul""" + return [ + torch.randn(512, 512, device=device, dtype=torch.float32), + torch.randn(512, 512, device=device, dtype=torch.float32), + ] + + +@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support lookup table") +@unittest.skipIf(not HAS_CUDA_AND_TRITON, "CUDA not available") +@instantiate_parametrized_tests +class TestLookupTableE2E(BaseE2ELookupTableTest): + """E2E tests for lookup table functionality""" + + @parametrize("max_autotune", [True, False]) + @fresh_cache() + def test_no_lookup_table_entry_autotune_modes(self, max_autotune): + """Test when there's no lookup table entry with different autotune modes""" + tensors = self.create_tensors("mm") + + # Setup lookup table with different key to force no match + self.setup_lookup_table( + "mm", + [ + torch.randn(64, 64, device=self.device), + torch.randn(64, 64, device=self.device), + ], + [], + ) + + # Inline validation function + def validate_choices(choices): + if max_autotune: + assert len(choices) > 2, ( + f"Max-autotune should have >2 choices, got {len(choices)}" + ) + assert any(isinstance(c, ExternKernelCaller) for c in choices), ( + "Should have ExternKernelCaller" + ) + assert any(isinstance(c, TritonTemplateCaller) for c in choices), ( + "Should have TritonTemplateCaller" + ) + else: + assert len(choices) == 1, ( + f"No max-autotune should have 1 choice, got {len(choices)}" + ) + assert isinstance(choices[0], ExternKernelCaller), ( + f"Should be ExternKernelCaller, got {type(choices[0])}" + ) + return choices + + add_preprocessing_fn(validate_choices) + self.run_model( + "mm", + tensors, + {"max_autotune_gemm": max_autotune, "max_autotune": max_autotune}, + ) + + @parametrize("operation", ["mm", "addmm", "bmm", "mm_plus_mm"]) + @fresh_cache() + def test_valid_lookup_table_entry(self, operation): + """Test when there's a valid entry for the operation""" + k = 256 if operation == "mm_plus_mm" else 64 + tensors = self.create_tensors(operation, k=k) + + # Map operation to actual template UID + template_mapping = { + "mm": torch._inductor.kernel.mm.mm_template.uid, + "addmm": torch._inductor.kernel.mm.mm_template.uid, + "bmm": torch._inductor.kernel.bmm.bmm_template.uid, + "mm_plus_mm": torch._inductor.kernel.mm_plus_mm.mm_plus_mm_template.uid, + } + template_id = template_mapping[operation] + config = self.create_basic_config(template_id) + + self.setup_lookup_table(operation, tensors, [config]) + add_preprocessing_fn( + partial(verify_choice_names, pattern="triton_", expected_count=1) + ) + self.run_model(operation, tensors) + + @unittest.skipIf(not has_triton_tma_device(), "Need TMA support") + @parametrize("operation", ["mm", "addmm"]) + @fresh_cache() + def test_tma_lookup_table_entry(self, operation): + """Test TMA template entry""" + tensors = self.create_tensors(operation) + config = self.create_basic_config( + torch._inductor.kernel.mm.persistent_tma_mm_template.uid + ) + + self.setup_lookup_table(operation, tensors, [config]) + add_preprocessing_fn( + partial( + verify_choice_names, + pattern="triton_mm_persistent_tma_", + expected_count=1, + ) + ) + self.run_model( + operation, tensors, {"triton.enable_persistent_tma_matmul": True} + ) + + @fresh_cache() + def test_decompose_k_lookup_table_entry(self): + """Test decompose_k template entry""" + tensors = self.create_tensors("mm", m=32, n=32, k=32 * 32) + config = self.create_basic_config( + torch._inductor.kernel.mm.decompose_k_subgraph_template.uid + ) + + self.setup_lookup_table("mm", tensors, [config]) + add_preprocessing_fn( + partial( + verify_choice_names, pattern="decompose_k|bmm_dtype", expected_count=1 + ) + ) + self.run_model("mm", tensors) + + @fresh_cache() + def test_bias_addmm_lookup_table_entry(self): + """Test bias_addmm template entry""" + # Create bias with stride[0] == 0 for bias_addmm eligibility + bias_unexpanded = torch.randn(64, device=self.device, dtype=torch.float16) + expanded_bias = bias_unexpanded.expand(64, 64) + tensors = [ + expanded_bias, + torch.randn(64, 32, device=self.device, dtype=torch.float16), + torch.randn(32, 64, device=self.device, dtype=torch.float16), + ] + + config = self.create_basic_config(torch._inductor.kernel.mm.aten_bias_addmm.uid) + self.setup_lookup_table("addmm", tensors, [config]) + add_preprocessing_fn( + partial(verify_choice_names, pattern="bias_addmm", expected_count=1) + ) + + # Run with original unexpanded bias + with inductor_config.patch( + {"max_autotune_gemm": True, "triton.autotune_cublasLt": True} + ): + model = UnifiedModel("addmm") + compiled_model = torch.compile(model.to(self.device), mode="max-autotune") + compiled_model(bias_unexpanded, tensors[1], tensors[2]) + + @unittest.skipIf(not has_triton_tma_device(), "Need TMA support") + @fresh_cache() + def test_multiple_configs_same_template(self): + """Test multiple configurations for same template""" + tensors = self.create_tensors("mm") + + config1 = self.create_basic_config( + torch._inductor.kernel.mm.persistent_tma_mm_template.uid + ) + config1.update({"BLOCK_M": 128, "BLOCK_N": 128, "num_warps": 8}) + + config2 = self.create_basic_config( + torch._inductor.kernel.mm.persistent_tma_mm_template.uid + ) + config2.update({"BLOCK_M": 64, "BLOCK_N": 64, "num_warps": 4}) + + self.setup_lookup_table("mm", tensors, [config1, config2]) + add_preprocessing_fn( + partial( + verify_choice_names, + pattern="triton_mm_persistent_tma_", + expected_count=2, + ) + ) + self.run_model("mm", tensors, {"triton.enable_persistent_tma_matmul": True}) + + @unittest.skipIf(not has_triton_tma_device(), "Need TMA support") + @fresh_cache() + def test_mixed_template_configs(self): + """Test mixing different template types""" + tensors = self.create_tensors("mm") + + triton_config = self.create_basic_config( + torch._inductor.kernel.mm.mm_template.uid + ) + triton_config.update({"BLOCK_M": 128, "num_warps": 8}) + + tma_config = self.create_basic_config( + torch._inductor.kernel.mm.persistent_tma_mm_template.uid + ) + tma_config.update({"BLOCK_M": 256, "num_warps": 4}) + + self.setup_lookup_table("mm", tensors, [triton_config, tma_config]) + add_preprocessing_fn( + partial(verify_choice_names, pattern="triton_", expected_count=2) + ) + self.run_model("mm", tensors, {"triton.enable_persistent_tma_matmul": True}) + + @fresh_cache() + def test_template_hash_filtering_e2e(self): + """Test end-to-end template hash filtering in real MM operation""" + tensors = self.create_tensors("mm") + + # Get the actual src_hash from the template + actual_hash = torch._inductor.kernel.mm.mm_template.src_hash + + # Create configs - one with correct hash, one with wrong hash + correct_config = self.create_basic_config( + torch._inductor.kernel.mm.mm_template.uid + ) + correct_config.update( + {"BLOCK_M": 128, "template_hash": actual_hash} # Use actual hash + ) + + wrong_config = self.create_basic_config( + torch._inductor.kernel.mm.mm_template.uid + ) + wrong_config.update( + { + "BLOCK_M": 64, + "template_hash": "definitely_wrong_hash_12345", # Wrong hash + } + ) + + self.setup_lookup_table("mm", tensors, [correct_config, wrong_config]) + + # Should only get 1 choice since the wrong hash config gets filtered + add_preprocessing_fn( + partial(verify_choice_names, pattern="triton_", expected_count=1) + ) + + # Ensure hash checking is enabled + with patch.object(inductor_config.lookup_table, "check_src_hash", True): + self.run_model("mm", tensors) + + +if __name__ == "__main__": + from torch._inductor.utils import is_big_gpu + + if HAS_GPU and HAS_CPU and is_big_gpu(): + run_tests() diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 420a9ee82927d..2f753b7ae0e69 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -2095,7 +2095,7 @@ def func_test1(x, y, z, m): # Test loop. def test_func2(x): - for i in range(10): + for _ in range(10): x = torch.matmul(x, x) return x diff --git a/test/inductor/test_memory.py b/test/inductor/test_memory.py index bf994b5e6b847..158200edc729e 100644 --- a/test/inductor/test_memory.py +++ b/test/inductor/test_memory.py @@ -343,7 +343,7 @@ def f(a, b, c): def test_fusion_acc_large_reads(self): def f(x, y, z): res = torch.zeros_like(x[0]) - for i in range(4): + for _ in range(4): temp = torch.matmul(x, y) + z res = res + temp return res diff --git a/test/inductor/test_mix_order_reduction.py b/test/inductor/test_mix_order_reduction.py index e8580be7c10b3..230a2514b9171 100644 --- a/test/inductor/test_mix_order_reduction.py +++ b/test/inductor/test_mix_order_reduction.py @@ -6,7 +6,6 @@ from torch._dynamo.utils import same from torch._inductor import metrics, utils from torch._inductor.test_case import run_tests, TestCase -from torch.testing import FileCheck from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -66,20 +65,6 @@ def f(x): self.check_numeric(f, (x,)) self.assertEqual(0, metrics.codegen_mix_order_reduction) - @inductor_config.patch(split_reductions=False) - def test_skip_due_to_non_persistent_reduction(self): - """ - We only generate mix order reduction if one of the reduction is - persistent reduction. - """ - - def f(x): - return x.sum(dim=1), x.sum(dim=0) - - x = torch.randn(32768, 2048, device=GPU_TYPE) - self.check_numeric(f, (x,)) - self.assertEqual(0, metrics.codegen_mix_order_reduction) - @instantiate_parametrized_tests class MixOrderReductionTest(TestBase): @@ -92,35 +77,105 @@ class MixOrderReductionTest(TestBase): ], ) @parametrize("swap", (False, True)) - @parametrize("shape", ((32768, 768), (32769, 768))) - @inductor_config.patch(split_reductions=False) - def test_mix_order_reduction(self, name, swap, shape): + @parametrize("split_reductions", (False, True)) + @parametrize("shape", ((32768, 768), (32769, 768), (32, 1024, 768))) + def test_mix_order_reduction(self, name, swap, split_reductions, shape): + # torch.prod does not accept tuple for dim argument + if name == "prod" and len(shape) == 3: + self.skipTest("Invalid combination") + def f(x): + def outer_red(): + if len(shape) == 3: + return reduction_fn(x, dim=(0, 1)) + else: + assert len(shape) == 2 + return reduction_fn(x, dim=0) + if swap: - return reduction_fn(x, dim=0), reduction_fn(x, dim=1) + return outer_red(), reduction_fn(x, dim=-1) else: - return reduction_fn(x, dim=1), reduction_fn(x, dim=0) + return reduction_fn(x, dim=-1), outer_red() reduction_fn = getattr(torch, name) - M, N = shape dtype = torch.float - x = torch.randn(M, N, dtype=dtype, device=GPU_TYPE) + x = torch.randn(shape, dtype=dtype, device=GPU_TYPE) - opt_f = torch.compile(f) + opt_f = torch.compile( + f, + options={ + "split_reductions": split_reductions, + }, + ) ref = f(x) act = opt_f(x) self.assertTrue(same(ref, act, tol=1e-3), f"ref:\n{ref}\nact:\n{act}") + self.assertEqual( + inductor_config.triton.mix_order_reduction, + metrics.codegen_mix_order_reduction, + ) + + @inductor_config.patch(unroll_reductions_threshold=1) + def test_3layer_split_reduction(self): + """ + Use a larger M and smaller N to trigger a 3 layer split reduction. + """ + if not inductor_config.triton.mix_order_reduction: + self.skipTest("Mix order reduction not enabled") + + def f(x): + return x.sum(dim=-1), x.sum(dim=0) + + x = torch.randn(32768 * 256, 2, dtype=torch.float, device=GPU_TYPE) + self.check_numeric(f, (x,)) + # We don't do mix order reduction for split redutions + # with more than 2 layers + self.assertEqual(metrics.codegen_mix_order_reduction, 0) + + def test_independent_split_size(self): + """ + Make sure mix order reduction can pick the split size it wants + """ + if not inductor_config.triton.mix_order_reduction: + self.skipTest("Mix order reduction not enabled") + + def f(x): + return x.sum(dim=-1), x.sum(dim=0) + + def check_one_split_size(split_size): + torch._dynamo.reset() + + with inductor_config.patch( + "triton.mix_order_reduction_split_size", split_size + ): + self.check_numeric(f, (x,)) + self.assertEqual( + inductor_config.triton.mix_order_reduction, + metrics.codegen_mix_order_reduction, + ) + + _, (code,) = utils.run_and_get_code(torch.compile(f), x) + self.assertTrue(f"'RSPLIT_SIZE': {split_size}" in code) + + x = torch.randn(32768, 768, dtype=torch.float, device=GPU_TYPE) - expected_num_kernel = 1 + (not inductor_config.triton.mix_order_reduction) - if name == "mean" and inductor_config.triton.mix_order_reduction: - # for mean we generate one more kernel to do the division - # this kernel should be very cheap since tensor size is small - expected_num_kernel = 2 + check_one_split_size(8) + check_one_split_size(16) + + @inductor_config.patch(split_reductions=False) + def test_non_contiguous_input(self): + def f(x): + return x.sum(dim=-1), x.sum(dim=[0, 1]) + + x = torch.randn(1024, 32, 768, dtype=torch.float, device=GPU_TYPE).permute( + 1, 0, 2 + ) + self.check_numeric(f, (x,)) self.assertEqual( - expected_num_kernel, - metrics.generated_kernel_count, + inductor_config.triton.mix_order_reduction, + metrics.codegen_mix_order_reduction, ) @inductor_config.patch(split_reductions=False) @@ -128,8 +183,8 @@ def test_multi_workspace_allocation(self): def f(x, y): return x.sum(dim=0), x.sum(dim=1), y.sum(dim=0), y.sum(dim=1) - x = torch.randn(128 * 15, 128, device=GPU_TYPE) - y = torch.randn(256 * 15, 256, device=GPU_TYPE) + x = torch.randn(4096, 32, device=GPU_TYPE) + y = torch.randn(4098, 34, device=GPU_TYPE) self.check_numeric(f, (x, y)) expected_mix_order_reduction = ( @@ -146,9 +201,9 @@ def f(x, y): torch.float, ], ) - @parametrize("shape", ((32768, 768), (32769, 768))) - @inductor_config.patch(split_reductions=False) - def test_rms_norm_bwd(self, wdtype, shape): + @parametrize("split_reductions", (False, True)) + @parametrize("shape", ((32768, 2048), (32768, 768), (32769, 768))) + def test_rms_norm_bwd(self, wdtype, split_reductions, shape): def f(x, w, eps): orig_dtype = x.dtype @@ -173,21 +228,21 @@ def fwd_bwd(f): dy = torch.randn_like(x) eps = 1e-5 - opt_f = torch.compile(f) + opt_f = torch.compile( + f, + options={ + "split_reductions": split_reductions, + }, + ) ref = fwd_bwd(f) act, (_, bwd_wrapper) = utils.run_and_get_code(fwd_bwd, opt_f) self.assertTrue(same(ref, act, tol=1e-2), f"ref:\n{ref}\nact:\n{act}") - expected_num_kernel = 1 + (not inductor_config.triton.mix_order_reduction) - if wdtype == torch.bfloat16 and inductor_config.triton.mix_order_reduction: - # one extra kernel for downcasting - expected_num_kernel = 2 - FileCheck().check_count( - "@triton.jit", - expected_num_kernel, - exactly=True, - ).run(bwd_wrapper) + self.assertEqual( + inductor_config.triton.mix_order_reduction, + metrics.codegen_mix_order_reduction, + ) @parametrize( "wbdtype", @@ -196,9 +251,9 @@ def fwd_bwd(f): torch.float, ], ) + @parametrize("split_reductions", (False, True)) @parametrize("shape", ((32768, 768), (32769, 768))) - @inductor_config.patch(split_reductions=False) - def test_layer_norm_bwd_with_bias(self, wbdtype, shape): + def test_layer_norm_bwd_with_bias(self, wbdtype, split_reductions, shape): def f(x, w, b, eps): return F.layer_norm(x, x.shape[-1:], w.float(), b.float(), eps) @@ -219,25 +274,25 @@ def fwd_bwd(f): dy = torch.randn_like(x) eps = 1e-5 - opt_f = torch.compile(f) + opt_f = torch.compile( + f, + options={ + "split_reductions": split_reductions, + }, + ) ref = fwd_bwd(f) act, (_, bwd_wrapper) = utils.run_and_get_code(fwd_bwd, opt_f) self.assertTrue(same(ref, act, tol=1e-2), f"ref:\n{ref}\nact:\n{act}") - expected_num_kernel = 1 + (not inductor_config.triton.mix_order_reduction) - if wbdtype == torch.bfloat16 and inductor_config.triton.mix_order_reduction: - # one extra kernel for downcasting - expected_num_kernel = 2 - FileCheck().check_count( - "@triton.jit", - expected_num_kernel, - exactly=True, - ).run(bwd_wrapper) + self.assertEqual( + inductor_config.triton.mix_order_reduction, + metrics.codegen_mix_order_reduction, + ) + @parametrize("split_reductions", (False, True)) @parametrize("shape", ((32768, 768), (32769, 768))) - @inductor_config.patch(split_reductions=False) - def test_layer_norm_bwd_no_bias(self, shape): + def test_layer_norm_bwd_no_bias(self, split_reductions, shape): def f(x, w, eps): return F.layer_norm(x, x.shape[-1:], w, bias=None, eps=eps) @@ -257,17 +312,21 @@ def fwd_bwd(f): dy = torch.randn_like(x) eps = 1e-5 - opt_f = torch.compile(f) + opt_f = torch.compile( + f, + options={ + "split_reductions": split_reductions, + }, + ) ref = fwd_bwd(f) act, (_, bwd_wrapper) = utils.run_and_get_code(fwd_bwd, opt_f) self.assertTrue(same(ref, act, tol=1e-2), f"ref:\n{ref}\nact:\n{act}") - FileCheck().check_count( - "@triton.jit", - 1 + (not inductor_config.triton.mix_order_reduction), - exactly=True, - ).run(bwd_wrapper) + self.assertEqual( + inductor_config.triton.mix_order_reduction, + metrics.codegen_mix_order_reduction, + ) @inductor_config.patch( diff --git a/test/inductor/test_native_matmul.py b/test/inductor/test_native_matmul.py index 7c91fd2b9faf6..1870a0e373be0 100644 --- a/test/inductor/test_native_matmul.py +++ b/test/inductor/test_native_matmul.py @@ -10,6 +10,7 @@ from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_triton_code from torch.testing import FileCheck +from torch.testing._internal.common_utils import skipIfXpu from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU @@ -106,6 +107,9 @@ def f(x, y): self._check_equal(f, (x, y)) self._check_code(f, (x, y), 1, 1) + @skipIfXpu( + msg="Intel triton issue: https://github.com/intel/intel-xpu-backend-for-triton/issues/5394" + ) def test_3mm_add(self): def f(x, y, z, w, r, t): return x @ y + z @ w + r @ t @@ -152,6 +156,5 @@ def f(x, y): torch.set_default_device(GPU_TYPE) if __name__ == "__main__": - # TODO: support native matmul on xpu - if HAS_GPU and GPU_TYPE != "xpu": + if HAS_GPU: run_tests() diff --git a/test/inductor/test_ordered_set.py b/test/inductor/test_ordered_set.py index 216b8ab0f0216..c588018fcf667 100644 --- a/test/inductor/test_ordered_set.py +++ b/test/inductor/test_ordered_set.py @@ -539,7 +539,7 @@ def test_discard(self): # s.discard(self.thetype(self.word)) def test_pop(self): - for i in range(len(self.s)): + for _ in range(len(self.s)): elem = self.s.pop() self.assertNotIn(elem, self.s) self.assertRaises(KeyError, self.s.pop) @@ -990,7 +990,7 @@ def test_instancesWithoutException(self): def test_changingSizeWhileIterating(self): s = OrderedSet([1, 2, 3]) try: - for i in s: + for _ in s: s.update([4]) # noqa: B909 except RuntimeError: pass diff --git a/test/inductor/test_profiler.py b/test/inductor/test_profiler.py index e6cf6bbcc91bd..11d1d4ce371a0 100644 --- a/test/inductor/test_profiler.py +++ b/test/inductor/test_profiler.py @@ -245,7 +245,7 @@ def fn(a, b, c): skip_first=3, wait=1, warmup=1, active=2, repeat=1 ), ) as prof: - for idx in range(10): + for _ in range(10): fn(*inputs) prof.step() diff --git a/test/inductor/test_provenance_tracing.py b/test/inductor/test_provenance_tracing.py index cc8596d903610..0d59616bc5338 100644 --- a/test/inductor/test_provenance_tracing.py +++ b/test/inductor/test_provenance_tracing.py @@ -7,6 +7,7 @@ import os import re import shutil +import sys import tempfile import unittest import zipfile @@ -24,7 +25,7 @@ ) from torch._inductor.fx_passes.post_grad import post_grad_passes from torch._inductor.test_case import run_tests, TestCase -from torch._inductor.utils import run_and_get_code +from torch._inductor.utils import run_and_get_code, run_and_get_cpp_code from torch._inductor.virtualized import V from torch.testing._internal.common_utils import IS_MACOS from torch.testing._internal.triton_utils import requires_cuda_and_triton @@ -32,8 +33,12 @@ try: from .test_aot_inductor_utils import AOTIRunnerUtil + from .test_torchinductor import copy_tests except ImportError: from test_aot_inductor_utils import AOTIRunnerUtil + from test_torchinductor import ( + copy_tests, # @manual=fbcode//caffe2/test/inductor:test_inductor-library + ) trace_log = logging.getLogger("torch.__trace") @@ -806,5 +811,135 @@ def forward(self, x): self.assertTrue("aoti_torch_cpu_convolution" in keys) +class ProvenanceTracingKernelContextTemplate: + def test_jit_inductor_with_flag(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 16) + self.relu = torch.nn.ReLU() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x, a, b, c): + x = self.fc1(x) + x = self.relu(x) + x = self.sigmoid(x) + d = a * 3.14 + y = torch.addmm(c, d, b) + z = torch.nn.functional.gelu(y) + return x, z + + model = Model().to(self.device) + x = torch.randn(8, 10).to(self.device) + a = torch.randn(10, 20).to(self.device) + b = torch.randn(20, 30).to(self.device) + c = torch.randn(10, 30).to(self.device) + example_inputs = (x, a, b, c) + + with config.patch( + { + "cpp.enable_kernel_profile": True, + } + ): + torch.compile(model)(*example_inputs) + + @unittest.skipIf(sys.platform == "darwin", "Different kernel names on MacOS") + def test_aoti_python_stack_traces(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 16) + self.relu = torch.nn.ReLU() + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, x, a, b, c): + x = self.fc1(x) + x = self.relu(x) + x = self.sigmoid(x) + d = a * 3.14 + y = torch.addmm(c, d, b) + z = torch.nn.functional.gelu(y) + return x, z + + x = torch.randn(8, 10).to(self.device) + a = torch.randn(10, 20).to(self.device) + b = torch.randn(20, 30).to(self.device) + c = torch.randn(10, 30).to(self.device) + example_inputs = (x, a, b, c) + model = Model().to(self.device) + + ep = torch.export.export(model, example_inputs) + _, code = run_and_get_cpp_code(torch._inductor.aoti_compile_and_package, ep) + + self.assertTrue("KernelContextGuard" not in code) + + with config.patch( + { + "trace.provenance_tracking_level": 1, + "cpp.enable_kernel_profile": True, + } + ): + package_path, code = run_and_get_cpp_code( + torch._inductor.aoti_compile_and_package, ep + ) + + FileCheck().check( + "#include " + ).check("thread_local KernelContext* tls_kernel_context = nullptr;").run( + code + ) + + if self.device == "cuda": + FileCheck().check( + """KernelContextGuard _ctx("aoti_torch_cuda_mm_out", R"(""" + ).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cuda_mm_out(").check( + """KernelContextGuard _ctx("triton_poi_fused_addmm_relu_sigmoid_0", R"(""" + ).check("call_triton_poi_fused_addmm_relu_sigmoid_0(").check( + """KernelContextGuard _ctx("triton_poi_fused_mul_1", R"(""" + ).check("call_triton_poi_fused_mul_1(").check( + """KernelContextGuard _ctx("aoti_torch_cuda_mm_out", R"(""" + ).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cuda_mm_out(").check( + """ KernelContextGuard _ctx("triton_poi_fused_addmm_gelu_2", R"(""" + ).check("call_triton_poi_fused_addmm_gelu_2(").run(code) + else: + FileCheck().check( + """KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(""" + ).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(").check( + """KernelContextGuard _ctx("cpp_fused_mul_relu_sigmoid_0", R"(""" + ).check("cpp_fused_mul_relu_sigmoid_0(").check( + """KernelContextGuard _ctx("aoti_torch_cpu_addmm_out", R"(""" + ).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu_addmm_out(").check( + """ KernelContextGuard _ctx("cpp_fused_gelu_1", R"(""" + ).check("cpp_fused_gelu_1(").run(code) + + compiled_model = torch._inductor.aoti_load_package(package_path) + result = compiled_model(*example_inputs) + self.assertEqual(result, model(*example_inputs)) + + +class TestProvenanceTracingKernelContextCpu(TestCase): + device = "cpu" + + +copy_tests( + ProvenanceTracingKernelContextTemplate, + TestProvenanceTracingKernelContextCpu, + "cpu", +) + + +@unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS") +@unittest.skipIf(not torch.cuda.is_available(), "No CUDA") +class TestProvenanceTracingKernelContextGpu(TestCase): + device = "cuda" + + +copy_tests( + ProvenanceTracingKernelContextTemplate, + TestProvenanceTracingKernelContextGpu, + "cuda", +) + + if __name__ == "__main__": run_tests() diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index a8def93ed72b6..675d912c0c01f 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -2090,7 +2090,7 @@ def test_reduction_config_limit(self): from torch._inductor.runtime.triton_heuristics import triton_config_reduction size_hints = {"x": 67108864, "r0_": 8192} - for i in range(4): + for _ in range(4): size_hints["x"] = next_power_of_2(size_hints["x"]) triton_config_reduction(size_hints, 1, 2048, 1, 8) @@ -5033,13 +5033,13 @@ def test_multi_threading(self): def run_weights_sharing_model(m, inp): with torch.no_grad(): - for i in range(num_run): + for _ in range(num_run): y = m(inp) numb_instance = 2 threads = [] compiled_m = torch.compile(model) - for i in range(1, numb_instance + 1): + for _ in range(1, numb_instance + 1): thread = threading.Thread( target=run_weights_sharing_model, args=(compiled_m, inp) ) @@ -6474,7 +6474,11 @@ def fn(x1, x2, x3, x4): # Constant folding was explicitly turned off due to issue #108388 # Turn it back on for test @unittest.skipIf(config.triton.native_matmul, "native matmul has better precision") - @torch._inductor.config.patch(joint_graph_constant_folding=True) + @torch._inductor.config.patch( + joint_graph_constant_folding=True, + # Numerical accuracy failure for triton fp16 + max_autotune_gemm_backends="ATEN", + ) def test_remove_no_ops(self): def matmul_with_op(x, y, fn): return fn(x @ y) @@ -6902,7 +6906,11 @@ def b(x): _, (code0, code1) = _run_and_get_stripped_kernels(b, x) self.assertEqual(code0, code1) - @config.patch(force_disable_caches=True) + @config.patch( + force_disable_caches=True, + # Test expects a single (fused) kernel to be generated + max_autotune_gemm_backends="ATEN", + ) @skip_if_cpp_wrapper("run_and_get_kernels issue") @unittest.skipIf(config.triton.native_matmul, "matmul is now generated") def test_deterministic_codegen_with_suffix(self): @@ -8423,6 +8431,22 @@ def fn(x): self.assertEqual(fn(x[0:]), x[16:][:16]) self.assertEqual(fn(x[128:]), x[128 + 16 :][:16]) + def test_index_float_zero(self): + def fn(arg0, arg1, arg2): + t1 = torch.tanh(arg0) + t2 = t1.clone() + t2.fill_(arg1.item()) + t3 = torch.clamp(t2, 0, arg2.size(0) - 1).to(torch.long) + return torch.nn.functional.embedding(t3, arg2) + + arg0 = torch.randint(0, 1000, [47], dtype=torch.int64, device=self.device) + arg1 = torch.randint(0, 1000, [], dtype=torch.int64, device=self.device) + arg2 = torch.rand([256, 88], dtype=torch.float16, device=self.device) + + cfn = torch.compile(fullgraph=True, dynamic=True)(fn) + + self.assertEqual(fn(arg0, arg1, arg2), cfn(arg0, arg1, arg2)) + # from GPT2ForSequenceClassification @skip_if_gpu_halide def test_index_tensor(self): @@ -10989,6 +11013,29 @@ def get_mask(W: torch.Tensor, percentage_nonzeros: torch.Tensor): p = torch.tensor(0.50, device=self.device) get_mask(x, p) + def test_flexible_layout_immutable_free_symbols(self): + import sympy + + x = sympy.Symbol("x") + y = sympy.Symbol("y") + z = sympy.Symbol("z") + + layout = torch._inductor.ir.FlexibleLayout( + self.device, torch.float32, size=(x, y) + ) + + # pad_strides works since it does not add new symints + layout.pad_strides() + + # same symints and different order should work + layout.size = (y, x) + + # adding new symints should fail + with self.assertRaisesRegex( + AssertionError, "Expected free symbols unchanged, but got" + ): + layout.size = (z,) + def test_sqrt_dynamic_shapes(self): # TIMM convit_base model: https://github.com/pytorch/pytorch/issues/97877. # TODO: support cuda path. @@ -14115,6 +14162,8 @@ def _is_triggering_buffer_reuse(fn, *inputs): code_disallowed = re.sub(r"AOT ID: .*", "AOT ID: ['test']", code_disallowed) return code_allowed != code_disallowed + # If matmul is implemented by triton there is more reuse + @config.patch(max_autotune_gemm_backends="ATEN") @unittest.skipIf(config.triton.native_matmul, "matmul is now generated") def test_allow_reuse_disable_if_exceed_peak(self): @torch.compile @@ -15522,6 +15571,9 @@ def fn(inp, weight): self.assertEqual(fn_opt(*inps), fn(*inps)) @torch._functorch.config.patch("donated_buffer", True) + # The inplace updating does not happen after we fused the + # layernorm backward + @torch._inductor.config.patch("triton.mix_order_reduction", False) def test_donated_buffer_inplace(self): batch_size = 32 seq_length = 50 diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 4739d00f1f4ad..e4ee0e4b2bd4c 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -497,7 +497,7 @@ def call_triton_add( x: torch.Tensor, y: torch.Tensor, ): - for i in range(4): + for _ in range(4): x = add_in_loop(x, y) return x @@ -2971,7 +2971,7 @@ def add_4_times_kernel( x = tl.load(in_ptr0 + offsets, mask=mask) y = tl.load(in_ptr1 + offsets, mask=mask) output = tl.zeros((n_elements,), dtype=tl.float32) - for i in range(4): + for _ in range(4): output += x + y tl.store(out_ptr + offsets, output, mask=mask) @@ -3041,8 +3041,8 @@ def add_4_times_kernel( x = tl.load(in_ptr0 + offsets, mask=mask) y = tl.load(in_ptr1 + offsets, mask=mask) output = tl.zeros((n_elements,), dtype=tl.float32) - for i in range(2): - for j in range(2): + for _ in range(2): + for _ in range(2): output += x + y tl.store(out_ptr + offsets, output, mask=mask) @@ -3078,8 +3078,8 @@ def add_4_times_kernel( y = tl.load(in_ptr1 + offsets, mask=mask) output1 = tl.zeros((n_elements,), dtype=tl.float32) output2 = tl.zeros((n_elements,), dtype=tl.float32) - for i in range(2): - for j in range(2): + for _ in range(2): + for _ in range(2): output1 += y output2 += x output = output1 + output2 diff --git a/test/inductor/test_utils.py b/test/inductor/test_utils.py index fa666dfc987ec..9516b4ee08940 100644 --- a/test/inductor/test_utils.py +++ b/test/inductor/test_utils.py @@ -91,7 +91,7 @@ def test_sympy_str(self): def test_flops_fx(self): def create_fx_node( - aten: torch._ops.OpOverloadPacket, args, kwargs + aten, op_overload: torch._ops.OpOverload, args, kwargs ) -> tuple[torch.fx.Node, torch.fx.Node]: node1 = torch.fx.Node( graph=torch.fx.Graph(), @@ -101,8 +101,13 @@ def create_fx_node( args=args, kwargs=kwargs, ) - name: str = aten.overloads()[0] - op_overload: torch._ops.OpOverload = getattr(aten, name) + # name: str = aten.overloads()[0] + # if aten == torch.ops.aten.addmm: + # name = "default" + # print(aten) + # print(aten.overloads()) + # print(name) + # op_overload: torch._ops.OpOverload = getattr(aten, name) node2 = torch.fx.Node( graph=torch.fx.Graph(), name="", @@ -119,17 +124,25 @@ def create_fx_node( trues = [ ( torch.ops.aten.addmm, + torch.ops.aten.addmm.default, (torch.Tensor(4, 4), torch.Tensor(4, 5), torch.Tensor(5, 4)), {}, ), ( torch.ops.aten.bmm, + torch.ops.aten.bmm.default, (torch.Tensor(10, 4, 5), torch.Tensor(10, 5, 4)), {}, ), - (torch.ops.aten.mm, (torch.Tensor(2, 3), torch.Tensor(3, 2)), {}), + ( + torch.ops.aten.mm, + torch.ops.aten.mm.default, + (torch.Tensor(2, 3), torch.Tensor(3, 2)), + {}, + ), ( torch.ops.aten.convolution, + torch.ops.aten.convolution.default, ( torch.Tensor(2, 2, 3), torch.Tensor(2, 2, 2), @@ -145,6 +158,7 @@ def create_fx_node( ), ( torch.ops.aten._convolution, + torch.ops.aten._convolution.deprecated, ( torch.Tensor(2, 2, 2), torch.Tensor(2, 2, 2), @@ -166,17 +180,19 @@ def create_fx_node( falses = [ ( torch.ops.aten.add, + torch.ops.aten.add.Tensor, (torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)), {}, ), ( torch.ops.aten.mul, + torch.ops.aten.mul.Tensor, (torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)), {}, ), ] - for t, args, kwargs in trues: - fx_node_1, fx_node_2 = create_fx_node(t, args, kwargs) + for t, t2, args, kwargs in trues: + fx_node_1, fx_node_2 = create_fx_node(t, t2, args, kwargs) self.assertTrue( countable_fx(fx_node_1), f"Expected true {t}: {fx_node_1}" ) @@ -185,8 +201,8 @@ def create_fx_node( ) self.assertNotEqual(count_flops_fx(fx_node_1), None) self.assertNotEqual(count_flops_fx(fx_node_2), None) - for f, args, kwargs in falses: - fx_node_1, fx_node_2 = create_fx_node(f, args, kwargs) + for f, f2, args, kwargs in falses: + fx_node_1, fx_node_2 = create_fx_node(f, f2, args, kwargs) self.assertFalse( countable_fx(fx_node_1), f"Expected false {f}: {fx_node_1}" ) diff --git a/test/jit/test_async.py b/test/jit/test_async.py index b739963ad5ea1..2621ac9414e06 100644 --- a/test/jit/test_async.py +++ b/test/jit/test_async.py @@ -35,7 +35,7 @@ def foo(x): def test_async_future_type_python(self): def foo(inp): futures = torch.jit.annotate(List[torch.jit.Future[torch.Tensor]], []) - for i in range(5): + for _ in range(5): futures.append(torch.jit.fork(lambda x: x, inp)) all_outputs = [] for future in futures: @@ -458,7 +458,7 @@ def add_one(input): class TestListFutureModule(nn.Module): def forward(self, input): input_list = [] - for i in range(3): + for _ in range(3): input_list.append(input) fut_list: List[Future[torch.Tensor]] = [] diff --git a/test/jit/test_autodiff.py b/test/jit/test_autodiff.py index 798f382968fe9..06117684971b1 100644 --- a/test/jit/test_autodiff.py +++ b/test/jit/test_autodiff.py @@ -68,7 +68,7 @@ def fn(a, b, c): fn_s = torch.jit.script(fn) - for i in range(4): + for _ in range(4): x, y = fn_s(a, b, c) self.assertFalse(x.requires_grad) self.assertTrue(y.requires_grad) @@ -90,7 +90,7 @@ def fn(a, b, c): b = torch.rand((10, 10), requires_grad=False) c = torch.rand((10, 10), requires_grad=True) - for i in range(4): + for _ in range(4): x_s, y_s, z_s = fn_s(a, b, c) x, y, z = fn(a, b, c) @@ -115,7 +115,7 @@ def fn(a, b, c): b = torch.rand((10, 10), requires_grad=False) c = torch.rand((10, 10), requires_grad=True) - for i in range(4): + for _ in range(4): x_s, y_s, z_s = fn_s(a, b, c) x, y, z = fn(a, b, c) @@ -141,7 +141,7 @@ def fn(a, b, c): b = torch.rand((10, 10), requires_grad=True) c = torch.rand((10, 10), requires_grad=True) - for i in range(4): + for _ in range(4): x_s, y_s, z_s = fn_s(a, b, c) x, y, z = fn(a, b, c) diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index 58bd66e7df165..b8853d2e6f5f4 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -2989,7 +2989,7 @@ def forward(self): test_script.segments_groupby_col # Smoketest for flakiness. Takes around 2s. - for i in range(300): + for _ in range(300): test = Test() test_script = torch.jit.script(test) diff --git a/test/jit/test_logging.py b/test/jit/test_logging.py index e03ffa9e0a137..37c379bde6c1b 100644 --- a/test/jit/test_logging.py +++ b/test/jit/test_logging.py @@ -19,7 +19,7 @@ def test_bump_numeric_counter(self): class ModuleThatLogs(torch.jit.ScriptModule): @torch.jit.script_method def forward(self, x): - for i in range(x.size(0)): + for _ in range(x.size(0)): x += 1.0 torch.jit._logging.add_stat_value("foo", 1) @@ -33,7 +33,7 @@ def forward(self, x): old_logger = torch.jit._logging.set_logger(logger) try: mtl = ModuleThatLogs() - for i in range(5): + for _ in range(5): mtl(torch.rand(3, 4, 5)) self.assertEqual(logger.get_counter_val("foo"), 15) @@ -60,7 +60,7 @@ def test_time_measurement_counter(self): class ModuleThatTimes(torch.jit.ScriptModule): def forward(self, x): tp_start = torch.jit._logging.time_point() - for i in range(30): + for _ in range(30): x += 1.0 tp_end = torch.jit._logging.time_point() torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start) @@ -80,7 +80,7 @@ class ModuleThatTimes(torch.jit.ScriptModule): @torch.jit.script_method def forward(self, x): tp_start = torch.jit._logging.time_point() - for i in range(30): + for _ in range(30): x += 1.0 tp_end = torch.jit._logging.time_point() torch.jit._logging.add_stat_value("mytimer", tp_end - tp_start) @@ -97,7 +97,7 @@ def forward(self, x): def test_counter_aggregation(self): def foo(x): - for i in range(3): + for _ in range(3): torch.jit._logging.add_stat_value("foo", 1) return x + 1.0 diff --git a/test/jit/test_misc.py b/test/jit/test_misc.py index 93c82d98c93ec..4f9eb39ef714e 100644 --- a/test/jit/test_misc.py +++ b/test/jit/test_misc.py @@ -518,7 +518,7 @@ def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor: ref = fn(x) script_fn = torch.jit.script(fn) - for i in range(4): + for _ in range(4): res = script_fn(x) self.assertEqual(ref, res) diff --git a/test/jit/test_module_containers.py b/test/jit/test_module_containers.py index eaedf48080b92..7a8bbf58224bb 100644 --- a/test/jit/test_module_containers.py +++ b/test/jit/test_module_containers.py @@ -300,7 +300,7 @@ def forward(self, inputs): # note: unable to index moduledict with a string variable currently i = 0 - for key in self.moduledict: + for _ in self.moduledict: i += 1 assert i == len(self.moduledict), "iteration failing for ModuleDict" diff --git a/test/jit/test_tracer.py b/test/jit/test_tracer.py index 8d5cfffbcad8e..a39542b5b21b9 100644 --- a/test/jit/test_tracer.py +++ b/test/jit/test_tracer.py @@ -2025,7 +2025,7 @@ def weighted_kernel_sum(self, weight): module = torch.jit.trace_module(n, inputs) check_inputs = [] - for i in range(2): + for _ in range(2): check_weight = torch.rand(1, 1, 3, 3) check_forward_input = torch.rand(1, 1, 3, 3) check_inputs.append( diff --git a/test/jit/test_types.py b/test/jit/test_types.py index a7b0752ab7500..df82d0a0e5bba 100644 --- a/test/jit/test_types.py +++ b/test/jit/test_types.py @@ -341,7 +341,7 @@ def __init__(self, x: int): self.x = x def set(self, val: int): - for i in range(3): + for _ in range(3): self.x: int = val # Type annotation in __init__, should not fail diff --git a/test/lazy/test_ts_opinfo.py b/test/lazy/test_ts_opinfo.py index 3e06539515384..bc88867bd50b3 100644 --- a/test/lazy/test_ts_opinfo.py +++ b/test/lazy/test_ts_opinfo.py @@ -85,6 +85,7 @@ def init_lists(): "linalg_inv_ex", "linalg_pinv.atol_rtol_tensor", "logsumexp", + "svd", } # For some ops, we don't support all variants. Here we use formatted_name # to uniquely identify the variant. @@ -220,20 +221,15 @@ def get_name(op): # noqa: F841 torch._lazy.wait_device_ops() prefix = "aten" if op.name in FALLBACK_LIST else "lazy" symint_suffix = "_symint" if op.name in HAS_SYMINT_SUFFIX else "" - found = f"{prefix}::{op.name}{symint_suffix}" in remove_suffixes( - torch._lazy.metrics.counter_names() - ) + metrics = remove_suffixes(torch._lazy.metrics.counter_names()) + cands = [f"{prefix}::{op.name}{symint_suffix}"] # check aliases - if not found: - for alias in op.aliases: - alias_found = ( - f"{prefix}::{alias.name}{symint_suffix}" - in remove_suffixes(torch._lazy.metrics.counter_names()) - ) - found = found or alias_found - if found: - break - self.assertTrue(found) + for alias in op.aliases: + cands.append(f"{prefix}::{alias.name}{symint_suffix}") + + self.assertTrue( + any(c in metrics for c in cands), f"none of {cands} not found in {metrics}" + ) @ops( [ diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index 64be8aac150ca..f6d0355461596 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -12,6 +12,16 @@ import torch.nn as nn import torch.nn.functional as F from torch.testing import make_tensor + + +def _get_cudnn_version(): + """Safely get cuDNN version, returning None if unavailable.""" + try: + return torch.backends.cudnn.version() + except RuntimeError: + return None + + from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN, tf32_on_and_off from torch.testing._internal.common_device_type import ( disablecuDNN, @@ -47,9 +57,11 @@ gradgradcheck, instantiate_parametrized_tests, MACOS_VERSION, + MI300_ARCH, parametrize as parametrize_test, run_tests, set_default_dtype, + skipIfRocmArch, subtest, TEST_SCIPY, TEST_WITH_ROCM, @@ -3393,8 +3405,9 @@ def test_contig_wrong_stride_cudnn(self, device): F.conv_transpose2d(x, torch.randn(16, 1, 1, 1, device=device)) F.conv2d(x, torch.randn(1, 16, 1, 1, device=device)) + @skipIfRocmArch(MI300_ARCH) @onlyCUDA - @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) + @tf32_on_and_off(0.005) def test_Conv2d_size_1_kernel(self, device): x_cpu = torch.randn(2, 3, 5, 5) conv_cpu = torch.nn.Conv2d(3, 3, kernel_size=1) @@ -3425,8 +3438,9 @@ def test_Conv2d_size_1_kernel(self, device): exact_device=False, ) + @skipIfRocmArch(MI300_ARCH) @onlyCUDA - @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) + @tf32_on_and_off(0.005) def test_ConvTranspose2d_size_1_kernel(self, device): x_cpu = torch.randn(2, 3, 5, 5) conv_cpu = torch.nn.ConvTranspose2d(3, 3, kernel_size=1) @@ -4206,10 +4220,7 @@ def test_conv3d_cudnn_broken(self, device): @largeTensorTest("20GB") @largeTensorTest("64GB", "cpu") # TODO(eqy): Remove this once it is fixed in cuDNN and we can dispatch to it again - @xfailIf( - torch.backends.cudnn.version() is not None - and torch.backends.cudnn.version() > 91000 - ) + @xfailIf(_get_cudnn_version() is not None and _get_cudnn_version() > 91000) def test_depthwise_conv_64bit_indexing(self, device): x = torch.randn(1, 2, 32800, 32800, dtype=torch.half).to( memory_format=torch.channels_last diff --git a/test/nn/test_module_hooks.py b/test/nn/test_module_hooks.py index 72e3665cfdd5d..4e8821656b7e1 100644 --- a/test/nn/test_module_hooks.py +++ b/test/nn/test_module_hooks.py @@ -873,7 +873,7 @@ def test_register_state_dict_post_hook(self, private): ) def linear_state_dict_post_hook(module, state_dict, prefix, local_metadata): - for name, param in module.named_parameters(recurse=False): + for name, _param in module.named_parameters(recurse=False): state_dict[prefix + name] = torch.nn.Parameter( state_dict[prefix + name] ) diff --git a/test/nn/test_pooling.py b/test/nn/test_pooling.py index c3a7b829b2b15..f20ee2a29d573 100644 --- a/test/nn/test_pooling.py +++ b/test/nn/test_pooling.py @@ -898,6 +898,16 @@ def test_AdaptiveMaxPool_zero_batch_dim(self, device): inp = torch.ones(1, 0, 50, 44, 31, device=device) mod(inp) + @onlyCPU + def test_LPPool1d_kernel_size_overflow_large(self, device): + avgpool = torch.nn.LPPool1d( + -1.38119e150, 7879455037536781369, ceil_mode=True + ).to(device) + inp = torch.randn(3, 15, device=device) + + with self.assertRaisesRegex(RuntimeError, "integer out of range"): + avgpool(inp) + @onlyNativeDeviceTypes def test_AvgPool2d_empty(self, device): avgpool = torch.nn.AvgPool2d(3, stride=2).to(device) diff --git a/test/onnx/exporter/test_small_models_e2e.py b/test/onnx/exporter/test_small_models_e2e.py index b59dcd4eec5f8..748c903d8308a 100644 --- a/test/onnx/exporter/test_small_models_e2e.py +++ b/test/onnx/exporter/test_small_models_e2e.py @@ -17,6 +17,8 @@ class _WithExport: def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgram: + if isinstance(model, torch.nn.Module): + model = model.eval() onnx_program = torch.onnx.export( model, args, @@ -751,7 +753,7 @@ def forward(self, x): x = torch.randn(2, 5, 3) onnx_program = self.export(RMSNormModel(), (x,), opset_version=23) - onnx_testing.assert_onnx_program(onnx_program, backend="reference") + onnx_testing.assert_onnx_program(onnx_program) # Test with multi-dimensional normalized_shape class RMSNormModel2D(torch.nn.Module): @@ -760,7 +762,7 @@ def forward(self, x): x = torch.randn(2, 5, 7, 3) onnx_program = self.export(RMSNormModel2D(), (x,), opset_version=23) - onnx_testing.assert_onnx_program(onnx_program, backend="reference") + onnx_testing.assert_onnx_program(onnx_program) def test_rms_norm_with_weight(self): """Test RMS normalization with weight parameter.""" @@ -790,7 +792,7 @@ def forward(self, x): onnx_program = self.export(RMSNormWithEps(), (x,), opset_version=23) - onnx_testing.assert_onnx_program(onnx_program, backend="reference") + onnx_testing.assert_onnx_program(onnx_program) def test_enable_gqa_in_attention_23_with_dropout(self): class Model(torch.nn.Module): diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index a474d71d49b73..2e96f70cf56f2 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -6106,7 +6106,7 @@ def test_loop_nested(self): class NestedLoopsModel(torch.jit.ScriptModule): @torch.jit.script_method def forward(self, x): - for i in range(5): + for _ in range(5): a = 0 while a < 4: a += 1 @@ -6145,7 +6145,7 @@ def test_loop_transpose(self): class LoopModel(torch.nn.Module): def forward(self, x): res = torch.zeros_like(x[0]) - for i in range(x.size(0)): + for _ in range(x.size(0)): res += x[0].transpose(0, 1) return res @@ -6780,7 +6780,7 @@ def forward(self, x): a = torch.ones( 12, ) - for i in range(10): + for _ in range(10): a.add_( torch.ones( 12, @@ -6809,7 +6809,7 @@ def forward(self, x): b_ref = b # not used in loop, should not be altered. for i in range(10): if i == 3: - for j in range(5): + for _ in range(5): a += _bias _bias.add_( torch.ones( @@ -6854,7 +6854,7 @@ def forward(self, x): ) for i in range(10): if i == 3: - for j in range(5): + for _ in range(5): self._bias += torch.arange( 12, ) @@ -6881,7 +6881,7 @@ def forward(self, x): ) for i in range(10): if i == 3: - for j in range(5): + for _ in range(5): self._bias.copy_( torch.arange( 12, @@ -8567,7 +8567,7 @@ def test_sequance_loopcarried(self): class SequanceLoopModel(torch.nn.Module): def forward(self, x): outputs = [] - for i in range(3): + for _ in range(3): outputs += [x] return torch.stack(outputs).transpose(0, 1) @@ -9768,9 +9768,9 @@ def forward(self, input1: Tensor, input2: Tensor, input3: Tensor) -> Tensor: a = (input1, input2) b = a c = (input1, input2, input3) - for i in range(5): + for _ in range(5): d = a[0] - for j in range(2): + for _ in range(2): e, f = a a = (d, f) f = c[2] @@ -9794,7 +9794,7 @@ def test_lower_tuple_2(self): class TupleModule(torch.nn.Module): def forward(self, input1: Tensor, input2: Tensor) -> tuple[Tensor, Tensor]: a = (input1, input2) - for x in range(5): + for _ in range(5): c, d = a a = (c, d) return a @@ -9812,7 +9812,7 @@ def forward( ) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: a = input1 b = input2 - for x in range(5): + for _ in range(5): c, d = a e, f = b if c.shape[0] == e.shape[0]: @@ -11418,7 +11418,7 @@ def set_cell_anchors(self, anchors): self.conv.weight = torch.arange(10) for i in range(10): if i == 3: - for j in range(10): + for _ in range(10): w = self.conv.weight self.conv.weight = torch.arange(10) + w @@ -11480,7 +11480,7 @@ def __init__(self) -> None: def set_cell_anchors(self, anchors): self.conv.weight = torch.randn(3, 10) for i in range(self.conv.weight.size(0)): - for j in range(10): + for _ in range(10): self.conv.bias = torch.randn(3, 10, 3) self.conv.weight = anchors * i self.boxes.append(torch.ones(3, 3)) @@ -11795,7 +11795,7 @@ def forward(self, x, y): elem = torch.matmul(x[0], y) for i in range(x.size(0)): res.append(torch.matmul(x[i], y)) - for i in range(x.size(0)): + for _ in range(x.size(0)): elem = res.pop() for i in range(x.size(0)): res.append(torch.matmul(x[i], y)) @@ -11815,7 +11815,7 @@ def forward(self, x, y): elem = torch.matmul(x[0], y) for i in range(x.size(0)): res.append(torch.matmul(x[i], y)) - for i in range(x.size(0)): + for _ in range(x.size(0)): del res[0] for i in range(x.size(0)): res.append(torch.matmul(x[i], y)) @@ -12452,7 +12452,7 @@ def __init__(self, dim, index, updates, loop_count): self.loop_count = loop_count def forward(self, x): - for i in range(self.loop_count): + for _ in range(self.loop_count): x.index_add_(self.dim, self.index, self.updates) return x diff --git a/test/onnx/torchlib/ops_test_data.py b/test/onnx/torchlib/ops_test_data.py index 6dd3a39a8d6fc..d48ae8f1a28f9 100644 --- a/test/onnx/torchlib/ops_test_data.py +++ b/test/onnx/torchlib/ops_test_data.py @@ -458,6 +458,18 @@ def _where_input_wrangler( TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}), TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True), TorchLibOpInfo("gelu_op20", nn_ops.aten_gelu_opset20, opset_introduced=20), + TorchLibOpInfo( + "nn.functional.group_norm", nn_ops.aten_group_norm, opset_introduced=21 + ).skip( + reason="ONNX Runtime does not support zero sized inputs for GroupNorm", + matcher=lambda sample: sample.input.numel() == 0, + ), + TorchLibOpInfo( + "nn.functional.rms_norm", nn_ops.aten_rms_norm, opset_introduced=23 + ).skip( + reason="ONNX Runtime does not support <1d inputs or zero sized inputs for RMSNorm", + matcher=lambda sample: len(sample.input.shape) < 2 or sample.input.numel() == 0, + ), ) diff --git a/test/optim/test_lrscheduler.py b/test/optim/test_lrscheduler.py index 2519207126102..797822ea4deee 100644 --- a/test/optim/test_lrscheduler.py +++ b/test/optim/test_lrscheduler.py @@ -192,7 +192,7 @@ def old_pattern2(): def test_old_pattern_warning_resuming(self): epochs = 35 - for i, group in enumerate(self.opt.param_groups): + for group in self.opt.param_groups: group["initial_lr"] = 0.01 with warnings.catch_warnings(record=True) as ws: @@ -209,7 +209,7 @@ def old_pattern(): def test_old_pattern_warning_resuming_with_arg(self): epochs = 35 - for i, group in enumerate(self.opt.param_groups): + for group in self.opt.param_groups: group["initial_lr"] = 0.01 with warnings.catch_warnings(record=True) as ws: @@ -226,7 +226,7 @@ def old_pattern2(): def test_old_pattern_warning_with_overridden_optim_step(self): epochs = 35 - for i, group in enumerate(self.opt.param_groups): + for group in self.opt.param_groups: group["initial_lr"] = 0.01 with warnings.catch_warnings(record=True) as ws: @@ -299,7 +299,7 @@ def new_step(o, *args, **kwargs): self.opt.step = types.MethodType(new_step, self.opt) def new_pattern(): - for e in range(epochs): + for _ in range(epochs): self.opt.step() scheduler.step() @@ -2617,7 +2617,7 @@ def test_constant_initial_params_swalr(self): sch = SWALR(opt, swa_lr=swa_lr) ori_param_groups = copy.deepcopy(opt.param_groups) - for i in range(2): + for _ in range(2): lr.multiply_(0.5) swa_lr.multiply_(0.5) opt.step() diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index b30d25ec9af63..25fb60674e59e 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -109,7 +109,7 @@ def test_mem_leak(self): t = torch.rand(1, 1).cuda() p = psutil.Process() last_rss = collections.deque(maxlen=5) - for outer_idx in range(10): + for _ in range(10): with _profile(use_cuda=True): for _ in range(1024): t = torch.mm(t, t) @@ -1054,7 +1054,7 @@ def trace_handler(p): schedule=torch.profiler.schedule(wait=1, warmup=1, active=2), on_trace_ready=trace_handler, ) as p: - for idx in range(8): + for _ in range(8): self.payload(use_cuda=use_cuda) p.step() @@ -1144,14 +1144,14 @@ def run_batch(): # See https://github.com/pytorch/pytorch/issues/88446 optimizer_step() - for idx in range(niters): + for _ in range(niters): run_batch() with profile( activities=supported_activities(), schedule=torch.profiler.schedule(wait=1, warmup=1, active=2), ) as p: - for idx in range(niters): + for _ in range(niters): run_batch() p.step() @@ -1508,7 +1508,7 @@ def test_profiler_correlation_id(self): ) inputs = torch.randn(40, 16, 18, 260) uint32_max = 2**32 - 1 - for i in range(5): + for _ in range(5): with profile() as prof: model(inputs) for event in prof.profiler.kineto_results.events(): @@ -2023,7 +2023,7 @@ def test_profiler_time_scale(self): WAIT_TIME = 10 with profile() as p: with torch.profiler.record_function("test_span"): - for i in range(WAIT_TIME): + for _ in range(WAIT_TIME): torch.rand(4, 4) time.sleep(1) events = p.events() @@ -2072,7 +2072,7 @@ def _schedule_helper(self, warmup, active, repeat, acc_events=True): ), acc_events=acc_events, ) as prof: - for i in range(100): + for _ in range(100): torch.add(1, 2) prof.step() # print(prof.key_averages()) @@ -2124,7 +2124,7 @@ def test_cpu_annotation_overlap(self): adjust_profiler_step=True ), ) as prof: - for i in range(5): + for _ in range(5): self._step_helper_func(prof) with TemporaryFileName(mode="w+") as fname: prof.export_chrome_trace(fname) @@ -3161,7 +3161,7 @@ def unpack(fmt, offset): r.seed(1) text_sections = get_text_sections() addrs = [] - for i in range(200): + for _ in range(200): s = r.randrange(0, len(text_sections)) start, size = text_sections[s] addr = r.randrange(start, start + size) diff --git a/test/quantization/core/test_quantized_tensor.py b/test/quantization/core/test_quantized_tensor.py index a98f5e379343c..4dc56e03488a3 100644 --- a/test/quantization/core/test_quantized_tensor.py +++ b/test/quantization/core/test_quantized_tensor.py @@ -1062,8 +1062,8 @@ def _test_qtensor_masked_fill(self, device): mask = torch.randint(0, 2, (numel, ), device=device) mask = mask.bool() x = torch.rand(numel, device=device) - qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qtype) for qtype, fill_with in itertools.product(types, fills): + qx = torch.quantize_per_tensor(x, scale=scale, zero_point=zero_point, dtype=qtype) q_masked_fill = qx.clone() q_masked_fill.masked_fill_(mask, fill_with) ref = qx.clone() diff --git a/test/quantization/core/test_workflow_ops.py b/test/quantization/core/test_workflow_ops.py index c1e8ecfa214bc..f69852760e8a0 100644 --- a/test/quantization/core/test_workflow_ops.py +++ b/test/quantization/core/test_workflow_ops.py @@ -344,7 +344,7 @@ def test_forward_per_tensor_half_precision_numerics(self): maxi = 255 mini = 0 - for i in range(20): + for _ in range(20): X1 = torch.randn(5, 5).to(torch.float16) Y1 = torch.fake_quantize_per_tensor_affine(X1, scale, zero, mini, maxi) Y1r = _fake_quantize_per_tensor_affine_reference(X1, scale, zero, mini, maxi) @@ -770,7 +770,7 @@ def test_forward_per_channel_half_precision_numerics(self): mini = 0 maxi = 255 - for i in range(20): + for _ in range(20): X1 = torch.randn(4, 5).to(torch.float16) Y1 = torch.fake_quantize_per_channel_affine(X1, scale, zero, axis, mini, maxi) Y1r = _fake_quantize_per_channel_affine_reference(X1, scale, zero, axis, mini, maxi) @@ -1028,7 +1028,7 @@ def _test_numerical_consistency(self, test_type): zero_types = [torch.int] devices = [torch.device('cpu'), torch.device('cuda')] if torch.cuda.is_available() else [torch.device('cpu')] axis = 1 - for i in range(20): + for _ in range(20): for torch_type, float_type, device, zero_type in itertools.product(torch_types, float_types, devices, zero_types): X = torch.randn(3, 3, device=device).to(float_type) scales = (10 * torch.randn(3, device=device)).abs() diff --git a/test/quantization/eager/test_numeric_suite_eager.py b/test/quantization/eager/test_numeric_suite_eager.py index cd11e96859937..f1b89fc5790ed 100644 --- a/test/quantization/eager/test_numeric_suite_eager.py +++ b/test/quantization/eager/test_numeric_suite_eager.py @@ -154,7 +154,7 @@ def compare_and_validate_results(float_model, q_model): for v in weight_dict.values(): self.assertTrue(len(v["float"]) == len(v["quantized"])) for i, val in enumerate(v["quantized"]): - self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) + self.assertTrue(v["float"][i].shape == val.shape) model_list = [SingleLayerLinearDynamicModel(qengine)] for model in model_list: @@ -178,7 +178,7 @@ def compare_and_validate_results(float_model, q_model): for v in weight_dict.values(): self.assertTrue(len(v["float"]) == len(v["quantized"])) for i, val in enumerate(v["quantized"]): - self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) + self.assertTrue(v["float"][i].shape == val.shape) model_list = [LSTMwithHiddenDynamicModel(qengine)] for model in model_list: @@ -200,7 +200,7 @@ def compare_and_validate_results(float_model, q_model, module_swap_list, data): for v in ob_dict.values(): self.assertTrue(len(v["float"]) == len(v["quantized"])) for i, val in enumerate(v["quantized"]): - self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) + self.assertTrue(v["float"][i].shape == val.shape) model_list = [ AnnotatedConvModel(qengine), @@ -235,7 +235,7 @@ def compare_and_validate_results(float_model, q_model, module_swap_list, data): for v in ob_dict.values(): self.assertTrue(len(v["float"]) == len(v["quantized"])) for i, val in enumerate(v["quantized"]): - self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) + self.assertTrue(v["float"][i].shape == val.shape) linear_data = self.calib_data[0][0] module_swap_list = [nn.Linear] @@ -260,7 +260,7 @@ def compare_and_validate_results(float_model, q_model, module_swap_list, data): for v in ob_dict.values(): self.assertTrue(len(v["float"]) == len(v["quantized"])) for i, val in enumerate(v["quantized"]): - self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) + self.assertTrue(v["float"][i].shape == val.shape) linear_data = self.calib_data[0][0] module_swap_list = [nn.Linear] @@ -314,7 +314,7 @@ def test_compare_model_stub_functional_static(self): for v in ob_dict.values(): self.assertTrue(len(v["float"]) == len(v["quantized"])) for i, val in enumerate(v["quantized"]): - self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) + self.assertTrue(v["float"][i].shape == val.shape) @override_qengines def test_compare_model_stub_linear_dynamic(self): @@ -328,7 +328,7 @@ def compare_and_validate_results(float_model, q_model, module_swap_list, data): for v in ob_dict.values(): self.assertTrue(len(v["float"]) == len(v["quantized"])) for i, val in enumerate(v["quantized"]): - self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) + self.assertTrue(v["float"][i].shape == val.shape) linear_data = self.calib_data[0][0] @@ -357,7 +357,7 @@ def compare_and_validate_results( for v in ob_dict.values(): self.assertTrue(len(v["float"]) == len(v["quantized"])) for i, val in enumerate(v["quantized"]): - self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) + self.assertTrue(v["float"][i].shape == val.shape) lstm_input = torch.rand((1, 1, 2)) lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2)) @@ -411,7 +411,7 @@ def compare_and_validate_results(float_model, q_model, data): for v in act_compare_dict.values(): self.assertTrue(len(v["float"]) == len(v["quantized"])) for i, val in enumerate(v["quantized"]): - self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) + self.assertTrue(v["float"][i].shape == val.shape) linear_data = self.calib_data[0][0] model_list = [AnnotatedSingleLayerLinearModel(qengine)] @@ -447,7 +447,7 @@ def test_compare_model_outputs_functional_static(self): for v in act_compare_dict.values(): self.assertTrue(len(v["float"]) == len(v["quantized"])) for i, val in enumerate(v["quantized"]): - self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) + self.assertTrue(v["float"][i].shape == val.shape) @override_qengines def test_compare_model_outputs_linear_dynamic(self): @@ -464,7 +464,7 @@ def compare_and_validate_results(float_model, q_model, data): for v in act_compare_dict.values(): self.assertTrue(len(v["float"]) == len(v["quantized"])) for i, val in enumerate(v["quantized"]): - self.assertTrue(v["float"][i].shape == v["quantized"][i].shape) + self.assertTrue(v["float"][i].shape == val.shape) linear_data = self.calib_data[0][0] @@ -493,18 +493,12 @@ def compare_and_validate_results(float_model, q_model, input, hidden): for v in act_compare_dict.values(): self.assertTrue(len(v["float"]) == len(v["quantized"])) for i, val in enumerate(v["quantized"]): - self.assertTrue(len(v["float"][i]) == len(v["quantized"][i])) + self.assertTrue(len(v["float"][i]) == len(val)) if i == 0: - self.assertTrue( - v["float"][i][0].shape == v["quantized"][i][0].shape - ) + self.assertTrue(v["float"][i][0].shape == val[0].shape) else: - self.assertTrue( - v["float"][i][0].shape == v["quantized"][i][0].shape - ) - self.assertTrue( - v["float"][i][1].shape == v["quantized"][i][1].shape - ) + self.assertTrue(v["float"][i][0].shape == val[0].shape) + self.assertTrue(v["float"][i][1].shape == val[1].shape) lstm_input = torch.rand((1, 1, 2)) lstm_hidden = (torch.rand(1, 1, 2), torch.rand(1, 1, 2)) diff --git a/test/quantization/fx/test_model_report_fx.py b/test/quantization/fx/test_model_report_fx.py index 58c88c487348d..78408c1b5a36d 100644 --- a/test/quantization/fx/test_model_report_fx.py +++ b/test/quantization/fx/test_model_report_fx.py @@ -525,7 +525,7 @@ def forward(self, x): def run_model_and_common_checks(self, model, ex_input, num_epochs, batch_size): # split up data into batches split_up_data = torch.split(ex_input, batch_size) - for epoch in range(num_epochs): + for _epoch in range(num_epochs): # reset all model report obs model.apply( lambda module: module.reset_batch_and_epoch_values() @@ -952,7 +952,7 @@ def test_prepare_model_callibration(self): # see whether observers properly in regular nn.Module # there should be 4 observers present in this case modules_observer_cnt = 0 - for fqn, module in prepared_for_callibrate_model.named_modules(): + for module in prepared_for_callibrate_model.modules(): if isinstance(module, ModelReportObserver): modules_observer_cnt += 1 @@ -999,7 +999,7 @@ def get_module_and_graph_cnts(self, callibrated_fx_module): """ # get the number of observers stored as modules modules_observer_cnt = 0 - for fqn, module in callibrated_fx_module.named_modules(): + for module in callibrated_fx_module.modules(): if isinstance(module, ModelReportObserver): modules_observer_cnt += 1 @@ -1058,7 +1058,7 @@ def test_generate_report(self): # now calibrate the two models num_iterations = 10 - for i in range(num_iterations): + for _ in range(num_iterations): example_input = torch.tensor(torch.randint(100, (1, 3, 3, 3)), dtype=torch.float) prepared_for_callibrate_model_full(example_input) prepared_for_callibrate_model_single(example_input) @@ -1324,7 +1324,7 @@ def test_input_weight_equalization_determine_points(self): fused = self._get_prepped_for_calibration_model(self.TwoBlockComplexNet(), detector_set, fused=True) # reporter should still give same counts even for fused model - for prepared_for_callibrate_model, mod_report in [non_fused, fused]: + for prepared_for_callibrate_model, _mod_report in [non_fused, fused]: # supported modules to check mods_to_check = {nn.Linear, nn.Conv2d} @@ -1345,7 +1345,7 @@ def test_input_weight_equalization_determine_points(self): self.assertEqual(number_of_obs_found, correct_number_of_obs_inserted) # assert that each of the desired modules have the observers inserted - for fqn, module in prepared_for_callibrate_model.named_modules(): + for module in prepared_for_callibrate_model.modules(): # check if module is a supported module is_in_include_list = sum(isinstance(module, x) for x in mods_to_check) > 0 @@ -1569,7 +1569,7 @@ def test_outlier_detection_determine_points(self): self.assertEqual(number_of_obs_found, correct_number_of_obs_inserted) # assert that each of the desired modules have the observers inserted - for fqn, module in prepared_for_callibrate_model.named_modules(): + for module in prepared_for_callibrate_model.modules(): # check if module is a supported module is_in_include_list = isinstance(module, tuple(mods_to_check)) diff --git a/test/quantization/fx/test_numeric_suite_fx.py b/test/quantization/fx/test_numeric_suite_fx.py index b53b9b0193e07..2b8afe1c7c8d8 100644 --- a/test/quantization/fx/test_numeric_suite_fx.py +++ b/test/quantization/fx/test_numeric_suite_fx.py @@ -2252,7 +2252,7 @@ def test_logger_enabled_and_save_activations_flags(self): msp(*example_input) def _check_logger_count(model, exp_count_stats, exp_count_comparisons): - for name, mod in model.named_modules(): + for mod in model.modules(): if isinstance(mod, OutputLogger): self.assertTrue( len(mod.stats) == exp_count_stats, diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index c54c741bcec3d..cd922d94c60c3 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -4672,13 +4672,13 @@ def forward(self, x): m = prepare(m, {"": qconfig}, example_inputs=example_inputs) # check that there is a duplicated observer instance actpp_module_count = 0 - for name, module in m.named_modules(remove_duplicate=False): + for module in m.modules(remove_duplicate=False): if isinstance(module, actpp_module_class): actpp_module_count += 1 self.assertEqual(actpp_module_count, 2) actpp_module_count = 0 - for name, module in m.named_modules(): + for module in m.modules(): if isinstance(module, actpp_module_class): actpp_module_count += 1 self.assertEqual(actpp_module_count, 1) @@ -5732,7 +5732,7 @@ def forward(self, x): m = M().eval() qconfig_dict = func(backend) m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1))) - for name, mod in m.named_modules(): + for mod in m.modules(): if _is_activation_post_process(mod) and mod.dtype == torch.quint8: if backend == "fbgemm": lower_bnd = 0 @@ -9435,7 +9435,7 @@ def _test_model_impl( criterion = nn.CrossEntropyLoss() train_one_epoch(prepared, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) else: - for i in range(10): + for _ in range(10): prepared(input_value) # print('after observation root:', prepared.root) @@ -9480,7 +9480,7 @@ def _test_model_impl( optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001) train_one_epoch(qeager, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) else: - for i in range(10): + for _ in range(10): qeager(input_value) # print('ref after observation:', qeager) diff --git a/test/quantization/jit/test_quantize_jit.py b/test/quantization/jit/test_quantize_jit.py index ec7618fb551b8..81bdd50adbd43 100644 --- a/test/quantization/jit/test_quantize_jit.py +++ b/test/quantization/jit/test_quantize_jit.py @@ -331,7 +331,7 @@ class SubModule(torch.nn.Module): def __init__(self, dim, num_blocks, enable_bias, enable_affine): super().__init__() layers = [] - for i in range(num_blocks): + for _ in range(num_blocks): layers.append(conv_module[dim](20, 20, 5, 1, bias=enable_bias)) bn_obj = bn_module[dim](num_features=20, affine=enable_affine) if enable_affine: diff --git a/test/test_autograd.py b/test/test_autograd.py index ee6d9c09282bd..6c3e250df7c7c 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -54,6 +54,7 @@ dtypes, dtypesIfCUDA, dtypesIfMPS, + expectedFailureMPS, instantiate_device_type_tests, onlyCPU, onlyCUDA, @@ -72,6 +73,7 @@ run_tests, scoped_load_inline, set_warn_always_context, + skipCUDANonDefaultStreamIf, skipIfMPS, skipIfNoLapack, skipIfTorchDynamo, @@ -582,14 +584,14 @@ def unpack(x): ctx_1 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, unpack) ctx_2 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x) - for i in range(10): + for _ in range(10): with ctx_2: ctx_1.__enter__() x = torch.randn(3, 3, requires_grad=True) x.sin().sum().backward() # Clean up - for i in range(10): + for _ in range(10): ctx_1.__exit__() # Validate there are no more hooks on the stack @@ -2989,7 +2991,7 @@ def coro_enable_grad(state): state = set() with torch.enable_grad(): coro = coro_no_grad(state) - for i in range(5): + for _ in range(5): next(coro) coro.close() @@ -2998,7 +3000,7 @@ def coro_enable_grad(state): state = set() with torch.no_grad(): coro = coro_enable_grad(state) - for i in range(5): + for _ in range(5): next(coro) coro.close() @@ -5293,7 +5295,7 @@ def test_profiler_aggregation_lstm(self): rnn = torch.nn.LSTM(10, 20, 2) total_time_s = 0 with profile(record_shapes=True, use_kineto=kineto_available()) as prof: - for i in range(20): + for _ in range(20): input = torch.randn(5, 3, 10) h = torch.randn(2, 3, 20) c = torch.randn(2, 3, 20) @@ -5925,7 +5927,7 @@ def backward(ctx, grad): self.assertTrue(p_a == p_g or p_b == p_g) # Run backwards multiple times to ensure accumulation works. - for i in range(10): + for _ in range(10): loss.backward(retain_graph=True) # non-contiguous indices and value, we should trigger a copy. @@ -5943,7 +5945,7 @@ def backward(ctx, grad): self.assertFalse(p_b == p_g) # Run backwards multiple times to ensure accumulation works. - for i in range(10): + for _ in range(10): loss.backward(retain_graph=True) def test_gradcheck_single_input(self): @@ -7132,7 +7134,7 @@ def test_checkpointing(self): ) feat_combined = [] - for r in range(num_inp): + for _ in range(num_inp): data_r = torch.empty(1, nz_inp) data_r.uniform_() data_r.requires_grad = True @@ -7202,7 +7204,7 @@ def __init__(self, n, use_checkpoint, use_reentrant): self.use_checkpoint = use_checkpoint self.use_reentrant = use_reentrant self.layers = nn.ModuleList() - for i in range(self.n): + for _ in range(self.n): layer = nn.Sequential( nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256) ) @@ -7513,7 +7515,7 @@ def forward(self, data): feat_combined = [] feat_combined_no_checkpoint = [] - for r in range(num_inp): + for _ in range(num_inp): data_r = torch.empty(1, nz_inp) data_r.uniform_() data_r.requires_grad = input_requires_grad @@ -11714,7 +11716,7 @@ def test_scatter_index_reduce_prod_gradgrad_error(self, device): def test_parameter_resize(self, device): asd = torch.nn.Parameter(torch.ones(16, dtype=torch.double, device=device)) - for i in range(2): + for _ in range(2): with torch.no_grad(): asd.set_(asd[1:]) asd.grad = None @@ -11942,7 +11944,7 @@ def _test_reentrant_parent_error_on_cpu(self, device): # Child gpu graph (much longer than parent graph). prev = t2 * t2 - for i in range(10): + for _ in range(10): prev = prev * t2 reentrant_root = prev @@ -13325,9 +13327,12 @@ def assert_all_streams_default(self, num_devices=1): ) # AttributeError: module 'torch.mps' has no attribute 'default_stream' - @skipIfMPS - @unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") - def test_consumer_to_single_producer_case_2_correctness(self): + @expectedFailureMPS + @skipCUDANonDefaultStreamIf(True) + def test_consumer_to_single_producer_case_2_correctness(self, device): + if device == "cpu": + self.skipTest("requires accelerator") + # Device Stream # Consumer (MulBackward): cuda:0 s0 # Producer : cuda:0 s1 @@ -13430,36 +13435,43 @@ def call_backward(x): test() # AttributeError: module 'torch.mps' has no attribute 'default_stream' - @skipIfMPS - @unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") + @expectedFailureMPS + @skipCUDANonDefaultStreamIf(True) @unittest.skipIf( torch.accelerator.device_count() < 2, "accelerator count is less than 2" ) def test_consumer_to_single_producer_case_3_correctness_non_default_ambient_stream( - self, + self, device ): + if device == "cpu": + self.skipTest("requires accelerator") self._test_consumer_to_single_producer_case_3_correctness( non_default_ambient_stream=True ) # AttributeError: module 'torch.mps' has no attribute 'default_stream' - @skipIfMPS - @unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") + @expectedFailureMPS + @skipCUDANonDefaultStreamIf(True) @unittest.skipIf( torch.accelerator.device_count() < 2, "accelerator count is less than 2" ) - def test_consumer_to_single_producer_case_3_correctness(self): + def test_consumer_to_single_producer_case_3_correctness(self, device): + if device == "cpu": + self.skipTest("requires accelerator") self._test_consumer_to_single_producer_case_3_correctness( non_default_ambient_stream=False ) # AttributeError: module 'torch.mps' has no attribute 'default_stream' - @skipIfMPS - @unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") + @expectedFailureMPS + @skipCUDANonDefaultStreamIf(True) @unittest.skipIf( torch.accelerator.device_count() < 2, "accelerator count is less than 2" ) - def test_consumer_to_single_producer_case_4_correctness(self): + def test_consumer_to_single_producer_case_4_correctness(self, device): + if device == "cpu": + self.skipTest("requires accelerator") + # Device Stream # Consumer: cuda:0 cuda:0 default # Producer: cuda:1 s1 @@ -13516,12 +13528,15 @@ def test(): test() # AttributeError: module 'torch.mps' has no attribute 'default_stream' - @skipIfMPS - @unittest.skipIf(not torch.accelerator.is_available(), "requires accelerator") + @expectedFailureMPS + @skipCUDANonDefaultStreamIf(True) @unittest.skipIf( torch.accelerator.device_count() < 2, "accelerator count is less than 2" ) - def test_consumer_to_multi_producer_case_4_correctness(self): + def test_consumer_to_multi_producer_case_4_correctness(self, device): + if device == "cpu": + self.skipTest("requires accelerator") + # Device Stream # Consumer : cuda:0 cuda:0 default # @@ -13603,12 +13618,11 @@ def test(): for _ in range(2): test() - # AttributeError: module 'torch.mps' has no attribute 'default_stream' - @skipIfMPS # This test may spuriously fail on non-cuda accelerators (since we won't # be calling sleep) - @unittest.skipIf(not TEST_CUDA, "requires CUDA") - def test_side_stream_backward_overlap(self): + @onlyCUDA + @skipCUDANonDefaultStreamIf(True) + def test_side_stream_backward_overlap(self, device): # In case 2/3, we would designate the consumer as the accumulation # stream and naively, one might have the consumer wait for the producer # as soon as we've added to the InputBuffer the first time. @@ -13709,6 +13723,54 @@ def check_ordering(): populate_events() check_ordering() + @expectedFailureMPS + def test_warn_on_accumulate_grad_stream_mismatch_flag(self, device): + if device == "cpu": + self.skipTest("requires accelerator") + + def do_test(suppress_warn, keep_grad_acc): + def _test(): + with warnings.catch_warnings(record=True) as warns: + warnings.simplefilter("always") + + with torch.Stream(0) as s0: + a = torch.ones(8, 8, device=device, requires_grad=True) + if keep_grad_acc: + # create grad_acc under s1 and keep alive with b + b = a.clone() + + with torch.Stream(0) as s1: + s1.wait_stream(s0) + c = a.sum() + + c.backward() + + filter_str = "set_warn_on_accumulate_grad_stream_mismatch" + return sum([filter_str in str(w.message) for w in warns]) > 0 + + if suppress_warn: + try: + torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch( + False + ) + actual_warn = _test() + finally: + torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch( + True + ) + else: + actual_warn = _test() + + expect_warn = not suppress_warn and keep_grad_acc + self.assertEqual(actual_warn, expect_warn) + + # Warn by default + self.assertTrue(torch._C._warn_on_accumulate_grad_stream_mismatch()) + + for suppress_warn in (True, False): + for keep_grad_acc in (True, False): + do_test(suppress_warn=suppress_warn, keep_grad_acc=keep_grad_acc) + class TestMultithreadAutograd(TestCase): def _run_py_multithread_fn( @@ -15196,6 +15258,9 @@ def log_grad_order(grad: torch.Tensor, name: str, order): instantiate_device_type_tests( TestAutogradMultipleDispatch, globals(), only_for=("cpu", "cuda") ) +instantiate_device_type_tests( + TestAutogradStreamSynchronization, globals(), except_for=None +) instantiate_parametrized_tests(TestAutograd) instantiate_parametrized_tests(TestNestedCheckpoint) diff --git a/test/test_cuda.py b/test/test_cuda.py index 0d290f08d9cf4..7ef757442f8d7 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -35,6 +35,7 @@ from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast from torch.testing._internal.common_cuda import ( _create_scaling_case, + HAS_WORKING_NVML, SM70OrLater, TEST_CUDNN, TEST_MULTIGPU, @@ -1168,7 +1169,7 @@ def test_record_stream_on_shifted_view(self): stream_record = torch.cuda.Stream() with torch.cuda.stream(stream_record): - torch.cuda._sleep(int(50 * get_cycles_per_ms())) + torch.cuda._busy_wait_for_flag() view.record_stream(stream_record) @@ -1181,6 +1182,7 @@ def test_record_stream_on_shifted_view(self): with torch.cuda.stream(stream_alloc): try_realloc = torch.cuda.FloatTensor([10, 10]) + torch.cuda._clear_flag() self.assertNotEqual(try_realloc.data_ptr(), data_ptr) @@ -4389,7 +4391,7 @@ def test_memory_plots_metadata(self): torch._C._cuda_clearCublasWorkspaces() torch.cuda.memory.empty_cache() torch.cuda.memory._set_memory_metadata("metadata test") - torch.cuda.memory._record_memory_history(context="all") + torch.cuda.memory._record_memory_history(context=context) x = torch.rand(3, 4, device="cuda") del x torch.cuda.memory.empty_cache() @@ -4783,7 +4785,7 @@ def free(): total -= x.numel() choices = [alloc, free, torch.cuda.memory.empty_cache] - for i in range(N): + for _ in range(N): while total >= 1024 * 1024 * 1024 / (4 * 10): free() (action,) = random.choices(choices, weights=[1, 1 if mem else 0, 0.1]) @@ -4802,6 +4804,7 @@ def test_nvml_get_handler(self): def test_temperature(self): self.assertTrue(0 <= torch.cuda.temperature() <= 150) + @unittest.skipIf(not HAS_WORKING_NVML, "pynvml availble but broken") @unittest.skipIf(TEST_WITH_ROCM, "flaky for AMD gpu") @unittest.skipIf(not TEST_PYNVML, "pynvml/amdsmi is not available") def test_device_memory_used(self): @@ -6966,7 +6969,8 @@ def test_compile_kernel_large_shared_memory(self): with self.assertRaises(RuntimeError): kernel.set_shared_memory_config(excessive_shared_mem) - @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) + @skipIfRocmArch(MI300_ARCH) + @tf32_on_and_off(0.005) @unittest.skipIf(not TEST_CUDA, "No CUDA") def test_compile_kernel_advanced(self): # Test matrix multiplication diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 3ac803239c53f..ba3fe63ed1f1b 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -242,14 +242,14 @@ def __len__(self): dataset = CustomDataset(self, x) dataset = random_split(dataset, [5])[0] data_loader = DataLoader(dataset) - for batch in data_loader: + for _batch in data_loader: pass # fractional splitting dataset = CustomDataset(self, x) dataset = random_split(dataset, [1.0])[0] data_loader = DataLoader(dataset) - for batch in data_loader: + for _batch in data_loader: pass def test_splits_reproducibility(self): @@ -1155,7 +1155,7 @@ def __iter__(self): worker_info = torch.utils.data.get_worker_info() assert worker_info is not None worker_id = worker_info.id - for idx in range(self.length // worker_info.num_workers): + for _ in range(self.length // worker_info.num_workers): yield worker_id def __len__(self): @@ -2000,7 +2000,7 @@ def test_multi_epochs_reproducibility(self): dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers ) - for ind in range(num_epochs): + for _ in range(num_epochs): for batch_idx, sample in enumerate(dataloader): self.assertEqual( sample.tolist(), [batch_idx % num_workers] * batch_size @@ -3018,7 +3018,7 @@ def _create_dp(buffer_size): # Same seeds dl_res = [] - for epoch in range(2): + for _epoch in range(2): torch.manual_seed(123) dl_res.append(list(dl)) self.assertEqual(dl_res[0], dl_res[1]) @@ -3238,7 +3238,7 @@ def test_dataset_not_reset(self): ) dataset.start = 0 for i in range(10): - for x in dataloader: + for _ in dataloader: pass # Changing the start value here doesn't have any effect in the dataset # cached by the workers. since they are not recreated between epochs diff --git a/test/test_datapipe.py b/test/test_datapipe.py index 2790145665b13..5a535e7e00663 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -658,7 +658,7 @@ def collate_j(column): ] actual_i = [] - for i, j in df_numbers: + for i, _ in df_numbers: actual_i.append(i) self.assertEqual(expected_i, actual_i) @@ -2632,7 +2632,7 @@ def __init__(self, dp: IterDataPipe[tuple[int, str]]): self.dp = dp def __iter__(self) -> Iterator[int]: - for a, b in self.dp: + for a, _ in self.dp: yield a # Non-DataPipe input with DataPipe hint diff --git a/test/test_decomp.py b/test/test_decomp.py index c65bc07cd9c9b..f5c791c8cbe88 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -1258,11 +1258,10 @@ def forward_pass_fn(): ) # check RMSNorm was fused with sinh + self.assertTrue("triton_per_fused__fused_rms_norm_sinh" in generated_codes[0]) self.assertTrue( - "triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0] - ) - self.assertTrue( - "triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1] + "triton_per_fused__fused_rms_norm__fused_rms_norm_backward_cosh_mul" + in generated_codes[1] ) diff --git a/test/test_extension_utils.py b/test/test_extension_utils.py index d624f93a1d3fe..d114a06dcef5d 100644 --- a/test/test_extension_utils.py +++ b/test/test_extension_utils.py @@ -2,7 +2,7 @@ import sys import torch -from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase class DummyPrivateUse1Module: @@ -60,6 +60,9 @@ def test_external_module_register(self): with self.assertRaisesRegex(RuntimeError, "The runtime module of"): torch._register_device_module("privateuseone", DummyPrivateUse1Module) + @skipIfTorchDynamo( + "accelerator doesn't compose with privateuse1 : https://github.com/pytorch/pytorch/issues/166696" + ) def test_external_module_register_with_renamed_backend(self): torch.utils.rename_privateuse1_backend("foo") with self.assertRaisesRegex(RuntimeError, "has already been set"): diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 0a5b6faab2f63..692a37b193d5e 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -512,7 +512,7 @@ def test_print_in_fake_mode(self): def test_upsample_bilinear_small_channels(self): out = [] mode = FakeTensorMode() - for i, context in enumerate([contextlib.nullcontext, lambda: mode]): + for context in [contextlib.nullcontext, lambda: mode]: with context(): arg0_1 = torch.empty_strided( (3, 427, 640), (1, 1920, 3), dtype=torch.float32, device="cuda" @@ -1058,6 +1058,17 @@ def add(x, y): self.assertIsInstance(r[0], FakeTensor) self.assertIsInstance(r[1], FakeTensor) + def test_fast_div_int_to_float(self): + mode = FakeTensorMode() + with mode: + x = torch.empty(2, 2, device="cpu", dtype=torch.int32) + y = torch.empty(2, 2, device="cpu", dtype=torch.int32) + from torch._subclasses.fake_impls import get_fast_op_impls + + fast_div = get_fast_op_impls()[torch.ops.aten.div.Tensor] + z = fast_div(mode, x, y) + self.assertEqual(z.dtype, torch.float32) + def test_fast_div(self): mode = FakeTensorMode() with mode: diff --git a/test/test_flop_counter.py b/test/test_flop_counter.py index 03eb15744b543..8f74644d1eb9d 100644 --- a/test/test_flop_counter.py +++ b/test/test_flop_counter.py @@ -270,7 +270,7 @@ def test_conv_transpose_loop(self): model = torch.nn.ConvTranspose2d(4, 8, (2, 2), stride=2) with FlopCounterMode() as mode: - for i in range(50): + for _ in range(50): out = model(x) out.sum().backward() self.assertExpectedInline(str(mode.get_total_flops()), """1536000""") diff --git a/test/test_foreach.py b/test/test_foreach.py index 12c2ec7ccc961..a266bcd071411 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -318,13 +318,11 @@ def clone(arg): return arg scalar_self_arg_test_complete = False - for i, sample in enumerate( - op.sample_inputs( - device, - dtype, - noncontiguous=not is_fastpath, - allow_higher_dtype_scalars=True, - ) + for sample in op.sample_inputs( + device, + dtype, + noncontiguous=not is_fastpath, + allow_higher_dtype_scalars=True, ): (rhs_arg,) = sample.args kwargs = {} or sample.kwargs diff --git a/test/test_fx.py b/test/test_fx.py index 4c4a6d8c619ae..880cc91edc067 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -2365,7 +2365,7 @@ def test_deepcopy_recursion_depth(self): g = torch.fx.Graph() x = g.placeholder("x") - for i in range(depth): + for _ in range(depth): x = g.call_function(torch.relu, (x,)) g.output(x) diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index d74a3febf171f..6fe3fe2355a1e 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -262,7 +262,7 @@ def __init__(self) -> None: self.embedding_layers = torch.nn.ModuleList() el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) self.embedding_layers.append(el) - for i in range(3): + for _ in range(3): el = torch.nn.EmbeddingBag(1000000, 4, mode="sum", sparse=True) self.embedding_layers.append(el) el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True) @@ -272,7 +272,7 @@ def forward(self, a, b, offset): x = self.bottom_layers(a) y = [] c = [] - for i in range(len(self.embedding_layers)): + for _ in range(len(self.embedding_layers)): temp = torch.randint(10, (8,)) c.append(temp + b) for i in range(len(self.embedding_layers)): diff --git a/test/test_indexing.py b/test/test_indexing.py index cca7a21165d0c..f69c326939aa6 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -156,7 +156,7 @@ def consec(size, start=1): torch.DoubleTensor if not device.startswith("mps") else torch.FloatTensor ) tensor = _make_tensor(lst).to(device) - for _i in range(100): + for _ in range(100): idx1_start = random.randrange(10) idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1) idx1_step = random.randrange(1, 8) diff --git a/test/test_jit.py b/test/test_jit.py index 137979fcc4f15..99d7e711da305 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -2337,9 +2337,9 @@ def constant_prop(cond, iter): print("stays") while False: print("removed") - for _i in range(0): + for _ in range(0): print("removed") - for _i in range(-4): + for _ in range(-4): print("removed") return b @@ -3138,7 +3138,7 @@ def test_request_bailout(self): with enable_profiling_mode_for_profiling_tests(): def fct_loop(x): - for i in range(3): + for _ in range(3): x = torch.cat((x, x), 0) return x @@ -3245,7 +3245,7 @@ def test_not_const(x): def test_nested_bailouts(self): @torch.jit.script def fct_loop(x): - for i in range(3): + for _ in range(3): x = torch.cat((x, x), 0) return x @@ -3907,7 +3907,7 @@ def select_expr_or_var(): else: return f'v{idx - len(exprs)}' - for i in range(50): + for _ in range(50): n = None while n is None or n > len(exprs) + n_variables: template = random.choice(templates) @@ -3922,7 +3922,7 @@ def select_expr_or_var(): src_lines.append(' return ({})\n'.format(''.join(f'v{i},' for i in range(n_variables)))) return '\n'.join(src_lines) - for i in range(100): + for _ in range(100): g = {'torch': torch} code = gen_code() builtins.exec(code, g, None) @@ -4602,7 +4602,7 @@ def test_block_input_grad_in_loop(self): y = torch.randn(3, 3, requires_grad=True) def grad_in_loop(x, y): - for i in range(100): + for _ in range(100): x = y @ x return x @@ -5559,7 +5559,7 @@ def test_resize(): @torch.jit.script def test(x): after_resize_alias = torch.zeros([2]) - for _i in range(5): + for _ in range(5): b = x + 1 f = [1] before_resize_alias = b.sub_(1) @@ -5950,7 +5950,7 @@ def fib(x): # type: (int) -> int prev = 1 v = 1 - for i in range(x): + for _ in range(x): save = v v = v + prev prev = save @@ -7785,7 +7785,7 @@ def test(y): while int(tensor.add_(1)) < 4: if y == 1: continue - for i in range(y): + for _ in range(y): continue ret += 1 ret += 1 @@ -7896,7 +7896,7 @@ def assign_after_break(y): def assign_after_break_nested(y): # type: (int) x = 0 - for i in range(y): + for _ in range(y): if y == 1: x = 5 break @@ -7916,7 +7916,7 @@ def assign_after_break_nested(y): def may_break(y): # type: (int) x = 0 - for i in range(y): + for _ in range(y): if y == 1: x = 5 else: @@ -7988,7 +7988,7 @@ def test_will_break_after_guard(x): def test_varexit(cond): # type: (int) m = 0 - for i in range(3): + for _ in range(3): if cond == 2: if cond == 2: m = 2 @@ -8376,7 +8376,7 @@ def contained_blocks(node): # find the last output, then all subsequent uses fc.check(out_name[-1] + " : ") # skip past node body - for i in range(contained_blocks(node)): + for _ in range(contained_blocks(node)): fc.check("->") if (node.kind() == "prim::If"): fc.check("->").check("->").check("\n") @@ -8429,7 +8429,7 @@ def test_loop(x, iter): a = 1 b = 2 c = 3 - for i in range(iter): + for _ in range(iter): a = 4 b = 5 c = 6 @@ -8445,7 +8445,7 @@ def loop_unused(iter): a = 1 b = 2 c = 3 - for i in range(iter): + for _ in range(iter): c = c + 1 b = b + 1 a = a + 1 @@ -10938,7 +10938,7 @@ def forward(self, x: torch.Tensor): # Test symbolic differentiation # Run Forward and Backward thrice to trigger autodiff graph - for i in range(3): + for _ in range(3): y = jit_module(x) y.backward(grad) x.grad.zero_() @@ -11030,7 +11030,7 @@ def jitted_foo(x: torch.Tensor, y: torch.Tensor, W: torch.Tensor): W.data /= 4 with enable_profiling_mode_for_profiling_tests(): - for i in range(4): + for _ in range(4): self.assertTrue((foo(x, y, W).grad_fn is None) == (jitted_foo(x, y, W).grad_fn is None)) @@ -11822,7 +11822,7 @@ def fn_enumerate_zip(x, y): def test_for_in_tensors(self): def test_sizes(x): sumz = 0 - for s in x: + for _ in x: sumz += 1 return sumz self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),)) @@ -11834,7 +11834,7 @@ def test_for_in_tensors_rank0(self): @torch.jit.script def test_sizes(x): sumz = 0 - for s in x: + for _ in x: sumz += 1 return sumz @@ -11846,7 +11846,7 @@ def test_for_in_tensors_fail_scalar(self): def test_sizes(x): # type: (float) -> int sumz = 0 - for s in x: + for _ in x: sumz += 1 return sumz @@ -11856,7 +11856,7 @@ def test_for_in_tensors_nested(self): def test_sizes(x): sumz = 0 for n in x: - for t in n: + for _ in n: sumz += 1 return sumz @@ -12316,7 +12316,7 @@ def __init__(self) -> None: @torch.jit.script_method def forward(self, x): - for _i in range(4): + for _ in range(4): x += self.param return x @@ -12840,7 +12840,7 @@ def test_file_reader_no_memory_leak(self): # Load from filename tracemalloc.start() - for i in range(num_iters): + for _ in range(num_iters): torch._C.PyTorchFileReader(filename) _, peak_from_string = tracemalloc.get_traced_memory() tracemalloc.stop() @@ -12848,7 +12848,7 @@ def test_file_reader_no_memory_leak(self): # Load from stream tracemalloc.start() with open(filename, 'rb') as f: - for i in range(num_iters): + for _ in range(num_iters): f.seek(0) torch._C.PyTorchFileReader(f) _, peak_from_file = tracemalloc.get_traced_memory() @@ -13287,7 +13287,7 @@ def call(self): def test_pass(self): def foo(x): # type: (bool) -> int - for _i in range(3): + for _ in range(3): pass if x: pass @@ -13903,7 +13903,7 @@ def test_if_might(x): def test_loop_no_escape(x): # type: (int) if x >= 0: - for i in range(x): + for _ in range(x): raise RuntimeError("hi") else: return 5 @@ -14116,7 +14116,7 @@ def loop_ret(x, y): def test_will_ret(y): # type: (int) -> int - for i in range(y): + for _ in range(y): return 2 return 1 @@ -14125,8 +14125,8 @@ def test_will_ret(y): def test_loop_nest_ret(y): # type: (int) -> int - for i in range(y): - for i in range(y - 2): + for _ in range(y): + for _ in range(y - 2): return 10 return 5 return 0 @@ -15387,7 +15387,7 @@ def is_tensor_value(item): if isinstance(item, list): return is_tensor_value(item[0]) return False - for name, value, the_type in self.get_pickle_values(): + for name, value, _the_type in self.get_pickle_values(): if is_tensor_value(value): continue self.assertEqual(value, getattr(loaded, "_" + name)) @@ -15768,7 +15768,7 @@ def fn(d): def test_for_else(self): def fn(): c = 0 - for i in range(4): + for _ in range(4): c += 10 else: print("In else block of for...else") diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index dba28f98cbf98..c3018be817d9b 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -94,7 +94,7 @@ def strip_profiling_nodes(nodes): def warmup_forward(f, *args, profiling_count=2): - for i in range(profiling_count): + for _ in range(profiling_count): results = f(*args) return results @@ -2284,7 +2284,7 @@ def _test_fwd_bwd(self, fn): x = torch.arange(-10, 10, dtype=torch.float32, requires_grad=True) xs = torch.arange(-10, 10, dtype=torch.float32, requires_grad=True) script = torch.jit.script(fn) - for i in range(11): + for _ in range(11): y = fn(x) g0 = torch.rand_like(y) y.backward(g0) @@ -2514,7 +2514,7 @@ def fum(x, y, z): x, y, z = gen(n), gen(n), gen(n) func_s(x, y, z) - for incr in range(3): + for _incr in range(3): func_s(*[gen(n + 1) for _ in range(3)]) g = torch.jit.last_executed_optimized_graph() @@ -2678,7 +2678,7 @@ def f(x, y): f_traced = torch.jit.trace(f, (x, y)) - for i in range(4): + for _ in range(4): # make sure this doesn't error out res = f_traced(x, y) @@ -2697,7 +2697,7 @@ def fn(p: torch.Tensor, gamma: float = 2.0) -> torch.Tensor: ref = fn(x) script_fn = torch.jit.script(fn) - for i in range(4): + for _ in range(4): res = script_fn(x) self.assertEqual(ref, res) diff --git a/test/test_jiterator.py b/test/test_jiterator.py index 813552f33a9cb..55ad64adb6b34 100644 --- a/test/test_jiterator.py +++ b/test/test_jiterator.py @@ -115,7 +115,7 @@ def ref_fn(x, mask, y): @parametrize("num_inputs", [1, 5, 8]) def test_various_num_inputs(self, num_inputs): inputs = [] - for i in range(num_inputs): + for _ in range(num_inputs): inputs.append(torch.rand(3, device='cuda').mul(10)) input_string = ",".join([f"T i{i}" for i in range(num_inputs)]) diff --git a/test/test_linalg.py b/test/test_linalg.py index 01a6dd5c8ecd6..41a223763d474 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -755,10 +755,11 @@ def cholesky_test_helper(n, batch_dims, upper): cholesky_test_helper(3, batchsize, upper) @precisionOverride({torch.float32: 1e-4, torch.complex64: 1e-4}) + @skipIfRocmArch(MI300_ARCH) @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(*floating_and_complex_types()) - @tf32_on_and_off(0.1 if TEST_WITH_ROCM else 0.01) + @tf32_on_and_off(0.01) @reduced_f32_on_and_off(0.01) def test_old_cholesky(self, device, dtype): from torch.testing._internal.common_utils import random_hermitian_pd_matrix @@ -1526,7 +1527,7 @@ def test_vector_norm_dim_tuple_arg(self, device): if error is None: torch.linalg.vector_norm(input, dim=dim) else: - with self.assertRaises(error): + with self.assertRaises(error, msg=error_msg): torch.linalg.vector_norm(input, dim=dim) # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that @@ -7132,7 +7133,7 @@ def tracker(worker): elapsed_ortho_general = 0 elapsed_scipy = 0 elapsed_general_scipy = 0 - for i in range(repeat): + for _ in range(repeat): start = time.time() torch.lobpcg(A1, X=X1, niter=niter, method='ortho', tol=tol) end = time.time() @@ -7328,9 +7329,11 @@ def _test_addmm_impl(self, func, activation, device, dtype): m2 = torch.randn(50, 25, device=device).to(dtype) self._test_addmm_addmv(func, M, m1, m2, activation=activation) - # vector-shaped bias and beta=1 result in epilogue fusion in CUDA + # vector-shaped bias (or with 1-len dims on the left from the leading dim) + # and beta=1 result in epilogue fusion in CUDA V = torch.randn(25, device=device).to(dtype) self._test_addmm_addmv(func, V, m1, m2, beta=1, activation=activation) + self._test_addmm_addmv(func, V.unsqueeze(0), m1, m2, beta=1, activation=activation) # Test 0-strided M = torch.randn(10, 1, device=device).to(dtype).expand(10, 25) @@ -7357,8 +7360,9 @@ def maybe_transpose(cond, m): self._test_addmm_addmv(func, M, m1, m2, transpose_out=t4, activation=activation) if t1: - # use vector V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1) + # use vector/(1 by k)-shaped V instead of matrix M for epilogue fusion in CUDA (doesn't depend on t1) self._test_addmm_addmv(func, V, m1, m2, beta=1, transpose_out=t4, activation=activation,) + self._test_addmm_addmv(func, V.unsqueeze(0), m1, m2, beta=1, transpose_out=t4, activation=activation,) @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) @@ -7407,9 +7411,10 @@ def test_addmm_relu_tunableop_rocm(self, device, dtype): def test_addmm_gelu(self, device, dtype): self._test_addmm_impl(torch._addmm_activation, "gelu", device, dtype) + @skipIfRocmArch(MI300_ARCH) @dtypes(torch.float, torch.double) @dtypesIfCUDA(*floating_and_complex_types()) - @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) + @tf32_on_and_off(0.005) @reduced_f32_on_and_off(0.005) def test_addmm_sizes(self, device, dtype): for m in [0, 1, 25]: @@ -8764,7 +8769,7 @@ def run_test(*n): num_matrices = tensors_batch.size(0) tensors_list = [] - for i in range(num_matrices): + for _ in range(num_matrices): tensors_list.append(torch.randn(n[-2], n[-1], dtype=dtype, device=device)) for i in range(num_matrices): @@ -9366,8 +9371,8 @@ def dims_full_for_fn(): r1 = fntorch(t0_full, t1, t2) self.assertEqual(r0, r1) - # ROCm 6.4 passes with tf32=on, but 6.4.1 needed tolerance reduced slightly - @tf32_on_and_off(0.002 if torch.version.hip else 0.001) + @skipIfRocmArch(MI300_ARCH) + @tf32_on_and_off(0.001) @reduced_f32_on_and_off(0.001) def test_broadcast_batched_matmul(self, device): n_dim = random.randint(1, 8) @@ -9704,7 +9709,8 @@ def fn(torchfn, *args): self.assertEqual((torch.tensor(1., device=device), torch.tensor(0., device=device)), fn(torch.slogdet, (0, 0))) - @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) + @skipIfRocmArch(MI300_ARCH) + @tf32_on_and_off(0.005) @reduced_f32_on_and_off(0.07, 0.005) def test_tensordot(self, device): a = torch.arange(60., device=device).reshape(3, 4, 5) diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 29e232778d464..5e54a851812e0 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -5,6 +5,7 @@ import unittest from itertools import product from functools import partial +from typing import Callable import torch @@ -49,6 +50,8 @@ decorateIf, ) +from torch.testing._internal.inductor_utils import IS_BIG_GPU + from torch._inductor.test_case import TestCase as InductorTestCase _IS_SM8X = False @@ -88,14 +91,21 @@ def tearDown(self): torch.backends.cuda.matmul.allow_tf32 = True super().tearDown() - def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False, fp16_accumulate: bool = False): + def cublas_addmm( + self, + size: int, + dtype: torch.dtype, + reduced_precision: bool = False, + fp16_accumulate: bool = False, + bias_shape_modifier: Callable | None = None, + ): # # Check for catastrophic cuBLAS inaccuracy by measuring the deviation between # results from the CUDA invocation of torch.addmm and the CPU invocation # (which does not use CUDA backend). # # Get dims - n, m, p = (size + 1, size, size + 2) + m, k, n = (size + 1, size, size + 2) # Disable reduced precision reductions in BFloat16 to bypass some kernels # which fail the threshold check orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction @@ -107,10 +117,12 @@ def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = # Make random tensors on CPU (seed set on common_utils.py import) # (Not using numpy because it does not support bfloat16) make_arg = partial(make_tensor, dtype=dtype, device="cpu") + + bias_shape_modifier = (lambda shape: shape) if bias_shape_modifier is None else bias_shape_modifier + m_input = make_arg(bias_shape_modifier((m, n))) + m_1 = make_arg((m, k)) + m_2 = make_arg((k, n)) m_beta = make_arg(1) - m_input = make_arg((n, p)) - m_1 = make_arg((n, m)) - m_2 = make_arg((m, p)) # scale to abate overflows in fp16 accum if fp16_accumulate: m_1 = m_1 / 100 @@ -177,6 +189,25 @@ def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype, bac with blas_library_context(backend): self.cublas_addmm(size, dtype, True) + + @onlyCUDA + # imported 'tol' as 'xtol' to avoid aliasing in code above + @toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=1e-4), + torch.bfloat16: xtol(atol=1e-3, rtol=1e-4), + torch.float32: xtol(atol=1e-3, rtol=1e-4)}) + @dtypes(torch.bfloat16, torch.float16, torch.float32) + @parametrize("size", [128]) + @parametrize("backend", ["cublas", "cublaslt"]) + def test_cublas_addmm_bias_shapes(self, size: int, dtype: torch.dtype, backend): + with blas_library_context(backend): + # 2D bias + self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: shape) + # 1D bias which is row-broadcast to 2D + self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: (1, shape[-1])) + # 1D bias which row-broadcasts + self.cublas_addmm(size, dtype, bias_shape_modifier=lambda shape: (shape[-1],)) + + @onlyCUDA @dtypes(torch.float16) # m == 4 chooses OUTPUT_TYPE reduction on H200 @@ -619,8 +650,12 @@ def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune) raise AssertionError(f"Invalid op: {op}") C_ref = f_ref(A, B.transpose(-2, -1), offs=offs) - C = f(A, B.transpose(-2, -1), offs=offs) - torch.testing.assert_close(C, C_ref) + if not IS_BIG_GPU and max_autotune: + with self.assertRaisesRegex(torch._inductor.exc.InductorError, "NoValidChoicesError"): + C = f(A, B.transpose(-2, -1), offs=offs) + else: + C = f(A, B.transpose(-2, -1), offs=offs) + self.assertEqual(C, C_ref) @onlyCUDA diff --git a/test/test_mps.py b/test/test_mps.py index 83d5b46d46821..fad09c2f5eb28 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -130,7 +130,7 @@ def __exit__(self, exc_type, exc_value, traceback): discrepancy_detected = True # Query memory multiple items to ensure leak was not transient - for n in range(3): + for _ in range(3): caching_allocator_mem_allocated = torch.mps.current_allocated_memory() driver_mem_allocated = torch.mps.driver_allocated_memory() @@ -4984,7 +4984,7 @@ def helper(shape): input_xs.append(torch.ones(prod, dtype=torch.int).reshape(shape).bool()) input_xs.append(torch.zeros(prod, dtype=torch.int).reshape(shape).bool()) - for i, cpu_x in enumerate(input_xs): + for cpu_x in input_xs: x = cpu_x.detach().clone().to('mps') y = torch.all(x) ref_y = torch.all(cpu_x) @@ -9601,7 +9601,7 @@ def get_mps_memory_usage(): key = torch.randn(batch_size, num_heads, seq_len, head_dim, device="mps", dtype=torch.float32) value = torch.randn(batch_size, num_heads, seq_len, head_dim, device="mps", dtype=torch.float32) memory_footprints = [] - for i in range(100): + for _ in range(100): output = F.scaled_dot_product_attention(query, key, value) current_mem, driver_mem = get_mps_memory_usage() memory_footprints.append((current_mem, driver_mem)) diff --git a/test/test_multiprocessing.py b/test/test_multiprocessing.py index 76e50375bba15..45a09a9312ced 100644 --- a/test/test_multiprocessing.py +++ b/test/test_multiprocessing.py @@ -58,7 +58,7 @@ def run(self): def _test_cuda_ipc_deadlock_actor(queue, iterations): - for i in range(iterations): + for _ in range(iterations): if not queue.empty(): queue.get() time.sleep(0.01) @@ -66,7 +66,7 @@ def _test_cuda_ipc_deadlock_actor(queue, iterations): def _test_cuda_ipc_deadlock_learner(queue, iterations): net = torch.nn.LSTM(1, 1).cuda() - for i in range(iterations): + for _ in range(iterations): if not queue.full(): queue.put(copy.deepcopy(net.state_dict())) time.sleep(0.01) @@ -138,7 +138,7 @@ def send_tensor_with_untyped_storage(queue, event): def receive_and_send_sum(queue, out_queue, event, device, dtype, count, size=5): s = torch.full([size], 0, device=device, dtype=dtype) - for i in range(count): + for _ in range(count): t = queue.get() s += t out_queue.put(s) @@ -146,7 +146,7 @@ def receive_and_send_sum(queue, out_queue, event, device, dtype, count, size=5): def receive_and_send(queue, out_queue, event, count): - for i in range(count): + for _ in range(count): t = queue.get() out_queue.put(t.clone()) event.wait() diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py index c8a9ca33efb0f..3829c0a5de5a5 100644 --- a/test/test_namedtensor.py +++ b/test/test_namedtensor.py @@ -1002,7 +1002,7 @@ def test_cummax_cummin(self): def test_ops(op): for device in get_all_device_types(): names = ('N', 'D') - tensor = torch.rand(2, 3, names=names) + tensor = torch.rand(2, 3, names=names, device=device) result = op(tensor, 0) self.assertEqual(result[0].names, names) self.assertEqual(result[1].names, names) @@ -1012,15 +1012,15 @@ def test_ops(op): def test_logcumsumexp(self): for device in get_all_device_types(): names = ('N', 'D') - tensor = torch.rand(2, 3, names=names) + tensor = torch.rand(2, 3, names=names, device=device) result = torch.logcumsumexp(tensor, 'D') self.assertEqual(result.names, names) def test_bitwise_not(self): for device in get_all_device_types(): names = ('N', 'D') - tensor = torch.zeros(2, 3, names=names, dtype=torch.bool) - result = torch.empty(0, dtype=torch.bool) + tensor = torch.zeros(2, 3, names=names, dtype=torch.bool, device=device) + result = torch.empty(0, dtype=torch.bool, device=device) self.assertEqual(tensor.bitwise_not().names, names) self.assertEqual(torch.bitwise_not(tensor, out=result).names, names) @@ -1029,8 +1029,8 @@ def test_bitwise_not(self): def test_logical_not(self): for device in get_all_device_types(): names = ('N', 'D') - tensor = torch.zeros(2, 3, names=names, dtype=torch.bool) - result = torch.empty(0, dtype=torch.bool) + tensor = torch.zeros(2, 3, names=names, dtype=torch.bool, device=device) + result = torch.empty(0, dtype=torch.bool, device=device) self.assertEqual(tensor.logical_not().names, names) self.assertEqual(torch.logical_not(tensor, out=result).names, names) @@ -1039,8 +1039,8 @@ def test_logical_not(self): def test_bernoulli(self): for device in get_all_device_types(): names = ('N', 'D') - tensor = torch.rand(2, 3, names=names) - result = torch.empty(0) + tensor = torch.rand(2, 3, names=names, device=device) + result = torch.empty(0, device=device) self.assertEqual(tensor.bernoulli().names, names) torch.bernoulli(tensor, out=result) diff --git a/test/test_nn.py b/test/test_nn.py index eac0d887c4252..034cf51d49ff0 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -32,7 +32,7 @@ from torch.nn.parallel._functions import Broadcast from torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes, floating_types from torch.testing._internal.common_utils import dtype_name, freeze_rng_state, run_tests, TestCase, \ - skipIfNoLapack, skipIfRocm, \ + skipIfNoLapack, skipIfRocm, MI300_ARCH, skipIfRocmArch, \ TEST_NUMPY, TEST_SCIPY, TEST_WITH_CROSSREF, TEST_WITH_ROCM, \ download_file, get_function_arglist, load_tests, skipIfMPS, \ IS_PPC, \ @@ -1238,7 +1238,7 @@ def test_ParameterDict(self): def check(): self.assertEqual(len(parameter_dict), len(parameters)) - for i, (k1, (k2, m2)) in enumerate(zip(parameters, parameter_dict.named_parameters())): + for (k1, (k2, m2)) in zip(parameters, parameter_dict.named_parameters()): self.assertEqual(k1, k2) self.assertIs(parameters[k1], m2) for k1, k2 in zip(parameters, parameter_dict): @@ -2958,7 +2958,7 @@ def perm_fn(x): batch_first=batch_first) # set constant weights of the model - for idx, p in enumerate(model.parameters()): + for p in model.parameters(): x = p.data sz = x.view(-1).size(0) shape = x.shape @@ -3108,7 +3108,7 @@ def perm_fn(x): activation, batch_first=batch_first) # set constant weights of the model - for idx, p in enumerate(model.parameters()): + for p in model.parameters(): x = p.data sz = x.view(-1).size(0) shape = x.shape @@ -3185,7 +3185,7 @@ def get_a_test_layer(use_cuda, activation, batch_first=False): with torch.no_grad(): # set constant weights of the model - for idx, p in enumerate(layer.parameters()): + for p in layer.parameters(): x = p.data sz = x.view(-1).size(0) shape = x.shape @@ -8378,8 +8378,9 @@ def test_affine_2d_rotate0(self, device): @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), "Scipy v1.0 and/or numpy not found") + @skipIfRocmArch(MI300_ARCH) @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 - @tf32_on_and_off(0.01 if TEST_WITH_ROCM else 0.001) + @tf32_on_and_off(0.001) @reduced_f32_on_and_off(0.001) def test_affine_2d_rotate90(self, device): # scipy before 1.0.0 do not support homogeneous coordinate @@ -8526,8 +8527,9 @@ def test_avg_pool_large_tensor2(self, device): @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), "Scipy v1.0 and/or numpy not found") + @skipIfRocmArch(MI300_ARCH) @expectedFailureMPS # Unsupported Border padding mode https://github.com/pytorch/pytorch/issues/125098 - @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) + @tf32_on_and_off(0.005) @reduced_f32_on_and_off(0.005) def test_affine_2d_rotateRandom(self, device): # scipy before 1.0.0 do not support homogeneous coordinate @@ -8579,7 +8581,8 @@ def test_affine_2d_rotateRandom(self, device): @unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'), "Scipy v1.0 and/or numpy not found") - @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) + @skipIfRocmArch(MI300_ARCH) + @tf32_on_and_off(0.005) @reduced_f32_on_and_off(0.005) def test_affine_3d_rotateRandom(self, device): # scipy before 1.0.0 do not support homogeneous coordinate @@ -9243,8 +9246,10 @@ def test_TransformerDecoder_empty(self, device): @expectedFailureMPS # Float64 is not supported @onlyNativeDeviceTypes def test_Transformer_empty(self, device): - for batch_first, src_shape, tgt_shape in [(True, (10, 0, 512), (20, 0, 512))]: - transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12, dtype=torch.double).to(device) + for batch_first, src_shape, tgt_shape in [(True, (0, 10, 512), (0, 20, 512)), + (False, (10, 0, 512), (20, 0, 512))]: + transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12, dtype=torch.double, + batch_first=batch_first).to(device) src = torch.rand(*src_shape, requires_grad=True, device=device, dtype=torch.double) tgt = torch.rand(*tgt_shape, requires_grad=True, device=device, dtype=torch.double) self._test_module_empty_inputs(transformer_model, [src, tgt]) @@ -9456,8 +9461,9 @@ def test_Unfold_empty(self, device): unfold(inp) @onlyCUDA + @skipIfRocmArch(MI300_ARCH) @dtypes(torch.float, torch.double) - @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) + @tf32_on_and_off(0.005) def test_rnn_fused(self, device, dtype): def copy_rnn(rnn1, rnn2): @@ -11936,10 +11942,11 @@ def test_ctc_loss_error(self, device): with self.assertRaisesRegex(RuntimeError, "log_probs tensor must not be empty"): F.ctc_loss(log_probs, targets, input_lengths, target_lengths, reduction='none') + @skipIfRocmArch(MI300_ARCH) @expectedFailureMPS # RuntimeError: LSTM with projections is not currently supported with MPS. @dtypesIfCUDA(torch.half, torch.float, torch.double) @dtypes(torch.float) - @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) + @tf32_on_and_off(0.005) @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons") def test_variable_sequence(self, device, dtype): def pad(var, length): @@ -13129,7 +13136,7 @@ def perm_fn(x): model = model.eval() # set constant weights of the model - for idx, p in enumerate(model.parameters()): + for p in model.parameters(): x = p.data sz = x.view(-1).size(0) shape = x.shape @@ -13349,7 +13356,7 @@ def perm_fn(x): model = model.eval() # set constant weights of the model - for idx, p in enumerate(model.parameters()): + for p in model.parameters(): x = p.data sz = x.view(-1).size(0) shape = x.shape diff --git a/test/test_opaque_obj_v2.py b/test/test_opaque_obj_v2.py index aea2441c61b9b..b2f0f873a853a 100644 --- a/test/test_opaque_obj_v2.py +++ b/test/test_opaque_obj_v2.py @@ -1,8 +1,16 @@ # Owner(s): ["module: custom-operators"] +import random + import torch from torch._dynamo.test_case import run_tests, TestCase +from torch._library.fake_class_registry import FakeScriptObject from torch._library.opaque_object import register_opaque_type +from torch.fx.experimental.proxy_tensor import make_fx +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) class OpaqueQueue: @@ -11,24 +19,39 @@ def __init__(self, queue: list[torch.Tensor], init_tensor_: torch.Tensor) -> Non self.queue = queue self.init_tensor_ = init_tensor_ + # For testing purposes + self._push_counter = 0 + self._pop_counter = 0 + self._size_counter = 0 + def push(self, tensor: torch.Tensor) -> None: + self._push_counter += 1 self.queue.append(tensor) def pop(self) -> torch.Tensor: + self._pop_counter += 1 if len(self.queue) > 0: return self.queue.pop(0) return self.init_tensor_ def size(self) -> int: + self._size_counter += 1 return len(self.queue) +class RNGState: + def __init__(self, seed): + self.rng = random.Random(seed) + + +register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue") +register_opaque_type(RNGState, "_TestOpaqueObject_RNGState") + + class TestOpaqueObject(TestCase): def setUp(self): self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") # noqa: TOR901 - register_opaque_type(OpaqueQueue, "_TestOpaqueObject_OpaqueQueue") - torch.library.define( "_TestOpaqueObject::queue_push", "(_TestOpaqueObject_OpaqueQueue a, Tensor b) -> ()", @@ -43,6 +66,10 @@ def push_impl(queue: OpaqueQueue, b: torch.Tensor) -> None: assert isinstance(queue, OpaqueQueue) queue.push(b) + @torch.library.register_fake("_TestOpaqueObject::queue_push", lib=self.lib) + def push_impl_fake(q: OpaqueQueue, b: torch.Tensor) -> None: + pass + self.lib.define( "queue_pop(_TestOpaqueObject_OpaqueQueue a) -> Tensor", ) @@ -53,6 +80,15 @@ def pop_impl(queue: OpaqueQueue) -> torch.Tensor: self.lib.impl("queue_pop", pop_impl, "CompositeExplicitAutograd") + def pop_impl_fake(q: OpaqueQueue) -> torch.Tensor: + # This is not accurate since the queue could have tensors that are + # not rank 1 + ctx = torch.library.get_ctx() + u0 = ctx.new_dynamic_size() + return torch.empty(u0) + + self.lib._register_fake("queue_pop", pop_impl_fake) + @torch.library.custom_op( "_TestOpaqueObject::queue_size", mutates_args=[], @@ -61,6 +97,34 @@ def size_impl(queue: OpaqueQueue) -> int: assert isinstance(queue, OpaqueQueue) return queue.size() + @size_impl.register_fake + def size_impl_fake(q: OpaqueQueue) -> int: + ctx = torch._custom_op.impl.get_ctx() + u0 = ctx.new_dynamic_size() + torch._check_is_size(u0) + return u0 + + torch.library.define( + "_TestOpaqueObject::noisy_inject", + "(Tensor x, _TestOpaqueObject_RNGState obj) -> Tensor", + tags=torch.Tag.pt2_compliant_tag, + lib=self.lib, + ) + + @torch.library.impl( + "_TestOpaqueObject::noisy_inject", "CompositeExplicitAutograd", lib=self.lib + ) + def noisy_inject(x: torch.Tensor, rng_state: RNGState) -> torch.Tensor: + assert isinstance(rng_state, RNGState) + out = x.clone() + for i in range(out.numel()): + out.view(-1)[i] += rng_state.rng.random() + return out + + @torch.library.register_fake("_TestOpaqueObject::noisy_inject", lib=self.lib) + def noisy_inject_fake(x: torch.Tensor, obj: RNGState) -> torch.Tensor: + return torch.empty_like(x) + super().setUp() def tearDown(self): @@ -79,6 +143,99 @@ def test_ops(self): size = torch.ops._TestOpaqueObject.queue_size(queue) self.assertEqual(size, 0) + @parametrize("make_fx_tracing_mode", ["fake", "symbolic"]) + def test_make_fx(self, make_fx_tracing_mode): + class M(torch.nn.Module): + def forward(self, queue, x): + torch.ops._TestOpaqueObject.queue_push(queue, x.tan()) + torch.ops._TestOpaqueObject.queue_push(queue, x.cos()) + torch.ops._TestOpaqueObject.queue_push(queue, x.sin()) + pop1 = torch.ops._TestOpaqueObject.queue_pop(queue) + size1 = torch.ops._TestOpaqueObject.queue_size(queue) + pop2 = torch.ops._TestOpaqueObject.queue_pop(queue) + size2 = torch.ops._TestOpaqueObject.queue_size(queue) + x_cos = pop1 + size1 + x_sin = pop2 - size2 + return x_sin + x_cos + + q1 = OpaqueQueue([], torch.empty(0).fill_(-1)) + q2 = OpaqueQueue([], torch.empty(0).fill_(-1)) + + x = torch.ones(2, 3) + gm = make_fx(M(), tracing_mode=make_fx_tracing_mode)(q1, x) + self.assertTrue(torch.allclose(gm(q1, x), M()(q2, x))) + self.assertEqual(q1._push_counter, 3) + self.assertEqual(q1._pop_counter, 2) + self.assertEqual(q1._size_counter, 2) + self.assertEqual(q1.size(), 1) + self.assertExpectedInline( + gm.code.strip("\n"), + """\ +def forward(self, arg0_1, arg1_1): + tan = torch.ops.aten.tan.default(arg1_1) + queue_push = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, tan); tan = queue_push = None + cos = torch.ops.aten.cos.default(arg1_1) + queue_push_1 = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, cos); cos = queue_push_1 = None + sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None + queue_push_2 = torch.ops._TestOpaqueObject.queue_push.default(arg0_1, sin); sin = queue_push_2 = None + queue_pop = torch.ops._TestOpaqueObject.queue_pop.default(arg0_1) + queue_size = torch.ops._TestOpaqueObject.queue_size.default(arg0_1) + queue_pop_1 = torch.ops._TestOpaqueObject.queue_pop.default(arg0_1) + queue_size_1 = torch.ops._TestOpaqueObject.queue_size.default(arg0_1); arg0_1 = None + add = torch.ops.aten.add.Tensor(queue_pop, queue_size); queue_pop = queue_size = None + sub = torch.ops.aten.sub.Tensor(queue_pop_1, queue_size_1); queue_pop_1 = queue_size_1 = None + add_1 = torch.ops.aten.add.Tensor(sub, add); sub = add = None + return add_1 + """, + ) + + @parametrize("make_fx_tracing_mode", ["fake", "symbolic"]) + def test_bad_fake(self, make_fx_tracing_mode): + torch.library.define( + "_TestOpaqueObject::bad_fake", + "(Tensor x, _TestOpaqueObject_RNGState obj) -> Tensor", + tags=torch.Tag.pt2_compliant_tag, + lib=self.lib, + ) + + def f(q, x): + torch.ops._TestOpaqueObject.bad_fake(x, q) + return x.cos() + + def bad_fake1(x, rng_state) -> torch.Tensor: + self.assertTrue(isinstance(rng_state, FakeScriptObject)) + out = x.clone() + for i in range(out.numel()): + out.view(-1)[i] += rng_state.rng.random() # bad: accessing attributes + return out + + torch.library.register_fake( + "_TestOpaqueObject::bad_fake", bad_fake1, lib=self.lib, allow_override=True + ) + + with self.assertRaisesRegex( + AttributeError, + "Tried to call __getattr__ with attr", + ): + make_fx(f, tracing_mode=make_fx_tracing_mode)(RNGState(0), torch.ones(3)) + + def bad_fake2(x, rng_state) -> torch.Tensor: + rng_state.rng = "foo" + return torch.empty_like(x) + + torch.library.register_fake( + "_TestOpaqueObject::bad_fake", bad_fake2, lib=self.lib, allow_override=True + ) + + with self.assertRaisesRegex( + AttributeError, + "Tried to call __setattr__ with attr", + ): + make_fx(f, tracing_mode=make_fx_tracing_mode)(RNGState(0), torch.ones(3)) + + +instantiate_parametrized_tests(TestOpaqueObject) + if __name__ == "__main__": run_tests() diff --git a/test/test_ops.py b/test/test_ops.py index 7427de04bf839..165b284b76d5c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2040,7 +2040,7 @@ def check_cow_input( # Convert strided tensor inputs to COW tensors and make copies of # all inputs - for idx, arg in enumerate(args_raw): + for arg in args_raw: if is_strided_tensor(arg): args_copy.append(arg.detach().clone()) args.append(torch._lazy_clone(arg)) diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index dff4e9c014c78..7a9f8f3aa317f 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -358,6 +358,7 @@ def onerror(modname): "torch.testing._internal.distributed.rpc.tensorpipe_rpc_agent_test_fixture", "torch.testing._internal.distributed.rpc_utils", "torch._inductor.codegen.cuda.cuda_template", + "torch._inductor.codegen.cutedsl._cutedsl_utils", "torch._inductor.codegen.cuda.gemm_template", "torch._inductor.codegen.cpp_template", "torch._inductor.codegen.cpp_gemm_template", diff --git a/test/test_rename_privateuse1_to_existing_device.py b/test/test_rename_privateuse1_to_existing_device.py index 539412a322385..40941ca4e77dd 100644 --- a/test/test_rename_privateuse1_to_existing_device.py +++ b/test/test_rename_privateuse1_to_existing_device.py @@ -1,7 +1,7 @@ # Owner(s): ["module: PrivateUse1"] import torch -from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase class DummyPrivateUse1Module: @@ -31,6 +31,9 @@ def get_amp_supported_dtype(): class TestRenamePrivateuseoneToExistingBackend(TestCase): + @skipIfTorchDynamo( + "TorchDynamo exposes https://github.com/pytorch/pytorch/issues/166696" + ) def test_external_module_register_with_existing_backend(self): torch.utils.rename_privateuse1_backend("maia") with self.assertRaisesRegex(RuntimeError, "has already been set"): diff --git a/test/test_scaled_matmul_cuda.py b/test/test_scaled_matmul_cuda.py index 204153e971b83..216978142d5b4 100644 --- a/test/test_scaled_matmul_cuda.py +++ b/test/test_scaled_matmul_cuda.py @@ -42,12 +42,13 @@ TestCase, ) from torch.testing._internal.common_quantized import ( - _f32_to_floatx_unpacked, + _bfloat16_to_float4_e2m1fn_x2, _floatx_unpacked_to_f32, ceil_div, to_blocked, - to_mxfp8, + to_mxfp, from_blocked_format, generate_jagged_offs, + pack_uint4, ) @@ -279,11 +280,13 @@ def scaled_grouped_mm_wrap( -def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: +def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype, bias: Optional[torch.Tensor] = None) -> torch.Tensor: # naive implementation: dq -> op -> q x_fp32 = x.to(torch.float) / x_scale y_fp32 = y.to(torch.float) / y_scale out_fp32 = torch.mm(x_fp32, y_fp32) + if bias is not None: + out_fp32 += bias.to(torch.float) return out_fp32.to(out_dtype) @@ -451,18 +454,6 @@ def data_to_nvfp4_with_global_scale(x, block_size): return x_fp4, S_dec_b_e4m3, S_dec.float() -def down_size(size): - assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" - return (*size[:-1], size[-1] // 2) - - -def pack_uint4(uint8_data) -> torch.Tensor: - # converting to uint8 for operations - shape = uint8_data.shape - assert shape[-1] % 2 == 0 - uint8_data = uint8_data.contiguous().view(-1) - return (uint8_data[1::2] << 4 | uint8_data[::2]).view(down_size(shape)) - def unpack_uint4(uint8_data) -> torch.Tensor: # Take a packed uint8 tensor (i.e. nvfp4) and unpack into # a tensor twice as wide. Useful for dequant operations. @@ -482,13 +473,6 @@ def unpack_uint4(uint8_data) -> torch.Tensor: return out.view(shape) -def _bfloat16_to_float4_e2m1fn_x2(x): - assert x.dtype == torch.bfloat16 - x = _f32_to_floatx_unpacked(x.float(), FP4_EBITS, FP4_MBITS) - x = pack_uint4(x) - x = x.view(torch.float4_e2m1fn_x2) - return x - def _convert_to_nvfp4_with_hp_ref(t): # Convert a tensor to nvfp4, returning: # t_hp : reconstructed bf16 version of t_lp @@ -509,17 +493,34 @@ def _convert_to_nvfp4_with_hp_ref(t): return t_hp, t_lp, t_scale, t_global_scale +def _convert_to_mxfp4_with_hp_ref(t): + # Convert a tensor to mxfp8, returning: + # t_hp : reconstructed bf16 version of t_lp + # t_lp : fp8_e4m3 tensor + # t_scale: fp8_e8m0 block-wise scaling factors (non-swizzled) + t_scale, t_lp = to_mxfp(t, format="mxfp4") + t_hp = from_blocked_format( + _floatx_unpacked_to_f32( + unpack_uint4(t_lp), + FP4_EBITS, + FP4_MBITS), + t_scale, + blocksize=32 + ) + + return t_hp, t_lp, t_scale + def _convert_to_mxfp8_with_hp_ref(t): # Convert a tensor to mxfp8, returning: # t_hp : reconstructed bf16 version of t_lp # t_lp : fp8_e4m3 tensor # t_scale: fp8_e8m0 block-wise scaling factors (non-swizzled) - t_scale, t_lp = to_mxfp8(t) + t_scale, t_lp = to_mxfp(t, format="mxfp8") t_hp = from_blocked_format(t_lp, t_scale, blocksize=32) return t_hp, t_lp, t_scale -def _2d_grouped_tensor_to_mxfp8_blocked_scaled(t, MN, G, offs, format='mxfp8'): +def _2d_grouped_tensor_to_blocked_scaled(t, MN, G, offs, format='mxfp8'): # Convert scales to blocked format. either mxfp8 or nvfp4 th_list = [] t_list = [] @@ -547,15 +548,18 @@ def round_up(x: int, y: int) -> int: t_slice, ) t_global_scale_list.append(tq_global) + elif format == 'mxfp4': + th_slice, tq_slice, t_scale_slice = _convert_to_mxfp4_with_hp_ref(t_slice) else: raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"') t_list.append(tq_slice) th_list.append(th_slice) # Convert scales to blocked format. - t_scale_slice_blocked = to_blocked( - t_scale_slice - ) # (round_up(M, 128), round_up(K_group//32, 4)) + if torch.version.cuda: + t_scale_slice_blocked = to_blocked( + t_scale_slice + ) # (round_up(M, 128), round_up(K_group//32, 4)) t_blocked_scale_list.append(t_scale_slice_blocked) # Assemble the full XQ and WQ @@ -576,15 +580,18 @@ def round_up(x: int, y: int) -> int: def _build_scaled_grouped_mm_kwargs(scale_a, scale_b, offs, format): # Build some standard args that are wordy - # Note: if/when ROCm support added, need to change swizzle handling + swizzle = SwizzleType.NO_SWIZZLE + if torch.version.cuda: + swizzle = SwizzleType.SWIZZLE_32_4_4 + kwargs = { 'mxfp8': { 'scale_a': scale_a, 'scale_b': scale_b, 'scale_recipe_a': ScalingType.BlockWise1x32, 'scale_recipe_b': ScalingType.BlockWise1x32, - 'swizzle_a': SwizzleType.SWIZZLE_32_4_4, - 'swizzle_b': SwizzleType.SWIZZLE_32_4_4, + 'swizzle_a': swizzle, + 'swizzle_b': swizzle, 'offs': offs, # (G,) 'out_dtype': torch.bfloat16, 'wrap_v2': True, @@ -594,13 +601,15 @@ def _build_scaled_grouped_mm_kwargs(scale_a, scale_b, offs, format): 'scale_b': scale_b, 'scale_recipe_a': [ScalingType.BlockWise1x16, ScalingType.TensorWise], 'scale_recipe_b': [ScalingType.BlockWise1x16, ScalingType.TensorWise], - 'swizzle_a': SwizzleType.SWIZZLE_32_4_4, - 'swizzle_b': SwizzleType.SWIZZLE_32_4_4, + 'swizzle_a': swizzle, + 'swizzle_b': swizzle, 'offs': offs, # (G,) 'out_dtype': torch.bfloat16, 'wrap_v2': True, }, } + # MXFP4 is exactly the same setup as mxfp8 + kwargs['mxfp4'] = kwargs['mxfp8'] return kwargs[format] class TestFP8Matmul(TestCase): @@ -665,7 +674,7 @@ def test_float8_scale(self, device) -> None: @parametrize("M", [2048, 2049]) @parametrize("N", [8192]) @parametrize("K", [16640]) - @parametrize("format", ["mxfp8"] + (["nvfp4"] if torch.version.cuda else [])) + @parametrize("format", ["mxfp8"] + (["nvfp4", "mxfp4"] if torch.version.cuda else [])) def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format): torch.manual_seed(42) total_K = K # Alias for clarity, communicating this consists of several groups along this dim @@ -675,14 +684,14 @@ def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format): X = torch.randn((M, total_K), dtype=torch.bfloat16, device="cuda") * 0.1 W = torch.randn((N, total_K), dtype=torch.bfloat16, device="cuda") * 0.01 - xh, xq, x_blocked_scales, x_global_scales = _2d_grouped_tensor_to_mxfp8_blocked_scaled( + xh, xq, x_blocked_scales, x_global_scales = _2d_grouped_tensor_to_blocked_scaled( X, M, G, input_group_end_offsets, format=format ) - wh, wq, w_blocked_scales, w_global_scales = _2d_grouped_tensor_to_mxfp8_blocked_scaled( + wh, wq, w_blocked_scales, w_global_scales = _2d_grouped_tensor_to_blocked_scaled( W, N, G, input_group_end_offsets, format=format ) - if format == "mxfp8": + if format in ["mxfp4", "mxfp8"]: kwargs = _build_scaled_grouped_mm_kwargs( x_blocked_scales, w_blocked_scales, @@ -697,7 +706,7 @@ def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format): format, ) else: - raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"') + raise ValueError(f'format must be mxfp8|nvfp4|mxfp4, got "{format}"') if format == 'nvfp4': assert x_global_scales.numel() == w_global_scales.numel() @@ -718,7 +727,7 @@ def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format): ) # Assert no NaNs - assert not y_lp.isnan().any(), "mxfp8 output contains NaN" + assert not y_lp.isnan().any(), "low-precision output contains NaN" # Assert outputs are close torch.testing.assert_close(y_lp, y_bf16, atol=8.0e-2, rtol=8.0e-2) @@ -728,7 +737,7 @@ def test_mxfp8_nvfp4_scaled_grouped_mm_2d_2d(self, G, M, N, K, format): @parametrize("M", [16640]) @parametrize("N", [8192]) @parametrize("K", [4096]) - @parametrize("format", ["mxfp8"] + (["nvfp4"] if torch.version.cuda else [])) + @parametrize("format", ["mxfp8"] + (["nvfp4", "mxfp4"] if torch.version.cuda else [])) def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, format): torch.manual_seed(42) # Simulate 2d-3d grouped gemm `out = input @ weight.t()` @@ -752,15 +761,17 @@ def _3d_to_blocked_scaled(W, G, format): if format == "mxfp8": wh, wq, w_scale = _convert_to_mxfp8_with_hp_ref(W[i]) elif format == "nvfp4": - w_scale, wq = to_mxfp8(W[i]) + w_scale, wq = to_mxfp(W[i], format="mxfp8") wh, wq, w_scale, w_global_scale = _convert_to_nvfp4_with_hp_ref(W[i]) w_global_scale_list.append(w_global_scale) + elif format == "mxfp4": + wh, wq, w_scale = _convert_to_mxfp4_with_hp_ref(W[i]) else: - raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"') + raise ValueError(f'format must be mxfp8|nvfp4|mxfp4, got "{format}"') # Swizzle scaled - # TODO(slayton): gate on cuda/hip - w_scale = to_blocked(w_scale) + if torch.version.cuda: + w_scale = to_blocked(w_scale) wh_list.append(wh) wq_list.append(wq) @@ -795,10 +806,13 @@ def _2d_to_blocked_scaled(X, K, G, offs, format): elif format == "nvfp4": xh, xq, x_scale, x_global_scale = _convert_to_nvfp4_with_hp_ref(x_slice) x_global_scale_list.append(x_global_scale) + elif format == "mxfp4": + xh, xq, x_scale = _convert_to_mxfp4_with_hp_ref(x_slice) else: - raise ValueError(f'format must be mxfp8|nvfp4, got "{format}"') + raise ValueError(f'format must be mxfp8|nvfp4|mxfp4, got "{format}"') - x_scale = to_blocked(x_scale) + if torch.version.cuda: + x_scale = to_blocked(x_scale) xh_list.append(xh) xq_list.append(xq) x_scale_list.append(x_scale) @@ -817,7 +831,7 @@ def _2d_to_blocked_scaled(X, K, G, offs, format): xh, xq, x_blocked_scales, x_global_scales = _2d_to_blocked_scaled(X, K, G, input_group_end_offsets, format) - if format == "mxfp8": + if format in ["mxfp8", "mxfp4"]: kwargs = _build_scaled_grouped_mm_kwargs( x_blocked_scales, w_blocked_scales, @@ -1169,7 +1183,7 @@ def e5m2(): @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific") - @parametrize("base_dtype", [torch.bfloat16, torch.float32]) + @parametrize("base_dtype", [torch.bfloat16, torch.float16, torch.float32]) @with_tf32_off def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): # Fp32 out_dtype is only supported by cuBLAS, which however only started @@ -1182,12 +1196,21 @@ def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): if torch.cuda.get_device_capability() < (9, 0): raise unittest.SkipTest("Need sm90+ for row-wise fp8 w/ cuBLAS") + if base_dtype is torch.float16: + if torch.version.hip: + raise unittest.SkipTest("hipblaslt rowwise _scaled_mm only supports BFloat16") + if torch.cuda.get_device_capability() < (9, 0): + raise unittest.SkipTest("Need sm90+ for row-wise fp8 w/ cuBLAS") + torch.manual_seed(42) input_dtype = e4m3_type output_dtype = base_dtype x = torch.randn(16, 16, device="cuda", dtype=base_dtype) y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() + bias = None + if base_dtype in {torch.bfloat16, torch.float16}: + bias = torch.randn((32,), device="cuda", dtype=base_dtype) x_scales = tensor_to_scale(x, input_dtype, dim=1).float() y_scales = tensor_to_scale(y, input_dtype, dim=0).float() @@ -1202,12 +1225,13 @@ def test(): y_fp8, scale_a=x_scales.reciprocal(), scale_b=y_scales.reciprocal(), - out_dtype=output_dtype + out_dtype=output_dtype, + bias=bias ) # Calculate emulated F8 mm out_emulated = mm_float8_emulated( - x_fp8, x_scales, y_fp8, y_scales, output_dtype + x_fp8, x_scales, y_fp8, y_scales, output_dtype, bias ) if base_dtype in {torch.bfloat16, torch.float16}: @@ -1222,7 +1246,7 @@ def test(): if torch.cuda.get_device_capability() != (9, 0) and output_dtype == torch.float: with self.assertRaisesRegex( ValueError, - "Only bf16 high precision output types are supported for row-wise scaling." + "Only bf16 and fp16 high precision output types are supported for row-wise scaling." ): test() else: @@ -1238,6 +1262,7 @@ def test(): @parametrize("output_dtype", [torch.bfloat16, torch.float32]) @parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)]) @parametrize("M,N,K", [(256, 768, 512)]) + @with_tf32_off def test_scaled_mm_vs_emulated_block_wise(self, output_dtype, lhs_block, rhs_block, M, N, K): torch.manual_seed(42) @@ -1337,6 +1362,56 @@ def test_scaled_mm_vs_emulated_block_wise_verify_small_shapes( # Verify that emulated F8 mm doesn't error mm_float8_emulated_block(x_fp8, x_scales, y_fp8.t(), y_scales.t(), output_dtype) + @skipIfRocm + @onlyCUDA + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) + @unittest.skipIf(IS_SM90, "cuBLAS blockwise scaling works on sm90") + @unittest.skipIf( + _get_torch_cuda_version() < (12, 9), + "cuBLAS blockwise scaling added in CUDA 12.9", + ) + @parametrize("output_dtype", [torch.bfloat16, ]) + @parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)]) + @parametrize("M,N,K", [(256, 256, 256), (256, 256, 512)]) + def test_scaled_mm_deepseek_error_messages( + self, output_dtype, lhs_block, rhs_block, M, N, K + ): + torch.manual_seed(42) + + x = torch.randn(M, K, device="cuda", dtype=output_dtype).pow(3) + y = torch.randn(N, K, device="cuda", dtype=output_dtype).pow(3) + + x_fp8, x_scales = tensor_to_scale_block(x, e4m3_type, lhs_block, 128) + y_fp8, y_scales = tensor_to_scale_block(y, e4m3_type, rhs_block, 128) + + # 1x128 blocks need scales to be outer-dim-major + if lhs_block == 1: + x_scales = x_scales.t().contiguous().t() + lhs_recipe = ScalingType.BlockWise1x128 + else: + lhs_recipe = ScalingType.BlockWise128x128 + + if rhs_block == 1: + y_scales = y_scales.t().contiguous().t() + rhs_recipe = ScalingType.BlockWise1x128 + else: + rhs_recipe = ScalingType.BlockWise128x128 + + # Verify that actual F8 mm doesn't error + with self.assertRaisesRegex( + NotImplementedError, + ".*DeepSeek.*scaling.*only supported in CUDA for SM90.*" + ): + scaled_mm_wrap( + x_fp8, + y_fp8.t(), + scale_a=x_scales, + scale_recipe_a=lhs_recipe, + scale_b=y_scales.t(), + scale_recipe_b=rhs_recipe, + out_dtype=output_dtype, + ) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("which_dim_zero", [0, 1, 2]) @parametrize("use_torch_compile", [False, True]) @@ -1484,8 +1559,13 @@ def test_blockwise_nvfp4_with_global_scale(self, mkn) -> None: A, A_scale, A_global_scale = data_to_nvfp4_with_global_scale(A_ref, BLOCK_SIZE) B, B_scale, B_global_scale = data_to_nvfp4_with_global_scale(B_ref, BLOCK_SIZE) - A_scale = to_blocked(A_scale) - B_scale = to_blocked(B_scale) + + if torch.version.cuda: + A_scale = to_blocked(A_scale) + B_scale = to_blocked(B_scale) + swizzle = [SwizzleType.SWIZZLE_32_4_4, SwizzleType.NO_SWIZZLE] + else: + swizzle = [SwizzleType.NO_SWIZZLE, SwizzleType.NO_SWIZZLE] C_ref = A_ref @ B_ref.t() @@ -1496,8 +1576,8 @@ def test_blockwise_nvfp4_with_global_scale(self, mkn) -> None: scale_recipe_a=[ScalingType.BlockWise1x16, ScalingType.TensorWise], scale_b=[B_scale, B_global_scale], scale_recipe_b=[ScalingType.BlockWise1x16, ScalingType.TensorWise], - swizzle_a=[SwizzleType.SWIZZLE_32_4_4, SwizzleType.NO_SWIZZLE], - swizzle_b=[SwizzleType.SWIZZLE_32_4_4, SwizzleType.NO_SWIZZLE], + swizzle_a=swizzle, + swizzle_b=swizzle, output_dtype=torch.bfloat16, ) diff --git a/test/test_serialization.py b/test/test_serialization.py index e378c6c2789d6..dcf67fe3ccf14 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -295,7 +295,7 @@ def test_serialization_fake_zip(self): 5, 6 ] - for i in range(100): + for _ in range(100): data.append(0) t = torch.tensor(data, dtype=torch.uint8) diff --git a/test/test_sparse.py b/test/test_sparse.py index eb6877b419d0b..809c30b92a8b7 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -223,6 +223,12 @@ def assert_uncoalesced(self, x): else: existing_indices.add(index) + def test_negative_indices(self): + indices = torch.tensor([[0, 1, -1], [2, 0, 1]]) + values = torch.tensor([1, 2, 3]) + shape = torch.Size([3, 3]) + self.assertRaisesRegex(RuntimeError, "found negative index", lambda: torch.sparse_coo_tensor(indices, values, shape)) + def randn(self, *args, **kwargs): """ Variant of torch.randn that also works in the TEST_CUDA case. @@ -1433,7 +1439,7 @@ def test_bmm(self, device, dtype, coalesced): def test_shape(num_mats, dim_i, dim_j, dim_k, nnz): a_list = [] b_list = [] - for mat_idx in range(num_mats): + for _ in range(num_mats): a_mat = self._gen_sparse(2, nnz, [dim_i, dim_j], dtype, device, coalesced)[0] b_mat = torch.randn([dim_j, dim_k], dtype=dtype, device=device) a_list.append(a_mat) @@ -1489,7 +1495,7 @@ def test_bmm_deterministic(self, device, dtype, coalesced): def test_shape(num_mats, dim_i, dim_j, dim_k, nnz): a_list = [] b_list = [] - for mat_idx in range(num_mats): + for _ in range(num_mats): a_list.append(self._gen_sparse(2, nnz, [dim_i, dim_j], dtype, device, coalesced)[0]) b_list.append(torch.randn([dim_j, dim_k], dtype=dtype, device=device)) @@ -3558,7 +3564,7 @@ def softmax_jacobian_autograd(x, dim, log=False): values = torch.ones((indices.shape[1],) + shape[sparse_dim:], dtype=dtype, device=device) else: ranges = [] - for j, sz in enumerate(shape[:sparse_dim]): + for sz in shape[:sparse_dim]: ranges.append(list(range(sz))) indices = torch.tensor(list(itertools.product(*ranges)), dtype=torch.long, device=device).t() values = torch.zeros((indices.shape[1],) + shape[sparse_dim:], dtype=dtype, device=device) diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index e1bfd3f146991..9e9670b17d37b 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -4252,9 +4252,9 @@ def get_meta_with_checks(M, K, N, warn_count=0, sparsity=None): # Test warn_once when requesting non-existing tuned parameters multiple times f = io.StringIO() with redirect_stderr(f): - for i in range(5): + for _ in range(5): get_meta(16, 16, 16) - for i in range(5): + for _ in range(5): get_meta(16, 16, 32) msg = f.getvalue() diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py index 03c62a272286d..6284be2aebe9e 100644 --- a/test/test_spectral_ops.py +++ b/test/test_spectral_ops.py @@ -481,7 +481,7 @@ def test_fftn_noop_transform(self, device, dtype): torch.fft.ifft2, ]: inp = make_tensor((10, 10), device=device, dtype=dtype) - out = torch.fft.fftn(inp, dim=[]) + out = op(inp, dim=[]) expect_dtype = RESULT_TYPE.get(inp.dtype, inp.dtype) expect = inp.to(expect_dtype) @@ -1315,7 +1315,7 @@ def _test_istft_is_inverse_of_stft(stft_kwargs): istft_kwargs = stft_kwargs.copy() del istft_kwargs['pad_mode'] for sizes in data_sizes: - for i in range(num_trials): + for _ in range(num_trials): original = torch.randn(*sizes, dtype=dtype, device=device) stft = torch.stft(original, return_complex=True, **stft_kwargs) inversed = torch.istft(stft, length=original.size(1), **istft_kwargs) @@ -1386,7 +1386,7 @@ def _test_istft_is_inverse_of_stft_with_padding(stft_kwargs): del stft_kwargs['size'] istft_kwargs = stft_kwargs.copy() del istft_kwargs['pad_mode'] - for i in range(num_trials): + for _ in range(num_trials): original = torch.randn(*sizes, dtype=dtype, device=device) stft = torch.stft(original, return_complex=True, **stft_kwargs) with self.assertWarnsOnceRegex(UserWarning, "The length of signal is shorter than the length parameter."): @@ -1501,7 +1501,7 @@ def test_istft_linearity(self, device, dtype): complex_dtype = corresponding_complex_dtype(dtype) def _test(data_size, kwargs): - for i in range(num_trials): + for _ in range(num_trials): tensor1 = torch.randn(data_size, device=device, dtype=complex_dtype) tensor2 = torch.randn(data_size, device=device, dtype=complex_dtype) a, b = torch.rand(2, dtype=dtype, device=device) diff --git a/test/test_static_runtime.py b/test/test_static_runtime.py index df1e0c3e34faa..f7efe9b929168 100644 --- a/test/test_static_runtime.py +++ b/test/test_static_runtime.py @@ -139,7 +139,7 @@ def fork_wait_graph_exception(input1, input2): def loop_graph(a, b, iters: int): c = a + b * 2 - for i in range(iters): + for _ in range(iters): c = c + b c *= 2 c -= a diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index b9256b322bb8a..b5a0fe5e9d3fa 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -81,7 +81,7 @@ def _generate_input(shape, dtype, device, with_extremal): # TODO: replace with make_tensor def _rand_shape(dim, min_size, max_size): shape = [] - for i in range(dim): + for _ in range(dim): shape.append(random.randint(min_size, max_size)) return tuple(shape) @@ -466,8 +466,8 @@ def test_torch_complex_floating_dtype_error(self, device, dtype): b = torch.tensor([3, 4], device=device, dtype=dtype) error = r"Expected both inputs to be Half, Float or Double tensors but " \ r"got [A-Za-z]+ and [A-Za-z]+" - with self.assertRaisesRegex(RuntimeError, error): - op(a, b) + with self.assertRaisesRegex(RuntimeError, error): + op(a, b) @onlyNativeDeviceTypes @dtypes(torch.float32, torch.float64) @@ -942,7 +942,7 @@ def _test_special_stacks(self, dim, at_least_dim, torch_fn, np_fn, device, dtype num_tensors = random.randint(1, 5) torch_input = [] # Create tensors with shape being different along one axis only - for param in range(num_tensors): + for _ in range(num_tensors): shape[i] = random.randint(1, 5) torch_input.append(_generate_input(tuple(shape), dtype, device, with_extremal=False)) @@ -997,7 +997,7 @@ def test_vstack_row_stack(self, device, dtype): ops = ((torch.vstack, np.vstack), (torch.row_stack, np.vstack)) for torch_op, np_op in ops: self._test_special_stacks(0, 2, torch_op, np_op, device, dtype) - for i in range(5): + for _ in range(5): # Test dimension change for 1D tensor of size (N) and 2D tensor of size (1, N) n = random.randint(1, 10) input_a = _generate_input((n,), dtype, device, with_extremal=False) @@ -1012,7 +1012,7 @@ def test_vstack_row_stack(self, device, dtype): @dtypes(*all_types_and_complex_and(torch.half)) def test_dstack(self, device, dtype): self._test_special_stacks(2, 3, torch.dstack, np.dstack, device, dtype) - for i in range(5): + for _ in range(5): # Test dimension change for 1D tensor of size (N), 2D tensor of size (1, N), and 3D tensor of size (1, N, 1) n = random.randint(1, 10) input_a = _generate_input((n,), dtype, device, with_extremal=False) @@ -2885,7 +2885,7 @@ def test_signal_window_functions(self, device, dtype, window): @dtypesIfCUDA(torch.float, torch.double, torch.bfloat16, torch.half, torch.long) @dtypes(torch.float, torch.double, torch.long, torch.bfloat16, torch.float16) def test_kaiser_window(self, device, dtype): - for num_test in range(50): + for _ in range(50): self._test_signal_window_functions('kaiser', dtype, device, beta=random.random() * 30) def _test_signal_windows_functions(self, name, dtype, device, **kwargs): @@ -2918,7 +2918,7 @@ def test_signal_windows_functions(self, device, dtype, window): @unittest.skipIf(not TEST_SCIPY, "Scipy not found") @dtypes(torch.float, torch.double) def test_kaiser(self, device, dtype): - for num_test in range(50): + for _ in range(50): self._test_signal_windows_functions('kaiser', dtype, device, beta=random.random() * 30) def test_tensor_factories_empty(self, device): @@ -3808,6 +3808,137 @@ def test_full_like_inference(self, device): self.assertEqual(torch.full_like(like, 1., dtype=torch.complex64).dtype, torch.complex64) + def test_rand_like(self, device): + like_tensor = torch.zeros(100, 100, device=device) + + def seed(generator): + if generator is None: + torch.manual_seed(123456) + else: + generator.manual_seed(123456) + return generator + + for generator in (None, torch.Generator(device)): + generator = seed(generator) + res1 = torch.rand_like(like_tensor, generator=generator) + + generator = seed(generator) + res2 = torch.empty_like(like_tensor) + res2 = torch.rand_like(like_tensor, generator=generator) + + self.assertEqual(res1, res2) + self.assertTrue((res1 >= 0).all().item()) + self.assertTrue((res1 < 1).all().item()) + self.assertEqual(res1.shape, like_tensor.shape) + + gen0 = torch.Generator(device) + gen1 = torch.Generator(device) + gen2 = torch.Generator(device) + gen0.manual_seed(42) + gen1.manual_seed(42) + gen2.manual_seed(123456) + + tensor0 = torch.rand_like(like_tensor, generator=gen0) + tensor1 = torch.rand_like(like_tensor, generator=gen1) + tensor2 = torch.rand_like(like_tensor, generator=gen2) + self.assertEqual(tensor0, tensor1) + self.assertNotEqual(tensor2, tensor0) + self.assertNotEqual(tensor2, tensor1) + + tensor0 = torch.rand_like(like_tensor, generator=gen0) + self.assertNotEqual(tensor0, tensor1) + + def test_randn_like(self, device): + like_tensor = torch.zeros(100, 100, device=device) + + def seed(generator): + if generator is None: + torch.manual_seed(123456) + else: + generator.manual_seed(123456) + return generator + + for generator in (None, torch.Generator(device)): + generator = seed(generator) + res1 = torch.randn_like(like_tensor, generator=generator) + + generator = seed(generator) + res2 = torch.empty_like(like_tensor) + res2 = torch.randn_like(like_tensor, generator=generator) + + self.assertEqual(res1, res2) + self.assertEqual(res1.shape, like_tensor.shape) + + gen0 = torch.Generator(device) + gen1 = torch.Generator(device) + gen2 = torch.Generator(device) + gen0.manual_seed(42) + gen1.manual_seed(42) + gen2.manual_seed(123456) + + tensor0 = torch.randn_like(like_tensor, generator=gen0) + tensor1 = torch.randn_like(like_tensor, generator=gen1) + tensor2 = torch.randn_like(like_tensor, generator=gen2) + self.assertEqual(tensor0, tensor1) + self.assertNotEqual(tensor2, tensor0) + self.assertNotEqual(tensor2, tensor1) + + tensor0 = torch.randn_like(like_tensor, generator=gen0) + self.assertNotEqual(tensor0, tensor1) + + + def test_randint_like(self, device): + like_tensor = torch.zeros(100, 100, device=device, dtype=torch.long) + + def seed(generator): + if generator is None: + torch.manual_seed(123456) + else: + generator.manual_seed(123456) + return generator + + for generator in (None, torch.Generator(device)): + generator = seed(generator) + res1 = torch.randint_like(like_tensor, 0, 10, generator=generator) + + generator = seed(generator) + res2 = torch.empty_like(like_tensor) + res2 = torch.randint_like(like_tensor, 0, 10, generator=generator) + + generator = seed(generator) + res3 = torch.randint_like(like_tensor, 10, generator=generator) + + generator = seed(generator) + res4 = torch.empty_like(like_tensor) + res4 = torch.randint_like(like_tensor, 10, generator=generator) + + self.assertEqual(res1, res2) + self.assertEqual(res3, res4) + self.assertTrue((res1 >= 0).all().item()) + self.assertTrue((res1 < 10).all().item()) + self.assertTrue((res3 >= 0).all().item()) + self.assertTrue((res3 < 10).all().item()) + self.assertEqual(res1.shape, like_tensor.shape) + self.assertEqual(res3.shape, like_tensor.shape) + + gen0 = torch.Generator(device) + gen1 = torch.Generator(device) + gen2 = torch.Generator(device) + gen0.manual_seed(42) + gen1.manual_seed(42) + gen2.manual_seed(123456) + + tensor0 = torch.randint_like(like_tensor, 0, 10, generator=gen0) + tensor1 = torch.randint_like(like_tensor, 0, 10, generator=gen1) + tensor2 = torch.randint_like(like_tensor, 0, 10, generator=gen2) + self.assertEqual(tensor0, tensor1) + self.assertNotEqual(tensor2, tensor0) + self.assertNotEqual(tensor2, tensor1) + + tensor0 = torch.randint_like(like_tensor, 0, 10, generator=gen0) + self.assertNotEqual(tensor0, tensor1) + + # Tests for the `frombuffer` function (only work on CPU): # Constructs tensors from Python objects that implement the buffer protocol, # without copying data. diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py index 628e45ed8eb6d..cf2c836486c80 100644 --- a/test/test_tensorexpr.py +++ b/test/test_tensorexpr.py @@ -1216,7 +1216,7 @@ def test_loop(self): @torch.jit.script def test(x: torch.Tensor, y: torch.Tensor, z: int) -> torch.Tensor: b = y - for i in range(z): + for _ in range(z): a = x + y b = b + y return b diff --git a/test/test_testing.py b/test/test_testing.py index 1735bcdcbb060..c660eb83b8042 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -2367,6 +2367,7 @@ def test_circular_dependencies(self) -> None: "torch.onnx._internal", # depends on onnx-script "torch._inductor.runtime.triton_helpers", # depends on triton "torch._inductor.codegen.cuda", # depends on cutlass + "torch._inductor.codegen.cutedsl", # depends on cutlass "torch.distributed.benchmarks", # depends on RPC and DDP Optim "torch.distributed.examples", # requires CUDA and torchvision "torch.distributed.tensor.examples", # example scripts diff --git a/test/test_throughput_benchmark.py b/test/test_throughput_benchmark.py index fe838928b8e0a..f98e837611d9e 100644 --- a/test/test_throughput_benchmark.py +++ b/test/test_throughput_benchmark.py @@ -46,7 +46,7 @@ def linear_test(self, Module, profiler_output_path=""): inputs = [] - for i in range(NUM_INPUTS): + for _ in range(NUM_INPUTS): inputs.append([torch.randn(B, D_in), torch.randn(B, D_in)]) bench = ThroughputBenchmark(module) diff --git a/test/test_torch.py b/test/test_torch.py index 47e65ab6a12e1..b54ae93baa647 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -138,7 +138,7 @@ class TestTorchDeviceType(TestCase): # TODO: move all tensor creation to common ops def _rand_shape(self, dim, min_size, max_size): shape = [] - for i in range(dim): + for _ in range(dim): shape.append(random.randint(min_size, max_size)) return tuple(shape) @@ -172,7 +172,7 @@ def rand_byte(): element_size = torch._utils._element_size(dtype) - for i in range(10): + for _ in range(10): bytes_list = [rand_byte() for _ in range(element_size)] scalar = bytes_to_scalar(bytes_list, dtype, device) self.assertEqual(scalar.storage().untyped().tolist(), bytes_list) @@ -2012,7 +2012,7 @@ def test_scatter_add_one_dim_deterministic(self, device) -> None: res = x.scatter_add(dim, idx, src) # Checking if scatter_add is deterministic - for i in range(5): + for _ in range(5): res_next = x.scatter_add(dim, idx, src) self.assertEqual(res, res_next, atol=0, rtol=0) res = res_next @@ -2264,7 +2264,7 @@ def check(t, correction=1, fweights=None, aweights=None): fweights = torch.randint(1, 10, (num_observations,), device=device) aweights = make_tensor((num_observations,), dtype=torch.float, device=device, low=1) for correction, fw, aw in product([0, 1, 2], [None, fweights], [None, aweights]): - check(x, correction, fweights, aweights) + check(x, correction, fw, aw) @skipIfNoSciPy @dtypes(*floating_types_and(torch.half, torch.bfloat16)) @@ -2479,7 +2479,8 @@ def test_cdist_cuda_backward(self, device): self.assertEqual(x1.grad, x2.grad, rtol=0, atol=0.001) self.assertEqual(y1.grad, y2.grad, rtol=0, atol=0.001) - @tf32_on_and_off(0.05 if TEST_WITH_ROCM else 0.005) + @skipIfRocmArch(MI300_ARCH) + @tf32_on_and_off(0.005) @reduced_f32_on_and_off(0.08) def test_cdist_large(self, device): for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: @@ -5151,7 +5152,7 @@ def test_multinomial_deterministic(self, device, dtype): prob_dist = torch.rand(10000, 1000, device=device, dtype=dtype) n_sample = 1 - for i in range(trials): + for _ in range(trials): gen.manual_seed(seed) samples_1 = torch.multinomial(prob_dist, n_sample, True, generator=gen) @@ -5229,7 +5230,7 @@ def _test_memory_format_transformations(self, device, input_generator_fn, transf # TODO copy _like constructors to stride permutation instead of just layout if not TEST_WITH_TORCHINDUCTOR: x = torch.randn((3, 4, 5, 6, 7, 8, 9), device=device) - for i in range(10): + for _ in range(10): permutation = list(range(len(x.shape))) random.shuffle(permutation) x = x.permute(permutation) diff --git a/test/test_torchfuzz_repros.py b/test/test_torchfuzz_repros.py index adfdd755bc7bc..3b864aae4f477 100644 --- a/test/test_torchfuzz_repros.py +++ b/test/test_torchfuzz_repros.py @@ -13,6 +13,7 @@ import torch from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON class TestFuzzerCompileIssues(TestCase): @@ -220,67 +221,6 @@ def foo(arg0, arg1, arg2): out_compiled.sum().backward() print("Compile Success! ✅") - @pytest.mark.xfail(reason="Issue #164086") - def test_fuzzer_issue_164086(self): - torch.manual_seed(0) - - def foo(arg0, arg1, arg2, arg3, arg4, arg5): - t0 = arg0 # size=(42, 56), stride=(42, 1), dtype=int64, device=cuda - t1 = torch.tanh( - t0 - ) # size=(42, 56), stride=(42, 1), dtype=int64, device=cuda - t2 = t1.clone() - t2.zero_() # size=(42, 56), stride=(42, 1), dtype=int64, device=cuda - t3 = ( - arg1 # size=(50000, 128), stride=(50000, 1), dtype=float16, device=cuda - ) - t4 = arg2 # size=(46, 128), stride=(46, 1), dtype=float16, device=cuda - t5 = torch.nn.functional.linear( - t3, t4 - ) # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda - t6 = arg3 # size=(50000, 4, 46), stride=(184, 46, 1), dtype=float16, device=cuda - t7 = t6.max( - dim=1 - ).values # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda - t8 = arg4 # size=(25786, 46), stride=(46, 1), dtype=float16, device=cuda - t9 = arg5 # size=(24214, 46), stride=(46, 1), dtype=float16, device=cuda - t10 = torch.cat( - [t8, t9], dim=0 - ) # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda - t11 = torch.pow( - torch.pow(torch.pow(torch.pow(t5, t7), t10), t5), t7 - ) # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda - t12 = torch.nn.functional.embedding( - torch.clamp(t2, 0, t11.size(0) - 1).to(torch.long), t11 - ) # size=(42, 56, 46), stride=(2576, 46, 1), dtype=float16, device=cuda - output = t12 - return output - - arg0 = torch.randint(0, 1000, [42, 56], dtype=torch.int64, device="cuda") - arg1 = torch.rand( - [50000, 128], dtype=torch.float16, device="cuda", requires_grad=True - ) - arg2 = torch.rand( - [46, 128], dtype=torch.float16, device="cuda", requires_grad=True - ) - arg3 = torch.rand( - [50000, 4, 46], dtype=torch.float16, device="cuda", requires_grad=True - ) - arg4 = torch.rand( - [25786, 46], dtype=torch.float16, device="cuda", requires_grad=True - ) - arg5 = torch.rand( - [24214, 46], dtype=torch.float16, device="cuda", requires_grad=True - ) - - out_eager = foo(arg0, arg1, arg2, arg3, arg4, arg5) - out_eager.sum().backward() - print("Eager Success! ✅") - compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True) - out_compiled = compiled_foo(arg0, arg1, arg2, arg3, arg4, arg5) - out_compiled.sum().backward() - print("Compile Success! ✅") - @pytest.mark.xfail(reason="Issue #163877") def test_fuzzer_issue_163877(self): torch.manual_seed(0) diff --git a/test/test_transformers.py b/test/test_transformers.py index 4dea431246999..56e1365d33c44 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -24,6 +24,8 @@ from torch.testing._internal.common_utils import ( TEST_WITH_ROCM, skipIfRocm, + skipIfRocmArch, + MI300_ARCH, skipIfTorchDynamo, TEST_FAIRSEQ, run_tests, @@ -303,7 +305,7 @@ def test_train_with_pad_and_catch_error(self, device): encoder = nn.TransformerEncoder(layer, 2).to(device) optimizer = optim.SGD(encoder.parameters(), lr=0.1, momentum=0.9) encoder.train() - for i in range(iters): + for _ in range(iters): encoder.train() optimizer.zero_grad() inputs = torch.cat([torch.randn(1, 2, 2), torch.zeros(1, 2, 2)], dim=1).to(device) @@ -427,7 +429,8 @@ def hook(module, inputs, output): # remove hook handle.remove() - @tf32_on_and_off(0.0021 if TEST_WITH_ROCM else 0.001) + @skipIfRocmArch(MI300_ARCH) + @tf32_on_and_off(0.001) @parametrize("use_torchscript", [False]) @parametrize("enable_nested_tensor", [True, False]) @parametrize("use_autocast", [True, False]) @@ -537,7 +540,7 @@ def test_transformerencoder_square_input(self, with_no_grad, training, enable_ne with torch.no_grad(): # set constant weights of the model - for idx, p in enumerate(model.parameters()): + for p in model.parameters(): x = p.data sz = x.view(-1).size(0) shape = x.shape @@ -587,7 +590,7 @@ def get_a_test_layer(activation, batch_first=False): with torch.no_grad(): # set constant weights of the model - for idx, p in enumerate(layer.parameters()): + for p in layer.parameters(): x = p.data sz = x.view(-1).size(0) shape = x.shape @@ -2845,6 +2848,30 @@ def test_cudnn_attention_seqlen1_dropout_heuristic(self): out = torch.nn.functional.scaled_dot_product_attention(q, q, q, dropout_p=0.5) out.backward(grad) + @skipIfRocm + @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") + def test_cudnn_attention_broken_166211(self): + # https://github.com/pytorch/pytorch/issues/166211#issue-3551350377 + shape = (20, 4, 4, 32) + scale = 10 + for i in range(100): + q = torch.randn(*shape, device='cuda', dtype=torch.bfloat16) * scale + k = torch.randn(*shape, device='cuda', dtype=torch.bfloat16) * scale + v = torch.randn(*shape, device='cuda', dtype=torch.bfloat16) * scale + q.requires_grad = True + k.requires_grad = True + v.requires_grad = True + + grad_attn_output = torch.randn(*shape, device='cuda', dtype=torch.bfloat16) * scale + + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): + attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v) + dq, dk, dv = torch.autograd.grad(outputs=attn_output, inputs=(q, k, v), grad_outputs=grad_attn_output) + + self.assertFalse(dq.isnan().any()) + self.assertFalse(dk.isnan().any()) + self.assertFalse(dv.isnan().any()) + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @parametrize("mask_dim", [1, 2, 3, 4]) def test_mem_efficient_attention_mask_variants(self, device, mask_dim: list[int]): diff --git a/test/test_varlen_attention.py b/test/test_varlen_attention.py index f249adf21a528..1a8371eee6345 100644 --- a/test/test_varlen_attention.py +++ b/test/test_varlen_attention.py @@ -5,22 +5,29 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn.attention import varlen_attn +from torch.nn.attention.varlen import varlen_attn from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_nn import NNTestCase -from torch.testing._internal.common_utils import parametrize, run_tests +from torch.testing._internal.common_utils import parametrize, run_tests, skipIfRocm +from torch.utils._python_dispatch import TorchDispatchMode VarlenShape = namedtuple( "VarlenShape", ["batch_size", "max_seq_len", "embed_dim", "num_heads"] ) -default_tolerances = { - torch.float16: {"atol": 1e-1, "rtol": 1e-1}, - torch.bfloat16: {"atol": 9e-2, "rtol": 5e-2}, - torch.float32: {"atol": 1e-5, "rtol": 1.3e-6}, -} + +class OpLoggingMode(TorchDispatchMode): + """Logging mode that captures all dispatched operations""" + + def __init__(self): + self.called_ops = [] + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + op_name = str(func) + self.called_ops.append(op_name) + return func(*args, **(kwargs or {})) class AttentionBlock(nn.Module): @@ -39,12 +46,9 @@ def __init__( embed_dim, embed_dim, bias=False, device=device, dtype=dtype ) - def forward_varlen( + def get_varlen_qkv( self, x_packed: torch.Tensor, - cu_seq: torch.Tensor, - max_len: int, - is_causal: bool = False, ): qkv = self.qkv_proj(x_packed) q, k, v = qkv.chunk(3, dim=-1) @@ -53,24 +57,50 @@ def forward_varlen( k = k.view(-1, self.num_heads, self.head_dim) v = v.view(-1, self.num_heads, self.head_dim) - attn_out = varlen_attn( - q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal=is_causal - ) + return q, k, v + + def forward_varlen( + self, + x_packed: torch.Tensor, + cu_seq: torch.Tensor, + max_len: int, + is_causal: bool = False, + ): + q, k, v = self.get_varlen_qkv(x_packed) + + attn_out = varlen_attn(q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal) attn_out = attn_out.view(-1, self.embed_dim) return self.out_proj(attn_out) - def forward_sdpa(self, x_padded: torch.Tensor, is_causal: bool = False): + def forward_sdpa( + self, + x_padded: torch.Tensor, + seq_lengths: torch.Tensor, + is_causal: bool = False, + ): batch_size, seq_len, _ = x_padded.shape qkv = self.qkv_proj(x_padded) q, k, v = qkv.chunk(3, dim=-1) + mask = ( + torch.arange(seq_len, device=x_padded.device)[None, :] + < seq_lengths[:, None] + ) + + attn_mask = mask[:, None, None, :].expand( + batch_size, self.num_heads, seq_len, seq_len + ) + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal) + attn_out = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, is_causal=is_causal + ) + attn_out = ( attn_out.transpose(1, 2) .contiguous() @@ -91,7 +121,9 @@ def create_variable_length_batch( seq_lengths = torch.tensor(seq_lengths, device=device) total_tokens = seq_lengths.sum().item() - x_packed = torch.randn(total_tokens, shape.embed_dim, device=device, dtype=dtype) + x_packed = torch.randn( + total_tokens, shape.embed_dim, device=device, dtype=dtype, requires_grad=True + ) cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32) cu_seq[1:] = seq_lengths.cumsum(0) @@ -106,6 +138,7 @@ def create_variable_length_batch( end_idx = start_idx + seq_len x_padded[i, :seq_len] = x_packed[start_idx:end_idx] start_idx = end_idx + x_padded = x_padded.clone().detach().requires_grad_() return { "seq_lengths": seq_lengths, @@ -118,6 +151,7 @@ def create_variable_length_batch( class TestVarlenAttention(NNTestCase): + @skipIfRocm(msg="ROCM does not support variable length attention") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" ) @@ -133,7 +167,11 @@ def test_basic_functionality(self, device, dtype): total_tokens = shape.batch_size * shape.max_seq_len x_packed = torch.randn( - total_tokens, shape.embed_dim, device=device, dtype=dtype + total_tokens, + shape.embed_dim, + device=device, + dtype=dtype, + requires_grad=True, ) cu_seq = torch.tensor( [0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32 @@ -147,6 +185,131 @@ def test_basic_functionality(self, device, dtype): self.assertEqual(output.device, torch.device(device)) self.assertEqual(output.dtype, dtype) + varlen_grad_out = torch.ones_like(output) + + varlen_grad = torch.autograd.grad( + outputs=output, + inputs=x_packed, + grad_outputs=varlen_grad_out, + retain_graph=True, + create_graph=False, + allow_unused=False, + )[0] + + self.assertIsNotNone(varlen_grad) + self.assertEqual(varlen_grad.shape, x_packed.shape) + self.assertEqual(varlen_grad.dtype, x_packed.dtype) + + @skipIfRocm(msg="ROCM does not support variable length attention") + @unittest.skipIf( + not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" + ) + @parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_custom_op_compliance(self, device, dtype): + torch.manual_seed(42) + + shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16) + + attention_block = AttentionBlock( + shape.embed_dim, shape.num_heads, device, dtype + ) + + total_tokens = shape.batch_size * shape.max_seq_len + x_packed = torch.randn( + total_tokens, + shape.embed_dim, + device=device, + dtype=dtype, + ) + cu_seq = torch.tensor( + [0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32 + ) + + q, k, v = attention_block.get_varlen_qkv(x_packed) + + torch.library.opcheck( + torch.ops.torch_attn._varlen_attn, + (q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False), + ) + + out, lse, rng_state = torch.ops.torch_attn._varlen_attn( + q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False + ) + grad_out = torch.randn_like(out) + + # we don't support double backward + # skipping test_autograd_registration, test_aot_dispatch_dynamic, test_aot_dispatch_static + torch.library.opcheck( + torch.ops.torch_attn._varlen_attn_backward, + ( + grad_out, + q, + k, + v, + out, + lse, + cu_seq, + cu_seq, + shape.max_seq_len, + shape.max_seq_len, + False, + rng_state, + ), + test_utils=["test_schema", "test_faketensor"], + ) + + @skipIfRocm(msg="ROCM does not support variable length attention") + @unittest.skipIf( + not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" + ) + @parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_custom_op_registration(self, device, dtype): + torch.manual_seed(42) + + shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16) + + attention_block = AttentionBlock( + shape.embed_dim, shape.num_heads, device, dtype + ) + + total_tokens = shape.batch_size * shape.max_seq_len + x_packed = torch.randn( + total_tokens, + shape.embed_dim, + device=device, + dtype=dtype, + requires_grad=True, + ) + cu_seq = torch.tensor( + [0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32 + ) + + compiled_forward = torch.compile( + attention_block.forward_varlen, backend="eager", fullgraph=True + ) + with OpLoggingMode() as mode: + output = compiled_forward( + x_packed, cu_seq, shape.max_seq_len, is_causal=False + ) + + varlen_grad_out = torch.ones_like(output) + _ = torch.autograd.grad( + outputs=output, + inputs=x_packed, + grad_outputs=varlen_grad_out, + retain_graph=True, + create_graph=False, + allow_unused=False, + )[0] + + called_ops = mode.called_ops + + custom_ops_called = any( + "torch_attn._varlen_attn" in op for op in called_ops + ) and any("torch_attn._varlen_attn_backward" in op for op in called_ops) + assert custom_ops_called + + @skipIfRocm(msg="ROCM does not support variable length attention") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" ) @@ -163,7 +326,14 @@ def test_varlen_vs_sdpa(self, device, dtype, is_causal): shape.embed_dim, shape.num_heads, device, dtype ) + golden_attention_block = AttentionBlock( + shape.embed_dim, shape.num_heads, device, torch.float32 + ) + variable_length_batch_data = create_variable_length_batch(shape, device, dtype) + golden_variable_length_batch_data = create_variable_length_batch( + shape, device, torch.float32 + ) varlen_output = attention_block.forward_varlen( variable_length_batch_data["x_packed"], @@ -172,18 +342,90 @@ def test_varlen_vs_sdpa(self, device, dtype, is_causal): is_causal=is_causal, ) sdpa_output = attention_block.forward_sdpa( - variable_length_batch_data["x_padded"], is_causal=is_causal + variable_length_batch_data["x_padded"], + variable_length_batch_data["seq_lengths"], + is_causal=is_causal, + ) + + golden_sdpa_output = golden_attention_block.forward_sdpa( + golden_variable_length_batch_data["x_padded"], + golden_variable_length_batch_data["seq_lengths"], + is_causal=is_causal, ) - tolerances = default_tolerances[dtype] start_idx = 0 for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]): end_idx = start_idx + seq_len varlen_seq = varlen_output[start_idx:end_idx] sdpa_seq = sdpa_output[i, :seq_len] + golden_sdpa_seq = golden_sdpa_output[i, :seq_len] + + fwd_atol = ( + 2 * (golden_sdpa_seq + 0.3 - 0.3 - golden_sdpa_seq).abs().max().item() + ) + + varlen_error = (varlen_seq - fwd_atol).abs().max().item() + sdpa_error = (sdpa_seq - fwd_atol).abs().max().item() + + assert varlen_error <= 2 * sdpa_error + fwd_atol + + start_idx = end_idx + + varlen_grad_out = torch.ones_like(varlen_output) + sdpa_grad_out = torch.ones_like(sdpa_output) + golden_sdpa_grad_out = torch.ones_like(golden_sdpa_output) + + start_idx = 0 + for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]): + end_idx = start_idx + seq_len + sdpa_grad_out[i, :seq_len] = varlen_grad_out[start_idx:end_idx] + start_idx = end_idx + + varlen_grad = torch.autograd.grad( + outputs=varlen_output, + inputs=variable_length_batch_data["x_packed"], + grad_outputs=varlen_grad_out, + retain_graph=True, + create_graph=False, + allow_unused=False, + )[0] + + sdpa_grad = torch.autograd.grad( + outputs=sdpa_output, + inputs=variable_length_batch_data["x_padded"], + grad_outputs=sdpa_grad_out, + retain_graph=True, + create_graph=False, + allow_unused=False, + )[0] + + golden_sdpa_grad = torch.autograd.grad( + outputs=golden_sdpa_output, + inputs=golden_variable_length_batch_data["x_padded"], + grad_outputs=golden_sdpa_grad_out, + retain_graph=True, + create_graph=False, + allow_unused=False, + )[0] + + start_idx = 0 + for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]): + end_idx = start_idx + seq_len + + varlen_grad_seq = varlen_grad[start_idx:end_idx] + sdpa_grad_seq = sdpa_grad[i, :seq_len] + golden_sdpa_seq = golden_sdpa_grad[i, :seq_len] + + fwd_atol = ( + 2 * (golden_sdpa_seq + 0.3 - 0.3 - golden_sdpa_seq).abs().max().item() + ) + + varlen_error = (varlen_grad_seq - fwd_atol).abs().max().item() + sdpa_error = (sdpa_grad_seq - fwd_atol).abs().max().item() + + assert varlen_error <= sdpa_error + fwd_atol - torch.testing.assert_close(varlen_seq, sdpa_seq, **tolerances) start_idx = end_idx diff --git a/test/test_xpu.py b/test/test_xpu.py index 9daa4b5501176..61dd91e5bfafc 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -489,6 +489,7 @@ def test_set_per_process_memory_fraction(self): torch.xpu.empty_cache() total_memory = torch.xpu.get_device_properties().total_memory fraction = 0.5 + orig_fraction = torch.xpu.get_per_process_memory_fraction() with self.assertRaisesRegex(ValueError, "invalid fraction:"): torch.xpu.set_per_process_memory_fraction(-0.1) with self.assertRaisesRegex(ValueError, "invalid fraction:"): @@ -503,11 +504,13 @@ def test_set_per_process_memory_fraction(self): gc.collect() torch.xpu.empty_cache() + self.assertEqual(fraction, torch.xpu.get_per_process_memory_fraction()) + application_memory = int(total_memory * 0.51) with self.assertRaises(torch.OutOfMemoryError): _ = torch.empty(application_memory, dtype=torch.int8, device="xpu") - torch.xpu.set_per_process_memory_fraction(1.0) + torch.xpu.set_per_process_memory_fraction(orig_fraction) def test_memory_allocation(self): torch.xpu.empty_cache() diff --git a/test/torch_np/numpy_tests/core/test_multiarray.py b/test/torch_np/numpy_tests/core/test_multiarray.py index f20e3ec7a166b..cc5e64874a05e 100644 --- a/test/torch_np/numpy_tests/core/test_multiarray.py +++ b/test/torch_np/numpy_tests/core/test_multiarray.py @@ -6131,7 +6131,7 @@ def __array__(self): assert res is not base_arr for copy in self.false_vals: - res = np.array(arr, copy=False) + res = np.array(arr, copy=copy) assert_array_equal(res, base_arr) assert res is base_arr # numpy trusts the ArrayLike diff --git a/third_party/METADATA.bzl b/third_party/METADATA.bzl deleted file mode 100644 index 6a1a9a4ca9762..0000000000000 --- a/third_party/METADATA.bzl +++ /dev/null @@ -1,7 +0,0 @@ -METADATA = { - "maintainers": [ - "pytorch_dev_infra", - ], - "name": "third_party", - "owner": "pytorch_dev_infra", -} diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 8bb99f982e0be..88e0a316f9d09 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1073,8 +1073,8 @@ - name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U) LU_data: lu_unpack_backward(grad_L, grad_U, LU_data.sym_size(-2), LU_data.sym_size(-1)) LU_pivots: non_differentiable - L: "LU_data_t.sym_size(-2) >= LU_data_t.sym_size(-1) ? LU_data_t.tril(-1) : LU_data_t.narrow_symint(-1, 0, LU_data_t.sym_size(-2)).tril(-1)" - U: "LU_data_t.sym_size(-1) >= LU_data_t.sym_size(-2) ? LU_data_t.triu() : LU_data_t.narrow_symint(-2, 0, LU_data_t.sym_size(-1)).triu()" + L: "LU_data_t.sym_size(-2) >= LU_data_t.sym_size(-1) ? LU_data_t.tril_symint(-1) : LU_data_t.narrow_symint(-1, 0, LU_data_t.sym_size(-2)).tril_symint(-1)" + U: "LU_data_t.sym_size(-1) >= LU_data_t.sym_size(-2) ? LU_data_t.triu_symint() : LU_data_t.narrow_symint(-2, 0, LU_data_t.sym_size(-1)).triu_symint()" output_differentiability: [False, True, True] - name: masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor @@ -1782,12 +1782,12 @@ self, B: linalg_solve_triangular_backward(grad, self, result, upper, left, unitriangular, grad_input_mask) result: linalg_solve_triangular_forward_AD(self_t, B_t, self_p, result, upper, left, unitriangular) -- name: tril(Tensor self, int diagonal=0) -> Tensor - self: grad.tril(diagonal) +- name: tril(Tensor self, SymInt diagonal=0) -> Tensor + self: grad.tril_symint(diagonal) result: auto_linear -- name: triu(Tensor self, int diagonal=0) -> Tensor - self: grad.triu(diagonal) +- name: triu(Tensor self, SymInt diagonal=0) -> Tensor + self: grad.triu_symint(diagonal) result: auto_linear - name: trunc(Tensor self) -> Tensor diff --git a/tools/dynamo/verify_dynamo.py b/tools/dynamo/verify_dynamo.py index b6ec848922f5a..a8ce085e864ea 100644 --- a/tools/dynamo/verify_dynamo.py +++ b/tools/dynamo/verify_dynamo.py @@ -216,9 +216,8 @@ def main() -> None: f"ROCM version: {rocm_ver}\n" ) for args in _SANITY_CHECK_ARGS: - if sys.version_info >= (3, 13): - warnings.warn("Dynamo not yet supported in Python 3.13. Skipping check.") - continue + if sys.version_info >= (3, 14): + warnings.warn("Dynamo not yet supported in Python 3.14. Skipping check.") check_dynamo(*args) print("All required checks passed") diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py index b2917a557b4da..5d95b708c34c1 100644 --- a/tools/flight_recorder/components/builder.py +++ b/tools/flight_recorder/components/builder.py @@ -374,6 +374,22 @@ def build_collectives( return tracebacks, collectives, nccl_calls +def transform_ft( + details: dict[str, dict[str, Any]], group_world_size: int +) -> dict[str, dict[str, Any]]: + for dump_key, dump in details.items(): + rank = dump["rank"] + for key, pg_config in dump["pg_config"].items(): + if pg_config["desc"] == "default_pg": + ranks = eval(pg_config["ranks"]) + replica_id = rank // group_world_size + first_rank = replica_id * group_world_size + new_ranks = [r + first_rank for r in ranks] + details[dump_key]["pg_config"][key]["ranks"] = f"{new_ranks}" + + return details + + def build_db( details: dict[str, dict[str, Any]], args: argparse.Namespace, version: str ) -> Database: diff --git a/tools/flight_recorder/components/config_manager.py b/tools/flight_recorder/components/config_manager.py index d43022444e447..1a2336c28c505 100644 --- a/tools/flight_recorder/components/config_manager.py +++ b/tools/flight_recorder/components/config_manager.py @@ -74,6 +74,17 @@ def __init__(self: "JobConfig"): default=10, help="Maximum number of mismatches we print (from earliest).", ) + self.parser.add_argument( + "--transform-ft", + action="store_true", + help="Transform PG config to use global ranks to analyze traces produced by torchft", + ) + self.parser.add_argument( + "--group-world-size", + type=int, + default=None, + help="The number of ranks in 1 torchft replica group. Must be specified if --transform-ft is True", + ) def parse_args( self: "JobConfig", args: Optional[Sequence[str]] diff --git a/tools/flight_recorder/fr_trace.py b/tools/flight_recorder/fr_trace.py index 3ba262832f57e..8989bcdfebd93 100644 --- a/tools/flight_recorder/fr_trace.py +++ b/tools/flight_recorder/fr_trace.py @@ -32,7 +32,7 @@ from collections.abc import Sequence from typing import Optional -from tools.flight_recorder.components.builder import build_db +from tools.flight_recorder.components.builder import build_db, transform_ft from tools.flight_recorder.components.config_manager import JobConfig from tools.flight_recorder.components.loader import read_dir from tools.flight_recorder.components.types import types @@ -46,6 +46,12 @@ def main(args: Optional[Sequence[str]] = None) -> None: assert args.trace_dir, "Trace directory trace_dir is required" # pyrefly: ignore [bad-argument-type] details, version = read_dir(args) + # pyrefly: ignore [missing-attribute] + if args.transform_ft: + # pyrefly: ignore [missing-attribute] + assert args.group_world_size, "World size is required for transform_ft" + # pyrefly: ignore [bad-argument-type] + details = transform_ft(details, args.group_world_size) # pyrefly: ignore [bad-argument-type] db = build_db(details, args, version) # pyrefly: ignore [missing-attribute] diff --git a/tools/linter/adapters/pyrefly_linter.py b/tools/linter/adapters/pyrefly_linter.py index 77ed9c681e522..57d4a99bde18d 100644 --- a/tools/linter/adapters/pyrefly_linter.py +++ b/tools/linter/adapters/pyrefly_linter.py @@ -165,8 +165,6 @@ def check_files( errors = result.get("errors", []) else: errors = [] - # For now filter out deprecated warnings and only report type errors as warnings - # until we remove mypy errors = [error for error in errors if error["name"] != "deprecated"] rc = [ LintMessage( @@ -178,9 +176,9 @@ def check_files( line=error["line"], char=error["column"], code=code, - severity=LintSeverity.ADVICE, - # uncomment and replace when we switch to pyrefly - # severity=LintSeverity.ADVICE if error["name"] == "deprecated" else LintSeverity.ERROR, + severity=LintSeverity.ADVICE + if error["name"] == "deprecated" + else LintSeverity.ERROR, original=None, replacement=None, ) diff --git a/tools/nightly.py b/tools/nightly.py index a829f4729e77a..927bbe66ff423 100755 --- a/tools/nightly.py +++ b/tools/nightly.py @@ -53,6 +53,7 @@ import contextlib import functools import itertools +import json import logging import os import re @@ -128,47 +129,19 @@ class PipSource(NamedTuple): accelerator: str -PYTORCH_NIGHTLY_PIP_INDEX_URL = "https://download.pytorch.org/whl/nightly" -PIP_SOURCES = { - "cpu": PipSource( - name="cpu", - index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cpu", - supported_platforms={"Linux", "macOS", "Windows"}, - accelerator="cpu", - ), - # NOTE: Sync with CUDA_ARCHES in .github/scripts/generate_binary_build_matrix.py - "cuda-12.6": PipSource( - name="cuda-12.6", - index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu126", - supported_platforms={"Linux", "Windows"}, - accelerator="cuda", - ), - "cuda-12.8": PipSource( - name="cuda-12.8", - index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu128", - supported_platforms={"Linux", "Windows"}, - accelerator="cuda", - ), - "cuda-13.0": PipSource( - name="cuda-13.0", - index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/cu130", - supported_platforms={"Linux", "Windows"}, - accelerator="cuda", - ), - # NOTE: Sync with ROCM_ARCHES in .github/scripts/generate_binary_build_matrix.py - "rocm-6.4": PipSource( - name="rocm-6.4", - index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/rocm6.4", - supported_platforms={"Linux"}, - accelerator="rocm", - ), - "rocm-7.0": PipSource( - name="rocm-7.0", - index_url=f"{PYTORCH_NIGHTLY_PIP_INDEX_URL}/rocm7.0", - supported_platforms={"Linux"}, - accelerator="rocm", - ), -} +# Generate: .github/scripts/nightly_source_matrix.json +GENERATE_MATRIX_SCRIPT = ( + REPO_ROOT / ".github" / "scripts" / "generate_binary_build_matrix.py" +) +subprocess.check_call( + [sys.executable, str(GENERATE_MATRIX_SCRIPT)], + cwd=GENERATE_MATRIX_SCRIPT.parent, +) + +# See: .github/scripts/nightly_source_matrix.json +NIGHTLY_SOURCE_FILE = GENERATE_MATRIX_SCRIPT.with_name("nightly_source_matrix.json") +NIGHTLY_SOURCE_MATRIX = json.loads(NIGHTLY_SOURCE_FILE.read_text(encoding="utf-8")) +PIP_SOURCES = {name: PipSource(**data) for name, data in NIGHTLY_SOURCE_MATRIX.items()} class Formatter(logging.Formatter): diff --git a/tools/rules/METADATA.bzl b/tools/rules/METADATA.bzl deleted file mode 100644 index a1e9c277630cf..0000000000000 --- a/tools/rules/METADATA.bzl +++ /dev/null @@ -1,9 +0,0 @@ -# THIS FILE IS AUTOMATICALLY GENERATED FROM INFORMATION STORED IN -# THIRD-PARTY METADATA SERVICE. YOUR MANUAL CHANGES TO THIS FILE WILL -# BE PRESERVED AND WILL SERVE AS THE SOURCE OF TRUTH FOR METADATA OF -# THIS PACKAGE. -# TPMS-GENERATED: b3448f8fd2a893772f944f37627e63917b77dede -METADATA = { - "name": "rules", - "owner": "pytorch_dev_infra", -} diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 45660d3ff7a9e..4acffdb1997f9 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -43,7 +43,9 @@ from torch._C import ( from torch._prims_common import DeviceLikeType from torch.autograd.graph import Node as _Node from torch.cuda import _POOL_HANDLE +from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor._op_schema import OpSchema +from torch.distributed.tensor.placement_types import Placement from torch.fx.node import Node as FxNode from torch.package import PackageExporter from torch.storage import TypedStorage, UntypedStorage @@ -1306,6 +1308,7 @@ def _group_tensors_by_device_and_dtype( tuple[list[list[Tensor | None]], list[_int]], ]: ... def _initCrashHandler() -> None: ... +def _set_warn_on_accumulate_grad_stream_mismatch(enabled: _bool) -> None: ... # NB: There is no Capsule type in typing, see # https://github.com/python/cpython/issues/109562 @@ -1956,6 +1959,9 @@ _TensorBase = TensorBase def _DTensor_OpSchema_post_init(self: OpSchema) -> None: ... def _DTensor_OpSchema_recompute_comparison_key(self: OpSchema) -> None: ... +def _DTensor_compute_global_tensor_info( + tensor: Tensor, mesh: DeviceMesh, placements: Sequence[Placement] +) -> tuple[list[_int], list[_int]]: ... # Defined in torch/csrc/multiprocessing/init.cpp def _multiprocessing_init() -> None: ... @@ -2038,6 +2044,8 @@ def _cuda_getDeviceCount() -> _int: ... def _cuda_set_sync_debug_mode(warn_level: _int | str) -> None: ... def _cuda_get_sync_debug_mode() -> _int: ... def _cuda_sleep(cycles: _int) -> None: ... +def _cuda_busy_wait_for_flag() -> None: ... +def _cuda_clear_flag() -> None: ... def _cuda_synchronize() -> None: ... def _cuda_ipc_collect() -> None: ... def _cuda_getArchFlags() -> str | None: ... @@ -2391,6 +2399,7 @@ def _xpu_resetAccumulatedMemoryStats(device: _int) -> None: ... def _xpu_resetPeakMemoryStats(device: _int) -> None: ... def _xpu_getMemoryInfo(device: _int) -> tuple[_int, _int]: ... def _xpu_canDeviceAccessPeer(device: _int, peer: _int) -> _bool: ... +def _xpu_getMemoryFraction(device: _int) -> _float: ... def _xpu_setMemoryFraction(fraction: _float, device: _int) -> None: ... class _XpuDeviceProperties: diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index da59123625e84..737362be62b48 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -849,7 +849,9 @@ class _SymmetricMemory: class ProcessGroupXCCL(Backend): class Options(Backend.Options): - def __init__(self): ... + is_high_priority_stream: bool + + def __init__(self, is_high_priority_stream: bool = False): ... def __init__( self, diff --git a/torch/__init__.py b/torch/__init__.py index b830704254f6a..1dd6aa0250243 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2644,7 +2644,16 @@ def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]: from torch._inductor.compiler_bisector import CompilerBisector if bisect_backend := CompilerBisector.get_backend(): - backend = bisect_backend + import torch._inductor.config as inductor_config + + # don't override the backend for use cases like vllm + # which leverages their custom backend. + if not ( + inductor_config.test_configs.bisect_keep_custom_backend_for_inductor + and bisect_backend == "inductor" + and not isinstance(backend, str) + ): + backend = bisect_backend guard_filter_fn = None if options and isinstance(options, dict): diff --git a/torch/_appdirs.py b/torch/_appdirs.py index 291963f6f6f62..9d8ad9487e255 100644 --- a/torch/_appdirs.py +++ b/torch/_appdirs.py @@ -445,7 +445,7 @@ def user_log_dir(appname=None, appauthor=None, version=None, opinion=True): return path -class AppDirs(object): +class AppDirs: """Convenience wrapper for getting application dirs.""" def __init__( diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index c4396932818d3..a321a49ac142e 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -58,7 +58,7 @@ def _should_decompose_because_unsafe_op(op: torch._ops.OperatorBase) -> bool: return False if torch.Tag.maybe_aliasing_or_mutating in op.tags: return True - return op == torch.ops.aten.native_batch_norm.default + return op is torch.ops.aten.native_batch_norm.default def _add_op_to_registry(registry, op, fn): @@ -404,6 +404,7 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.max_unpool3d, aten.mish, aten.mish_, + aten.mish_backward, aten.mse_loss, aten.mse_loss_backward, aten.multi_margin_loss, @@ -419,6 +420,7 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.native_dropout_backward, aten.native_group_norm_backward, aten.native_layer_norm_backward, + aten._fused_rms_norm, aten._fused_rms_norm_backward, aten.new_empty, aten.new_full, @@ -475,6 +477,7 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.silu, aten.silu_, aten.silu_backward.grad_input, + aten.silu_backward, aten.sinc, aten.sinc_, aten.slice_backward, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index ad08d26521908..a69a46c48b5f1 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1757,6 +1757,61 @@ def native_layer_norm_backward_out( return grad_input +@register_decomposition(aten._fused_rms_norm.default) +def _fused_rms_norm( + input: Tensor, + normalized_shape: list[int], + weight: Optional[Tensor], + eps: Optional[float], +) -> tuple[Tensor, Tensor]: + dims_to_reduce: list[int] = [] + for i in range(len(normalized_shape)): + dims_to_reduce.append(input.dim() - i - 1) + + # upcast is needed for fp16 and bf16 + computation_dtype = utils.get_computation_dtype(input.dtype) + upcasted_input = input.to(computation_dtype) + + # computation_dtype would be one of [Double, Float, ComplexFloat, ComplexDouble] + if eps is None: + if computation_dtype in (torch.float32, torch.complex64): + eps_val = torch.finfo(torch.float32).eps + else: + eps_val = torch.finfo(torch.float64).eps + else: + eps_val = eps + + rqrst_input = torch.rsqrt( + # NB: don't inplace here, will violate functional IR invariant + # NB: carefully use the Scalar overload of add to ensure compatibility with the C++ decomp + torch.ops.aten.add.Scalar( + torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True), eps_val + ) + ) + + upcasted_result = upcasted_input.mul(rqrst_input) + + if weight is not None: + upcasted_result = upcasted_result.mul(weight) + + # NB: nested should be dead here, just here for fidelity + is_nested = input.is_nested or (weight is not None and weight.is_nested) + memory_format = utils.suggest_memory_format(input) + is_channels_last = memory_format in ( + torch.channels_last, + torch.channels_last_3d, + ) + + if not is_nested and not is_channels_last: + upcasted_result = upcasted_result.contiguous() + rqrst_input = rqrst_input.contiguous() + + # Cast normalized result back to original input type + result = upcasted_result.type_as(input) + + return result, rqrst_input + + @register_decomposition(aten._fused_rms_norm_backward.default) def _fused_rms_norm_backward( grad_out: Tensor, diff --git a/torch/_dispatch/python.py b/torch/_dispatch/python.py index e6b3f09c22fc2..98f6ccf78bb89 100644 --- a/torch/_dispatch/python.py +++ b/torch/_dispatch/python.py @@ -115,7 +115,7 @@ def make_crossref_functionalize( from torch._subclasses.fake_tensor import FakeTensorMode # This case is pretty weird, suppress it for now - if op == torch.ops.aten.lift_fresh.default: + if op is torch.ops.aten.lift_fresh.default: return final_key def handler(*args: _P.args, **kwargs: _P.kwargs) -> _T: diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 8666d12bddcbc..28a77d20ea3b0 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -26,6 +26,7 @@ allow_in_graph, assume_constant_result, disable, + disable_nested_graph_breaks, disallow_in_graph, dont_skip_tracing, error_on_graph_break, @@ -78,6 +79,7 @@ "assume_constant_result", "config", "disable", + "disable_nested_graph_breaks", "disallow_in_graph", "dont_skip_tracing", "export", @@ -153,7 +155,6 @@ def reset() -> None: GenerationTracker.clear() TensorifyState.clear() torch._dynamo.utils.warn_once_cache.clear() - torch._dynamo.utils.user_obj_id_to_weakref.clear() torch._C._autograd._saved_tensors_hooks_set_tracing(False) diff --git a/torch/_dynamo/_trace_wrapped_higher_order_op.py b/torch/_dynamo/_trace_wrapped_higher_order_op.py index c7679a9300a01..1de308b803702 100644 --- a/torch/_dynamo/_trace_wrapped_higher_order_op.py +++ b/torch/_dynamo/_trace_wrapped_higher_order_op.py @@ -140,7 +140,7 @@ def __torch_function__( args: tuple[object, ...] = (), kwargs: Optional[dict[str, object]] = None, ) -> object: - if func == torch.Tensor.__getitem__: + if func is torch.Tensor.__getitem__: index_args = pytree.tree_leaves(args[1]) if all(isinstance(x, torch.Tensor) for x in index_args): return mod_index(args[0], index_args) diff --git a/torch/_dynamo/aot_compile.py b/torch/_dynamo/aot_compile.py index 4b13d677f5abb..000d977d29f36 100644 --- a/torch/_dynamo/aot_compile.py +++ b/torch/_dynamo/aot_compile.py @@ -50,6 +50,7 @@ class CompileArtifacts: compiled_fn: SerializableCallable original_code: types.CodeType closure: Optional[tuple[Any, ...]] + argdefs: Optional[tuple[Any, ...]] source_info: "SourceInfo" device_type: str system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current) @@ -111,7 +112,10 @@ def __post_init__(self) -> None: } # pyrefly: ignore [read-only] self.fn = types.FunctionType( - self._artifacts.bytecode, f_globals, closure=self._artifacts.closure + self._artifacts.bytecode, + f_globals, + closure=self._artifacts.closure, + argdefs=self._artifacts.argdefs, ) if self._artifacts.guard_manager is None: @@ -266,6 +270,7 @@ def new_guard_filter_fn( compiled_fn=compiled_fn, original_code=fn.__code__, closure=fn.__closure__, + argdefs=fn.__defaults__, source_info=source_info, device_type=device_type, ) diff --git a/torch/_dynamo/aot_compile_types.py b/torch/_dynamo/aot_compile_types.py index 2d605531bd094..547a0bbdc915d 100644 --- a/torch/_dynamo/aot_compile_types.py +++ b/torch/_dynamo/aot_compile_types.py @@ -48,7 +48,7 @@ def serialize_compile_artifacts( @classmethod def deserialize_compile_artifacts(cls, data: bytes) -> Any: - from torch._functorch._aot_autograd.autograd_cache import ( + from torch._functorch._aot_autograd.aot_autograd_result import ( deserialize_bundled_cache_entry, ) diff --git a/torch/_dynamo/backends/common.py b/torch/_dynamo/backends/common.py index 781315d95346e..2ffd9523bdf15 100644 --- a/torch/_dynamo/backends/common.py +++ b/torch/_dynamo/backends/common.py @@ -104,7 +104,7 @@ def _wrapped_bw_compiler(*args: P.args, **kwargs: P.kwargs) -> R: # debug asserts slow down compile time noticeably, # So only default them on when the aot_eager backend is used. - if self.kwargs.get("fw_compiler", None) == nop: + if self.kwargs.get("fw_compiler", None) is nop: patch_config: contextlib.AbstractContextManager[Any] = patch( "functorch.compile.config.debug_assert", True ) diff --git a/torch/_dynamo/backends/debugging.py b/torch/_dynamo/backends/debugging.py index 360a3d7335303..0e62e08cf1fc9 100644 --- a/torch/_dynamo/backends/debugging.py +++ b/torch/_dynamo/backends/debugging.py @@ -369,7 +369,7 @@ def relu_compile_error_TESTING_ONLY( gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] ) -> torch.fx.GraphModule: for node in gm.graph.nodes: - if node.target == torch.relu: + if node.target is torch.relu: raise ReluCompileError return gm @@ -379,7 +379,7 @@ def relu_runtime_error_TESTING_ONLY( gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] ) -> torch.fx.GraphModule: for node in gm.graph.nodes: - if node.target == torch.relu: + if node.target is torch.relu: node.target = torch._assert node.args = (False, "ReluRuntimeError") gm.recompile() @@ -391,7 +391,7 @@ def relu_accuracy_error_TESTING_ONLY( gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor] ) -> torch.fx.GraphModule: for node in gm.graph.nodes: - if node.target == torch.relu: + if node.target is torch.relu: node.target = torch.add node.args = (node.args[0], 1) gm.recompile() diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py index ed6e17f8aa25d..31c1d243de721 100644 --- a/torch/_dynamo/bytecode_transformation.py +++ b/torch/_dynamo/bytecode_transformation.py @@ -415,6 +415,7 @@ def create_call_function_ex( and not ignore_314_kwargs_push ): output.append(create_instruction("PUSH_NULL")) + has_kwargs = True if push_null: output.append(create_instruction("PUSH_NULL")) # 3.13 swapped NULL and callable diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 3a933f3de34a4..1861b20105265 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -153,7 +153,7 @@ def add_push_null( self.clear_tos() def __call__( - self, value: Union[VariableTracker, Source], allow_cache: bool = True + self, value: Union[VariableTracker, Source, None], allow_cache: bool = True ) -> None: """ Generate code such that top-of-stack (TOS) is set to value. @@ -188,7 +188,7 @@ def __call__( value to handle aliasing (check side_effects.py and search for allow_cache=False). - b) If value.source is None, this is not allowed. TODO - assert this. + b) If value.source is None, this is not allowed Notable effects: 1. `self.top_of_stack` will be set to `value`, if we don't codegen @@ -197,6 +197,7 @@ def __call__( `top_of_stack` or cached `tempvars`, or (b). `value` has special VT types like `NNModuleVariable`, etc. """ + assert value is not None if isinstance(value, Source): # If the source needs to be overridden, use the new one. source = self.overridden_sources.get(value, value) @@ -289,7 +290,8 @@ def __call__( self.load_graph_output(graph_outputs[graph_outputs_key].index) output.append( self.create_load_global( - value.global_mangled_class_name(self.tx), add=True + value.global_mangled_class_name(self.tx), # type: ignore[arg-type] + add=True, ) ) output.extend(create_call_function(2, False)) diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 5af72310b3a7f..cace23af20565 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -147,7 +147,7 @@ def prep_with_graph(self, graph: torch.fx.Graph) -> None: # so Compiled Autograd will always lift the param and # this should always be true assert ( - param_node.target == operator.getitem + param_node.target is operator.getitem and param_node.args[0] is inputs_node # type: ignore[possibly-undefined] and isinstance(param_node.args[1], int) ) @@ -573,7 +573,7 @@ def make_unique(node_name: str) -> str: result.name = make_unique(node.name) value_remap[node] = result elif node.op == "call_function": - if node.target == torch.ops.aten.view.default: + if node.target is torch.ops.aten.view.default: # this aot bwd graph is being lazily compiled # we must manually apply the view_to_reshape post grad pass # since it was already applied to the aot fwd, and baked into the gradients @@ -755,7 +755,6 @@ def proxy_call( self, fn: Callable[..., Any], args: Any, output_metadata: Sequence[Any] ) -> Sequence[torch.Tensor]: """Proxies a call to fn(*args) into the graph""" - flat_args, _ = pytree.tree_flatten(args) proxy_args = pytree.tree_map(lambda e: self.to_proxy(e), args) proxy_out = self.fx_tracer.create_proxy( "call_function", fn, args=proxy_args, kwargs={} @@ -997,7 +996,7 @@ def remove_unused_sizes(self) -> set[int]: assert sizes_node.name == "sizes" for getitem_node in sizes_node.users.keys(): - assert getitem_node.target == operator.getitem + assert getitem_node.target is operator.getitem if getitem_node.users: used_sizes.append(getitem_node) else: @@ -1159,7 +1158,7 @@ def get_all_nodes(args: Sequence[Any]) -> list[torch.fx.Node]: def is_placeholder(node: torch.fx.Node) -> bool: if node.op == "placeholder" or ( node.op == "call_function" - and node.target == operator.getitem + and node.target is operator.getitem and node.args[0].op == "placeholder" # type: ignore[union-attr, arg-type] ): return True @@ -1176,7 +1175,7 @@ def reorder_accumulate_grad_nodes(self) -> None: ): param_node, grad_node = node.args[0], node.args[1] getitem_node = None - if grad_node.target == operator.getitem: + if grad_node.target is operator.getitem: getitem_node = grad_node grad_node = getitem_node.args[0] @@ -1241,7 +1240,7 @@ def reorder_pre_hook_nodes_to_schedule_asap(self) -> None: to_append = [] hook_block = [node] # contain the hook and hook args getitem for n in input_nodes: - if n.op == "call_function" and n.target == operator.getitem: + if n.op == "call_function" and n.target is operator.getitem: to_append.append(n.args[0]) to_remove.append(n) hook_block.append(n) @@ -1279,7 +1278,7 @@ def reorder_pre_hook_nodes_to_mimic_eager(self) -> None: # users are all getitem ops and they are used by same registered node assert all( - user.op == "call_function" and user.target == operator.getitem + user.op == "call_function" and user.target is operator.getitem for user in users ) registered_node = next(iter(users[0].users.keys())) @@ -1314,7 +1313,7 @@ def reorder_post_acc_grad_hook_nodes(self) -> None: # find the corresponding acc_grad node acc_grad_node = None for n in list(param_node.users.keys()): - if n.op == "call_function" and n.target == call_accumulate_grad: + if n.op == "call_function" and n.target is call_accumulate_grad: acc_grad_node = n break @@ -1357,19 +1356,19 @@ def reorder_post_hook_nodes(self) -> None: for user in list(input_node.users.keys()) if not ( user.op == "call_function" - and user.target == call_hook + and user.target is call_hook and node.kwargs.get("hook_type", None) == "post_hook" ) ) arg = max(input_nodes_and_users) # last input users - if arg.op == "call_function" and arg.target == call_accumulate_grad: + if arg.op == "call_function" and arg.target is call_accumulate_grad: param_node = arg.args[0] post_acc_grad_hook_node = None for n in list(param_node.users.keys()): if ( n.op == "call_function" - and n.target == call_hook + and n.target is call_hook and n.kwargs.get("hook_type", None) == "post_acc_grad_hook" ): post_acc_grad_hook_node = n diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 74a53c6d9c4bf..875f640194e42 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -882,7 +882,9 @@ def build_guards( strict_error=strict_error, ) - def graph_capture_output(self) -> GraphCaptureOutput: + def graph_capture_output( + self, argdefs: Optional[tuple[Any, ...]] = None + ) -> GraphCaptureOutput: output_graph = self.tracer_output.output_graph assert output_graph is not None return GraphCaptureOutput( @@ -897,6 +899,7 @@ def graph_capture_output(self) -> GraphCaptureOutput: output_graph.traced_code, self.bytecode, self.tracer_output.closure, + argdefs, ) @@ -929,6 +932,7 @@ class GraphCaptureOutput: traced_code: list[CodeType] bytecode: CodeType closure: Optional[tuple[Any, ...]] + argdefs: Optional[tuple[Any, ...]] def build_guards( self, @@ -984,6 +988,7 @@ def forward_callable(self) -> Callable[..., Any]: self.graph_capture_output.bytecode, f_globals, closure=self.graph_capture_output.closure, + argdefs=self.graph_capture_output.argdefs, ) @@ -995,7 +1000,10 @@ def get_traced_fn(mod: Any) -> tuple[FunctionType, Optional[object]]: import inspect if isinstance(mod, torch.nn.Module): - mod = mod.forward + if len(mod._forward_pre_hooks) == 0 and len(mod._forward_hooks) == 0: + mod = mod.forward + else: + mod = mod.__call__ if hasattr(mod, "__self__"): # pyrefly: ignore [missing-attribute] return mod.__func__, mod.__self__ @@ -1044,6 +1052,7 @@ def _get_frame( f_locals, builtins.__dict__, closure=fn.__closure__ or (), # type: ignore[arg-type] + argdefs=fn.__defaults__, ) @@ -1093,6 +1102,7 @@ class FrameInfo: locals: dict[str, object] builtins: dict[str, object] closure: tuple[CellType] + argdefs: Optional[tuple[Any, ...]] def _fullgraph_capture_frame( @@ -1146,7 +1156,7 @@ def fullgraph_compiler( raise e.with_traceback(None) from e.__cause__ # User compiler error return CaptureOutput( - dynamo_output.graph_capture_output(), + dynamo_output.graph_capture_output(frame.argdefs), backend_input, ) diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 47e5fdb12dfcc..e16fa11ed08f6 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -455,7 +455,7 @@ def cast_dtype_args_to_fp64(model: torch.fx.GraphModule) -> torch.fx.GraphModule for node in model.graph.nodes: if ( node.op == "call_function" - and node.target == torch.ops.prims.convert_element_type.default + and node.target is torch.ops.prims.convert_element_type.default ): assert len(node.args) == 2 if is_float_dtype(node.args[1]) and node.args[1] != torch.float64: diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 50fdadbb8fbbd..144f0ea7eeefa 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -890,6 +890,7 @@ def __exit__( "allow_unspec_int_on_nn_module", "skip_torchrec", "dont_skip_tracing", + "nested_graph_breaks", ) from . import config @@ -965,6 +966,26 @@ def dont_skip_tracing(fn: Optional[Any] = None) -> Any: return ctx +@overload +def disable_nested_graph_breaks(fn: None = None) -> DynamoConfigPatchProxy: ... + + +@overload +def disable_nested_graph_breaks(fn: Callable[_P, _R]) -> Callable[_P, _R]: ... + + +def disable_nested_graph_breaks(fn: Optional[Any] = None) -> Any: + """ + Context manager/decorator to disable nested graph breaks when tracing + this function and any nested functions. Used when nested graph breaks + is causing problems. + """ + ctx = patch_dynamo_config(nested_graph_breaks=False) + if fn: + return ctx(fn) + return ctx + + class ErrorOnGraphBreakDecoratorContextManager: def __init__(self, error_on_graph_break: bool) -> None: self.error_on_graph_break = error_on_graph_break diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 070d26a4699c4..cbcb8c5de9be9 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -42,7 +42,7 @@ from dataclasses import dataclass from enum import Enum from os.path import dirname, join -from typing import Any, NamedTuple, Optional, Sized, TYPE_CHECKING, Union +from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union from unittest.mock import patch import sympy @@ -395,13 +395,6 @@ def __init__(self, mod: torch.nn.Module, dynamo_ctx: _TorchDynamoContext) -> Non self._initialize() self.training = self._orig_mod.training - def __len__(self) -> int: - # Proxy the len call to the original module - if isinstance(self._orig_mod, Sized): - return len(self._orig_mod) - # Mimic python's default behavior for objects without a length - raise TypeError(f"{type(self._orig_mod).__name__} does not support len()") - def _initialize(self) -> None: # Do this stuff in constructor to lower overhead slightly if isinstance(self.dynamo_ctx, DisableContext): @@ -1791,7 +1784,7 @@ def produce_matching( for i, val in enumerate(sources): dict_of_source_vals[id(val)] = i - for i, val in enumerate(candidates): + for val in candidates: if isinstance(val, tuple(common_constant_types)): matched_elements_positions.append(None) elif id(val) not in dict_of_source_vals: diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index c6bed34c0d8d5..857c694ad92c6 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -273,7 +273,8 @@ class RecompileLimitExceeded(Unsupported): # debug exception thrown when tracing torch._dynamo.step_unsupported() class StepUnsupported(TorchDynamoException): - pass + def __init__(self) -> None: + self.real_stack = torch._guards.TracingContext.extract_stack() class UnsafeScriptObjectError(TorchDynamoException): diff --git a/torch/_dynamo/functional_export.py b/torch/_dynamo/functional_export.py index d7c4f6a300101..23b02e69a5640 100644 --- a/torch/_dynamo/functional_export.py +++ b/torch/_dynamo/functional_export.py @@ -16,6 +16,7 @@ from torch._dynamo.exc import UserErrorType from torch._dynamo.utils import dynamo_timed, get_metrics_context from torch._export.utils import _compiling_state_context +from torch._guards import TracingContext from torch.export.dynamic_shapes import _RelaxedConstraint, Constraint from torch.fx import Node from torch.fx.experimental.proxy_tensor import make_fx @@ -138,10 +139,10 @@ def _process_source_fn(source_fn_stack): node.meta["nn_module_stack"] = _process_nn_module_stack( node.meta["nn_module_stack"].copy() ) - if "source_fn_stack" in node.meta: - node.meta["source_fn_stack"] = _process_source_fn( - node.meta["source_fn_stack"].copy() - ) + + source_fn_stack = node.meta.get("source_fn_stack", None) + if source_fn_stack: + node.meta["source_fn_stack"] = _process_source_fn(source_fn_stack.copy()) if "dynamo_flat_name_to_original_fqn" in graph_module.meta: # Clean up flat name to original fqn mapping @@ -449,6 +450,14 @@ def _suggest_or_raise_constraint_violation( raise constraint_violation_error +def _normalize_shuffle_graph(shuffle_gm: torch.fx.GraphModule) -> None: + shuffle_gm.graph.eliminate_dead_code() + shuffle_gm.recompile() + for name, buffer in list(shuffle_gm.named_buffers()): + delattr(shuffle_gm, name) + setattr(shuffle_gm, name, buffer) + + @dataclass(frozen=True) class PyTreeifyOutput: graph_module: torch.fx.GraphModule @@ -525,6 +534,7 @@ def backend_dummy(*example_inputs): in_shuffle_graph = make_fx( InShuffle(), tracing_mode="symbolic", proxy_module_inputs=True )(*flat_real_args) + _normalize_shuffle_graph(in_shuffle_graph) output_node = next(iter(reversed(backend_input.graph_module.graph.nodes))) @@ -572,6 +582,7 @@ def backend_dummy(*example_inputs): out_shuffle_graph = make_fx( out_shuffle, tracing_mode="symbolic", proxy_module_inputs=True )(*flat_out_shuffle_args) + _normalize_shuffle_graph(out_shuffle_graph) assert out_shuffle.out_spec is not None return PyTreeifyOutput( @@ -650,6 +661,10 @@ def inner(*args: Any, **kwargs: Any) -> Any: ) assert out.backend_input is not None graph_module.meta["fake_mode"] = out.backend_input.fake_mode # type: ignore[attr-defined] + graph_module.meta["fake_mode"].allow_non_fake_inputs = True + tracing_context = TracingContext(graph_module.meta["fake_mode"]) + tracing_context.tensor_to_context = out.backend_input.tensor_to_context # type: ignore[attr-defined] + graph_module.meta["tracing_context"] = tracing_context return graph_module return inner diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index 1896240317a29..dc56f42b52ad9 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -2495,6 +2495,30 @@ } ], "GB0249": [ + { + "Gb_type": "bad device argument to torch.accelerator.current_stream", + "Context": "args={args}, kwargs={kwargs}", + "Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + }, + { + "Gb_type": "bad device argument to torch.get_device_module", + "Context": "args={args}, kwargs={kwargs}", + "Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + }, + { + "Gb_type": "bad device argument to torch.accelerator.current_stream", + "Context": "args={args}, kwargs={kwargs}", + "Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + }, { "Gb_type": "bad device argument to torch.get_device_module", "Context": "args={args}, kwargs={kwargs}", @@ -2851,5 +2875,45 @@ "Move the Placement usage outside the compiled region" ] } + ], + "GB0283": [ + { + "Gb_type": "Failed to make weakref to graph-created external object", + "Context": "user_object: {example_value}", + "Explanation": "Object does not allow us to make a weakref to it", + "Hints": [] + } + ], + "GB0284": [ + { + "Gb_type": "cannot resume from torch._dynamo.step_unsupported()", + "Context": "", + "Explanation": "traced torch._dynamo.step_unsupported(), but Dynamo is instructed to error on graph break. This graph break is used for debugging only.", + "Hints": [ + "Remove the torch._dynamo.step_unsupported() call.", + "Make sure fullgraph=False and error_on_graph_break=False.", + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], + "GB0285": [ + { + "Gb_type": "unsupported arguments to torch.accelerator.current_stream", + "Context": "args={args}, kwargs={kwargs}", + "Explanation": "torch.accelerator.current_stream accepts one optional argument `device`", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } + ], + "GB0286": [ + { + "Gb_type": "bad device argument to torch.get_device_module", + "Context": "args={args}, kwargs={kwargs}", + "Explanation": "Expected valid string/torch.device argument ('cpu', 'cuda', etc.)", + "Hints": [ + "Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled." + ] + } ] } diff --git a/torch/_dynamo/graph_bytecode_inputs.py b/torch/_dynamo/graph_bytecode_inputs.py index 7836478b51782..979950cf3bd1b 100644 --- a/torch/_dynamo/graph_bytecode_inputs.py +++ b/torch/_dynamo/graph_bytecode_inputs.py @@ -1,9 +1,11 @@ import weakref -from typing import Any +from typing import Any, Callable from torch._dynamo.source import Source +PyCodegen = Any + # This file is to handle types that we don't want to support # as explicit FX graph inputs. This uses a sidetable which # we populate in bytecode and is loaded during graph execution @@ -11,44 +13,70 @@ # We use a dynamo-generated index as a level of indirection # this allows us to register objects externally in pre-graph bytecode that we want # to pass to the graph, but not support their types as graph inputs -index_to_source: dict[int, Source] = {} +index_to_bytecode_constructor: dict[int, Callable[[PyCodegen], None]] = {} + +index_to_external_object_weakref: dict[int, weakref.ReferenceType[Any]] = {} -index_to_user_object_weakref: dict[int, weakref.ReferenceType[Any]] = {} +keep_alive: list[Any] = [] def has_user_objects() -> bool: - return bool(index_to_source) + return bool(index_to_bytecode_constructor) -def get_user_object_by_index(index: int) -> Any: - assert index in index_to_user_object_weakref, ( +def get_external_object_by_index(index: int) -> Any: + assert index in index_to_external_object_weakref, ( "Index not registered in index_to_user_object_weakref" ) - obj = index_to_user_object_weakref[index]() + obj = index_to_external_object_weakref[index]() assert obj is not None, "User object is no longer alive" - return index_to_user_object_weakref[index]() + return index_to_external_object_weakref[index]() def store_user_object_weakrefs(*args: Any) -> None: - global index_to_user_object_weakref - index_to_user_object_weakref.clear() - index_to_user_object_weakref.update( + global index_to_external_object_weakref + index_to_external_object_weakref.clear() + index_to_external_object_weakref.update( {i: weakref.ref(arg) for i, arg in enumerate(args)} ) def reset_user_object_tracking() -> None: - index_to_source.clear() - index_to_user_object_weakref.clear() + index_to_bytecode_constructor.clear() + index_to_external_object_weakref.clear() + keep_alive.clear() + + +def register_graph_created_object( + example_value: Any, construct_fn: Callable[[int, PyCodegen], None] +) -> int: + global index_to_bytecode_constructor + global keep_alive + keep_alive.append(example_value) + index = len(index_to_bytecode_constructor) + index_to_bytecode_constructor[index] = lambda cg: construct_fn(index, cg) + try: + index_to_external_object_weakref[index] = weakref.ref(example_value) + except TypeError as e: + from .exc import unimplemented_v2 + + unimplemented_v2( + gb_type="Failed to make weakref to graph-created external object", + context=f"user_object: {example_value}", + explanation="Object does not allow us to make a weakref to it", + hints=[], + from_exc=e, + ) + return index # Register a user object to be used in the graph def register_user_object(value: Any, source: Source) -> int: - global index_to_source - index = len(index_to_source) - index_to_source[index] = source + global index_to_bytecode_constructor + index = len(index_to_bytecode_constructor) + index_to_bytecode_constructor[index] = lambda cg: cg(source) try: - index_to_user_object_weakref[index] = weakref.ref(value) + index_to_external_object_weakref[index] = weakref.ref(value) except TypeError as e: from .exc import unimplemented_v2 diff --git a/torch/_dynamo/graph_deduplication.py b/torch/_dynamo/graph_deduplication.py index dcc6302dd7c64..0517fd5c1df8b 100644 --- a/torch/_dynamo/graph_deduplication.py +++ b/torch/_dynamo/graph_deduplication.py @@ -456,10 +456,10 @@ def _add_mutation_dependencies( for user in mutated_arg.users: if user is node: continue - # pyrefly: ignore [unsupported-operation] + # pyrefly: ignore # unsupported-operation elif user < node: node_to_additional_deps[node].add(user) - # pyrefly: ignore [unsupported-operation] + # pyrefly: ignore # unsupported-operation elif user > node: node_to_additional_deps[user].add(node) @@ -529,7 +529,7 @@ def _is_tuple_node(node: Node) -> bool: def _get_children_getitems(node: Node) -> Generator[Node, None, None]: for user in node.users: - if user.target == operator.getitem and isinstance(user.args[1], int): + if user.target is operator.getitem and isinstance(user.args[1], int): yield user diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index d5869b9b29f51..2792ce512d8a1 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -132,6 +132,7 @@ CodeSource, ConstantSource, ConstDictKeySource, + CurrentStreamSource, DataclassFieldsSource, DefaultsSource, DictGetItemSource, @@ -181,6 +182,7 @@ common_constant_types, dataclass_fields, dict_keys, + get_current_stream, get_custom_getattr, get_torch_function_mode_stack, get_torch_function_mode_stack_at, @@ -317,7 +319,7 @@ def visit_dict_manager(node: DictGuardManager) -> bool: is_diff_guard_node = ( node.get_source() in self.diff_guard_sources or node.fail_count() > 0 ) - for idx, (key_mgr, val_mgr) in sorted( + for _idx, (key_mgr, val_mgr) in sorted( node.get_key_value_managers().items() ): is_diff_guard_node |= visit(key_mgr) | visit(val_mgr) @@ -440,7 +442,7 @@ def visit_dict_manager(node: DictGuardManager) -> list[GuardManager]: is_subtree_tag_safe = True # Recurse to get the tag safe roots from subtree. - for idx, (key_mgr, val_mgr) in sorted( + for _idx, (key_mgr, val_mgr) in sorted( node.get_key_value_managers().items() ): if key_mgr is not None: @@ -448,9 +450,7 @@ def visit_dict_manager(node: DictGuardManager) -> list[GuardManager]: if val_mgr is not None: tag_safe_roots.extend(visit(val_mgr)) - for idx, (key_mgr, val_mgr) in sorted( - node.get_key_value_managers().items() - ): + for key_mgr, val_mgr in node.get_key_value_managers().values(): if key_mgr: is_subtree_tag_safe &= key_mgr.is_tag_safe() @@ -761,6 +761,7 @@ def _get_closure_vars() -> dict[str, object]: "___dataclass_fields": dataclass_fields, "___namedtuple_fields": lambda x: x._fields, "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at, + "___get_current_stream": get_current_stream, "__math_isnan": math.isnan, "__numpy_isnan": None if np is None else np.isnan, "inf": float("inf"), @@ -1450,6 +1451,13 @@ def get_guard_manager_from_source(self, source: Source) -> GuardManager: example_value=example_value, guard_manager_enum=guard_manager_enum, ) + elif istype(source, CurrentStreamSource): + out = root_guard_manager.lambda_manager( + python_lambda=lambda _: get_current_stream(source.device), + source=source_name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, + ) elif istype(source, GradSource): assert base_guard_manager # to make mypy happy out = base_guard_manager.grad_manager( @@ -2112,7 +2120,6 @@ def hooks_ids_fn( if not are_inline_hooks(hooks): return None - pack_hook, unpack_hook = hooks return tuple(map(id, hooks)) guard_hooks_ids = hooks_ids_fn(get_hooks()) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 50638ccbba0eb..77f5d6cb05a01 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -101,7 +101,7 @@ unimplemented_v2, unimplemented_v2_with_warning, ) -from .graph_bytecode_inputs import has_user_objects, index_to_source +from .graph_bytecode_inputs import has_user_objects, index_to_bytecode_constructor from .graph_deduplication import apply_graph_deduplication from .graph_region_tracker import GraphRegionTracker from .guards import GuardBuilder, install_guard @@ -1060,7 +1060,7 @@ def update_co_names(self, name: str) -> None: def module_key_name(*names: Any) -> str: # create a new unique name name = "_".join(map(str, names)) - # Strip _buffers[..]/_parmeters[..]/_modules[..] names + # Strip _buffers[..]/_parameters[..]/_modules[..] names name = re.sub( r"\._(?:modules|parameters|buffers)\[(['\"])([^'\"\]]+)\1\]", r".\2", name ) @@ -1303,6 +1303,7 @@ def handle_aliases_for_stolen_lists( # A small codegen optimization because we might have different # VariableTrackers that share the same source. + assert x.source is not None list_idx = x.source.index # type: ignore[attr-defined] if list_idx not in visited: alias_name = self.new_var( @@ -1321,6 +1322,7 @@ def handle_aliases_for_stolen_lists( ) # operate on alias, handled by suffix codegen + assert x.source is not None old_source = x.source overridden_sources[old_source] = LocalSource(visited[list_idx]) @@ -1539,9 +1541,19 @@ def compile_subgraph( "store_user_object_weakrefs", ) ) - for source in reversed(index_to_source.values()): - codegen(source) - codegen.call_function(len(index_to_source), False) + tmp_vars = [] + for constructor in reversed(index_to_bytecode_constructor.values()): + constructor(codegen) + var_name = ( + self.new_var() + ) # keep alive any temp objects for the rest of the frame + codegen.store(var_name) + tmp_vars.append(var_name) + + for var_name in tmp_vars: + codegen.append_output(codegen.create_load(var_name)) + + codegen.call_function(len(index_to_bytecode_constructor), False) codegen.pop_top() self.add_output_instructions(codegen.get_instructions()) @@ -1864,7 +1876,6 @@ def compile_subgraph( and isinstance(var.value, _ExportModuleSpecTrackerDict) ): potential_side_effects.append(var) - side_effect_refs = [ _get_source_debug_name(var.source) for var in potential_side_effects ] @@ -2204,10 +2215,21 @@ def compile_and_call_fx_graph( with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting + + # Why create a new FakeTensorMode? + # + # The reason this needs to be done is because when we do Dynamo tracing, fake + # tensors can have their metadata mutated. Thus, the fake tensor we allocated + # for any given tensor may no longer be valid for the beginning trace of the + # graph. Nor is it convenient to "clone" the input tensors before mutating them, + # since you have to preserve aliasing. So we just reconstruct the FakeTensorMode + # from scratch when we go to AOTAutograd. But the ShapeEnv must be preserved as + # Dynamo made decisions about what is dynamic or not / guards from the user code + # that is not in graph. backend_fake_mode = torch._subclasses.FakeTensorMode( shape_env=old_fake_mode.shape_env, ) - # TODO(voz): Ostensibily, this should be scoped and + # TODO(voz): Ostensibly, this should be scoped and # restore back to old_fake_mode, but doing so currently violates # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode self.tracing_context.fake_mode = backend_fake_mode @@ -3404,7 +3426,7 @@ def lift_tracked_freevar_to_input( if proxy in self.lifted_freevars: return self.lifted_freevars[proxy] - # We first lift proxy to parent's graph then lift to current grpah's input + # We first lift proxy to parent's graph then lift to current graph's input # so that when we bind symints of the sizes in current graph, those symints # would already be lifted as inputs to parent graph. if proxy.tracer != self.parent: @@ -3452,7 +3474,7 @@ def maybe_lift_tracked_freevar_to_input(self, arg: Any) -> Any: def track_produced_symints( self, example_value: Any, e_proxy: Union[LazyProxy, torch.fx.Proxy] ) -> None: - # When binding the symbols in an exmaple_value, we bind the symbols + # When binding the symbols in an example_value, we bind the symbols # to the proxy's associated Tracer instead of current tracer. # This is because: # 1. We may be calling wrap_tensors during speculate_subgraph because diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index 7a0efe79d5cfd..94f3c2d689b6a 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -35,6 +35,8 @@ from typing import Any, IO, Optional, TYPE_CHECKING, Union from typing_extensions import Unpack +import sympy + try: from triton.runtime.autotuner import Autotuner, Heuristics @@ -403,6 +405,7 @@ def generate_compiler_repro_string( # pyrefly: ignore [missing-attribute] kernel._fn_name if isinstance(kernel, JITFunction) + # pyrefly: ignore # missing-attribute else kernel.fn._fn_name ) fn_name = fn_name.split(".")[-1] @@ -441,22 +444,36 @@ def generate_compiler_repro_string( # Extract symbolic variables from the same arguments # pyrefly: ignore [unbound-name] - if isinstance(arg, torch.SymInt): - sym_name = str(arg.node) - if arg.node.hint is not None: - used_syms[sym_name] = arg.node.hint + if ( + # pyrefly: ignore [unbound-name] + isinstance(arg, torch.SymInt) + # By checking sympy.Symbol, we are excluding any symbolic expressions. + # TODO: we may need to solve expressions to extract symbol definitions. + and isinstance(arg.node.expr, sympy.Symbol) + and arg.node.hint is not None + ): + used_syms[str(arg.node)] = arg.node.hint # pyrefly: ignore [unbound-name] elif isinstance(arg, torch.Tensor): # Extract symbolic variables from tensor shapes and strides for dim in arg.shape: # pyrefly: ignore [unbound-name] - if isinstance(dim, torch.SymInt) and dim.node.hint is not None: + if ( + # pyrefly: ignore [unbound-name] + isinstance(dim, torch.SymInt) + and isinstance(dim.node.expr, sympy.Symbol) + and dim.node.hint is not None + ): used_syms[str(dim.node)] = dim.node.hint for stride in arg.stride(): # pyrefly: ignore [unbound-name] - if isinstance(stride, torch.SymInt) and stride.node.hint is not None: + if ( + # pyrefly: ignore [unbound-name] + isinstance(stride, torch.SymInt) + and isinstance(stride.node.expr, sympy.Symbol) + and stride.node.hint is not None + ): used_syms[str(stride.node)] = stride.node.hint - # Add symbolic variable definitions to the top of the generated code if used_syms: hint_lines = "\n".join( diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index eed90ed5a9c67..bd38e9295a05a 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -258,6 +258,7 @@ def check_allowed_side_effect(self, item: VariableTracker) -> bool: "Dynamo needs to fully exhaust the generator, which may cause " "unintended variable modifications." ) + assert item.mutation_type is not None if not is_side_effect_safe(item.mutation_type): # TODO plumb HOP information here unimplemented_v2( @@ -373,7 +374,7 @@ def is_modified(self, item: VariableTracker) -> bool: if self.is_attribute_mutation(item): return item in self.store_attr_mutations - + assert item.mutation_type is not None return item.mutation_type.is_modified # type: ignore[attr-defined] def _track_obj( @@ -625,6 +626,12 @@ def is_live(var: VariableTracker) -> bool: cur_tx: Optional[InstructionTranslatorBase] = tx while cur_tx is not None: init_live_vars.extend([cur_tx.stack, cur_tx.symbolic_locals]) + if cur_tx.parent is not None: + # for non-root tx'es, also keep the cells/freevars alive so they get codegen'd properly + # TODO see if we could prune dead cells - cell pruning information needs to be forwarded + # to the resume function creation as well. + assert cur_tx.post_prune_cell_and_freevars is not None + init_live_vars.append(cur_tx.post_prune_cell_and_freevars) cur_tx = cur_tx.parent VariableTracker.visit( visit, diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 9fb4f32d68ad4..8edd8f7540e31 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -23,6 +23,7 @@ from collections.abc import Callable from typing import Any, Optional, TYPE_CHECKING, Union +from torch import device as device_type from torch._guards import ChainedSource, Guard, GuardSource, Source from . import utils @@ -111,11 +112,14 @@ def is_constant_source(source: Source) -> bool: return False -def _get_source_debug_name(source: Source) -> str: - try: - return source.name() - except NotImplementedError: +def _get_source_debug_name(source: Optional[Source]) -> str: + if source is None: return "" + else: + try: + return source.name() + except NotImplementedError: + return "" @dataclasses.dataclass(frozen=True) @@ -1079,6 +1083,30 @@ def guard_source(self) -> GuardSource: return GuardSource.SHAPE_ENV +@dataclasses.dataclass(frozen=True) +class CurrentStreamSource(Source): + device: device_type + + def name(self) -> str: + return f"___get_current_stream(torch.device('{self.device.type}', {self.device.index}))" + + def reconstruct(self, codegen: "PyCodegen") -> None: + num_args = 1 + codegen.add_push_null( + lambda: codegen.load_import_from(utils.__name__, "get_current_stream") + ) + codegen.add_push_null(lambda: codegen.load_import_from("torch", "device")) + codegen.extend_output([codegen.create_load_const(self.device.type)]) + if self.device.index is not None: + num_args += 1 + codegen.extend_output([codegen.create_load_const(self.device.index)]) + codegen.extend_output(create_call_function(num_args, False)) + codegen.extend_output(create_call_function(1, False)) + + def guard_source(self) -> GuardSource: + return GuardSource.GLOBAL + + @dataclasses.dataclass(frozen=True) class BackwardStateSource(Source): def name(self) -> str: diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 2c7b09ee3a31e..9d0d87c5f8a06 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -171,6 +171,7 @@ UnknownVariable, ) from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable +from .variables.streams import SymbolicStreamState from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable from .variables.torch_function import ( SymbolicTorchFunctionState, @@ -538,91 +539,35 @@ def _detect_and_normalize_assert_statement( explain = False -def log_graph_break( - code_options: dict[str, Any], - reason: str = "", - exc_info: bool = False, - user_stack: Optional[StackSummary] = None, - latest_bytecode_log: Optional[str] = None, -) -> None: - if user_stack is None: - user_stack = torch._guards.TracingContext.extract_stack() - - try: - frame_loc = (user_stack[-1].filename, user_stack[-1].lineno) - except IndexError: - # first instruction - frame_loc = ( - code_options["co_filename"], - code_options["co_firstlineno"], - ) - - stack_above_dynamo_formatted = "" - if config.verbose: - stack_above_dynamo = get_stack_above_dynamo() - stack_above_dynamo_formatted = "".join( - traceback.format_list(stack_above_dynamo) - ) - else: - user_stack = get_stack_above_dynamo() + user_stack # type: ignore[assignment] - # pyrefly: ignore [bad-argument-type] - user_stack = collapse_resume_frames(user_stack) - user_stack_formatted = "".join(traceback.format_list(user_stack)) - user_stack_trace = ( - f"Graph break in user code at {frame_loc[0]}:{frame_loc[1]}\n" - f"Graph Break Reason: {reason}\n" - "User code traceback:\n" - ) - - if config.verbose: - user_stack_trace += ( - f"{stack_above_dynamo_formatted}\n" - "========== most recent `torch.compile` tracing attempt started here ==========\n\n" - f"{user_stack_formatted}\n" - "NOTE: the most recent `torch.compile` tracing attempt might not be where you applied `torch.compile`! " - "This is due to how graph breaks are implemented - the optimized code object returned by Dynamo will call another " - "Dynamo-generated resume function and tracing is re-enabled by calling the resume function as a normal Python " - "function, which Dynamo intercepts as a top-level frame.\n" - ) - else: - user_stack_trace += str(user_stack_formatted) - - torch._logging.trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "dynamo_graph_break_reason", - "encoding": "string", - }, - payload_fn=lambda: f"{user_stack_trace}\n{traceback.format_exc() if exc_info else ''}", - ) - - # torch._dynamo.explain() formats this a little nicer, and presents a slightly - # more actionable user code pointer - if ( - graph_break_log.isEnabledFor(logging.DEBUG) - and not explain - and graph_break_dup_warning_checker.add(frame_loc) - ): - # This log line MUST contain the string "Graph break in user code", - # This log line is exercised from - # python test/dynamo/test_exc.py -k test_graph_break_log - if latest_bytecode_log and config.verbose: - user_stack_trace += "Most recent bytecode instructions traced (max 20):\n" - user_stack_trace += latest_bytecode_log - - graph_break_log.debug( - user_stack_trace, - ) - else: - # This log line MUST not contain the string "Graph break in user code", - # exercised by - # python test/dynamo/test_misc.py -k test_duplicate_graph_break_log - graph_break_log.debug( - "Graph break (user stack suppressed due to duplicate graph break) in user code at %s:%s\nGraph Break Reason: %s", - frame_loc[0], - frame_loc[1], - reason, - ) +# [NOTE] graph break handling in symbolic_convert +# There are 4 possible graph break cases that InstructionTranslatorBase handles: +# 1. Regular graph breaks from CALL, BINARY_SUBSCR, etc. (implemented by break_graph_if_unsupported) +# 2. Data-dependent condition graph breaks (implemented by generic_jump) +# 3. STORE_ATTR graph breaks (implemented in InstructionTranslatorBase.STORE_ATTR) +# 4. All other unhandled graph breaks - unsupported step graph breaks (implemented in InstructionTranslatorBase.step) +# +# Graph breaks are handled in the following manner: +# 1. The Unsupported exception is caught. If we cannot compile a partial graph (should_compile_partial_graph() is False), +# then propagate the exception upward. For unsupported step graph breaks, the condition to abort partial compilation is +# more restrictive (see InstructionTranslatorBase.step). +# 2. If the Unsupported exception escapes symbolic_convert.py, then we are done. +# Otherwise, we want to attempt partial compilation. +# Log the graph break via log_graph_break. If we're handling a data-dependent graph break (type 2.), then we can immediately +# codegen the compiled graph and resume function and we're done. This is because the jump instruction we graph break on is +# limited in how it can manipulate Python state (say, in comparison, to CALL, which can modify Python state arbitrarily). +# Otherwise, we need to restart compilation. We need to restart because by processing the unsupported instruction, +# we may have modified the VariableTrackers, and we need all of our VariableTrackers to be in the state BEFORE tracing the +# unsupported instruction. +# 3. During the first compilation, we updated a speculation log, indicating points in the code that we can resume from. +# On the second compilation, we will stop tracing at the first speculation log that fails. Then we compile the partial +# graph and resume function. +# +# Logging invariants: +# 1. No logs need to be made if Unsupported escapes symbolic_convert.py. Python's default exception printing will +# print out all of the necessary information and no partial compilation will be attempted. +# 2. log_graph_break should be called as soon as Unsupported is caught and we determined we want to partial compile. +# This always happens on the first compilation, NOT the restart handling this graph +# 3. Any compile_subgraph call should be preceded immediately by a log in the form of "... triggered compile". def generic_jump( @@ -645,7 +590,8 @@ def jump_graph_break( value: VariableTracker, extra_msg: str = "", ) -> None: - log_graph_break( + assert self.should_compile_partial_graph() + self.log_graph_break( self.code_options, reason=format_graph_break_message( gb_type=_gb_type, @@ -654,7 +600,6 @@ def jump_graph_break( hints=_hints, ), ) - assert self.should_compile_partial_graph() # compile a partial subgraph prefix then jump into user code if self.maybe_has_backedge(): msg = ( @@ -928,12 +873,10 @@ def wrapper(self: InstructionTranslatorBase, inst: Instruction) -> None: if not self.should_compile_partial_graph(): raise - log_graph_break( + self.log_graph_break( self.code_options, - exc_info=True, reason=str(excp), user_stack=excp.real_stack, - latest_bytecode_log="\n".join(self.latest_bytecode_queue), ) if self.maybe_has_backedge(): @@ -966,6 +909,7 @@ def handle_graph_break( else: stack_effect = dis.stack_effect(inst.opcode, inst.arg) + log.debug("%s triggered compile", inst.opname) all_stack_locals_metadata = self.output.compile_subgraph( self, reason=reason, stack_pops=push - stack_effect ) @@ -1161,6 +1105,7 @@ class InstructionTranslatorBase( symbolic_locals: dict[str, VariableTracker] symbolic_globals: dict[str, VariableTracker] symbolic_torch_function_state: SymbolicTorchFunctionState + symbolic_stream_state: SymbolicStreamState post_prune_cell_and_freevars: Optional[dict[str, VariableTracker]] stack: list[VariableTracker] instruction_pointer: Optional[int] @@ -1376,8 +1321,33 @@ def step(self) -> bool: except (ReturnValueOp, YieldValueOp): return False except (Unsupported, StepUnsupported) as e: + # More restrictive condition than should_compile_partial_graph: + # if this condition is true, then we SHOULD NOT attempt to find + # a previous checkpoint to resume from and try to resume - we should + # immediately error out. + # The condition is more restrictive because, it may be possible to resume significantly earlier + # in the code (the most recent speculation point). This happens, for example, in the case + # of a graph break in a try block. + if ( + self.one_graph + or self.error_on_graph_break + or self.is_tracing_resume_prologue + ): + if isinstance(e, StepUnsupported): + unimplemented_v2( + gb_type="cannot resume from torch._dynamo.step_unsupported()", + context="", + explanation="traced torch._dynamo.step_unsupported(), but Dynamo is instructed " + "to error on graph break. This graph break is used for debugging only.", + hints=[ + "Remove the torch._dynamo.step_unsupported() call.", + "Make sure fullgraph=False and error_on_graph_break=False.", + *graph_break_hints.DYNAMO_BUG, + ], + ) + raise if self.current_speculation is None: - log.debug("empty checkpoint") + log.debug("empty checkpoint - cannot resume from graph break") if isinstance(e, StepUnsupported): unimplemented_v2( gb_type="torch._dynamo.step_unsupported() with empty checkpoint", @@ -1392,7 +1362,17 @@ def step(self) -> bool: ], ) raise - log.debug("step triggered compile", exc_info=True) + reason = ( + "Encountered graph break that we cannot resume from. " + "Compiling up to the previous resumable state, " + "then skipping the rest of the function. " + f"Graph break encountered:\n{str(e)}" + ) + self.log_graph_break( + self.code_options, + reason=reason, + user_stack=e.real_stack, + ) self.current_speculation.fail_and_restart_analysis(self.error_on_graph_break) return False @@ -1464,6 +1444,7 @@ def step_graph_break(self, continue_inst: Instruction) -> None: # NOTE: if we support non-empty self.stack in the future, the `stack_pops` argument # below should be set to the stack length to ensure that the stack is codegen'd # for the rest of the function. + log.debug("step triggered compile") all_stack_locals_metadata = self.output.compile_subgraph( self, partial_convert=True, @@ -2087,7 +2068,7 @@ def _create_exception_type(self, val: VariableTracker) -> VariableTracker: def _raise_exception_variable(self, val: VariableTracker) -> NoReturn: # User can raise exception in 2 ways # 1) raise exception type - raise NotImplementedError - # 2) raise exception instance - raise NotImplemetedError("foo") + # 2) raise exception instance - raise NotImplementedError("foo") # 1) when user raises exception type val = self._create_exception_type(val) @@ -2138,7 +2119,7 @@ def RAISE_VARARGS(self, inst: Instruction) -> None: try: self._raise_exception_variable(val) finally: - # Update __cause__/__supppress_context__ in the raised exception + # Update __cause__/__suppress_context__ in the raised exception curr_exc = self.exn_vt_stack.get_current_exception() cause = self._create_exception_type(from_vt) curr_exc.call_setattr(self, ConstantVariable("__cause__"), cause) # type: ignore[arg-type, union-attr, assignment] @@ -2415,8 +2396,8 @@ def check_if_exc_matches(self) -> bool: # Users can check exception in 3 ways # 1) except NotImplementedError --> BuiltinVariable - # 2) except CustomException --> UserDefinedExceptionClasVariable - # 3) except (NotImplemetedError, AttributeError) -> TupleVariable + # 2) except CustomException --> UserDefinedExceptionClassVariable + # 3) except (NotImplementedError, AttributeError) -> TupleVariable if not isinstance( expected_exc_types, @@ -2656,13 +2637,17 @@ def STORE_ATTR(self, inst: Instruction) -> None: except Unsupported as e: if not self.should_compile_partial_graph(): raise - log.debug("STORE_ATTR triggered compile", exc_info=True) + reason = f"Encountered graph break when attempting to store an object's attribute (STORE_ATTR):\n\n{str(e)}" + self.log_graph_break( + self.code_options, + reason=reason, + user_stack=e.real_stack, + ) e.remove_from_stats() e.add_to_stats("graph_break") speculation.fail_and_restart_analysis(self.error_on_graph_break) def store_attr_graph_break(self, inst: Instruction) -> None: - log_graph_break(self.code_options, reason="STORE_ATTR-caused graph break") if not self.should_compile_partial_graph(): unimplemented_v2( gb_type="Should not compile partial graph (STORE_ATTR)", @@ -2671,6 +2656,7 @@ def store_attr_graph_break(self, inst: Instruction) -> None: "STORE_ATTR instruction (i.e. `obj.attr = val`) that it should not compile the partial graph.", hints=[], ) + log.debug("STORE_ATTR triggered compile") all_stack_locals_metadata = self.output.compile_subgraph( self, reason=GraphCompileReason("store_attr", [self.frame_summary()]), @@ -3192,7 +3178,7 @@ def codegen_call_resume( ] ) - # TOS: resumes, frames (popped), frame 1 stack + locals + # TOS: resume 1, remaining resumes, frames (popped), frame 1 stack + locals cg.extend_output( [ *create_rot_n(3), @@ -3203,12 +3189,8 @@ def codegen_call_resume( ] ) - # TOS: [resumes, frames, *(frame 1 stack + locals)] - cg.extend_output( - [ - *create_call_function_ex(False, True), - ] - ) + # TOS: resume 1, [remaining resumes, frames, *(frame 1 stack + locals)] + cg.extend_output(create_call_function_ex(False, True)) def should_compile_partial_graph(self) -> bool: if sys.version_info >= (3, 11): @@ -4165,6 +4147,93 @@ def speculate(self) -> SpeculationEntry: self.instructions[self.instruction_pointer - 1], ) + def log_graph_break( + self, + code_options: dict[str, Any], + reason: str = "", + user_stack: Optional[StackSummary] = None, + ) -> None: + if user_stack is None: + user_stack = torch._guards.TracingContext.extract_stack() + + try: + frame_loc = (user_stack[-1].filename, user_stack[-1].lineno) + except IndexError: + # first instruction + frame_loc = ( + code_options["co_filename"], + code_options["co_firstlineno"], + ) + + stack_above_dynamo_formatted = "" + if config.verbose: + stack_above_dynamo = get_stack_above_dynamo() + stack_above_dynamo_formatted = "".join( + traceback.format_list(stack_above_dynamo) + ) + else: + user_stack = get_stack_above_dynamo() + user_stack # type: ignore[assignment] + # pyrefly: ignore [bad-argument-type] + user_stack = collapse_resume_frames(user_stack) + user_stack_formatted = "".join(traceback.format_list(user_stack)) + user_stack_trace = ( + f"Graph break in user code at {frame_loc[0]}:{frame_loc[1]}\n" + f"Graph Break Reason: {reason}\n" + "User code traceback:\n" + ) + + if config.verbose: + user_stack_trace += ( + f"{stack_above_dynamo_formatted}\n" + "========== most recent `torch.compile` tracing attempt started here ==========\n\n" + f"{user_stack_formatted}\n" + "NOTE: the most recent `torch.compile` tracing attempt might not be where you applied `torch.compile`! " + "This is due to how graph breaks are implemented - the optimized code object returned by Dynamo will call another " + "Dynamo-generated resume function and tracing is re-enabled by calling the resume function as a normal Python " + "function, which Dynamo intercepts as a top-level frame.\n" + ) + else: + user_stack_trace += str(user_stack_formatted) + + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "dynamo_graph_break_reason", + "encoding": "string", + }, + payload_fn=lambda: f"{user_stack_trace}\n{traceback.format_exc()}", + ) + + # torch._dynamo.explain() formats this a little nicer, and presents a slightly + # more actionable user code pointer + if ( + graph_break_log.isEnabledFor(logging.DEBUG) + and not explain + and graph_break_dup_warning_checker.add(frame_loc) + ): + # This log line MUST contain the string "Graph break in user code", + # This log line is exercised from + # python test/dynamo/test_exc.py -k test_graph_break_log + if config.verbose: + user_stack_trace += ( + "\nMost recent bytecode instructions traced (max 20):\n" + ) + user_stack_trace += "\n".join(self.latest_bytecode_queue) + "\n" + + graph_break_log.debug( + user_stack_trace, + ) + else: + # This log line MUST not contain the string "Graph break in user code", + # exercised by + # python test/dynamo/test_misc.py -k test_duplicate_graph_break_log + graph_break_log.debug( + "Graph break (user stack suppressed due to duplicate graph break) in user code at %s:%s\nGraph Break Reason: %s", + frame_loc[0], + frame_loc[1], + reason, + ) + def __init__( self, output: OutputGraph, @@ -4176,6 +4245,7 @@ def __init__( symbolic_locals: dict[str, VariableTracker], symbolic_globals: dict[str, VariableTracker], symbolic_torch_function_state: SymbolicTorchFunctionState, + symbolic_stream_state: SymbolicStreamState, f_code: types.CodeType, export: bool, inline_depth: int, @@ -4195,6 +4265,7 @@ def __init__( self.symbolic_locals = symbolic_locals self.symbolic_globals = symbolic_globals self.symbolic_torch_function_state = symbolic_torch_function_state + self.symbolic_stream_state = symbolic_stream_state # used to keep cell/freevars alive after pruning symbolic_locals (prune_dead_locals) # in order to generate any nested closures self.post_prune_cell_and_freevars = None @@ -4349,6 +4420,7 @@ def __init__( # A global var is inserted only after a STORE_GLOBAL happens to it symbolic_globals={}, symbolic_torch_function_state=None, # type: ignore[arg-type] # set below + symbolic_stream_state=None, # type: ignore[arg-type] # set below f_code=f_code, export=export, inline_depth=0, @@ -4453,6 +4525,8 @@ def __init__( torch_function_mode_stack ) + self.symbolic_stream_state = SymbolicStreamState() + if export: # export gets confused if we never realize unused inputs # in export mode just eagerly realize everything @@ -4542,7 +4616,7 @@ def _return(self, inst: Instruction) -> None: logging.INFO, f"torchdynamo done tracing {self.f_code.co_name} ({inst.opname})", ) - log.debug("%s triggered compile", inst.opname) + log.debug("return triggered compile") all_stack_locals_metadata = self.output.compile_subgraph( self, reason=GraphCompileReason( @@ -4779,6 +4853,7 @@ def get_trace_call_log_str() -> str: sub_locals, parent.symbolic_globals, parent.symbolic_torch_function_state, + parent.symbolic_stream_state, func, ) else: @@ -4790,6 +4865,7 @@ def get_trace_call_log_str() -> str: sub_locals, parent.symbolic_globals, parent.symbolic_torch_function_state, + parent.symbolic_stream_state, # pyrefly: ignore [bad-argument-type] func, ) @@ -4873,6 +4949,7 @@ def __init__( symbolic_locals: dict[str, VariableTracker], symbolic_globals: dict[str, VariableTracker], symbolic_torch_function_state: SymbolicTorchFunctionState, + symbolic_stream_state: SymbolicStreamState, funcvar: BaseUserFunctionVariable, ) -> None: f_globals = funcvar.get_globals() # type: ignore[attr-defined] @@ -4906,6 +4983,7 @@ def __init__( symbolic_locals=symbolic_locals, symbolic_globals=symbolic_globals, symbolic_torch_function_state=symbolic_torch_function_state, + symbolic_stream_state=symbolic_stream_state, instructions=instructions, code_options={k: getattr(code, k) for k in get_code_keys()}, f_code=code, @@ -5130,7 +5208,7 @@ def SEND(self, inst: Instruction) -> None: ): if isinstance(val, ConstantVariable) and val.value is None: try: - val = tos.next_variable(self) + val = tos.next_variable(self) # type: ignore[arg-type] except (StopIteration, exc.ObservedUserStopIteration) as ex: # To implement SEND, we have to look at the implementation # when the iterator returns StopIteration. This translates to this code diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 595ba182a597c..6a162350039d7 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2590,6 +2590,8 @@ "torch.cuda._set_rng_state_offset", "torch.cuda._set_stream_by_id", "torch.cuda._sleep", + "torch.cuda._busy_wait_for_flag", + "torch.cuda._clear_flag", "torch.cuda._transform_uuid_to_ordinals", "torch.cuda._utils._get_device_index", "torch.cuda.amp.autocast_mode._cast", diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index d07ad52ab32c1..644081ab68579 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2707,6 +2707,7 @@ def to_subclass(t: Any, cls: type) -> Any: dict_getitem = dict.__getitem__ +@torch.fx.wrap def dict_keys_getitem(d: dict[Any, Any], n: int) -> Any: # Call dict(d) to prevent calling overridden __iter__/keys dict_class = dict @@ -4695,6 +4696,10 @@ def clear_torch_function_mode_stack() -> None: _pop_torch_function_stack() +def get_current_stream(device: torch.device) -> torch.Stream: + return torch.accelerator.current_stream(device) + + # call from C dynamo in order to inspect values in pdb def _breakpoint_for_c_dynamo(*args: Any) -> None: breakpoint() @@ -4759,33 +4764,8 @@ def _extract_tensor_dict(t: torch.Tensor) -> dict[str, Any]: return tensor_dict -# This is useful for reconstructing within the Dynamo graph the non-graph-input objects -# whose lifetime is governed by the user. -# e.g. torch.cuda.Event is a prime example. -user_obj_id_to_weakref: dict[int, weakref.ReferenceType[object]] = {} - - -# TODO: mlazos to remove after replacing w/ above API -def get_user_object_from_id(obj_id: int) -> Any: - obj = user_obj_id_to_weakref[obj_id]() - assert obj is not None, "User object is no longer alive" - return obj - - -def store_user_object_weakref(obj: object) -> None: - obj_id = id(obj) - try: - user_obj_id_to_weakref[obj_id] = weakref.ref(obj) - except TypeError as e: - from .exc import unimplemented_v2 - - unimplemented_v2( - gb_type="Failed to make weakref to User Object when storing by ID", - context=f"user_objected: {obj}", - explanation="Object does not allow us to make a weakref to it", - hints=[], - from_exc=e, - ) +def build_stream(args: tuple[Any], kwargs: dict[Any, Any]) -> torch.Stream: + return torch._C.Stream(*args, **kwargs) class CompileTimeInstructionCounter: diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 731c29a365ad7..0abf2cc91e784 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ Core variable tracking functionality for Dynamo. This module defines the fundamental classes and systems used to track and manage variables during Dynamo's operation. @@ -18,7 +16,10 @@ import collections from collections.abc import Callable, ItemsView, KeysView, Sequence, ValuesView from enum import Enum -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, NoReturn, Optional, TYPE_CHECKING + +from torch._guards import Guard +from torch.fx.proxy import Node from .. import graph_break_hints, variables from ..current_scope_id import current_scope_id @@ -30,7 +31,7 @@ if TYPE_CHECKING: from ..codegen import PyCodegen - from ..symbolic_convert import InstructionTranslator, InstructionTranslatorBase + from ..symbolic_convert import InstructionTranslator class SourceType(Enum): @@ -115,10 +116,10 @@ class ValueMutationNew(MutationType): def __init__(self) -> None: super().__init__(SourceType.New) - def __hash__(self): + def __hash__(self) -> int: return id(self) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return self is other @@ -139,7 +140,7 @@ class ValueMutationExisting(MutationType): # filter out which pre-existing values it needs to generate mutation for. is_modified: bool - def __init__(self, is_modified: bool = False): + def __init__(self, is_modified: bool = False) -> None: super().__init__(SourceType.Existing) self.is_modified = is_modified @@ -150,7 +151,7 @@ class AttributeMutation(MutationType): allows mutation on the value's attributes. """ - def __init__(self, typ: SourceType): + def __init__(self, typ: SourceType) -> None: super().__init__(typ) @@ -166,7 +167,7 @@ class AttributeMutationExisting(AttributeMutation): be used afterwards in Python. """ - def __init__(self): + def __init__(self) -> None: super().__init__(SourceType.Existing) @@ -182,16 +183,16 @@ class AttributeMutationNew(AttributeMutation): the Python world. """ - def __init__(self, cls_source: Optional[Source] = None): + def __init__(self, cls_source: Optional[Source] = None) -> None: super().__init__(SourceType.New) self.cls_source = cls_source -def _is_top_level_scope(scope_id): +def _is_top_level_scope(scope_id: int) -> bool: return scope_id == 1 -def is_side_effect_safe(m: MutationType): +def is_side_effect_safe(m: MutationType) -> bool: scope_id = current_scope_id() # In the top-level scope (if no HigherOrderOperators are involved), @@ -209,15 +210,15 @@ def is_side_effect_safe(m: MutationType): class AsPythonConstantNotImplementedError(NotImplementedError): vt: "VariableTracker" - def __init__(self, vt: "VariableTracker"): + def __init__(self, vt: "VariableTracker") -> None: super().__init__(f"{vt} is not a constant") self.vt = vt class VariableTrackerMeta(type): - all_subclasses = [] + all_subclasses: list[type] = [] - def __instancecheck__(cls, instance) -> bool: + def __instancecheck__(cls: type, instance: object) -> bool: """Make isinstance work with LazyVariableTracker""" # This is super expensive - just having it costs over 4% of tracing # time! @@ -227,8 +228,10 @@ def __instancecheck__(cls, instance) -> bool: instance = instance.realize() return type.__instancecheck__(cls, instance) - def __init__(cls, name, bases, attrs) -> None: - super().__init__(name, bases, attrs) + def __init__( + cls: type, name: str, bases: tuple[type, ...], attrs: dict[str, Any] + ) -> None: + super().__init__(name, bases, attrs) # type: ignore[misc] VariableTrackerMeta.all_subclasses.append(cls) @@ -252,7 +255,7 @@ class VariableTracker(metaclass=VariableTrackerMeta): "user_code_variable_name", } - def clone(self, **kwargs): + def clone(self, **kwargs: Any) -> "VariableTracker": """Shallow copy with some (optional) changes""" args = dict(self.__dict__) args.update(kwargs) @@ -295,14 +298,14 @@ def visit( def __repr__(self) -> str: return f"{self.__class__.__name__}()" - def debug_repr(self): + def debug_repr(self) -> str: # Intended to be overridden to provide more info try: return repr(self.as_python_constant()) except NotImplementedError: return repr(self) - def python_type(self): + def python_type(self) -> type: """ Abstract method to be implemented by subclasses of VariableTracker. @@ -331,17 +334,17 @@ def python_type(self): except NotImplementedError: raise NotImplementedError(f"{self} has no type") from None - def python_type_name(self): + def python_type_name(self) -> str: try: return self.python_type().__name__ except NotImplementedError: return "" - def as_python_constant(self): + def as_python_constant(self) -> Any: """For constants""" raise AsPythonConstantNotImplementedError(self) - def guard_as_python_constant(self): + def guard_as_python_constant(self) -> Any: """Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" try: return self.as_python_constant() @@ -353,18 +356,20 @@ def guard_as_python_constant(self): hints=[], ) - def is_python_constant(self): + def is_python_constant(self) -> bool: try: self.as_python_constant() return True except NotImplementedError: return False - def make_guard(self, fn): + def make_guard(self, fn: Callable[..., Any]) -> Guard: if self.source: return self.source.make_guard(fn) raise NotImplementedError + # TODO[@lucaskabela] - change this type to `InstructionTranslatorBase` + # and cascade that (large blast radius) def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: """getattr(self, name) returning a python constant""" raise NotImplementedError @@ -381,17 +386,17 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) return variables.ConstantVariable.create(value, source=source) - def is_proxy(self): + def is_proxy(self) -> bool: try: self.as_proxy() return True except NotImplementedError: return False - def as_proxy(self): + def as_proxy(self) -> Any: raise NotImplementedError(str(self)) - def maybe_fx_node(self): + def maybe_fx_node(self) -> Optional[Node]: try: proxy = self.as_proxy() import torch.fx @@ -402,13 +407,13 @@ def maybe_fx_node(self): except NotImplementedError: return None - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: raise NotImplementedError - def unpack_var_sequence(self, tx) -> list["VariableTracker"]: + def unpack_var_sequence(self, tx: Any) -> list["VariableTracker"]: raise NotImplementedError - def force_unpack_var_sequence(self, tx) -> list["VariableTracker"]: + def force_unpack_var_sequence(self, tx: Any) -> list["VariableTracker"]: # like unpack_var_sequence, but should only be used when it is # safe to eagerly (vs. lazily) unpack this variable. # e.g. map(f, x) is normally evaluated lazily but sometimes @@ -417,7 +422,7 @@ def force_unpack_var_sequence(self, tx) -> list["VariableTracker"]: # it should only be called once. return self.unpack_var_sequence(tx) - def has_unpack_var_sequence(self, tx) -> bool: + def has_unpack_var_sequence(self, tx: Any) -> bool: try: self.unpack_var_sequence(tx) return True @@ -425,13 +430,15 @@ def has_unpack_var_sequence(self, tx) -> bool: return False # NB: don't call force_unpack_var_sequence, especially if it mutates! - def has_force_unpack_var_sequence(self, tx) -> bool: + def has_force_unpack_var_sequence(self, tx: Any) -> bool: return self.has_unpack_var_sequence(tx) # Forces unpacking the var sequence while also applying a function to each element. # Only use when it is safe to eagerly unpack this variable (like force_unpack_var_sequence). # INVARIANT: variable must satisfy has_force_unpack_var_sequence() == True! - def force_apply_to_var_sequence(self, tx, fn) -> None: + def force_apply_to_var_sequence( + self, tx: Any, fn: Callable[["VariableTracker"], Any] + ) -> None: assert self.has_force_unpack_var_sequence(tx) for v in self.unpack_var_sequence(tx): fn(v) @@ -444,9 +451,7 @@ def inspect_parameter_names(self) -> list[str]: hints=[], ) - def call_obj_hasattr( - self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + def call_obj_hasattr(self, tx: Any, name: str) -> "VariableTracker": unimplemented_v2( gb_type="Unsupported hasattr call", context=f"call_obj_hasattr {self} {name}", @@ -459,9 +464,9 @@ def call_obj_hasattr( def call_function( self, - tx: "InstructionTranslator", + tx: Any, args: Sequence["VariableTracker"], - kwargs: "dict[str, VariableTracker]", + kwargs: dict[str, "VariableTracker"], ) -> "VariableTracker": unimplemented_v2( gb_type="Unsupported function call", @@ -475,10 +480,10 @@ def call_function( def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", + tx: Any, + name: str, + args: list["VariableTracker"], + kwargs: dict[str, "VariableTracker"], ) -> "VariableTracker": if name == "__len__" and self.has_unpack_var_sequence(tx): assert not (args or kwargs) @@ -562,7 +567,7 @@ def call_method( hints=hints, ) - def set_name_hint(self, name): + def set_name_hint(self, name: str) -> None: pass def realize(self) -> "VariableTracker": @@ -573,11 +578,11 @@ def unwrap(self) -> "VariableTracker": """Used by LazyVariableTracker to return the real VariableTracker if it already exists""" return self - def is_realized(self): + def is_realized(self) -> bool: """Used by LazyVariableTracker to indicate an unrealized node""" return True - def next_variable(self, tx): + def next_variable(self, tx: Any) -> "VariableTracker": unimplemented_v2( gb_type="Unsupported next() call", context=f"next({self})", @@ -585,20 +590,20 @@ def next_variable(self, tx): hints=[*graph_break_hints.USER_ERROR], ) - def is_strict_mode(self, tx): - return tx.strict_checks_fn and tx.strict_checks_fn(self) + def is_strict_mode(self, tx: Any) -> bool: + return bool(tx.strict_checks_fn and tx.strict_checks_fn(self)) - def is_mutable(self): + def is_mutable(self) -> bool: """Whether Dynamo allows mutation on this variable.""" return not self.is_immutable() - def is_immutable(self): + def is_immutable(self) -> bool: """Whether Dynamo bans mutation on this variable.""" return self.mutation_type is None @staticmethod def build( - tx: "InstructionTranslatorBase", + tx: Any, value: Any, source: Optional[Source] = None, ) -> Any: @@ -611,8 +616,8 @@ def build( def __init__( self, *, - source: Source = None, - mutation_type: MutationType = None, + source: Optional[Source] = None, + mutation_type: Optional[MutationType] = None, ) -> None: super().__init__() self.source = source @@ -636,12 +641,12 @@ def __init__( assert source is not None -def raise_type_error_exc(tx: "InstructionTranslator", msg_str: str) -> None: +def raise_type_error_exc(tx: Any, msg_str: str) -> NoReturn: msg = variables.ConstantVariable.create(msg_str) raise_observed_exception(TypeError, tx, args=[msg]) -def typestr(*objs): +def typestr(*objs: object) -> str: if len(objs) == 1: (obj,) = objs if isinstance(obj, VariableTracker): diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 2a1cff0211f5b..81baaa236b0a8 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -46,7 +46,7 @@ from torch import SymInt from torch._dispatch.python import enable_python_dispatcher from torch._dynamo.graph_bytecode_inputs import ( - get_user_object_by_index, + get_external_object_by_index, register_user_object, ) from torch._dynamo.utils import ( @@ -1057,13 +1057,12 @@ def build_key_value(i, k, v): self.install_guards(GuardBuilder.TYPE_MATCH) index = register_user_object(value, self.source) stream_proxy = self.tx.output.create_proxy( - "call_function", get_user_object_by_index, (index,), {} + "call_function", get_external_object_by_index, (index,), {} ) set_example_value(stream_proxy.node, value) var = StreamVariable( stream_proxy, value, - value.device, source=self.source, ) return self.tx.output.side_effects.track_object_existing(value, var) @@ -1078,7 +1077,7 @@ def build_key_value(i, k, v): index = register_user_object(value, self.source) event_proxy = self.tx.output.create_proxy( "call_function", - get_user_object_by_index, + get_external_object_by_index, (index,), {}, ) @@ -2930,12 +2929,12 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe hasattr(proxy.node.target, "__name__") and proxy.node.target.__name__ == "set_state" and isinstance(proxy.node.target.__self__, torch._C.Generator) - or proxy.node.target == torch.random.set_rng_state + or proxy.node.target is torch.random.set_rng_state ): return TorchInGraphFunctionVariable(proxy.node.target) elif ( - proxy.node.target == torch._C._DisableFuncTorch - or proxy.node.target == torch.cuda._is_in_bad_fork + proxy.node.target is torch._C._DisableFuncTorch + or proxy.node.target is torch.cuda._is_in_bad_fork ): return UserDefinedObjectVariable(example_value) elif istype(example_value, torch.Size) and all( @@ -3006,14 +3005,15 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe set_example_value(proxy.node, example_value) return SymNodeVariable(proxy, example_value, **options) elif ( - inspect.isclass(proxy.node.target) - and issubclass(proxy.node.target, torch.Stream) + isinstance(example_value, torch.Stream) + and proxy.node.target + in (get_external_object_by_index, torch.accelerator.current_stream) ) or proxy.node.target in [ device_interface.current_stream for _, device_interface in get_registered_device_interfaces() ]: set_example_value(proxy.node, example_value) - return StreamVariable(proxy, example_value, example_value.device, **options) + return StreamVariable(proxy, example_value, **options) elif ( inspect.isclass(proxy.node.target) and issubclass(proxy.node.target, torch.Event) @@ -3743,7 +3743,7 @@ def create(tx: "InstructionTranslator", value) -> VariableTracker: return torch._dynamo.variables.higher_order_ops.FlexAttentionBackwardHighOrderVariable( value ) - elif isinstance(value, types.GenericAlias): + elif isinstance(value, (types.GenericAlias, types.UnionType)): return TypingVariable(value) elif is_namedtuple(value): output = [ diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 2382d4ef5df4a..4e68cf6f3071f 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1828,17 +1828,20 @@ def _call_tuple_list(self, tx, obj=None, *args, **kwargs): return self._call_iter_tuple_list(tx, obj, *args, **kwargs) def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs): - if isinstance(obj, variables.IteratorVariable): - ret = obj - elif isinstance(obj, variables.RangeVariable): - ret = obj.call_method(tx, "__iter__", [], {}) - elif isinstance(obj, variables.LocalGeneratorObjectVariable): - ret = obj # type: ignore[assignment] + # avoid the overhead of tracing the polyfill if we already know the class implemented __iter__ + if isinstance( + obj, + ( + variables.ListVariable, + variables.RangeVariable, + variables.IteratorVariable, + variables.ConstDictVariable, + variables.NNModuleVariable, + variables.TensorVariable, + ), + ): + return obj.call_method(tx, "__iter__", [], {}) else: - # Handle the case where we are iterating over a tuple, list or iterator - ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs) - - if ret is None: # If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway. # If the object implements a __iter__ method, inlining effectively forwards the call to another iter call # (e.g. when __iter__ just returns iter(self.list)) or return a user-defined iterator. @@ -1854,7 +1857,7 @@ def call_iter(self, tx: "InstructionTranslator", obj, *args, **kwargs): # Wrap the return value in a IteratorVariable subclass (LazyObjectIteratorVariable) # that forwards the next_variable call to the object. ret = variables.ObjectIteratorVariable(ret) - return ret + return ret call_tuple = _call_tuple_list call_list = _call_tuple_list diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 34573c5cfc773..1793f5c10844e 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ Constant and enum variable tracking in Dynamo. @@ -8,8 +6,9 @@ maintaining type safety through the compilation process. """ +import enum import operator -from typing import TYPE_CHECKING +from typing import Any, Literal, Optional, TYPE_CHECKING, Union import torch from torch._dynamo.source import AttrSource, GetItemSource @@ -23,7 +22,7 @@ np, raise_args_mismatch, ) -from .base import VariableTracker +from .base import ValueMutationNew, VariableTracker if TYPE_CHECKING: @@ -40,7 +39,7 @@ class ConstantVariable(VariableTracker): """ @staticmethod - def create(value, **kwargs) -> VariableTracker: + def create(value: Any, **kwargs: Any) -> VariableTracker: """ Create a `ConstantVariable` based on the given value, and supports automatic routing for collection types like `tuple` (in which case we'd @@ -76,7 +75,7 @@ def create(value, **kwargs) -> VariableTracker: return ConstantVariable(value, **kwargs) - def __init__(self, value, **kwargs) -> None: + def __init__(self, value: Any, **kwargs: Any) -> None: super().__init__(**kwargs) assert ConstantVariable.is_base_literal(value), f""" Cannot construct `ConstantVariable` for value of type {type(value)}. @@ -92,48 +91,52 @@ def __init__(self, value, **kwargs) -> None: else: self.value = value - def as_proxy(self): + def as_proxy(self) -> Any: return self.value def __repr__(self) -> str: return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})" - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.value - def is_python_constant(self): + def is_python_constant(self) -> Literal[True]: return True @property - def items(self): + def items(self) -> list[VariableTracker]: """ Need this when adding a BaseListVariable and a ConstantVariable together. Happens in detectron2. """ return self.unpack_var_sequence(tx=None) - def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: return ConstantVariable.create( self.value[arg.as_python_constant()], ) @staticmethod - def is_base_literal(obj): + def is_base_literal(obj: object) -> bool: return type(obj) in common_constant_types @staticmethod - def is_literal(obj): + def is_literal(obj: object) -> bool: if type(obj) in (list, tuple, set, frozenset, torch.Size): - return all(ConstantVariable.is_literal(x) for x in obj) + return all(ConstantVariable.is_literal(x) for x in obj) # type: ignore[attr-defined] return ConstantVariable.is_base_literal(obj) - def unpack_var_sequence(self, tx): + def unpack_var_sequence( + self, tx: Optional["InstructionTranslator"] + ) -> list[VariableTracker]: try: return [ConstantVariable.create(x) for x in self.as_python_constant()] except TypeError as e: raise NotImplementedError from e - def const_getattr(self, tx: "InstructionTranslator", name): + def const_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if not hasattr(self.value, name): raise_observed_exception(AttributeError, tx, args=[name]) member = getattr(self.value, name) @@ -144,10 +147,10 @@ def const_getattr(self, tx: "InstructionTranslator", name): def call_method( self, tx: "InstructionTranslator", - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: from .tensor import SymNodeVariable if name == "format" and istype(self.value, str): @@ -168,6 +171,14 @@ def call_method( return ConstantVariable.create(self.value.join(arg_const)) except NotImplementedError: return super().call_method(tx, name, args, kwargs) + elif name == "__iter__" and istype(self.value, str): + # this could be some generic iterator to avoid the circular import, + # but ListIterator does what we want + from .lists import ListIteratorVariable + + return ListIteratorVariable( + self.unpack_var_sequence(tx), mutation_type=ValueMutationNew() + ) if any(isinstance(x, SymNodeVariable) for x in args): # Promote to SymNodeVariable for operations involving dynamic shapes. @@ -230,10 +241,12 @@ def call_method( raise_observed_exception(type(e), tx) if name == "__len__" and not (args or kwargs): + # pyrefly: ignore [bad-argument-type] return ConstantVariable.create(len(self.value)) elif name == "__round__" and len(args) == 1 and args[0].is_python_constant(): try: return ConstantVariable.create( + # pyrefly: ignore [no-matching-overload] round(self.value, args[0].as_python_constant()) ) except Exception as e: @@ -244,6 +257,7 @@ def call_method( assert not kwargs search = args[0].as_python_constant() try: + # pyrefly: ignore [unsupported-operation] result = search in self.value return ConstantVariable.create(result) except TypeError as e: @@ -254,7 +268,7 @@ def call_method( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + ) -> VariableTracker: result = hasattr(self.value, name) return variables.ConstantVariable.create(result) @@ -266,12 +280,14 @@ class EnumVariable(VariableTracker): both standard Enum and IntEnum with proper value tracking and comparison. """ - def __init__(self, value, **kwargs) -> None: + def __init__(self, value: Union[enum.Enum, enum.IntEnum], **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value @classmethod - def create(cls, cls_type, value_vt, options): + def create( + cls, cls_type: Any, value_vt: VariableTracker, options: Any + ) -> "EnumVariable": if isinstance(value_vt, variables.ConstantVariable): for member in list(cls_type): if member.value == value_vt.as_python_constant(): @@ -285,7 +301,7 @@ def create(cls, cls_type, value_vt, options): hints=[*graph_break_hints.USER_ERROR, *graph_break_hints.SUPPORTABLE], ) - def as_proxy(self): + def as_proxy(self) -> Union[enum.Enum, int]: if isinstance(self.value, int): return int(self.value) # convert IntEnum to a normal int return self.value @@ -293,10 +309,10 @@ def as_proxy(self): def __repr__(self) -> str: return f"EnumVariable({type(self.value)})" - def as_python_constant(self): + def as_python_constant(self) -> Union[enum.Enum, enum.IntEnum]: return self.value - def var_getattr(self, tx: "InstructionTranslator", name): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if not hasattr(self.value, name): raise NotImplementedError if name in cmp_name_to_op_mapping: diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 481aa7e1a302b..4f1f84a55b0b0 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -34,7 +34,7 @@ from ..bytecode_transformation import create_call_function, create_instruction from ..exc import raise_observed_exception, unimplemented_v2 from ..guards import GuardBuilder, install_guard -from ..source import is_from_local_source +from ..source import is_constant_source, is_from_local_source from ..utils import ( cmp_name_to_op_mapping, dict_items, @@ -46,6 +46,7 @@ ) from .base import ValueMutationNew, VariableTracker from .constant import ConstantVariable +from .lists import ListIteratorVariable if TYPE_CHECKING: @@ -53,7 +54,7 @@ from torch._dynamo.symbolic_convert import InstructionTranslator -# [Adding a new supported class within the keys of ConstDictVarialble] +# [Adding a new supported class within the keys of ConstDictVariable] # - Add its tracker type to is_hashable # - (perhaps) Define how it is compared in _HashableTracker._eq_impl @@ -779,6 +780,12 @@ def call_method( elif name == "__ior__": self.call_method(tx, "update", args, kwargs) return self + elif name == "__iter__": + if self.source and not is_constant_source(self.source): + tx.output.guard_on_key_order.add(self.source) + return ListIteratorVariable( + self.unpack_var_sequence(tx), mutation_type=ValueMutationNew() + ) else: return super().call_method(tx, name, args, kwargs) @@ -787,12 +794,16 @@ def unpack_var_sequence(self, tx): return [x.vt for x in self.items.keys()] def call_obj_hasattr(self, tx, name): - # dict not allow setting arbitrary attributes. To check for hasattr, we can just check the __dict__ of the dict. - # OrderedDict though requires side effects tracking because it supports arbitrary setattr. - if self.user_cls is dict: - if name in self.user_cls.__dict__: + # dict not allow setting arbitrary attributes. OrderedDict and + # defaultdict allow arbitrary setattr, but not deletion of default attrs + if any( + self.user_cls is t + for t in (dict, collections.OrderedDict, collections.defaultdict) + ): + if hasattr(self.user_cls, name): return ConstantVariable.create(True) - return ConstantVariable.create(False) + if self.user_cls is dict: + return ConstantVariable.create(False) msg = f"hasattr on {self.user_cls} is not supported" unimplemented_v2( @@ -879,6 +890,13 @@ def call_method( ) return self.dv_dict.call_method(tx, name, args, kwargs) + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "VariableTracker": + if self.python_type() is types.MappingProxyType: + return ConstantVariable.create(name in types.MappingProxyType.__dict__) + return super().call_obj_hasattr(tx, name) + class NNModuleHooksDictVariable(ConstDictVariable): # Special class to avoid adding any guards on the nn module hook ids. @@ -1388,6 +1406,10 @@ def call_method( ) -> "VariableTracker": if name == "__len__": return self.dv_dict.call_method(tx, name, args, kwargs) + elif name == "__iter__": + return ListIteratorVariable( + self.view_items_vt, mutation_type=ValueMutationNew() + ) return super().call_method(tx, name, args, kwargs) diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index 37878abbb37e4..eb39dd8fa3e07 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ Distributed computing variable tracking classes for PyTorch Dynamo. @@ -22,7 +20,7 @@ import functools import inspect -from typing import TYPE_CHECKING +from typing import Any, Sequence, TYPE_CHECKING import torch from torch.fx.experimental._backward_state import BackwardState @@ -40,6 +38,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -54,7 +53,7 @@ class DistributedVariable(VariableTracker): and hold the tracking value for the corresponding distributed object. """ - def __init__(self, value, **kwargs) -> None: + def __init__(self, value: Any, **kwargs: Any) -> None: super().__init__(**kwargs) if not DistributedVariable.is_available(): unimplemented_v2( @@ -67,16 +66,16 @@ def __init__(self, value, **kwargs) -> None: ) self.value = value - def python_type(self): + def python_type(self) -> type: return type(self.value) @staticmethod - def is_available(): + def is_available() -> bool: # check if the distributed package is available or not return torch.distributed.is_available() -def is_from_local(value): +def is_from_local(value: object) -> bool: if not DistributedVariable.is_available(): return False from torch.distributed.tensor import DTensor @@ -84,7 +83,7 @@ def is_from_local(value): return inspect.isfunction(value) and value is DTensor.from_local -def is_constant_pg_functions(value): +def is_constant_pg_functions(value: object) -> bool: if not DistributedVariable.is_available(): return False @@ -114,7 +113,7 @@ class WorldMetaClassVariable(DistributedVariable): """ @classmethod - def is_group_member_type(cls, value): + def is_group_member_type(cls, value: object) -> bool: if not cls.is_available(): return False @@ -124,10 +123,12 @@ def is_group_member_type(cls, value): def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name == "WORLD": + assert self.source source = AttrSource(base=self.source, member="WORLD") install_guard(source.make_guard(GuardBuilder.ID_MATCH)) return ProcessGroupVariable(self.value.WORLD) elif name == "NON_GROUP_MEMBER": + assert self.source source = AttrSource(base=self.source, member="NON_GROUP_MEMBER") install_guard(source.make_guard(GuardBuilder.ID_MATCH)) return EnumVariable(self.value.NON_GROUP_MEMBER) @@ -136,7 +137,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker class PlacementClassVariable(DistributedVariable): @staticmethod - def is_placement_type(value): + def is_placement_type(value: object) -> bool: # we can't rely on importing/accessing torch distributed, it is not always built. if not DistributedVariable.is_available(): return False @@ -145,15 +146,15 @@ def is_placement_type(value): return isinstance(value, type) and issubclass(value, Placement) - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.value def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if self.source: # NOTE: we don't need to track mutations to the placement class as they # are supposed to be immutable. @@ -168,16 +169,15 @@ def call_function( class PlacementVariable(DistributedVariable): @staticmethod - def is_placement(value): + def is_placement(value: object) -> bool: # we can't rely on importing/accessing torch distributed, it is not always built. if not DistributedVariable.is_available(): return False - from torch.distributed.tensor.placement_types import Placement return isinstance(value, Placement) - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.value def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: @@ -187,11 +187,11 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: from . import ConstantVariable # Placement types dynamo tracking only allows following methods @@ -228,15 +228,16 @@ def call_method( args = [x.as_python_constant() for x in args] kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} + assert method is not None if name == "__setattr__": method(self.value, *args, **kwargs) return self constant_val = method(self.value, *args, **kwargs) return ConstantVariable.create(constant_val) - return super().call_method(tx, name, args, kwargs) + return super().call_method(tx, name, args, kwargs) # type: ignore[arg-type] - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen") -> None: # Reconstruct the Placement object by calling its constructor # e.g., Shard(0), Replicate(), Partial() from torch.distributed.tensor.placement_types import Partial, Replicate, Shard @@ -263,7 +264,7 @@ def reconstruct(self, codegen): class DeviceMeshVariable(DistributedVariable): @staticmethod - def is_device_mesh(value): + def is_device_mesh(value: object) -> bool: # we can't rely on importing/accessing torch distributed, it is not always built. if not DistributedVariable.is_available(): return False @@ -272,7 +273,7 @@ def is_device_mesh(value): return istype(value, DeviceMesh) - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.value def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: @@ -289,11 +290,11 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "size": const_args = [x.as_python_constant() for x in args] const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} @@ -338,16 +339,16 @@ class ProcessGroupVariable(DistributedVariable): or just graph-break whenever one of our special cases is not hit? """ - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.value def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "rank": return variables.ConstantVariable.create(self.value.rank()) if name == "size": @@ -357,7 +358,7 @@ def call_method( return super().call_method(tx, name, args, kwargs) - def var_getattr(self, tx: "InstructionTranslator", name): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name == "group_name": return variables.ConstantVariable.create(self.value.group_name) if name in ["rank", "size"]: @@ -368,7 +369,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): return super().var_getattr(tx, name) @staticmethod - def is_process_group(value): + def is_process_group(value: object) -> bool: # we can't rely on importing/accessing torch distributed, it is not always built. if not DistributedVariable.is_available(): return False @@ -386,11 +387,11 @@ class BackwardHookVariable(VariableTracker): @staticmethod def create( - tx, + tx: "InstructionTranslator", module: VariableTracker, user_hooks: VariableTracker, user_pre_hooks: VariableTracker, - ): + ) -> "BackwardHookVariable": if not compiled_autograd.compiled_autograd_enabled: unimplemented_v2( gb_type="Module-level backwards hooks require compiled autograd.", @@ -401,7 +402,9 @@ def create( ], ) - def _in_graph_bw_hooks(bw_state: BackwardState): + def _in_graph_bw_hooks( + bw_state: BackwardState, + ) -> torch.utils.hooks.BackwardHook: """ Rather than installing the user hooks in the graph (which don't survive AotAutograd), we install hooks that will call @@ -448,7 +451,7 @@ def __init__( module: VariableTracker, user_hooks: VariableTracker, user_pre_hooks: VariableTracker, - **options, + **options: Any, ) -> None: super().__init__(**options) self.proxy = proxy @@ -456,13 +459,13 @@ def __init__( self.user_hooks = user_hooks self.user_pre_hooks = user_pre_hooks - def as_proxy(self): + def as_proxy(self) -> torch.fx.Proxy: return self.proxy def call_method( self, - tx, - name, + tx: "InstructionTranslator", + name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: @@ -470,7 +473,9 @@ def call_method( return self._setup_hook(tx, name, *args, **kwargs) return super().call_method(tx, name, args, kwargs) - def _setup_hook(self, tx: "InstructionTranslator", hook_method_name, args): + def _setup_hook( + self, tx: "InstructionTranslator", hook_method_name: str, args: VariableTracker + ) -> VariableTracker: from .builder import wrap_fx_proxy return wrap_fx_proxy( diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index f15be35cac99d..0752a413fce6e 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -154,27 +154,32 @@ def bind_args_cached(func, tx, fn_source, args, kwargs): rem_kw = dict(kwargs) # 1) Bind all positional (pos-only + pos-or-kw) + # 1.1) Apply pos-defaults first (maybe overridden later) + for name, idx in spec.pos_default_map.items(): + default_source = None + if fn_source and not ( + ConstantVariable.is_literal(spec.defaults[idx]) + and config.skip_guards_on_constant_func_defaults + ): + default_source = DefaultsSource(fn_source, idx) + ba[name] = wrap_bound_arg(tx, spec.defaults[idx], default_source) + # 1.2) Fill in provided positional args for i, name in enumerate(spec.all_pos_names): if i < len(args): + # Maybe override pos-defaults applied above ba[name] = wrap_bound_arg(tx, args[i]) - elif name in rem_kw: - if name in spec.posonly_names: - raise_observed_exception( - TypeError, - tx, - args=[ConstantVariable.create(f"{name} is positional-only")], - ) + elif name in rem_kw and ( + # `kwargs` can have the same key as a pos-only arg `name`. + # If this case happens, we should not consume the `name` here and + # keep it in `kwargs`: + # >>> def fn(a, /, **kwargs): return (a, kwargs) + # >>> fn(1, a=2) + # (1, {'a': 2}) + name not in spec.posonly_names + ): + # Maybe override pos-defaults applied above ba[name] = wrap_bound_arg(tx, rem_kw.pop(name)) - elif name in spec.pos_default_map: - idx = spec.pos_default_map[name] - default_source = None - if fn_source and not ( - ConstantVariable.is_literal(spec.defaults[idx]) - and config.skip_guards_on_constant_func_defaults - ): - default_source = DefaultsSource(fn_source, idx) - ba[name] = wrap_bound_arg(tx, spec.defaults[idx], default_source) - else: + elif name not in ba: raise_observed_exception( TypeError, tx, diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 7b1a1cc83dbc9..c330a700fd66b 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -2595,7 +2595,7 @@ def _call_function( from torch.utils.checkpoint import noop_context_fn context_fn = None - if "context_fn" in kwargs and kwargs["context_fn"] != noop_context_fn: + if "context_fn" in kwargs and kwargs["context_fn"] is not noop_context_fn: ctx = kwargs.pop("context_fn") if isinstance(ctx, torch._dynamo.variables.UserFunctionVariable): context_fn = ctx.fn @@ -3560,7 +3560,7 @@ def should_wrap_in_hop(cls, value): if type(value) is not type(_local_map_wrapped): return False - return value == _local_map_wrapped and cls._enabled + return value is _local_map_wrapped and cls._enabled @staticmethod def build(**options): diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index be1b7bf433f3a..5970ba0e1dda7 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ This module provides iterator-related variable tracking functionality for Dynamo. It implements variable classes for handling Python iterators and itertools functions @@ -16,7 +14,8 @@ """ import itertools -from typing import TYPE_CHECKING, Union +from collections.abc import Callable +from typing import Any, Sequence, TYPE_CHECKING, Union from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import ( @@ -45,20 +44,20 @@ class ItertoolsVariable(VariableTracker): - def __init__(self, value, **kwargs) -> None: + def __init__(self, value: Any, **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value def __repr__(self) -> str: return f"ItertoolsVariable({self.value})" - def as_python_constant(self): + def as_python_constant(self) -> Any: return self.value def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", + args: Sequence["VariableTracker"], kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": # See also: module `torch._dynamo.polyfills.itertools` @@ -111,7 +110,7 @@ def call_function( hints=[*graph_break_hints.USER_ERROR], ) - def retrieve_const_key(key): + def retrieve_const_key(key: VariableTracker) -> Any: if isinstance(key, variables.SymNodeVariable): return key.evaluate_expr() elif isinstance(key, variables.ConstantVariable): @@ -144,18 +143,19 @@ def retrieve_const_key(key): if "key" in kwargs: - def keyfunc(x): + def keyfunc(x: VariableTracker) -> Any: return retrieve_const_key( - kwargs.get("key").call_function(tx, [x], {}) + kwargs.get("key").call_function(tx, [x], {}) # type: ignore[union-attr] ) else: - def keyfunc(x): + def keyfunc(x: VariableTracker) -> Any: return retrieve_const_key(x) result = [] try: + # pyrefly: ignore [unbound-name] for k, v in itertools.groupby(seq, key=keyfunc): result.append( variables.TupleVariable( @@ -219,10 +219,10 @@ def keyfunc(x): class IteratorVariable(VariableTracker): - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - def next_variable(self, tx): + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: unimplemented_v2( gb_type="Unimplemented next() call", context=f"next({self})", @@ -234,12 +234,16 @@ def next_variable(self, tx): # Normally, iterators are accessed lazily. # Example of safe eager unpacking: list(map(f, seq)) # Example of unsafe eager unpacking: list(islice(map(f, seq), 5)) - def force_unpack_var_sequence(self, tx) -> list[VariableTracker]: - result = [] + def force_unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list[VariableTracker]: + result: list[VariableTracker] = [] self.force_apply_to_var_sequence(tx, result.append) return result - def force_apply_to_var_sequence(self, tx, fn) -> None: + def force_apply_to_var_sequence( + self, tx: "InstructionTranslator", fn: Callable[[Any], Any] + ) -> None: while True: try: fn(self.next_variable(tx)) @@ -249,9 +253,29 @@ def force_apply_to_var_sequence(self, tx, fn) -> None: # don't call force_unpack_var_sequence since it can mutate # IteratorVariable state! - def has_force_unpack_var_sequence(self, tx) -> bool: + def has_force_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: return True + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "VariableTracker": + if name == "__iter__" or name == "__next__": + return variables.ConstantVariable.create(True) + return super().call_obj_hasattr(tx, name) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__iter__": + return self + elif name == "__next__": + return self.next_variable(tx) + return super().call_method(tx, name, args, kwargs) + class ObjectIteratorVariable(IteratorVariable): """ @@ -267,12 +291,12 @@ class ObjectIteratorVariable(IteratorVariable): > list(b) # empty list """ - def __init__(self, obj: VariableTracker, **kwargs): + def __init__(self, obj: VariableTracker, **kwargs: Any) -> None: super().__init__(**kwargs) self.obj = obj self.generator_exhausted = False - def next_variable(self, tx): + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: if self.generator_exhausted: raise_observed_exception(StopIteration, tx) @@ -286,15 +310,15 @@ def next_variable(self, tx): class RepeatIteratorVariable(IteratorVariable): - def __init__(self, item: VariableTracker, **kwargs) -> None: + def __init__(self, item: VariableTracker, **kwargs: Any) -> None: super().__init__(**kwargs) self.item = item # Repeat needs no mutation, clone self - def next_variable(self, tx): + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: return self.item - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.extend_output( [ @@ -308,7 +332,12 @@ def reconstruct(self, codegen: "PyCodegen"): class CountIteratorVariable(IteratorVariable): - def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None: + def __init__( + self, + item: Union[int, VariableTracker] = 0, + step: Union[int, VariableTracker] = 1, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) if not isinstance(item, VariableTracker): item = ConstantVariable.create(item) @@ -317,14 +346,14 @@ def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None: self.item = item self.step = step - def next_variable(self, tx): + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: assert self.is_mutable() old_item = self.item tx.output.side_effects.mutation(self) self.item = self.item.call_method(tx, "__add__", [self.step], {}) return old_item - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.extend_output( [ @@ -353,7 +382,7 @@ def __init__( self, iterables: list[VariableTracker], strict: bool = False, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(**kwargs) assert isinstance(iterables, list) @@ -362,16 +391,18 @@ def __init__( self.index = 0 self.strict = strict - def python_type(self): + def python_type(self) -> type[zip]: # type: ignore[type-arg] return zip - def has_unpack_var_sequence(self, tx) -> bool: + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: return all( isinstance(it, list) or it.has_unpack_var_sequence(tx) for it in self.iterables ) - def unpack_var_sequence(self, tx) -> list["VariableTracker"]: + def unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list["VariableTracker"]: assert self.has_unpack_var_sequence(tx) iterables = [] for it in self.iterables: @@ -383,7 +414,7 @@ def unpack_var_sequence(self, tx) -> list["VariableTracker"]: zipped = zip(*iterables, **kwargs) return [variables.TupleVariable(list(var)) for var in zipped] - def next_variable(self, tx): + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: assert self.is_mutable() if len(self.iterables) == 0: @@ -392,7 +423,9 @@ def next_variable(self, tx): old_index = self.index args = [] - def get_item(it): + def get_item( + it: Union[list[VariableTracker], VariableTracker], + ) -> VariableTracker: if isinstance(it, list): if old_index >= len(it): raise_observed_exception(StopIteration, tx) @@ -421,7 +454,7 @@ def get_item(it): raise handle_observed_exception(tx) raise UserError( - ValueError, + ValueError, # type: ignore[arg-type] "zip() has one argument of len differing from others", ) from None raise @@ -430,7 +463,7 @@ def get_item(it): self.index += 1 return variables.TupleVariable(args) - def reconstruct_items(self, codegen: "PyCodegen"): + def reconstruct_items(self, codegen: "PyCodegen") -> None: for it in self.iterables: if isinstance(it, list): remaining_items = it[self.index :] @@ -439,7 +472,7 @@ def reconstruct_items(self, codegen: "PyCodegen"): else: codegen(it) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True ) @@ -463,23 +496,23 @@ class MapVariable(ZipVariable): def __init__( self, fn: VariableTracker, - iterables: list[Union[list[VariableTracker], VariableTracker]], - **kwargs, + iterables: list[VariableTracker], + **kwargs: Any, ) -> None: super().__init__(iterables, **kwargs) self.fn = fn - def python_type(self): + def python_type(self) -> type: return map - def has_unpack_var_sequence(self, tx) -> bool: + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: return False - def next_variable(self, tx): + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: args = super().next_variable(tx) - return self.fn.call_function(tx, args.items, {}) + return self.fn.call_function(tx, args.items, {}) # type: ignore[attr-defined] - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True ) @@ -506,23 +539,25 @@ class FilterVariable(IteratorVariable): def __init__( self, fn: VariableTracker, - iterable: Union[list[VariableTracker], VariableTracker], - **kwargs, + iterable: list[VariableTracker], + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.fn = fn self.iterable = iterable self.index = 0 - def python_type(self): + def python_type(self) -> type: return filter - def has_unpack_var_sequence(self, tx) -> bool: + def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: return isinstance(self.iterable, list) or self.iterable.has_unpack_var_sequence( tx ) - def unpack_var_sequence(self, tx) -> list["VariableTracker"]: + def unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list["VariableTracker"]: assert self.has_unpack_var_sequence(tx) it = None if isinstance(self.iterable, list): @@ -532,8 +567,8 @@ def unpack_var_sequence(self, tx) -> list["VariableTracker"]: filtered = self.fn.call_function(tx, it, {}) return [variables.TupleVariable([filtered])] - def next_variable(self, tx): - def _next(): + def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: + def _next() -> VariableTracker: old_index = self.index if isinstance(self.iterable, list): if old_index >= len(self.iterable): @@ -556,7 +591,7 @@ def _next(): if pred_res.as_python_constant(): return item - def reconstruct_items(self, codegen: "PyCodegen"): + def reconstruct_items(self, codegen: "PyCodegen") -> None: if isinstance(self.iterable, list): remaining_items = self.iterable[self.index :] codegen.foreach(remaining_items) @@ -564,7 +599,7 @@ def reconstruct_items(self, codegen: "PyCodegen"): else: codegen(self.iterable) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null(lambda: codegen.load_import_from("builtins", "filter")) codegen(self.fn) self.reconstruct_items(codegen) diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 909892b3da686..11a199e99eadc 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -294,6 +294,8 @@ def call_method( [variables.BuiltinVariable(cmp_name_to_op_mapping[name]), left, right], {}, ) + elif name == "__iter__": + return ListIteratorVariable(self.items, mutation_type=ValueMutationNew()) return super().call_method(tx, name, args, kwargs) @@ -472,9 +474,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def call_obj_hasattr( self, tx: "InstructionTranslator", name: str ) -> "VariableTracker": - if self.python_type() is not range: - return super().call_obj_hasattr(tx, name) - return variables.ConstantVariable.create(hasattr(range(0), name)) + if self.python_type() is range: + return variables.ConstantVariable.create(name in range.__dict__) + return super().call_obj_hasattr(tx, name) def range_equals(self, other: "RangeVariable"): r0, r1 = self, other @@ -1064,6 +1066,13 @@ def call_method( self.items[:] = self.items[slice_within_maxlen] return result + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> "VariableTracker": + if self.python_type() is collections.deque: + return variables.ConstantVariable.create(name in collections.deque.__dict__) + return super().call_obj_hasattr(tx, name) + class TupleVariable(BaseListVariable): def python_type(self): @@ -1584,6 +1593,7 @@ def __init__(self, items, index: int = 0, **kwargs) -> None: # assert all(isinstance(x, VariableTracker) for x in items) self.items = items self.index = index + self.is_exhausted = False def __repr__(self) -> str: return f"{self.__class__.__name__}(length={len(self.items)}, index={repr(self.index)})" @@ -1591,7 +1601,8 @@ def __repr__(self) -> str: def next_variable(self, tx): assert self.is_mutable() old_index = self.index - if old_index >= len(self.items): + if old_index >= len(self.items) or self.is_exhausted: + self.is_exhausted = True raise_observed_exception(StopIteration, tx) tx.output.side_effects.mutation(self) @@ -1613,15 +1624,19 @@ def has_unpack_var_sequence(self, tx): return True def unpack_var_sequence(self, tx): - r = list(self.items[self.index :]) - self.index = len(self.items) - return r + if self.is_exhausted: + return [] + self.is_exhausted = True + return list(self.items[self.index :]) def force_unpack_var_sequence(self, tx) -> list[VariableTracker]: return self.unpack_var_sequence(tx) def reconstruct(self, codegen: "PyCodegen") -> None: - remaining_items = self.items[self.index :] + if not self.is_exhausted: + remaining_items = self.items[self.index :] + else: + remaining_items = [] codegen.foreach(remaining_items) codegen.extend_output( [ diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index a8653ffda2f97..4c099b6644902 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -1346,7 +1346,7 @@ def call_method( if name == "__getitem__" and len(args) == 1: new_typing = self.value[args[0].as_python_constant()] return TypingVariable(new_typing) - unimplemented("unsupported method call on typing variablel") + unimplemented("unsupported method call on typing variable") def var_getattr(self, tx: "InstructionTranslator", name: str): from .builder import SourcelessBuilder, VariableBuilder @@ -1368,6 +1368,8 @@ def as_python_constant(self): return self.value def reconstruct(self, codegen: "PyCodegen") -> None: + if not isinstance(self.value, types.GenericAlias): + return super().reconstruct(codegen) # We're just trying to load the type here. Reconstructing the type from # scratch is tricky - for a type like `typing.List[int]` we'd need to # deconstruct the origin and args. The origin for `List[int]` is `list` diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 1e294950d9283..794fdf607220a 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -796,6 +796,10 @@ def gen_source(source, name): f"{len(args)} args and {len(kwargs)} kwargs", ) return ConstantVariable.create(len(module)) + elif name == "__iter__": + return ListIteratorVariable( + self.unpack_var_sequence(tx), mutation_type=ValueMutationNew() + ) elif ( name == "__contains__" and isinstance(module, (torch.nn.ModuleDict, torch.nn.ParameterDict)) diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py index 776f7f34d9c37..289cebbe8129b 100644 --- a/torch/_dynamo/variables/optimizer.py +++ b/torch/_dynamo/variables/optimizer.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ This module implements variable tracking for PyTorch optimizers during Dynamo tracing. @@ -24,9 +22,11 @@ import logging import weakref -from typing import TYPE_CHECKING +from typing import Any, Iterable, Optional, TYPE_CHECKING import torch +from torch._dynamo.variables.tensor import TensorVariable +from torch._guards import Source from torch._logging import getArtifactLogger from torch.utils._pytree import tree_map_only @@ -63,13 +63,14 @@ class GuardInstallException(Exception): perf_hint_log = getArtifactLogger(__name__, "perf_hints") -def _is_static_for_cudagraphs(x): +def _is_static_for_cudagraphs(x: torch.Tensor) -> bool: from torch._inductor.cudagraph_trees import get_manager if x.is_cuda: manager = get_manager(x.device.index, False) is_static_address = torch._dynamo.utils.get_static_address_type(x) is not None if manager: + assert manager.current_node is not None return ( is_static_address or manager.current_node._is_cuda_graph_recorded_tensor(x) @@ -91,26 +92,31 @@ class OptimizerVariable(UserDefinedObjectVariable): def __init__( self, - value, - grad_to_source=None, - static_tensor_names=None, - tensor_to_source=None, - **kwargs, + value: torch.optim.Optimizer, + grad_to_source: Optional[dict[Any, GradSource]] = None, + static_tensor_names: Optional[set[str]] = None, + tensor_to_source: Optional[dict[torch.Tensor, Source]] = None, + **kwargs: Any, ) -> None: super().__init__(value, **kwargs) + # pyrefly: ignore [bad-override] + self.value: torch.optim.Optimizer = value self.grad_to_source = grad_to_source or {} self.tensor_to_source = tensor_to_source or {} self.static_tensor_names = static_tensor_names or set() def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], ) -> "VariableTracker": """This is an optimization to avoid tracing the very slow initialization of the optimizer""" if name == "_init_group": + if not hasattr(self.value, "_init_group"): + # Fallback: if the optimizer does not have _init_group, trace normally + return super().call_method(tx, name, args, kwargs) try: self.graph_break_if_pending_mutation(tx) self.move_step_if_cpu() @@ -135,11 +141,12 @@ def call_method( return super().call_method(tx, name, args, kwargs) - def var_getattr(self, tx: "InstructionTranslator", name): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: # Note: this allows us to intercept the call in call_method # in the typical case, we return a UserMethodVariable # which will directly inline if name in ("_init_group", "step"): + assert self.source return GetAttrVariable(self, name, source=AttrSource(self.source, name)) if name == "param_groups": @@ -153,7 +160,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): return super().var_getattr(tx, name) - def graph_break_if_pending_mutation(self, tx): + def graph_break_if_pending_mutation(self, tx: "InstructionTranslator") -> None: # If there are pending mutations on a parameter (due to using closure) # then we need to graph break to allow the python version of the parameter # to update, so that running _init_group will initialize the states with @@ -167,12 +174,12 @@ def graph_break_if_pending_mutation(self, tx): raise Unsupported("Pending mutation on parameter") - def _set_capturable(self, tx): + def _set_capturable(self, tx: "InstructionTranslator") -> None: from . import LazyVariableTracker # We only set capturable if params are on cuda # and the state is not initialized - def safe_to_set_capturable(group): + def safe_to_set_capturable(group: dict[str, Any]) -> bool: all_uninitialized = True all_gpu = True @@ -199,10 +206,12 @@ def safe_to_set_capturable(group): ) param_group_vt.items[key] = ConstantVariable.create(True) - def get_python_args(self, *args, **kwargs): + def get_python_args( + self, *args: Any, **kwargs: Any + ) -> tuple[list[Any], dict[str, Any]]: """Get python values equivalent to the variable tracker args""" - def map_arg(arg): + def map_arg(arg: Any) -> Any: if isinstance(arg, ConstantVariable): return arg.as_python_constant() elif isinstance(arg, ListVariable) and not arg.items: @@ -227,19 +236,19 @@ def map_arg(arg): # if this is the case, move it to the GPU # corresponding to the parameter # in most cases this is a no-op because the state is empty - def move_step_if_cpu(self): + def move_step_if_cpu(self) -> None: for p, state in self.value.state.items(): if "step" in state and state["step"].is_cpu: state["step"] = state["step"].to(p.device) - def map_sources_and_install_guards(self, tx): + def map_sources_and_install_guards(self, tx: "InstructionTranslator") -> None: from ..decorators import mark_static_address from .lazy import LazyVariableTracker self.grad_to_source = {} self.tensor_to_source = {} - def mark_static(x): + def mark_static(x: Any) -> None: mark_static_address(x, guard=True) tree_map_only(torch.Tensor, mark_static, self.value.state) @@ -252,12 +261,12 @@ def mark_static(x): ) state_source = self.source and AttrSource(self.source, "state") - state_vt = VariableTracker.build(tx, self.value.state, state_source) # We need to realize the top level state dict to populate # the guard locals state_vt.realize() + assert state_source is not None tx.output.guard_on_key_order.add(state_source) # Populate self.grad_to_source and self.tensor_to_source so that we can @@ -289,9 +298,7 @@ def mark_static(x): params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params")) all_static = True non_static_grads = [] - for p_ind, (p, p_vt) in enumerate( - zip(group["params"], params_vt.unpack_var_sequence(tx)) - ): + for p, p_vt in zip(group["params"], params_vt.unpack_var_sequence(tx)): param_source = p_vt.source self.tensor_to_source[p] = param_source grad_source = GradSource( @@ -310,24 +317,24 @@ def mark_static(x): # Note: to avoid spam logs only warn if perf hint artifact is enabled # (NB: artifacts are only enabled at the debug or warning level) if not all_static and perf_hint_log.isEnabledFor(logging.DEBUG): - non_static_grads = [src.name() for src in non_static_grads] + non_static_grad_names = [src.name() for src in non_static_grads] perf_hint_log.warning( ( "Grad tensors %s will be copied during cudagraphs execution." "If using cudagraphs and the grad tensor addresses will be the same across runs," " use torch._dynamo.decorators.mark_static_address to elide this copy.", ), - non_static_grads, + non_static_grad_names, ) # We have to again iterate over the state dict to collect the # tensor_to_source dict. This is used for the finalizer. - for idx, (p, value) in enumerate(self.value.state.items()): + for idx, value in enumerate(self.value.state.values()): p_state_source = DictGetItemSource( state_source, ConstDictKeySource(state_source, idx) ) tx.output.guard_on_key_order.add(p_state_source) - for inner_idx, (k, v) in enumerate(value.items()): + for inner_idx, v in enumerate(value.values()): if ( isinstance(v, torch.Tensor) and v not in self.grad_to_source @@ -337,7 +344,9 @@ def mark_static(x): p_state_source, ConstDictKeySource(p_state_source, inner_idx) ) - def wrap_tensor(self, tx: "InstructionTranslator", tensor_value): + def wrap_tensor( + self, tx: "InstructionTranslator", tensor_value: torch.Tensor + ) -> TensorVariable: """Wrap state tensor in a TensorVariable""" from ..decorators import mark_static_address @@ -364,8 +373,13 @@ def wrap_tensor(self, tx: "InstructionTranslator", tensor_value): return VariableTracker.build(tx, tensor_value, source) def update_list_args( - self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs - ): + self, + tx: "InstructionTranslator", + args: Iterable[VariableTracker], + kwargs: Any, + py_args: Iterable[Any], + py_kwargs: Any, + ) -> None: """Update the args and kwargs to the traced optimizer call""" for arg, py_arg in zip(args, py_args): if isinstance(arg, ListVariable): @@ -380,13 +394,13 @@ def update_list_args( source = arg.source and GetItemSource(arg.source, i) arg.items.append(VariableTracker.build(tx, val, source)) - def create_finalizer(self, tx): + def create_finalizer(self, tx: "InstructionTranslator") -> None: names_to_delete = self.static_tensor_names value = self.value tc = tx.output.tracing_context - def init_finalizer(gm): - def clear_static_tensor_refs(): + def init_finalizer(gm: torch.fx.GraphModule) -> None: + def clear_static_tensor_refs() -> None: for name in names_to_delete: gm._buffers.pop(name, None) gm._parameters.pop(name, None) diff --git a/torch/_dynamo/variables/script_object.py b/torch/_dynamo/variables/script_object.py index a120ab488ed95..85977104977fb 100644 --- a/torch/_dynamo/variables/script_object.py +++ b/torch/_dynamo/variables/script_object.py @@ -1,6 +1,3 @@ -# mypy: allow-untyped-decorators -# mypy: allow-untyped-defs - """ This module implements variable tracking for TorchScript objects during Dynamo tracing. @@ -22,8 +19,13 @@ """ import functools +from collections.abc import Callable +from typing import Any, Iterable, TYPE_CHECKING, TypeVar +from typing_extensions import ParamSpec import torch +from torch._guards import Source +from torch.fx.proxy import Proxy from .. import graph_break_hints from ..exc import unimplemented_v2, UnsafeScriptObjectError, Unsupported @@ -31,10 +33,19 @@ from .user_defined import UserDefinedObjectVariable -def _raise_hard_error_if_graph_break(reason): - def deco(fn): +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +def _raise_hard_error_if_graph_break( + reason: str, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def deco(fn: Callable[_P, _T]) -> Callable[_P, _T]: @functools.wraps(fn) - def graph_break_as_hard_error(*args, **kwargs): + def graph_break_as_hard_error(*args: _P.args, **kwargs: _P.kwargs) -> _T: try: return fn(*args, **kwargs) except Unsupported as e: @@ -49,26 +60,26 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable): _fake_script_object_cache: dict[int, "TorchScriptObjectVariable"] = {} @classmethod - def is_matching_cls(cls, user_cls: type): + def is_matching_cls(cls, user_cls: type) -> bool: return issubclass(user_cls, torch.ScriptObject) @staticmethod - def create(proxy, value, **options): + def create(proxy: Proxy, value: Any, **options: Any) -> "TorchScriptObjectVariable": return TorchScriptObjectVariable(proxy, value, **options) - def __init__(self, proxy, value, source, **kwargs) -> None: + def __init__(self, proxy: Proxy, value: Any, source: Source, **kwargs: Any) -> None: super().__init__(value, **kwargs) self.proxy = proxy self.proxy.node.meta["example_value"] = value self.source = source - def as_proxy(self): + def as_proxy(self) -> Proxy: return self.proxy @_raise_hard_error_if_graph_break( "Dynamo cannot safely trace script object due to graph break." ) - def var_getattr(self, tx, name: str) -> VariableTracker: + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: from torch._higher_order_ops.torchbind import call_torchbind from ..source import AttrSource @@ -95,7 +106,7 @@ def var_getattr(self, tx, name: str) -> VariableTracker: "Use method calls instead of attribute access.", ], ) - + assert self.source is not None return TorchHigherOrderOperatorVariable.make( call_torchbind, source=AttrSource(self.source, name), @@ -110,7 +121,13 @@ def var_getattr(self, tx, name: str) -> VariableTracker: @_raise_hard_error_if_graph_break( "Dynamo cannot safely trace script object due to graph break." ) - def call_method(self, tx, name, args, kwargs): + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: Iterable[Any], + kwargs: dict[str, Any], + ) -> VariableTracker: unimplemented_v2( gb_type="Weird method call on TorchScript object", context=f"value={self.value}, method={name}", diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py index e63edf8e2b036..75928842cf297 100644 --- a/torch/_dynamo/variables/sdpa.py +++ b/torch/_dynamo/variables/sdpa.py @@ -1,7 +1,9 @@ -# mypy: ignore-errors - from inspect import getattr_static -from typing import TYPE_CHECKING +from typing import Any, Sequence, TYPE_CHECKING, TypeGuard + +from torch._guards import Source +from torch.backends.cuda import SDPAParams +from torch.fx.proxy import Proxy from ..bytecode_transformation import create_call_function from ..exc import Unsupported @@ -29,9 +31,9 @@ class SDPAParamsVariable(VariableTracker): This is a read-only container.""" @staticmethod - def create(tx: "InstructionTranslator", value, source): - from torch.backends.cuda import SDPAParams - + def create( + tx: "InstructionTranslator", value: Any, source: Source + ) -> VariableTracker: from .torch import TorchInGraphFunctionVariable params = [ @@ -40,12 +42,14 @@ def create(tx: "InstructionTranslator", value, source): ] return TorchInGraphFunctionVariable(SDPAParams).call_function(tx, params, {}) - def __init__(self, proxy, param_vars, **kwargs) -> None: + def __init__( + self, proxy: Proxy, param_vars: Sequence[VariableTracker], **kwargs: Any + ) -> None: self.proxy = proxy self.param_vars = param_vars super().__init__(**kwargs) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: assert self.source is None assert self.param_vars is not None codegen.add_push_null( @@ -54,7 +58,7 @@ def reconstruct(self, codegen: "PyCodegen"): codegen.foreach(self.param_vars) codegen.extend_output(create_call_function(len(self.param_vars), False)) - def as_proxy(self): + def as_proxy(self) -> Proxy: return self.proxy def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: @@ -80,7 +84,5 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker return wrap_fx_proxy(tx=tx, proxy=proxy) @staticmethod - def is_sdpa_params(value): - from torch.backends.cuda import SDPAParams - + def is_sdpa_params(value: Any) -> TypeGuard["SDPAParams"]: return value is SDPAParams diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index cc79e1467264f..fbc0eed3a99ff 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -1,15 +1,18 @@ -from typing import Any, Optional +import collections +from typing import Any, Callable, Optional import torch +from torch._dynamo.variables.dicts import ConstDictVariable +from torch._dynamo.variables.lists import TupleVariable from torch.fx import Proxy from .. import graph_break_hints -from ..device_interface import get_interface_for_device +from ..bytecode_transformation import create_call_function from ..exc import TYPE_CHECKING, unimplemented_v2 from .base import VariableTracker from .constant import ConstantVariable -from .ctx_manager import ContextWrappingVariable -from .misc import GetAttrVariable +from .ctx_manager import FxTracebackAnnotateVariable +from .lazy import LazyVariableTracker if TYPE_CHECKING: @@ -63,103 +66,89 @@ def _( pass -class StreamContextVariable(ContextWrappingVariable): +class SymbolicStreamState: + """Track the currently entered stream if any""" + + def __init__(self) -> None: + from ..source import CurrentStreamSource + + cur_stack: list[StreamVariable] = [] + if torch.accelerator.is_available(): + stream_var = LazyVariableTracker.create( + torch.accelerator.current_stream(), + source=CurrentStreamSource(torch.accelerator.current_stream().device), + ) + cur_stack = [stream_var] # type: ignore[list-item] + + self.cur_stream_stack: collections.deque[StreamVariable] = collections.deque( + cur_stack + ) + + def enter_stream(self, stream: "StreamVariable") -> None: + self.cur_stream_stack.append(stream) + + def exit_stream(self) -> None: + self.cur_stream_stack.pop() + + def cur_stream(self, device: Optional[torch.device] = None) -> "StreamVariable": + if device is not None: + for stream in reversed(self.cur_stream_stack): + if stream.device == device: + return stream + + return self.cur_stream_stack[-1] + + def in_stream_context(self) -> bool: + return len(self.cur_stream_stack) > 0 + + +class StreamContextVariable(FxTracebackAnnotateVariable): """This represents torch.cuda.StreamContext""" @staticmethod def create( tx: "InstructionTranslator", - target_value: "StreamVariable", + stream_to_enter: "StreamVariable", **kwargs: dict[str, Any], ) -> "StreamContextVariable": return StreamContextVariable( - target_values=[target_value], - initial_values=[ - StreamContextVariable._get_current_stream(target_value.device, tx) - ], - device=target_value.device, + stream_to_enter, **kwargs, ) def __init__( self, - target_values: list["StreamVariable"], - device: torch.device, - initial_values: Optional[list["StreamVariable"]] = None, + stream: Optional["StreamVariable"], **kwargs: dict[str, Any], ) -> None: + self.stream = stream super().__init__( - target_values=target_values, initial_values=initial_values, **kwargs + target_values={"stream": self.get_stream().user_object_index}, + initial_values=None, + **kwargs, ) - self.device = device - def enter(self, tx: "InstructionTranslator") -> "VariableTracker": + def enter( + self, tx: "InstructionTranslator", *args: tuple[Any] + ) -> "VariableTracker": # to stream, from stream is the order of the arguments # we are entering the target, and leaving the initial stream - tx.output.create_proxy( - "call_function", - torch.ops.streams.fork.default, - self._target_stream_proxies() + self._initial_stream_proxies(), - {}, - ) - return ConstantVariable.create(None) + tx.symbolic_stream_state.enter_stream(self.get_stream()) + return super().enter(tx) def exit(self, tx: "InstructionTranslator", *args: tuple[Any]) -> "VariableTracker": # to stream, from stream is the order of the arguments # we are leaving the target, and entering the initial stream - tx.output.create_proxy( - "call_function", - torch.ops.streams.join.default, - self._initial_stream_proxies() + self._target_stream_proxies(), - {}, - ) - return ConstantVariable.create(None) - - def _initial_stream_proxies(self) -> tuple[Proxy, Proxy]: - assert self.initial_values, "No initial stream to move from" - return StreamContextVariable._extract_stream_properties( - self.initial_values[0].as_proxy() - ) - - def _target_stream_proxies(self) -> tuple[Proxy, Proxy]: - return StreamContextVariable._extract_stream_properties( - self._get_target_values()[0].as_proxy() - ) - - @staticmethod - def _extract_stream_properties(stream_proxy: Proxy) -> tuple[Proxy, Proxy]: - stream_index = GetAttrVariable.create_getattr_proxy(stream_proxy, "stream_id") - stream_device = GetAttrVariable.create_getattr_proxy(stream_proxy, "device") - return stream_index, stream_device - - @staticmethod - def _get_current_stream( - device: torch.device, tx: "InstructionTranslator" - ) -> "StreamVariable": - from .builder import wrap_fx_proxy_cls - - current_stream_method = get_interface_for_device(device).current_stream - current_stream = wrap_fx_proxy_cls( - StreamVariable, - tx, - tx.output.create_proxy( - "call_function", - current_stream_method, - (None,), - {}, - ), - ) - return current_stream - - def _get_target_values(self) -> list["StreamVariable"]: - # We need this to be overridable, since StreamVariable does - # not store target values (it does not require any arguments) - # and captures the current stream at the time of entering the context - return self.target_values + tx.symbolic_stream_state.exit_stream() + return super().exit(tx, *args) def supports_graph_breaks(self) -> bool: return True + def get_stream(self) -> "StreamVariable": + assert self.stream, "Stream context should have a separate stream" + return self.stream + class StreamVariable(StreamContextVariable): """Represents the device-agnostic torch.Stream class""" @@ -168,19 +157,21 @@ def __init__( self, proxy: Proxy, value: torch.Stream, - device: torch.device, **kwargs: Any, ) -> None: + # Index into the user object table + # used to pass arbitrary objects to the graph + user_object_index = kwargs.pop("user_obj_index", None) if proxy is not None and "example_value" in proxy.node.meta: assert proxy.node.meta["example_value"] == value - assert value.device.type == device.type, ( - "stream value is not equal to the passed device" - ) - super().__init__(target_values=[], initial_values=None, device=device, **kwargs) + self.proxy = proxy self.value = value # pyrefly: ignore [read-only] - self.device = device + self.device = value.device + # pyrefly: ignore [read-only] + self.user_object_index = user_object_index + super().__init__(None, **kwargs) def python_type(self) -> type: return torch.Stream @@ -231,6 +222,7 @@ def call_method( return ConstantVariable.create(NotImplemented) if other.source: + assert self.source is not None install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) return ConstantVariable.create( cmp_name_to_op_mapping[name](self.value, other.value) # type: ignore[arg-type] @@ -238,15 +230,6 @@ def call_method( return super().call_method(tx, name, args, kwargs) - def enter(self, tx: "InstructionTranslator") -> "VariableTracker": - # NB: Set initial values when we enter - # Don't do this at object creation, as we need to record the current stream - # at the time the context is entered. - self.initial_values = [ - StreamContextVariable._get_current_stream(self.device, tx) - ] - return super().enter(tx) - def as_proxy(self) -> Proxy: return self.proxy @@ -260,18 +243,39 @@ def reconstruct(self, codegen: "PyCodegen") -> None: # If we got here, this stream is fully subsumed by the graph - this means it is # not an input or global assert not self.source - # Since we just proved that - for other such structures, like lists and dicts, reconstruction - # is fine and sound according to dynamo principles of treating collectives. However, - # streams are special in that we want to preserve the identity of the stream as the same as in the graph - # Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not - # yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending - # design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there. - prefix = f"_stream_{self.device}" - name = codegen.tx.output.install_global_by_id(prefix, self.value) - codegen.append_output(codegen.create_load_global(name, add=True)) + if self.user_object_index is not None: + codegen.add_push_null( + lambda: codegen.load_import_from( + torch._dynamo.graph_bytecode_inputs.__name__, + "get_external_object_by_index", + ) + ) + codegen.append_output(codegen.create_load_const(self.user_object_index)) + codegen.extend_output(create_call_function(1, False)) + else: + # TODO mlazos: evaluate if we still need this + prefix = f"_stream_{self.device}" + name = codegen.tx.output.install_global_by_id(prefix, self.value) + codegen.append_output(codegen.create_load_global(name, add=True)) + + def get_stream(self) -> "StreamVariable": + return self + + @staticmethod + def make_construct_in_graph_stream_fn( + args: TupleVariable, kwargs: ConstDictVariable + ) -> Callable[[int, "PyCodegen"], None]: + def fn(index: int, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.load_import_from( + torch._dynamo.utils.__name__, "build_stream" + ) + ) + codegen(args) + codegen(kwargs) + codegen.extend_output(create_call_function(2, False)) - def _get_target_values(self) -> list["StreamVariable"]: - return [self] + return fn class EventVariable(VariableTracker): diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 1e66d48cec495..1e8351115c079 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -66,9 +66,9 @@ set_example_value, tensortype_to_dtype, ) -from .base import AttributeMutationNew, VariableTracker +from .base import AttributeMutationNew, ValueMutationNew, VariableTracker from .constant import ConstantVariable -from .lists import SizeVariable +from .lists import ListIteratorVariable, SizeVariable from .user_defined import UserDefinedClassVariable @@ -427,7 +427,7 @@ def call_obj_hasattr(self, tx: "InstructionTranslator", name): # Today, var_getattr returns GetAttrVariable for both non-existent # attributes and existing attributes. This is a bug and requires more # deep dive. - if name in ("size", "stride"): + if name in ("size", "stride", "__iter__"): return ConstantVariable(True) try: @@ -1079,6 +1079,14 @@ def method___len__(self): tx = InstructionTranslator.current_tx() return self.call_method(tx, "size", [ConstantVariable.create(0)], {}) + def method___iter__(self): + from ..symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + return ListIteratorVariable( + self.unpack_var_sequence(tx), mutation_type=ValueMutationNew() + ) + def method_addcmul_(self, tensor1, tensor2, *, value=None): from ..symbolic_convert import InstructionTranslator @@ -1612,7 +1620,7 @@ def call_method( ), hints=[*graph_break_hints.FUNDAMENTAL], ) - if name in ["__len__", "size", "tolist"]: + if name in ["__len__", "size", "tolist", "__iter__"]: # delegate back to TensorVariable return super().call_method(tx, name, args, kwargs) if name in ("tostring", "tobytes", "__delattr__"): diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index e48a488101549..c2e3df8e4adce 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -834,12 +834,13 @@ def handle_addcdiv(self, tx: "InstructionTranslator", *args, **kwargs): @register(torch.full) def handle_full(self, tx, size, fill_value, **kwargs): if isinstance(fill_value, TensorVariable): - result = TorchInGraphFunctionVariable( - torch.ops.aten._local_scalar_dense - ).call_function(tx, [fill_value], {}) - return TorchInGraphFunctionVariable(torch.full).call_function( - tx, [size, result], kwargs + # Decompose: create empty tensor and fill it + # This avoids the scalar extraction at compile time + empty_result = TorchInGraphFunctionVariable(torch.empty).call_function( + tx, [size], kwargs ) + # Call fill_ method on the empty tensor + return empty_result.call_method(tx, "fill_", [fill_value], {}) @register(torch._foreach_lerp_) def handle_inplace_foreach_lerp_scalar( @@ -1268,6 +1269,35 @@ def handle_get_device_module(self, tx, *args, **kwargs): # pyrefly: ignore [unbound-name] return VariableTracker.build(tx, module, new_source) + @register(torch.accelerator.current_stream) + def handle_current_stream(self, tx: "InstructionTranslator", *args, **kwargs): + if len(args) + len(kwargs) > 1 or (kwargs and "device" not in kwargs): + unimplemented_v2( + gb_type="unsupported arguments to torch.accelerator.current_stream", + context=f"args={args}, kwargs={kwargs}", + explanation="torch.accelerator.current_stream accepts one optional argument `device`", + hints=[ + *graph_break_hints.USER_ERROR, + ], + ) + try: + if kwargs: + device = torch.device(kwargs["device"].as_python_constant()) + elif args: + device = torch.device(args[0].as_python_constant()) + else: + device = None + + return tx.symbolic_stream_state.cur_stream(device) + except Exception as e: + unimplemented_v2( + gb_type="bad device argument to torch.accelerator.current_stream", + context=f"args={args}, kwargs={kwargs}", + explanation="Expected valid string/torch.device argument ('cpu', 'cuda', etc.)", + hints=[*graph_break_hints.USER_ERROR], + from_exc=e, + ) + @register(torch.set_default_device) def handle_set_default_device( self, tx: "InstructionTranslator", *args, **kwargs @@ -1464,6 +1494,7 @@ def patched_fn(*args, **kwargs): ): # constant fold functions need to be guarded. if self.value in constant_fold_functions_need_guards: + assert self.source is not None source = CallFunctionNoArgsSource(self.source) install_guard(source.make_guard(GuardBuilder.EQUALS_MATCH)) # constant fold diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 817385ff149c0..71993a62434cc 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """TorchDynamo support for __torch_function__ tensor subclasses. This module implements support for tensor subclasses with __torch_function__ overrides. @@ -31,7 +29,8 @@ import functools import inspect import operator -from typing import TYPE_CHECKING +from types import TracebackType +from typing import Any, Generator, Iterable, Optional, TYPE_CHECKING import torch._C import torch.utils._pytree as pytree @@ -125,34 +124,134 @@ banned_attrs = [ - fn.__self__.__name__ + fn.__self__.__name__ # type: ignore[attr-defined] for fn in get_default_nowrap_functions() if is_tensor_base_attr_getter(fn) ] @functools.cache -def get_prev_stack_var_name(): +def get_prev_stack_var_name() -> str: from ..bytecode_transformation import unique_id return unique_id("___prev_torch_function_mode_stack") +class TorchFunctionModeVariable(GenericContextWrappingVariable): + @staticmethod + def is_supported_torch_function_mode(ty: type[TorchFunctionMode]) -> bool: + # Supported in this sense means we can support graph breaks under the + # context. + # We are able to trace custom modes but if there are graph breaks under them + # and they have a custom __enter__/__exit__ we don't handle this for the + # same reason we don't handle generic context managers: there may be side effects + # that are now affected by executing the function across two frames instead of one + # Today we support the enter/exit of the default TorchFunctionMode as well as + # DeviceContext (which is used for set_default_device) + return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( + not class_has_getattribute(ty) + and inspect.getattr_static(ty, "__enter__") is TorchFunctionMode.__enter__ + and inspect.getattr_static(ty, "__exit__") is TorchFunctionMode.__exit__ + ) + + def __init__( + self, + value: Optional[TorchFunctionMode], + source: Optional[Source] = None, + **kwargs: Any, + ): + if value is not None: + super().__init__(value, **kwargs) + self.value = value + self.cm_obj = value # needed for BC with calling enter from CM code + self.source = source # type: ignore[assignment] + + def reconstruct(self, codegen: "PyCodegen") -> None: + # This shouldn't be called unless we have a source + assert self.source + self.source.reconstruct(codegen) + + def module_name(self) -> str: + return self.value.__module__ + + def fn_name(self) -> str: + return type(self.value).__name__ + + def python_type(self) -> type: + return type(self.value) + + def call_torch_function( + self, + tx: "InstructionTranslator", + fn: VariableTracker, + types: TupleVariable, + args: Iterable[Any], + kwargs: dict[str, Any], + ) -> VariableTracker: + return call_torch_function( + tx, + get_torch_function_fn(tx, self), # type: ignore[arg-type] + fn, + types, + args, + kwargs, + ) + + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + from .torch import TorchInGraphFunctionVariable + + if isinstance(self.value, NoEnterTorchFunctionMode): + return ConstantVariable.create(None) + + TorchInGraphFunctionVariable( + torch._C._push_on_torch_function_stack + ).call_function(tx, [self], {}) + return ConstantVariable.create(None) + + def exit(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker: + from .torch import TorchInGraphFunctionVariable + + TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( + tx, [], {} + ) + return ConstantVariable.create(None) + + def reconstruct_type(self, codegen: "PyCodegen") -> None: + ty = NoEnterTorchFunctionMode + codegen( + AttrSource( + codegen.tx.import_source(ty.__module__), + ty.__name__, + ) + ) + + def supports_graph_breaks(self) -> bool: + return True + + def exit_on_graph_break(self) -> bool: + return False + + # Used to clear/restore the python torch function mode stack and temporarily restore it as needed class TorchFunctionModeStackStateManager: - def __init__(self): - self.stack = [] + def __init__(self) -> None: + self.stack: list[Any] = [] - def __enter__(self): + def __enter__(self) -> None: self.stack = torch.overrides._get_current_function_mode_stack() clear_torch_function_mode_stack() - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: set_torch_function_mode_stack(self.stack) self.stack = [] @contextlib.contextmanager - def temp_restore_stack(self): + def temp_restore_stack(self) -> Generator[None, None, None]: prev = torch.overrides._get_current_function_mode_stack() set_torch_function_mode_stack(self.stack) try: @@ -165,7 +264,7 @@ def temp_restore_stack(self): class SymbolicTorchFunctionState: - def __init__(self, py_stack): + def __init__(self, py_stack: Iterable[Any]) -> None: # This is annoyingly complicated because of how the torch function subclass + mode C API was designed # There are two exposed C knobs here as contexts: torch._C.DisableTorchFunction and torch._C.DisableTorchFunctionSubclass # These are their definitions: @@ -199,32 +298,41 @@ def __init__(self, py_stack): for i, val in enumerate(py_stack): self.mode_stack.append( - LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) + LazyVariableTracker.create(val, source=TorchFunctionModeStackSource(i)) # type: ignore[arg-type] ) - def in_torch_function_mode(self): + def in_torch_function_mode(self) -> bool: return len(self.mode_stack) > 0 - def pop_torch_function_mode(self): + def pop_torch_function_mode(self) -> TorchFunctionModeVariable: return self.mode_stack.pop() - def push_torch_function_mode(self, mode_var): + def push_torch_function_mode(self, mode_var: TorchFunctionModeVariable) -> None: self.mode_stack.append(mode_var) - def call_torch_function_mode(self, tx, fn, types, args, kwargs): + def call_torch_function_mode( + self, + tx: "InstructionTranslator", + fn: VariableTracker, + types: TupleVariable, + args: Iterable[Any], + kwargs: dict[str, Any], + ) -> Any: with self._pop_mode_for_inlining() as cur_mode: return cur_mode.call_torch_function(tx, fn, types, args, kwargs) @contextlib.contextmanager - def _pop_mode_for_inlining(self): + def _pop_mode_for_inlining( + self, + ) -> Generator[TorchFunctionModeVariable, None, None]: old_mode = self.cur_mode - self.cur_mode = self.pop_torch_function_mode() + self.cur_mode = self.pop_torch_function_mode() # type: ignore[assignment] try: - yield self.cur_mode + yield self.cur_mode # type: ignore[misc] finally: mode = self.cur_mode self.cur_mode = old_mode - self.push_torch_function_mode(mode) + self.push_torch_function_mode(mode) # type: ignore[arg-type] class TorchFunctionModeStackVariable(VariableTracker): @@ -244,16 +352,20 @@ class TorchFunctionModeStackVariable(VariableTracker): # each of the indices of other modes should be shifted left by 1 (-1) offset = 0 - def __init__(self, source, symbolic_stack): + def __init__( + self, + source: Source, + symbolic_stack: collections.deque[TorchFunctionModeVariable], + ) -> None: self.source = source self.symbolic_stack = symbolic_stack @classmethod - def reset(cls): + def reset(cls) -> None: cls.offset = 0 @classmethod - def register_mutation(cls, tx: "InstructionTranslator"): + def register_mutation(cls, tx: "InstructionTranslator") -> None: if cls.stack_value_singleton not in tx.output.side_effects: var = cls( source=Source(), @@ -263,7 +375,7 @@ def register_mutation(cls, tx: "InstructionTranslator"): tx.output.side_effects.mutation(var) @classmethod - def register_device_context_insertion(cls, tx: "InstructionTranslator"): + def register_device_context_insertion(cls, tx: "InstructionTranslator") -> None: stack = tx.symbolic_torch_function_state.mode_stack if stack and cls.is_device_context(stack[0]): return @@ -277,109 +389,28 @@ def register_device_context_insertion(cls, tx: "InstructionTranslator"): ) @classmethod - def clear_default_device(cls, tx: "InstructionTranslator"): + def clear_default_device(cls, tx: "InstructionTranslator") -> None: stack = tx.symbolic_torch_function_state.mode_stack if stack and cls.is_device_context(stack[0]): stack.popleft() cls.offset -= 1 @staticmethod - def is_device_context(var): + def is_device_context(var: TorchFunctionModeVariable) -> bool: return isinstance(var.value, DeviceContext) or var.value is None @classmethod - def get_mode_index(cls, ind): + def get_mode_index(cls, ind: int) -> int: return ind + cls.offset -class TorchFunctionModeVariable(GenericContextWrappingVariable): - @staticmethod - def is_supported_torch_function_mode(ty): - # Supported in this sense means we can support graph breaks under the - # context. - # We are able to trace custom modes but if there are graph breaks under them - # and they have a custom __enter__/__exit__ we don't handle this for the - # same reason we don't handle generic context managers: there may be side effects - # that are now affected by executing the function across two frames instead of one - # Today we support the enter/exit of the default TorchFunctionMode as well as - # DeviceContext (which is used for set_default_device) - return issubclass(ty, (NoEnterTorchFunctionMode, DeviceContext)) or ( - not class_has_getattribute(ty) - and inspect.getattr_static(ty, "__enter__") == TorchFunctionMode.__enter__ - and inspect.getattr_static(ty, "__exit__") == TorchFunctionMode.__exit__ - ) - - def __init__(self, value, source=None, **kwargs): - if value is not None: - super().__init__(value, **kwargs) - self.value = value - self.cm_obj = value # needed for BC with calling enter from CM code - self.source = source - - def reconstruct(self, codegen: "PyCodegen"): - # This shouldn't be called unless we have a source - assert self.source - self.source.reconstruct(codegen) - - def module_name(self): - return self.value.__module__ - - def fn_name(self): - return type(self.value).__name__ - - def python_type(self): - return type(self.value) - - def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): - return call_torch_function( - tx, - get_torch_function_fn(tx, self), - fn, - types, - args, - kwargs, - ) - - def enter(self, tx): - from .torch import TorchInGraphFunctionVariable - - if isinstance(self.value, NoEnterTorchFunctionMode): - return ConstantVariable.create(None) - - TorchInGraphFunctionVariable( - torch._C._push_on_torch_function_stack - ).call_function(tx, [self], {}) - return ConstantVariable.create(None) - - def exit(self, tx: "InstructionTranslator", *args): - from .torch import TorchInGraphFunctionVariable - - TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( - tx, [], {} - ) - return ConstantVariable.create(None) - - def reconstruct_type(self, codegen: "PyCodegen"): - ty = NoEnterTorchFunctionMode - codegen( - AttrSource( - codegen.tx.import_source(ty.__module__), - ty.__name__, - ) - ) - - def supports_graph_breaks(self): - return True - - def exit_on_graph_break(self): - return False - - -def _get_all_args(args, kwargs): +def _get_all_args( + args: Iterable[Any], kwargs: dict[str, Any] +) -> Iterable[VariableTracker]: return _flatten_vts(pytree.arg_tree_leaves(*args, **kwargs)) -def _flatten_vts(vts): +def _flatten_vts(vts: Iterable[VariableTracker]) -> list[VariableTracker]: from collections import deque from .dicts import ConstDictVariable @@ -391,7 +422,7 @@ def _flatten_vts(vts): while vts: vt = vts.popleft() - if not vt.is_realized() and vt.peek_type() in (dict, list, tuple): + if not vt.is_realized() and vt.peek_type() in (dict, list, tuple): # type: ignore[attr-defined] vt.realize() if vt.is_realized(): @@ -407,21 +438,28 @@ def _flatten_vts(vts): return output -def _get_subclass_type(var): +def _get_subclass_type(var: VariableTracker) -> type: assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)) return var.python_type() -def _get_subclass_type_var(tx: "InstructionTranslator", var): - assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)) +def _get_subclass_type_var( + tx: "InstructionTranslator", var: VariableTracker +) -> VariableTracker: if isinstance(var, TensorWithTFOverrideVariable): return var.class_type_var(tx) elif isinstance(var, UserDefinedObjectVariable): source = var.source and TypeSource(var.source) return VariableTracker.build(tx, var.python_type(), source) + else: + raise AssertionError(f"Unexpected type {type(var)}") -def _is_attr_overridden(tx: "InstructionTranslator", var, name): +def _is_attr_overridden( + tx: "InstructionTranslator", var: VariableTracker, name: str +) -> bool: + if not isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable)): + return False import torch overridden = False @@ -434,7 +472,14 @@ def _is_attr_overridden(tx: "InstructionTranslator", var, name): return overridden -def call_torch_function(tx, torch_function_var, fn, types, args, kwargs): +def call_torch_function( + tx: "InstructionTranslator", + torch_function_var: VariableTracker, + fn: VariableTracker, + types: TupleVariable, + args: Iterable[Any], + kwargs: dict[str, Any], +) -> Any: # This emulates calling __torch_function__, which has a signature # def __torch_function__(cls, func, types, args=(), kwargs=None): # @@ -451,7 +496,9 @@ def call_torch_function(tx, torch_function_var, fn, types, args, kwargs): return torch_function_var.call_function(tx, tf_args, {}) -def get_torch_function_fn(tx: "InstructionTranslator", vt): +def get_torch_function_fn( + tx: "InstructionTranslator", vt: VariableTracker +) -> VariableTracker: # The underlying function could be a classmethod, staticmethod, regular # function or a function with C-implementation. It doesn't matter as long as # they satisfy the calling convention in `call_torch_function`. @@ -462,7 +509,9 @@ def get_torch_function_fn(tx: "InstructionTranslator", vt): return func_vt -def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): +def can_dispatch_torch_function( + tx: "InstructionTranslator", args: Iterable[Any], kwargs: dict[str, Any] +) -> bool: has_overridden_args = any( has_torch_function(arg) for arg in _get_all_args(args, kwargs) ) @@ -472,7 +521,12 @@ def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): ) -def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs): +def dispatch_torch_function( + tx: "InstructionTranslator", + fn: VariableTracker, + args: Iterable[Any], + kwargs: dict[str, Any], +) -> Any: """Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args""" all_args = _get_all_args(args, kwargs) @@ -518,7 +572,13 @@ class TensorWithTFOverrideVariable(TensorVariable): """ @classmethod - def from_tensor_var(cls, tx, tensor_var, class_type, cls_source): + def from_tensor_var( + cls, + tx: "InstructionTranslator", + tensor_var: VariableTracker, + class_type: type, + cls_source: Source, + ) -> "TensorWithTFOverrideVariable": # [Note: __torch_function__] coerce `tensor_var` into a # TensorWithTFOverrideVariable. In eager, this is just a type change. import torch @@ -533,7 +593,7 @@ def from_tensor_var(cls, tx, tensor_var, class_type, cls_source): var.install_global(tx) return var - def install_global(self, tx): + def install_global(self, tx: "InstructionTranslator") -> None: # stash the subclass type to rewrap an output tensor if needed # this is needed because the actual type needs to be available # each time the compiled artifact is run and outputs a wrapped tensor. @@ -543,20 +603,20 @@ def install_global(self, tx): self.global_mangled_class_name(tx), self.class_type ) - def python_type(self): + def python_type(self) -> type: return self.class_type - def class_type_var(self, tx): + def class_type_var(self, tx: "InstructionTranslator") -> VariableTracker: return TensorSubclassVariable( self.class_type, source=GlobalSource(self.global_mangled_class_name(tx)) ) - def global_mangled_class_name(self, tx): + def global_mangled_class_name(self, tx: "InstructionTranslator") -> str: return get_safe_global_name( tx, f"__subclass_{self.class_type.__name__}", self.class_type ) - def var_getattr(self, tx: "InstructionTranslator", name): + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: # [Note: __torch_function__] We currently only support attributes that are defined on # base tensors, custom attribute accesses will graph break. import torch @@ -581,7 +641,8 @@ def var_getattr(self, tx: "InstructionTranslator", name): and not attr_is_overridden and not inspect.ismethoddescriptor(getattr(torch.Tensor, name)) ): - args, kwargs = [self], {} + args = [self] + kwargs: dict[Any, Any] = {} if can_dispatch_torch_function(tx, args, kwargs): get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__) @@ -636,7 +697,14 @@ def var_getattr(self, tx: "InstructionTranslator", name): return super().var_getattr(tx, name) - def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): + def call_torch_function( + self, + tx: "InstructionTranslator", + fn: VariableTracker, + types: TupleVariable, + args: Iterable[Any], + kwargs: dict[str, Any], + ) -> Any: # NOTE this assumes `__torch_function__` isn't modified during tracing. if not hasattr(self, "torch_function_fn"): self.torch_function_fn = get_torch_function_fn(tx, self) @@ -652,8 +720,8 @@ def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwar def call_method( self, - tx, - name, + tx: "InstructionTranslator", + name: str, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index f4419fbbfe79b..707ad7b3d9d18 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -58,6 +58,7 @@ raise_observed_exception, unimplemented_v2, ) +from ..graph_bytecode_inputs import get_external_object_by_index from ..guards import GuardBuilder, install_guard from ..source import ( AttrSource, @@ -94,7 +95,7 @@ unpatched_nn_module_getattr, ) from .base import raise_type_error_exc, ValueMutationNew, VariableTracker -from .dicts import DefaultDictVariable +from .dicts import ConstDictVariable, DefaultDictVariable from .lists import SizeVariable @@ -809,14 +810,44 @@ def deque_signature(iterable=None, maxlen=None): ) args = [stacked] - tensor_variable = wrap_fx_proxy( - tx=tx, - proxy=tx.output.create_proxy( - "call_function", - self.value, - *proxy_args_kwargs(args, kwargs), - ), - ) + if issubclass(self.value, torch.Stream): + from .constant import ConstantVariable + from .lists import TupleVariable + + # Register newly created stream for reconstruction + var_kwargs = ConstDictVariable( + {ConstantVariable(k): v for k, v in kwargs.items()} + ) + var_args = TupleVariable(list(args)) + stream = self.value( + *(var_args.as_python_constant()), + **(var_kwargs.as_python_constant()), + ) + from ..graph_bytecode_inputs import register_graph_created_object + from .streams import StreamVariable + + ind = register_graph_created_object( + stream, + StreamVariable.make_construct_in_graph_stream_fn( + var_args, var_kwargs + ), + ) + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", get_external_object_by_index, (ind,), {} + ), + user_obj_index=ind, + ) + else: + tensor_variable = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + *proxy_args_kwargs(args, kwargs), + ), + ) return tensor_variable elif self.value is random.Random: diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 1a928f011bbed..89b6e3297933f 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -765,7 +765,7 @@ def convert_prim_device(self, node: torch._C.Node): raise ValueError(f"Unsupported JitType ({input_type}) when get device") def convert_prim_GetAttr(self, node: torch._C.Node): - # Build fully qulified name + # Build fully qualified name attr_fqn = get_attribute_fqn_from_ts_node(self.name_to_attribute_fqn, node) output_name = node.output().debugName() self.name_to_attribute_fqn[output_name] = attr_fqn @@ -971,7 +971,7 @@ def convert_aten_to(self, node: torch._C.Node): # "cannot mutate tensors with frozen storage" functionalization error. # To work around the issue, we override the copy to be True, so that the output # is for sure not an alias of input - if target == torch.ops.aten.to.dtype or target == torch.ops.aten.to.prim_dtype: + if target is torch.ops.aten.to.dtype or target is torch.ops.aten.to.prim_dtype: user_nodes = [use.user for use in node.output().uses()] user_targets = [ get_op_overload(user_node) @@ -1011,7 +1011,7 @@ def convert_aten_add(self, node: torch._C.Node): else: target = get_op_overload(node) - if target == torch.ops.aten.add.t: + if target is torch.ops.aten.add.t: # special handle python list/tuple add: "aten::add.t(t[] a, t[] b) -> t[]" for # RuntimeError: aten::add() Expected a value of type 'List[t]' for argument 'a' but instead found type 'immutable_list'. args, _kwargs = self.get_args_kwargs(node, target._schema) @@ -1455,7 +1455,7 @@ def convert(self) -> ExportedProgram: ) gm = graph_converter.convert() - # Post-proccessing step to deal with quantized operators. + # Post-processing step to deal with quantized operators. replace_quantized_ops_with_standard_ops(gm) log.info("GraphModule: %s", gm.print_readable(print_output=False)) diff --git a/torch/_export/pass_base.py b/torch/_export/pass_base.py index a8522525fc28c..b2548bb61d69d 100644 --- a/torch/_export/pass_base.py +++ b/torch/_export/pass_base.py @@ -214,7 +214,7 @@ def call_function( ) -> ProxyValue: meta = NodeMetadata(self.node.meta) - if target == operator.getitem: + if target is operator.getitem: value, key = args return self.callback.call_getitem(value, key, meta) elif getattr(target, "__module__", None) in { @@ -236,10 +236,10 @@ def call_function( kwargs, meta, ) - elif target == torch.ops.higher_order.cond: + elif target is torch.ops.higher_order.cond: pred, true_fn, false_fn, inputs = args return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta) - elif target == torch.ops.higher_order.map_impl: + elif target is torch.ops.higher_order.map_impl: f, mapped_args, operands = args # type: ignore[assignment] return self.callback.call_map(f, mapped_args, operands, meta) # For other unregistered HigherOrderOps, just interpret them blindly diff --git a/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py index d646b7edaaf06..345401e9f76e5 100644 --- a/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py +++ b/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py @@ -224,7 +224,7 @@ def maybe_get_symint(x): lhs = maybe_get_symint(lhs) rhs = maybe_get_symint(rhs) - if compare_op == operator.ge: + if compare_op is operator.ge: lhs, rhs = rhs, lhs if isinstance(lhs, sympy.Symbol) and isinstance(rhs, int): diff --git a/torch/_export/passes/collect_tracepoints_pass.py b/torch/_export/passes/collect_tracepoints_pass.py index 8162342e50c88..d9a8256488688 100644 --- a/torch/_export/passes/collect_tracepoints_pass.py +++ b/torch/_export/passes/collect_tracepoints_pass.py @@ -2,7 +2,7 @@ from __future__ import annotations import operator -from typing import Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING import torch from torch.export.exported_program import ConstantArgument, TensorArgument @@ -29,8 +29,8 @@ def __init__( self.specs = specs self.sig = sig - def call(self, gm: torch.fx.GraphModule) -> Optional[PassResult]: - def get_arg_spec(arg) -> Union[TensorArgument, ConstantArgument]: + def call(self, gm: torch.fx.GraphModule) -> PassResult | None: + def get_arg_spec(arg) -> TensorArgument | ConstantArgument: if isinstance(arg, torch.fx.Node): if isinstance(arg.meta.get("val"), torch.Tensor): return TensorArgument(name=arg.name) @@ -48,7 +48,7 @@ def get_arg_spec(arg) -> Union[TensorArgument, ConstantArgument]: for node in module.graph.nodes: if node.op != "call_function": continue - if node.target == torch.ops.higher_order._export_tracepoint: + if node.target is torch.ops.higher_order._export_tracepoint: kind = node.kwargs["kind"] if kind == "module_call_outputs": nn_module_stack = node.meta["nn_module_stack"] @@ -64,7 +64,7 @@ def get_arg_spec(arg) -> Union[TensorArgument, ConstantArgument]: for node in reversed(module.graph.nodes): if node.op != "call_function": continue - if node.target == torch.ops.higher_order._export_tracepoint: + if node.target is torch.ops.higher_order._export_tracepoint: kind = node.kwargs["kind"] if kind == "module_call_inputs": nn_module_stack = node.meta["nn_module_stack"] @@ -94,7 +94,7 @@ def copy_sig(sig) -> ModuleCallSignature: for node in module.graph.nodes: if node.op != "call_function": continue - if node.target == torch.ops.higher_order._export_tracepoint: + if node.target is torch.ops.higher_order._export_tracepoint: # There's some subtlety worth noting. Here fqn corresponds to # the call name, whereas path corresponds to the module name. # They are not necessarily the same! When a submodule is shared @@ -130,7 +130,7 @@ def copy_sig(sig) -> ModuleCallSignature: if isinstance(arg, torch.fx.Node): for user in node.users: assert user.op == "call_function" - assert user.target == operator.getitem + assert user.target is operator.getitem assert isinstance(user.args[1], int) if user.args[1] == i: user.replace_all_uses_with(arg) diff --git a/torch/_export/passes/constant_folding.py b/torch/_export/passes/constant_folding.py index 5fdc92702a116..58534856422c7 100644 --- a/torch/_export/passes/constant_folding.py +++ b/torch/_export/passes/constant_folding.py @@ -65,7 +65,7 @@ def __init__( def is_impure(self, node: torch.fx.Node) -> bool: if ( - node.target == torch.ops.prims.convert_element_type.default + node.target is torch.ops.prims.convert_element_type.default and node.args[0].op == "get_attr" # type: ignore[union-attr] and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] and node.args[1] == torch.bfloat16 @@ -135,7 +135,7 @@ def set_env(arg): # TODO - fix errors with this if ( node.op == "call_function" - and node.target == aten._efficientzerotensor.default + and node.target is aten._efficientzerotensor.default ): return self.unknown_value diff --git a/torch/_export/passes/replace_autocast_with_hop_pass.py b/torch/_export/passes/replace_autocast_with_hop_pass.py index 71b90a3ff1bfb..14ab3e817ed70 100644 --- a/torch/_export/passes/replace_autocast_with_hop_pass.py +++ b/torch/_export/passes/replace_autocast_with_hop_pass.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs from __future__ import annotations -from typing import Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING import torch from torch._higher_order_ops.wrap import wrap_with_autocast @@ -18,7 +18,7 @@ from torch.export.graph_signature import ExportGraphSignature -def _is_autocast_node(node: torch.fx.Node) -> Union[torch.fx.Node, bool]: +def _is_autocast_node(node: torch.fx.Node) -> torch.fx.Node | bool: return ( node and node.op == "call_function" @@ -30,19 +30,19 @@ def _is_autocast_node(node: torch.fx.Node) -> Union[torch.fx.Node, bool]: ) -def _is_enter_autocast_node(node: torch.fx.Node) -> Union[torch.fx.Node, bool]: +def _is_enter_autocast_node(node: torch.fx.Node) -> torch.fx.Node | bool: return ( node and node.op == "call_function" - and node.target == torch.amp.autocast_mode._enter_autocast + and node.target is torch.amp.autocast_mode._enter_autocast ) -def _is_exit_autocast_node(node: torch.fx.Node) -> Union[torch.fx.Node, bool]: +def _is_exit_autocast_node(node: torch.fx.Node) -> torch.fx.Node | bool: return ( node and node.op == "call_function" - and node.target == torch.amp.autocast_mode._exit_autocast + and node.target is torch.amp.autocast_mode._exit_autocast ) @@ -59,7 +59,7 @@ def _is_autocast_sub_mod(node: torch.fx.Node) -> bool: if ( first_non_ph and first_non_ph.op == "call_function" - and first_non_ph.target == torch.amp.autocast_mode._enter_autocast + and first_non_ph.target is torch.amp.autocast_mode._enter_autocast ): # TODO: check if current auto-cast type is the same as the args of # _enter_autocast. If so, return False, i.e. do not create a submodule. @@ -144,8 +144,8 @@ def node_call_back(node: torch.fx.Node) -> bool: def _sequential_split_and_maybe_inline_subgraphs( - gm: torch.fx.GraphModule, graph_signature: Optional[ExportGraphSignature] -) -> tuple[torch.fx.GraphModule, Optional[ExportGraphSignature]]: + gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None +) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]: """ Helper function for replace_autocast_with_hop_pass(). Split the graph module into multiple subgraphs based on the autocast nodes. @@ -176,8 +176,8 @@ def _maybe_inline_or_replace_with_hop(node: torch.fx.Node) -> None: def replace_autocast_with_hop_pass( - gm: torch.fx.GraphModule, graph_signature: Optional[ExportGraphSignature] -) -> tuple[torch.fx.GraphModule, Optional[ExportGraphSignature]]: + gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None +) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]: """ Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and then recursively call itself on each of the submodules. diff --git a/torch/_export/passes/replace_set_grad_with_hop_pass.py b/torch/_export/passes/replace_set_grad_with_hop_pass.py index 4c3a9c48d755f..5a15a59505755 100644 --- a/torch/_export/passes/replace_set_grad_with_hop_pass.py +++ b/torch/_export/passes/replace_set_grad_with_hop_pass.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs from __future__ import annotations -from typing import Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING import torch from torch._higher_order_ops.wrap import wrap_with_set_grad_enabled @@ -18,17 +18,17 @@ from torch.export.graph_signature import ExportGraphSignature -def _is_set_grad_enabled_node(node: torch.fx.Node) -> Union[torch.fx.Node, bool]: +def _is_set_grad_enabled_node(node: torch.fx.Node) -> torch.fx.Node | bool: return ( node and node.op == "call_function" - and node.target == torch._C._set_grad_enabled + and node.target is torch._C._set_grad_enabled ) def _is_set_grad_enabled_sub_mod( node: torch.fx.Node, omit_if_same_with_ambient: bool = False -) -> Union[bool, torch.Tensor]: +) -> bool | torch.Tensor: if node.op == "call_module": assert isinstance(node.target, str) subgm = getattr(node.graph.owning_module, node.target) @@ -38,7 +38,7 @@ def _is_set_grad_enabled_sub_mod( if ( first_non_ph and first_non_ph.op == "call_function" - and first_non_ph.target == torch._C._set_grad_enabled + and first_non_ph.target is torch._C._set_grad_enabled ): return ( first_non_ph.args[0] != torch.is_grad_enabled() @@ -80,8 +80,8 @@ def _remove_set_grad_and_inline(node: torch.fx.Node) -> None: def _sequential_split_and_maybe_inline_subgraphs( - gm: torch.fx.GraphModule, graph_signature: Optional[ExportGraphSignature] -) -> tuple[torch.fx.GraphModule, Optional[ExportGraphSignature]]: + gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None +) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]: """ Helper function for replace_set_grad_with_hop_pass(). Split the graph module into multiple subgraphs based on the set_grad_enabled nodes. @@ -108,8 +108,8 @@ def _maybe_inline_or_replace_with_hop(node: torch.fx.Node): def replace_set_grad_with_hop_pass( - gm: torch.fx.GraphModule, graph_signature: Optional[ExportGraphSignature] -) -> tuple[torch.fx.GraphModule, Optional[ExportGraphSignature]]: + gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature | None +) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]: """ Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and then recursively call itself on each of the submodules. diff --git a/torch/_export/passes/replace_with_hop_pass_util.py b/torch/_export/passes/replace_with_hop_pass_util.py index 4579519fa3f2c..6ea3f1adde4f8 100644 --- a/torch/_export/passes/replace_with_hop_pass_util.py +++ b/torch/_export/passes/replace_with_hop_pass_util.py @@ -4,7 +4,7 @@ import contextlib import copy import operator -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING import torch @@ -109,9 +109,9 @@ def set_hoo_node_meta(call_func_node): def _sequential_split_and_maybe_inline_subgraphs_helper( new_gm: torch.fx.GraphModule, - graph_signature: Optional[ExportGraphSignature], + graph_signature: ExportGraphSignature | None, maybe_inline_or_replace_with_hop: Callable[[torch.fx.Node], None], -) -> tuple[torch.fx.GraphModule, Optional[ExportGraphSignature]]: +) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]: """ Helper function for replacing graph nodse with higher order nodes. For each subgraph in `new_gm`, decides whether to construct a HOO subgraph, or inline the calls @@ -159,12 +159,12 @@ def _sequential_split_and_maybe_inline_subgraphs_helper( def _replace_with_hop_pass_helper( gm: torch.fx.GraphModule, - graph_signature: Optional[ExportGraphSignature], + graph_signature: ExportGraphSignature | None, sequential_split_and_maybe_inline_subgraphs: Callable[ - [torch.fx.GraphModule, Optional[ExportGraphSignature]], - tuple[torch.fx.GraphModule, Optional[ExportGraphSignature]], + [torch.fx.GraphModule, ExportGraphSignature | None], + tuple[torch.fx.GraphModule, ExportGraphSignature | None], ], -) -> tuple[torch.fx.GraphModule, Optional[ExportGraphSignature]]: +) -> tuple[torch.fx.GraphModule, ExportGraphSignature | None]: """ Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and then recursively call itself on each of the submodules. diff --git a/torch/_export/serde/schema_check.py b/torch/_export/serde/schema_check.py index 0890b4b2dd84e..5ec1fdb9026b9 100644 --- a/torch/_export/serde/schema_check.py +++ b/torch/_export/serde/schema_check.py @@ -88,7 +88,7 @@ def dump_type(t, level: int) -> tuple[str, str, str]: f"std::optional<{cpp_type}>", f"optional {thrift_type}", ) - elif o == Annotated: + elif o is Annotated: return dump_type(t.__origin__, level) else: raise AssertionError(f"Type {t} is not supported in export schema.") @@ -129,7 +129,7 @@ def dump_field(f) -> tuple[dict[str, Any], str, Optional[str], str, int]: t, cpp_type, thrift_type = dump_type(f.type, 0) ret = {"type": t} cpp_default: Optional[str] = None - assert typing.get_origin(f.type) == Annotated, ( + assert typing.get_origin(f.type) is Annotated, ( f"Field {f.name} must be annotated with an integer id." ) thrift_id = f.type.__metadata__[0] diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 7e706baa5f9bd..9c4629f13337d 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -11,6 +11,7 @@ import logging import math import operator +import re import traceback import typing from collections import namedtuple, OrderedDict @@ -511,6 +512,15 @@ def __new__(metacls, name, bases, classdict): return type.__new__(metacls, name, bases, dict(classdict)) +def is_metadata_matched(config, entry_metadata): + metadata_attrs = ["num_cpu_threads", "num_warps", "num_stages", "num_ctas"] + for attr in metadata_attrs: + if hasattr(config, attr) and hasattr(entry_metadata, attr): + if getattr(config, attr) != getattr(entry_metadata, attr): + return False + return True + + def get_triton_kernel_and_cache_entry(node: torch.fx.Node): assert ( node.target @@ -519,50 +529,115 @@ def get_triton_kernel_and_cache_entry(node: torch.fx.Node): assert has_triton(), "triton required to serialize triton kernels" from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction assert isinstance(node.kwargs["kernel_idx"], int) kernel = torch._higher_order_ops.triton_kernel_wrap.kernel_side_table.get_kernel( node.kwargs["kernel_idx"] ) - kNumWarpsDefault = 4 + # For Autotuner, we need to look at the underlying JITFunction's cache + # since the Autotuner itself doesn't have a cache + is_autotuner = isinstance(kernel, Autotuner) + # pyrefly: ignore [missing-attribute] + actual_kernel = kernel.fn if is_autotuner else kernel - # currently we only support specialization of - # num_warps -- so search for the entry that - # matches the value from the associated kernel - if isinstance(kernel, Autotuner): - assert len(kernel.configs) == 1 - num_warps = kernel.configs[0].num_warps - assert kernel.configs[0].num_ctas == 1, ( - "serialization only supports num_ctas == 1" - ) - kernel = kernel.fn - else: - num_warps = kNumWarpsDefault - - if hasattr(kernel, "device_caches"): - caches = kernel.device_caches + if hasattr(actual_kernel, "device_caches"): + caches = actual_kernel.device_caches assert len(caches.keys()) == 1 cache = next(iter(caches.values()))[0] - elif hasattr(kernel, "cache"): + elif hasattr(actual_kernel, "cache"): # old path, still used for cpu triton builds - caches = kernel.cache + caches = actual_kernel.cache assert len(caches.keys()) == 1 cache = next(iter(caches.values())) else: - raise AssertionError(f"kernel caches not found for kernel {kernel.__name__}") + raise AssertionError( + # pyrefly: ignore [missing-attribute] + f"kernel caches not found for kernel {actual_kernel.__name__}" + ) - # can also get num_warps, num_ctas, etc. from here ig if len(cache.keys()) == 1: - return kernel, next(iter(cache.values())) + return actual_kernel, next(iter(cache.values())) + + has_constexprs = ( + isinstance(actual_kernel, JITFunction) + and hasattr(actual_kernel, "constexprs") + and len(actual_kernel.constexprs) > 0 + ) + + if has_constexprs: + constexpr_vals = {} + # pyrefly: ignore [missing-attribute] + for constexpr_idx in actual_kernel.constexprs: + # pyrefly: ignore [missing-attribute] + if constexpr_idx < len(actual_kernel.arg_names): + # pyrefly: ignore [missing-attribute] + param_name = actual_kernel.arg_names[constexpr_idx] + kwargs_dict = node.kwargs.get("kwargs", {}) + if isinstance(kwargs_dict, dict): + if param_name in kwargs_dict: + constexpr_vals[param_name] = kwargs_dict[param_name] + + expected_values = [ + # pyrefly: ignore [missing-attribute] + constexpr_vals[actual_kernel.arg_names[idx]] + # pyrefly: ignore [missing-attribute] + for idx in actual_kernel.constexprs + # pyrefly: ignore [missing-attribute] + if actual_kernel.arg_names[idx] in constexpr_vals + ] + + matching_entries = [] + for sig_key, cache_entry in cache.items(): + constexpr_matches = re.findall(r"\('constexpr',\s*([^)]+)\)", sig_key) + if constexpr_matches: + constexpr_values = [] + for match in constexpr_matches: + if match in ("True", "False"): + constexpr_values.append(match == "True") + elif "." in match or "e" in match or "E" in match: + constexpr_values.append(float(match)) + else: + constexpr_values.append(int(match)) + + if constexpr_values == expected_values: + matching_entries.append((sig_key, cache_entry)) else: - for cache_entry in cache.values(): - if cache_entry.metadata.num_warps == num_warps: - return kernel, cache_entry + matching_entries = list(cache.items()) + + if len(matching_entries) == 0: raise AssertionError( - f"couldn't find a kernel cache entry with metadata matching the autotuner configs for kernel {kernel.__name__}" + # pyrefly: ignore [missing-attribute] + f"couldn't find a kernel cache entry with metadata matching the autotuner configs for kernel {actual_kernel.__name__}. " + f"Available cache keys: {list(cache.keys())}" ) + if len(matching_entries) == 1: + return actual_kernel, matching_entries[0][1] + + if is_autotuner: + for sig_key, cache_entry in matching_entries: + entry_metadata = cache_entry.metadata + # pyrefly: ignore [missing-attribute] + for config in kernel.configs: + if is_metadata_matched(config, entry_metadata): + return actual_kernel, cache_entry + + raise AssertionError( + # pyrefly: ignore [missing-attribute] + f"Multiple cache entries found for autotuned kernel {actual_kernel.__name__} " + f"{'with same constexpr values' if has_constexprs else 'with no constexpr'} " + f"and couldn't disambiguate using configs. " + ) + + raise AssertionError( + # pyrefly: ignore [missing-attribute] + f"Multiple cache entries found for non-autotuned kernel {actual_kernel.__name__} " + f"{'with same constexpr values' if has_constexprs else 'with no constexpr'}. " + f"This should not happen. Available cache keys: {[key for key, _ in matching_entries]}" + ) + @final class GraphModuleSerializer(metaclass=Final): @@ -763,8 +838,12 @@ def serialize_tensor_list_output(node): i += 1 assert isinstance(node.kwargs["grid"], list) + + kernel_name_with_hash = ( + f"{kernel.fn.__name__}_{kernel_cache_metadata.hash}" + ) kwargs_new = { - "name": kernel.fn.__name__, + "name": kernel_name_with_hash, "grid": node.kwargs["grid"][0], "output_indices": output_indices, "num_warps": kernel_cache_metadata.num_warps, @@ -2287,14 +2366,14 @@ def _is_single_tensor_return(target) -> bool: ) # handle ShapeEnv asserts - if target == torch.ops.aten._assert_scalar.default: + if target is torch.ops.aten._assert_scalar.default: if not isinstance((arg := fx_node.args[0]), bool): expr = arg.meta["val"] # type: ignore[union-attr] if isinstance(expr, torch.SymBool): self.shape_env.guard_or_defer_runtime_assert( expr.node.expr, "", fx_node ) - elif target == torch.ops.aten.sym_constrain_range_for_size.default: + elif target is torch.ops.aten.sym_constrain_range_for_size.default: sym = fx_node.args[0].meta["val"] # type: ignore[union-attr] if isinstance(sym, torch.SymInt): self.shape_env._constrain_range_for_size(sym.node.expr) @@ -3152,7 +3231,7 @@ def serialize( def _dict_to_dataclass(cls, data): assert not isinstance(cls, str), f"Unresolved class type: '{cls}'." - if typing.get_origin(cls) == Annotated: + if typing.get_origin(cls) is Annotated: return _dict_to_dataclass(cls.__origin__, data) if typing.get_origin(cls) == typing.Union and type(None) in typing.get_args(cls): if data is None: diff --git a/torch/_export/utils.py b/torch/_export/utils.py index cc7cbee8dff47..648e32758e5fa 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -681,7 +681,7 @@ def _insert_aten_to_metadata_assert_pass(gm: torch.fx.GraphModule) -> None: for node in gm.graph.nodes: if node.target in aten_to_variants: if ( - node.prev.target == torch.ops.aten._assert_tensor_metadata.default + node.prev.target is torch.ops.aten._assert_tensor_metadata.default and node.args[0] == node.prev.args[0] ): # skip if already guarded @@ -850,7 +850,7 @@ def node_inline_(call_mod_node: torch.fx.Node) -> Optional[torch.fx.GraphModule] get_item_users = nodes_filter( list(call_mod_node.users.keys()), lambda node: node.op == "call_function" - and node.target == operator.getitem, + and node.target is operator.getitem, ) # get_item_node.args[1] is the idx referring to new_output[idx] nodes_map( @@ -1477,7 +1477,7 @@ def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any: flattened, _ = flatten_fn(obj) # NOTE: This helper function will replicate an nn.Module in the exactly same - # structure to be used together with _reparametrize_module. This will + # structure to be used together with _reparameterize_module. This will # create a clone of the module with the new parameters and buffers without # affecting the original module. def copy_module(mod: torch.nn.Module): diff --git a/torch/_functorch/_aot_autograd/aot_autograd_result.py b/torch/_functorch/_aot_autograd/aot_autograd_result.py new file mode 100644 index 0000000000000..ce01e37f03243 --- /dev/null +++ b/torch/_functorch/_aot_autograd/aot_autograd_result.py @@ -0,0 +1,674 @@ +# mypy: allow-untyped-defs +""" +This module provides result classes for AOT Autograd compilation. + +Similar to how torch._inductor.output_code provides OutputCode classes for inductor +compilation results, this module provides AOTAutogradResult classes that represent +the compiled artifacts produced by AOT Autograd. + +These results are: +- Serializable: can be saved/loaded from disk without recompilation +- Addressable: can be stored in caches with keys for later retrieval +- Reusable: can be used for both caching and ahead-of-time compilation (precompile) + +The main result types are: +- GenericAOTAutogradResult: Abstract base for all AOT Autograd results +- AOTAutogradResult: Regular result that references FxGraphCache entries +- BundledAOTAutogradResult: Result that bundles the entire compiled code directly +""" + +from __future__ import annotations + +import json +import logging +from abc import ABC, abstractmethod +from copy import copy +from dataclasses import dataclass +from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar + +import torch +from torch._dynamo.precompile_context import BackendCacheArtifact +from torch._inductor.codecache import FxGraphCache +from torch._inductor.output_code import ( + CompiledFxGraph, + CompiledFxGraphConstants, + OutputCode, +) +from torch._inductor.utils import should_use_remote_fx_graph_cache + +from .runtime_wrappers import ( + AOTDispatchAutograd, + AOTDispatchSubclassWrapper, + CachedAutogradLazyBackwardCompileInfo, + CompilerWrapper, + FunctionalizedRngRuntimeWrapper, + post_compile, + RuntimeWrapper, + SerializableCompiledFunction, + SubclassMeta, +) +from .schemas import AOTAutogradCacheInfo # noqa: F401 +from .utils import simple_wraps + + +if TYPE_CHECKING: + from torch._inductor.compile_fx import _CompileFxKwargs + + from .schemas import AOTConfig, ViewAndMutationMeta + +log = logging.getLogger(__name__) + + +TOut = TypeVar("TOut", bound=OutputCode) + + +class InductorOutput(ABC, Generic[TOut]): + """ + Class representing a single inductor output + """ + + @abstractmethod + def pre_save(self) -> None: ... + + @abstractmethod + def load(self, example_inputs) -> TOut: ... + + @abstractmethod + def post_compile(self, result: TOut, fx_config: _CompileFxKwargs) -> TOut: ... + + +TOutputCode = TypeVar("TOutputCode", bound=OutputCode) + + +@dataclass +class BundledOutputCodeLoadable(InductorOutput[TOutputCode], Generic[TOutputCode]): + """ + A generic wrapper for OutputCode objects that are bundled directly in the cache + (rather than looked up via FxGraphCache). + + This works for any OutputCode subclass (CompiledFxGraph, RegionalOutputCode, etc.) + """ + + result: TOutputCode + + def pre_save(self) -> None: + disk_result = copy(self.result) + disk_result.prepare_for_serialization() + self.result = disk_result + return + + def load(self, example_inputs) -> TOutputCode: + self.example_inputs = example_inputs + return self.result + + def post_compile( + self, result: TOutputCode, fx_config: _CompileFxKwargs + ) -> TOutputCode: + constants = CompiledFxGraphConstants() + + # Special handling for CompiledFxGraph - needs FxGraphCache.cache_hit_post_compile + if isinstance(result, CompiledFxGraph): + graph, cache_info = FxGraphCache.cache_hit_post_compile( + result, {}, constants + ) + if graph is None: + raise RuntimeError("Failed to reload cache entry from disk") + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_bundled_cache_hit", # always a hit + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + result = graph # type: ignore[assignment] + + # Run normal post compile + result.post_compile(self.example_inputs, constants, fx_config) + return result + + +# Backwards compatibility alias +CompiledFxGraphLoadable: type[BundledOutputCodeLoadable[CompiledFxGraph]] = ( + BundledOutputCodeLoadable[CompiledFxGraph] +) + + +@dataclass +class FxGraphCacheLoadable(InductorOutput[CompiledFxGraph]): + fx_graph_cache_info: tuple[str, list[str]] + fx_graph_guard_expr: Optional[str] + + def pre_save(self): + return + + def _is_backward(self) -> bool: + return False + + def load(self, example_inputs) -> CompiledFxGraph: + from .autograd_cache import FXGraphCacheMiss + + # [Note: AOTAutogradCache and FXGraphCache Guard interactions] + # As mentioned, AOTAutograd takes in the symint inputs from dynamo's list of arguments. + # FXGraphCache serializes guards that are needed in the shape_env based on these symint inputs to the graph. + # The invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly + # the same as the ones it passes to inductor, for both the forward and backward passes. + # (This does not mean that the tensor values passed in are the same: only that their symints are). + # That is, AOTAutograd and Inductor never create new guards based on symints with different sources + # than those passed to it by inductor. + # We pass the post compile function, which sets various fx_config boxed values, + # so we can call it only after we're sure both forward and backward have + # Clear CompiledTritonKernels before loading from FXGraphCache + torch._inductor.async_compile.CompiledTritonKernels.cache_clear() + remote_cache = None + constants = CompiledFxGraphConstants() + if should_use_remote_fx_graph_cache(): + remote_cache = FxGraphCache.get_remote_cache() + (cache_key, debug_lines) = self.fx_graph_cache_info + + def check_exact_guard_match(guard_expr, _hints): + """ + AOTAutogradCache tracks its own guards, so we just need to treat these guard expressions as a second + cache key of sorts: we just check for equality, i.e. the FXGraphCache entry with + the exact same guards as we originally saved into the cache. + """ + return guard_expr == self.fx_graph_guard_expr + + result, cache_info = FxGraphCache.load_with_key( + cache_key, + debug_lines, + example_inputs, + local=True, + remote_cache=remote_cache, + is_backward=self._is_backward(), + constants=constants, + evaluate_guards=check_exact_guard_match, + ) + if result is None: + log.info("FXGraphCache cache miss for key %s", self.fx_graph_cache_info) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_cache_miss", # always a hit + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + + raise FXGraphCacheMiss + + # No need to log chromium event because AOTAutograd will log that immediately for us + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "fx_graph_cache_hit", # always a hit + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) + self.example_inputs = example_inputs + self.constants = constants + return result + + def post_compile( + self, result: CompiledFxGraph, fx_config: _CompileFxKwargs + ) -> CompiledFxGraph: + """ + Called after FXGraphCacheLoadable.load, mutates fx_config + """ + result.post_compile(self.example_inputs, self.constants, fx_config) + return result + + +@dataclass +class CompiledForward(FxGraphCacheLoadable): + """ + Cacheable entry for a forward function + """ + + def _is_backward(self) -> bool: + return False + + +@dataclass +class GenericCompiledBackward(InductorOutput[TOut]): + # Used by AOTDispatchAutograd.post_compile + backward_state_indices: list[int] + num_symints_saved_for_bw_: int + + +@dataclass +class CompiledBackward(GenericCompiledBackward[CompiledFxGraph], FxGraphCacheLoadable): + """ + Cacheable entry for a forward function + """ + + def _is_backward(self) -> bool: + return True + + def post_compile( + self, result: CompiledFxGraph, fx_config: _CompileFxKwargs + ) -> CompiledFxGraph: + compiled_bw = super().post_compile(result, fx_config) + # See note [Wrapping bw_compiler in disable] + # This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py + # But since on cache hit we do not call the bw_compiler, we need to reapply the disable + return torch._dynamo.disable( # type: ignore[return-value] + compiled_bw, reason="do not trace generated backwards pass" + ) + + +# Generic bundled forward/backward classes that work with any OutputCode type +@dataclass +class BundledCompiledForward( + BundledOutputCodeLoadable[TOutputCode], Generic[TOutputCode] +): + """ + Generic forward function for bundled compilation. + Works with any OutputCode type (CompiledFxGraph, RegionalOutputCode, etc.) + """ + + +@dataclass +class BundledCompiledBackward( + GenericCompiledBackward[TOutputCode], + BundledOutputCodeLoadable[TOutputCode], + Generic[TOutputCode], +): + """ + Generic backward function for bundled compilation. + Works with any OutputCode type (CompiledFxGraph, RegionalOutputCode, etc.) + """ + + def post_compile( + self, result: TOutputCode, fx_config: _CompileFxKwargs + ) -> TOutputCode: + compiled_bw = super().post_compile(result, fx_config) + # See note [Wrapping bw_compiler in disable] + # This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py + # But since on cache hit we do not call the bw_compiler, we need to reapply the disable + return torch._dynamo.disable( # type: ignore[return-value] + compiled_bw, reason="do not trace generated backwards pass" + ) + + +@dataclass +class SerializedGraphModule: + fn: Callable[[dict[Any, Any], str], torch.nn.Module] + args: tuple[Any, ...] + + def __init__(self, gm: torch.fx.GraphModule): + self.fn, self.args = gm.__reduce__() + + def deserialize(self) -> torch.fx.GraphModule: + gm = self.fn(*self.args) + assert isinstance(gm, torch.fx.GraphModule) + return gm + + +def serialize_graph_module(gm: torch.fx.GraphModule) -> SerializedGraphModule: + # NOTE: mutates the graph module + gm.meta = {} + for node in gm.graph.nodes: + node.meta = {} + return SerializedGraphModule(gm) + + +TForward = TypeVar("TForward", bound=InductorOutput) +TBackward = TypeVar("TBackward", bound=GenericCompiledBackward) + + +@dataclass +class GenericAOTAutogradResult(Generic[TForward, TBackward]): + """A single result from AOT Autograd compilation, genericized by Forward and Backward types. + + A TForward is always an InductorOutput of some sort, which represents the + forward graph of the compile. + A TBackward is an InductorOutput + metadata about the backward, useful for specific + backward-only wrappers. This type is encapsulated by GenericCompiledBackward. + + Each AOTAutogradResult is essentially parameterized by 1. the method of loading + from the cache (either Bundled or UnBundled), and 2. The type of the output. For now, + the only type of output we support is Python Wrapper output, i.e. OutputCode.CompiledFxGraph, + but the same technique works for C++ wrapper code; we'd just add an extra InductorOutput type. + """ + + # Forward and Backward info + compiled_fw: TForward + compiled_bw: Optional[TBackward] + + # Code of the joint graph using print_readable() + # Used for logging purposes + aot_joint_graph_str: Optional[str] + aot_forward_graph_str: Optional[str] + aot_backward_graph_str: Optional[str] + + # Runtime_metadata saved right before compilation + runtime_metadata: ViewAndMutationMeta + + # Wrappers that run after each aot_dispatch_* function + dispatch_wrappers: list[CompilerWrapper] + + # Used by AOTSubclassWrapper + maybe_subclass_meta: Optional[SubclassMeta] + num_fw_outs_saved_for_bw: Optional[int] + + # Used by RuntimeWrapper + indices_of_inps_to_detach: list[int] + + # Time taken to trace/compile the forward + # forward_time_taken includes AOTAutograd tracing time + inductor compilation time + # backward_time_taken is essentially just the time inductor took to compile + forward_time_taken_ns: int + backward_time_taken_ns: int + + # Used by standalone_compile + sanitized_aot_config: AOTConfig + + guards_expr: Optional[str] + + # Used by Compiled Autograd + serialized_bw_module: Optional[SerializedGraphModule] + + def pre_save(self): + """ + Perform any preparations to make the result ready for serialization. + """ + self.compiled_fw.pre_save() + if self.compiled_bw is not None: + self.compiled_bw.pre_save() + + # Turn result into the original callable + def wrap_post_compile( + self, + args: list[torch.Tensor], + aot_config: AOTConfig, + fx_config: _CompileFxKwargs, + ) -> Callable: + """ + This function takes a result and carefully reconstructs the original callable + that AOTAutograd returned the first time it was run. It does this by running the various + post compile steps that AOTAutograd runs on its compiled artifact after running the fw/bw compilers. + + In the inference path, this consists of the Subclass, FunctionalzedRngRuntime, and RuntimeWrappers. + In the autograd path, this consists of AOTAutogradDispatch.post_compile. + + The steps here should match exactly the steps that are run in aot_dispatch_base and aot_dispatch_autograd. + + Notably absent from the cached path are: + - DebugAssertWrapper + - FakifiedOutWrapper + + Which we'll handle separately later on, if necessary. + """ + from torch._dynamo.utils import CompileEventLogger, dynamo_timed + + # Log the output of AOTAutogradCache + if aot_config.enable_log: + # TODO: maybe also log to aot_graphs_log + # Unfortunately aot_graphs_log uses + # slightly different formatting though + if self.aot_joint_graph_str is not None: + torch._logging.trace_structured( + "aot_joint_graph", payload_fn=lambda: self.aot_joint_graph_str + ) + + if self.aot_forward_graph_str is not None: + from torchgen.utils import dataclass_repr + + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(self.runtime_metadata), + ) + if self.maybe_subclass_meta is not None: + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_subclass_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(self.maybe_subclass_meta), + ) + + # It's called an inference graph if not running with autograd + name = ( + "aot_forward_graph" + if self.aot_backward_graph_str is not None + else "aot_inference_graph" + ) + torch._logging.trace_structured( + name, payload_fn=lambda: self.aot_forward_graph_str + ) + + if self.aot_backward_graph_str is not None: + torch._logging.trace_structured( + "aot_backward_graph", payload_fn=lambda: self.aot_backward_graph_str + ) + with dynamo_timed("AOTAutogradCache.inductor_load"): + compiled_fw_func = self.compiled_fw.load(args) + compiled_bw_func = None + if self.compiled_bw is not None: + compiled_bw_func = self.compiled_bw.load(args) + needs_autograd = True + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="autograd" + ) + # Now that we've loaded forward and backward, call post compile on both + # This avoids setting things like BoxedBools in fx_config until + # after both forward and backward cache hit + fw_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": False, + } + bw_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": True, + } + compiled_fw_func = self.compiled_fw.post_compile( + compiled_fw_func, fw_fx_config + ) + compiled_bw_func = self.compiled_bw.post_compile( + compiled_bw_func, bw_fx_config + ) + else: + inference_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": False, + } + + needs_autograd = False + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="inference" + ) + compiled_fw_func = self.compiled_fw.post_compile( + compiled_fw_func, inference_fx_config + ) + + # Wrap the forward function in post compile wrappers + compiled_fw_func = AOTDispatchSubclassWrapper( + trace_joint=needs_autograd, + fw_only=None, + maybe_subclass_meta=self.maybe_subclass_meta, + num_fw_outs_saved_for_bw=self.num_fw_outs_saved_for_bw, + ).post_compile( + compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata + ) + + req_subclass_dispatch = self.maybe_subclass_meta is not None + CompileEventLogger.try_add_pt2_compile( + "backend_compile", requires_subclass_dispatch=req_subclass_dispatch + ) + + # In autograd case, functionalizedRngWrapper should not modify outs + return_new_outs = not needs_autograd + compiled_fw_func = FunctionalizedRngRuntimeWrapper( + return_new_outs=return_new_outs + ).post_compile( + compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata + ) + disable_amp = torch._C._is_any_autocast_enabled() + + if needs_autograd: + assert self.compiled_bw is not None + + cached_lazy_backward = None + if self.serialized_bw_module is not None: + cached_lazy_backward = CachedAutogradLazyBackwardCompileInfo( + self.serialized_bw_module.deserialize + ) + # This function is run on both cache miss and cache hit, either here + # or in aot_dispatch_autograd. On a cache hit, + # 1. the bw is already compiled + # 2. we don't need to save to the cache again + # so those corresponding arguments are set to None. + compiled_function = AOTDispatchAutograd.post_compile( + compiled_fw_func, + compiled_bw_func, + self.maybe_subclass_meta, + self.compiled_bw.num_symints_saved_for_bw_, + self.compiled_bw.backward_state_indices, + disable_amp, + self.indices_of_inps_to_detach, + cached_lazy_backward, + aot_config, + fw_metadata=self.runtime_metadata, + try_save_cache_entry=None, + ) + + else: + compiled_function = RuntimeWrapper( + indices_of_inps_to_detach=self.indices_of_inps_to_detach, + trace_joint=False, + disable_amp=disable_amp, + ).post_compile( + compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata + ) + + # Add serialization function back onto object + compiled_function, _ = post_compile( + self.dispatch_wrappers, + compiled_function, + aot_config, + runtime_metadata=self.runtime_metadata, + ) + + # Now that we're pretty sure it's a successful load, add guards + # to the existing shape environment from the cache + if self.guards_expr: + from .autograd_cache import AOTAutogradCache + + symints = AOTAutogradCache._filter_backed_symints(args) + check = bool(AOTAutogradCache.evaluate_guards(self.guards_expr, symints)) + assert check is True + + return compiled_function + + +class AOTAutogradResult(GenericAOTAutogradResult[CompiledForward, CompiledBackward]): + """ + Regular AOTAutogradResult: saves the forward/backward FxGraphCache keys + and looks them up in FxGraphCache on load + """ + + +class BundledAOTAutogradResult( + GenericAOTAutogradResult[ + BundledCompiledForward[TOutputCode], BundledCompiledBackward[TOutputCode] + ], + Generic[TOutputCode], +): + """ + Generic AOTAutogradResult where we bundle the entire OutputCode directly + (rather than looking it up via FxGraphCache). + + This works with any OutputCode type: + - CompiledFxGraph: Traditional inductor compilation + - RegionalOutputCode: Regional inductor compilation with GraphPickler serialization + - Any future OutputCode subclasses + + Type parameter: + TOutputCode: The OutputCode subclass (e.g., CompiledFxGraph, RegionalOutputCode) + + Usage with CompiledFxGraph: + entry = BundledAOTAutogradResult[CompiledFxGraph]( + compiled_fw=BundledCompiledForward(result=CompiledFxGraph(...)), + compiled_bw=BundledCompiledBackward( + result=CompiledFxGraph(...), + backward_state_indices=[...], + num_symints_saved_for_bw_=..., + ), + ... + ) + + Usage with RegionalOutputCode: + entry = BundledAOTAutogradResult[RegionalOutputCode]( + compiled_fw=BundledCompiledForward(result=RegionalOutputCode(gm)), + compiled_bw=BundledCompiledBackward( + result=RegionalOutputCode(gm), + backward_state_indices=[...], + num_symints_saved_for_bw_=..., + ), + ... + ) + """ + + +def deserialize_bundled_cache_entry(entry: BundledAOTAutogradResult) -> Callable: + from copy import deepcopy + + from torch._inductor.cudagraph_utils import BoxedDeviceIndex + from torch._inductor.utils import BoxedBool + + # In the precompile use case, guards are already serialized + # by dynamo, so we don't need to add them to the environment + entry.guards_expr = None + # TODO: this isn't exactly right, because cudagraphs needs to be a shared config + # which is set by compile_fx. But in precompile, we never actually call compile_fx + # so we don't have a place to track cudagraphs here. + cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs) + boxed_forward_device_index = BoxedDeviceIndex(None) + # We need to make a clean copy of the cache entry + # in case it needs to be serialized again + serializable_copy = deepcopy(entry) + + from torch._subclasses import FakeTensorMode + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + context = torch._guards.TracingContext.try_get() + if context is None: + # Create a clean environment when running fx graph post compile + # if one is not available + context = torch._guards.TracingContext(FakeTensorMode(shape_env=ShapeEnv())) + with torch._guards.tracing(context): + compiled_fn = entry.wrap_post_compile( + [], + entry.sanitized_aot_config, + { + "cudagraphs": cudagraphs, + "boxed_forward_device_index": boxed_forward_device_index, + }, + ) + # Ensure the deserialized cache entry is still serializable + + compiled_fn = SerializableCompiledFunction(compiled_fn, lambda: serializable_copy) + + # TODO: this ignores flat_params, which can exist + # if inline_builtin_nn_modules=False + @simple_wraps(compiled_fn) + def forward(*runtime_args: tuple[Any]): + return compiled_fn(list(runtime_args)) + + assert hasattr(compiled_fn, "serialize") + forward.serialize = compiled_fn.serialize # type: ignore[attr-defined] + + return forward + + +@dataclass +class BundledAOTAutogradCacheArtifact(BackendCacheArtifact[Callable]): + def after_deserialization(self) -> Callable: + return deserialize_bundled_cache_entry(self.content) diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index f60bf274b8fb9..e411b4c7f6d86 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -15,22 +15,14 @@ import shutil import time import traceback -from abc import ABC, abstractmethod -from collections.abc import Callable -from copy import copy, deepcopy -from dataclasses import dataclass -from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar, Union +from copy import copy +from typing import Any, Optional, TYPE_CHECKING, Union from typing_extensions import override import torch -from torch._dynamo.precompile_context import BackendCacheArtifact, PrecompileContext +from torch._dynamo.precompile_context import PrecompileContext from torch._dynamo.trace_rules import torch_non_c_binding_in_graph_functions -from torch._dynamo.utils import ( - chromium_event_log_active, - CompileEventLogger, - counters, - dynamo_timed, -) +from torch._dynamo.utils import chromium_event_log_active, CompileEventLogger, counters from torch._functorch import config from torch._inductor.codecache import ( _ident, @@ -45,12 +37,7 @@ sha256_hash, write_atomic, ) -from torch._inductor.cudagraph_utils import BoxedDeviceIndex -from torch._inductor.output_code import ( - CompiledFxGraph, - CompiledFxGraphConstants, - OutputCode, -) +from torch._inductor.output_code import OutputCode from torch._inductor.runtime.runtime_utils import cache_dir from torch._inductor.utils import BoxedBool, should_use_remote_fx_graph_cache from torch._logging import LazyString @@ -62,28 +49,35 @@ ) from torch.fx.experimental.symbolic_shapes import hint_int from torch.utils._triton import has_triton_package -from torchgen.utils import dataclass_repr +from .aot_autograd_result import ( + AOTAutogradResult, + BundledAOTAutogradCacheArtifact, + BundledAOTAutogradResult, + BundledCompiledBackward, + BundledCompiledForward, + CompiledBackward, + CompiledForward, + GenericAOTAutogradResult, + SerializedGraphModule, +) from .runtime_wrappers import ( - AOTDispatchAutograd, - AOTDispatchSubclassWrapper, - CachedAutogradLazyBackwardCompileInfo, CompilerWrapper, - FunctionalizedRngRuntimeWrapper, - post_compile, - RuntimeWrapper, SerializableCompiledFunction, SubclassMeta, ) from .schemas import AOTAutogradCacheInfo, AOTConfig, ViewAndMutationMeta # noqa: F401 -from .utils import simple_wraps if TYPE_CHECKING: + from collections.abc import Callable + from torch._inductor.compile_fx import _CompileFxKwargs + from torch._inductor.cudagraph_utils import BoxedDeviceIndex from torch._inductor.remote_cache import JsonDataTy, RemoteCache from torch.fx.node import Node + log = logging.getLogger(__name__) @@ -506,498 +500,6 @@ def autograd_cache_key( return key, debug_lines -TOut = TypeVar("TOut", bound=OutputCode) - - -class InductorOutput(ABC, Generic[TOut]): - """ - Class representing a single inductor output - """ - - @abstractmethod - def pre_save(self) -> None: ... - - @abstractmethod - def load(self, example_inputs) -> TOut: ... - - @abstractmethod - def post_compile(self, result: TOut, fx_config: _CompileFxKwargs) -> TOut: ... - - -@dataclass -class CompiledFxGraphLoadable(InductorOutput[CompiledFxGraph]): - """ - A full compiled fx graph that doesn't need to lookup the FxGraphCache - to run - """ - - result: CompiledFxGraph - - def pre_save(self) -> None: - disk_compiled_graph = copy(self.result) - disk_compiled_graph.prepare_for_serialization() - self.result = disk_compiled_graph - return - - def load(self, example_inputs) -> CompiledFxGraph: - self.example_inputs = example_inputs - - return self.result - - def post_compile( - self, result: CompiledFxGraph, fx_config: _CompileFxKwargs - ) -> CompiledFxGraph: - constants = CompiledFxGraphConstants() - # Cache hit specific post compile - graph, cache_info = FxGraphCache.cache_hit_post_compile(result, {}, constants) - if graph is None: - raise BypassAOTAutogradCache("Failed to reload cache entry from disk") - torch._logging.trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "fx_graph_bundled_cache_hit", # always a hit - "encoding": "json", - }, - payload_fn=lambda: json.dumps(cache_info), - ) - # Run normal post compile - graph.post_compile(self.example_inputs, constants, fx_config) - return graph - - -@dataclass -class FxGraphCacheLoadable(InductorOutput[CompiledFxGraph]): - fx_graph_cache_info: tuple[str, list[str]] - fx_graph_guard_expr: Optional[str] - - def pre_save(self): - return - - def _is_backward(self) -> bool: - return False - - def load(self, example_inputs) -> CompiledFxGraph: - # [Note: AOTAutogradCache and FXGraphCache Guard interactions] - # As mentioned, AOTAutograd takes in the symint inputs from dynamo's list of arguments. - # FXGraphCache serializes guards that are needed in the shape_env based on these symint inputs to the graph. - # The invariant that AOTAutograd uses here is that the sources for symints given to it by dynamo are exactly - # the same as the ones it passes to inductor, for both the forward and backward passes. - # (This does not mean that the tensor values passed in are the same: only that their symints are). - # That is, AOTAutograd and Inductor never create new guards based on symints with different sources - # than those passed to it by inductor. - - # We pass the post compile function, which sets various fx_config boxed values, - # so we can call it only after we're sure both forward and backward have - - # Clear CompiledTritonKernels before loading from FXGraphCache - torch._inductor.async_compile.CompiledTritonKernels.cache_clear() - remote_cache = None - constants = CompiledFxGraphConstants() - if should_use_remote_fx_graph_cache(): - remote_cache = FxGraphCache.get_remote_cache() - (cache_key, debug_lines) = self.fx_graph_cache_info - - def check_exact_guard_match(guard_expr, _hints): - """ - AOTAutogradCache tracks its own guards, so we just need to treat these guard expressions as a second - cache key of sorts: we just check for equality, i.e. the FXGraphCache entry with - the exact same guards as we originally saved into the cache. - """ - return guard_expr == self.fx_graph_guard_expr - - result, cache_info = FxGraphCache.load_with_key( - cache_key, - debug_lines, - example_inputs, - local=True, - remote_cache=remote_cache, - is_backward=self._is_backward(), - constants=constants, - evaluate_guards=check_exact_guard_match, - ) - if result is None: - log.info("FXGraphCache cache miss for key %s", self.fx_graph_cache_info) - torch._logging.trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "fx_graph_cache_miss", # always a hit - "encoding": "json", - }, - payload_fn=lambda: json.dumps(cache_info), - ) - - raise FXGraphCacheMiss - - # No need to log chromium event because AOTAutograd will log that immediately for us - torch._logging.trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "fx_graph_cache_hit", # always a hit - "encoding": "json", - }, - payload_fn=lambda: json.dumps(cache_info), - ) - self.example_inputs = example_inputs - self.constants = constants - return result - - def post_compile( - self, result: CompiledFxGraph, fx_config: _CompileFxKwargs - ) -> CompiledFxGraph: - """ - Called after FXGraphCacheLoadable.load, mutates fx_config - """ - result.post_compile(self.example_inputs, self.constants, fx_config) - return result - - -@dataclass -class CompiledForward(FxGraphCacheLoadable): - """ - Cacheable entry for a forward function - """ - - def _is_backward(self) -> bool: - return False - - -@dataclass -class GenericCompiledBackward(InductorOutput[TOut]): - # Used by AOTDispatchAutograd.post_compile - backward_state_indices: list[int] - num_symints_saved_for_bw_: int - - -@dataclass -class CompiledBackward(GenericCompiledBackward[CompiledFxGraph], FxGraphCacheLoadable): - """ - Cacheable entry for a forward function - """ - - def _is_backward(self) -> bool: - return True - - def post_compile( - self, result: CompiledFxGraph, fx_config: _CompileFxKwargs - ) -> CompiledFxGraph: - compiled_bw = super().post_compile(result, fx_config) - # See note [Wrapping bw_compiler in disable] - # This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py - # But since on cache hit we do not call the bw_compiler, we need to reapply the disable - return torch._dynamo.disable( # type: ignore[return-value] - compiled_bw, reason="do not trace generated backwards pass" - ) - - -# Forward types don't have any extra parameters, so this is just a TypeAlias, in essence -class BundledCompiledForward(CompiledFxGraphLoadable): - pass - - -@dataclass -class BundledCompiledBackward( - GenericCompiledBackward[CompiledFxGraph], CompiledFxGraphLoadable -): - def post_compile( - self, result: CompiledFxGraph, fx_config: _CompileFxKwargs - ) -> CompiledFxGraph: - compiled_bw = super().post_compile(result, fx_config) - # See note [Wrapping bw_compiler in disable] - # This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py - # But since on cache hit we do not call the bw_compiler, we need to reapply the disable - return torch._dynamo.disable( # type: ignore[return-value] - compiled_bw, reason="do not trace generated backwards pass" - ) - - -@dataclass -class SerializedGraphModule: - fn: Callable[[dict[Any, Any], str], torch.nn.Module] - args: tuple[Any, ...] - - def __init__(self, gm: torch.fx.GraphModule): - self.fn, self.args = gm.__reduce__() - - def deserialize(self) -> torch.fx.GraphModule: - gm = self.fn(*self.args) - assert isinstance(gm, torch.fx.GraphModule) - return gm - - -def serialize_graph_module(gm: torch.fx.GraphModule) -> SerializedGraphModule: - # NOTE: mutates the graph module - gm.meta = {} - for node in gm.graph.nodes: - node.meta = {} - return SerializedGraphModule(gm) - - -TForward = TypeVar("TForward", bound=InductorOutput) -TBackward = TypeVar("TBackward", bound=GenericCompiledBackward) - - -@dataclass -class GenericAOTAutogradCacheEntry(Generic[TForward, TBackward]): - """A single entry into the cache, genericized by Forward and Backward types. - - A TForward is always an InductorOutput of some sort, which represents the - forward graph of the compile. - A TBackward is an InductorOutput + metadata about the backward, useful for specific - backward-only wrappers. This type is encapsulated by GenericCompiledBackward. - - Each AOTAutogradCacheEntry is essentially parameterized by 1. the method of loading - from the cache (either Bundled or UnBundled), and 2. The type of the output. For now, - the only type of output we support is Python Wrapper output, i.e. OutputCode.CompiledFxGraph, - but the same technique works for C++ wrapper code; we'd just add an extra InductorOutput type. - """ - - # Forward and Backward info - compiled_fw: TForward - compiled_bw: Optional[TBackward] - - # Code of the joint graph using print_readable() - # Used for logging purposes - aot_joint_graph_str: Optional[str] - aot_forward_graph_str: Optional[str] - aot_backward_graph_str: Optional[str] - - # Runtime_metadata saved right before compilation - runtime_metadata: ViewAndMutationMeta - - # Wrappers that run after each aot_dispatch_* function - dispatch_wrappers: list[CompilerWrapper] - - # Used by AOTSubclassWrapper - maybe_subclass_meta: Optional[SubclassMeta] - num_fw_outs_saved_for_bw: Optional[int] - - # Used by RuntimeWrapepr - indices_of_inps_to_detach: list[int] - - # Time taken to trace/compile the forward - # forward_time_taken includes AOTAutograd tracing time + inductor compilation time - # backward_time_taken is essentially just the time inductor took to compile - forward_time_taken_ns: int - backward_time_taken_ns: int - - # Used by standalone_compile - sanitized_aot_config: AOTConfig - - guards_expr: Optional[str] - - # Used by Compiled Autograd - serialized_bw_module: Optional[SerializedGraphModule] - - def pre_save(self): - """ - Perform any preparations to make the cache entry ready for serialization. - """ - self.compiled_fw.pre_save() - if self.compiled_bw is not None: - self.compiled_bw.pre_save() - - # Turn cache entry into the original callable - def wrap_post_compile( - self, - args: list[torch.Tensor], - aot_config: AOTConfig, - fx_config: _CompileFxKwargs, - ) -> Callable: - """ - This function takes a cache entry and carefully reconstructs the original callable - that AOTAutograd returned the first time it was run. It does this by running the various - post compile steps that AOTAutograd runs on its compiled artifact after running the fw/bw compilers. - - In the inference path, this consists of the Subclass, FunctionalzedRngRuntime, and RuntimeWrappers. - In the autograd path, this consists of AOTAutogradDispatch.post_compile. - - The steps here should match exactly the steps that are run in aot_dispatch_base and aot_dispatch_autograd. - - Notably absent from the cached path are: - - DebugAssertWrapper - - FakifiedOutWrapper - - Which we'll handle separately later on, if necessary. - """ - # Log the output of AOTAutogradCache - if aot_config.enable_log: - # TODO: maybe also log to aot_graphs_log - # Unfortunately aot_graphs_log uses - # slightly different formatting though - if self.aot_joint_graph_str is not None: - torch._logging.trace_structured( - "aot_joint_graph", payload_fn=lambda: self.aot_joint_graph_str - ) - - if self.aot_forward_graph_str is not None: - torch._logging.trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "aot_forward_graph_fw_metadata", - "encoding": "string", - }, - payload_fn=lambda: dataclass_repr(self.runtime_metadata), - ) - if self.maybe_subclass_meta is not None: - torch._logging.trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "aot_forward_graph_fw_subclass_metadata", - "encoding": "string", - }, - payload_fn=lambda: dataclass_repr(self.maybe_subclass_meta), - ) - - # It's called an inference graph if not running with autograd - name = ( - "aot_forward_graph" - if self.aot_backward_graph_str is not None - else "aot_inference_graph" - ) - torch._logging.trace_structured( - name, payload_fn=lambda: self.aot_forward_graph_str - ) - - if self.aot_backward_graph_str is not None: - torch._logging.trace_structured( - "aot_backward_graph", payload_fn=lambda: self.aot_backward_graph_str - ) - with dynamo_timed("AOTAutogradCache.inductor_load"): - compiled_fw_func = self.compiled_fw.load(args) - compiled_bw_func = None - if self.compiled_bw is not None: - compiled_bw_func = self.compiled_bw.load(args) - needs_autograd = True - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="autograd" - ) - # Now that we've loaded forward and backward, call post compile on both - # This avoids setting things like BoxedBools in fx_config until - # after both forward and backward cache hit - fw_fx_config: _CompileFxKwargs = { - **fx_config, - "is_backward": False, - } - bw_fx_config: _CompileFxKwargs = { - **fx_config, - "is_backward": True, - } - compiled_fw_func = self.compiled_fw.post_compile( - compiled_fw_func, fw_fx_config - ) - compiled_bw_func = self.compiled_bw.post_compile( - compiled_bw_func, bw_fx_config - ) - else: - inference_fx_config: _CompileFxKwargs = { - **fx_config, - "is_backward": False, - } - - needs_autograd = False - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="inference" - ) - compiled_fw_func = self.compiled_fw.post_compile( - compiled_fw_func, inference_fx_config - ) - - # Wrap the forward function in post compile wrappers - compiled_fw_func = AOTDispatchSubclassWrapper( - trace_joint=needs_autograd, - fw_only=None, - maybe_subclass_meta=self.maybe_subclass_meta, - num_fw_outs_saved_for_bw=self.num_fw_outs_saved_for_bw, - ).post_compile( - compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata - ) - - req_subclass_dispatch = self.maybe_subclass_meta is not None - CompileEventLogger.try_add_pt2_compile( - "backend_compile", requires_subclass_dispatch=req_subclass_dispatch - ) - - # In autograd case, functionalizedRngWrapper should not modify outs - return_new_outs = not needs_autograd - compiled_fw_func = FunctionalizedRngRuntimeWrapper( - return_new_outs=return_new_outs - ).post_compile( - compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata - ) - disable_amp = torch._C._is_any_autocast_enabled() - - if needs_autograd: - assert self.compiled_bw is not None - - cached_lazy_backward = None - if self.serialized_bw_module is not None: - cached_lazy_backward = CachedAutogradLazyBackwardCompileInfo( - self.serialized_bw_module.deserialize - ) - # This function is run on both cache miss and cache hit, either here - # or in aot_dispatch_autograd. On a cache hit, - # 1. the bw is already compiled - # 2. we don't need to save to the cache again - # so those corresponding arguments are set to None. - compiled_function = AOTDispatchAutograd.post_compile( - compiled_fw_func, - compiled_bw_func, - self.maybe_subclass_meta, - self.compiled_bw.num_symints_saved_for_bw_, - self.compiled_bw.backward_state_indices, - disable_amp, - self.indices_of_inps_to_detach, - cached_lazy_backward, - aot_config, - fw_metadata=self.runtime_metadata, - try_save_cache_entry=None, - ) - - else: - compiled_function = RuntimeWrapper( - indices_of_inps_to_detach=self.indices_of_inps_to_detach, - trace_joint=False, - disable_amp=disable_amp, - ).post_compile( - compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata - ) - - # Add serialization function back onto object - compiled_function, _ = post_compile( - self.dispatch_wrappers, - compiled_function, - aot_config, - runtime_metadata=self.runtime_metadata, - ) - - # Now that we're pretty sure it's a successful load, add guards - # to the existing shape environment from the cache - if self.guards_expr: - symints = AOTAutogradCache._filter_backed_symints(args) - check = bool(AOTAutogradCache.evaluate_guards(self.guards_expr, symints)) - assert check is True - - return compiled_function - - -class AOTAutogradCacheEntry( - GenericAOTAutogradCacheEntry[CompiledForward, CompiledBackward] -): - """ - Regular AOTAutogradCacheEntry: saves the forward/backward FxGraphCache keys - and looks them up in FxGraphCache on load - """ - - -class BundledAOTAutogradCacheEntry( - GenericAOTAutogradCacheEntry[BundledCompiledForward, BundledCompiledBackward] -): - """ - AOTAutogradCacheEntry where we save the entire CompiledFxGraph instead - of relying on cache keys from FxGraphCache - """ - - @contextlib.contextmanager def sanitize_gm_for_cache(gm: torch.fx.GraphModule): """ @@ -1042,62 +544,10 @@ def type(): return "aot_autograd" -def deserialize_bundled_cache_entry(entry: BundledAOTAutogradCacheEntry) -> Callable: - # In the precompile use case, guards are already serialized - # by dynamo, so we don't need to add them to the environment - entry.guards_expr = None - # TODO: this isn't exactly right, because cudagraphs needs to be a shared config - # which is set by compile_fx. But in precompile, we never actually call compile_fx - # so we don't have a place to track cudagraphs here. - cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs) - boxed_forward_device_index = BoxedDeviceIndex(None) - # We need to make a clean copy of the cache entry - # in case it needs to be serialized again - serializable_copy = deepcopy(entry) - - from torch._subclasses import FakeTensorMode - from torch.fx.experimental.symbolic_shapes import ShapeEnv - - context = torch._guards.TracingContext.try_get() - if context is None: - # Create a clean environment when running fx graph post compile - # if one is not available - context = torch._guards.TracingContext(FakeTensorMode(shape_env=ShapeEnv())) - with torch._guards.tracing(context): - compiled_fn = entry.wrap_post_compile( - [], - entry.sanitized_aot_config, - { - "cudagraphs": cudagraphs, - "boxed_forward_device_index": boxed_forward_device_index, - }, - ) - # Ensure the deserialized cache entry is still serializable - - compiled_fn = SerializableCompiledFunction(compiled_fn, lambda: serializable_copy) - - # TODO: this ignores flat_params, which can exist - # if inline_builtin_nn_modules=False - @simple_wraps(compiled_fn) - def forward(*runtime_args: tuple[Any]): - return compiled_fn(list(runtime_args)) - - assert hasattr(compiled_fn, "serialize") - forward.serialize = compiled_fn.serialize # type: ignore[attr-defined] - - return forward - - -@dataclass -class BundledAOTAutogradCacheArtifact(BackendCacheArtifact[Callable]): - def after_deserialization(self) -> Callable: - return deserialize_bundled_cache_entry(self.content) - - -class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]): +class AOTAutogradCache(GuardedCache[GenericAOTAutogradResult]): """ Caches the results of running AOTAutograd. This class mostly handles the save and load logic, whereas - AOTAutogradCacheEntry handles the wrapping/unwrapping logic. + AOTAutogradResult handles the wrapping/unwrapping logic. Cache Inputs (AOTAutogradCacheDetails) - AOTAutogradCache takes in the following inputs, which are analogous to inputs given @@ -1115,11 +565,11 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]): In a later PR, we'll likely generate the cache key based on the FakeTensors AOTAutograd generates based on the real tensor inputs, which can contain symints. - # Cache Outputs (AOTAutogradCacheEntry) + # Cache Outputs (AOTAutogradResult) - AOTAutogradCache caches the following values: - The compiled forward and backward functions from inductor, via keys to the FXGraphCache - Metadata to reconstruct the AOTModule from the compiled inductor artifacts - - See AOTAutogradCacheEntry for more info + - See AOTAutogradResult for more info [Note: Caching guards generated by AOTAutograd and Inductor] AOTAutograd and inductor both can introduce new guards to the shape environment. FXGraphCache saves guards with each @@ -1167,7 +617,7 @@ def try_load( cache_key, debug_lines = autograd_cache_key( gm, args, aot_config, fx_config ) - result: Optional[tuple[GenericAOTAutogradCacheEntry, bytes]] = ( + result: Optional[tuple[GenericAOTAutogradResult, bytes]] = ( AOTAutogradCache._lookup( cache_key, local, remote, args, cache_info, aot_config ) @@ -1339,7 +789,7 @@ def _lookup( args: list[Any], cache_info: dict[str, Any], aot_config: Optional[AOTConfig], - ) -> Optional[tuple[GenericAOTAutogradCacheEntry, bytes]]: + ) -> Optional[tuple[GenericAOTAutogradResult, bytes]]: """Given a key generated by AOTAutogradCachePickler, look up its location in the cache.""" remote_cache: Optional[RemoteCache[JsonDataTy]] = None if remote: @@ -1403,7 +853,7 @@ def _write_to_local_cache(key: str, content: bytes): write_atomic(path, content) @staticmethod - def save(key: str, entry: GenericAOTAutogradCacheEntry, remote: bool): + def save(key: str, entry: GenericAOTAutogradResult, remote: bool): """Save a single entry into the cache.""" try: entry.pre_save() @@ -1469,8 +919,8 @@ def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: @staticmethod def make_entry( - compiled_fw_func: CompiledFxGraph, - compiled_bw_func: Optional[CompiledFxGraph], + compiled_fw_func: OutputCode, + compiled_bw_func: Optional[OutputCode], aot_joint_graph_str: Optional[str], aot_forward_graph_str: Optional[str], aot_backward_graph_str: Optional[str], @@ -1486,28 +936,28 @@ def make_entry( backward_state_indices: Optional[list[int]], num_symints_saved_for_bw: Optional[int], serialized_bw_module: Optional[SerializedGraphModule], - ) -> GenericAOTAutogradCacheEntry: + ) -> GenericAOTAutogradResult: if should_bundle_autograd_cache(): # Helper function to unwrap all the wrappers we added during aotdispatch # They get reapplied on cache load - def unwrap_compiled_fx_graph(obj): + def unwrap_output_code(obj): while hasattr(obj, "__wrapped__"): obj = obj.__wrapped__ - assert isinstance(obj, CompiledFxGraph) + assert isinstance(obj, OutputCode) return obj - compiled_fw_graph = unwrap_compiled_fx_graph(compiled_fw_func) + compiled_fw_graph = unwrap_output_code(compiled_fw_func) bundled_compiled_forward = BundledCompiledForward(compiled_fw_graph) bundled_compiled_backward = None if compiled_bw_func is not None: assert backward_state_indices is not None assert num_symints_saved_for_bw is not None - compiled_bw_graph = unwrap_compiled_fx_graph(compiled_bw_func) + compiled_bw_graph = unwrap_output_code(compiled_bw_func) bundled_compiled_backward = BundledCompiledBackward( compiled_bw_graph, backward_state_indices, num_symints_saved_for_bw ) - return BundledAOTAutogradCacheEntry( + return BundledAOTAutogradResult( compiled_fw=bundled_compiled_forward, compiled_bw=bundled_compiled_backward, aot_joint_graph_str=aot_joint_graph_str, @@ -1552,7 +1002,7 @@ def unwrap_compiled_fx_graph(obj): num_symints_saved_for_bw_=num_symints_saved_for_bw, ) - return AOTAutogradCacheEntry( + return AOTAutogradResult( compiled_fw=compiled_forward, compiled_bw=compiled_backward, aot_joint_graph_str=aot_joint_graph_str, diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index 6f0a76d5d6f13..11cef0f9205a5 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -240,7 +240,7 @@ def inner(*flat_args): # Inspect the state of the input tensor functional wrapper to detect input mutation info # If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version - for i, (arg, f_arg) in enumerate(zip(flat_args, flat_f_args)): + for arg, f_arg in zip(flat_args, flat_f_args): # NB: Mutation of non-contiguous tensor subclass input can result in a mismatch in # strides between the functionalized arg inner tensors and non-functionalized arg inner # tensors. This is a problem as the inner tensor stride change may not be reflected diff --git a/torch/_functorch/_aot_autograd/frontend_utils.py b/torch/_functorch/_aot_autograd/frontend_utils.py index 83f98e34fc4bf..c36a71ae96318 100644 --- a/torch/_functorch/_aot_autograd/frontend_utils.py +++ b/torch/_functorch/_aot_autograd/frontend_utils.py @@ -14,6 +14,7 @@ from torch.utils._python_dispatch import is_traceable_wrapper_subclass from .. import config +from .descriptors import BufferAOTInput, DifferentiableAOTInput, ParamAOTInput from .schemas import AOTConfig, FakifiedFlatArgs @@ -107,7 +108,10 @@ def construct_fake_mode( def _try_get_metadata_from_dynamo( - mod: torch.nn.Module, param_keys: KeysView[str], full_args_num: int + mod: torch.nn.Module, + param_keys: KeysView[str], + full_args_num: int, + full_args_descs: list[DifferentiableAOTInput], ) -> tuple[Optional[list[torch._guards.Source]], list[int]]: """ Metadata is forwarded from Dynamo to AOTDispatch via special fields on GraphModule. @@ -130,7 +134,12 @@ def _try_get_metadata_from_dynamo( if not hasattr(mod, "_param_name_to_source"): # is from export - return None, [] + static_input_indices = [ + i + for i, node in enumerate(full_args_descs) + if isinstance(node, (ParamAOTInput, BufferAOTInput)) + ] + return None, static_input_indices # We now know this came from dynamo, and (1) we care about guards, # so setting up aot_autograd_arg_pos_to_source for downstream dedup guards diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py index 958804e5c763f..fcbf861e537db 100644 --- a/torch/_functorch/_aot_autograd/functional_utils.py +++ b/torch/_functorch/_aot_autograd/functional_utils.py @@ -10,7 +10,6 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional import torch from torch import Tensor @@ -225,7 +224,7 @@ def gen_alias_from_base( aliased_base_tensor, target_meta_tensor, target_requires_grad, - target_view_meta_sequence: Optional[ViewMetaSequence] = None, + target_view_meta_sequence: ViewMetaSequence | None = None, *, replay_views: bool, ): @@ -337,8 +336,8 @@ class MetadataKey: layout: torch.layout is_sparse: bool # these are empty when is_sparse - stride: Optional[tuple[SymIntEqByExpr, ...]] - storage_offset: Optional[SymIntEqByExpr] + stride: tuple[SymIntEqByExpr, ...] | None + storage_offset: SymIntEqByExpr | None is_conj: bool is_neg: bool diff --git a/torch/_functorch/_aot_autograd/graph_compile.py b/torch/_functorch/_aot_autograd/graph_compile.py index 5967bfdd6b850..60ee3bc2973b1 100644 --- a/torch/_functorch/_aot_autograd/graph_compile.py +++ b/torch/_functorch/_aot_autograd/graph_compile.py @@ -41,7 +41,7 @@ from torch._subclasses.meta_utils import is_sparse_any from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental.proxy_tensor import is_sym_node -from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals +from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals, guard_or_true from torch.fx.graph_module import GraphModule from torch.fx.passes._tensorify_python_scalars import tensorify_python_scalars from torch.multiprocessing.reductions import StorageWeakRef @@ -50,10 +50,9 @@ from torchgen.utils import dataclass_repr from .. import config +from .aot_autograd_result import GenericAOTAutogradResult, serialize_graph_module from .autograd_cache import ( AOTAutogradCache, - GenericAOTAutogradCacheEntry, - serialize_graph_module, should_bundle_autograd_cache, should_use_remote_autograd_cache, ) @@ -89,7 +88,6 @@ ) from .subclass_utils import compute_inner_mutated_inp_indices_from_subclass_meta from .utils import ( - _get_symint_hints, contain_metadata_mutation_ops, get_cuda_generator_meta_val, make_boxed_func, @@ -397,7 +395,7 @@ def should_save_cache(): else: return hasattr(compiled_fw, "_fx_graph_cache_key") - entry: Optional[GenericAOTAutogradCacheEntry] = None + entry: Optional[GenericAOTAutogradResult] = None if cache_info is not None and should_save_cache(): time_taken_ns = time.time_ns() - cache_info.start_time_ns guards_expr = AOTAutogradCache.generate_guards_expression(cache_info) @@ -1774,8 +1772,28 @@ def _aot_stage2b_bw_compile( # Comparing ph_arg.stride() with real_stride directly may # cause dynamic dimensions in ph_arg being specialized to static - # value. Using the hints to avoid that. - if _get_symint_hints(ph_arg.stride()) != real_stride: + # value. Using suppress_guards and guard_or_true to avoid that. + + stride_different = False + fake_mode = detect_fake_mode() + suppress_ctx = ( + fake_mode.shape_env.suppress_guards() + if fake_mode is not None and fake_mode.shape_env is not None + else nullcontext() + ) + + # Inductor can choose different strides for activations than + # what backward graph has. if we can't statically tell that + # strides are the same, we assume they are not. + with suppress_ctx: + for k in range(len(ph_arg.stride())): + # real_stride can't be symbolic. + # pyrefly: ignore [index-error] + if guard_or_true(ph_arg.stride()[k] != int(real_stride[k])): + stride_different = True + break + + if stride_different: # Note that here we use the stride of the real tensor to # restride a FakeTensor. This does not cause trouble # for dynamic shape since this code path only get @@ -2031,7 +2049,7 @@ def _cache_autograd_info( make_runtime_safe(fw_metadata, maybe_subclass_meta) try_save_cache_entry: Optional[Callable] = None - entry: Optional[GenericAOTAutogradCacheEntry] = None + entry: Optional[GenericAOTAutogradResult] = None if aot_config.cache_info is not None: forward_time_taken_ns = time.time_ns() - aot_config.cache_info.start_time_ns @@ -2044,7 +2062,7 @@ def try_save_cache_entry( # noqa: F811 bw_module: torch.fx.GraphModule, _fw_metadata: ViewAndMutationMeta, aot_config: AOTConfig, - ) -> Optional[GenericAOTAutogradCacheEntry]: + ) -> Optional[GenericAOTAutogradResult]: cache_info = aot_config.cache_info def should_save_cache(): @@ -2140,6 +2158,7 @@ def _aot_stage2b_compile_forward_or_inference( - FunctionalizedRngRuntimeWrapper - FakifiedOutWrapper """ + # Validation if not is_inference and num_fw_outs_saved_for_bw is None: raise ValueError( diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 33aea13c3365d..4846f1ca74edb 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -2041,7 +2041,7 @@ def maybe_coerce(x): assert len(meta.attrs) == len(runtime_subclass_keys) leaves = [] - for i, (attr, attr_meta) in enumerate(meta.attrs.items()): + for attr, attr_meta in meta.attrs.items(): elem = getattr(x, attr) new_elem, elem_leaves = AOTDispatchAutograd.process_runtime_tangent( elem, attr_meta diff --git a/torch/_functorch/_aot_autograd/subclass_parametrization.py b/torch/_functorch/_aot_autograd/subclass_parametrization.py index 3b7f80114bbf2..0ea6635a62e81 100644 --- a/torch/_functorch/_aot_autograd/subclass_parametrization.py +++ b/torch/_functorch/_aot_autograd/subclass_parametrization.py @@ -98,7 +98,7 @@ def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Modul module, name, UnwrapTensorSubclass() ) - for name, child in module.named_children(): + for child in module.children(): unwrap_tensor_subclass_parameters(child) return module diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index 4fd88e53f3843..844f34bb576da 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -249,7 +249,7 @@ def maybe_to_fresh_input(idx, t, meta): def is_with_effects(node): return ( node.op == "call_function" - and node.target == torch.ops.higher_order.with_effects + and node.target is torch.ops.higher_order.with_effects ) @@ -295,7 +295,7 @@ def rewrite_output(module, node, output_token_nodes, other_output_args): for output_token_node in output_token_nodes: assert ( output_token_node.op == "call_function" - and output_token_node.target == operator.getitem + and output_token_node.target is operator.getitem and output_token_node.args[1] == 0 ) with module.graph.inserting_before(node): @@ -327,7 +327,7 @@ def do(module, subgraph, expected_num_erased): if ( isinstance(out, torch.fx.node.Node) and out.op == "call_function" - and out.target == operator.getitem + and out.target is operator.getitem and out.args[1] == 0 and out.args[0] in with_effect_nodes ): diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index f48cb04f08f98..2aa70a76e6e78 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -975,7 +975,9 @@ def prepare_aot_module_simplified( ( aot_autograd_arg_pos_to_source, static_input_indices, - ) = _try_get_metadata_from_dynamo(mod, params_buffers.keys(), len(full_args)) + ) = _try_get_metadata_from_dynamo( + mod, params_buffers.keys(), len(full_args), full_args_descs + ) dynamic_shapes = False for x in full_args: diff --git a/torch/_functorch/compile_utils.py b/torch/_functorch/compile_utils.py index cdf2e1855a093..49a1adacab6ef 100644 --- a/torch/_functorch/compile_utils.py +++ b/torch/_functorch/compile_utils.py @@ -99,7 +99,7 @@ def checkable_node(node: fx.Node) -> bool: # so it's not worth CSEing. or get_aten_target(n) is aten.empty or n in nodes_that_alias_outputs - # This CSE pass currently doesn't handle re-propogation of unbacked + # This CSE pass currently doesn't handle re-propagation of unbacked # meta where it'll sometimes eliminate a _local_scalar_dense but not # replace the meta of downstream users. eg. one bug we've seen is: # diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 89fd907619175..3dd2529b1b107 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -20,7 +20,7 @@ # [@compile_ignored: debug] _save_config_ignore = [ - # callable not serializeable + # callable not serializable "joint_custom_pass", ] diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 8e194a0f0ce77..e7f8075b0281e 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -180,6 +180,7 @@ def _extract_graph_with_inputs_outputs( outputs: list[fx.Node], outputs_descs: list[AOTOutput], subgraph: Optional[str] = None, + ignore_must_be_in_fw_bw: bool = False, ) -> fx.Graph: """ Given a graph, extracts out a subgraph that takes the specified nodes as @@ -203,13 +204,22 @@ def _extract_graph_with_inputs_outputs( env[node] = new_node for node in joint_graph.nodes: - if _must_be_in_backward(node) and subgraph != "backward" and node not in inputs: - env[node] = InvalidNode # type: ignore[assignment] - continue + if not ignore_must_be_in_fw_bw: + if ( + _must_be_in_backward(node) + and subgraph != "backward" + and node not in inputs + ): + env[node] = InvalidNode # type: ignore[assignment] + continue - if _must_be_in_forward(node) and subgraph != "forward" and node not in inputs: - env[node] = InvalidNode # type: ignore[assignment] - continue + if ( + _must_be_in_forward(node) + and subgraph != "forward" + and node not in inputs + ): + env[node] = InvalidNode # type: ignore[assignment] + continue if node in env: # Node must be one of our inputs. (Any member of env which wasn't an @@ -1086,7 +1096,7 @@ def is_mutated_later_in_fw(node): ): # Since we can't save tuple of tensor values, we need to flatten out what we're saving users = node.users - assert all(user.target == operator.getitem for user in users) + assert all(user.target is operator.getitem for user in users) saved_values.extend(users) else: backward_usages = [ @@ -1256,7 +1266,7 @@ def insert_node_in_graph(node): # Build the graph op-by-op by starting from the node all the way to the end # copy_ can be not using tangents at all, we must copy it. for node in list(gm.graph.nodes)[: order[first_node_in_bwd]]: - if node.op == "call_function" and node.target == torch.ops.aten.copy_.default: + if node.op == "call_function" and node.target is torch.ops.aten.copy_.default: insert_node_in_graph(node) for node in list(gm.graph.nodes)[order[first_node_in_bwd] :]: @@ -1481,9 +1491,7 @@ def get_sample_rng_state(device: Optional[torch.device]): ) ) - for rng_count, (base_node, node_pair) in enumerate( - recomputable_rng_ops_map.items() - ): + for rng_count, node_pair in enumerate(recomputable_rng_ops_map.values()): # Step 2 - Modify the fwd pass such that fw_node = node_pair["fwd"] bw_node = node_pair["bwd"] @@ -1598,7 +1606,7 @@ def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None: if node.op == "output": continue - is_copy_ = node.target == torch.ops.aten.copy_.default + is_copy_ = node.target is torch.ops.aten.copy_.default if is_copy_: if _has_tag_must_be_in_backward(node): has_mutation_in_bw.add(node.args[0]) @@ -1744,7 +1752,7 @@ def is_materialized_backwards(node): def should_ban_recomputation(node): if node.op != "call_function": return False - if node.target == operator.getitem: + if node.target is operator.getitem: return False if node.meta.get("recompute", None) == CheckpointPolicy.MUST_SAVE: return True @@ -2714,9 +2722,7 @@ def thread_graphsafe_rng_from_hops(module, is_backward): subgraph = getattr(module, hop_node.args[0].target) if isinstance(subgraph, fx.GraphModule): new_rng_inputs = [] - for idx, placeholder_node in enumerate( - subgraph.graph.find_nodes(op="placeholder") - ): + for placeholder_node in subgraph.graph.find_nodes(op="placeholder"): if rng_string in placeholder_node.name: # Found a rng state placeholder in the hop graph, lets add # the corresponding node in the outer graph diff --git a/torch/_functorch/pyfunctorch.py b/torch/_functorch/pyfunctorch.py index 0a811ed86c21c..b76cd191c3cc9 100644 --- a/torch/_functorch/pyfunctorch.py +++ b/torch/_functorch/pyfunctorch.py @@ -116,7 +116,7 @@ def temporarily_restore_interpreter_stack(stack): pushed.append(s) yield finally: - for s in reversed(pushed): + for _ in reversed(pushed): # TODO: would be nice to assert that the layers are the same, but # Python object identity is not preserved pop_dynamic_layer_stack() diff --git a/torch/_guards.py b/torch/_guards.py index fc8f88f237c4c..bac59965a3aef 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -65,15 +65,15 @@ @dataclass(frozen=True, kw_only=True, slots=True) class CompileId: - frame_id: Optional[int] + frame_id: int | None # This id is per-frame, and counts how many times we've compiled this # frame. This could have been a global id but having this be per-frame # gives you a better intuitive sense for how many recompiles have occurred # so far. - frame_compile_id: Optional[int] + frame_compile_id: int | None # torch.compiling a compiled autograd graph - compiled_autograd_id: Optional[int] = None + compiled_autograd_id: int | None = None # TODO: consider also tracking the recompilation count # See Note: Updating CompileId @@ -210,8 +210,8 @@ class GuardBuilderBase: @dataclasses.dataclass(frozen=True) class SLoc: - framework_loc: Optional[Union[traceback.FrameSummary, str]] - maybe_user_loc: Optional[str] + framework_loc: traceback.FrameSummary | str | None + maybe_user_loc: str | None def __str__(self) -> str: floc = ( @@ -758,9 +758,7 @@ def __init__(self) -> None: self.hop_cache_map = {invoke_subgraph: InvokeSubgraphCache()} - def get_cache( - self, op: torch._ops.HigherOrderOperator - ) -> Optional[HopSubgraphCache]: + def get_cache(self, op: torch._ops.HigherOrderOperator) -> HopSubgraphCache | None: if op not in self.hop_cache_map: return None return self.hop_cache_map[op] # type: ignore[index] @@ -794,12 +792,12 @@ def get() -> CompileContext: return _TLS.compile_context @staticmethod - def try_get() -> Optional[CompileContext]: + def try_get() -> CompileContext | None: return getattr(_TLS, "compile_context", None) def __init__(self, compile_id: Optional[CompileId]) -> None: assert compile_id is None or isinstance(compile_id, CompileId) - self.compile_id: Optional[CompileId] = compile_id + self.compile_id: CompileId | None = compile_id self.attempt = 0 # Verbose ShapeEnv guards produced. self.shape_env_guards: list[str] = [] @@ -830,7 +828,7 @@ class TracingContext: """ @staticmethod - def try_get() -> Optional[TracingContext]: + def try_get() -> TracingContext | None: return getattr(_TLS, "tracing_context", None) @staticmethod @@ -874,7 +872,7 @@ def __init__(self, fake_mode: Optional[FakeTensorMode]) -> None: # careful not to accidentally induce guards on the SymInt if # you ever do change this in aot_autograd.py; you should check # on permutations preferentially.) - self.output_strides: Optional[list[Optional[tuple[int, ...]]]] = None + self.output_strides: list[tuple[int, ...] | None] | None = None # When this is True, whenever we encounter an int in Dynamo tracing, # we will (1) force unspec it and (2) force it as a size-like unbacked # integer. This is currently used when processing certain lists of diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py index 68942ee0b9032..3f93036836eec 100644 --- a/torch/_higher_order_ops/auto_functionalize.py +++ b/torch/_higher_order_ops/auto_functionalize.py @@ -44,7 +44,7 @@ def from_tree_spec(cls, tree_spec: pytree.TreeSpec): return cls(pytree.tree_unflatten([], tree_spec).schema) -# regsiter_constant allows us to get a tree_spec from pytree.tree_flatten(SchemaHolder(FunctionSchema)). +# register_constant allows us to get a tree_spec from pytree.tree_flatten(SchemaHolder(FunctionSchema)). # The tree_spec is proxable in the graph and we can get back the schema via # schema = pytree.tree_unflatten([], tree_spec).schema pytree.register_constant(SchemaHolder) diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 0849d690eaf41..f2d3c96a5cbfd 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -102,8 +102,8 @@ def cond( Conditionally applies `true_fn` or `false_fn`. .. warning:: - `torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types and - doesn't support training currently. Please look forward to a more stable implementation in a future version of PyTorch. + `torch.cond` is a prototype feature in PyTorch. It has limited support for input and output types. + Please look forward to a more stable implementation in a future version of PyTorch. Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype `cond` is structured control flow operator. That is, it is like a Python if-statement, diff --git a/torch/_higher_order_ops/out_dtype.py b/torch/_higher_order_ops/out_dtype.py index 38c07e37bdb85..5f6e409fb215e 100644 --- a/torch/_higher_order_ops/out_dtype.py +++ b/torch/_higher_order_ops/out_dtype.py @@ -106,7 +106,7 @@ def out_dtype_dense(op: torch._ops.OpOverload, output_dtype: torch.dtype, *args) def is_int_mm(op, output_dtype, args): return ( - op == torch.ops.aten.mm.default + op is torch.ops.aten.mm.default and output_dtype == torch.int32 and len(args) == 2 and args[0].dtype == torch.int8 diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py index 2c3067f2cce00..852339d11ece5 100644 --- a/torch/_higher_order_ops/scan.py +++ b/torch/_higher_order_ops/scan.py @@ -312,7 +312,7 @@ def _scan(init, xs): out_tensor_mask = get_tensor_mask(dummy_out) dummy_out_masked = mask_list(out_tensor_mask, dummy_out) - # Pre-alocate + # Pre-allocate # outs -> Output matrix # idxs -> Index matrix for scatter_ # out: (num_elems, M, N, ...) diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 87cf42a950eb6..8ffab37699422 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -264,6 +264,7 @@ def generate_ttir( assert isinstance(kernel, JITFunction) + # pyrefly: ignore # missing-attribute context = triton._C.libtriton.ir.context() target = triton.runtime.driver.active.get_current_target() backend = triton.compiler.compiler.make_backend(target) @@ -305,6 +306,7 @@ def generate_ttir( base_tensor = torch.empty( [elements_per_dim] * len(block_shape), dtype=a.dtype ) + # pyrefly: ignore # bad-argument-type ordered_args[name] = TensorDescriptor.from_tensor(base_tensor, block_shape) elif isinstance(a, (FakeTensor, torch._inductor.ir.TensorBox)): with torch._C._DisableTorchDispatch(): @@ -368,6 +370,7 @@ def _get_specialization(args): # type: ignore[no-untyped-def] target = triton.runtime.driver.active.get_current_target() backend_ = triton.compiler.compiler.make_backend(target) + # pyrefly: ignore # missing-attribute return backend_.get_attrs_descriptor(args, kernel.params) else: assert ( @@ -384,6 +387,7 @@ def _get_specialization(args): # type: ignore[no-untyped-def] except TypeError: # Unknown arg `specialize_extra` # Older versions of Triton take specialize_extra as an arg to specialize_impl specialize_impl = functools.partial( + # pyrefly: ignore # missing-argument triton.runtime.jit.create_specialize_impl(), specialize_extra=backend.get_arg_specialization, ) @@ -468,6 +472,7 @@ def get_signature_value(idx: int, arg: Any) -> str: if i not in constexprs } + # pyrefly: ignore # missing-attribute triton._C.libtriton.ir.load_dialects(context) backend.load_dialects(context) @@ -477,22 +482,29 @@ def get_signature_value(idx: int, arg: Any) -> str: # backward compatibility here. make_ir_sig_params = len(inspect.signature(src.make_ir).parameters) get_codegen_implementation_sig_params = len( + # pyrefly: ignore # missing-attribute inspect.signature(backend.get_codegen_implementation).parameters ) if make_ir_sig_params == 2: + # pyrefly: ignore # missing-argument ttir_module = src.make_ir(options, context) elif make_ir_sig_params == 3: + # pyrefly: ignore # missing-attribute codegen_fns = backend.get_codegen_implementation() + # pyrefly: ignore # missing-argument ttir_module = src.make_ir(options, codegen_fns, context) elif make_ir_sig_params == 4: codegen_args = [options] if get_codegen_implementation_sig_params == 1 else [] + # pyrefly: ignore # missing-attribute codegen_fns = backend.get_codegen_implementation(*codegen_args) module_map = backend.get_module_map() ttir_module = src.make_ir(options, codegen_fns, module_map, context) else: codegen_args = [options] if get_codegen_implementation_sig_params == 1 else [] + # pyrefly: ignore # missing-attribute codegen_fns = backend.get_codegen_implementation(*codegen_args) module_map = backend.get_module_map() + # pyrefly: ignore # bad-argument-count ttir_module = src.make_ir(target, options, codegen_fns, module_map, context) if not ttir_module.verify(): raise RuntimeError("Verification for TTIR module has failed") @@ -1102,6 +1114,7 @@ def triton_kernel_wrapper_mutation_dense( from triton.tools.tensor_descriptor import TensorDescriptor block_shape = stable_meta[0] + # pyrefly: ignore # bad-argument-type kwargs[k] = TensorDescriptor.from_tensor(tensor, block_shape) # move as many positional arguments from dicts to args as we @@ -1658,6 +1671,7 @@ def call_triton_kernel( "Passing multiple @triton.autotune decorators is not supported. " "Please use a single @triton.autotune decorator instead." ) + # pyrefly: ignore # missing-attribute iter_kernel = iter_kernel.fn # Process the @triton.heuristics decorator: @@ -1868,6 +1882,7 @@ def call_triton_kernel( # Both for grid's meta as well as for the kernel, we need combined # args and kwargs combined and normalized + # pyrefly: ignore # missing-attribute combined_args_raw = {**dict(zip(variable.kernel.arg_names, args)), **kwargs} # precompute the grid for the kernel @@ -2061,6 +2076,7 @@ def __init__( kernel_idx: Optional[int], grid: Optional["TritonGridType"], ) -> None: + # pyrefly: ignore # bad-assignment self.kernel = None self.grid = None tracing_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid) diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index e734bd4df5e4e..160e149fd769f 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -708,7 +708,7 @@ def _stack_pytree(pytrees): # is partitioned into in order to recover it in saved_tensors_and_symints. # # In saved_tensors_and_symints, we can recover the original args by: -# iterating over the pos list and pop one item from the front of paritioned_args[pos[i]]. +# iterating over the pos list and pop one item from the front of partitioned_args[pos[i]]. # We use t_idx and s_idx to keep track of the next index of the item we are going to pop for the two lists. def save_tensors_and_symints_for_backward(ctx, args): assert all( @@ -907,7 +907,7 @@ def diff_tensor_meta( try: if val1 != val2: pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'") - except GuardOnDataDependentSymNode as _: + except GuardOnDataDependentSymNode: pair_diffs.append(f"'{meta_name}: {val1} vs {val2}'") continue return pair_diffs @@ -1197,7 +1197,7 @@ def wrapped_fn(*flat_args): # call_op preserves ordering of proxies via schema materialized_args = [] - for i, (proxy, arg) in enumerate(zip(arg_proxies, schema.arguments)): + for i, proxy in enumerate(arg_proxies): if ( isinstance(proxy, torch.fx.Node) and proxy.op == "get_attr" diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index 4ada93c6e47c6..148f4c516bbd2 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -316,7 +316,7 @@ def _validate_cond_output(pred): if stack_output: outs: list[torch.Tensor] = [] - for i, out in enumerate(outputs): + for out in outputs: outs.append(torch.stack(out, dim=0)) return tuple(outs) @@ -660,7 +660,7 @@ def __call__( # # gx = gy0 * bw(y0, x), # -# where gy0 denotes the graident of loss with respect to y0, and bw(y0, x) denotes the graident of y0 with +# where gy0 denotes the gradient of loss with respect to y0, and bw(y0, x) denotes the gradient of y0 with # respect to x. Note that bw can be computed from forward body_fn easily using torch.autograd.grad. # We could substitute the unknowns gy0, gy1, ..., with chain rule until gy4: # @@ -769,7 +769,7 @@ def backward(ctx, *grads): # Note [Handle inputs that're not differentiable] # When a forward input is non-differentiable e.g. a symint or an integer tensor, their gradients # will be None. However, we don't want to return None in the subgraph because this complicates the - # inductor codegen, where we need to do a non-unform treatment for None and tensors. + # inductor codegen, where we need to do a non-uniform treatment for None and tensors. # So we set up masks and filter the None gradients so that only tensors are returned from each step. carries_tensor_masks = [ bool(isinstance(t, torch.Tensor) and t.dtype.is_floating_point) diff --git a/torch/_inductor/analysis/device_info.py b/torch/_inductor/analysis/device_info.py index 6fc271458c771..8d5edf1e7fd26 100644 --- a/torch/_inductor/analysis/device_info.py +++ b/torch/_inductor/analysis/device_info.py @@ -86,6 +86,28 @@ class DeviceInfo: ), # Source: # @lint-ignore https://www.amd.com/content/dam/amd/en/documents\ + # /instinct-tech-docs/product-briefs/amd-instinct-mi350x-gpu-brochure.pdf + "AMD MI350X": DeviceInfo( + tops={ + torch.float64: 72.1, + torch.float32: 144.2, + # not specified, fall back to float32 numbers + "torch.tf32": 144.2, + torch.bfloat16: 2309.6, + torch.float16: 2309.6, + torch.float8_e8m0fnu: 4614.0, + torch.float8_e8m0fnu: 4614.0, + torch.float8_e4m3fnuz: 4614.0, + torch.float8_e5m2: 4614.0, + torch.float8_e5m2fnuz: 4614.0, + torch.float8_e8m0fnu: 4614.0, + torch.int8: 4614.0, + }, + dram_bw_gbs=8000.0, + dram_gb=288.0, + ), + # Source: + # @lint-ignore https://www.amd.com/content/dam/amd/en/documents\ # /instinct-tech-docs/data-sheets/amd-instinct-mi300a-data-sheet.pdf "AMD MI300A": DeviceInfo( tops={ @@ -151,6 +173,7 @@ class DeviceInfo: dram_gb=64.0, ), } +_device_mapping["AMD INSTINCT MI350X"] = _device_mapping["AMD MI350X"] _device_mapping["AMD INSTINCT MI300X"] = _device_mapping["AMD MI300X"] _device_mapping["AMD INSTINCT MI210X"] = _device_mapping["AMD MI210X"] diff --git a/torch/_inductor/augmented_graph_helper.py b/torch/_inductor/augmented_graph_helper.py index ac61c015888e6..81dca605940e5 100644 --- a/torch/_inductor/augmented_graph_helper.py +++ b/torch/_inductor/augmented_graph_helper.py @@ -26,6 +26,8 @@ def __init__( # Extra dependencies: node depends on dep (dep must come before node) self.extra_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + # Extra uses: reverse of extra_deps (node is used by user) + self.extra_uses: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) # Note: only reflect original ancestors, not maintained through additional deps # or merge sets self.node_ancestors = node_ancestors @@ -33,6 +35,12 @@ def __init__( def add_extra_dep(self, *, n: fx.Node, dep: fx.Node) -> None: """Add extra dependency: node depends on dep.""" self.extra_deps[n].add(dep) + self.extra_uses[dep].add(n) + + def remove_extra_dep(self, *, n: fx.Node, dep: fx.Node) -> None: + if dep in self.extra_deps[n]: + self.extra_deps[n].discard(dep) + self.extra_uses[dep].discard(n) def merge_to_set(self, existing_node: fx.Node, new_node: fx.Node) -> None: """ @@ -123,3 +131,51 @@ def has_path(self, source: fx.Node, target: fx.Node) -> bool: queue.append(dep) return False + + def transfer_erased_node_deps(self, erased_to_new: dict[fx.Node, fx.Node]) -> None: + """ + Transfer all extra dependencies from erased nodes to their replacements, handling + cross-dependencies between erased nodes correctly. + """ + erased_merge_sets: dict[fx.Node, fx.Node] = {} + + for replaced, new in erased_to_new.items(): + for equiv in self.merge_sets[replaced]: + erased_merge_sets[equiv] = new + + # Transfer dependencies + for old_node, new_node in erased_merge_sets.items(): + # Transfer dependencies FROM old_node (what old_node depended on) + for extra_dep in self.extra_deps[old_node]: + # Redirect if dep is also being erased + updated_dep = erased_merge_sets.get(extra_dep, extra_dep) + self.extra_deps[new_node].add(updated_dep) + self.extra_uses[updated_dep].discard(old_node) + self.extra_uses[updated_dep].add(new_node) + + # Transfer dependencies TO old_node (what depended on old_node) + for extra_use in self.extra_uses[old_node]: + # Redirect if this user is also being erased + updated_use = erased_merge_sets.get(extra_use, extra_use) + + # Update the user's deps to point to new_node + self.extra_deps[updated_use].discard(old_node) + self.extra_deps[updated_use].add(new_node) + self.extra_uses[new_node].add(updated_use) + + # Clean up erased nodes + for old_node in erased_merge_sets.keys(): + self.extra_deps[old_node].clear() + self.extra_uses[old_node].clear() + del self.merge_sets[old_node] + + def get_all_extra_deps(self) -> dict[fx.Node, OrderedSet[fx.Node]]: + """ + Get all extra dependencies in a format suitable for topological sort. + Returns a copy to avoid external modifications. + """ + return { + node: OrderedSet(deps) + for node, deps in self.extra_deps.items() + if deps # Only include nodes with non-empty deps + } diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 512efeb633625..b6ef9006f8d2a 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -2606,7 +2606,7 @@ def convert_arg(arg: Any) -> Any: if isinstance(result, (list, tuple)): # unsafe_alloc_void_ptrs_from_tensors expects result contains tensor only result = [torch.tensor([]) if r is None else r for r in result] - for i, r in enumerate(result): + for r in result: assert isinstance(r, torch.Tensor), op + " returns a list of non-tensors" return torch._C._aoti.unsafe_alloc_void_ptrs_from_tensors(result) # type: ignore[arg-type] diff --git a/torch/_inductor/codegen/aoti_hipify_utils.py b/torch/_inductor/codegen/aoti_hipify_utils.py index eb71d4ee7f392..eca4f85ced926 100644 --- a/torch/_inductor/codegen/aoti_hipify_utils.py +++ b/torch/_inductor/codegen/aoti_hipify_utils.py @@ -1,7 +1,6 @@ import re import torch -from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE # It is not a good idea to directly apply hipify_torch to codegen, which will be vulnerable to cases like: @@ -15,6 +14,12 @@ def maybe_hipify_code_wrapper(source_codes: str, force_hipify: bool = False) -> if torch.version.hip is None and not force_hipify: return source_codes + try: + from torch.utils.hipify.hipify_python import PYTORCH_MAP, PYTORCH_TRIE + except ImportError: + # hipify not available for non-AMD builds + return source_codes + def c2_repl(m: re.Match[str]) -> object: return PYTORCH_MAP[m.group(0)] diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index ec967ca83c3bc..d81233e6026e9 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -780,7 +780,7 @@ def deduce_node_dtype(self, node: torch.fx.Node) -> Optional[torch.dtype]: # we can infer output node if it only have 1 arg return None - if node.target == operator.getitem: + if node.target is operator.getitem: node_arg = node.args[0] assert isinstance(node_arg, torch.fx.Node), type(node_arg) return self.deduce_node_dtype(node_arg) @@ -2080,6 +2080,7 @@ def __init__( self.compute = IndentedBuffer() self.stores = IndentedBuffer() + self.atomic_add_found = False self.num_load = 0 self.num_store = 0 self.num_reduction = 0 @@ -2184,6 +2185,7 @@ def partial_accumulate( name: str, reduction_type: ReductionType, value: CSEVariable, + extra_meta: dict[str, Any], ) -> None: raise NotImplementedError diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 65f5d37d0d852..a08a516bc6c17 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -165,6 +165,8 @@ def get_export_declaration(): torch.float16, torch.uint8, torch.int8, + torch.float8_e4m3fn, + torch.float8_e5m2, ] @@ -1126,6 +1128,7 @@ def partial_accumulate( name: str, reduction_type: str, value: CSEVariable, + extra_meta: dict[str, Any], ) -> None: raise NotImplementedError @@ -1185,7 +1188,7 @@ def wrapper(*args, **kwargs): # 3. int32 and fp32 in test_torchinductor_dynamic_shapes.py::test_avg_pool2d8_dynamic_shapes_cpu if len(new_args) == 2: new_args = promote_args(new_args) - elif func == CppVecOverrides.where: + elif func is CppVecOverrides.where: new_args[1:] = promote_args(new_args[1:]) # Broadcast scalar args to vector @@ -2541,7 +2544,7 @@ def codegen_loops(self, code, worksharing): @property def assert_function(self) -> str: if V.graph.aot_mode: - return "AOTI_TORCH_CHECK" + return "STD_TORCH_CHECK" else: return "TORCH_CHECK" @@ -5469,7 +5472,16 @@ def flush(self): src_code, self.kernel_group.scheduled_nodes ) self.codegen_comment(self.kernel_group.scheduled_nodes, kernel_name) + if config.cpp.enable_kernel_profile: + V.graph.wrapper_code.write_kernel_context_guard_begin() + V.graph.wrapper_code.write_kernel_context_guard( + kernel_name, + self.kernel_group.scheduled_nodes, # type: ignore[arg-type] + ) self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name) + if config.cpp.enable_kernel_profile: + V.graph.wrapper_code.write_kernel_context_guard_end() + self.reset_kernel_group() self._set_flush_status(False) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index e49498cce411d..55248b4e40629 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -22,6 +22,7 @@ from torch.utils._sympy.symbol import symbol_is_type, SymT from .. import config, cpp_builder, ir +from ..ir import ExternKernel from ..utils import _align, DeferredLineBase, LineContext, normalize_name from ..virtualized import V from .aoti_hipify_utils import maybe_hipify_code_wrapper @@ -43,6 +44,8 @@ # At most, the list nesting can go one layer deep. _OUTPUT_ARGS_TYPE = list[Union[Optional[str], list[Optional[str]]]] + from ..scheduler import BaseSchedulerNode + class HasWriteLine(Protocol): def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None: ... @@ -233,6 +236,18 @@ def write_header(self): self.header.splice(f"""#include \"{self.model_class_name_suffix}.h\"""") self.header.splice("\n") + if config.cpp.enable_kernel_profile: + self.header.splice( + "#include " + ) + self.header.splice( + """ + namespace torch::aot_inductor { + thread_local KernelContext* tls_kernel_context = nullptr; + } + """ + ) + def _include_extra_header(self, header: str): # This is needed for cpp to python dtype conversion self.header.splice(f"#include <{header}>") @@ -1249,7 +1264,7 @@ def generate_c_shim_extern_kernel_call( device: str, *, debug_args: Optional[list[str]] = None, - debug_handle: Optional[int] = None, + stack_traces: Optional[OrderedSet[str]] = None, ) -> None: """debug_args kwarg allows CppWrapperCpuArrayRef to pass in wrapped arguments in place of args while preserving debug printer output.""" @@ -1266,21 +1281,26 @@ def generate_c_shim_extern_kernel_call( ] with debug_printer_manager: shim_fn = self.get_c_shim_func_name(kernel, device) - self.write_provenance_debug_handle(shim_fn, debug_handle) - shim_fn_codes = ( + shim_fn_codes = [ f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));" - ) + ] if enable_kernel_profile: - debug_handle_str = "" if debug_handle is None else f":{debug_handle}" - shim_fn_codes = textwrap.dedent( - f""" - {{ - RAIIAtenRecordFunctionHandle record_{shim_fn}_("{shim_fn}{debug_handle_str}", nullptr); - {shim_fn_codes} - }} - """ - ) - self.writeline(shim_fn_codes) + stack_trace_str = 'R"(' + if stack_traces: + for stack_trace in stack_traces: + for line in stack_trace.split("\n"): + stack_trace_str += f"\n{line}" + stack_trace_str += "\n" + stack_trace_str += ')"' + + shim_fn_codes = [ + "{", + f"""KernelContextGuard _ctx("{shim_fn}", {stack_trace_str});""", + f"""RAIIAtenRecordFunctionHandle record_{shim_fn}_("{shim_fn}", nullptr);""", + shim_fn_codes[0], + "}", + ] + self.writelines(shim_fn_codes) def generate_c_shim_extern_kernel_alloc( self, extern_kernel: ir.ExternKernelAlloc, args: list[str] @@ -1373,7 +1393,7 @@ def _generate_extern_kernel_out_helper( out_view: Optional[str], args: list[str], device: str, - debug_handle: Optional[int] = None, + stack_traces: Optional[OrderedSet[str]] = None, ) -> None: if out_view: out_name = f"{out}_as_strided" @@ -1383,7 +1403,7 @@ def _generate_extern_kernel_out_helper( args.insert(0, out) self.generate_c_shim_extern_kernel_call( - kernel, args, device, debug_handle=debug_handle + kernel, args, device, stack_traces=stack_traces ) def _get_scatter_reduce_enum(self, reduce): @@ -2897,3 +2917,54 @@ def create_tmp_raii_handle_var_if_needed( writer.writeline(call_str) return tmp_var_name + + def write_kernel_context_guard_begin( + self, + ): + # Beginning of a kernel context guarded block. + # The block looks like this: + # { + # KernelContextGuard _ctx("{kernel_name}", {stack_trace_str}); + # ... operations... + # } + self.writeline("{") + + def write_kernel_context_guard_end( + self, + ): + # End of a kernel context guarded block. + self.writeline("}") + + def write_kernel_context_guard( + self, + kernel_name: str, + node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel], + ): + def aggregate_stack_traces( + node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel], + ) -> OrderedSet[str]: + if isinstance(node_schedule, list): + return functools.reduce( + lambda a, b: a | b, + [ + # pyrefly: ignore [missing-attribute] + node.node.get_stack_traces() + for node in node_schedule + if hasattr(node, "node") and node.node + ], + OrderedSet(), + ) + elif isinstance(node_schedule, ExternKernel): + return node_schedule.get_stack_traces() + else: + return OrderedSet() + + stack_trace_str = 'R"(' + stack_traces = aggregate_stack_traces(node_schedule) + + for stack_trace in stack_traces: + for line in stack_trace.split("\n"): + stack_trace_str += f"\n{line}" + stack_trace_str += "\n" + stack_trace_str += ')"' + self.writeline(f'KernelContextGuard _ctx("{kernel_name}", {stack_trace_str});') diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index 9ed1cfb9adfcd..22d0981febecd 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -996,7 +996,7 @@ def gen_ops(self) -> "list[tuple[str, cutlass_gemm_op.GemmOperation]]": # type: No function arguments. Returns: - List[Tuple[str, cutlass_gemm_op.GemmOperation]]: A list of (cutlass_name, GemmOperation) + List[tuple[str, cutlass_gemm_op.GemmOperation]]: A list of (cutlass_name, GemmOperation) tuples that are compatible with the operation requirements of this template. """ assert cutlass_utils.try_import_cutlass() @@ -1564,7 +1564,7 @@ def _define_gemm_instance( op (cutlass_library.gemm_op.GemmOperation): This is the core GEMM operation that we are defining and rendering. Returns: - Tuple[str, str]: A tuple where the first part is a string that constitutes the defined GEMM operation in C++ + tuple[str, str]: A tuple where the first part is a string that constitutes the defined GEMM operation in C++ code (render) and the second part is the string that specifies the operation type. """ assert cutlass_utils.try_import_cutlass() @@ -1852,7 +1852,7 @@ def _define_gemm_instance( op (cutlass_library.gemm_op.GemmOperation): This is the core GEMM operation that we are defining and rendering. Returns: - Tuple[str, str]: A tuple where the first part is a string that constitutes the defined GEMM operation in C++ + tuple[str, str]: A tuple where the first part is a string that constitutes the defined GEMM operation in C++ code (render) and the second part is the string that specifies the operation type. """ assert cutlass_utils.try_import_cutlass() diff --git a/torch/_inductor/codegen/cutedsl/_cutedsl_utils.py b/torch/_inductor/codegen/cutedsl/_cutedsl_utils.py new file mode 100644 index 0000000000000..173d122781016 --- /dev/null +++ b/torch/_inductor/codegen/cutedsl/_cutedsl_utils.py @@ -0,0 +1,29 @@ +# mypy: disable-error-code=import-not-found +# pyrefly: ignore [import-error] +import cutlass.cute as cute + + +@cute.jit # type: ignore[misc] +def ssa_to_indexable(ssa_value: cute.TensorSSA, dtype: str) -> cute.Numeric: + """ + Convert SSA form to indexable non-SSA form. + + Workaround for lack of gather support: SSA values cannot be used directly + as indices in tensor loads. This converts SSA → fragment → scalar for indexing. + """ + frag = cute.make_fragment(1, dtype) + frag.store(ssa_value) + return frag[0] + + +@cute.jit # type: ignore[misc] +def result_to_ssa(value: cute.Numeric, dtype: str) -> cute.TensorSSA: + """ + Convert non-SSA result back to SSA form. + + After performing operations with non-SSA values (like indexed loads), + convert the result back to SSA form for further computation. + """ + frag = cute.make_fragment(1, dtype) + frag[0] = value + return frag.load() diff --git a/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py b/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py index ac8ce6f917664..9809f89e0873e 100644 --- a/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py +++ b/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py @@ -65,6 +65,10 @@ class CuteDSLSubgraphInfo: body: IndentedBuffer template_mask: Optional[str] = None template_out: Optional[str] = None + cse: Optional[CSE[Any]] = None + + def __post_init__(self): + self.only_copy_if_non_none_fields = ("cse",) def to_dict(self): return { @@ -118,8 +122,6 @@ def __init__( def kexpr(self, expr: sympy.Expr) -> str: """Convert sympy expression to CuteDSL string representation.""" - # For CuteDSL, we use standard Python string conversion - # since CuteDSL uses Python syntax for expressions return str(expr) def gen_imports(self) -> str: @@ -134,6 +136,7 @@ def gen_imports(self) -> str: import cuda.bindings.driver as cuda from cutlass._mlir.dialects import math as mlir_math import operator + from torch._inductor.codegen.cutedsl._cutedsl_utils import ssa_to_indexable, result_to_ssa """ ) return imports.getvalue() @@ -191,10 +194,15 @@ def set_subgraph_body(self, body_name: str): body=IndentedBuffer(), template_mask=None, template_out=None, + cse=None, ) subgraph = self.subgraph_bodies[body_name] for key, value in subgraph.to_dict().items(): + if value is None and key in getattr( + subgraph, "only_copy_if_non_none_fields", () + ): + continue setattr(self, key, value) try: @@ -212,15 +220,17 @@ def set_subgraph_body(self, body_name: str): setattr(self, key, value) @contextlib.contextmanager - def create_subgraph_body(self, body_name: str): + def create_subgraph_body(self, body_name: str, *, clear_cse: bool = False): """Create a new subgraph body for template processing.""" assert body_name not in self.subgraph_bodies, ( f"Subgraph body '{body_name}' already exists" ) + new_cse = self.cse.clone() if clear_cse else None self.subgraph_bodies[body_name] = CuteDSLSubgraphInfo( body=IndentedBuffer(), template_mask=None, template_out=None, + cse=new_cse, ) with self.set_subgraph_body(body_name): yield @@ -294,7 +304,8 @@ def hook(): # Register the hook and return placeholder placeholder = "" - assert placeholder not in self.render_hooks + # TODO: I think double invoking is fine for this specific hook + # assert placeholder not in self.render_hooks self.render_hooks[placeholder] = hook return placeholder @@ -330,7 +341,7 @@ def modification( while f"mod_{subgraph_number}_{num}" in self.subgraph_bodies: num += 1 - with self.create_subgraph_body(f"mod_{subgraph_number}_{num}"): + with self.create_subgraph_body(f"mod_{subgraph_number}_{num}", clear_cse=True): subgraph = self._get_subgraph(subgraph_number) modification_handler = ModificationWrapperCuteDSL( self, subgraph_number, fixed_inputs, mask @@ -406,72 +417,32 @@ def _get_input_dtype(self, name: str) -> torch.dtype: def load(self, name: str, index: sympy.Expr): """Handle loading from tensor or fixed(template args) input for CuteDSL.""" if name not in self.fixed_inputs: - index_str = self._process_indexing(index) var = self._add_kernel_input(name) buffer = V.graph.get_buffer(name) var_dtype = buffer.dtype - # Get the CuteDSL dtype mapping cute_dtype = CuteDSLOpOverrides.TORCH_TO_CUTE_DTYPE.get( var_dtype, "cutlass.Float32" ) + renamed_index = self.kernel.rename_indexing(index) - # NB - # This assumes single-value loads which is not generally the case but is a workaround - # since we don't have gather support yet. We do loads in non-SSA form then convert - # back to SSA form for any remaining operations over the loaded values. - # - # Pattern: - # index_frag = cute.make_fragment(1, cutlass.Int32) - # index_frag.store(index) - # val_frag = cute.make_fragment(1, dtype) - # index = index_frag[0] - # val_frag[0] = tensor[index] - # result = val_frag.load() - - index_frag = self.kernel.cse.generate( - self.kernel.body, - "cute.make_fragment(1, cutlass.Int32)", - dtype=torch.int32, - bounds=ValueRanges.unknown(), + idx_var = self._emit_scalar_fragment( + self.kernel.kexpr(renamed_index), "cutlass.Int32", torch.int32 ) - self.kernel.cse.generate( - self.kernel.body, - f"{index_frag}.store({index_str})", - dtype=torch.int32, - bounds=ValueRanges.unknown(), + val_frag = self.kernel.cse.newvar(dtype=var_dtype) + self.kernel.body.writeline( + f"{val_frag} = cute.make_fragment(1, {cute_dtype})" ) - val_frag = self.kernel.cse.generate( - self.kernel.body, - f"cute.make_fragment(1, {cute_dtype})", - dtype=var_dtype, - bounds=ValueRanges.unknown(), - ) - - index_var = self.kernel.cse.generate( - self.kernel.body, - f"{index_frag}[0]", - dtype=torch.int32, - bounds=ValueRanges.unknown(), - ) - - self.kernel.cse.generate( - self.kernel.body, - f"{val_frag}[0] = ({var}[{index_var}])", - dtype=var_dtype, - bounds=ValueRanges.unknown(), - ) + self.kernel.body.writeline(f"{val_frag}[0] = ({var}[{idx_var}])") final_expr = f"{val_frag}.load()" - # Handle upcast to fp32 if needed if ( var_dtype in (torch.float16, torch.bfloat16) and config.triton.codegen_upcast_to_fp32 ): - # Apply dtype conversion after fragment load final_expr = f"({final_expr}).to(cutlass.Float32)" var_dtype = torch.float32 @@ -486,11 +457,25 @@ def load(self, name: str, index: sympy.Expr): value = self.fixed_inputs[name] dtype = self._get_input_dtype(name) - # ensure CSE wrapping return self.kernel.cse.generate( self.kernel.body, value, bounds=ValueRanges.unknown(), dtype=dtype ) + def _emit_scalar_fragment( + self, expr_str: str, cute_dtype: str, torch_dtype: torch.dtype + ) -> str: + """ + Convert SSA expression to indexable scalar for tensor loads. + + Workaround for lack of gather support: SSA values cannot be used directly + as indices. This generates code to convert SSA → indexable scalar. + """ + result = self.kernel.cse.newvar(dtype=torch_dtype) + self.kernel.body.writeline( + f"{result} = ssa_to_indexable({expr_str}, {cute_dtype})" + ) + return str(result) + def indirect_indexing(self, index_var: str, size, check, wrap_neg=True): """Convert index variable to symbolic form.""" return sympy_index_symbol(str(index_var)) diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index cdad5f1c72426..60674d8a3bf43 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -576,6 +576,7 @@ def partial_accumulate( name: str, reduction_type: str, value: CSEVariable, + extra_meta: dict[str, Any], ) -> None: raise NotImplementedError diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 268d044db6bae..4c668ea194409 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -141,6 +141,15 @@ def _print_ToFloat(self, expr: sympy.Expr) -> str: x = self.doprint(expr.args[0]) return f"static_cast({x})" + def _print_Float(self, expr: sympy.Expr) -> str: + if expr.is_integer: + # sympy considers 0.0 to be integer, but Metal doesn't. + # this workaround prints the float as an integer + # xref: https://github.com/sympy/sympy/issues/26620 + return str(int(expr)) + else: + return str(expr) + def _print_FloorToInt(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 x = self.doprint(expr.args[0]) @@ -895,7 +904,7 @@ def codegen_kernel(self, name: Optional[str] = None) -> str: else: dtype_str = self.dtype_to_str(dtype) code.writeline(f"constant {dtype_str}* {inner},") - for outer, inner in self.args.sizevars.items(): + for inner in self.args.sizevars.values(): code.writeline(f"constant long& {inner},") # Write dynamic values as inputs diff --git a/torch/_inductor/codegen/multi_kernel.py b/torch/_inductor/codegen/multi_kernel.py index 9bd0d780f824f..094164a1f08ca 100644 --- a/torch/_inductor/codegen/multi_kernel.py +++ b/torch/_inductor/codegen/multi_kernel.py @@ -218,7 +218,7 @@ def call_kernel(self, kernel_name): # the multi call kernel. multi_call_args = call_args multi_call_arg_types = arg_types - for i, kernel in enumerate(self.kernels): + for kernel in self.kernels: additional_call_args, additional_arg_types = ( kernel.additional_call_args_and_types() ) @@ -381,7 +381,12 @@ def inner(): return inner return [ - benchmarker.benchmark_gpu(wrap_fn(kernel, index), rep=40) + benchmarker.benchmark( + wrap_fn(kernel, index), + # Currently the kernel type must be a CachingAutotuner + device=kernel.device_props.type, + rep=40, + ) for index, kernel in enumerate(self.kernels) ] diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 7e5457f78ebb8..dd2742b68f0f7 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -35,7 +35,7 @@ from ..._dynamo.utils import counters from .. import config, ir, scheduler from ..analyze_preserves_zero_mask import prologue_preserves_zero_mask -from ..codecache import code_hash +from ..codecache import code_hash, PyCodeCache from ..dependencies import MemoryDep, StarDep, WeakDep @@ -43,7 +43,9 @@ from ..ir import IRNode from ..optimize_indexing import indexing_dtype_strength_reduction -from ..runtime.runtime_utils import green_text, yellow_text +from ..runtime.coordinate_descent_tuner import CoordescTuner +from ..runtime.hints import DeviceProperties +from ..runtime.runtime_utils import green_text, next_power_of_2, yellow_text from ..scheduler import BaseSchedulerNode, BaseScheduling, WhyNoFuse from ..utils import ( cache_property_on_self, @@ -1535,40 +1537,17 @@ def _split_mix_order_reduction_epilogue(self, node): epilogues.append(node) return reductions, epilogues - def _codegen_mix_order_reduction(self, node1, node2): - if not V.graph.sizevars.statically_known_gt( - node1.group[1][0], node1.group[1][1] - ): - return self._codegen_mix_order_reduction(node2, node1) - - # pyrefly: ignore [bad-assignment] - metrics.codegen_mix_order_reduction += 1 - - assert V.graph.sizevars.statically_known_gt( - node1.group[1][0], node1.group[1][1] - ) - - # split epilogue out of node2 - node2_reductions, node2_epilogue = self._split_mix_order_reduction_epilogue( - node2 - ) - - # decide the split size - nrow, ncol = node1.group[1] - split_size = 64 # TODO need add heuristics - nsplit = (nrow + split_size - 1) // split_size - - numel, rnumel = node1.group[1] + def _generate_kernel_code_for_mix_order_reduction( + self, kernel_features, split_size, for_benchmark + ): + """ + for_benchmark: + True if the generated code is for benchmarking. We need make + sure benchmark harness code is generated. + """ + numel, rnumel = kernel_features.numel, kernel_features.reduction_numel + node_schedule = kernel_features.node_schedule - converted_nodes = [] - for subnode in node2_reductions: - converted = subnode.extract_pw_from_reduction() - converted.swap_pw_red_dimension() - converted_nodes.append(converted) - node_schedule = self.generate_node_schedule( - node1.get_nodes() + converted_nodes, numel, rnumel - ) - kernel_features = SIMDKernelFeatures(node_schedule, numel, rnumel, None) kernel = self.create_kernel_choices( kernel_features, [{"x": numel, "r0_": rnumel}], @@ -1576,12 +1555,12 @@ def _codegen_mix_order_reduction(self, node1, node2): "features": kernel_features, "tiling_scores": None, "mix_order_reduction": True, + "override_persistent_reduction": True, }, )[0] assert kernel.persistent_reduction assert kernel.mix_order_reduction kernel.rsplit_size = split_size - self.codegen_node_schedule_with_kernel(node_schedule, kernel) # allocate workspace for this kernel @@ -1595,15 +1574,140 @@ def _codegen_mix_order_reduction(self, node1, node2): assert ws_off == 0, f"{ws_off=}" with kernel: kernel.codegen_body() - with V.set_kernel_handler(kernel): + + stack = contextlib.ExitStack() + with V.set_kernel_handler(kernel), stack: + if for_benchmark: + stack.enter_context(config.patch(benchmark_kernel=True)) src_code = kernel.codegen_kernel() + + if for_benchmark: + # only do this if we are doing benchmarking. + # When we are generating final code, the kernel name + # should be decided differently with node type, fx node name + # etc. + src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") + return kernel, ws_name, src_code + + def benchmark_codegened_module( + self, mod, n_spills_threshold=8, node_names: Optional[OrderedSet[str]] = None + ) -> tuple[float, str]: + raise NotImplementedError + + def _codegen_mix_order_reduction(self, node1, node2): + numel, rnumel = scheduler.MixOrderReduction.get_numel_rnumel(node1) + + if not V.graph.sizevars.statically_known_gt( + numel, + rnumel, + ): + return self._codegen_mix_order_reduction(node2, node1) + + def _pick_split_size(): + # the overridden has highest priority + if config.triton.mix_order_reduction_split_size is not None: + return config.triton.mix_order_reduction_split_size + + # heuristics based on number of SMs + device_prop = DeviceProperties.create(node1.get_device()) + num_sm = device_prop.multi_processor_count + estimated_num_splits = num_sm * 8 + split_size = max(next_power_of_2(numel // estimated_num_splits), 16) + split_size = min(split_size, 128) + return split_size + + split_size = _pick_split_size() + + # pyrefly: ignore [bad-assignment] + metrics.codegen_mix_order_reduction += 1 + + assert V.graph.sizevars.statically_known_gt( + numel, + rnumel, + ) + + # split epilogue out of node2 + node2_reductions, node2_epilogue = self._split_mix_order_reduction_epilogue( + node2 + ) + + converted_nodes = [] + for subnode in node2_reductions: + subnode.cancel_reduction_split() + converted = subnode.extract_pw_from_reduction() + converted.swap_pw_red_dimension() + converted_nodes.append(converted) + node_schedule = self.generate_node_schedule( + node1.get_nodes() + converted_nodes, numel, rnumel + ) + kernel_features = SIMDKernelFeatures(node_schedule, numel, rnumel) + + # The autotuning is skipped in deterministic mode + if ( + not torch._inductor.config.deterministic + and config.triton.mix_order_reduction_split_size is None + and config.triton.mix_order_reduction_autotune_split_size + ): + + def _bench(candidate_split_size): + _, _, src_code = self._generate_kernel_code_for_mix_order_reduction( + kernel_features, + split_size=candidate_split_size, + for_benchmark=True, + ) + mod = PyCodeCache.load(src_code) + ms, _ = self.benchmark_codegened_module(mod) + return ms + + split_size = CoordescTuner.autotune_single_field( + _bench, + split_size, + 8, + ) + # print(f"Autotuning pick split size {split_size}") + + kernel, ws_name, src_code = self._generate_kernel_code_for_mix_order_reduction( + kernel_features, + split_size=split_size, + for_benchmark=False, + ) + + # rename intermediate reduction output to final reduction + # output + is_split_reduction = bool(node2_reductions[0].node._split_size) + rename = {} + if is_split_reduction: + for subnode in node2_reductions: + bufname = subnode.get_outputs()[0].node.get_name() + username = ( + subnode.get_outputs()[0] + .users[0] + .node.get_outputs()[0] + .node.get_name() + ) + rename[bufname] = username + assert self.scheduler + self.scheduler.removed_ops.add( + subnode.get_outputs()[0].users[0].node.get_name() + ) + V.graph.removed_buffers.add(bufname) + + for partial_accum in kernel.saved_partial_accumulate: + partial_accum.buffer_name = rename.get( + partial_accum.buffer_name, partial_accum.buffer_name + ) + kernel_name = self.define_kernel(src_code, node_schedule, kernel) kernel.kernel_name = kernel_name kernel.code_hash = code_hash(src_code) with V.set_kernel_handler(kernel): for node in kernel_features.scheduler_nodes(): - node.mark_run() + # No need to allocate buffer for split reduction + # since we are gonna to allocate workspace to store the + # intermediate reduction reduction + if node.get_outputs()[0].node.get_name() not in rename: + node.mark_run() # workspace args is still needed after the call kernel.call_kernel(kernel.kernel_name, deallocate_ws=False) @@ -1612,10 +1716,9 @@ def _codegen_mix_order_reduction(self, node1, node2): # a extra round of reduction assert len(converted_nodes) == len(kernel.saved_partial_accumulate) - for idx, (buffer_name, partial_accum) in enumerate( - zip(node2.get_buffer_names(), kernel.saved_partial_accumulate) - ): - assert buffer_name == partial_accum.buffer_name + nsplit = (numel + split_size - 1) // split_size + for idx, partial_accum in enumerate(kernel.saved_partial_accumulate): + buffer_name = partial_accum.buffer_name stride_str = f"{nsplit} * {rnumel}" start = f"{idx} * {stride_str}" @@ -1627,9 +1730,13 @@ def _codegen_mix_order_reduction(self, node1, node2): opname = reduction_type2op.get( partial_accum.reduction_type, partial_accum.reduction_type ) + V.graph.wrapper_code.writeline( f"{buffer_name} = {ws_name}[{start} : {end}].view({nsplit}, {rnumel}).{opname}(dim=0)", ) + # mark the buffer as allocated, so we don't try to allocate + # it again when it's later used + V.graph.wrapper_code.allocated.add(buffer_name) kernel.deallocate_workspaces() @@ -1643,6 +1750,12 @@ def _codegen_nodes( nodes: Sequence[scheduler.SchedulerNode], coalesce_analysis: Optional[CoalesceVarAnalysis] = None, ): + assert self.scheduler + nodes = [ + node for node in nodes if node.get_name() not in self.scheduler.removed_ops + ] + if not nodes: + return _, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group node_schedule = self.generate_node_schedule(nodes, numel, rnumel) @@ -1658,14 +1771,24 @@ def codegen_node( """ Given a set of pre-fused nodes, generate a Triton kernel. """ + assert self.scheduler + nodes = [ + node + for node in node.get_nodes() + if node.get_name() not in self.scheduler.removed_ops + ] + if len(nodes) == 0: + return if torch._inductor.config.triton.coalesce_tiling_analysis: + if len(nodes) != len(node.get_nodes()): + assert self.scheduler + node = scheduler.FusedSchedulerNode(self.scheduler, nodes) coalesce_analysis = analyze_memory_coalescing(node) else: coalesce_analysis = None - nodes: list[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment] - return self._codegen_nodes(nodes, coalesce_analysis) + return self._codegen_nodes(nodes, coalesce_analysis) # type: ignore[arg-type] @staticmethod def can_use_32bit_indexing( @@ -1707,6 +1830,9 @@ def can_use_32bit_indexing( return True def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures): + """ + Generate code for nodes in kernel_features + """ node_schedule = kernel_features.node_schedule tiling, tiling_score = self.get_tiling_and_scores( @@ -1747,7 +1873,15 @@ def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures): node for node in node_schedule if isinstance(node, BaseSchedulerNode) ] self.codegen_comment(base_scheduler_nodes, final_kernel.kernel_name) + if config.cpp.enable_kernel_profile: + V.graph.wrapper_code.write_kernel_context_guard_begin() + V.graph.wrapper_code.write_kernel_context_guard( + final_kernel.kernel_name, + base_scheduler_nodes, # type: ignore[arg-type] + ) final_kernel.call_kernel(final_kernel.kernel_name) + if config.cpp.enable_kernel_profile: + V.graph.wrapper_code.write_kernel_context_guard_end() if config.nan_asserts: final_kernel.codegen_nan_check() @@ -2521,6 +2655,7 @@ def compute_tiling_strategy( sympy_product(pw_ranges) == pointwise_numel, lambda: f"{pw_ranges}, {pointwise_numel}, {node_schedule}", ) + torch._check( sympy_product(red_ranges) == reduction_numel, lambda: f"{red_ranges}, {reduction_numel}, {node_schedule}", diff --git a/torch/_inductor/codegen/subgraph.py b/torch/_inductor/codegen/subgraph.py index 7259bd3460054..828ee02e559ef 100644 --- a/torch/_inductor/codegen/subgraph.py +++ b/torch/_inductor/codegen/subgraph.py @@ -1,6 +1,6 @@ import itertools import logging -from typing import Any, Callable, Union +from typing import Any, Callable, Optional, Union import torch import torch._inductor.config as config @@ -8,6 +8,7 @@ from torch._inductor.codegen.common import KernelTemplate from torch._inductor.ir import ( Buffer, + FixedLayout, get_free_symbols, get_symbolic_inputs, gm_original_output_strides, @@ -110,7 +111,11 @@ def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: bm_func([*sym_inputs, *args]) if config.profile_bandwidth_with_do_bench_using_profiling: return do_bench_using_profiling(lambda: bm_func([*sym_inputs, *args])) - return benchmarker.benchmark_gpu(lambda: bm_func([*sym_inputs, *args])) + return benchmarker.benchmark( + # Shallow clone args since bm_func may clear args + lambda: bm_func([*sym_inputs, *args]), + device=benchmarker.infer_device(*sym_inputs, *args), + ) def hash_key(self) -> str: return "-".join( @@ -181,9 +186,11 @@ def generate( # type: ignore[override] Generate a SubgraphChoiceCaller instance for autotuning. Args: + name: The name for this subgraph choice input_nodes: List of input nodes to the subgraph layout: Memory layout information for the output - example_inputs: Example tensor inputs used to trace and benchmark the subgraph + make_fx_graph: Callable that creates the FX graph for this subgraph + description: Optional description of this choice **kwargs: Additional keyword arguments Returns: @@ -197,3 +204,165 @@ def generate( # type: ignore[override] description=description, make_fx_graph=make_fx_graph, ) + + def generate_custom_op_choices( + self, + name: str, + decompositions: list[Callable[..., Any]], + input_nodes: list[Buffer], + non_tensor_args: list[dict[str, Any]], + default_impl: Optional[Callable[..., Any]] = None, + ) -> list[SubgraphChoiceCaller]: + """ + Generate multiple SubgraphChoiceCaller instances for custom op autotuning. + + This method extends SubgraphTemplate to support custom op decompositions, + allowing multiple implementations to compete in autotuning. + + Args: + name: Base name for the choices + decompositions: List of decomposition functions to compete in autotuning + input_nodes: List of tensor inputs. All tensor arguments must be passed here. + non_tensor_args: List of non-tensor kwargs only, one dict per corresponding decomposition. + default_impl: Default implementation for layout inference + + Returns: + List of SubgraphChoiceCaller instances for autotuning + """ + if not decompositions: + return [] + + assert len(decompositions) == len(non_tensor_args), ( + f"decompositions and non_tensor_args must have same length, " + f"got {len(decompositions)} decompositions and {len(non_tensor_args)} kwargs" + ) + + # Infer layouts and ensure layout consistency for fair autotuning comparison + layouts = [ + self._infer_custom_op_layout(input_nodes, decomp, kwargs, default_impl) + for decomp, kwargs in zip(decompositions, non_tensor_args) + ] + + # Validate all decompositions produce equivalent layouts for fair comparison + self._validate_layout_equivalence(name, decompositions, layouts) + layout = layouts[0] # All layouts are now validated to be equivalent + + choices: list[SubgraphChoiceCaller] = [] + for decomp, decomp_kwargs in zip(decompositions, non_tensor_args): + # Create make_fx_graph function for this decomposition + import functools + + def make_fx_graph( + *args: Any, + decomp: Callable[..., Any] = decomp, + decomp_kwargs: dict[str, Any] = decomp_kwargs, + ) -> Any: + # decomp_kwargs contains all merged parameters: CustomOpConfig params + runtime kwargs + from torch.fx.experimental.proxy_tensor import make_fx + + return make_fx(functools.partial(decomp, **decomp_kwargs))(*args) + + # Generate descriptive name for this variant + variant_name = self._generate_variant_name(decomp, decomp_kwargs) + + choice = self.generate( + name=f"{name}_{variant_name}", + input_nodes=input_nodes, + layout=layout, + make_fx_graph=make_fx_graph, + description=f"CustomOp {decomp.__name__}", + ) + choices.append(choice) + + return choices + + def _generate_variant_name( + self, decomp: Callable[..., Any], kwargs: dict[str, Any] + ) -> str: + """Generate a descriptive name for a decomposition variant with its parameters.""" + base_name = decomp.__name__ + if not kwargs: + return base_name + param_suffix = "_".join(f"{k}_{v}" for k, v in sorted(kwargs.items())) + return f"{base_name}_{param_suffix}" + + def _validate_non_tensor_kwargs(self, kwargs: dict[str, Any]) -> None: + """Validate that kwargs contains only non-tensor arguments.""" + for key, value in kwargs.items(): + assert not isinstance(value, (torch.Tensor, Buffer)), ( + f"kwargs['{key}'] contains tensor {type(value)}. " + f"Tensor arguments should be in input_nodes, not kwargs. " + f"Only scalar/non-tensor parameters should be in kwargs." + ) + + def _validate_layout_equivalence( + self, + op_name: str, + decompositions: list[Callable[..., Any]], + layouts: list[Layout], + ) -> None: + """Ensure all layouts have consistent stride, device, dtype, and sizes for fair autotuning.""" + if not layouts: + return + + reference = layouts[0] + for i, layout in enumerate(layouts[1:], start=1): + if (layout.device, layout.dtype, layout.size, layout.stride) != ( + reference.device, + reference.dtype, + reference.size, + reference.stride, + ): + raise AssertionError( + f"Layout mismatch in custom op '{op_name}': " + f"decomposition '{decompositions[i].__name__}' produces " + f"({layout.device}, {layout.dtype}, {layout.size}, {layout.stride}) " + f"but '{decompositions[0].__name__}' produces " + f"({reference.device}, {reference.dtype}, {reference.size}, {reference.stride})" + ) + + def _infer_custom_op_layout( + self, + input_nodes: list[Buffer], + function_decomposition: Callable[..., Any], + kwargs: dict[str, Any], + default_impl: Optional[Callable[..., Any]] = None, + ) -> Layout: + """Infer output layout for custom ops using the default implementation when available. + Note that the Subgraph assumes custom ops return exactly one tensor output. + TODO: Add support for multiple output custom ops. + """ + import functools + + from torch._inductor.virtualized import V + + # Assert kwargs contain only non-tensor arguments + self._validate_non_tensor_kwargs(kwargs) + + with V.fake_mode: + example_inputs = [] + for inp in input_nodes: + raw_shape = inp.get_size() + concrete_shape = V.graph.sizevars.size_hints( + raw_shape, fallback=config.unbacked_symint_fallback + ) + fake_tensor = torch.empty( + concrete_shape, dtype=inp.get_dtype(), device=inp.get_device() + ) + example_inputs.append(fake_tensor) + + fn = functools.partial(function_decomposition, **kwargs) + output = fn(*example_inputs) + + # Assert single output + assert isinstance(output, torch.Tensor), ( + f"Expected single tensor output, got {type(output)}. " + f"Multi-output custom ops not yet supported in autotuning." + ) + + return FixedLayout( + device=output.device, + dtype=output.dtype, + size=output.shape, + stride=output.stride(), + ) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 59f2006823042..bba4ce2d2aca1 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -745,7 +745,12 @@ def _print_TruncToInt(self, expr: sympy.Expr) -> str: ) def _print_Float(self, expr: sympy.Expr) -> str: - if config.is_fbcode() and torch.version.hip: + if expr.is_integer: + # sympy considers 0.0 to be integer, but triton doesn't. + # this workaround prints the float as an integer + # xref: https://github.com/sympy/sympy/issues/26620 + ret = str(int(expr)) + elif config.is_fbcode() and torch.version.hip: ret = f"{expr}" else: ret = f"tl.full([], {expr}, tl.float64)" @@ -801,6 +806,9 @@ def _print_CeilToInt(self, expr: sympy.Expr) -> str: return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" def _helper_sqrt(self, expr: sympy.Expr) -> str: + # work around for https://github.com/pytorch/pytorch/issues/165738 + if torch.xpu.is_available(): + return f"libdevice.sqrt(({self._print(expr)}).to(tl.float32))" return f"tl.sqrt_rn(({self._print(expr)}).to(tl.float32))" def _print_FloatPow(self, expr: sympy.Expr) -> str: @@ -1212,6 +1220,9 @@ def expm1(x): @staticmethod @maybe_upcast_float32() def sqrt(x): + # work around for https://github.com/pytorch/pytorch/issues/165738 + if torch.xpu.is_available(): + return f"libdevice.sqrt({x})" return f"tl.sqrt_rn({x})" @staticmethod @@ -1918,6 +1929,7 @@ def partial_accumulate( name: str, reduction_type: str, value: CSEVariable, + extra_meta: dict[str, Any], ) -> None: raise NotImplementedError @@ -3103,7 +3115,9 @@ def _handle_pdl_after_load(self, launch_buffer, result_var): ) self.cse.generate(launch_buffer, launch_if_last_load, dtype=torch.int32) - def partial_accumulate(self, name: str, reduction_type, val): + def partial_accumulate( + self, name: str, reduction_type, val, extra_meta: dict[str, Any] + ): self.saved_partial_accumulate.append( PartialAccumulate(name, reduction_type, val) ) @@ -3357,6 +3371,7 @@ def store( indexing_str += f".broadcast_to({value_shape})" line = f"tl.store({var} + ({indexing_str}), {value}, {indexing.mask_str})" elif mode == "atomic_add": + self.atomic_add_found = True indexing_str = indexing.index_str if ( is_sympy_integer_like(index) @@ -4556,14 +4571,16 @@ def codegen_body(self): ) accumname2var[name] = self.cse.namedvar(name, dtype=torch.float) self.body.writeline("split_size = min(RSPLIT_SIZE, xnumel - xoffset)") - self.body.writeline("for suboff in range(0, split_size, XBLOCK):") + self.body.writeline("for _ in range(0, split_size, XBLOCK):") with self.body.indent(offset=1): + self.body.splice(self.indexing_code) self.body.writelines( [ - "x0 = xindex + suboff", + "xindex += XBLOCK", + # TODO we force XBLOCK==1 for now so there is + # no need to update the xmask ] ) - self.body.splice(self.indexing_code) self.body.splice(self.loads) self.body.splice(self.compute) self.body.splice(self.stores) @@ -4821,7 +4838,7 @@ def codegen_kernel_benchmark(self, num_gb: Optional[float]) -> IndentedBuffer: result.writeline("args = get_args()") result.writeline( - "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)" + f"ms = benchmarker.benchmark(lambda: call(args), device={V.graph.get_current_device_or_throw().type}, rep=40)" # noqa: B950 line too long ) result.writeline(f"num_gb = {num_gb}") result.writeline("gb_per_s = num_gb / (ms / 1e3)") @@ -5048,6 +5065,7 @@ def add_constexpr_arg(arg_name): "mutated_arg_names": mutated_args, "optimize_mem": optimize_mem, "no_x_dim": self.no_x_dim, + "atomic_add_found": self.atomic_add_found, "num_load": self.num_load, "num_store": self.num_store, "num_reduction": self.num_reduction, @@ -5354,7 +5372,10 @@ def create_cse_var(self, *args, **kwargs) -> TritonCSEVariable: def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry): line = f"{entry.name} = {self.kexpr(self.rename_indexing(entry.expr))}" - if entry.root.is_loop: + + # mix order reduction introduces an extra loop across the x + # dimension + if entry.root.is_loop or (self.mix_order_reduction and entry.prefix == "x"): self.indexing_code.writeline(line) else: # lift non-reduction stores outside loop @@ -5776,18 +5797,21 @@ def load_cache(): # skip benchmarking the kernel if there are register spills ms = float("inf") else: + device = V.graph.get_current_device_or_throw() # We have to clone the inplace updated arguments to avoid earlier calls # generating out of range indices for later calls. - ms = benchmarker.benchmark_gpu( - lambda: call(wrapped_jit_function.clone_args(*args)[0]) + ms = benchmarker.benchmark( + lambda: call(wrapped_jit_function.clone_args(*args)[0]), + device=device, ) # overhead of cloning args gives bias for fusing the kernel # in the case of mutating/in-placeable second fusion # TODO - would be better as a hook in triton do_bench that reset # the input values between benchmarking if len(wrapped_jit_function.mutated_arg_names) > 0: - ms = ms - benchmarker.benchmark_gpu( - lambda: wrapped_jit_function.clone_args(*args) + ms = ms - benchmarker.benchmark( + lambda: wrapped_jit_function.clone_args(*args), + device=str(device), ) log.debug( @@ -5956,13 +5980,16 @@ def store_cache(): # skip benchmarking the kernel if there are register spills ms = ms_clone = float("inf") else: + device = V.graph.get_current_device_or_throw() # We have to clone the inplace updated arguments to avoid earlier calls # generating out of range indices for later calls. - ms = benchmarker.benchmark_gpu( - lambda: call(wrapped_jit_function.clone_args(*args)[0]) + ms = benchmarker.benchmark( + lambda: call(wrapped_jit_function.clone_args(*args)[0]), + device=device, ) - ms_clone = benchmarker.benchmark_gpu( - lambda: wrapped_jit_function.clone_args(*args)[0] + ms_clone = benchmarker.benchmark( + lambda: wrapped_jit_function.clone_args(*args)[0], + device=device, ) log.debug( diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index 7778498237c09..ec4ff92a3e7d9 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -717,7 +717,7 @@ def add_numel_to_call_args( self, name: str, call_args: list[Any], arg_types: list[Any] ) -> None: for num, sub_kernel in enumerate(self.sub_kernels): - for i, tree in enumerate(sub_kernel.range_trees): + for tree in sub_kernel.range_trees: numel_name = f"{tree.prefix}numel_{num}" if numel_name not in self.dynamic_shape_args: continue @@ -735,7 +735,7 @@ def add_numel_to_call_args( def kernel_benchmark_extra_args(self) -> list[str]: extra_args = [] for num, sub_kernel in enumerate(self.sub_kernels): - for i, tree in enumerate(sub_kernel.range_trees): + for tree in sub_kernel.range_trees: numel_name = f"{tree.prefix}numel_{num}" if numel_name not in self.dynamic_shape_args: continue @@ -896,6 +896,7 @@ def codegen_kernel_benchmark(self, num_gb: float) -> IndentedBuffer: result.writeline(f"return {', '.join(var_names)},") result.writelines(["\n", "\n", "def call(args):"]) + device = V.graph.get_current_device_or_throw() index = V.graph.get_current_device_or_throw().index with result.indent(): result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") @@ -930,7 +931,7 @@ def codegen_kernel_benchmark(self, num_gb: float) -> IndentedBuffer: result.writeline("args = get_args()") result.writeline( - "ms = benchmarker.benchmark_gpu(lambda: call(args), rep=40)" + f"ms = benchmarker.benchmark(call, fn_args=(args,), device={device.type},rep=40)" ) result.writeline(f"num_gb = {num_gb}") result.writeline("gb_per_s = num_gb / (ms / 1e3)") @@ -1018,7 +1019,7 @@ def combo_grid_meta(self) -> dict[str, Any]: for num, sub_kernel in enumerate(self.sub_kernels): meta[f"no_x_dim_{num}"] = sub_kernel.no_x_dim - for i, tree in enumerate(sub_kernel.range_trees): + for tree in sub_kernel.range_trees: # pyrefly: ignore [missing-argument] if not tree.is_reduction: numel_name = f"{tree.prefix}numel_{num}" diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index fa5048fd726b7..ea3c6a35d5c5f 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -77,6 +77,8 @@ import triton from ..graph import GraphLowering + from ..ir import ExternKernel + from ..scheduler import BaseSchedulerNode from .wrapper_fxir import FxConverter @@ -287,11 +289,15 @@ def traverse(cur_kernel): if isinstance(symbol, JITFunction): compile_wrapper.newline() compile_wrapper.writeline("@triton.jit") + # pyrefly: ignore # missing-attribute compile_wrapper.splice(symbol.src, strip=True) symbols_included.add(symbol_name) traverse(symbol) elif hasattr(triton, "constexpr_function") and isinstance( - symbol, triton.runtime.jit.ConstexprFunction + # pyrefly: ignore # missing-attribute + symbol, + # pyrefly: ignore # missing-attribute + triton.runtime.jit.ConstexprFunction, ): compile_wrapper.newline() compile_wrapper.writeline("@triton.constexpr_function") @@ -528,6 +534,7 @@ def codegen(self, code: IndentedBuffer) -> None: node.output_view.codegen_reference() if node.output_view else None, args, device, + self.node.get_stack_traces(), ) def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: @@ -1554,6 +1561,7 @@ def _generate_extern_kernel_out_helper( out_view: Optional[str], args: list[str], device: str, + stack_traces: Optional[OrderedSet[str]] = None, ) -> None: # add debug printer code for triton kernel calls at (jit) inductor level debug_printer_manager = V.graph.wrapper_code.debug_printer @@ -3600,16 +3608,12 @@ def codegen_subgraph(subgraph, outer_inputs, outer_outputs): self.writeline("if not should_loop:") if stack_output: # Handle the case when loop never executes - for i, (carried_input, carried_buf) in enumerate( - zip(outer_carried_inputs, while_loop.carried_inputs) - ): + for i, carried_input in enumerate(outer_carried_inputs): self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) self.writeline(f"{name}[{i}] = {carried_input}.unsqueeze(0).clone()") self.writeline(ExitSubgraphLine(self)) else: - for i, (carried_input, carried_buf) in enumerate( - zip(outer_carried_inputs, while_loop.carried_inputs) - ): + for i, carried_input in enumerate(outer_carried_inputs): self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) self.writeline(f"{name}[{i}] = {carried_input}.clone()") self.writeline(ExitSubgraphLine(self)) @@ -3690,6 +3694,29 @@ def static_shape_for_buffer_or_none(buffer): def can_prove_buffer_has_static_shape(buffer): return PythonWrapperCodegen.static_shape_for_buffer_or_none(buffer) is not None + def write_kernel_context_guard( + self, + kernel_name: str, + node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel], + ): + return + + def write_kernel_context_guard_begin( + self, + ): + """ + Mark the beginning of kernel context guard + """ + return + + def write_kernel_context_guard_end( + self, + ): + """ + Mark the end of kernel context guard + """ + return + class SubgraphPythonWrapperCodegen(PythonWrapperCodegen): """ diff --git a/torch/_inductor/codegen/wrapper_fxir.py b/torch/_inductor/codegen/wrapper_fxir.py index 7b684124a3980..2fee74458b1d6 100644 --- a/torch/_inductor/codegen/wrapper_fxir.py +++ b/torch/_inductor/codegen/wrapper_fxir.py @@ -14,7 +14,6 @@ _node_metadata_hook, _set_node_metadata_hook, ) -from torch._export.utils import _detect_fake_mode_from_gm from torch._higher_order_ops.triton_kernel_wrap import ( TraceableTritonKernelWrapper, tracing_triton_hopifier_singleton, @@ -23,7 +22,7 @@ from torch._inductor.codecache import LambdaFuture, PyCodeCache from torch._inductor.runtime.triton_heuristics import CachingAutotuner from torch._inductor.select_algorithm import extern_kernels # noqa: F401 -from torch._inductor.utils import convert_shape_to_symint, convert_to_symint +from torch._inductor.utils import convert_to_symint from torch._inductor.virtualized import V from torch._library.triton import wrap_triton from torch.fx import GraphModule @@ -315,21 +314,6 @@ def _import_kernel(self, code: str, kernel_name: str) -> CachingAutotuner: return kernel - def _fake_tensor( - self, - size: tuple[Any, ...], - stride: tuple[Any, ...], - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - ) -> torch.Tensor: - with V.fake_mode: - return torch.empty_strided( - convert_shape_to_symint(size), - convert_shape_to_symint(stride), - dtype=dtype, - device=device, - ) - def _create_as_strided( self, input_node: torch.fx.Node, @@ -606,11 +590,9 @@ def generate(self) -> torch.fx.GraphModule: self._generate_graph_constants() self._generate_subgm_getattrs() - fake_mode = _detect_fake_mode_from_gm(self.gm) - with _set_node_metadata_hook( self.gm, - functools.partial(_node_metadata_hook, fake_mode=fake_mode), + functools.partial(_node_metadata_hook, fake_mode=V.fake_mode), ): self._generate_graph_input_shapes() @@ -967,7 +949,9 @@ def tune_kernel(tuner: CachingAutotuner, call_args: Sequence[Any]) -> None: from triton.runtime import driver log.info("Autotuning Triton kernel %s at compile time.", kernel_name) + # pyrefly: ignore # missing-attribute device = driver.active.get_current_device() + # pyrefly: ignore # missing-attribute stream = driver.active.get_current_stream(device) def node_to_tuning_arg(arg: Any) -> Any: diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index 51c5472c7fe34..77c8bf94015c8 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -2,11 +2,12 @@ import logging import math from enum import IntEnum -from typing import Optional +from typing import Any, Optional import sympy import torch +import torch.utils._pytree as pytree from torch.fx.operator_schemas import normalize_function from . import ir @@ -67,15 +68,18 @@ def get_collective_type(node: ir.IRNode) -> NCCL_COLL: return get_collective_type_from_kernel_name(name) +def get_size_numel(size: torch.Size, fallback: int = 4096 * 4096) -> int: + numel = sympy_product(size) + if isinstance(numel, sympy.Integer): + return int(numel) + + return V.graph.sizevars.size_hint(numel, fallback=fallback) + + def get_collective_input_size_bytes(node: ir.IRNode) -> int: sz_bytes = 0 for inp in node.inputs: # type: ignore[attr-defined] - numel = sympy_product(inp.layout.size) - if isinstance(numel, sympy.Integer): - # For ease of testing - numel = int(numel) - else: - numel = V.graph.sizevars.size_hint(numel, fallback=0) + numel = get_size_numel(inp.layout.size) sz_bytes += numel * get_dtype_size(inp.layout.dtype) return sz_bytes @@ -176,13 +180,9 @@ def estimate_nccl_collective_runtime_nccl_estimator(snode) -> Optional[float]: kernel = snode.node assert kernel is not None py_kernel_name = getattr(kernel, "python_kernel_name", "") - if not ("all_gather" in py_kernel_name or "reduce_scatter" in py_kernel_name): - # NCCL of version 2.27 sometimes unrecoverably fail for all_to_all, all_reduce - return None - + pg_name = kernel.constant_args[-1] # type: ignore[attr-defined] from torch.distributed.distributed_c10d import _resolve_process_group - pg_name = kernel.constant_args[-1] # type: ignore[attr-defined] pg = _resolve_process_group(pg_name) rank: int = torch.distributed.get_rank(pg) # TODO(ivankobzarev): Figure out how we can use time estimations, @@ -357,7 +357,9 @@ def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int: def estimate_nccl_collective_runtime_from_fx_node( - fx_node: torch.fx.Node, override_size: Optional[int] = None + fx_node: torch.fx.Node, + override_size: Optional[int] = None, + use_nccl_estimator: bool = True, ) -> float: """ Returns estimated NCCL collective runtime in nanoseconds (ns). @@ -386,12 +388,59 @@ def estimate_nccl_collective_runtime_from_fx_node( normalize_to_only_use_kwargs=True, ) assert opt_args_kwargs is not None - _, kwargs = opt_args_kwargs + args, kwargs = opt_args_kwargs - group_size = _get_group_size_by_name(kwargs["group_name"]) + group_name = kwargs["group_name"] + group_size = _get_group_size_by_name(group_name) assert isinstance(fx_node.target, torch._ops.OpOverload) coll = get_collective_type_from_kernel_name(fx_node.target.name()) + def _nccl_estimate() -> Optional[float]: + # TODO: Refactor with estimate_nccl_collective_runtime_nccl_estimator + + flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs)) + + def _tensor(size, dtype, device) -> torch.Tensor: # type: ignore[no-untyped-def] + return torch.empty( + size if override_size is None else [override_size], + dtype=dtype, + device=device, + ) + + def try_size_hint(s: sympy.Expr) -> int: + return V.graph.sizevars.size_hint(s, fallback=0) + + def to_real_tensor(e: Any) -> Any: + if isinstance(e, torch.fx.Node): + return to_real_tensor(e.meta["val"]) + if isinstance(e, torch.Tensor): + return _tensor([get_size_numel(e.size())], e.dtype, e.device) + return e + + flat_args = [to_real_tensor(a) for a in flat_args] + real_args, real_kwargs = pytree.tree_unflatten(flat_args, flat_args_pytree_spec) + + from torch.distributed.distributed_c10d import _resolve_process_group + + pg = _resolve_process_group(group_name) + fn = fx_node.target + assert isinstance(fn, torch._ops.OpOverload) + with torch.distributed._time_estimator(group=pg) as time_estimator: + w = fn(*real_args, **real_kwargs) + torch.ops._c10d_functional.wait_tensor.default(w) + est_time_us = time_estimator.estimated_time + # -1000 constant is NCCL return in case of error during estimations. + # Observed it for all_to_all estimations. + if est_time_us < 0: + return None + est_time_ms = est_time_us / 1e3 + return est_time_ms + + if torch.distributed.is_nccl_available() and use_nccl_estimator: + est_time_ms = _nccl_estimate() + if est_time_ms is not None: + return est_time_ms + return estimate_nccl_collective_runtime_impl( tensor_storage_size_bytes, group_size, coll ) diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index f063d911b2a46..6c7c9a8bd7dab 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -424,10 +424,7 @@ def _update_memory_tracking_after_swap( return # Candidate becomes last use of some bufs - for ( - gn, - bufs, - ) in group_n_to_bufs_after_swap_dealloc_by_candidate.items(): + for bufs in group_n_to_bufs_after_swap_dealloc_by_candidate.values(): for buf in bufs: buf_to_snode_last_use[buf] = candidate @@ -840,7 +837,7 @@ def schedule_collective_for_overlap(snode): else: schedule(snode) - for snode, deps in unmet_deps.items(): + for deps in unmet_deps.values(): assert len(deps) == 0, ( f"Detected unscheduled nodes. Nodes with unmet dependencies: {unmet_deps}" ) @@ -1396,7 +1393,7 @@ def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph): for idx, node in enumerate(node_list): if ( node.op == "call_function" - and node.target == torch.ops.inductor.resize_storage_bytes_.default + and node.target is torch.ops.inductor.resize_storage_bytes_.default ): assert node.args[0].op == "placeholder", f"""\ Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]} @@ -1447,7 +1444,7 @@ def check_resize_pattern(graph_input): # Find all eligible unsharded params and their corresponding graph intermediates. unsharded_param_to_fsdp_copy_node_idxes = defaultdict(list) for idx, node in enumerate(node_list): - if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default: + if node.op == "call_function" and node.target is torch.ops.fsdp.copy_.default: fsdp_copy_node = node unsharded_param = node.args[0] assert unsharded_param.op == "placeholder", f""" @@ -1459,8 +1456,8 @@ def check_resize_pattern(graph_input): def is_allowed_mutation(node): return ( - node.target == torch.ops.fsdp.copy_.default - or node.target == torch.ops.inductor.resize_storage_bytes_.default + node.target is torch.ops.fsdp.copy_.default + or node.target is torch.ops.inductor.resize_storage_bytes_.default ) def is_node_mutating_unsharded_param_or_its_alias(node, unsharded_params): @@ -1552,11 +1549,8 @@ def is_node_mutating_unsharded_param_or_its_alias(node, unsharded_params): node.args = new_args # Delete `fsdp.copy_(unsharded_param, Y)` nodes - for ( - unsharded_param, - fsdp_copy_node_idxes, - ) in unsharded_param_to_fsdp_copy_node_idxes.items(): - for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes): + for fsdp_copy_node_idxes in unsharded_param_to_fsdp_copy_node_idxes.values(): + for fsdp_copy_node_idx in fsdp_copy_node_idxes: fsdp_copy_node = node_list[fsdp_copy_node_idx] graph.erase_node(fsdp_copy_node) @@ -1564,7 +1558,7 @@ def is_node_mutating_unsharded_param_or_its_alias(node, unsharded_params): for node in node_list: if ( node.op == "call_function" - and node.target == torch.ops.inductor.resize_storage_bytes_.default + and node.target is torch.ops.inductor.resize_storage_bytes_.default and node.args[0] in unsharded_param_to_fsdp_copy_node_idxes ): graph.erase_node(node) @@ -1612,7 +1606,7 @@ def remove_unused_getitem(g): node_list = list(g.nodes) for n in node_list: if ( - n.target == operator.getitem + n.target is operator.getitem and n.args[0].target is torch.ops.fsdp.all_gather_copy_in.default and n.args[1] == 1 ): diff --git a/torch/_inductor/comms_debug.py b/torch/_inductor/comms_debug.py index b6012828b8731..20c9779a4ef3f 100644 --- a/torch/_inductor/comms_debug.py +++ b/torch/_inductor/comms_debug.py @@ -46,7 +46,7 @@ def _debug_iterative_memory_recompute( if iter_cm != new_cm: log = "ITERATIVE CURR MEMORY CANDIDATE DOES NOT MATCH" iterative_recompute_error = True - for i, gn in enumerate(gns): + for gn in gns: iter_gnm = iter_curr_memory[gn] new_gnm = est_curr_memory[gn] if iter_gnm != new_gnm: @@ -65,7 +65,7 @@ def _debug_iterative_memory_recompute( f"\nCANDIDATE_NEW_ALLOCFREE:{snodes_allocfree[candidate]}" ) peak_log = "" - for i, (pre, post) in enumerate(snodes_curr_memory): + for i, (pre, _post) in enumerate(snodes_curr_memory): if est_peak_memory == pre: n = snodes[i] peak_log = ( diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 03f890558b395..957f9f9fafe23 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1550,9 +1550,12 @@ def codegen_and_compile( payload_fn=lambda: inductor_kernel_stack_trace_str, ) if inductor_kernel_stack_trace_str: - get_metrics_context().add_to_set( - "inductor_provenance", inductor_kernel_stack_trace_str - ) + metrics_context = get_metrics_context() + if metrics_context.in_progress(): + metrics_context.add_to_set( + "inductor_provenance", + inductor_kernel_stack_trace_str, + ) node_runtimes = None if inductor_metrics_log.isEnabledFor(logging.INFO): @@ -2445,6 +2448,11 @@ def compile_fx( # Some arguments trigger a recursive call to compile_fx. Handle these # short circuits first, before anything else + from torch._inductor.compiler_bisector import CompilerBisector + + if CompilerBisector.disable_subsystem("inductor", "pre_grad_graph"): + return model_ + if config_patches: with config.patch(config_patches): return compile_fx( diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index 30a4e2203bb96..060dc17b0ce4b 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -29,6 +29,7 @@ from torch._inductor.compile_worker.utils import _async_compile_initializer from torch._inductor.utils import get_ld_library_path, python_subprocess_env from torch._utils_internal import find_compile_subproc_binary +from torch.monitor import _WaitCounter, _WaitCounterTracker log = logging.getLogger(__name__) @@ -193,9 +194,26 @@ def __init__( self.futures_lock = threading.Lock() self.pending_futures: dict[int, Future[Any]] = {} + # The pending waitcounter, is used to indicate the time when we have any specific job running. + self.pending_waitcounters: dict[int, Any] = {} self.job_id_count = itertools.count() + # The running waitcounter indicates the time when the SubProcPool object exists. self.running = True + self.running_waitcounter = _WaitCounter( + "pytorch.wait_counter.subproc_pool.running" + ).guard() + self.running_waitcounter.__enter__() + + # The quiesce waitcounter indicates when the job is in a quiesced state. + self.quiesce_waitcounter: Optional[_WaitCounterTracker] = None + + # Firstjob is used to capture the time from when the firstjob is queued, to when the first job is done. + self.firstjob = True + self.firstjob_id: Optional[int] = None + self.firstjob_waitcounter = _WaitCounter( + "pytorch.wait_counter.subproc_pool.first_job" + ).guard() # Start thread last to ensure all member variables are initialized # before any access. @@ -212,6 +230,19 @@ def submit( with self.futures_lock: job_id = next(self.job_id_count) self.pending_futures[job_id] = future = Future() + self.pending_waitcounters[job_id] = _WaitCounter( + "pytorch.wait_counter.subproc_pool.job" + ).guard() + self.pending_waitcounters[job_id].__enter__() + if self.quiesce_waitcounter: + self.firstjob = True + self.quiesce_waitcounter.__exit__() + self.quiesce_waitcounter = None + # This can be entered from either quiesce wakeup, or from startup. + if self.firstjob: + self.firstjob_id = job_id + self.firstjob_waitcounter.__enter__() + self.firstjob = False future.set_running_or_notify_cancel() self._send(MsgHeader.JOB, job_id, job_data) return future @@ -239,6 +270,7 @@ def _read_thread(self) -> None: if self.running: log.warning("SubprocPool unclean exit") self.running = False + self.running_waitcounter.__exit__() self.read_pipe.close() # Cancel all the pending futures. self.shutdown() @@ -265,10 +297,21 @@ def _read_thread(self) -> None: self.pending_futures[job_id].set_exception(result) else: self.pending_futures[job_id].set_result(result) + + self.pending_waitcounters[job_id].__exit__() + del self.pending_waitcounters[job_id] + if self.firstjob_id == job_id: + self.firstjob_waitcounter.__exit__() + del self.pending_futures[job_id] def quiesce(self) -> None: self._send(MsgHeader.QUIESCE) + assert self.quiesce_waitcounter is None + self.quiesce_waitcounter = _WaitCounter( + "pytorch.wait_counter.subproc_pool.running" + ).guard() + self.quiesce_waitcounter.__enter__() def wakeup(self) -> None: self._send(MsgHeader.WAKEUP) @@ -279,6 +322,7 @@ def shutdown(self) -> None: if not self.running: return self.running = False + self.running_waitcounter.__exit__() _send_msg(self.write_pipe, MsgHeader.SHUTDOWN) self.write_pipe.close() self.process.wait(300) diff --git a/torch/_inductor/compile_worker/timer.py b/torch/_inductor/compile_worker/timer.py new file mode 100644 index 0000000000000..d4b0c0dc9e281 --- /dev/null +++ b/torch/_inductor/compile_worker/timer.py @@ -0,0 +1,54 @@ +from threading import Lock, Thread +from time import monotonic, sleep +from typing import Callable, Optional, Union + + +class Timer: + """ + This measures how long we have gone since last receiving an event and if it is greater than a set interval, calls a function. + """ + + def __init__( + self, + duration: Union[int, float], # Duration in seconds + call: Callable[[], None], # Function to call when we expire + ) -> None: + # We don't start the background thread until we actually get an event. + self.background_thread: Optional[Thread] = None + self.last_called: Optional[float] = None + self.duration = duration + self.sleep_time = 60 + self.call = call + self.exit = False + + self.lock = Lock() + + def record_call(self) -> None: + with self.lock: + if self.background_thread is None: + self.background_thread = Thread( + target=self.check, daemon=True, name="subproc_worker_timer" + ) + self.background_thread.start() + self.last_called = monotonic() + + def quit(self) -> None: + with self.lock: + self.exit = True + + def check(self) -> None: + while True: + # We have to be sensitive on checking here, to avoid too much impact on cpu + sleep(self.sleep_time) + with self.lock: + if self.exit: + return + assert self.last_called is not None + if self.last_called + self.duration >= monotonic(): + continue + self.last_called = None + self.background_thread = None + + # Releasing lock in case self.call() takes a very long time or is reentrant + self.call() + return diff --git a/torch/_inductor/compiler_bisector.py b/torch/_inductor/compiler_bisector.py index 41dc4777df823..32cceb9c384a3 100644 --- a/torch/_inductor/compiler_bisector.py +++ b/torch/_inductor/compiler_bisector.py @@ -491,6 +491,13 @@ def do_bisect( Run fn repeatedly attempting to bisect torch.compile. fn should return True on success and False on failure. """ + # TODO graph bisecting is not well composed with lowering + # bisector so far. Use a config to opt-in + import torch._inductor.config as inductor_config + + if inductor_config.test_configs.bisect_pre_grad_graph: + BACKENDS["inductor"].insert(0, BisectSubsystem("pre_grad_graph")) + if not cli_interface: bisection_enabled_orig = cls.bisection_enabled cls.delete_bisect_status() @@ -502,6 +509,9 @@ def cleanup() -> None: cls.delete_bisect_status() cls.in_process_cache = None + if BACKENDS["inductor"][0].name == "pre_grad_graph": + del BACKENDS["inductor"][0] + cleanup_handler = atexit.register(cleanup) class DisableBisect: diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index ec083c45da825..e3517c05299dc 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -8,6 +8,9 @@ from torch.utils._config_module import Config, get_tristate_env, install_config_module +if TYPE_CHECKING: + from torch._inductor.choices import InductorChoices + inplace_padding = os.environ.get("TORCHINDUCTOR_INPLACE_PADDING", "1") == "1" can_inplace_pad_graph_input = False # ease testing @@ -483,6 +486,11 @@ def prologue_fusion_enabled() -> bool: == "1" ) +# register ops upon which inductor should partition the graph. name format should be +# "namespace::kernel_name" (e.g., aten::mm) for op overload packet, or +# "namespace::kernel_name.overload" (e.g., aten::mm.default). +custom_should_partition_ops: list[str] = [] + # whether template autotuning should allow flexible layouts if possible (e.g. only extern choices) max_autotune_allow_flexible_layouts: bool = False @@ -648,6 +656,9 @@ def use_autoheuristic(name: str) -> bool: os.environ.get("TORCHINDUCTOR_ASSUME_UNALIGNED_FALLBACK_OUTPUT") == "1" ) +# Custom InductorChoices callable to use (can be a class or functools.partial with kwargs) +inductor_choices_class: Optional[Callable[[], "InductorChoices"]] = None + # fuse even in cases without common reads aggressive_fusion = False @@ -1544,6 +1555,9 @@ class triton: os.environ.get("TORCHINDUCTOR_MIX_ORDER_REDUCTION", "0") == "1" ) + mix_order_reduction_split_size: Optional[int] = None + mix_order_reduction_autotune_split_size = True + class aot_inductor: """ @@ -1934,6 +1948,9 @@ class rocm: # Backend to use for CUDA codegen either "triton" or "halide" (experimental) cuda_backend: Literal["triton", "halide"] = "triton" +# Backend to use for XPU codegen either "triton" +xpu_backend: Literal["triton"] = "triton" + class halide: # Base halide target to use for CPU devices @@ -2087,6 +2104,17 @@ class trace: ) +class lookup_table: + # Lookup table for template config overrides + table: Optional[dict[str, list[dict[str, Any]]]] = None + + # Enable template src_hash checking in lookup table to prevent using stale configs. + # If True, configs with 'template_hash' field will be compared against the template's + # src_hash at runtime and filtered out if they don't match. If False, no + # hash checking is performed. + check_src_hash: bool = True + + class test_configs: force_extern_kernel_in_multi_template: bool = False @@ -2125,6 +2153,9 @@ class test_configs: "TORCHINDUCTOR_DISTORT_BENCHMARKING_RESULT", "" ) + bisect_pre_grad_graph = False + bisect_keep_custom_backend_for_inductor = False + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_inductor/constant_folding.py b/torch/_inductor/constant_folding.py index 85033e2b3e8d5..8150c1a1ea4b1 100644 --- a/torch/_inductor/constant_folding.py +++ b/torch/_inductor/constant_folding.py @@ -122,7 +122,7 @@ def _deduce_value(self, node: torch.fx.Node) -> Any: def is_impure(self, node: torch.fx.node.Node) -> bool: def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool: return ( - node.target == torch.ops.prims.convert_element_type.default # type: ignore[return-value] + node.target is torch.ops.prims.convert_element_type.default # type: ignore[return-value] and isinstance(node.args[0], torch.fx.Node) and "val" in node.args[0].meta and node.args[0].meta["val"].dtype == torch.int8 # type: ignore[union-attr] @@ -132,7 +132,7 @@ def is_woq_int8_pattern(node: torch.fx.node.Node) -> bool: if ( is_woq_int8_pattern(node) or ( - node.target == torch.ops.aten.permute.default + node.target is torch.ops.aten.permute.default and len(node.users) == 1 and is_woq_int8_pattern(next(iter(node.users))) ) @@ -214,7 +214,7 @@ def set_env(arg: torch.fx.Node) -> None: # TODO - fix errors with this if ( node.op == "call_function" - and node.target == aten._efficientzerotensor.default + and node.target is aten._efficientzerotensor.default ): return self.unknown_value diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 34f55c7bf797c..6ebbd1bb5b719 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -949,7 +949,7 @@ def maybe_get_static_data_ptr( self.recorded_liveness_before_graph: LevelList[OutputList[bool]] = [] self.recorded_liveness_after_graph: LevelList[OutputList[bool]] = [] - # List of Tuples of (depth, output_index) that index into node at depth + # List of tuples of (depth, output_index) that index into node at depth # number of nodes from root and output_index of outputs. Will index into # path_weakrefs. self.expected_dead_indices_before_graph: list[PathOutputIndex] = [] diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index effed470548cb..5fdf3fe1c0429 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -73,7 +73,7 @@ def get_mutating_use_stack_trace_from_node( return next(iter(placeholder_node.users)).meta.get("stack_trace", None) for use in placeholder_node.users: - if use.target == torch.ops.aten.copy_.default: + if use.target is torch.ops.aten.copy_.default: if stack_trace := use.meta.get("stack_trace", None): return stack_trace diff --git a/torch/_inductor/dtype_propagation.py b/torch/_inductor/dtype_propagation.py index aab2c49fe4e35..366d57647ec68 100644 --- a/torch/_inductor/dtype_propagation.py +++ b/torch/_inductor/dtype_propagation.py @@ -253,6 +253,7 @@ def partial_accumulate( name: str, reduction_type: str, value: DTypeArg, + extra_meta: dict[str, Any], ) -> None: return None diff --git a/torch/_inductor/freezing.py b/torch/_inductor/freezing.py index 76bad0ec967b1..70ebe6e9ead06 100644 --- a/torch/_inductor/freezing.py +++ b/torch/_inductor/freezing.py @@ -271,7 +271,7 @@ def convert_conv_weights_to_channels_last(gm: torch.fx.GraphModule): folded by freezing. """ with dynamo_timed("convert_conv_weights_to_channels_last"): - convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default] + convs = [n for n in gm.graph.nodes if n.target is aten.convolution.default] for conv in convs: weight_node = conv.args[1] if len(weight_node.meta["val"].size()) != 4 or weight_node.meta[ diff --git a/torch/_inductor/fuzzer.py b/torch/_inductor/fuzzer.py index 0a8459da2ffc9..382cf3c54db3d 100644 --- a/torch/_inductor/fuzzer.py +++ b/torch/_inductor/fuzzer.py @@ -517,6 +517,7 @@ def keys(self) -> KeysView[ComboType]: "joint_custom_pre_pass": DEFAULT, # Typing "pre_grad_custom_pass": DEFAULT, # Typing "custom_partitioner_fn": DEFAULT, # Typing + "inductor_choices_class": DEFAULT, # Typing }, "torch._dynamo.config": { "traceable_tensor_subclasses": DEFAULT, # Typing diff --git a/torch/_inductor/fx_passes/b2b_gemm.py b/torch/_inductor/fx_passes/b2b_gemm.py index 303d9bfd59a39..9faec788e9e3a 100644 --- a/torch/_inductor/fx_passes/b2b_gemm.py +++ b/torch/_inductor/fx_passes/b2b_gemm.py @@ -591,7 +591,7 @@ def is_pointwise_node(node: torch.fx.Node) -> bool: ) def is_mm(node: torch.fx.Node) -> bool: - return node.target == torch.ops.aten.mm.default + return node.target is torch.ops.aten.mm.default # the inner MM inner_mm = match.nodes[-1] diff --git a/torch/_inductor/fx_passes/binary_folding.py b/torch/_inductor/fx_passes/binary_folding.py index f2f68a76c426f..2f9bce1a8a2d5 100644 --- a/torch/_inductor/fx_passes/binary_folding.py +++ b/torch/_inductor/fx_passes/binary_folding.py @@ -284,7 +284,7 @@ def _is_foldable_pattern(match): if binary_node.args[0].target in _computation_ops: computation_node = binary_node.args[0] other = binary_node.args[1] - elif binary_node.args[0].target == aten.reshape.default: + elif binary_node.args[0].target is aten.reshape.default: computation_node = binary_node.args[0].args[0] other = binary_node.args[1] has_reshape = True @@ -295,7 +295,7 @@ def _is_foldable_pattern(match): computation_node = binary_node.args[1].args[0] other = binary_node.args[0] has_reshape = False - if computation_node.target == aten.convolution.default: + if computation_node.target is aten.convolution.default: return _check_conv_and_broadcast_op(computation_node, other) elif computation_node.target in [aten.addmm.default, aten.mm.default]: return ( @@ -344,7 +344,7 @@ def resize_scalar_or_tensor_to_shape(graph, other, shape, weight): return res def _create_new_conv_node(graph, conv_node, binary_node, other): - assert conv_node.target == aten.convolution.default + assert conv_node.target is aten.convolution.default conv_args = list(conv_node.args) weight_meta_value = conv_node.args[1].meta.get("val") bias = conv_args[2] @@ -472,7 +472,7 @@ def folded_op(match, *args, **kwargs): reshape_node = None if binary_node.args[0].target in _computation_ops: computation_node = binary_node.args[0] - elif binary_node.args[0].target == aten.reshape.default: + elif binary_node.args[0].target is aten.reshape.default: computation_node = binary_node.args[0].args[0] reshape_node = binary_node.args[0] elif binary_node.args[1].target in _computation_ops: @@ -483,7 +483,7 @@ def folded_op(match, *args, **kwargs): graph = match.graph with graph.inserting_before(reshape_node if reshape_node else binary_node): assert computation_node.target in _computation_ops - if computation_node.target == aten.convolution.default: + if computation_node.target is aten.convolution.default: counters["inductor"]["binary_folding_conv"] += 1 new_computation_node = _create_new_conv_node( graph, computation_node, binary_node, other @@ -494,7 +494,7 @@ def folded_op(match, *args, **kwargs): ) new_computation_node.meta.update(computation_node.meta) if reshape_node: - assert reshape_node.target == aten.reshape.default + assert reshape_node.target is aten.reshape.default computation_node.replace_all_uses_with(new_computation_node) binary_node.replace_all_uses_with(reshape_node) else: diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index d509e8c515e4f..ab831c96c94ba 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -51,9 +51,12 @@ def _ar_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]: return (group_name, reduce_op, dtype) -def bucket_key(node: torch.fx.Node) -> object | None: +def bucket_key(node: torch.fx.Node, mode: BucketMode | None = None) -> object | None: if is_all_gather_into_tensor(node): - return _ag_group_key(node) + group_key_fn = ( + _ag_group_key_multidtype if mode and "multidtype" in mode else _ag_group_key + ) + return group_key_fn(node) elif is_reduce_scatter_tensor(node): return _rs_group_key(node) elif is_all_reduce_tensor(node): @@ -119,28 +122,28 @@ def bucket_reduce_scatter( def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: # type: ignore[arg-type] return ( node.op == "call_function" - and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default + and node.target is torch.ops._c10d_functional.all_gather_into_tensor.default ) def is_reduce_scatter_tensor(node: torch.fx.Node) -> bool: return ( node.op == "call_function" - and node.target == torch.ops._c10d_functional.reduce_scatter_tensor.default + and node.target is torch.ops._c10d_functional.reduce_scatter_tensor.default ) def is_wait_tensor(node: torch.fx.Node) -> bool: return ( node.op == "call_function" - and node.target == torch.ops._c10d_functional.wait_tensor.default + and node.target is torch.ops._c10d_functional.wait_tensor.default ) def is_all_reduce_tensor(node: torch.fx.Node) -> bool: return ( node.op == "call_function" - and node.target == torch.ops._c10d_functional.all_reduce.default + and node.target is torch.ops._c10d_functional.all_reduce.default ) @@ -727,7 +730,7 @@ def process_collective_bucket( is_all_gather_into_tensor(n) and isinstance(node_in, torch.fx.Node) # Add type check and node_in.op == "call_function" - and node_in.target == torch.ops.prims.convert_element_type.default + and node_in.target is torch.ops.prims.convert_element_type.default and len(node_in.users) == 1 ): ag_node_to_pre_nodes[n].append(node_in) diff --git a/torch/_inductor/fx_passes/ddp_fusion.py b/torch/_inductor/fx_passes/ddp_fusion.py index 8f55d670058fb..8a4de1a604869 100644 --- a/torch/_inductor/fx_passes/ddp_fusion.py +++ b/torch/_inductor/fx_passes/ddp_fusion.py @@ -92,12 +92,12 @@ def get_comm_block(comm_node: fx.Node) -> CommBlock | None: first_user = next(iter(comm_node.users)) if ( len(comm_node.users) == 1 - and first_user.target == torch.ops._c10d_functional.wait_tensor.default + and first_user.target is torch.ops._c10d_functional.wait_tensor.default ): # Collective with only one output node_list = [comm_node, first_user] wait_nodes.append(first_user) - elif len(comm_node.users) > 1 and first_user.target == operator.getitem: + elif len(comm_node.users) > 1 and first_user.target is operator.getitem: # Collective with only more than one output node_list.append(comm_node) for user in comm_node.users: @@ -348,7 +348,7 @@ def _scatter_fused_allreduce_waits( # Some descendant users of the orig_comm_blocks may be scheduled before # the fused all_reduce. For example, the user nodes of the very first # all_reduce may be scheduled before the second all_reduce. Since the - # fused all_reduce is inserted right after the last all_reudce, the + # fused all_reduce is inserted right after the last all_reduce, the # order can be wrong. # `incorrect_order_nodes` records these nodes. diff --git a/torch/_inductor/fx_passes/freezing_patterns.py b/torch/_inductor/fx_passes/freezing_patterns.py index 04add7596b6d4..b8fca2087a5d5 100644 --- a/torch/_inductor/fx_passes/freezing_patterns.py +++ b/torch/_inductor/fx_passes/freezing_patterns.py @@ -144,7 +144,7 @@ def check_int8_woq_concat_linear_weights(match): weight_inputs.append("w3") if not all( - match.kwargs[wgt].target == torch.ops.prims.convert_element_type.default + match.kwargs[wgt].target is torch.ops.prims.convert_element_type.default for wgt in weight_inputs ): return False diff --git a/torch/_inductor/fx_passes/fsdp.py b/torch/_inductor/fx_passes/fsdp.py index 6a1a2d227de1d..6b0c2ad2c94a7 100644 --- a/torch/_inductor/fx_passes/fsdp.py +++ b/torch/_inductor/fx_passes/fsdp.py @@ -48,7 +48,7 @@ def is_fsdp_reduce_scatter_wait(wait: torch.fx.Node) -> bool: return ( is_graph_output(user) and user.op == "call_function" - and user.target == torch.ops.prims.convert_element_type.default + and user.target is torch.ops.prims.convert_element_type.default ) return False diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 471bdfb02813b..295c720382853 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -71,9 +71,9 @@ def update_stack_example_value(node, metadata, dim=0, op=torch.stack): Update the example value of the node in the graph to enable followup split cat opt. """ if node is not None and hasattr(node, "meta"): - if op == torch.stack: + if op is torch.stack: example_value = torch.stack(metadata, dim=dim) - elif op == torch.unbind: + elif op is torch.unbind: example_value = torch.unbind(metadata, dim=dim) # type: ignore[assignment] else: return @@ -85,9 +85,9 @@ def update_pointwise_example_value(pointwise_node, input, other, op): Update the example value of the add node in the graph to enable followup split cat opt. """ if pointwise_node is not None and hasattr(pointwise_node, "meta"): - if op == torch.add: + if op is torch.add: example_value = torch.add(input, other) - elif op == torch.mul: + elif op is torch.mul: example_value = torch.mul(input, other) else: return @@ -414,12 +414,12 @@ def match(self, node: torch.fx.Node): if self.graph_search_options.get("fuse_nodes_with_same_parent", False): # only consider the linear case so far # pyre-fixme[16] - if input.target == aten.select or other.target == aten.select: # type: ignore[union-attr] + if input.target is aten.select or other.target is aten.select: # type: ignore[union-attr] parent = ( # pyre-fixme[16] input.args[0] # type: ignore[union-attr] # pyre-fixme[16] - if input.target == aten.select # type: ignore[union-attr] + if input.target is aten.select # type: ignore[union-attr] else other.args[0] # type: ignore[union-attr] ) else: @@ -950,7 +950,7 @@ def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): torch.stack, args=(batch_inputs,), kwargs={"dim": 0} ) update_stack_example_value(stack_inputs, batch_inputs_metadata) - if self.op == torch.nn.functional.relu: + if self.op is torch.nn.functional.relu: batch_op = graph.call_function( # type: ignore[operator] self.op, args=(stack_inputs,), @@ -958,6 +958,7 @@ def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): ) batch_op.meta["example_value"] = self.op( stack_inputs.meta["example_value"], + # pyrefly: ignore [bad-argument-type] inplace=subset[0].kwargs.get("inplace", False), ) else: diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 87075efc20258..25b10966cfa96 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -227,7 +227,8 @@ def __init__(self, gm, skip_constructors=False) -> None: self.symint_nodes = _SymHashingDict() for n in self.module.graph.nodes: # type: ignore[union-attr] if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt): - self.symint_nodes[n.meta["val"]] = n + if n.meta["val"] not in self.symint_nodes: + self.symint_nodes[n.meta["val"]] = n # reference from torch/_funtorch/partitioners.py:get_default_op_list self.view_op_packets = [ @@ -315,7 +316,7 @@ def _deduce_value(self, node: torch.fx.Node): # single-elem attrs if node.op == "get_attr" or ( node.op == "call_function" - and node.target == torch.ops.aten.lift_fresh_copy.default + and node.target is torch.ops.aten.lift_fresh_copy.default ): out = super(ConstantFolder, self).run_node(node) if isinstance(out, torch.Tensor) and out.numel() == 1: @@ -328,7 +329,7 @@ def _deduce_value(self, node: torch.fx.Node): # constructors ops if ( node.op == "call_function" - and node.target == aten.full.default + and node.target is aten.full.default and len(node.args) == 2 ): args, kwargs = self.fetch_args_kwargs_from_env(node) @@ -339,7 +340,7 @@ def _deduce_value(self, node: torch.fx.Node): return aten.full.default(*new_args, **node.kwargs) # handle before view ops because this changes value - if node.target == aten.view.dtype: + if node.target is aten.view.dtype: return super(ConstantFolder, self).run_node(node) # view ops, return input tensor, the first argument @@ -437,7 +438,7 @@ def constant_fold_uniform_value(gm: torch.fx.GraphModule): # the conversion from tensor and back to value can be lossy, just use the original full ctor value if ( node.op == "call_function" - and node.target == aten.full.default + and node.target is aten.full.default and len(node.args) == 2 ): value = node.args[1] @@ -534,7 +535,7 @@ def canonicalize_quant_mapping(gm: torch.fx.GraphModule): if ( len(invoke_quant_replacement.users) == 1 and len(subgraph.users) == 1 - and first_user.target == operator.getitem + and first_user.target is operator.getitem and first_user.args[1] == 0 ): subgraph_graph = getattr(gm, subgraph.target) diff --git a/torch/_inductor/fx_passes/micro_pipeline_tp.py b/torch/_inductor/fx_passes/micro_pipeline_tp.py index 97b4342fa7638..dd61909bdb358 100644 --- a/torch/_inductor/fx_passes/micro_pipeline_tp.py +++ b/torch/_inductor/fx_passes/micro_pipeline_tp.py @@ -386,9 +386,9 @@ def __post_init__(self): if len(self.nodes) == 1: assert self.nodes[0].target in (aten.mm.default, aten._scaled_mm.default) else: - assert self.nodes[0].target == aten.reshape.default + assert self.nodes[0].target is aten.reshape.default assert self.nodes[1].target in (aten.mm.default, aten._scaled_mm.default) - assert self.nodes[2].target == aten.reshape.default + assert self.nodes[2].target is aten.reshape.default self.arg_ancestor_nodes = _find_ancestors(self.B_node) def replace_with(self, new_node: torch.fx.Node) -> None: @@ -415,7 +415,7 @@ def replace_with(self, new_node: torch.fx.Node) -> None: output_reshape_node = self.nodes[2] assert mm_node.target in (aten.mm.default, aten._scaled_mm.default) - assert output_reshape_node.target == aten.reshape.default + assert output_reshape_node.target is aten.reshape.default output_reshape_node.replace_all_uses_with(new_node) if len(mm_node.users) > 1: @@ -482,7 +482,7 @@ def get_arg(node: torch.fx.Node, idx: int, default: Any) -> Any: # Use mm_node with 2D args for both A and B, even if this is a "reshape -> mm -> reshape" pattern. # We will store the reshapes in pre_mm_reshape and post_mm_reshape, to be referenced later to # produce the correct output shapes, reduce-scatter along the correct dimensions, etc. - is_reshape_mm_reshape_pattern = match[0].target == aten.reshape.default + is_reshape_mm_reshape_pattern = match[0].target is aten.reshape.default mm_node = match[1] if is_reshape_mm_reshape_pattern else match[0] pre_mm_reshape = match[0] if is_reshape_mm_reshape_pattern else None post_mm_reshape = match[-1] if is_reshape_mm_reshape_pattern else None @@ -540,10 +540,10 @@ def _find_reshape_mm_reshape(node: torch.fx.Node) -> list[_Matmul]: matmuls = [] for match in matches: mm_node = match[1] - if mm_node.target == aten.mm.default: + if mm_node.target is aten.mm.default: matmul = _Matmul.from_match(match) matmuls.append(matmul) - elif mm_node.target == aten._scaled_mm.default: + elif mm_node.target is aten._scaled_mm.default: matmul = _ScaledMatmul.from_match(match) matmuls.append(matmul) else: @@ -561,13 +561,13 @@ def _find_consumer_matmuls(node: torch.fx.Node) -> list[_Matmul]: matmuls = [] for user in node.users: # ND matmuls - if user.target == aten.reshape.default: + if user.target is aten.reshape.default: matmuls.extend(_find_reshape_mm_reshape(user)) # 2D matmuls - elif user.target == aten.mm.default: + elif user.target is aten.mm.default: matmul = _Matmul.from_match(match=[user]) matmuls.append(matmul) - elif user.target == aten._scaled_mm.default: + elif user.target is aten._scaled_mm.default: matmul = _ScaledMatmul.from_match([user]) matmuls.append(matmul) return matmuls @@ -790,11 +790,11 @@ def _find_producer_matmul(node: torch.fx.Node) -> _Matmul | None: """ Returns producer matmul node if found, otherwise returns None. """ - if node.target == aten.mm.default: + if node.target is aten.mm.default: return _Matmul.from_match(match=[node]) - elif node.target == aten._scaled_mm.default: + elif node.target is aten._scaled_mm.default: return _ScaledMatmul.from_match(match=[node]) - elif node.target == aten.reshape.default: + elif node.target is aten.reshape.default: reshape_node_1 = node mm_node = reshape_node_1.args[0] @@ -807,9 +807,9 @@ def _find_producer_matmul(node: torch.fx.Node) -> _Matmul | None: if reshape_node_0.target != aten.reshape.default: return None - if mm_node.target == aten.mm.default: + if mm_node.target is aten.mm.default: return _Matmul.from_match(match=[reshape_node_0, mm_node, reshape_node_1]) - elif mm_node.target == aten._scaled_mm.default: + elif mm_node.target is aten._scaled_mm.default: return _ScaledMatmul.from_match( match=[reshape_node_0, mm_node, reshape_node_1] ) diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index addc6e1ea8ece..70b3a3c355dde 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -73,7 +73,7 @@ def get_linear_transpose_weight(self, weight_node): packed_weight_node = weight_node assert packed_weight_node.target == mkldnn._reorder_linear_weight transpose_weight_node = packed_weight_node.args[0] - assert transpose_weight_node.target == aten.permute.default + assert transpose_weight_node.target is aten.permute.default return transpose_weight_node def pack_conv_weight( @@ -991,7 +991,7 @@ def _register_binary_unary_fusion(): def _recover_linear(): # convert reshape+linear+reshape to a single linear for applying fusion path. - # concat_linear (pass_number=0) -> mkldnn_linear_pack (pass_numer=1) -> _recover_linear(pass_number=2) + # concat_linear (pass_number=0) -> mkldnn_linear_pack (pass_number=1) -> _recover_linear(pass_number=2) @register_freezing_graph_pattern( CallFunction( aten.reshape.default, @@ -1213,13 +1213,13 @@ def is_const_or_cat_by_const(weight): linear_node = match.output_node() # mkldnn linear only supports beta=1or0 and alpha=1 - if linear_node.target == aten.addmm.default: + if linear_node.target is aten.addmm.default: alpha = linear_node.kwargs.get("alpha", 1.0) beta = linear_node.kwargs.get("beta", 1.0) if (beta != 0.0 and beta != 1.0) or alpha != 1.0: return False # weight_idx is 1 for aten.mm and is 2 for aten.addmm - weight_idx = 2 if linear_node.target == aten.addmm.default else 1 + weight_idx = 2 if linear_node.target is aten.addmm.default else 1 if not is_const_or_cat_by_const(linear_node.args[weight_idx]): return False input_meta_value = linear_node.args[weight_idx - 1].meta.get("val") @@ -1437,17 +1437,17 @@ def get_item(graph, node, index): def linear(match, *args, **kwargs): graph = match.graph linear_node = match.output_node() - input = args[0] if linear_node.target == aten.mm.default else args[1] + input = args[0] if linear_node.target is aten.mm.default else args[1] bias = ( None - if linear_node.target == aten.mm.default + if linear_node.target is aten.mm.default or ( - linear_node.target == aten.addmm.default + linear_node.target is aten.addmm.default and linear_node.kwargs.get("beta", 1.0) == 0.0 ) else args[0] ) - weight = args[1] if linear_node.target == aten.mm.default else args[2] + weight = args[1] if linear_node.target is aten.mm.default else args[2] device_type = input.meta.get("val").device.type mkldnn_device_op = _get_mkldnn_device_op(device_type) with graph.inserting_before(linear_node): diff --git a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py index 660befff00eee..1d25896cb8c4a 100644 --- a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py +++ b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py @@ -1,17 +1,123 @@ +import logging from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Literal, Optional import torch.fx as fx +from torch._dynamo.utils import counters from torch._inductor.augmented_graph_helper import AugmentedGraphHelper from torch._inductor.fx_passes.bucketing import ( bucket_key, + BucketMode, is_all_gather_into_tensor as is_all_gather, is_reduce_scatter_tensor as is_reduce_scatter, is_wait_tensor, ) -from torch._inductor.fx_passes.overlap_scheduling import CollBucket, CollectiveInfo +from torch._inductor.fx_passes.overlap_scheduling import ( + CollBucket, + CollectiveInfo, + get_group_name, + is_compute_node, +) from torch.utils._ordered_set import OrderedSet +bucket_log = logging.getLogger(__name__) + + +@dataclass +class WhyNoBucket: + name1: str + name2: str + reason: str + args: tuple[Any, ...] + + def __init__(self, node1: fx.Node, node2: fx.Node) -> None: + self.name1 = node1.name + self.name2 = node2.name + self.reason = "" + self.args = () + + def __call__(self, reason: str, *args: Any) -> None: + if bucket_log.isEnabledFor(logging.DEBUG): + bucket_log.debug( + "cannot bucket %s with %s: " + reason, # noqa: G003 + self.name1, + self.name2, + *args, + ) + + +def is_collective_or_wait(n: fx.Node) -> bool: + """Check if node is a collective start or wait.""" + if is_wait_tensor(n): + return True + # Collective starts have exactly one use: the wait_tensor + if len(n.users) == 1: + user = next(iter(n.users.keys())) + if is_wait_tensor(user): + return True + return False + + +@dataclass +class PGEvent: + """ + Represents an important event in a process group timeline. Either + a collective start, wait, or hiding compute. Each node is linked + to its prev and next and these dependencies are reflected + in the augmented graph. + + We want to enforce a sequential ordering of collective starts and waits + because NCCL collectives on the same process group execute on the same CUDA + stream, creating implicit dependencies between all operations on that PG. + + A wait of a particular collective will implicitly force realization of all collectives + enqueued prior to that collective. + """ + + node: fx.Node + event_type: Literal["compute", "starts", "waits"] + position: int + prev: Optional["PGEvent"] = None + next: Optional["PGEvent"] = None + + @property + def is_start(self) -> bool: + return self.event_type == "starts" + + @property + def is_wait(self) -> bool: + return self.event_type == "waits" + + @property + def is_compute(self) -> bool: + return self.event_type == "compute" + + def unlink(self) -> tuple[Optional["PGEvent"], Optional["PGEvent"]]: + """Remove this event from the linked list, return (prev, next).""" + prev_event, next_event = self.prev, self.next + if self.prev: + self.prev.next = self.next + if self.next: + self.next.prev = self.prev + self.prev = None + self.next = None + return prev_event, next_event + + def insert_between( + self, prev_event: Optional["PGEvent"], next_event: Optional["PGEvent"] + ) -> None: + """Insert this event between prev_event and next_event in the linked list.""" + if prev_event: + prev_event.next = self + self.prev = prev_event + + if next_event: + next_event.prev = self + self.next = next_event + + class OverlapPreservingBucketer: """ Buckets collective operations while preserving compute-collective overlap relationships. @@ -27,6 +133,7 @@ def __init__( max_bucket_memory_gb: float = 1.0, max_coll_distance: int = 1000, insert_overlap_deps: bool = False, + bucket_mode: BucketMode = "custom_ops_multidtype", ): self.graph = graph self.collective_info = collective_info @@ -37,50 +144,159 @@ def __init__( self.aug_graph = AugmentedGraphHelper(self.graph, self.node_ancestors) self.max_coll_distance = max_coll_distance self.insert_overlap_deps = insert_overlap_deps + self.bucket_mode = bucket_mode + self.node_to_event: dict[fx.Node, PGEvent] = {} + self.pg_to_timeline_head: dict[str, Optional[PGEvent]] = self.build_timelines() - def bucket_collectives(self) -> None: - """Main entry point for bucketing collectives.""" + self._add_hiding_interval_constraints() + + def build_timelines(self) -> dict[str, Optional[PGEvent]]: + "Construct each process groups ordered series of event" + all_pgs: OrderedSet[str] = OrderedSet() + for start in self.collective_info: + pg = get_group_name(start) + all_pgs.add(pg) + + pg_timeline: dict[str, Optional[PGEvent]] = {} + for pg in all_pgs: + pg_timeline[pg] = self.build_timeline(pg) + + return pg_timeline + + def build_timeline(self, pg: str) -> Optional[PGEvent]: + """ + Build a timeline of important events (starts, waits, hiding compute) for this process group + and constrain this ordering in the augmented graph. + + Sequential dependencies are added between all events because NCCL collectives on the same + process group execute on the same CUDA stream, enforcing LIFO semantics where later-issued + collectives must complete before earlier ones can finish. + """ + + head = None + prev_event = None + position = 0 + + for node in self.scheduled: + node_type = None + + # Determine if this node is relevant for this PG + if node in self.collective_info and get_group_name(node) == pg: + node_type = "starts" + elif is_wait_tensor(node): + wait_input = node.args[0] + if isinstance(wait_input, fx.Node) and get_group_name(wait_input) == pg: + node_type = "waits" + elif is_compute_node(node): + node_type = "compute" + + if node_type is None: + continue + + event = PGEvent(node=node, event_type=node_type, position=position) # type: ignore[arg-type] + + event.insert_between(prev_event, None) + + # Add sequential dependency to augmented graph + if prev_event: + self.aug_graph.add_extra_dep(n=event.node, dep=prev_event.node) + else: + head = event + + prev_event = event + position += 1 + + return head + + def _populate_node_to_event(self, pg: str) -> None: + """Populate node_to_event mapping for a specific PG's timeline.""" + self.node_to_event.clear() + head = self.pg_to_timeline_head[pg] + curr = head + while curr is not None: + self.node_to_event[curr.node] = curr + curr = curr.next - # Add extra dependencies for hidden collectives - # For each hidden collective, add: compute -> start and wait -> compute - for start_node, info in self.collective_info.items(): + def _add_hiding_interval_constraints(self) -> None: + """ + Add hiding interval constraints: start -> compute -> wait. + """ + for start, info in self.collective_info.items(): if info.hiding_node and not info.is_exposed: - # Add edge: hiding_compute depends on start (start must come before compute) - self.aug_graph.add_extra_dep(n=info.hiding_node, dep=start_node) - # Add edge: wait depends on hiding_compute (compute must come before wait) + # Enforce: start -> compute -> wait + self.aug_graph.add_extra_dep(n=info.hiding_node, dep=start) self.aug_graph.add_extra_dep(n=info.wait_node, dep=info.hiding_node) - # Group collectives by bucket key (type, group, etc.) - grouped_collectives: dict[object, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + def bucket_collectives(self) -> None: + """Main entry point for bucketing collectives.""" + # Group collectives by PG first + pg_collectives: dict[str, OrderedSet[fx.Node]] = defaultdict(OrderedSet) for start in self.collective_info: - key = bucket_key(start) - if key is not None: - grouped_collectives[key].add(start) + pg = get_group_name(start) + pg_collectives[pg].add(start) all_buckets: list[CollBucket] = [] - for collective_group in grouped_collectives.values(): - buckets = self._find_buckets(collective_group) - all_buckets.extend(buckets) + for pg, collectives in pg_collectives.items(): + # Populate node_to_event for this PG's timeline + self._populate_node_to_event(pg) - # Collect all extra dependencies to preserve after bucketing - additional_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) + # Group by bucket key within this PG + grouped_collectives: dict[object, OrderedSet[fx.Node]] = defaultdict( + OrderedSet + ) + for start in collectives: + key = bucket_key(start, self.bucket_mode) + if key is not None: + grouped_collectives[key].add(start) + + # Find buckets for this PG + for key, collective_group in grouped_collectives.items(): + bucket_log.debug( + "bucketing collective group with key %s: %s", + key, + [n.name for n in collective_group], + ) + buckets = self._find_buckets(collective_group) + all_buckets.extend(buckets) # Apply bucketing transformations + # Dependencies are tracked in aug_graph.extra_deps during bucketing for coll_bucket in all_buckets: if len(coll_bucket.collectives) <= 1: continue - bucket_deps = self._apply_bucket(coll_bucket) - additional_deps.update(bucket_deps) + counters["inductor"]["collective_buckets"] += 1 + self._apply_bucket(coll_bucket) + + # Extract all dependencies from augmented graph + # This includes: + # - Sequential timeline deps (added during build_timeline) + # - Hiding interval deps (added during _add_hiding_interval_constraints) + # - All transferred deps from bucketing (transferred during _apply_bucket) + additional_deps = self.aug_graph.get_all_extra_deps() - # Apply topological sort with all the collected dependencies + # Apply topological sort with all dependencies from torch._dynamo.graph_deduplication import _stable_topological_sort _stable_topological_sort(self.graph, additional_deps) # After topological sort, preserve dependencies using effect tokens + # Only preserve edges where NOT both nodes are collective starts or waits if self.insert_overlap_deps: - self._preserve_dependencies_with_tokens(additional_deps) + filtered_deps: dict[fx.Node, OrderedSet[fx.Node]] = {} + for node, deps in additional_deps.items(): + filtered_node_deps: OrderedSet[fx.Node] = OrderedSet() + + # only preserve comm-comptue overlap for now, although we could more + # generally constrain + for dep in deps: + if not (is_collective_or_wait(node) and is_collective_or_wait(dep)): + filtered_node_deps.add(dep) + + if filtered_node_deps: + filtered_deps[node] = filtered_node_deps + + self._preserve_dependencies_with_tokens(filtered_deps) self.graph.lint() @@ -89,12 +305,14 @@ def _find_buckets( collective_group: OrderedSet[fx.Node], ) -> list[CollBucket]: """Find valid buckets within a group of similar collectives.""" - max_bucket_bytes = int(self.max_bucket_memory_gb * 1024 * 1024 * 1024) buckets = [] processed: OrderedSet[fx.Node] = OrderedSet() - for start_node in collective_group: + # Sort collectives by node index for efficient distance checking + sorted_collectives = sorted(collective_group, key=lambda n: self.node_idx[n]) + + for start_node in sorted_collectives: if start_node in processed: continue @@ -106,14 +324,18 @@ def _find_buckets( processed.add(start_node) start_node_idx = self.node_idx[start_node] - # TODO - limit within range - for candidate in collective_group: + # Check candidates in sorted order, break when beyond max distance + for candidate in sorted_collectives: if candidate in processed: continue candidate_idx = self.node_idx[candidate] # Check if candidate is within max distance from the bucket start - if abs(candidate_idx - start_node_idx) > self.max_coll_distance: + distance = abs(candidate_idx - start_node_idx) + if distance > self.max_coll_distance: + # Since sorted, all remaining candidates will be too far + if candidate_idx > start_node_idx: + break continue candidate_bytes = self.collective_info[candidate].size_bytes @@ -134,107 +356,441 @@ def _ancestor_dep(self, n1: fx.Node, n2: fx.Node) -> bool: """Check if there's an ancestor relationship between two nodes.""" return n1 in self.node_ancestors[n2] or n2 in self.node_ancestors[n1] - def _can_add_to_bucket( + def _get_intervals( + self, event: PGEvent + ) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]: + """Get (execution_interval, hiding_interval) for a collective event. + + Returns: + (execution_interval, hiding_interval) where: + - execution_interval is (start_pos, wait_pos) or None + - hiding_interval is (start_pos, compute_pos) or None if no hiding node + + Works for both start and wait events by looking up the collective info. + """ + # For start events, directly use the node + if event.is_start: + coll = event.node + # For wait events, look up the start node from the event's args + elif event.is_wait: + wait_input = event.node.args[0] + if not isinstance(wait_input, fx.Node): + return None, None + coll = wait_input + else: + return None, None + + if coll not in self.collective_info: + return None, None + + info = self.collective_info[coll] + start_event = self.node_to_event[coll] + wait_event = self.node_to_event[info.wait_node] + + execution_interval = (start_event.position, wait_event.position) + + hiding_interval = None + if info.hiding_node: + hiding_interval = ( + start_event.position, + self.node_to_event[info.hiding_node].position, + ) + + return execution_interval, hiding_interval + + def _preserves_hiding_intervals( self, bucket_info: CollBucket, candidate: fx.Node, + start_pos: fx.Node, + wait_pos: fx.Node, + why: WhyNoBucket, ) -> bool: """ - Check if candidate can be added to bucket without interfering - with comm/compute overlap. + Check that (start_pos, wait_pos) doesn't violate any hiding intervals or collectives. + + Collects all execution and hiding intervals in the affected timeline regions, + then checks: + 1. All bucket hiding compute stays between new start/wait + 2. No other collective's compute interval is enclosed by bucket execution interval + 3. No other collective's execution interval encloses bucket compute intervals """ + # Collect all collectives being bucketed + all_bucketed_colls = [candidate] + list(bucket_info.collectives) + all_bucketed_waits = [ + self.collective_info[coll].wait_node for coll in all_bucketed_colls + ] + + # Collect hiding compute positions for the bucket + bucket_hiding_compute_positions = [] + for coll in all_bucketed_colls: + if hiding_node := self.collective_info[coll].hiding_node: + bucket_hiding_compute_positions.append( + self.node_to_event[hiding_node].position + ) + + # Get new positions + new_start_event = self.node_to_event[start_pos] + new_wait_event = self.node_to_event[wait_pos] + + # Check 1: All bucket hiding compute must be between new start and wait + for compute_pos in bucket_hiding_compute_positions: + if not (new_start_event.position < compute_pos < new_wait_event.position): + why( + "hiding compute at pos %d not between start %d and wait %d", + compute_pos, + new_start_event.position, + new_wait_event.position, + ) + return False - candidate_info = self.collective_info[candidate] - candidate_wait = candidate_info.wait_node + def get_wait(n: fx.Node) -> fx.Node: + return self.collective_info[n].wait_node - # Step 1: Quick check using precomputed ancestors - # This will not be fully up to date because bucketing changes ancestors, - # however any ancestor at the start of bucketing will remain an ancestor. - for coll in bucket_info.collectives: - if self._ancestor_dep(coll, candidate): - return False + def get_pos(n: fx.Node) -> int: + return self.node_to_event[n].position - coll_wait = self.collective_info[coll].wait_node - if self._ancestor_dep(candidate_wait, coll_wait): - return False + latest_start_pos = max(get_pos(candidate), get_pos(bucket_info.collectives[0])) + earliest_wait_pos = min( + get_pos(get_wait(candidate)), get_pos(get_wait(bucket_info.collectives[0])) + ) - if hiding_node := self.collective_info[coll].hiding_node: - if self._ancestor_dep(hiding_node, candidate_wait): + # Bucket execution interval + bucket_execution_interval = (new_start_event.position, new_wait_event.position) + + # Because collectives on the same PG operate under LIFO semantics, + # it's only possible for us to force an early realization of an unrelated collective + # by delaying a start or raising a wait. + # We search in the interval from old_start -> new_start, to see if would be + # forcing another collective to be realized prior to its hiding nodes. + # Similarly, we search from old_wait -> new_wait, in the reverse direction, + # to check the same thing. + + execution_intervals = [bucket_execution_interval] + hiding_intervals = [ + (bucket_execution_interval[0], pos) + for pos in bucket_hiding_compute_positions + ] + + curr_event = new_start_event.next + while curr_event is not None and curr_event.position < latest_start_pos: + if ( + curr_event.node not in all_bucketed_colls + and curr_event.node not in all_bucketed_waits + ): + exec_interval, hiding_interval = self._get_intervals(curr_event) + if exec_interval: + execution_intervals.append(exec_interval) + if hiding_interval: + hiding_intervals.append(hiding_interval) + curr_event = curr_event.next + + curr_event = new_wait_event.prev + while curr_event is not None and curr_event.position > earliest_wait_pos: + if ( + curr_event.node not in all_bucketed_colls + and curr_event.node not in all_bucketed_waits + ): + exec_interval, hiding_interval = self._get_intervals(curr_event) + if exec_interval: + execution_intervals.append(exec_interval) + if hiding_interval: + hiding_intervals.append(hiding_interval) + curr_event = curr_event.prev + + # Check: no hiding interval should be enclosed by any execution interval + def enclosed_interval(inner: tuple[int, int], outer: tuple[int, int]) -> bool: + return outer[0] < inner[0] and inner[1] < outer[1] + + for hiding_interval in hiding_intervals: + for execution_interval in execution_intervals: + if enclosed_interval(hiding_interval, execution_interval): + why( + "hiding interval %s enclosed by execution interval %s", + hiding_interval, + execution_interval, + ) return False - if new_hiding_node := candidate_info.hiding_node: - if self._ancestor_dep(new_hiding_node, coll_wait): - return False + return True - # Step 2: Check and merge starts - # Check if there's a path between any existing start and candidate start. - # Because the collectives have already been merged, we can just start from one - # of them. - # TODO: we have a range of possible idxs of the merged node, and idx of new node. - # we should not do path search beyond that range - existing_coll = bucket_info.collectives[0] - if self.aug_graph.has_path(existing_coll, candidate): + def remove_from_event( + self, node: fx.Node + ) -> tuple[Optional[PGEvent], Optional[PGEvent]]: + """Remove node from timeline and return (prev_event, next_event).""" + event = self.node_to_event[node] + assert not event.is_compute, "Cannot remove compute events from timeline" + + prev_event, next_event = event.unlink() + + # Remove augmented graph dependency + if prev_event: + self.aug_graph.remove_extra_dep(n=node, dep=prev_event.node) + if next_event: + self.aug_graph.remove_extra_dep(n=next_event.node, dep=node) + + # Add bypass dependency + if prev_event and next_event: + self.aug_graph.add_extra_dep(n=next_event.node, dep=prev_event.node) + + return prev_event, next_event + + def restore_to_event( + self, + node: fx.Node, + prev_event: Optional[PGEvent], + next_event: Optional[PGEvent], + ) -> None: + """Restore node to timeline after failed merge attempt.""" + event = self.node_to_event[node] + + # Reinsert into linked list + event.insert_between(prev_event, next_event) + if prev_event: + self.aug_graph.add_extra_dep(n=node, dep=prev_event.node) + if next_event and not prev_event: + self.aug_graph.add_extra_dep(n=next_event.node, dep=node) + + # Remove bypass dependency + if prev_event and next_event: + self.aug_graph.remove_extra_dep(n=next_event.node, dep=prev_event.node) + + def _try_timeline_position( + self, + bucket_info: CollBucket, + candidate: fx.Node, + start_pos: fx.Node, + wait_pos: fx.Node, + why: WhyNoBucket, + ) -> bool: + """ + Try a specific timeline position for the candidate. + Returns True if valid and merges are successful. + """ + candidate_info = self.collective_info[candidate] + candidate_wait = candidate_info.wait_node + + # Quick check: does this violate hiding intervals? + if not self._preserves_hiding_intervals( + bucket_info, candidate, start_pos, wait_pos, why + ): return False - if self.aug_graph.has_path(candidate, existing_coll): + + # Determine which start needs to move + existing_coll = bucket_info.collectives[0] + if start_pos == existing_coll: + start_to_move = candidate + else: + assert start_pos == candidate + start_to_move = existing_coll + + # Remove start from timeline + start_prev, start_next = self.remove_from_event(start_to_move) + + # Check if starts can be merged + if self.aug_graph.has_path(existing_coll, candidate) or self.aug_graph.has_path( + candidate, existing_coll + ): + # Restore start constraints + self.restore_to_event(start_to_move, start_prev, start_next) + why("path exists between starts") return False - # Safe to merge starts - do the merge + # Merge starts self.aug_graph.merge_to_set(existing_coll, candidate) - # Step 3: Check and merge waits + # Determine which wait needs to move existing_wait = self.collective_info[existing_coll].wait_node - candidate_wait = candidate_info.wait_node - # TODO - as above, limit search by idx + candidate_wait = self.collective_info[candidate].wait_node + + if wait_pos == existing_wait: + wait_to_move = candidate_wait + else: + wait_to_move = existing_wait + + # Remove wait from timeline + wait_prev, wait_next = self.remove_from_event(wait_to_move) + + # Check if waits can be merged if self.aug_graph.has_path( existing_wait, candidate_wait ) or self.aug_graph.has_path(candidate_wait, existing_wait): + # Restore wait constraints + self.restore_to_event(wait_to_move, wait_prev, wait_next) # Unmerge the start we just merged self.aug_graph.unmerge_node(candidate) + # Restore start constraints + self.restore_to_event(start_to_move, start_prev, start_next) + why("path exists between waits") return False + # Merge waits - success! self.aug_graph.merge_to_set(existing_wait, candidate_wait) + + # Update node_to_event for moved nodes + target_start_event = self.node_to_event[start_pos] + target_wait_event = self.node_to_event[wait_pos] + + self.node_to_event[candidate] = target_start_event + self.node_to_event[candidate_wait] = target_wait_event + self.node_to_event[existing_coll] = target_start_event + self.node_to_event[existing_wait] = target_wait_event + return True - def _apply_bucket( - self, bucket_info: CollBucket - ) -> dict[fx.Node, OrderedSet[fx.Node]]: - """Apply bucketing transformation and return dependencies to preserve.""" + def _has_ancestor_conflicts( + self, bucket_info: CollBucket, candidate: fx.Node + ) -> bool: + """ + Check if candidate has ancestor conflicts with bucket collectives. + Returns True if there are conflicts. + """ + candidate_info = self.collective_info[candidate] + candidate_wait = candidate_info.wait_node + + for coll in bucket_info.collectives: + # Check if collectives are ancestors of each other + if self._ancestor_dep(coll, candidate): + return True + + # Check if waits are ancestors of each other + coll_wait = self.collective_info[coll].wait_node + if self._ancestor_dep(candidate_wait, coll_wait): + return True + + # Check if existing hiding node conflicts with candidate wait + if hiding_node := self.collective_info[coll].hiding_node: + if self._ancestor_dep(hiding_node, candidate_wait): + return True + + # Check if candidate hiding node conflicts with existing wait + if new_hiding_node := candidate_info.hiding_node: + if self._ancestor_dep(new_hiding_node, coll_wait): + return True + + return False + + def _can_add_to_bucket( + self, + bucket_info: CollBucket, + candidate: fx.Node, + ) -> bool: + """ + Check if candidate can be added to bucket without breaking comm/compute overlap. + + Strategy: Try all timeline positions - combinations of [existing_start, candidate_start] + x [existing_wait, candidate_wait]. For each position, verify: + 1. Hiding intervals preserved - for any (start, hiding_compute, wait) interval, no other + collective's (start, wait) pair falls between start and hiding_compute, which would + force realization and break overlap due to LIFO semantics + 2. Topologically valid (no dependency cycles) + + Return True if any timeline position satisfies both constraints. + """ + existing_coll = bucket_info.collectives[0] + why = WhyNoBucket(existing_coll, candidate) + + candidate_info = self.collective_info[candidate] + + # Step 1: Quick check using precomputed ancestors + # These ancestors are computed prior to adding augmented dependencies and not updated, + # so if any of these checks fail then the merge will not be topologically valid + # even ignoring comm/compute overlap + if self._has_ancestor_conflicts(bucket_info, candidate): + why("has ancestor conflicts") + return False + + # Step 2: Try different rail positions + existing_wait = self.collective_info[existing_coll].wait_node + + candidate_start = candidate + candidate_wait = candidate_info.wait_node + + # Try combinations in order of likelihood to succeed + # (early start, later wait is most likely to work) + combinations = [ + ( + existing_coll, + candidate_wait, + ), # Move candidate start early, keep wait late + ( + existing_coll, + existing_wait, + ), # Move candidate start early, move wait early + (candidate_start, candidate_wait), # Keep both in place + (candidate_start, existing_wait), # Keep start in place, move wait early + ] + + for i, (start_pos, wait_pos) in enumerate(combinations): + if self._try_timeline_position( + bucket_info, candidate, start_pos, wait_pos, why + ): + bucket_log.debug( + "bucketed %s with %s using timeline position %d: (start=%s, wait=%s)", + candidate.name, + existing_coll.name, + i + 1, + start_pos.name, + wait_pos.name, + ) + return True + + why("all timeline positions failed") + return False + + def _apply_bucket(self, bucket_info: CollBucket) -> None: + """ + Apply bucketing transformation. + + Dependencies are added to aug_graph.extra_deps and transferred from old nodes. + """ from torch._inductor.fx_passes.bucketing import ( + is_all_reduce_tensor, merge_all_gather_bucket, + merge_all_reduce_bucket, merge_reduce_scatter_bucket, ) bucket = bucket_info.collectives + # Collect old nodes BEFORE they're erased + old_starts = list(bucket) + old_waits = [self.collective_info[n].wait_node for n in bucket] + # Find where to place the bucketed operations next_node = bucket[0] while next_node in bucket: next_node = next_node.next - waits = [self.collective_info[n].wait_node for n in bucket] - first_wait = min(waits, key=lambda w: self.node_idx[w]) - # Create bucketed collective + # Don't use wait_insertion_point - let merge functions place waits naturally + # The wait_insertion_point feature tries to move waits to a specific location, + # but this can cause issues when that location is one of the nodes being erased + # Create bucketed collective (this will erase old nodes) if is_all_gather(bucket[0]): new_nodes, replacements = merge_all_gather_bucket( self.graph, bucket, - wait_insertion_point=first_wait, insert_before=next_node, mode="custom_ops", ) + elif is_all_reduce_tensor(bucket[0]): + new_nodes, replacements = merge_all_reduce_bucket( + self.graph, + bucket, + mode="custom_ops", + insert_before=next_node, + ) else: assert is_reduce_scatter(bucket[0]) new_nodes, replacements = merge_reduce_scatter_bucket( self.graph, bucket, - wait_insertion_point=first_wait, insert_before=next_node, mode="custom_ops", ) - # Build dependencies to preserve overlap - # replacements maps old_start -> new_start, old_wait -> new_wait + # Get new nodes new_waits = [n for n in new_nodes if is_wait_tensor(n)] assert len(new_waits) == 1 @@ -242,18 +798,15 @@ def _apply_bucket( new_start = new_wait.args[0] assert isinstance(new_start, fx.Node) - overlap_deps: dict[fx.Node, OrderedSet[fx.Node]] = defaultdict(OrderedSet) - - # Create dependencies to preserve overlap - for coll in bucket: - info = self.collective_info[coll] - if info.hiding_node and not info.is_exposed: - # Compute depends on collective start - overlap_deps[info.hiding_node].add(new_start) - # Wait depends on compute - overlap_deps[new_wait].add(info.hiding_node) + # Create mapping of all erased nodes to their replacements + erased_to_new = {} + for old_start in old_starts: + erased_to_new[old_start] = new_start + for old_wait in old_waits: + erased_to_new[old_wait] = new_wait - return overlap_deps + # Transfer all dependencies from old nodes to new nodes + self.aug_graph.transfer_erased_node_deps(erased_to_new) def _preserve_dependencies_with_tokens( self, additional_deps: dict[fx.Node, OrderedSet[fx.Node]] diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index 3575b2b49efbb..a47aa960e58c5 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -18,8 +18,8 @@ MemoryTracker, ) from torch.fx.operator_schemas import normalize_function -from torch.utils._mode_utils import no_dispatch from torch.utils._ordered_set import OrderedSet +from torch.utils._python_dispatch import _disable_current_modes log = logging.getLogger(__name__) @@ -136,7 +136,7 @@ def to_real(t: torch.Tensor) -> torch.Tensor | None: key += f"T: {shape, stride, t.dtype} " return rand_strided(shape, stride, device=t.device, dtype=t.dtype) # type: ignore[arg-type] - with no_dispatch(): + with _disable_current_modes(): args, kwargs = torch.utils._pytree.tree_map_only( torch.Tensor, lambda t: to_real(t), @@ -732,12 +732,12 @@ def should_assume_bucketed(self, node: fx.Node) -> bool: if not torch._inductor.config.test_configs.assume_bucketing_reduces_latency: return False - key = bucket_key(node) + key = bucket_key(node, mode="custom_ops_multidtype") if key is None: return False for in_flight_coll in self.in_flight.keys(): - if bucket_key(in_flight_coll) == key: + if bucket_key(in_flight_coll, mode="custom_ops_multidtype") == key: return True return False diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 909b199cf4da2..30768fda9bb72 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -305,7 +305,7 @@ def fmt_pad(name: str) -> str | None: def get_non_view_def(node: torch.fx.Node) -> torch.fx.Node: - if node.op == operator.getitem: + if node.op is operator.getitem: return get_non_view_def(node.args[0]) # type: ignore[arg-type] if ( @@ -344,7 +344,7 @@ def should_exclude_padding_time(match: Match, arg_name: str) -> bool: return False if ( - node_def.target == aten.cat.default + node_def.target is aten.cat.default and len(node_def.all_input_nodes) > torch._inductor.config.max_pointwise_cat_inputs ): diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index bc5e03ea44fc1..f11817e1d4c51 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -454,7 +454,7 @@ def body_fn(*flat_args): graph_pass.apply(gm) - for node in gm.graph.find_nodes( + for _node in gm.graph.find_nodes( op="call_function", target=torch.ops.higher_order.map_impl ): raise AssertionError("map is not lowered to while_loop") @@ -585,7 +585,7 @@ def lower_to_while_loop(*args, **kwargs): # NOTE [Pre-allocate scan's output buffer] # In order to pre-allocate the output buffer for ys, we rely on the meta of scan's fx_node. # However, the meta consists of concrete symints, we need to bind those symints with - # proxies in order to trace the torch.empyt_strided call correctly. + # proxies in order to trace the torch.empty_strided call correctly. # # Also note that basic free symbols of tensor's shapes are guaranteed to be lifted as subgraph inputs # in dynamo so we can always re-construct the sym expression from placeholders. @@ -666,7 +666,7 @@ def body_fn(*flat_args): graph_pass.apply(gm) - for node in gm.graph.find_nodes( + for _node in gm.graph.find_nodes( op="call_function", target=torch.ops.higher_order.scan ): raise AssertionError("scan is not lowered to while_loop") @@ -1265,7 +1265,7 @@ def decomp(*flat_args): graph_pass.apply(graph) - for node in graph.find_nodes( + for _ in graph.find_nodes( op="call_function", target=torch.ops.higher_order.triton_kernel_wrapper_functional, ): @@ -1867,7 +1867,7 @@ def __call__(self, graph: fx.Graph) -> None: noop_device_puts = [ user for user in gpu_node.users - if user.target == torch.ops.prims.device_put.default + if user.target is torch.ops.prims.device_put.default and user.args[1] == target_device ] for noop in noop_device_puts: diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index b953a7ad01a23..051c75b2c2a90 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -737,7 +737,7 @@ def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: input_node = node.kwargs["input"] if ( input_node.op == "call_function" - and input_node.target == torch.nn.functional.linear + and input_node.target is torch.nn.functional.linear ): normalized = NormalizedLinearNode(input_node) input = normalized.get_input() diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 80bb9a05e2aae..a0567da118109 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -1175,8 +1175,8 @@ def woq_int8(match: Match, *args, **kwargs): mm_node_of_x = None for candidate in iter(x.users.keys()): if ( - candidate.target == aten.mm.default - and list(candidate._input_nodes)[1].target == aten.cat.default + candidate.target is aten.mm.default + and list(candidate._input_nodes)[1].target is aten.cat.default ): mm_node_of_x = candidate break diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index 8b9deac6ba5a5..52222f3da8344 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -434,7 +434,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None: for i, node in enumerate(reversed(graph.nodes)): node_order[node] = len(graph.nodes) - i - 1 storage_to_nodes[get_node_storage(node)].append(node) - if node.target == aten.copy_.default and node.args[0].op in ( + if node.target is aten.copy_.default and node.args[0].op in ( "placeholder", "get_attr", ): @@ -442,13 +442,13 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None: src = node.args[1] # If the target is a getitem and it indexes a possible clone, # then skip over it - if src.target == operator.getitem and ( + if src.target is operator.getitem and ( ( src.args[0].target == triton_kernel_wrapper_functional and src.args[0].kwargs["kwargs"][src.args[1]] == node.args[0] ) or (src.args[0].target in inplaceable_foreach_ops) - or (src.args[0].target == torch.ops.higher_order.auto_functionalized) + or (src.args[0].target is torch.ops.higher_order.auto_functionalized) ): src = src.args[0] @@ -643,7 +643,7 @@ def tensor_with_same_storage_already_reinplaced(arg): # output atindex size(out)+i. # This used to compare string with integers before for auto_functionalize_v2. Not sure # if it was needed for inplaceable_triton_ops? - if user.target == operator.getitem and user.args[1] == arg: + if user.target is operator.getitem and user.args[1] == arg: replace_dict[user] = mutated_arg if isinstance(mutated_arg, (list, tuple)): @@ -679,7 +679,7 @@ def tensor_with_same_storage_already_reinplaced(arg): if copy_node is not None: replace_dict[copy_node] = copy_node.args[0] node.target = inplaceable_op.inplace_op - elif node.target == torch.ops.higher_order.auto_functionalized_v2: + elif node.target is torch.ops.higher_order.auto_functionalized_v2: _mutable_op = node.args[0] kwargs = node.kwargs @@ -696,7 +696,7 @@ def tensor_with_same_storage_already_reinplaced(arg): # auto_functionalized into clones + a mutable op; this metadata # tells the decomp to only clone the following inputs node.meta["only_clone_these_tensors"] = new_bases_to_clone - elif node.target == torch.ops.higher_order.auto_functionalized: + elif node.target is torch.ops.higher_order.auto_functionalized: _mutable_op = node.args[0] from torch._higher_order_ops.auto_functionalize import get_mutable_args diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 15ea6867dba38..92e1e6f375f44 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -127,12 +127,12 @@ def _get_dim(node: Any): if "dim" in node.kwargs: assert isinstance(node.kwargs["dim"], int) return node.kwargs["dim"] - if node.target == torch.unbind: + if node.target is torch.unbind: if len(node.args) == 2: assert isinstance(node.args[-1], int) return node.args[-1] return 0 # defaults to dim=0 - if node.target == torch.split: + if node.target is torch.split: if len(node.args) == 3: assert isinstance(node.args[-1], int) return node.args[-1] @@ -351,7 +351,7 @@ def is_empty_tensor(x): cat_node.args == new_args and cat_node.kwargs == new_kwargs and cat_node.op == "call_function" - and cat_node.target == torch.cat + and cat_node.target is torch.cat ): return @@ -866,7 +866,7 @@ def get_transform_params( cat_dim = get_arg_value(user_node, 1, "dim") transform_params: list[_TransformParam] = [] for user_input in user_inputs: - if split_dim == cat_dim and user_node.target == torch.cat: + if split_dim == cat_dim and user_node.target is torch.cat: # No transform needed transform_params.append((None, None, None, None)) elif isinstance(user_input, tuple): # Split being simplified @@ -888,7 +888,7 @@ def get_transform_params( (unflatten_params, movedim_params, None, None) ) elif ( - user_node.target == torch.stack or split_dim != cat_dim + user_node.target is torch.stack or split_dim != cat_dim ): # We need to unsqueeze inputs not coming through split transform_params.append((None, None, (cat_dim,), None)) else: # Non-split inputs @@ -1107,9 +1107,9 @@ def replace_cat( ) if ( - user_node.target == torch.cat + user_node.target is torch.cat and split_dim != cat_dim - and split_node.target == torch.split + and split_node.target is torch.split ): with graph.inserting_after(new_cat_node): new_cat_node_meta = new_cat_node.meta["example_value"] @@ -1225,13 +1225,13 @@ def get_transform_params( (split_dim, cat_dim) if split_dim != cat_dim else None ) flatten_params = None - if user_node.target == torch.cat: + if user_node.target is torch.cat: flatten_params = (cat_dim, cat_dim + 1) transform_params.append( (None, movedim_params, None, flatten_params) ) elif ( - user_node.target == torch.stack + user_node.target is torch.stack ): # We need to unsqueeze inputs not coming through unbind into cat transform_params.append((None, None, (cat_dim,), None)) else: # Non-unbind inputs @@ -1298,13 +1298,13 @@ def merge_split_squeeze( match: Match, split_input: torch.fx.Node, split_sizes: list[int], dim: int ): graph = match.graph - split = next(node for node in match.nodes if node.target == torch.split) + split = next(node for node in match.nodes if node.target is torch.split) if not all(s == 1 for s in split_sizes): return if isinstance(dim, Sequence): return next_users = find_next_users(split) - if not all(node.target == torch.squeeze for node in next_users): + if not all(node.target is torch.squeeze for node in next_users): return with graph.inserting_before(match.output_node()): unbind = graph.call_function( @@ -1364,7 +1364,7 @@ def merge_split_squeeze( pass_dict=construct_pattern_matcher_pass("unbind_stack_pass"), ) def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int): - unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + unbind_node = next(node for node in match.nodes if node.target is torch.unbind) UnbindCatRemover().remove_unbind(match.graph, unbind_node) @@ -1431,7 +1431,7 @@ def merge_unbind_stack(match: Match, unbind_input: torch.fx.Node, dim: int): def simplify_split_cat(match: Match, split_sections: list[int], dim: int): if not isinstance(split_sections, (list, tuple)): # Unnormalized split return - split_node = next(node for node in match.nodes if node.target == torch.split) + split_node = next(node for node in match.nodes if node.target is torch.split) # pyrefly: ignore [bad-argument-type] SplitCatSimplifier().simplify(match.graph, split_node, split_sections) @@ -1518,7 +1518,7 @@ def merge_getitem_cat(match: Match, split_sections: list[int], dim: int): if not isinstance(split_sections, (list, tuple)): # Unnormalized split return graph = match.graph - split_node = next(node for node in match.nodes if node.target == torch.split) + split_node = next(node for node in match.nodes if node.target is torch.split) split_input, _split_size, split_dim = _get_split_args_default(split_node) # if the cat and split have different dims, return # Find the next users (i.e. users after the getitem) @@ -1526,7 +1526,7 @@ def merge_getitem_cat(match: Match, split_sections: list[int], dim: int): # 'immutable_list' object does not support mutation. Create a new copy of it split_sections = list(split_sections) for cat_user in next_users: - if cat_user.target == torch.cat: + if cat_user.target is torch.cat: cat_dim = get_arg_value(cat_user, 1, "dim") # check the all getitems in the cat_user from the same node # check the input of the cat has all getitem from the split @@ -1625,13 +1625,13 @@ def mutate_cat_node(match: Match, split_sections: list[int], dim: int): if not isinstance(split_sections, (list, tuple)): # Unnormalized split return graph = match.graph - split_node = next(node for node in match.nodes if node.target == torch.split) + split_node = next(node for node in match.nodes if node.target is torch.split) _split_input, _split_size, split_dim = _get_split_args_default(split_node) # if the cat and split have different dims, return # Find the next users (i.e. users after the getitem) next_users = find_next_users(split_node) for cat_user in next_users: - if cat_user.target == torch.cat: + if cat_user.target is torch.cat: cat_dim = get_arg_value(cat_user, 1, "dim") or 0 # check that all getitems in the cat_user from the same node # check the input of the cat has all getitem from the split @@ -1904,7 +1904,7 @@ def merge_select_cat_aten(match: Match, *args, **kwargs): # get the select nodes from the node select_nodes = list(node_input.users.keys()) for cat_node in list(node.users.keys()): - if cat_node.target == torch.ops.aten.cat.default: + if cat_node.target is torch.ops.aten.cat.default: cat_dim = get_arg_value(cat_node, 1, "dim") cat_inputs = get_arg_value(cat_node, 0, "tensors") # check all select nodes has same slice dim @@ -2010,7 +2010,7 @@ def merge_unbind_stack_aten(match: Match, *args, **kwargs): cat_dim = get_arg_value(node, 1, "dim") # check the unsqueeze nodes come from the select nodes if not all( - get_arg_value(unsqueeze_node, 0, "input").target == torch.ops.aten.select + get_arg_value(unsqueeze_node, 0, "input").target is torch.ops.aten.select for unsqueeze_node in unsqueeze_nodes ): return @@ -2020,7 +2020,7 @@ def merge_unbind_stack_aten(match: Match, *args, **kwargs): parent_of_select_node = get_arg_value(select_nodes[0], 0, "input") # check the target of select_nodes are the same if not all( - select_node.target == torch.ops.aten.select for select_node in select_nodes + select_node.target is torch.ops.aten.select for select_node in select_nodes ): return # check the select nodes come from the same parent node @@ -2319,7 +2319,7 @@ def construct_cat_args( def remove_split_unbind_children(graph: torch.fx.Graph, inputs: list[torch.fx.Node]): nodes = OrderedSet[Any]() for input in inputs: - if input.target == operator.getitem: + if input.target is operator.getitem: nodes.add(input.args[0]) # type: ignore[union-attr] if len(input.users.keys()) == 0: graph.erase_node(input) @@ -2357,7 +2357,7 @@ def remove_split_unbind_children(graph: torch.fx.Graph, inputs: list[torch.fx.No def split_cat_to_slices(match: Match, split_sections: list[int], dim: int): if not isinstance(split_sections, (list, tuple)): # Unnormalized split return - split_nodes = [node for node in match.nodes if node.target == torch.split] + split_nodes = [node for node in match.nodes if node.target is torch.split] if split_nodes: split_node = next(node for node in split_nodes) else: @@ -2438,7 +2438,7 @@ def split_cat_to_slices(match: Match, split_sections: list[int], dim: int): pass_dict=construct_pattern_matcher_pass("unbind_cat_to_view_pass"), ) def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int): - unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + unbind_node = next(node for node in match.nodes if node.target is torch.unbind) graph = match.graph # get the cat_node and check its inputs and meta data next_users = find_next_users(unbind_node) @@ -2614,7 +2614,7 @@ def convert_reshape_cat_arg_to_stack( def split_stack_to_cats(match: Match, split_sections: list[int], dim: int): if not isinstance(split_sections, (list, tuple)): # Unnormalized split return - split_node = next(node for node in match.nodes if node.target == torch.split) + split_node = next(node for node in match.nodes if node.target is torch.split) split_dim = get_arg_value(split_node, 2, "dim") or 0 graph = match.graph threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ @@ -2685,7 +2685,7 @@ def split_stack_to_cats(match: Match, split_sections: list[int], dim: int): pass_dict=construct_pattern_matcher_pass("unbind_stack_to_slices_pass"), ) def unbind_stack_to_slices(match: Match, unbind_input: torch.fx.Node, dim: int): - unbind_node = next(node for node in match.nodes if node.target == torch.unbind) + unbind_node = next(node for node in match.nodes if node.target is torch.unbind) graph = match.graph # get the cat_node and check its inputs and meta data next_users = find_next_users(unbind_node) @@ -2755,13 +2755,13 @@ def get_view_shape_list(cat_arg: torch.fx.Node, stack_dim: int) -> list[int]: # cat_arg must be the split input view_shape_list = [] for user in cat_arg.users.keys(): - if user.target == torch.split: + if user.target is torch.split: for getitem in user.users.keys(): - if getitem.target == operator.getitem: + if getitem.target is operator.getitem: reshape_user = [ user for user in getitem.users.keys() - if user.target == torch.reshape + if user.target is torch.reshape ] if len(reshape_user) > 0: view_shape_list = list( @@ -2785,10 +2785,10 @@ def get_view_shape_list(cat_arg: torch.fx.Node, stack_dim: int) -> list[int]: pass_dict=construct_pattern_matcher_pass("move_reshape_out_of_split_stack_pass"), ) def move_reshape_out_of_split_stack(match: Match, *args, **kwargs): - split_node = next(node for node in match.nodes if node.target == torch.split) + split_node = next(node for node in match.nodes if node.target is torch.split) split_dim = _get_dim(split_node) split_users = list(split_node.users.keys()) - stack_nodes = [node for node in match.nodes if node.target == torch.stack] + stack_nodes = [node for node in match.nodes if node.target is torch.stack] graph = match.graph threshold_to_cat = torch._inductor.config.pre_grad_fusion_options[ "move_reshape_out_of_split_stack_pass" @@ -2926,17 +2926,17 @@ def move_view_after_cat(match: Match, *args, **kwargs): split_node = next( node for node in match.nodes - if node.target == torch.ops.aten.split_with_sizes.default + if node.target is torch.ops.aten.split_with_sizes.default ) split_input, split_section, split_dim = _get_split_args_default(split_node) split_users = list(split_node.users.keys()) getitem_indices = [ - getitem.args[1] for getitem in split_users if getitem.target == operator.getitem + getitem.args[1] for getitem in split_users if getitem.target is operator.getitem ] if not is_sorted_and_consecutive(getitem_indices): # type: ignore[arg-type] return cat_nodes = [ - node for node in match.nodes if node.target == torch.ops.aten.cat.default + node for node in match.nodes if node.target is torch.ops.aten.cat.default ] graph = match.graph for cat_node in cat_nodes: @@ -2950,13 +2950,13 @@ def move_view_after_cat(match: Match, *args, **kwargs): continue # check if the cat inputs are all the view nodes if not all( - view_node.target == torch.ops.aten.reshape.default + view_node.target is torch.ops.aten.reshape.default for view_node in cat_inputs ): continue # check if the view nodes are all from getitem nodes if not all( - view_node.args[0].target == operator.getitem for view_node in cat_inputs + view_node.args[0].target is operator.getitem for view_node in cat_inputs ): continue view_indices = [view.args[0].args[1] for view in cat_inputs] @@ -3037,7 +3037,7 @@ def should_replace_einsum(einsum_node) -> bool: and is_node_meta_valid(input) and is_node_meta_valid(weights) and any( - user.target == "add" or user.target == operator.add for user in users + user.target == "add" or user.target is operator.add for user in users ) and match_einsum_strings(equation) ) diff --git a/torch/_inductor/fx_utils.py b/torch/_inductor/fx_utils.py index ec3a1d83d9248..adeca75ff53ab 100644 --- a/torch/_inductor/fx_utils.py +++ b/torch/_inductor/fx_utils.py @@ -146,7 +146,7 @@ def any_user_may_alias(node): (torch._ops.OpOverload, torch._ops.HigherOrderOperator), ) or user.target - == torch._inductor.fx_passes.reinplace._generalized_scatter + is torch._inductor.fx_passes.reinplace._generalized_scatter ): return True if isinstance(user.target, torch._ops.HigherOrderOperator): @@ -200,9 +200,9 @@ def should_process_node(node): # tensors from an op. return node.op == "call_function" and ( isinstance(node.target, torch._ops.OpOverload) - or node.target == operator.getitem + or node.target is operator.getitem or node.target - == torch._inductor.fx_passes.reinplace._generalized_scatter + is torch._inductor.fx_passes.reinplace._generalized_scatter ) to_process = OrderedSet[int]() diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 2e83e6b3a694b..5a552e096f75f 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -661,7 +661,7 @@ def decide_layout_opt(gm: GraphModule, *, is_inference: bool) -> bool: return True conv_nodes = [ - n for n in gm.graph.nodes if n.target == torch.ops.aten.convolution.default + n for n in gm.graph.nodes if n.target is torch.ops.aten.convolution.default ] nconv = len(conv_nodes) @@ -860,7 +860,7 @@ def find_nodes_prefer_channels_last(self) -> OrderedSet[Node]: nodes_cannot_propagate = [torch.ops.aten.bmm.default] output_set = OrderedSet[Node]() for n in reversed(self.module.graph.nodes): # type: ignore[arg-type, union-attr] - if n.target == torch.ops.aten.convolution.default: + if n.target is torch.ops.aten.convolution.default: output_set.add(n) if last_conv is None: last_conv = n @@ -1988,7 +1988,7 @@ def make_assert(expr: SympyBoolean, msg: str) -> None: if ( full_aoti_runtime_assert() - and n.target == torch.ops.aten._assert_scalar.default + and n.target is torch.ops.aten._assert_scalar.default and self.aot_mode ): node_args, _ = self.fetch_args_kwargs_from_env(n) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index a5da990e4ba24..12f13cfdb0c77 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -64,6 +64,7 @@ compute_unbacked_bindings, free_symbols, free_unbacked_symbols, + IterateExprs, rebind_unbacked, resolve_unbacked_bindings, ShapeEnv, @@ -97,6 +98,7 @@ argsort, argsort_sym, cache_on_self, + cache_on_self_and_args, ceildiv, convert_shape_to_inductor, convert_shape_to_symint, @@ -933,6 +935,7 @@ class Loops(IRNode): inner_fn: Callable[..., Any] ranges: Sequence[_IntLike] + @cache_on_self_and_args("Loops") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -1228,6 +1231,7 @@ def __str__(self) -> str: __repr__ = __str__ + @cache_on_self_and_args("Reduction") def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: return super().get_free_symbol_uses(unbacked_only) | OrderedSet().union( *(get_free_symbols(e, unbacked_only) for e in self.reduction_ranges) @@ -1510,6 +1514,10 @@ def create( reduction_hint: ReductionHint = ReductionHint.DEFAULT, input_node: Optional[IRNode] = None, ) -> Union[TensorBox, ShapeAsConstantBuffer]: + """ + Create a reduction node. May split the reduction to multiple layers to expose + more parallelism. + """ reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) if reduction_numel == 0: @@ -1636,7 +1644,7 @@ def _maybe_increase_split(split: int) -> int: ) elif split > 1: # triton doesn't support reduce to single element well, so break it up - return cls.create_multilayer( + out = cls.create_multilayer( device, dst_dtype, src_dtype, @@ -1649,7 +1657,47 @@ def _maybe_increase_split(split: int) -> int: input_node, ) - return TensorBox.create( + # Find the reduction that get split + split_reduction = None + if config.triton.mix_order_reduction and isinstance(out, TensorBox): + + def _find_split_reduction( + cur_node: TensorBox, + ) -> Optional[ComputedBuffer]: + read_names = cur_node.get_read_names() + if len(read_names) != 1: + return None + + bufname = next(iter(read_names)) + if bufname not in V.graph.name_to_buffer: + return None + buf = V.graph.name_to_buffer[bufname] + if not isinstance(buf, ComputedBuffer): + return None + + assert buf.data.get_reduction_type() is not None + + return buf + + split_reduction = _find_split_reduction(out) + + if split_reduction: + # If a reduction is split to more than 2 layers, + # say there are 3 layers, + # we always have the correct setting for layer1 (top layer). + # The setting on layer2 may be incorrect but it's fine + # since they are never get used. + # TODO: should we skip setting these fields for layer2 + assert isinstance(split_reduction.data, Reduction), ( + f"{type(split_reduction.data)}" + ) + split_reduction._split_size = split_reduction.data.reduction_ranges[0] + split_reduction._original_inner_fn = inner_fn + split_reduction._original_ranges = ranges + split_reduction._original_reduction_ranges = reduction_ranges + return out + + out = TensorBox.create( Reduction( device=device, dtype=dst_dtype, @@ -1661,6 +1709,7 @@ def _maybe_increase_split(split: int) -> int: reduction_hint=reduction_hint, ) ) + return out @staticmethod def default_accumulator( @@ -2327,6 +2376,7 @@ class Scan(Loops): # HACK we mimic reduction + @cache_on_self_and_args("Scan") def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: # TODO: Can combine_fn/reindex close over unbacked symbols? If so, we # need to explicitly represent the closure so we can pull out unbacked @@ -2537,6 +2587,7 @@ class Sort(Loops): # HACK we mimic reduction + @cache_on_self_and_args("Sort") def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: return ( super().get_free_symbol_uses(unbacked_only) @@ -2785,6 +2836,7 @@ def is_unaligned(node: IRNode) -> bool: class BaseView(IRNode): data: IRNode + @cache_on_self_and_args("BaseView") def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]: return self.data.get_free_symbol_uses(unbacked_only) @@ -3359,6 +3411,7 @@ def get_layout(self) -> Layout: def freeze_layout(self) -> None: pass + @cache_on_self_and_args("ReinterpretView") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -3643,13 +3696,37 @@ def __init__( self.dtype = dtype assert len(size) == len(stride), f"size={size}, stride={stride}" assert all(isinstance(s, (Expr, int)) for s in size) - self.size = size - self.stride = stride - self.offset = offset + self._size = size + self._stride = stride + self._offset = offset self.is_pinned = is_pinned # is_pinned implies cpu assert (not self.is_pinned) or (self.device.type == "cpu") + @property + def size(self) -> Sequence[Expr]: + return self._size + + @size.setter + def size(self, value: Sequence[Expr]) -> None: + self._size = value + + @property + def stride(self) -> Sequence[Expr]: + return self._stride + + @stride.setter + def stride(self, value: Sequence[Expr]) -> None: + self._stride = value + + @property + def offset(self) -> Expr: + return self._offset + + @offset.setter + def offset(self, value: Expr) -> None: + self._offset = value + def __str__(self) -> str: offset = "" if self.offset != 0: @@ -3869,6 +3946,7 @@ def __eq__(self, other: object) -> bool: def storage_size(self) -> Expr: return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type] + @cache_on_self_and_args("Layout") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -3888,7 +3966,11 @@ def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: class FlexibleLayout(Layout): - """A Tensor layout that we are allowed to change""" + """ + A Tensor layout that we are allowed to change + + Assumption: layout change should NOT add or remove free symbols + """ allow_indexing = False @@ -3973,6 +4055,33 @@ def same_ordered( fill_order = sorted(range(len(stride)), key=stride.__getitem__) return FlexibleLayout.fill_ordered(sizes, fill_order) + @property + def size(self) -> Sequence[Expr]: + return self._size + + @size.setter + def size(self, value: Sequence[Expr]) -> None: + self.assert_free_symbol_uses_unchanged("size", value) + self._size = value + + @property + def stride(self) -> Sequence[Expr]: + return self._stride + + @stride.setter + def stride(self, value: Sequence[Expr]) -> None: + self.assert_free_symbol_uses_unchanged("stride", value) + self._stride = value + + @property + def offset(self) -> Expr: + return self._offset + + @offset.setter + def offset(self, value: Expr) -> None: + self.assert_free_symbol_uses_unchanged("offset", value) + self._offset = value + def as_stride_order( self, order: Sequence[int], allow_padding: bool = False ) -> FixedLayout: @@ -4031,6 +4140,25 @@ def as_same_order(self, stride: Sequence[_IntLike]) -> FixedLayout: self.is_pinned, ) + def get_initial_free_symbol_uses(self) -> dict[tuple[str, bool], sympy.Symbol]: + initial_free_symbols = {} + for name in ["size", "stride", "offset"]: + for unbacked_only in [True, False]: + key = (name, unbacked_only) + initial_free_symbols[key] = OrderedSet( + get_free_symbols(getattr(self, name), unbacked_only) + ) + + return initial_free_symbols + + def assert_free_symbol_uses_unchanged(self, name: str, value: IterateExprs) -> None: + for unbacked_only in [True, False]: + old_free_symbols = self.initial_free_symbols[(name, unbacked_only)] + new_free_symbols = OrderedSet(get_free_symbols(value, unbacked_only)) + assert new_free_symbols == old_free_symbols, ( + f"Expected free symbols unchanged, but got {new_free_symbols} vs {old_free_symbols}" + ) + def __init__( self, device: torch.device, @@ -4045,6 +4173,10 @@ def __init__( strides = FlexibleLayout.contiguous_strides(size) super().__init__(device, dtype, size, strides, is_pinned=is_pinned) + # record the initial free symbols to check that we do not add new free symbols + # later when modifying sizes, strides, and offsets. + self.initial_free_symbols = self.get_initial_free_symbol_uses() + class NonOwningLayout(Layout): """Is a view into the storage of another tensor""" @@ -4070,6 +4202,7 @@ def maybe_guard_aligned(self) -> bool: return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) + @cache_on_self_and_args("NonOwningLayout") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -4358,6 +4491,7 @@ def get_mutation_names(self) -> Sequence[str]: def get_read_names(self) -> OrderedSet[str]: return OrderedSet([self.get_name()]) + @cache_on_self_and_args("Buffer") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -4430,6 +4564,7 @@ class NoneAsConstantBuffer(IRNode): def get_reads(self) -> OrderedSet[Dep]: return OrderedSet() + @cache_on_self_and_args("NoneAsConstantBuffer") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -4449,6 +4584,7 @@ def has_tensor_output(self) -> bool: class ShapeAsConstantBuffer(IRNode): expr: Expr + @cache_on_self_and_args("ShapeAsConstantBuffer") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -4470,6 +4606,47 @@ class ComputedBuffer(OperationBuffer): data: Loops _force_realize: ClassVar[bool] = False + # fields for split reduction + _split_size: Optional[int] = None + _original_inner_fn: Optional[Callable[..., Any]] = None + _original_ranges: Optional[Sequence[_IntLike]] = None + _original_reduction_ranges: Optional[Sequence[_IntLike]] = None + + @contextlib.contextmanager + def with_original_inner_fn(self) -> Iterator[None]: + assert self._split_size is not None + assert self._original_inner_fn is not None + assert self._original_ranges is not None + assert self._original_reduction_ranges is not None + + assert isinstance(self.data, Reduction), f"{type(self.data)}" + old_data = self.data + old_layout = self.layout + try: + new_data = Reduction( + device=old_data.device, + dtype=old_data.dtype, + inner_fn=self._original_inner_fn, + ranges=self._original_ranges, + reduction_ranges=self._original_reduction_ranges, + reduction_type=old_data.reduction_type, + src_dtype=old_data.src_dtype, + reduction_hint=old_data.reduction_hint, + ) + self.data = new_data + # this layout does not matter since we skip tl.store + # later + self.layout = FixedLayout( + old_data.device, + old_data.dtype, + self._original_ranges, + ) + self.get_default_sizes_body.clear_cache(self) + yield + finally: + self.data = old_data + self.layout = old_layout + @staticmethod @contextlib.contextmanager def force_realize() -> Iterator[None]: @@ -4521,6 +4698,7 @@ def get_read_writes(self) -> dependencies.ReadWrites: self.data.get_size(), ) + @cache_on_self_and_args("ComputedBuffer") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -4624,7 +4802,7 @@ def get_default_sizes_body( tuple[list[Expr], list[Expr]], ]: args, var_ranges = dependencies.index_vars_squeeze( - self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q" + self.get_pointwise_size(), self.get_reduction_size(), prefix="q" ) with patch.object(ConstantBuffer, "override_device", self.get_device()): body = LoopBody( @@ -4839,6 +5017,9 @@ def _apply_loop_reordering( sizes = [sizes[i] for i in order] return sizes, same_reorder(order), inverse_reorder(order) + def get_pointwise_size(self) -> Sequence[Expr]: + return self.data.get_pointwise_size() + def get_reduction_size(self) -> Sequence[Expr]: return self.data.get_reduction_size() @@ -4974,6 +5155,7 @@ def __init__( self.subgraph_inps: Optional[list[Optional[Union[IRNode, sympy.Expr]]]] = None self.subgraph_outs: Optional[list[Optional[IRNode]]] = None + @cache_on_self_and_args("TritonTemplateBuffer") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -5048,7 +5230,9 @@ def benchmark(self, *args: Any, out: torch.Tensor) -> float: } if config.profile_bandwidth_with_do_bench_using_profiling: return do_bench_using_profiling(lambda: algo(*args), **benchmark_configs) # type: ignore[arg-type] - return benchmarker.benchmark(algo, args, {"out": out}, **benchmark_configs) + return benchmarker.benchmark( + algo, args, {"out": out}, device=None, **benchmark_configs + ) def call_name(self) -> str: raise NotImplementedError @@ -5340,6 +5524,7 @@ def is_extern(self) -> bool: def num_reads(self) -> int: return 1 + @cache_on_self_and_args("InputsKernel") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -5514,6 +5699,7 @@ def can_realize_into_without_copy( and not isinstance(src.data, ExternKernelAlloc) ) + @cache_on_self_and_args("ConcatKernel") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -5849,7 +6035,7 @@ def unflatten_args( if shape_env := V.fake_mode.shape_env: node_meta_val = V.current_node.meta.get("val") ctx: AbstractContextManager[None] = nullcontext() - if V.current_node.target == torch._higher_order_ops.effects.with_effects: + if V.current_node.target is torch._higher_order_ops.effects.with_effects: # remove the first effect token in meta["val"] and meta["unbacked_bindings"] node_meta_val = node_meta_val[1] ctx = _remove_effect_token_unbacked_bindings(V.current_node) @@ -6430,6 +6616,7 @@ def canonicalize(self) -> tuple[Expr, Sequence[Expr]]: index = sympy_subs(sympy.expand(index), replacement) return index, tuple(new_sizes) + @cache_on_self_and_args("ExternKernel") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -6798,6 +6985,7 @@ def get_kernel_and_metadata(self) -> tuple[Kernel, Any, list[str], list[str]]: configs = kernel.configs kernel = kernel.fn + # pyrefly: ignore # bad-return return kernel, configs, restore_value_args, reset_to_zero_args @override @@ -6889,6 +7077,7 @@ def codegen(self, wrapper: PythonWrapperCodegen) -> None: original_fxnode_name=self.fx_node.name, ) + @cache_on_self_and_args("UserDefinedTritonKernel") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -6952,7 +7141,10 @@ def __init__( self.mutable_args = [ kernel_args[key] for key in identify_mutated_tensors( - kernel, {**kernel_args, **autotuned_kwargs}, tma_descriptor_metadata + # pyrefly: ignore # bad-argument-type + kernel, + {**kernel_args, **autotuned_kwargs}, + tma_descriptor_metadata, ) ] @@ -7327,6 +7519,7 @@ def __init__( def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet([self.unbacked_offset_symbol]) + @cache_on_self_and_args("DynamicSelectStorageOffset") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -7377,6 +7570,7 @@ def __init__( def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet([self.unbacked_size_symbol]) + @cache_on_self_and_args("DynamicSliceSize") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -7441,6 +7635,7 @@ def __init__(self, scalar: SympyBoolean, msg: str) -> None: def has_side_effects(self) -> bool: return True + @cache_on_self_and_args("AssertScalar") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -8115,6 +8310,7 @@ def __init__( self.indices = indices self.skip_size_stride_alignment_checks = skip_size_stride_alignment_checks + @cache_on_self_and_args("MultiOutput") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -8237,6 +8433,7 @@ def get_inputs_that_alias_output(self) -> Sequence[str]: def realize(self) -> Optional[str]: return self.data.realize() + @cache_on_self_and_args("MutableBox") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: @@ -8770,9 +8967,7 @@ def _clone_aliased_inputs(carried_inputs: Sequence[IRNode]) -> Sequence[IRNode]: seen_buffers: OrderedSet[int] = OrderedSet() result: list[Union[IRNode, TensorBox, ShapeAsConstantBuffer]] = [] - for i, (original_input, unwrapped_buffer) in enumerate( - zip(carried_inputs, unwrapped_buffers) - ): + for original_input, unwrapped_buffer in zip(carried_inputs, unwrapped_buffers): if id(unwrapped_buffer) in seen_buffers: result.append(ExternKernel.copy_input(original_input)) else: @@ -9075,6 +9270,7 @@ def has_side_effects(self) -> bool: class NonTensorObj(IRNode): + @cache_on_self_and_args("NonTensorObj") def get_free_symbol_uses( self, unbacked_only: bool = False ) -> OrderedSet[sympy.Symbol]: diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py index 2977932c084f6..8e5a2aa09d4ea 100644 --- a/torch/_inductor/kernel/conv.py +++ b/torch/_inductor/kernel/conv.py @@ -677,7 +677,7 @@ def _convolution( def constrain_conv_to_fx_strides(fx_node, *args, **kwargs): - assert fx_node.target == torch.ops.aten.convolution.default + assert fx_node.target is torch.ops.aten.convolution.default if V.graph.layout_opt: return args, kwargs else: diff --git a/torch/_inductor/kernel/custom_op.py b/torch/_inductor/kernel/custom_op.py new file mode 100644 index 0000000000000..303110a561b5e --- /dev/null +++ b/torch/_inductor/kernel/custom_op.py @@ -0,0 +1,426 @@ +# Owner(s): ["module: inductor"] + +import functools +import logging +from typing import Any, Callable, Optional, Union + +import torch +from torch._inductor.codegen.subgraph import SubgraphTemplate +from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox +from torch._inductor.lowering import lowerings, validate_ir +from torch._inductor.select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, +) +from torch._inductor.virtualized import V +from torch.utils._ordered_set import OrderedSet + + +log = logging.getLogger(__name__) + + +class CustomOpConfig: + """Config for custom op autotuning. + + Specifies optional decomposition function with parameter values. + Each config creates exactly one variant. + + Args: + decomposition: Optional functions to autotune. If not provided, default will be used. + **params: Parameters passed to the function + + Examples: + CustomOpConfig(attention_impl, head_dim=32, method='chunked') + CustomOpConfig(head_dim=32, method='chunked') + """ + + def __init__( + self, + decomposition: Optional[Callable[..., Any]] = None, + **params: Any, + ): + if decomposition is not None and not callable(decomposition): + raise TypeError( + f"decomposition must be callable, got {type(decomposition)}" + ) + + self.decomposition = decomposition + self.params = params + + def get_decomposition( + self, default_impl: Optional[Callable[..., Any]] = None + ) -> Callable[..., Any]: + """Return the decomposition function for this config. + When decomposition is not specified, return the default implementation. + """ + if self.decomposition is not None: + return self.decomposition + + if default_impl is not None and callable(default_impl): + return default_impl + + raise TypeError( + "No decomposition specified in config and no default implementation provided. " + "Please provide a decomposition function in CustomOpConfig." + ) + + def __repr__(self) -> str: + decomp_name = self.decomposition.__name__ if self.decomposition else "default" + if self.params: + params_str = ", ".join(f"{k}={v}" for k, v in self.params.items()) + return f"CustomOpConfig({decomp_name}, {params_str})" + return f"CustomOpConfig({decomp_name})" + + +__all__ = [ + "autotune_custom_op", + "register_custom_op_autotuning", + "CustomOpConfig", +] + + +def _extract_tensor_inputs( + args: tuple[Any, ...], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + """Extract tensor inputs from mixed args/kwargs. + Separates tensors (for autotuning input_nodes) from non-tensor parameters. + Non-tensor kwargs are later functools.partial'd into decomposition functions. + + Args: + args: Positional arguments (mix of tensors and scalars) + kwargs: Keyword arguments (mix of tensors and scalars) + + Returns: + Tuple of (tensor_inputs_list, non_tensor_kwargs) + """ + tensor_inputs = [] + non_tensor_kwargs = {} + + # Process args and kwargs: separate tensor inputs and non tensor args + for i, arg in enumerate(args): + if isinstance(arg, (TensorBox, Buffer)): + tensor_inputs.append(arg) + else: + # Add non-tensor positional args to kwargs with generated names + non_tensor_kwargs[f"arg_{i}"] = arg + + for key, value in kwargs.items(): + if isinstance(value, (TensorBox, Buffer)): + tensor_inputs.append(value) + else: + non_tensor_kwargs[key] = value + + return tensor_inputs, non_tensor_kwargs + + +def _merge_config_and_runtime_kwargs( + config_params: dict[str, Any], + runtime_kwargs: dict[str, Any], +) -> dict[str, Any]: + """Merge config parameters with runtime kwargs. Runtime kwargs take precedence. + If there are conflicts, log a warning and use runtime value. + + Args: + config_params: Parameters from CustomOpConfig + runtime_kwargs: Runtime non-tensor kwargs from _extract_tensor_inputs + + Returns: + Merged kwargs dictionary with runtime values taking precedence + """ + merged_kwargs = config_params.copy() + + # Check for conflicts and let runtime kwargs dominate + conflicts = OrderedSet(config_params.keys()).intersection(runtime_kwargs.keys()) + + for key in conflicts: + log.warning( + "Parameter '%s' specified both in CustomOpConfig (%s) " + "and at runtime (%s). Using runtime value.", + key, + config_params[key], + runtime_kwargs[key], + ) + + # Runtime kwargs override config params + merged_kwargs.update(runtime_kwargs) + + return merged_kwargs + + +def _adapt_user_input_gen_fns( + inputs: list[Any], + arg_names: list[str], + user_input_gen_fns: dict[str, Callable[[torch.Tensor], torch.Tensor]], +) -> dict[int, Callable[[Any], torch.Tensor]]: + """Convert user input generators from name-based to index-based format. + Inductor autotune's input_gen_fns expects index of arg_names as key. + + Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes. + """ + from torch._inductor import config + + name_to_index = {name: i for i, name in enumerate(arg_names)} + index_based_fns = {} + + for name, gen_fn in user_input_gen_fns.items(): + if name in name_to_index: + index_based_fns[name_to_index[name]] = gen_fn + else: + log.warning( + "Unknown argument name '%s' in input_gen_fns. " + "Available argument names: %s", + name, + list(name_to_index.keys()), + ) + + def create_internal_input_gen_fn( + user_function: Callable[[torch.Tensor], torch.Tensor], arg_name: str + ) -> Callable[[Any], torch.Tensor]: + """Create internal input generator that converts IR buffer to user's fake tensor.""" + + def internal_input_gen_fn(ir_buffer: Any) -> torch.Tensor: + raw_shape = ir_buffer.get_size() + concrete_shape = V.graph.sizevars.size_hints( + raw_shape, fallback=config.unbacked_symint_fallback + ) + + fake_tensor = torch.empty( + concrete_shape, dtype=ir_buffer.get_dtype(), device="meta" + ) + return user_function(fake_tensor) + + return internal_input_gen_fn + + return { + i: create_internal_input_gen_fn( + user_gen_fn, arg_names[i] if i < len(arg_names) else f"arg_{i}" + ) + for i, user_gen_fn in index_based_fns.items() + if i < len(inputs) + } + + +def _create_fallback_choice( + name: str, + default_impl: Callable[..., Any], + fake_output: torch.Tensor, + kwargs: dict[str, Any], +) -> ExternKernelChoice: + """Create fallback choice for default implementation.""" + + def fallback_wrapper(*args: Any) -> Any: + return default_impl(*args, **kwargs) + + return ExternKernelChoice( + kernel=fallback_wrapper, + name=f"{name}_fallback_default", + has_out_variant=False, + op_overload=default_impl, + use_fallback_kernel=True, + ) + + +def autotune_custom_op( + name: str, + decompositions: list[Callable[..., Any]], + inputs: list[Any], + non_tensor_args: list[dict[str, Any]], + op_overload: torch._ops.OpOverload, + user_input_gen_fns: Optional[ + dict[str, Callable[[torch.Tensor], torch.Tensor]] + ] = None, +) -> Union[TensorBox, Any]: + """Autotune custom operations by comparing multiple decomposition implementations. + + Currently supports SINGLE OUTPUT custom ops only. + TODO: Add support for multiple output custom ops (tuple/list returns). + + This function generates multiple implementation choices for a custom operation and + uses Inductor's autotuning system to select the best performing variant at runtime. + + Args: + name: Unique identifier for the autotuning operation + decompositions: List of alternative implementation functions to benchmark + inputs: Input tensor IR nodes from compilation (TensorBox/Buffer objects) + non_tensor_args: List of kwargs dicts, paired with corresponding decompositions arg + op_overload: OpOverload of the custom op, used as fallback implementation + user_input_gen_fns: Optional custom input generators for benchmarking. + Maps input indices to functions that take fake tensors + and return real tensors for performance measurement. + + Returns: + IR node representing the optimized operation result + + Raises: + TypeError: If decompositions is not a list/tuple + RuntimeError: If no inputs or no valid choices generated + """ + if not isinstance(decompositions, (list, tuple)): + raise TypeError( + f"decompositions must be a list or tuple of callables, got {type(decompositions)}" + ) + + if not inputs: + raise RuntimeError(f"Custom op '{name}' requires tensor inputs for autotuning") + + if len(decompositions) != len(non_tensor_args): + raise ValueError( + f"decompositions and non_tensor_args must have same length, " + f"got {len(decompositions)} decompositions and {len(non_tensor_args)} kwargs" + ) + + template = SubgraphTemplate(name=name) + choices = template.generate_custom_op_choices( + name=name, + # pyrefly: ignore [bad-argument-type] + decompositions=decompositions, + input_nodes=list(inputs), + non_tensor_args=non_tensor_args, + ) + + # Add default implementation as fallback + if op_overload and hasattr(op_overload, "_op"): + fallback_name = f"{name}_fallback_default" + from torch._inductor.select_algorithm import extern_kernels + + # Skip if extern_kernel already registered to avoid duplicate registration error + if not hasattr(extern_kernels, fallback_name): + with V.fake_mode: + fake_inputs = [ir_node_to_tensor(inp) for inp in inputs] + fallback_kwargs = non_tensor_args[0] if non_tensor_args else {} + fake_output = op_overload(*fake_inputs, **fallback_kwargs) + + fallback_choice = _create_fallback_choice( + name, op_overload, fake_output, fallback_kwargs + ) + fallback_choice.maybe_append_choice( + choices=choices, + input_nodes=list(inputs), + layout=FixedLayout( + device=fake_output.device, + dtype=fake_output.dtype, + size=fake_output.shape, + stride=fake_output.stride(), + ), + ) + + if not choices: + raise RuntimeError(f"No valid choices generated for {name}") + + # Convert user input generation functions to internal format + input_gen_fns = {} + if user_input_gen_fns: + import inspect + + arg_names = ( + list(inspect.signature(decompositions[0]).parameters.keys()) + if decompositions + else [] + ) + input_gen_fns = _adapt_user_input_gen_fns(inputs, arg_names, user_input_gen_fns) + + return autotune_select_algorithm( + name=name, + choices=choices, + input_nodes=list(inputs), + layout=choices[0].layout, + input_gen_fns=input_gen_fns, + ) + + +def register_custom_op_autotuning( + custom_op: torch._library.custom_ops.CustomOpDef, + configs: Union[list[CustomOpConfig], list[Callable[..., Any]]], + name: Optional[str] = None, + input_gen_fns: Optional[dict[str, Callable[[torch.Tensor], torch.Tensor]]] = None, +) -> None: + """Register custom op for autotuning with custom_op configs where each config + specifies a decomposition implementation function with its parameter values. + + Args: + custom_op: Custom operation (decorated function from @torch.library.custom_op) + configs: List of CustomOpConfig objects + name: Operation name (default: "{op_name}_autotuned") + input_gen_fns: Custom input generators for benchmarking + + Examples: + @torch.library.custom_op("mylib::attention", mutates_args=()) + def my_attention(query, key, value, head_dim=32): + ... + + register_custom_op_autotuning( + my_attention, + configs=[ + CustomOpConfig(attention_impl, head_dim=32, method='chunked'), + CustomOpConfig(attention_impl, head_dim=64, method='tiled'), + CustomOpConfig(head_dim=128), # No decomposition specified, use default + ], + input_gen_fns={ + "query": lambda fake: torch.randn_like(fake, device='cuda'), + "key": lambda fake: torch.randn_like(fake, device='cuda'), + "value": lambda fake: torch.randn_like(fake, device='cuda'), + } + ) + """ + from torch._library.custom_ops import CustomOpDef + + if not isinstance(custom_op, CustomOpDef): + raise TypeError( + f"custom_op must be a CustomOpDef (decorated function from @torch.library.custom_op), " + f"got {type(custom_op)}." + ) + + op_overload = custom_op._opoverload + default_impl = custom_op._init_fn + + if not isinstance(configs, (list, tuple)): + raise TypeError(f"configs must be a list or tuple, got {type(configs)}") + + processed_configs = [] + for config in configs: + if isinstance(config, CustomOpConfig): + processed_configs.append(config) + else: + raise TypeError( + f"Each config must be a CustomOpConfig object, got {type(config)}" + ) + + if not processed_configs: + raise ValueError("At least one config must be provided") + + if name is None: + name = f"{op_overload._name}_autotuned" + + @functools.wraps(op_overload) + def autotuning_lowering(*args: Any, **kwargs: Any) -> Any: + """Inductor lowering function that replaces custom op calls with autotuned versions.""" + # Extract tensor inputs and non-tensor parameters (runtime kwargs) + tensor_inputs, runtime_kwargs = _extract_tensor_inputs(args, kwargs) + + # Prepare decompositions and kwargs by merging config params with runtime kwargs + decompositions = [] + non_tensor_args = [] + + for config in processed_configs: + decomp = config.get_decomposition(default_impl=default_impl) + decompositions.append(decomp) + + # Merge config params with runtime kwargs (runtime takes precedence) + merged_kwargs = _merge_config_and_runtime_kwargs( + config.params, runtime_kwargs + ) + non_tensor_args.append(merged_kwargs) + + result = autotune_custom_op( + name=name, + decompositions=decompositions, + inputs=tensor_inputs, + non_tensor_args=non_tensor_args, + op_overload=op_overload, + user_input_gen_fns=input_gen_fns, + ) + + validate_ir(result) + return result + + lowerings[op_overload] = autotuning_lowering diff --git a/torch/_inductor/kernel/flex/common.py b/torch/_inductor/kernel/flex/common.py index 8ad64cf1800f0..b604514f30d14 100644 --- a/torch/_inductor/kernel/flex/common.py +++ b/torch/_inductor/kernel/flex/common.py @@ -5,7 +5,7 @@ from collections.abc import Sequence from functools import partial from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union import sympy @@ -14,6 +14,13 @@ from torch.utils._ordered_set import OrderedSet from torch.utils._pytree import tree_map, tree_map_only + +if TYPE_CHECKING: + from torch._inductor.codegen.cuda_combined_scheduling import _IntLike +else: + _IntLike = Union[int, sympy.Expr] + + from ...ir import ( ComputedBuffer, ExternKernel, @@ -214,18 +221,18 @@ def create_placeholder( def construct_strides( - sizes: Sequence[int], + sizes: Sequence[_IntLike], fill_order: Sequence[int], -) -> Sequence[int]: +) -> Sequence[_IntLike]: """From a list of sizes and a fill order, construct the strides of the permuted tensor.""" # Initialize strides assert len(sizes) == len(fill_order), ( "Length of sizes must match the length of the fill order" ) - strides = [0] * len(sizes) + strides: list[_IntLike] = [0] * len(sizes) # Start with stride 1 for the innermost dimension - current_stride = 1 + current_stride: _IntLike = 1 # Iterate through the fill order populating strides for dim in fill_order: @@ -235,7 +242,10 @@ def construct_strides( return strides -def infer_dense_strides(size: Sequence[int], orig_strides: Sequence[int]): +def infer_dense_strides( + size: Sequence[_IntLike], + orig_strides: Sequence[_IntLike], +): """This is a mirror of the same function in aten/src/ATen/ExpandUtils.cpp Args: diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py index 9e7a217829d7e..1a72e279aab79 100644 --- a/torch/_inductor/kernel/flex/flex_attention.py +++ b/torch/_inductor/kernel/flex/flex_attention.py @@ -193,24 +193,6 @@ def flex_attention( score_mod_other_buffers, mask_mod_other_buffers, ) - if _use_flex_flash_attention( - subgraph, - mask_graph, - kernel_options, - num_score_mod_placeholders=len(placeholder_inps), - ): - return create_flex_flash_attention_kernel( - query, - key, - value, - block_mask, - scale, - kernel_options, - subgraph_buffer, - mask_graph_buffer, - score_mod_other_buffers, - mask_mod_other_buffers, - ) ( query, @@ -240,6 +222,31 @@ def flex_attention( ] ) + if _use_flex_flash_attention( + subgraph, + mask_graph, + kernel_options, + num_score_mod_placeholders=len(placeholder_inps), + ): + return create_flex_flash_attention_kernel( + query, + key, + value, + block_mask, + scale, + kernel_options, + subgraph_buffer, + mask_graph_buffer, + score_mod_other_buffers, + mask_mod_other_buffers, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + mask_graph=mask_graph, + subgraph=subgraph, + ) + score_mod_other_buffers = maybe_realize(score_mod_other_buffers) mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) @@ -478,7 +485,7 @@ def validate_joint_graph(joint_graph: torch.fx.Graph): for node in joint_graph.nodes: if ( node.op == "call_function" - and node.target == torch.ops.flex_lib.zeros_and_scatter.default + and node.target is torch.ops.flex_lib.zeros_and_scatter.default ): for user in node.users: if user.op != "output": diff --git a/torch/_inductor/kernel/flex/flex_flash_attention.py b/torch/_inductor/kernel/flex/flex_flash_attention.py index 4fd38b0c66f53..c100df84d5a73 100644 --- a/torch/_inductor/kernel/flex/flex_flash_attention.py +++ b/torch/_inductor/kernel/flex/flex_flash_attention.py @@ -3,12 +3,15 @@ import functools import importlib -from typing import Any +from contextlib import contextmanager +from typing import Any, Callable, Optional, Sequence import sympy +from sympy import Expr, Integer import torch from torch.fx import GraphModule +from torch.utils._sympy.functions import Identity from ...ir import FixedLayout, ShapeAsConstantBuffer, Subgraph, TensorBox from ...lowering import empty_strided @@ -27,7 +30,7 @@ def ensure_flash_available() -> bool: in the same interpreter to retry the import. """ try: - return importlib.util.find_spec("flash_attn.cute") is not None + return importlib.util.find_spec("flash_attn.cute") is not None # type: ignore[attr-defined] except ImportError: return False @@ -40,6 +43,63 @@ def ensure_flash_available() -> bool: ) +def _fixed_indexer_cute( + size: Sequence[int], + stride: Optional[Sequence[int]] = None, + offset: Expr = Integer(0), +) -> Callable[[Sequence[Expr]], Expr]: + """ + Colexicographic indexer for CuteDSL - matches CuTe's coordinate interpretation. + + CuTe interprets linear indices in colexicographic (column-major) order, + whereas Inductor's default _fixed_indexer uses lexicographic (row-major) order. + + For size=[4, 128] with index=[b, q_idx]: + - Lexicographic: b*128 + q_idx*1 + - Colexicographic: b*1 + q_idx*2 + + CuTe then applies the tensor's actual memory strides to get the correct offset. + """ + + def indexer(index: Sequence[Expr]) -> Expr: + assert offset == Integer(0), "Offset not supported for colexicographic indexing" + if not index: + return Integer(0) + + result = index[0] + runner = size[0] + + for idx, sz in zip(index[1:], size[1:], strict=True): + result = result + runner * Identity(idx) + runner = runner * sz + + return result + + return indexer + + +@contextmanager +def patch_fixed_layout_indexer_for_cutedsl(): + """ + Temporarily swap FixedLayout.make_indexer so CuteDSL sees colexicographic indexing. + + Note [CuteDSL indexer patch]: + Flex flash attention only supports a limited set of IR ops (pointwise, reads, no stores), + so temporarily changing the indexing order is safe for the kernels we emit today. + TODO(dynamic shapes): Reconfirm once flex flash attention supports dynamic shapes. + """ + original_make_indexer = FixedLayout.make_indexer + + def cutedsl_make_indexer(self): + return _fixed_indexer_cute(self.size, self.stride, self.offset) + + FixedLayout.make_indexer = cutedsl_make_indexer # type: ignore[assignment] + try: + yield + finally: + FixedLayout.make_indexer = original_make_indexer # type: ignore[assignment] + + def input_buffers_require_grads(graph_module, num_score_mod_placeholders: int): """Check if any of the input buffers (beyond the score mod placeholders) require gradients.""" inputs = [] @@ -56,10 +116,8 @@ def requires_grad(n): return any(requires_grad(n) for n in inputs[num_score_mod_placeholders:]) -def is_trivial_graph( - graph_module: GraphModule, is_score_graph: bool, num_score_mod_placeholders: int -): - """Check if the flex graphs are compatible with Flash Attention.""" +def is_trivial_mask_graph(graph_module: GraphModule) -> bool: + """Mask graph is trivial when it only gates via the default full op.""" graph = graph_module.graph nodes = list(graph.nodes) placeholders = [n for n in nodes if n.op == "placeholder"] @@ -67,12 +125,14 @@ def is_trivial_graph( assert len(output) == 1, "Got graph w/ multiple outputs" output_val = output[0].args[0] - if is_score_graph: - if input_buffers_require_grads(graph_module, num_score_mod_placeholders): - return False - return True # party on garth # mask mod graph is empty if we have 4 inputs and full_default output - return len(placeholders) == 4 and output_val.target == torch.ops.aten.full.default + return len(placeholders) == 4 and output_val.target is torch.ops.aten.full.default + + +@functools.lru_cache(maxsize=1) +def _supports_nontrivial_mask_graphs() -> bool: + """Currently only supported on Hopper (SM90) GPUs.""" + return torch.cuda.get_device_capability()[0] == 9 def _can_use_flex_flash_attention( @@ -91,32 +151,15 @@ def _can_use_flex_flash_attention( False, "Input buffers require gradients (not supported by flash attention)", ) + mask_trivial = is_trivial_mask_graph(mask_graph.graph_module) - score_trivial = is_trivial_graph( - subgraph.graph_module, - is_score_graph=True, - num_score_mod_placeholders=num_score_mod_placeholders, - ) - mask_trivial = is_trivial_graph( - mask_graph.graph_module, - is_score_graph=False, - num_score_mod_placeholders=num_score_mod_placeholders, - ) + if mask_trivial: + return True, "" - if not score_trivial and not mask_trivial: - return ( - False, - "Both score and mask graphs are too complex for flash attention (require simple operations only)", - ) - elif not score_trivial: + if not _supports_nontrivial_mask_graphs(): return ( False, - "Score modification captured tensors that require gradients (not supported by flash attention)", - ) - elif not mask_trivial: - return ( - False, - "A non None BlockMask was passed to flex attention (not supported by flash attention yet)", + "NYI: Non-trivial mask graphs only supported on Hopper (SM90) for flash attention", ) return True, "" @@ -154,6 +197,12 @@ def create_flex_flash_attention_kernel( mask_graph_buffer: SubgraphResults, score_mod_other_buffers: list[TensorBox], mask_mod_other_buffers: list[TensorBox], + kv_num_blocks: TensorBox | None, + kv_indices: TensorBox | None, + full_kv_num_blocks: TensorBox | None, + full_kv_indices: TensorBox | None, + mask_graph: Subgraph, + subgraph: Subgraph | None = None, ) -> tuple[TensorBox | ShapeAsConstantBuffer, TensorBox | ShapeAsConstantBuffer]: """Create a flex flash attention kernel using CuteDSL template.""" if not ensure_flash_available(): @@ -193,19 +242,54 @@ def create_flex_flash_attention_kernel( stride=[sympy.sympify(s) for s in output.get_stride()], ) + # Used to check if we can skip block sparse impl + mask_graph_is_trivial = is_trivial_mask_graph(mask_graph.graph_module) + + needs_block_mask = not mask_graph_is_trivial + has_full_blocks = full_kv_num_blocks is not None + choices: list[Any] = [] - causal = kernel_options.get("causal", False) assert flash_attention_cutedsl_template is not None + + input_nodes = [query, key, value, lse] + if has_full_blocks: + input_nodes.extend( + [kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices] + ) + + if needs_block_mask and not has_full_blocks: + raise NotImplementedError( + "Flash attention with block mask but without full blocks is not supported yet" + ) + error = flash_attention_cutedsl_template.maybe_append_choice( choices, - input_nodes=[query, key, value, lse], + input_nodes=input_nodes, layout=output_layout, mutated_inputs=[lse], subgraphs=[subgraph_buffer, mask_graph_buffer], SM_SCALE=scale, - CAUSAL=causal, + NEEDS_BLOCK_MASK=needs_block_mask, ) + def wrap_choice_render(choice): + # See Note [CuteDSL indexer patch] + original_make_kernel_render = choice.make_kernel_render + + def make_kernel_render_with_patch(*args, **kwargs): + render_kernel, render = original_make_kernel_render(*args, **kwargs) + + # Let the template construct its closures, then scope the indexer patch + # to the actual render call that emits the kernel + render_with_patch = patch_fixed_layout_indexer_for_cutedsl()(render) + + return render_kernel, render_with_patch + + choice.make_kernel_render = make_kernel_render_with_patch + + for choice in choices: + wrap_choice_render(choice) + if error or not choices: # Fallback to original implementation raise RuntimeError(f"CuteDSL template failed: {error}") diff --git a/torch/_inductor/kernel/flex/templates/flash_attention.py.jinja b/torch/_inductor/kernel/flex/templates/flash_attention.py.jinja index d4f29bb847033..252e324554fdf 100644 --- a/torch/_inductor/kernel/flex/templates/flash_attention.py.jinja +++ b/torch/_inductor/kernel/flex/templates/flash_attention.py.jinja @@ -1,6 +1,10 @@ - +{% if NEEDS_BLOCK_MASK %} +{{def_kernel("Q", "K", "V", "LOGSUMEXP", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} +{% else %} {{def_kernel("Q", "K", "V", "LOGSUMEXP")}} +{% endif %} from flash_attn.cute.interface import _flash_attn_fwd + from flash_attn.cute.block_sparsity import BlockSparseTensorsTorch # Transpose tensors for _flash_attn_fwd compatibility (B,H,M,D) -> (B,M,H,D) q_transposed = Q.transpose(1, 2) @@ -26,6 +30,25 @@ output = {{get_output()}} output_transposed = output.transpose(1, 2) + {% if NEEDS_BLOCK_MASK %} + @cute.jit + def mask_mod(b_idx, h_idx, q_idx, kv_idx, aux_tensors): + {{unpack_buffers("aux_tensors", indent_width=8)}} + {{ modification( + subgraph_number=1, + output_name="mask_mod_output", + b="b_idx", + h="h_idx", + m="q_idx", + n="kv_idx", + ) | indent_except_first(2) }} + return mask_mod_output + block_sparse_tensors = BlockSparseTensorsTorch(KV_NUM_BLKS, KV_IDX, FULL_KV_NUM_BLKS, FULL_KV_IDX) + {% else %} + block_sparse_tensors = None + mask_mod = None + {% endif %} + # Collect any additional tensor buffers that were added during modifications {% set tensor_buffers = get_tensor_buffers() -%} {% if tensor_buffers -%} @@ -41,10 +64,11 @@ k_transposed, v_transposed, softmax_scale={{SM_SCALE}}, - causal={{CAUSAL}}, return_lse=True, score_mod=score_mod, + mask_mod=mask_mod, out=output_transposed, lse=LOGSUMEXP, + block_sparse_tensors=block_sparse_tensors, aux_tensors=buffers - ) \ No newline at end of file + ) diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 5e4aed0d507a0..6a8657f86bf03 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -1545,6 +1545,7 @@ def tuned_sparse_semi_structured_mm( (ScalingType.TensorWise, ScalingType.TensorWise), (ScalingType.RowWise, ScalingType.RowWise), (ScalingType.BlockWise1x128, ScalingType.BlockWise128x128), + (ScalingType.BlockWise1x128, ScalingType.BlockWise1x128), ] @@ -1563,11 +1564,15 @@ def _is_rowwise_scaling(sz: Any, transpose: bool) -> bool: return V.graph.sizevars.statically_known_equals(sz[idx], 1) -def _is_blockwise1xTILESIZE_scaling(sz: Any, tensor_sz: Any, tile_size: int) -> bool: +def _is_blockwise1xTILESIZE_scaling( + sz: Any, tensor_sz: Any, tile_size: int, transpose: bool +) -> bool: + lhs = 1 if transpose else 0 + rhs = 0 if transpose else 1 return V.graph.sizevars.statically_known_equals( - sz[0], tensor_sz[0] + sz[lhs], tensor_sz[lhs] ) and V.graph.sizevars.statically_known_equals( - sz[1], ceildiv(tensor_sz[1], tile_size) + sz[rhs], ceildiv(tensor_sz[rhs], tile_size) ) @@ -1589,7 +1594,9 @@ def is_desired_scaling( case ScalingType.RowWise: return _is_rowwise_scaling(scale_size, transpose) case ScalingType.BlockWise1x128: - return _is_blockwise1xTILESIZE_scaling(scale_size, t.get_size(), 128) + return _is_blockwise1xTILESIZE_scaling( + scale_size, t.get_size(), 128, transpose + ) case ScalingType.BlockWise128x128: return _is_blockwise128x128_scaling(scale_size, t.get_size()) case _: diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 5da5eaa70ffb7..b95073e769f31 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -5,7 +5,7 @@ import torch from torch._inductor.select_algorithm import realize_inputs, SymbolicGridFn -from torch._inductor.utils import sympy_product +from torch._inductor.utils import get_current_backend, sympy_product from torch._inductor.virtualized import V from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols @@ -145,7 +145,10 @@ def use_native_matmul(mat1, mat2): raise AssertionError("native matmul doesn't support block_ptr codegen yet") # Currently only enable native matmul for triton on GPU. - if not (mat1.get_device().type == "cuda" and config.cuda_backend == "triton"): + device_type = mat1.get_device().type + if not ( + device_type in ("cuda", "xpu") and get_current_backend(device_type) == "triton" + ): return False # Currently, tl.dot only supports following dtypes diff --git a/torch/_inductor/lookup_table/README.md b/torch/_inductor/lookup_table/README.md new file mode 100644 index 0000000000000..6c87a365bd8ac --- /dev/null +++ b/torch/_inductor/lookup_table/README.md @@ -0,0 +1,253 @@ +# Template Lookup Table System + +The template lookup table system provides a way to pre-configure kernel template parameters for specific operations and +input configurations, bypassing the default choice generation and autotuning process. + +## Overview + +The lookup table system replaces default choice generation with pre-configured template parameters for specific +operations and input configurations. It sits orthogonal to `max-autotune(-gemm)` in the following way + +If a lookup table is provided and there is a match + +- We check whether the template(s) in the match are currently in use +- If so, we use the pre-configured template(s) and config and bypass choice generation + - If more than one choice is provided, we run autotune among the pre-configured choices +- If not, we fall back to the default choice generation process, including max-autotune(-gemm) logic + +If there is no match, we fall back to the default choice generation process, including max-autotune(-gemm) logic + +## Configuration + +Enable the system by setting both: + +```python +from torch._inductor import config +config.lookup_table.table = your_table_dict +# You also need to set it as the default choice handler +from torch._inductor.lookup_table import LookupTableChoices +torch._inductor.V.set_choices_handler(LookupTableChoices()) +``` + +### Device Key Handling + +The key schema format is described in detail in the [Key Schemas](#key-schemas) section below. + +Configure device key behavior: + +```python +# Control whether entries include device-specific keys for lookups +# Device-agnostic entries work across different GPU models +``` + +**Lookup Behavior**: During lookup, the system automatically tries both key formats: + +1. **Device-specific key** (e.g., `"NVIDIA H100+input_data+mm"`) - tried first +1. **Device-agnostic key** (e.g., `"input_data+mm"`) - tried if device-specific fails + +**Priority**: If both device-specific and device-agnostic entries exist for the same inputs, the device-specific entry +takes priority. + +**NOTE**: Device-based keys simplify hardware-specific optimization without complex build rules. Currently limited to +device name only. If you need additional conditional key attributes (e.g., CUDA version filtering), please file an issue +or submit a patch. + +## Behavior + +When the table is active, the following behavior occurs for all supported operations: + +### Match Found + +- Uses pre-configured choices from the table instead of generating default choices +- Bypasses autotuning if only a single choice is provided +- If multiple choices are provided, autotuning occurs among those choices only + +### No Match Found + +- Standard default behavior - generates choices using heuristics and max-autotune settings + +### Table Not Set or Inactive + +- Standard default behavior - generates choices using heuristics and max-autotune settings + +## Supported Operations + +Currently supports: `mm`, `addmm`, `bmm`, `mm_plus_mm`, `scaled_mm` operations with + +- Triton +- ATEN +- DecomposeK + +## Table Format + +The table is a dictionary with keys in the format: + +``` +"input_key+op_name" +``` + +Where: + +- `input_key`: Generated from `KernelInputs.key` property, represents tensor shapes/dtypes/strides +- `op_name`: Operation name (`"mm"`, `"addmm"`, etc.) + +Each value is a list of configuration dictionaries containing: + +- `template_id`: Template identifier (`"triton:mm"`, `"triton::mm_persistent_tma"`, `"decompose_k"`, etc.) +- Template-specific parameters (`BLOCK_M`, `BLOCK_N`, `BLOCK_K`, `num_warps`, etc.) + +## Key Schemas + +**NOTE**: The key schema format is subject to change as the system evolves. + +The lookup table uses composite keys to match kernel configurations. See +[Implementation Details](#implementation-details) below for more technical information about key generation. This +section describes the structure of these keys. + +### Key Format Structure + +Keys follow the pattern: + +``` +[device_name+]input_key+[additional_params+]op_name +``` + +Components: + +- **device_name** (optional): GPU device identifier (e.g., `"NVIDIA H100"`) + + - Obtained from `torch.cuda.get_device_properties().gcnArchName` + - Enables device-specific optimizations + - When omitted, creates device-agnostic entries that work across hardware + +- **input_key**: Tensor configuration representation from `KernelInputs.key` + + - Format: `((dtype, shape, stride), (dtype, shape, stride), ...)` + - Each tuple represents one input tensor's properties + - Example: `((torch.float16, [128, 256], [0, 1]), (torch.float16, [64, 256], [256, 1]))` + - Order matches the operation's input argument order + +- **additional_params** (optional): Operation-specific parameters + + - Format: `key1=value1&key2=value2` + - Example: `alpha=1&beta=1` for addmm operations + +- **op_name**: Operation identifier + + - Examples: `"mm"`, `"addmm"`, `"bmm"`, `"mm_plus_mm"`, `"scaled_mm"` + +### Key Examples + +**Device-specific key for addmm:** + +``` +"NVIDIA H100+((torch.float16, [128, 256], [0, 1]), (torch.float16, [128, 64], [64, 1]), (torch.float16, [64, 256], [256, 1]))+alpha=1&beta=1+addmm" +``` + +**Device-agnostic key for mm:** + +``` +"((torch.float16, [64, 128], [128, 1]), (torch.float16, [128, 256], [256, 1]))+mm" +``` + +**Key with no additional parameters:** + +``` +"((torch.float32, [512, 512], [512, 1]), (torch.float32, [512, 512], [512, 1]))+bmm" +``` + +### Lookup Strategy + +During lookup, the system tries keys in priority order: + +1. **Device-specific key** - checked first if device information is available +1. **Device-agnostic key** - fallback if device-specific lookup fails + +This allows tables to contain: + +- Device-optimized configurations (higher priority) +- Portable configurations that work across devices +- Mix of both for flexible deployment + +## Example Table + +This is an example table for a single input showing two configurations + +```python +table = { + "((torch.float16, [128, 256], [0, 1]), (torch.float16, [128, 64], [64, 1]), (torch.float16, [64, 256], [256, 1]))+alpha=1&beta=1+addmm": [ + { + "template_id": "triton::mm", + "EVEN_K": true, + "USE_FAST_ACCUM": false, + "ACC_TYPE": "tl.float32", + "num_stages": 2, + "num_warps": 4, + "BLOCK_M": 32, + "BLOCK_N": 32, + "BLOCK_K": 64, + "hint_override": null, + "GROUP_M": 8, + "template_hash": "0717af5834e39dcca7ea817f896b8d85b4886422da7a3ab5f6911b4cfe568896" + }, + { + "template_id": "aten::bias_addmm" + }, + ] +} +``` + +## Source Hashing Safety + +The lookup table system includes source hashing to prevent using stale configurations when template code changes. + +### Configuration + +- **Enabled by default**: `torch._inductor.config.lookup_table.check_src_hash = True` +- **Optional field**: Add `"template_hash"` to table entries for enhanced safety + +### Behavior + +When source hash checking is enabled: + +- Template configurations with `"template_hash"` fields are validated against current template source hashes +- Mismatched hashes indicate the template code has changed since the configuration was created +- Stale configurations are automatically filtered out with a warning message +- Configurations without hash fields are preserved for backward compatibility or if the user wants to fly looser + +### Example with Template Hash + +```python +{ + "template_id": "triton::mm", + "BLOCK_M": 32, + "BLOCK_N": 32, + "BLOCK_K": 16, + "template_hash": "0717af5834e39dcca7ea817f896b8d85b4886422da7a3ab5f6911b4cfe568896" +} +``` + +## Performance Impact + +- **Lookup Hit**: Eliminates heuristic choice generation and autotuning overhead (if a single choice) +- **Lookup Miss**: Default behavior, including heuristic choice generation and autotuning +- **Memory**: Table stored in memory, minimal overhead for key generation and lookup + +## Implementation Details + +### Key Generation + +- Device key: Uses `torch.cuda.get_device_properties().gcnArchName` (e.g., "NVIDIA H100") +- Input key: Generated from `KernelInputs.key` containing tensor properties + +### Entry Points + +The system is accessed through: + +- `lookup_template_configs(kernel_inputs, op_name, template_uids)` - Main lookup function +- `LookupTableChoices._finalize_template_configs()` - Integration point with existing choice system + +### Error Handling + +- Validates config dictionaries contain required `template_id` field +- Gracefully handles non-CUDA devices by returning empty results diff --git a/torch/_inductor/lookup_table/__init__.py b/torch/_inductor/lookup_table/__init__.py new file mode 100644 index 0000000000000..0ebb1d5618bfa --- /dev/null +++ b/torch/_inductor/lookup_table/__init__.py @@ -0,0 +1,32 @@ +""" +Template lookup table system for PyTorch Inductor. + +This package provides functionality for: +- Loading pre-configured template choices from lookup tables +- Managing template configurations and choices + +All functionality is contained within the LookupTableChoices class. +You can customize any aspect by subclassing LookupTableChoices and overriding methods. + +Usage: + # Basic usage + choices = LookupTableChoices() + V.set_choices_handler(choices) + + # Custom usage + class MyCustomChoices(LookupTableChoices): + def _get_lookup_table(self): + return my_custom_table + + def make_lookup_key(self, kernel_inputs, op_name, include_device=False): + return f"custom_{op_name}_{hash(str(kernel_inputs))}" + + V.set_choices_handler(MyCustomChoices()) +""" + +from .choices import LookupTableChoices + + +__all__ = [ + "LookupTableChoices", +] diff --git a/torch/_inductor/lookup_table/choices.py b/torch/_inductor/lookup_table/choices.py new file mode 100644 index 0000000000000..46e54180114ab --- /dev/null +++ b/torch/_inductor/lookup_table/choices.py @@ -0,0 +1,418 @@ +from __future__ import annotations + +import copy +import logging +from functools import lru_cache +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch +from torch._inductor import config +from torch._inductor.choices import InductorChoices +from torch._inductor.kernel_template_choice import KernelTemplateChoice +from torch._inductor.template_heuristics.params import DictKernelTemplateParams + + +log = logging.getLogger(__name__) + + +if TYPE_CHECKING: + from collections.abc import Generator + + from torch._inductor.codegen.common import KernelTemplate + from torch._inductor.kernel_inputs import KernelInputs + from torch._inductor.select_algorithm import ExternKernelChoice + + +class LookupTableChoices(InductorChoices): + """ + InductorChoices subclass that uses lookup table when available, otherwise falls back to parent. + All lookup functionality is contained within this class and can be customized by overriding methods. + """ + + def _get_lookup_table(self) -> dict[str, list[dict[str, Any]]]: + """ + Get the template lookup table from config. + Override this method to use custom lookup table sources (database, API, etc.). + """ + if not torch.cuda.is_available() or config.lookup_table.table is None: + return {} + return config.lookup_table.table + + @staticmethod + @lru_cache + def _get_device_key(device: torch.device) -> Optional[str]: + """ + Generate a device key for lookup table indexing. + For CPU devices, returns None. + For CUDA devices, returns the props.gcnArchName string. + """ + if device.type != "cuda": + # only cuda devices are supported, this indicates that the system is not in use + # for this device + return None + + # Get CUDA device properties + props = torch.cuda.get_device_properties(device.index) + return props.gcnArchName + + @staticmethod + def _generate_kernel_inputs_key(kernel_inputs: KernelInputs) -> str: + """ + Generate a key based on input node properties and scalars. + The key includes dtype, size, and stride information for each input node, + plus scalar values as key=value pairs separated by & signs. + """ + # Get node information using existing methods + dtypes = kernel_inputs.dtypes() + shapes = kernel_inputs.shapes_hinted() + strides = kernel_inputs.strides_hinted() + + # Create tuple of (dtype, shape_list, stride_list) for each node + node_info = tuple( + (dtype, list(shape), list(stride)) + for dtype, shape, stride in zip(dtypes, shapes, strides) + ) + + # Create base key from node information + fmt_key = str(node_info) + # Add scalar information if present + if kernel_inputs._scalars: + # Sort scalars for consistent key generation and join with & + scalar_parts = [ + f"{key}={value}" + for key, value in sorted(kernel_inputs._scalars.items()) + ] + scalars_key = "&".join(scalar_parts) + fmt_key = f"{fmt_key}+{scalars_key}" + + return f"{fmt_key}" + + def make_lookup_key( + self, kernel_inputs: KernelInputs, op_name: str, include_device: bool = False + ) -> Optional[str]: + """ + Create a flattened lookup key from kernel inputs and operation name. + Override this method to customize key generation. + + Args: + kernel_inputs: KernelInputs object containing input nodes and scalars + op_name: Operation name (e.g., "mm", "addmm") + include_device: Whether to include device key in the generated key + + Returns: + A string key combining device (optional), operation, and input information + """ + device = kernel_inputs.device() + dev_key = self._get_device_key(device) + if dev_key is None: + # The system does not run when dev_key is None, regardless of + # whether include_device is True or False + return None + if not include_device: + dev_key = None + + # Generate input key using our staticmethod + input_key = self._generate_kernel_inputs_key(kernel_inputs) + + # Create the flattened lookup key + if dev_key is not None: + key_parts = [dev_key, input_key, op_name] + else: + key_parts = [input_key, op_name] + + return "+".join(key_parts) + + def make_lookup_key_variants( + self, kernel_inputs: KernelInputs, op_name: str + ) -> tuple[Optional[str], Optional[str]]: + """ + Generate both device-specific and device-agnostic lookup keys. + Override this method to customize key variant generation. + + Args: + kernel_inputs: KernelInputs object containing input nodes and scalars + op_name: Operation name (e.g., "mm", "addmm") + + Returns: + Tuple of (device_key, device_agnostic_key). Either may be None if generation fails. + """ + device_key = self.make_lookup_key(kernel_inputs, op_name, include_device=True) + device_agnostic_key = self.make_lookup_key( + kernel_inputs, op_name, include_device=False + ) + + return device_key, device_agnostic_key + + @staticmethod + def _entry_is_valid( + cfg: dict[str, Any], + template_id: str, + template_hash_map: Optional[dict[str, Optional[str]]], + ) -> bool: + """ + Check if a config entry is valid based on template hash validation. + + Args: + cfg: Configuration dictionary that may contain a template_hash field + template_id: The template identifier + template_hash_map: Optional mapping from template_uid to src_hash for validation + + Returns: + True if the config is valid and should be kept, False if it should be filtered out + """ + # If hash checking is disabled or no hash map provided, keep the config + if not config.lookup_table.check_src_hash or not template_hash_map: + return True + + template_hash = template_hash_map.get(template_id) + config_hash = cfg.get("template_hash") + + # Both hashes present - validate they match + if template_hash is not None and config_hash is not None: + if config_hash != template_hash: + log.warning( + "Hash validation failed for template '%s': config_hash='%s' != template_hash='%s'. " + "Template code may have changed. Filtering out config: %s", + template_id, + config_hash, + template_hash, + {k: v for k, v in cfg.items() if k != "template_hash"}, + ) + return False + else: + log.debug( + "Hash validation passed for template '%s': hash='%s'", + template_id, + template_hash, + ) + return True + # Config has no hash - keep it + elif config_hash is None: + log.debug( + "Config for template '%s' has no hash - keeping it (template_hash='%s')", + template_id, + template_hash, + ) + return True + # Template has no hash - keep config + else: + log.debug( + "Template '%s' has no src_hash - keeping config with hash '%s'", + template_id, + config_hash, + ) + return True + + def lookup_template_configs( + self, + kernel_inputs: KernelInputs, + op_name: str, + template_uids: list[str], + template_hash_map: Optional[dict[str, Optional[str]]] = None, + ) -> dict[str, list[dict[str, Any]]]: + """ + Unified function to look up template configurations for multiple templates. + Override this method to customize lookup logic. + + Args: + kernel_inputs: KernelInputs object containing input nodes and scalars + op_name: Operation name (e.g., "mm", "addmm") + template_uids: List of template identifiers (e.g., ["mm", "tma", "decompose_k"]) + template_hash_map: Optional mapping from template_uid to src_hash for validation + + Returns: + {}: No lookup table in use, or no matches found for any template + {"template_uid1": [config1, config2], ...}: Matches found, filtered configurations + """ + lookup_table = self._get_lookup_table() + if not lookup_table: + log.debug("Lookup table: no table configured or CUDA unavailable") + return {} + + # Try both key variants: device-specific first, then device-agnostic + # If both exist, device-specific takes priority + device_key, device_agnostic_key = self.make_lookup_key_variants( + kernel_inputs, op_name + ) + + config_list = [] + + for key_type, key in [ + ("device-specific", device_key), + ("device-agnostic", device_agnostic_key), + ]: + if key is not None: + config_list = lookup_table.get(key, []) + if config_list: + log.debug( + "Lookup table: found %d configs using %s key '%s' for %s", + len(config_list), + key_type, + key, + op_name, + ) + break + else: + log.debug( + "Lookup table: no match for %s (tried keys: %s, %s) (table has %d keys)", + op_name, + device_key, + device_agnostic_key, + len(lookup_table), + ) + return {} + + log.debug( + "Lookup table: found %d configs for %s templates %s", + len(config_list), + op_name, + template_uids, + ) + # Group configs by template_id + configs_by_template: dict[str, list[dict[str, Any]]] = {} + for cfg in config_list: + if not isinstance(cfg, dict): + raise ValueError( + f"Config for {op_name} operation is not a dictionary: {cfg}" + ) + if "template_id" not in cfg: + raise ValueError( + f"Config for {op_name} operation missing required 'template_id' field: {cfg}" + ) + + template_id = cfg["template_id"] + if template_id in template_uids: + if template_id not in configs_by_template: + configs_by_template[template_id] = [] + configs_by_template[template_id].append(cfg) + + # Check template hashes and clean up template_id field + result = {} + for template_id, matching_configs in configs_by_template.items(): + filtered_configs = [] + for cfg in matching_configs: + # Check template hash using helper function + if not self._entry_is_valid(cfg, template_id, template_hash_map): + continue + + # Return a copy of the config, as we don't want to modify the original + cconfig = copy.deepcopy(cfg) + # Lastly, we have to throw out the template_id, as it's not a valid kwarg + # and just used to identify which template the entry belongs to + del cconfig["template_id"] + # Similarly, the template_hash is not a valid kwarg + cconfig.pop("template_hash", None) + filtered_configs.append(cconfig) + + if filtered_configs: + result[template_id] = filtered_configs + + return result + + def _finalize_template_configs( + self, + template_choices: dict[str, Generator[KernelTemplateChoice, None, None]], + kernel_inputs: KernelInputs, + templates: list[Union[KernelTemplate, ExternKernelChoice]], + op_name: str, + kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None, + ) -> list[KernelTemplateChoice]: + """Check lookup table for hits, use those if found, otherwise fall back to parent.""" + # 1. Collect template src_hashes for validation + template_uids = [template.uid for template in templates] + template_hash_map = {} + for template in templates: + src_hash = getattr(template, "src_hash", None) + template_hash_map[template.uid] = src_hash + + log.debug( + "Choices: attempting lookup for %s with %d templates", + op_name, + len(template_uids), + ) + + # 2. Single batch lookup for all templates + lookup_results = self.lookup_template_configs( + kernel_inputs, op_name, template_uids, template_hash_map + ) + + # 3. Early exit if no lookup table or no matches + if not lookup_results: # Empty dict + log.info("LookupChoices: lookup miss for %s, using fallback", op_name) + return self._fallback( + template_choices, + kernel_inputs, + templates, + op_name, + kwarg_overrides, + ) + + log.info( + "LookupChoices: lookup hit for %s - found %d/%d templates: %s", + op_name, + len(lookup_results), + len(template_uids), + list(lookup_results.keys()), + ) + + # 4. Create KTCs only for templates with lookup entries + return self._create_lookup_choices( + lookup_results, templates, kernel_inputs, op_name + ) + + def _fallback( + self, + template_choices: dict[str, Generator[KernelTemplateChoice, None, None]], + kernel_inputs: KernelInputs, + templates: list[Union[KernelTemplate, ExternKernelChoice]], + op_name: str, + kwarg_overrides: Optional[dict[str, dict[str, Any]]] = None, + ) -> list[KernelTemplateChoice]: + """Fallback to parent if no lookup table or no matches.""" + # NOTE: this is broken out, so that subclasses are able to override this + # to handle explicitly the situations where the lookup take had a miss vs + # overriding the entire logic + return super()._finalize_template_configs( + template_choices, + kernel_inputs, + templates, + op_name, + kwarg_overrides, + ) + + def _create_lookup_choices( + self, + lookup_results: dict[str, list[dict[str, Any]]], + templates: list[Union[KernelTemplate, ExternKernelChoice]], + kernel_inputs: KernelInputs, + op_name: str, + ) -> list[KernelTemplateChoice]: + """Create KernelTemplateChoice objects from lookup results using parent's get_ktc method.""" + templates_by_uid = {template.uid: template for template in templates} + lookup_choices: list[KernelTemplateChoice] = [] + + for template_uid, configs in lookup_results.items(): + template = templates_by_uid[template_uid] + + # Use parent's get_ktc method to get a generator, then get the first base KTC + ktc_generator = self.get_ktc(kernel_inputs, template, op_name) + + try: + base_ktc = next(ktc_generator) + except StopIteration: + # No configs from heuristic, skip this template + continue + + # For each lookup config, create a KTC with the override kwargs + for c in configs: + lookup_ktc = KernelTemplateChoice( + template=base_ktc.template, + # use the ones from the lookup table + params=DictKernelTemplateParams(c), + extra_kwargs=base_ktc.extra_kwargs, + layout=base_ktc.layout, + inputs=base_ktc.inputs, + ) + lookup_choices.append(lookup_ktc) + + return lookup_choices diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index 77969952fce20..b077817cf30cf 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -132,6 +132,14 @@ def __init__( self.indexing = None + def get_original_num_rdims(self) -> int: + assert self.has_partial_accumulate + node = self.root_block.graph.find_nodes( + op="call_method", target="partial_accumulate" + )[0] + meta = node.args[-1] + return meta["num_reduction_dims"] + def extract_pw_from_reduction(self): self.root_block = self.root_block.extract_pw_from_reduction() self.has_partial_accumulate = True @@ -553,9 +561,12 @@ def extract_pw_from_reduction(self): buf = store.args[1] ops = store.args[0] + extra_meta = { + "num_reduction_dims": len(self.body.reduce_vars), + } with self.graph.inserting_after(store): self.graph.call_method( - "partial_accumulate", (ops, buf, reduction_type, red_arg) + "partial_accumulate", (ops, buf, reduction_type, red_arg, extra_meta) ) self.graph.erase_node(store) self.graph.erase_node(red) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 52521285dfec3..288db20e64db5 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -27,7 +27,7 @@ from torch._higher_order_ops.associative_scan import associative_scan_op from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation from torch._library.utils import get_layout_constraint_tag -from torch._prims_common import ( # pyrefly: ignore # deprecated +from torch._prims_common import ( # pyrefly: ignore # deprecated; pyrefly: ignore [deprecated] canonicalize_dim, canonicalize_dims, check, @@ -2490,11 +2490,18 @@ def inner_fn(index): def _boundaries_helper(tb: TensorBox) -> tuple[str, sympy.Expr, sympy.Expr, sympy.Expr]: + # Calculate the maximum offset for the boundaries tensor + # For a strided tensor, this is sum((size[i] - 1) * stride[i]) + stride[-1] + # This ensures the mask check in bucketize_binary_search works correctly + # for both contiguous and non-contiguous tensors. + size = tb.get_size() + stride = tb.get_stride() + max_offset = sum((s - 1) * st for s, st in zip(size, stride)) + stride[-1] return ( tb.get_name(), - tb.get_size()[-1], - tb.get_size()[0] * tb.get_stride()[0], - tb.get_stride()[-1], + size[-1], + max_offset, + stride[-1], ) @@ -3922,7 +3929,7 @@ def indice_slice_from_randperm(indice): isinstance(indice, ir.StorageBox) and isinstance(indice.data, ir.ExternKernel) and getattr(indice.data, "fx_node", None) - and indice.data.fx_node.target == torch.ops.aten.randperm.default + and indice.data.fx_node.target is torch.ops.aten.randperm.default ) return False diff --git a/torch/_inductor/memory.py b/torch/_inductor/memory.py index a8df2fe559875..1987195516ba7 100644 --- a/torch/_inductor/memory.py +++ b/torch/_inductor/memory.py @@ -928,7 +928,7 @@ def reorder_for_peak_memory( # other methods for method in methods: try: - if method == topological_sort_lpmf: + if method is topological_sort_lpmf: order = method( nodes, name_to_freeable_input_buf, name_to_buf, graph_outputs ) diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 857fe238c25c8..dccfd0a2f769f 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -293,6 +293,7 @@ def partial_accumulate( name: str, reduction_type: ReductionType, value: T, + extra_meta: dict[str, Any], ) -> None: raise NotImplementedError diff --git a/torch/_inductor/output_code.py b/torch/_inductor/output_code.py index 214b52a7491ad..f5ab01374d8a3 100644 --- a/torch/_inductor/output_code.py +++ b/torch/_inductor/output_code.py @@ -88,6 +88,9 @@ class OutputCode: def __call__(self, inputs: Sequence[Any]) -> Any: raise NotImplementedError(type(self)) + def prepare_for_serialization(self) -> None: + raise NotImplementedError(type(self)) + def post_compile( self, example_inputs: Sequence[InputType], @@ -677,7 +680,7 @@ def post_compile( ] else: # On the forward we don't know whether or not - # boxed_foward_device_index is set yet + # boxed_forward_device_index is set yet boxed_forward_device_index = graph_kwargs.get( "boxed_forward_device_index", None ) @@ -783,6 +786,9 @@ def post_compile( ) -> None: pass + def prepare_for_serialization(self) -> None: + pass + def set_triton_bundle(self, triton_bundle: Any) -> None: pass @@ -807,3 +813,97 @@ def __call__(self, inputs: Sequence[Any]) -> Any: def set_triton_bundle(self, triton_bundle: Any) -> None: pass + + +@dataclasses.dataclass +class RegionalOutputCode(OutputCode): + """ + OutputCode for regional inductor compilation results. + + Regional inductor returns a torch.fx.GraphModule that contains both + compiled regions (via standalone_compile) and eager regions. This needs + special serialization using GraphPickler instead of standard pickle. + + The serialization strategy stores the GraphModule as bytes using + GraphPickler.dumps(), which handles FakeTensors, AOTCompiledArtifacts, + and other special objects that standard pickle cannot handle. + """ + + # The serialized graph module as bytes (using GraphPickler) + _serialized_graph_module: Optional[bytes] = dataclasses.field( + default=None, init=False + ) + + # The actual graph module (cleared during serialization) + _graph_module: Optional[torch.fx.GraphModule] = dataclasses.field( + default=None, init=False + ) + + def __init__(self, graph_module: torch.fx.GraphModule): + """ + Args: + graph_module: The torch.fx.GraphModule returned by regional_inductor + """ + super().__init__() + self._graph_module = graph_module + self._serialized_graph_module = None + + def __call__(self, inputs: Sequence[Any]) -> Any: + """Execute the regional compiled graph.""" + if self._graph_module is None: + raise RuntimeError( + "RegionalOutputCode has no graph module loaded. " + "Did you forget to call post_compile()?" + ) + return self._graph_module(*inputs) + + def post_compile( + self, + example_inputs: Sequence[InputType], + constants: CompiledFxGraphConstants, + graph_kwargs: _CompileFxKwargs, + ) -> None: + """ + Post-compile processing for regional inductor. + + This deserializes the GraphModule from bytes using GraphPickler, + extracting the fake_mode from example_inputs. + """ + if self._graph_module is not None: + return + assert self._serialized_graph_module is not None + # Get fake mode from example inputs + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(example_inputs) + if fake_mode is None: + raise RuntimeError( + "Could not detect fake mode from example inputs. " + "Regional inductor requires fake mode for deserialization." + ) + + # Deserialize the graph module + from torch.fx._graph_pickler import GraphPickler + + gm = GraphPickler.loads(self._serialized_graph_module, fake_mode) + assert isinstance(gm, torch.fx.GraphModule) + gm.recompile() + self._graph_module = gm + + def set_triton_bundle(self, triton_bundle: Any) -> None: + """Regional inductor doesn't use triton bundles directly.""" + + def prepare_for_serialization(self) -> None: + """ + Prepare for serialization by converting the GraphModule to bytes. + + This uses GraphPickler to serialize the graph module since it contains + special objects like FakeTensors and AOTCompiledArtifacts that need + custom pickling. + """ + if self._graph_module is not None: + from torch.fx._graph_pickler import GraphPickler + + self._serialized_graph_module = GraphPickler.dumps(self._graph_module) + # Clear the graph module to avoid pickling it with standard pickle + self._graph_module = None diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index c457a4a863fb2..8133a50ca9405 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -743,7 +743,7 @@ def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: assert len(node_items) == len(self_items) m = Match(ctx, self) - for i, pattern, child_node in zip(itertools.count(), self_items, node_items): + for pattern, child_node in zip(self_items, node_items): if isinstance(pattern, PatternExpr): child_match = ctx.match(pattern, child_node) if not is_match(child_match): @@ -2099,7 +2099,7 @@ def call_function( ) -> PatternExpr: process_arg_fn = process_arg # Indexing is critical for matching getitem nodes, so we can't ignore int args here - if target == operator.getitem: + if target is operator.getitem: def process_arg_fn_impl( x: T, diff --git a/torch/_inductor/runtime/benchmarking.py b/torch/_inductor/runtime/benchmarking.py index 2ee2a7ae05434..999a27beddd6b 100644 --- a/torch/_inductor/runtime/benchmarking.py +++ b/torch/_inductor/runtime/benchmarking.py @@ -4,10 +4,11 @@ from functools import cached_property, wraps from itertools import chain from statistics import median -from typing import Any, Callable +from typing import Any, Callable, Optional, Union from typing_extensions import Concatenate, ParamSpec, Self, TypeVar import torch +import torch.utils._pytree as pytree from torch._dynamo.utils import counters, dynamo_timed from torch._inductor.config import use_experimental_benchmarker @@ -92,15 +93,45 @@ def wrapper(self: Any, *args: P.args, **kwargs: P.kwargs) -> T: class Benchmarker: + """ + A device-agnostic benchmarking utility for measuring the runtime of + inductor generated callables. + """ + def __init__(self: Self) -> None: pass + def infer_device(self, *fn_args: Any, **fn_kwargs: Any) -> torch.device: + inferred_device: Optional[torch.device] = None + for arg_or_kwarg in chain(fn_args, fn_kwargs.values()): + # Some callables take nested structures as arguments so use the + # flattened form to find any tensors + for arg_or_kwarg_leaf in pytree.tree_leaves(arg_or_kwarg): + if not isinstance(arg_or_kwarg_leaf, torch.Tensor): + continue + if inferred_device is None: + inferred_device = arg_or_kwarg_leaf.device + elif arg_or_kwarg_leaf.device != inferred_device: + raise ValueError( + "Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!" + ) + + if inferred_device is None: + raise ValueError( + "Can't safely infer the device type of `fn` with no device types" + " in `fn_args` or `fn_kwargs`. Use a direct benchmarking method instead e.g. " + "`Benchmarker.benchmark_cpu` or `Benchmarker.benchmark_gpu`." + ) + + return inferred_device + @time_and_count def benchmark( self: Self, fn: Callable[..., Any], - fn_args: tuple[Any, ...], - fn_kwargs: dict[str, Any], + fn_args: Optional[tuple[Any, ...]] = None, + fn_kwargs: Optional[dict[str, Any]] = None, + device: Optional[Union[str, torch.device]] = None, **kwargs: Any, ) -> float: """Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the @@ -109,7 +140,14 @@ def benchmark( device-specific implementations, like `benchmark_cpu` and `benchmark_gpu`. Raises `ValueError(...)` if we can't safely infer the device type of `fn`; for example, if multiple device types are found in `fn_args` and `fn_kwargs`, or if no device - types are found. + types are found. To bypass device inference, provide the device to the `device` + parameter. + + WARNING: if `fn` mutates `fn_args` or `fn_kwargs`, benchmarking may fail unexpectedly. + For example, if `fn` clears a mutable object, subsequent invocations of `fn` during + benchmarking will fail. In such cases, `fn` should handle cloning its arguments internally. + If device inference is required, `Benchmarker.infer_device` can be used prior to calling + this method without any arguments for `fn_args` and `fn_kwargs`. Arguments: - fn: The function to benchmark. @@ -117,27 +155,39 @@ def benchmark( - fn_kwargs: The function's kwargs. Keyword Arguments: + - device: Which device to use for benchmarking. If not provided the device will be attempted + to be inferred from `fn_args` and `fn_kwargs`. - **kwargs: The benchmarking implementation's kwargs. Returns: - The runtime of `fn(*fn_args, **fn_kwargs)`, in milliseconds. """ - inferred_device = None - # pyrefly: ignore [bad-assignment] - for arg_or_kwarg in chain(fn_args, fn_kwargs.values()): - if not isinstance(arg_or_kwarg, torch.Tensor): - continue - if inferred_device is None: - inferred_device = arg_or_kwarg.device - elif arg_or_kwarg.device != inferred_device: + inferred_device: Optional[torch.device] = None + if device is not None: + inferred_device = ( + torch.device(device) if isinstance(device, str) else device + ) + else: + if fn_args is None and fn_kwargs is None: raise ValueError( - "Can't safely infer the device type of `fn` with multiple device types in `fn_args` and `fn_kwargs`!" + "`fn_args` and `fn_kwargs` cannot both be None if `device` is not provided." ) - if inferred_device is None: - raise ValueError( - "Can't safely infer the device type of `fn` with no device types in `fn_args` or `fn_kwargs`! You should be calling `.benchmark_cpu` or `.benchmark_gpu` directly." # noqa: B950 - ) - _callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731 + + fn_args = fn_args or tuple() + fn_kwargs = fn_kwargs or {} + inferred_device = self.infer_device(*fn_args, **fn_kwargs) + + assert isinstance(inferred_device, torch.device) + + fn_args = fn_args or tuple() + fn_kwargs = fn_kwargs or {} + + # No need to wrap if the callable takes no arguments + if len(fn_args) == 0 and len(fn_kwargs) == 0: + _callable = fn + else: + _callable = lambda: fn(*fn_args, **fn_kwargs) # noqa: E731 + if inferred_device == torch.device("cpu"): return self.benchmark_cpu(_callable, **kwargs) # TODO(nmacchioni): For non-CPU functions we default to using the GPU-specific benchmarking diff --git a/torch/_inductor/runtime/caching/interfaces.py b/torch/_inductor/runtime/caching/interfaces.py index b4e2a3bed5a12..0758e11134018 100644 --- a/torch/_inductor/runtime/caching/interfaces.py +++ b/torch/_inductor/runtime/caching/interfaces.py @@ -183,6 +183,7 @@ def _make_key( fkey: Any = ( (callee, params) if not custom_params_encoder + # pyrefly: ignore [invalid-param-spec] else (callee, custom_params_encoder(*params[0], **params[1])) ) ikey: Any = context._isolation_key( @@ -193,8 +194,10 @@ def _make_key( def _make_dummy_record_wrapper(self, fn: Callable[P, R]) -> Callable[P, R]: @wraps(fn) def dummy_wrapper(*args: Any, **kwargs: Any) -> R: + # pyrefly: ignore [invalid-param-spec] return fn(*args, **kwargs) + # pyrefly: ignore [bad-return] return dummy_wrapper @abstractmethod @@ -359,6 +362,7 @@ def _get_odc_from_callee(self, callee: str) -> impls._OnDiskCacheImpl: callee_sub_dir: PathLike[str] = Path(callee) odc = impls._OnDiskCacheImpl(sub_dir=callee_sub_dir) self._callee_to_odc[callee] = odc + # pyrefly: ignore [unbound-name] return odc @override @@ -501,6 +505,7 @@ def __init__(self) -> None: self._imc: impls._InMemoryCacheImpl = impls._InMemoryCacheImpl() if fpath := os.environ.get("TORCHINDUCTOR_PRE_POPULATE_DETERMINISTIC_CACHE"): + # pyrefly: ignore [bad-assignment] flock: FileLock = FileLock(str(fpath) + ".lock") with locks._acquire_flock_with_timeout(flock): with open(fpath) as fp: @@ -545,6 +550,7 @@ def _get_odc_from_callee(self, callee: str) -> impls._OnDiskCacheImpl: callee_sub_dir: PathLike[str] = Path(callee) odc = impls._OnDiskCacheImpl(sub_dir=callee_sub_dir) self._callee_to_odc[callee] = odc + # pyrefly: ignore [unbound-name] return odc def _dump_imc_to_disk(self) -> Path | None: diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 9f3f0ed1b7d74..4555a94f9da17 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -331,3 +331,40 @@ def autotune( ) return best_config + + @staticmethod + def autotune_single_field(fn, init_val, min_val=None, max_val=None): + """ + fn is a function that takes the field value and returns the benchmarking result + init_val is the starting point of autotuning. + + Should work well for parabola like curve. Here is a real example + for split-size of mix-order-reduction: https://github.com/pytorch/pytorch/pull/166461 + """ + cache = {} + + def _bench(val): + if val not in cache: + cache[val] = fn(val) + # print(f"split size {val} -> {cache[val]:.3f} ms") + return cache[val] + + if min_val is None: + min_val = 1 + if max_val is None: + max_val = 2**30 # some arbitrary large value + + best_val = init_val + improved = True + while improved: + improved = False + candlist = [best_val // 2, best_val * 2] + for cand in candlist: + cand = max(cand, min_val) + cand = min(cand, max_val) + + if _bench(cand) < _bench(best_val): + best_val = cand + improved = True + + return best_val diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index f88278df88a48..5fa6be2c1eb49 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -530,7 +530,7 @@ def _dynamic_scale_rblock(self): # = regs_per_multiprocessor / (nreg * 32 * num_warps) # < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps) # = max_threads_per_multi_processor / (32 * num_warps) - # Using a tigher upper bound can reveal more optimization opportunities. + # Using a tighter upper bound can reveal more optimization opportunities. max_blocks_per_sm = max( device_prop.regs_per_multiprocessor // nreg_per_block, 1 ) @@ -917,11 +917,15 @@ def kernel_call(): return do_bench_using_profiling(kernel_call, warmup=10, rep=40) - if self.device_props.type == "cpu": - return benchmarker.benchmark_cpu(kernel_call) - - return benchmarker.benchmark_gpu( - kernel_call, rep=40, is_vetted_benchmarking=True + benchmark_kwargs = ( + {} + if self.device_props.type == "cpu" + else {"rep": 40, "is_vetted_benchmarking": True} + ) + return benchmarker.benchmark( + fn=kernel_call, + device=self.device_props.type, + **benchmark_kwargs, # type: ignore[arg-type] ) def copy_args_to_cpu_if_needed(self, *args, **kwargs): @@ -2623,6 +2627,17 @@ def pointwise( ), ] ) + if inductor_meta.get("atomic_add_found"): + configs.extend( + [ + triton_config_with_settings( + size_hints, + 64, + num_warps=1, + num_stages=1, # 250% improvement + ) + ] + ) if len(size_hints) == 2: # Only avoiding tuning on TileHint.SQUARE if not on ROCm builds # ROCm has observed improvement by diverging here @@ -3377,15 +3392,28 @@ def persistent_reduction( # small XBLOCK to use less registers/smem c.kwargs["XBLOCK"] = 1 - c.num_warps //= 2 - c.num_warps = max(c.num_warps, 2) - # less warps so potentially each sm can run more thread blocks - # Inside each thread block, we handle the split sequentially, - # more thread blocks is beneficial here. - newc = copy.deepcopy(c) - newc.num_warps = 2 - new_configs.append(newc) + rnumel_hint = size_hints["r0_"] + + if rnumel_hint <= 1024: + c.num_warps //= 2 + c.num_warps = max(c.num_warps, 2) + new_configs.append(c) + + # less warps so potentially each sm can run more thread blocks + # Inside each thread block, we handle the split sequentially, + # more thread blocks is beneficial here. + newc = copy.deepcopy(c) + newc.num_warps = 2 + new_configs.append(newc) + else: + # more warps for larger rows + new_configs.append(c) + + if c.num_warps < 32: + newc = copy.deepcopy(c) + newc.num_warps *= 2 + new_configs.append(newc) configs = unique_configs(new_configs) @@ -3691,8 +3719,7 @@ def generate(self, meta: dict[str, int]) -> None: split_size = meta.get("RSPLIT_SIZE") xblock = meta.get("XBLOCK") assert split_size - assert xblock - assert split_size % xblock == 0 + assert xblock == 1, "Mix order reduction force XBLOCK=1 right now" self.x_grid = self.ceildiv("xnumel", split_size) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 35b42174a62ca..d9b9d830b45af 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -20,13 +20,13 @@ from torch.utils._ordered_set import OrderedSet +from .ir import ComputedBuffer + if TYPE_CHECKING: from collections.abc import Iterator, Sequence from types import ModuleType -import weakref - import sympy import torch @@ -68,6 +68,7 @@ cache_on_self, cmp, device_need_guard, + get_current_backend, get_device_tflops, get_dtype_size, get_gpu_dram_gbps, @@ -97,28 +98,6 @@ _P = ParamSpec("_P") -_custom_should_partition_fns: weakref.WeakKeyDictionary[ - torch._ops.OpOverload, Callable[..., bool] -] = weakref.WeakKeyDictionary() - - -def register_should_partition_rule( - op: torch._ops.OpOverload, - func: Callable[..., bool], -) -> None: - """Register a function that says if Inductor should partition the graph on this op. - - The function should be have the same signature as the operator. - Inductor will invoke the function with FakeTensors when it needs to decide - if the graph should be partitioned. - - `register_should_partition_rule` is currently private and experimental. - Use at your own risk. - """ - assert isinstance(op, torch._ops.OpOverload) - _custom_should_partition_fns[op] = func - - class MixOrderReduction: """ This class contains utility functions to decide if we should fuse reductions @@ -126,11 +105,59 @@ class MixOrderReduction: """ @staticmethod + def is_split_reduction(node: BaseSchedulerNode) -> bool: + return node.is_reduction() and all( + subnode.node._split_size is not None + for subnode in node.get_nodes() + if isinstance(subnode, SchedulerNode) + and subnode.is_reduction() + and isinstance(subnode.node, ComputedBuffer) + ) + + @classmethod + def get_numel_rnumel(cls, node: BaseSchedulerNode) -> tuple[sympy.Expr, sympy.Expr]: + if cls.is_split_reduction(node): + xnumel = None + rnumel = None + for subnode in node.get_nodes(): + if not ( + isinstance(subnode, SchedulerNode) + and subnode.is_reduction() + and isinstance(subnode.node, ComputedBuffer) + ): + continue + + assert subnode.node._original_ranges is not None + curxnumel = V.graph.sizevars.simplify( + sympy_product(subnode.node._original_ranges) + ) + assert subnode.node._original_reduction_ranges is not None + currnumel = V.graph.sizevars.simplify( + sympy_product(subnode.node._original_reduction_ranges) + ) + + if xnumel is None: + xnumel = curxnumel + rnumel = currnumel + else: + assert V.graph.sizevars.statically_known_equals( + xnumel, curxnumel + ), f"{xnumel} v.s. {curxnumel}" + assert V.graph.sizevars.statically_known_equals( + rnumel, currnumel + ), f"{rnumel} v.s. {currnumel}" + + assert xnumel is not None + return (xnumel, rnumel) + else: + return node.group[1] # type: ignore[return-value] + + @classmethod def has_mix_reduction_orders( - node1: BaseSchedulerNode, node2: BaseSchedulerNode + cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode ) -> bool: - g1 = node1.group[1] - g2 = node2.group[1] + g1 = cls.get_numel_rnumel(node1) + g2 = cls.get_numel_rnumel(node2) if len(g1) != 2 or len(g2) != 2 or g1 == g2: return False @@ -194,9 +221,14 @@ def has_common_read( def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: if not config.triton.mix_order_reduction: return False + if not node1.is_gpu() or not node2.is_gpu(): return False - if node1.get_device().type != "cuda" or config.cuda_backend != "triton": # type: ignore[union-attr] + device_type = node1.get_device().type # type: ignore[union-attr] + if ( + device_type not in ("cuda", "xpu") + or get_current_backend(device_type) != "triton" + ): return False if not node1.is_reduction() or not node2.is_reduction(): return False @@ -210,18 +242,24 @@ def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: if len(common_reads) == 0: return False - g1 = node1.group[1] + g1 = cls.get_numel_rnumel(node1) nrow = sympy.Max(g1[0], g1[1]) ncol = sympy.Min(g1[0], g1[1]) # We require more more row than columns since # 1, we prefer doing persistent reduction for each row # 2, we will split the reduction across the rows - if not V.graph.sizevars.statically_known_geq(nrow, ncol * 10): + if not V.graph.sizevars.statically_known_geq(nrow, ncol * 2): + return False + + # When nrow is small, ncol should also be small (due to the check + # above). Thus the entire tensor should be well cached in L2. + # Mix order reduction is less beneficial. + if not V.graph.sizevars.statically_known_geq(nrow, 4096): return False contiguous_node, other_node = ( - (node1, node2) if node1.group[1][1] == ncol else (node2, node1) + (node1, node2) if g1[1] == ncol else (node2, node1) ) if not all( @@ -242,12 +280,12 @@ def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: return False # rnumel so large that we will not generated persistent reduction - if not V.graph.sizevars.statically_known_leq(ncol, 1024): + if not V.graph.sizevars.statically_known_leq(ncol, 1024 * 16): return False # Other reduction types like max/min is not supported yet. # There are no real use case as well. - return all( + out = all( subnode.node.get_reduction_type() # type: ignore[union-attr] in { "sum", @@ -256,6 +294,7 @@ def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: for subnode in other_node.get_nodes() if subnode.is_reduction() ) + return out @classmethod def are_mix_order_reductions( @@ -276,23 +315,23 @@ def is_contiguous_load(cls, buf: str, parent_node: BaseSchedulerNode) -> bool: if len(index_names) == 0: continue - assert len(index_names) == 1 - index_name = index_names[0] - index_expr = loop_body.indexing_exprs[index_name] - var_ranges = loop_body.var_ranges - if len(var_ranges) != 2: - return False - var_symbols = list(var_ranges.keys()) - stride_vars = V.graph.sizevars.stride_vars( - index_expr, - var_symbols, - var_symbols, - ) - n_congituous_read += stride_vars[-1] == 1 - if n_congituous_read > 0: - break - return n_congituous_read > 0 + # there can be multiple index_names some times + for index_name in index_names: + index_expr = loop_body.indexing_exprs[index_name] + var_ranges = loop_body.var_ranges + + # assumes the final symbol is for reduction + var_symbols = list(var_ranges.keys()) + stride_vars = V.graph.sizevars.stride_vars( + index_expr, + var_symbols, + var_symbols, + ) + n_congituous_read += stride_vars[-1] == 1 + if n_congituous_read > 0: + return True + return False @dataclasses.dataclass @@ -1442,8 +1481,12 @@ def apply_new_loop_order(self, new_order: Sequence[int]) -> None: self.refresh_dependencies(normalize=False, need_clear_tiling_cache=True) def swap_pw_red_dimension(self) -> None: - assert len(self._body.sizes[0]) == 2 - self.apply_new_loop_order([1, 0]) + num_rdims = self._body.get_original_num_rdims() + num_pwdims = len(self._body.iter_vars) - num_rdims + pwdims = tuple(range(num_pwdims)) + rdims = tuple(range(num_pwdims, num_pwdims + num_rdims)) + + self.apply_new_loop_order(rdims + pwdims) assert len(self.group[1]) == 2 self.group = self.group[0], (self.group[1][1], self.group[1][0]) @@ -1451,6 +1494,13 @@ def extract_pw_from_reduction(self) -> BaseSchedulerNode: self._body = self._body.extract_pw_from_reduction() return self + def cancel_reduction_split(self) -> None: + if not MixOrderReduction.is_split_reduction(self): + return + assert isinstance(self.node, ir.ComputedBuffer) + with self.node.with_original_inner_fn(): + self._compute_attrs() + def expand_dimension_for_pointwise_node( self, dimension: int, new_range: int ) -> None: @@ -2691,6 +2741,10 @@ def _init(self, nodes: list[ir.Operation]) -> None: } ) + # Unlike V.graph.removed_buffers, the op recorded here is removed but + # we still need the buffer (generated in alternative ways) + self.removed_ops: OrderedSet[str] = OrderedSet() + def get_donated_buffers(self) -> dict[str, SchedulerDonatedBuffer]: name_to_donated_buf = {} for name in V.graph.graph_inputs_original: @@ -2869,7 +2923,7 @@ def add_user( # NB: None means that the dependency is on an input. Don't actually # generate a dependency because if we do, Inductor will start trying # to free the unbacked int but that's pointless - for name, val in V.graph.graph_inputs.items(): + for val in V.graph.graph_inputs.values(): if isinstance(val, sympy.Expr): for fs in val.free_symbols: unbacked_symbol_to_origin_node[fs] = None @@ -3498,8 +3552,8 @@ def speedup_by_fusion( device = node_list_1[0].get_device() assert device - # don't support benchmark fusion for CPU right now. - if device.type == "cpu": + # don't support benchmark fusion for CPU C++ backend right now. + if device.type == "cpu" and config.cpu_backend != "triton": return True node_list_2 = node2.get_nodes() @@ -3569,9 +3623,7 @@ def compile_kernel( future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = [] for hint_override in config.multi_kernel_hints: choice_timings = multi_node.choice_timings(hint_override) - for choice, unfused_time in sorted( - choice_timings.items(), key=lambda x: x[1] - ): + for choice, _ in sorted(choice_timings.items(), key=lambda x: x[1]): if not isinstance( choice, torch._inductor.select_algorithm.TritonTemplateCaller ): @@ -4517,6 +4569,15 @@ def can_fuse( if node1 is node2: return False + # We don't further fuse with FusedMixOrderReductions for now. + # It's not a big deal since the score for fusion with + # mix order reduction is low. When we do this kind of fusion, + # the participants should have already been well fused. + if isinstance(node1, FusedMixOrderReductions) or isinstance( + node2, FusedMixOrderReductions + ): + return False + why = WhyNoFuse(node1, node2) if node1.is_template() and self.get_backend( @@ -4996,21 +5057,21 @@ def should_partition( # Allow users to manually specify if a node should be partitioned # Can only do this for FallbackKernels ir_node = node.node - if isinstance(ir_node, torch._inductor.ir.FallbackKernel): - operator = ir_node.op_overload - if operator is not None and operator in _custom_should_partition_fns: - assert isinstance(operator, torch._ops.OpOverload) - should_partition_fn = _custom_should_partition_fns[operator] - fx_node = ir_node.get_origin_node() - assert fx_node is not None - success, fake_args, fake_kwargs = ( - torch._inductor.fx_utils.get_fake_args_kwargs(fx_node) - ) - assert success, ( - "If this op came from a custom inductor pass, make sure to run FakeTensorUpdator" - ) - should_partition = should_partition_fn(*fake_args, **fake_kwargs) - return should_partition + if isinstance(ir_node, torch._inductor.ir.FallbackKernel) and ( + op := ir_node.op_overload + ): + op_overload_packet_name = op.name() + op_overload_name = ( + f"{op_overload_packet_name}.{op._overloadname}" + if isinstance(op, torch._ops.OpOverload) + else op_overload_packet_name + ) + if ( + op_overload_packet_name in config.custom_should_partition_ops + or op_overload_name in config.custom_should_partition_ops + ): + assert isinstance(op, torch._ops.OpOverload) + return True # When not using cudagraphs, keep all kernels in the `call` function # instead of graph partition functions, since graph partition only brings @@ -5860,8 +5921,8 @@ def speedup_by_combo_kernel(self, nodes: list[BaseSchedulerNode]) -> bool: subkernel_nodes = nodes device = subkernel_nodes[0].get_device() - # don't support benchmark fusion for CPU right now. - if device is None or device.type == "cpu": + # don't support benchmark fusion for CPU C++ backend right now. + if device is None or (device.type == "cpu" and config.cpu_backend != "triton"): return True from triton.compiler.errors import CompilationError diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 6edf351b42a23..95dab86fc35e3 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -2709,8 +2709,10 @@ def __call__( # Templates selected with input_gen_fns require specific input data to avoid IMA # Passing custom input gen fns to benchmark_fusion NYI, so skip deferred template selection - # TODO(jgong5): support multi-template on CPU - if input_gen_fns is not None or layout.device.type == "cpu": + # TODO(jgong5): support multi-template on CPU C++ backend + if input_gen_fns is not None or ( + layout.device.type == "cpu" and config.cpu_backend != "triton" + ): return_multi_template = False # TODO - assert that we have not mutating kernels here @@ -2957,7 +2959,6 @@ def get_timings(hint_override: Optional[int] = None): ) timings = do_autotuning(choices, precompile_fn) - # if timings is empty, we really have no choice but to return a semi-random # choice. returning the first `ExternKernelCaller` is probably the safest bet # in this case, since it will generally be the ATen kernel. if there are no @@ -3680,6 +3681,7 @@ def log_results( dtypes = ", ".join([str(n.get_dtype()) for n in input_nodes]) if config.autotune_num_choices_displayed == 0: return + # when autotune_num_choices_displayed is None, [:None] means all n = config.autotune_num_choices_displayed top_k = sorted(timings, key=timings.__getitem__)[:n] diff --git a/torch/_inductor/tiling_utils.py b/torch/_inductor/tiling_utils.py index 4327637a87207..8b4bab3bca938 100644 --- a/torch/_inductor/tiling_utils.py +++ b/torch/_inductor/tiling_utils.py @@ -425,7 +425,7 @@ def apply_var_mapping( new_ranges, norm_pw_vars + norm_red_vars, strict=True ): range_vars = [] - for i in range(len(new_range)): + for _ in range(len(new_range)): range_vars.append(flat_vars[count]) count += 1 diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index d3b9ee49cb7d2..c7c896c2d10e5 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -348,7 +348,7 @@ def _do_bench_using_profiling( ] ) as p: # Benchmark - for i in range(n_repeat): + for _ in range(n_repeat): # we clear the L2 cache before each run cache.zero_() # record time of `fn` @@ -662,6 +662,7 @@ def sort_func(elem: _T) -> str: P = ParamSpec("P") RV = TypeVar("RV", covariant=True) +FN_TYPE = Callable[Concatenate[Any, P], RV] class CachedMethod(Protocol, Generic[P, RV]): @@ -709,6 +710,52 @@ def cache_property_on_self(fn: Callable[P, RV]) -> CachedMethod[P, RV]: return cache_on_self(fn) +def cache_on_self_and_args( + class_name: str, +) -> Callable[[FN_TYPE[P, RV]], FN_TYPE[P, RV]]: + # include both class_name and fn_name in the key to support `super().fn(self, **args, **kwargs)` calls. + + def wrapper( + fn: FN_TYPE[P, RV], + ) -> FN_TYPE[P, RV]: + key = f"__{class_name}_{fn.__name__}_cache" + + # wrapper is likely on the hot path, compile a specialized version of it + ctx = {"fn": fn} + exec( + f"""\ + def inner(self: Any, *args: P.args, **kwargs: P.kwargs) -> RV: + args_kwargs = (args, tuple(sorted(kwargs.items()))) + + if not hasattr(self, "{key}"): + object.__setattr__(self, "{key}", {{}}) + + cache = self.{key} + + try: + return cache[args_kwargs] + except KeyError: + pass + + rv = fn(self, *args, **kwargs) + + cache[args_kwargs] = rv + return rv + """.lstrip(), + ctx, + ) + inner = functools.wraps(fn)(ctx["inner"]) + + def clear_cache(self: Any) -> None: + if hasattr(self, key): + delattr(self, key) + + inner.clear_cache = clear_cache # type: ignore[attr-defined] + return inner + + return wrapper + + def aggregate_origins( node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel], ) -> OrderedSet[Node]: @@ -2508,11 +2555,14 @@ def get_device_tflops(dtype: torch.dtype) -> float: return get_max_simd_tflops(torch.float32, sm_clock) else: if dtype in (torch.float16, torch.bfloat16) and SM80OrLater: + # pyrefly: ignore # missing-argument return get_max_tensorcore_tflops(dtype) if torch.backends.cuda.matmul.allow_tf32: + # pyrefly: ignore # missing-argument return get_max_tensorcore_tflops(torch.float32) else: + # pyrefly: ignore # missing-argument return get_max_simd_tflops(torch.float32) @@ -2526,6 +2576,7 @@ def get_gpu_dram_gbps() -> int: def get_gpu_shared_memory() -> int: from triton.runtime import driver + # pyrefly: ignore # missing-attribute return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0) @@ -3315,14 +3366,17 @@ def register_op_requires_libdevice_fp64(name: str) -> None: op_requires_libdevice_fp64.add(name) -def get_current_backend() -> str: +def get_current_backend(device_type: Optional[str] = None) -> str: from torch._inductor.virtualized import V - device_str = V.graph.get_current_device_or_throw().type - if device_str == "cpu": + if not device_type: + device_type = V.graph.get_current_device_or_throw().type + if device_type == "cpu": return config.cpu_backend - elif device_str == "mps": + elif device_type == "mps": return "mps" + elif device_type == "xpu": + return config.xpu_backend else: return config.cuda_backend diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index ea1073f88b714..8088078593fbf 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -207,9 +207,13 @@ def _choices_default(): We virtualize InductorChoices to allow changing inductor heuristics from out of tree. """ + from torch._inductor import config from torch._inductor.choices import InductorChoices - rv = InductorChoices() + if config.inductor_choices_class is not None: + rv = config.inductor_choices_class() + else: + rv = InductorChoices() setattr(threadlocal, _choices._key, rv) return rv diff --git a/torch/_inductor/wrapper_benchmark.py b/torch/_inductor/wrapper_benchmark.py index 9a527471c8cc0..56adde809079f 100644 --- a/torch/_inductor/wrapper_benchmark.py +++ b/torch/_inductor/wrapper_benchmark.py @@ -93,6 +93,7 @@ def benchmark_all_kernels( continue triton_kernel = get_triton_kernel(kernel_mod) + device_type = triton_kernel.device_props.type kernel_category = get_kernel_category(kernel_mod) args = kernel_mod.get_args() num_in_out_ptrs = len( @@ -137,7 +138,11 @@ def get_info_str( f" {get_info_str(ms, launcher.n_regs, launcher.n_spills, launcher.shared)} @ {launcher.config}" ) else: - ms = benchmarker.benchmark_gpu(lambda: kernel_mod.call(args), rep=40) + ms = benchmarker.benchmark( + lambda: kernel_mod.call(args), + device=device_type, + rep=40, + ) assert len(triton_kernel.launchers) == 1, ( "Autotuner should have selected the best config" ) diff --git a/torch/_library/fake_class_registry.py b/torch/_library/fake_class_registry.py index b98949b388a91..474df5116e460 100644 --- a/torch/_library/fake_class_registry.py +++ b/torch/_library/fake_class_registry.py @@ -12,14 +12,15 @@ class FakeScriptObject: - def __init__(self, wrapped_obj: Any, script_class_name: str, x: torch.ScriptObject): - self.wrapped_obj = wrapped_obj - - # The fully qualified name of the class of original script object - self.script_class_name = script_class_name + def __init__( + self, wrapped_obj: Any, script_class_name: str, x: Optional[torch.ScriptObject] + ): + # Use object.__setattr__ to bypass our custom __setattr__ during initialization + object.__setattr__(self, "wrapped_obj", wrapped_obj) + object.__setattr__(self, "script_class_name", script_class_name) try: with _disable_current_modes(): - self.real_obj = copy.deepcopy(x) + real_obj = copy.deepcopy(x) except RuntimeError as e: log.warning( # noqa: G200 "Unable to deepcopy the custom object %s due to %s. " @@ -29,7 +30,31 @@ def __init__(self, wrapped_obj: Any, script_class_name: str, x: torch.ScriptObje script_class_name, str(e), ) - self.real_obj = x + real_obj = x + object.__setattr__(self, "real_obj", real_obj) + + def __getattribute__(self, name): + try: + return super().__getattribute__(name) + except AttributeError as e: + raise AttributeError( + f"Tried to call __getattr__ with attr '{name}' on a FakeScriptObject, " + "implying that you are calling this inside of a fake kernel. " + "The fake kernel should not depend on the contents of the " + "OpaqueObject at all, so we're erroring out. If you need this" + "functionality, consider creating a custom TorchBind Object instead" + "(but note that this is more difficult)." + ) from e + + def __setattr__(self, name, value): + raise AttributeError( + f"Tried to call __setattr__ with attr '{name}' on a FakeScriptObject, " + "implying that you are calling this inside of a fake kernel. " + "The fake kernel should not depend on the contents of the " + "OpaqueObject at all, so we're erroring out. If you need this" + "functionality, consider creating a custom TorchBind Object instead" + "(but note that this is more difficult)." + ) class FakeScriptMethod: @@ -125,7 +150,8 @@ def tracing_with_real(x: torch.ScriptObject) -> bool: def maybe_to_fake_obj( - fake_mode, x: torch.ScriptObject + fake_mode, + x: Any, ) -> Union[FakeScriptObject, torch.ScriptObject]: import torch.utils._pytree as pytree from torch.utils._python_dispatch import _disable_current_modes @@ -135,13 +161,17 @@ def maybe_to_fake_obj( if tracing_with_real(x): return x - from torch._library.opaque_object import FakeOpaqueObject, OpaqueTypeStr + from torch._library.opaque_object import ( + FakeOpaqueObject, + is_opaque_type, + OpaqueTypeStr, + ) - if str(x._type()) == OpaqueTypeStr: + if x is None or is_opaque_type(type(x)) or str(x._type()) == OpaqueTypeStr: # In order to make OpaqueObjects truly opaque, the fake kernel should # not depend on the contents of the OpaqueObject at all. - fake_x = FakeOpaqueObject() - + fake_x_wrapped = FakeScriptObject(FakeOpaqueObject(), OpaqueTypeStr, None) + return fake_x_wrapped else: # x.__obj_flatten__() could be calling some tensor operations inside but we don't # want to call these ops in surrounding dispatch modes when executing it. @@ -209,7 +239,8 @@ def maybe_to_fake_obj( if isinstance(real_attr, torch.ScriptMethod): method_schema = real_attr.schema # type: ignore[attr-defined] - setattr( + # Bypasses our custom setattr function + object.__setattr__( fake_x_wrapped, name, FakeScriptMethod(fake_x_wrapped, name, method_schema), diff --git a/torch/_library/triton.py b/torch/_library/triton.py index 761279743f3aa..dc55cb9b34944 100644 --- a/torch/_library/triton.py +++ b/torch/_library/triton.py @@ -215,7 +215,7 @@ def functional_decomp( # type: ignore[no-untyped-def] # the exported program to be high-level and serializable. If we decompose # the custom op to a functional hop and make it a node in exported program, # we need to figure out ways of serializing the hop and its arguments, which can be triton.jited - # functions and triton dtypes. This is undesireble because: + # functions and triton dtypes. This is undesirable because: # - it can be tedious to maintain a layer that serializes the jited function (e.g. with a string) and dtypes. # - exported program will contain the implementation detail (e.g. triton source code) for a specific # backend (GPU), which is probably at a wrong level of abstraction. diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 91080bf5a8b3f..f84b77e630bf3 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -6365,12 +6365,18 @@ def has_zero_dim(tensor_2d): n = mat2.size(1) is_blockwise_scaling = ( - scale_a.dtype == torch.float8_e8m0fnu - and scale_b.dtype == torch.float8_e8m0fnu - ) or ( - scale_a.dtype == torch.float8_e4m3fn - and scale_b.dtype == torch.float8_e4m3fn - ) + ( + scale_a.dtype == torch.float8_e8m0fnu + and scale_b.dtype == torch.float8_e8m0fnu + ) + or ( + scale_a.dtype == torch.float8_e4m3fn + and scale_b.dtype == torch.float8_e4m3fn + ) + ) # note: this applies to blockwise scaling for non-FP8 types (FP8 accepts FP32 scales) + + def ceil_div(a, b): + return (a + b - 1) // b if scale_a.numel() == 1 and scale_b.numel() == 1: # tensorwise scaling @@ -6392,9 +6398,6 @@ def has_zero_dim(tensor_2d): block_size_mn = 128 - def ceil_div(a, b): - return (a + b - 1) // b - num_k_blocks = ceil_div(_k, block_size_k) padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4 @@ -6450,11 +6453,18 @@ def ceil_div(a, b): ) elif ( scale_a.size(0) == m - and scale_a.size(1) == scale_b.size(0) == (_k + 128 - 1) // 128 - and scale_b.size(1) == (n + 128 - 1) // 128 + and scale_a.size(1) == scale_b.size(0) == ceil_div(_k, 128) + and scale_b.size(1) == ceil_div(n, 128) ): # (BlockWise1x128, BlockWise128x128) pass # do nothing, but do not error + elif ( + scale_a.size(0) == m + and scale_a.size(1) == scale_b.size(0) == ceil_div(_k, 128) + and scale_b.size(1) == n + ): + # (BlockWise1x128, BlockWise1x128) + pass # do nothing, but do not error else: # does not match any valid scaling type torch._check( @@ -6463,8 +6473,10 @@ def ceil_div(a, b): "Invalid scaling configuration. " "For tensorwise scaling, both scales should be scalar. " f"For rowwise scaling, scale_a should be ({m}, 1), scale_b should be (1, {n}). " - f"For (BlockWise1x128, BlockWise128x128), scale_a should be ({m}, {(_k + 128 - 1) // 128}), " - + f"scale_b should be ({(_k + 128 - 1) // 128}, {(n + 128 - 1) // 128}). " + f"For (BlockWise1x128, BlockWise128x128), scale_a should be ({m}, {ceil_div(_k, 128)}), " + + f"scale_b should be ({ceil_div(_k, 128)}, {ceil_div(n, 128)}). " + f"For (BlockWise1x128, BlockWise1x128), scale_a should be ({m}, {ceil_div(_k, 128)}), " + + f"scale_b should be ({ceil_div(_k, 128)}, {n}). " f"Got scale_a.size()=({scale_a.size(0)}, {scale_a.size(1)}) " f"and scale_b.size()=({scale_b.size(0)}, {scale_b.size(1)})" ), diff --git a/torch/_ops.py b/torch/_ops.py index 95f78ca7f32a9..9cdf735532d7d 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -437,7 +437,7 @@ def check_overloaded(arg): subclass_type = type(arg) if ( subclass_type.__torch_dispatch__ - == torch._C._disabled_torch_dispatch_impl + is torch._C._disabled_torch_dispatch_impl ): continue @@ -530,7 +530,7 @@ def __call__(self, /, *args, **kwargs): dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys) return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs) - # NOTE [HigherOrderOprator Schema] + # NOTE [HigherOrderOperator Schema] # Each invocation of a HigherOrderOperator (hop) should have its own schema because # the subgraphs and the arguments can be different even for the same hop. # diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 58a6e8c3c2a6d..9224643fe55ab 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -3155,7 +3155,7 @@ def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorL # Tries to take a view # TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view) - # Unbacked semnatics: if validty of in-place flattening is undecided we copy. + # Unbacked semantics: if validity of in-place flattening is undecided we copy. new_shape, _new_strides = prims._collapse_view_helper( a, start_dim, end_dim, must_be_valid=None ) diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 0b000cfa1a9aa..ff309af8a29e0 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -73,15 +73,25 @@ def is_noncontiguous_supported(device): aten.ones_like.default, aten.ones_like.out, aten.rand_like.default, + aten.rand_like.generator, aten.rand_like.out, + aten.rand_like.generator_out, aten.randn_like.default, + aten.randn_like.generator, aten.randn_like.out, + aten.randn_like.generator_out, aten.randint_like.default, + aten.randint_like.generator, aten.randint_like.Tensor, + aten.randint_like.Tensor_generator, aten.randint_like.Tensor_out, + aten.randint_like.Tensor_generator_out, aten.randint_like.out, + aten.randint_like.generator_out, aten.randint_like.low_dtype, + aten.randint_like.low_generator_dtype, aten.randint_like.low_dtype_out, + aten.randint_like.low_generator_dtype_out, aten.zeros_like.default, aten.zeros_like.out, aten.new_empty.default, @@ -1338,9 +1348,11 @@ def slow(msg): continue if common_device == cpu and op.device.type != "cpu": common_device = op.device - # Slightly simplified here as target_dtype cannot vary if common_dtype is None: - common_dtype = op.dtype + if type_promotion_kind != ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT: + has_different_input_dtypes = True + else: + common_dtype = op.dtype elif common_dtype != op.dtype: has_different_input_dtypes = True diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index bf1ed1ff2b111..d682db9312afd 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -827,7 +827,7 @@ def __torch_dispatch__( # type: ignore[override] # TODO ) -> object: # need to handle here to avoid infinite recursion # see [in_kernel_invocation] - if func == torch.ops.prim.device.default: + if func is torch.ops.prim.device.default: assert len(args) == 1 and isinstance(args[0], FakeTensor) if args[0].fake_mode.in_kernel_invocation: return torch.device("meta") @@ -1665,7 +1665,7 @@ def _validate_cache_key( if torch.Tag.inplace_view in func.tags: raise _BypassDispatchCache("inplace view") - if func == aten._unsafe_view.default: + if func is aten._unsafe_view.default: raise _BypassDispatchCache("unsafe view") if func in self.lift_fns: @@ -2378,12 +2378,12 @@ def _dispatch_impl( avoiding_device_init = False if self.avoid_device_init: if ( - func == torch.ops.aten._to_copy.default + func is torch.ops.aten._to_copy.default and "device" in kwargs and kwargs["device"].type != "cpu" # type: ignore[attr-defined] ): avoiding_device_init = True - if func == torch.ops.prims.device_put.default: + if func is torch.ops.prims.device_put.default: avoiding_device_init = True # skip const prop for aten._to_copy if @@ -3118,7 +3118,7 @@ def _validate_symbolic_output_for_caching( if is_tracing: # Check for SymNode types in PROXY mode - this should bypass caching # regardless of whether symbols are known or not - for node in _iterate_nodes(output): + for _ in _iterate_nodes(output): raise _BypassDispatchCache("Proxy mode with SymNode output") else: # Check for unrepresented symbols in tensor expressions @@ -3226,12 +3226,12 @@ def __torch_function__( kwargs = kwargs if kwargs else {} # clone will get called in Parameter deepcopy - if func == torch._C.TensorBase.clone: + if func is torch._C.TensorBase.clone: assert isinstance(args[0], Tensor) return func( self.fake_mode.from_tensor(args[0], static_shapes=True), **kwargs ) - elif func == Tensor.__deepcopy__: + elif func is Tensor.__deepcopy__: assert len(args) == 2 and len(kwargs) == 0 tensor = cast(Tensor, args[0]) memo = cast(dict[int, FakeTensor], args[1]) diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index 206a41100b935..a8329d11e7ea1 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -16,6 +16,7 @@ from torch.utils._python_dispatch import ( _detect_infra_mode, _disable_infra_mode, + autograd_would_have_decomposed, return_and_correct_aliasing, TorchDispatchMode, ) @@ -376,7 +377,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): def _can_decompose(func): # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832 # Never decompose dropout in export - if self.export and func == torch.ops.aten.dropout.default: + if self.export and func is torch.ops.aten.dropout.default: return False # We unconditionally decompose ops that are maybe aliasing or mutating ops @@ -412,8 +413,13 @@ def _can_decompose(func): return False return True - # in normal torch.compile IR, we decompose functional composite ops - return True + # in normal torch.compile IR, we only decompose an op if autograd + # would have decomposed it (NB: autograd may have been skipped if + # we are in inference mode) + # TODO: the flatten here can potentially be deduped with the + # unwrapping pytree_map later + flat_args_kwargs, _ = pytree.tree_flatten((args, kwargs)) + return autograd_would_have_decomposed(func, flat_args_kwargs) if ( func not in FunctionalTensor.metadata_fns @@ -549,7 +555,7 @@ def unwrap(x): ) if self.export: - if func == torch.ops.aten.dropout.default: + if func is torch.ops.aten.dropout.default: torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined] outs_wrapped = pytree.tree_map_only( torch.Tensor, wrap, outs_unwrapped @@ -576,7 +582,7 @@ def unwrap(x): # aliasing correction step. Otherwise, we would be setting the storage of a # lifted tensor to that of an unlifted tensor. # Ref: https://github.com/pytorch/pytorch/issues/111506 - or func == torch.ops.aten.lift_fresh.default + or func is torch.ops.aten.lift_fresh.default ): return outs_wrapped # for metadata mutations, need to manually mutate the metadata of the FunctionalTensor wrapper diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index 4b6a05a3085d4..f56800367af45 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -1860,7 +1860,7 @@ def is_c_of_r( nt_tensor_id=t.nested_int ) - # pyrefly: ignore [bad-argument-type] + # pyrefly: ignore [bad-argument-type, unbound-name] self.set_tensor_memo(t, r) return self._checked_get_tensor_memo(t) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 3a8c2083afac6..144d433e9a026 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -156,6 +156,7 @@ def merge_dicts(*dicts): input (Tensor): the size of :attr:`input` will determine size of the output tensor. layout (:class:`torch.layout`, optional): the desired layout of returned tensor. Default: if ``None``, defaults to the layout of :attr:`input`. + generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling. dtype (:class:`torch.dtype`, optional): the desired data type of returned Tensor. Default: if ``None``, defaults to the dtype of :attr:`input`. device (:class:`torch.device`, optional): the desired device of returned tensor. @@ -9012,9 +9013,11 @@ def merge_dicts(*dicts): add_docstr( torch.rand_like, - r""" -rand_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor - + """ +rand_like(input, *, generator=None, dtype=None, layout=None, device=None, \ +requires_grad=False, memory_format=torch.preserve_format) -> Tensor +""" + + r""" Returns a tensor with the same size as :attr:`input` that is filled with random numbers from a uniform distribution on the interval :math:`[0, 1)`. ``torch.rand_like(input)`` is equivalent to @@ -9024,6 +9027,7 @@ def merge_dicts(*dicts): {input} Keyword args: + {generator} {dtype} {layout} {device} @@ -9084,9 +9088,10 @@ def merge_dicts(*dicts): add_docstr( torch.randint_like, """ -randint_like(input, low=0, high, \\*, dtype=None, layout=torch.strided, device=None, requires_grad=False, \ -memory_format=torch.preserve_format) -> Tensor - +randint_like(input, low=0, high, \\*, generator=None, dtype=None, layout=torch.strided, \ +device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor +""" + + r""" Returns a tensor with the same shape as Tensor :attr:`input` filled with random integers generated uniformly between :attr:`low` (inclusive) and :attr:`high` (exclusive). @@ -9101,6 +9106,7 @@ def merge_dicts(*dicts): high (int): One above the highest integer to be drawn from the distribution. Keyword args: + {generator} {dtype} {layout} {device} @@ -9168,9 +9174,11 @@ def merge_dicts(*dicts): add_docstr( torch.randn_like, - r""" -randn_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor - + """ +randn_like(input, *, generator=None, dtype=None, layout=None, device=None, \ +requires_grad=False, memory_format=torch.preserve_format) -> Tensor +""" + + r""" Returns a tensor with the same size as :attr:`input` that is filled with random numbers from a normal distribution with mean 0 and variance 1. Please refer to :func:`torch.randn` for the sampling process of complex dtypes. ``torch.randn_like(input)`` is equivalent to @@ -9180,6 +9188,7 @@ def merge_dicts(*dicts): {input} Keyword args: + {generator} {dtype} {layout} {device} diff --git a/torch/ao/ns/_numeric_suite.py b/torch/ao/ns/_numeric_suite.py index 1c8e751b1ebdc..a4b873cb7d168 100644 --- a/torch/ao/ns/_numeric_suite.py +++ b/torch/ao/ns/_numeric_suite.py @@ -137,7 +137,7 @@ def _get_logger_dict_helper( def get_prefix(prefix): return prefix if prefix == "" else prefix + "." - for name, child in mod.named_children(): + for child in mod.children(): if isinstance(child, Logger): target_dict[get_prefix(prefix) + "stats"] = child.stats break diff --git a/torch/ao/ns/fx/graph_passes.py b/torch/ao/ns/fx/graph_passes.py index 9a93e9ad582d7..1b1726499445f 100644 --- a/torch/ao/ns/fx/graph_passes.py +++ b/torch/ao/ns/fx/graph_passes.py @@ -909,8 +909,7 @@ def load_arg(a): # is added prev_node_c_list = [env_c[arg.name] for arg in prev_node_b] - for arg_idx, arg in enumerate(prev_node_b): - prev_node_c = prev_node_c_list[arg_idx] + for arg_idx, prev_node_c in enumerate(prev_node_c_list): env_c[prev_node_c.name] = _insert_logger_after_node( prev_node_c, gm_b, diff --git a/torch/ao/ns/fx/n_shadows_utils.py b/torch/ao/ns/fx/n_shadows_utils.py index 9adca1a7751ab..f16700994d095 100644 --- a/torch/ao/ns/fx/n_shadows_utils.py +++ b/torch/ao/ns/fx/n_shadows_utils.py @@ -1050,7 +1050,7 @@ def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module): raise AssertionError(f"Expected exactly 1, got {len(shadow_n.users)}") quant_node = next(iter(shadow_n.users.keys())) new_args: Any = None - if quant_node.target == torch.quantize_per_channel: + if quant_node.target is torch.quantize_per_channel: _weight, scale_node, zp_node, axis, dtype = quant_node.args scale_val = getattr_from_fqn(shadow_wrapper, scale_node.target) zp_val = getattr_from_fqn(shadow_wrapper, zp_node.target) diff --git a/torch/ao/ns/fx/utils.py b/torch/ao/ns/fx/utils.py index a7541e8a50c79..3423d8533204a 100644 --- a/torch/ao/ns/fx/utils.py +++ b/torch/ao/ns/fx/utils.py @@ -204,7 +204,7 @@ def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx): if prev_node.op == "call_function": # quantize - read the args directly - if prev_node.target == torch.quantize_per_tensor: + if prev_node.target is torch.quantize_per_tensor: return _get_scale_zp_from_function_args(prev_node, gm, 1, 2) elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu): return _get_scale_zp_from_function_args(prev_node, gm, 2, 3) diff --git a/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py b/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py index d536245b0e9bb..ef43d7a1f7de2 100644 --- a/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py +++ b/torch/ao/pruning/_experimental/activation_sparsifier/activation_sparsifier.py @@ -128,13 +128,15 @@ def _safe_rail_checks(args): # if features are not None, then feature_dim must not be None features, feature_dim = args["features"], args["feature_dim"] if features is not None: - assert feature_dim is not None, "need feature dim to select features" + if feature_dim is None: + raise AssertionError("need feature dim to select features") # all the *_fns should be callable fn_keys = ["aggregate_fn", "reduce_fn", "mask_fn"] for key in fn_keys: fn = args[key] - assert callable(fn), "function should be callable" + if not callable(fn): + raise AssertionError(f"{fn} must be callable") def _aggregate_hook(self, name): """Returns hook that computes aggregate of activations passing through.""" @@ -209,7 +211,8 @@ def register_layer( - All the functions (fn) passed as argument will be called at a dim, feature level. """ name = module_to_fqn(self.model, layer) - assert name is not None, "layer not found in the model" # satisfy mypy + if name is None: + raise AssertionError("layer not found in the model") if name in self.data_groups: # unregister layer if already present warnings.warn( @@ -261,14 +264,15 @@ def get_mask(self, name: Optional[str] = None, layer: Optional[nn.Module] = None Hence, if get_mask() is called before model.forward(), an error will be raised. """ - assert name is not None or layer is not None, ( - "Need at least name or layer obj to retrieve mask" - ) + if name is None and layer is None: + raise AssertionError("Need at least name or layer obj to retrieve mask") if name is None: - assert layer is not None + if layer is None: + raise AssertionError("layer must be provided when name is None") name = module_to_fqn(self.model, layer) - assert name is not None, "layer not found in the specified model" + if name is None: + raise AssertionError("layer not found in the specified model") if name not in self.state: raise ValueError("Error: layer with the given name not found") @@ -451,7 +455,8 @@ def __set_state__(self, state: dict[str, Any]) -> None: for name, config in self.data_groups.items(): # fetch layer layer = fqn_to_module(self.model, name) - assert layer is not None # satisfy mypy + if layer is None: + raise AssertionError(f"layer {name} not found in the model") # if agg_mode is True, then layer in aggregate mode if "hook_state" in config and config["hook_state"] == "aggregate": diff --git a/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py b/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py index 0db7becdda5b1..07040584231e1 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/base_data_sparsifier.py @@ -91,9 +91,10 @@ def add_data(self, name: str, data, reuse_mask=True, **config): 4. By default, the config of the replaced data is used as config for the replacing data, unless something is specified in the config dictionary. """ - assert type(data) in SUPPORTED_TYPES, ( - "specified data type not supported at the moment" - ) + if type(data) not in SUPPORTED_TYPES: + raise AssertionError( + f"specified data type:{type(data)} not supported at the moment" + ) local_args = copy.deepcopy(self.defaults) local_args.update(config) weight = self._extract_weight(data) @@ -116,9 +117,10 @@ def add_data(self, name: str, data, reuse_mask=True, **config): if reuse_mask: current_data = self.get_data(name=name) - assert weight.shape == current_data.shape, ( - "to retain the old mask, the shape of the new data must be the same as the previous one" - ) + if weight.shape != current_data.shape: + raise AssertionError( + "to retain the old mask, the shape of the new data must be the same as the previous one" + ) mask = self.get_mask( name=name ) # reuse mask instead of creating a new one diff --git a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py index ff4b4f913f503..4dccb52ee24fb 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/data_norm_sparsifier.py @@ -47,7 +47,8 @@ def __init__( if zeros_per_block is None: zeros_per_block = reduce(operator.mul, sparse_block_shape) - assert norm in ["L1", "L2"], "only L1 and L2 norm supported at the moment" + if norm not in ["L1", "L2"]: + raise AssertionError("only L1 and L2 norm supported at the moment") defaults = { "sparsity_level": sparsity_level, diff --git a/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py b/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py index b2943e2af1a87..5c3dbde4c3d4c 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py @@ -66,17 +66,20 @@ def post_training_sparse_quantize( else: embedding_modules = [] - assert isinstance(select_embeddings, list), ( - "the embedding_modules must be a list of embedding modules" - ) - for emb in select_embeddings: - assert type(emb) in SUPPORTED_MODULES, ( - "the embedding_modules list must be an embedding or embedding bags" + if not isinstance(select_embeddings, list): + raise AssertionError( + "the embedding_modules must be a list of embedding modules" ) + for emb in select_embeddings: + if type(emb) not in SUPPORTED_MODULES: + raise AssertionError( + "the embedding_modules list must be an embedding or embedding bags" + ) fqn_name = module_to_fqn(model, emb) - assert fqn_name is not None, ( - "the embedding modules must be part of input model" - ) + if fqn_name is None: + raise AssertionError( + "the embedding modules must be part of input model" + ) embedding_modules.append((fqn_name, emb)) if sparsify_first: @@ -114,7 +117,8 @@ def post_training_sparse_quantize( for name, _ in embedding_modules: quantized_emb = fqn_to_module(model, name) - assert quantized_emb is not None # satisfy mypy + if quantized_emb is None: + raise AssertionError(f"quantized embedding {name} not found in model") quantized_weight = quantized_emb.weight() # type: ignore[operator] quantize_params["scales"][name] = quantized_weight.q_per_channel_scales() @@ -138,7 +142,8 @@ def post_training_sparse_quantize( for name, _ in embedding_modules: quantized_emb = fqn_to_module(model, name) - assert quantized_emb is not None # satisfy mypy + if quantized_emb is None: + raise AssertionError(f"quantized embedding {name} not found in model") requantized_vector = torch.quantize_per_channel( quantize_params["dequant_weights"][name], scales=quantize_params["scales"][name], diff --git a/torch/ao/pruning/_experimental/pruner/parametrization.py b/torch/ao/pruning/_experimental/pruner/parametrization.py index 58b3f7651caab..4256d6fd01750 100644 --- a/torch/ao/pruning/_experimental/pruner/parametrization.py +++ b/torch/ao/pruning/_experimental/pruner/parametrization.py @@ -28,8 +28,12 @@ def __init__(self, mask): self.register_buffer("mask", mask) def forward(self, x): - assert isinstance(self.mask, torch.Tensor) - assert self.mask.shape[0] == x.shape[0] + if not isinstance(self.mask, torch.Tensor): + raise AssertionError("mask must be a torch.Tensor") + if self.mask.shape[0] != x.shape[0]: + raise AssertionError( + f"mask shape[0] ({self.mask.shape[0]}) must match x shape[0] ({x.shape[0]})" + ) shape = [1] * len(x.shape) shape[0] = -1 return self.mask.reshape(shape) * x diff --git a/torch/ao/pruning/_experimental/pruner/prune_functions.py b/torch/ao/pruning/_experimental/pruner/prune_functions.py index 4294ee04f9f3e..c567e5771859d 100644 --- a/torch/ao/pruning/_experimental/pruner/prune_functions.py +++ b/torch/ao/pruning/_experimental/pruner/prune_functions.py @@ -332,9 +332,10 @@ def prune_conv2d_pool_flatten_linear( linear_ic = linear.weight.shape[1] conv2d_oc = len(mask) - assert linear_ic % conv2d_oc == 0, ( - f"Flattening from dimensions {conv2d_oc} to {linear_ic} not supported" - ) + if linear_ic % conv2d_oc != 0: + raise AssertionError( + f"Flattening from dimensions {conv2d_oc} to {linear_ic} not supported" + ) flatten_scale = linear_ic // conv2d_oc flattened_mask = torch.tensor( diff --git a/torch/ao/pruning/_experimental/pruner/saliency_pruner.py b/torch/ao/pruning/_experimental/pruner/saliency_pruner.py index 1a97cff7ab231..11c4652a7f0da 100644 --- a/torch/ao/pruning/_experimental/pruner/saliency_pruner.py +++ b/torch/ao/pruning/_experimental/pruner/saliency_pruner.py @@ -23,7 +23,10 @@ def update_mask(self, module, tensor_name, **kwargs): "Structured pruning can only be applied to a 2+dim weight tensor!" ) saliency = -weights.norm(dim=tuple(range(1, weights.dim())), p=1) - assert saliency.shape == mask.shape + if saliency.shape != mask.shape: + raise AssertionError( + f"saliency shape ({saliency.shape}) must match mask shape ({mask.shape})" + ) num_to_pick = int(len(mask) * kwargs["sparsity_level"]) prune = saliency.topk(num_to_pick).indices diff --git a/torch/ao/pruning/sparsifier/base_sparsifier.py b/torch/ao/pruning/sparsifier/base_sparsifier.py index 47ff1b86488ff..14764c77cc604 100644 --- a/torch/ao/pruning/sparsifier/base_sparsifier.py +++ b/torch/ao/pruning/sparsifier/base_sparsifier.py @@ -149,7 +149,8 @@ def make_config_from_model( for _name, child in module.named_children(): if type(child) in SUPPORTED_MODULES: module_fqn = module_to_fqn(model, child) - assert isinstance(module_fqn, str) # for mypy + if not isinstance(module_fqn, str): + raise AssertionError("module_fqn must be a string") self.config.append({"tensor_fqn": module_fqn + ".weight"}) else: stack.append(child) @@ -172,20 +173,23 @@ def prepare(self, model, config): # TODO: Remove the configuration by reference ('module') # pyrefly: ignore [not-iterable] for module_config in self.config: - assert isinstance(module_config, dict), ( - "config elements should be dicts not modules i.e.:" - "[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]" - ) + if not isinstance(module_config, dict): + raise AssertionError( + "config elements should be dicts not modules i.e.:" + "[{`tensor_fqn`: `foo.bar.weight`}, {`tensor_fqn`: ... }, ...]" + ) - assert isinstance(self.defaults, dict) # for mypy + if not isinstance(self.defaults, dict): + raise AssertionError("defaults must be a dict") local_args = copy.deepcopy(self.defaults) local_args.update(module_config) tensor_fqn = local_args.get("tensor_fqn", None) - assert tensor_fqn is not None, ( - "tensor_fqn is a required argument in the sparsity config which" - "replaces previous `module` and [module]`fqn` arguments" - ) + if tensor_fqn is None: + raise AssertionError( + "tensor_fqn is a required argument in the sparsity config which" + "replaces previous `module` and [module]`fqn` arguments" + ) # populate all information from tensor_fqn info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn) @@ -194,16 +198,17 @@ def prepare(self, model, config): # from tensor_fqn for key in info_from_tensor_fqn.keys(): if key in local_args: - assert ( + if not ( info_from_tensor_fqn[key] == local_args[key] or ( key == "tensor_fqn" and "." + info_from_tensor_fqn[key] == local_args[key] ) # info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that - ), ( - f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!" - ) + ): + raise AssertionError( + f"Given both `{key}` and `tensor_fqn` in the config, it is expected them to agree!" + ) local_args.update(info_from_tensor_fqn) self.groups.append(local_args) self._prepare() diff --git a/torch/ao/pruning/sparsifier/utils.py b/torch/ao/pruning/sparsifier/utils.py index f16a309583683..a852b35017fcd 100644 --- a/torch/ao/pruning/sparsifier/utils.py +++ b/torch/ao/pruning/sparsifier/utils.py @@ -53,9 +53,10 @@ def swap_module( # respect device affinity when swapping modules # pyrefly: ignore [bad-argument-type] devices = {p.device for p in chain(mod.parameters(), mod.buffers())} - assert len(devices) <= 1, ( - f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" - ) + if len(devices) > 1: + raise AssertionError( + f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" + ) device = next(iter(devices)) if len(devices) > 0 else None if device: new_mod.to(device) @@ -129,7 +130,10 @@ def __init__(self, mask): self.register_buffer("mask", mask) def forward(self, x): - assert self.mask.shape == x.shape + if self.mask.shape != x.shape: + raise AssertionError( + f"mask shape ({self.mask.shape}) must match x shape ({x.shape})" + ) return self.mask * x def state_dict(self, *args, **kwargs): diff --git a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py index bb10331826e34..a3645dc3ab872 100644 --- a/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py +++ b/torch/ao/pruning/sparsifier/weight_norm_sparsifier.py @@ -95,7 +95,8 @@ def _scatter_fold_block_mask( ): r"""Creates patches of size `block_shape` after scattering the indices.""" if mask is None: - assert input_shape is not None + if input_shape is None: + raise AssertionError("input_shape must be provided when mask is None") mask = torch.ones(input_shape, device=device) mask.scatter_(dim=dim, index=indices, value=0) mask.data = F.fold( diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index b07494e9a855f..f66a0640fcadc 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -235,9 +235,10 @@ def __init__( from .utils import is_per_channel if is_per_channel(self.qscheme): - assert self.ch_axis is not None, ( - "Must provide a valid ch_axis if qscheme is per channel" - ) + if self.ch_axis is None: + raise AssertionError( + "Must provide a valid ch_axis if qscheme is per channel" + ) def forward(self, x: Tensor) -> Tensor: return x diff --git a/torch/ao/quantization/_correct_bias.py b/torch/ao/quantization/_correct_bias.py index 3f480486893d4..4309e4530cb72 100644 --- a/torch/ao/quantization/_correct_bias.py +++ b/torch/ao/quantization/_correct_bias.py @@ -151,6 +151,6 @@ def bias_correction( bias.data = updated_bias # Resets the data contained in the loggers - for name, submodule in quantized_model.named_modules(): + for submodule in quantized_model.modules(): if isinstance(submodule, MeanShadowLogger): submodule.clear() diff --git a/torch/ao/quantization/_equalize.py b/torch/ao/quantization/_equalize.py index 5d79f7f71b4f2..a78dd307fc6d6 100644 --- a/torch/ao/quantization/_equalize.py +++ b/torch/ao/quantization/_equalize.py @@ -92,9 +92,10 @@ def channel_range(input, axis=0): mins = min_over_ndim(input, axis_list) maxs = max_over_ndim(input, axis_list) - assert mins.size(0) == input.size(axis), ( - "Dimensions of resultant channel range does not match size of requested axis" - ) + if mins.size(0) != input.size(axis): + raise AssertionError( + "Dimensions of resultant channel range does not match size of requested axis" + ) return maxs - mins diff --git a/torch/ao/quantization/_learnable_fake_quantize.py b/torch/ao/quantization/_learnable_fake_quantize.py index d12c96f66c009..00b824f8d1ecf 100644 --- a/torch/ao/quantization/_learnable_fake_quantize.py +++ b/torch/ao/quantization/_learnable_fake_quantize.py @@ -45,7 +45,8 @@ def __init__( **observer_kwargs, ): super().__init__() - assert quant_min < quant_max, "quant_min must be strictly less than quant_max." + if quant_min >= quant_max: + raise AssertionError("quant_min must be strictly less than quant_max.") self.quant_min = quant_min self.quant_max = quant_max # also pass quant_min and quant_max to observer @@ -56,19 +57,16 @@ def __init__( self.scale = Parameter(torch.tensor([scale])) self.zero_point = Parameter(torch.tensor([zero_point])) else: - assert isinstance(channel_len, int) and channel_len > 0, ( - "Channel size must be a positive integer." - ) + if not (isinstance(channel_len, int) and channel_len > 0): + raise AssertionError("Channel size must be a positive integer.") self.scale = Parameter(torch.tensor([scale] * channel_len)) self.zero_point = Parameter(torch.tensor([zero_point] * channel_len)) self.activation_post_process = observer(**observer_kwargs) - assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, ( - "quant_min out of bound" - ) - assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, ( - "quant_max out of bound" - ) + if torch.iinfo(self.activation_post_process.dtype).min > quant_min: + raise AssertionError("quant_min out of bound") + if quant_max > torch.iinfo(self.activation_post_process.dtype).max: + raise AssertionError("quant_max out of bound") self.dtype = self.activation_post_process.dtype self.qscheme = self.activation_post_process.qscheme self.ch_axis = ( diff --git a/torch/ao/quantization/backend_config/onednn.py b/torch/ao/quantization/backend_config/onednn.py index 348cec62ea18a..3cc7a2cf4c669 100644 --- a/torch/ao/quantization/backend_config/onednn.py +++ b/torch/ao/quantization/backend_config/onednn.py @@ -88,9 +88,10 @@ def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu): >>> lr = nn.LeakyReLU(0.01) >>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr) """ - assert linear.training == bn.training and bn.training == leaky_relu.training, ( - "Linear, BN and LeakyReLU all must be in the same mode (train or eval)." - ) + if linear.training != bn.training or bn.training != leaky_relu.training: + raise AssertionError( + "Linear, BN and LeakyReLU all must be in the same mode (train or eval)." + ) if is_qat: raise NotImplementedError( diff --git a/torch/ao/quantization/backend_config/utils.py b/torch/ao/quantization/backend_config/utils.py index 65094392abf8e..0758c6a3b59d8 100644 --- a/torch/ao/quantization/backend_config/utils.py +++ b/torch/ao/quantization/backend_config/utils.py @@ -164,10 +164,11 @@ def remove_boolean_dispatch_from_name(p) -> Any: return "torch.nn.functional.adaptive_max_pool2d" elif p is F.adaptive_max_pool3d: return "torch.nn.functional.adaptive_max_pool3d" - assert "boolean_dispatch" not in str(p), ( - f"{p} does not have a human readable representation in " - + "quantization documentation" - ) + if "boolean_dispatch" in str(p): + raise AssertionError( + f"{p} does not have a human readable representation in " + + "quantization documentation" + ) return p @@ -300,7 +301,8 @@ def _get_fuser_method_in_reversed_nested_tuple_format( The first argument of a fuser method is always `is_qat` and is not affected in the conversion. We currently only support functions with 3 or 4 arguments. """ - assert config.fuser_method is not None + if config.fuser_method is None: + raise AssertionError("config.fuser_method must be provided") if config._pattern_complex_format is not None: return config.fuser_method if not isinstance(config.pattern, tuple): diff --git a/torch/ao/quantization/fake_quantize.py b/torch/ao/quantization/fake_quantize.py index a1ebffebb7d9d..c4a380946c8a0 100644 --- a/torch/ao/quantization/fake_quantize.py +++ b/torch/ao/quantization/fake_quantize.py @@ -175,9 +175,10 @@ def __init__( super().__init__() # Populate quant_min/quant_max to observer_kwargs if valid if quant_min is not None and quant_max is not None: - assert quant_min <= quant_max, ( - "quant_min must be less than or equal to quant_max" - ) + if quant_min > quant_max: + raise AssertionError( + "quant_min must be less than or equal to quant_max" + ) dtype = observer_kwargs.get("dtype", torch.quint8) if hasattr(observer, "p"): # In case observer is _PartialWrapper, dtype can be stored in @@ -186,9 +187,11 @@ def __init__( "dtype", dtype ) # pyrefly: ignore [bad-argument-type] - assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound" + if torch.iinfo(dtype).min > quant_min: + raise AssertionError("quant_min out of bound") # pyrefly: ignore [bad-argument-type] - assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound" + if quant_max > torch.iinfo(dtype).max: + raise AssertionError("quant_max out of bound") observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max}) observer_kwargs["is_dynamic"] = is_dynamic self.activation_post_process = observer(**observer_kwargs) @@ -210,11 +213,12 @@ def __init__( if hasattr(self.activation_post_process, "ch_axis") else -1 ) - assert _is_per_channel(self.qscheme) or _is_per_tensor(self.qscheme), ( - "Only per channel and per tensor quantization are supported in fake quantize" - + " got qscheme: " - + str(self.qscheme) - ) + if not (_is_per_channel(self.qscheme) or _is_per_tensor(self.qscheme)): + raise AssertionError( + "Only per channel and per tensor quantization are supported in fake quantize" + + " got qscheme: " + + str(self.qscheme) + ) self.is_per_channel = _is_per_channel(self.qscheme) @torch.jit.export @@ -295,7 +299,10 @@ def _load_from_state_dict( if name == "scale": self.scale.resize_(val.shape) else: - assert name == "zero_point" + if name != "zero_point": + raise AssertionError( + "Expected 'zero_point' but got different state key" + ) self.zero_point.resize_(val.shape) # For torchscript module we need to update the attributes here since we do not # call the `_load_from_state_dict` function defined module.py @@ -303,7 +310,10 @@ def _load_from_state_dict( if name == "scale": self.scale.copy_(val) else: - assert name == "zero_point" + if name != "zero_point": + raise AssertionError( + "Expected 'zero_point' but got different state key" + ) self.zero_point.copy_(val) elif strict: missing_keys.append(key) @@ -329,17 +339,19 @@ class FixedQParamsFakeQuantize(FakeQuantize): # TODO: rename observer to observer_ctr def __init__(self, observer): super().__init__(observer=observer) - assert type(self.activation_post_process) is FixedQParamsObserver, ( - f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}" - ) + if type(self.activation_post_process) is not FixedQParamsObserver: + raise AssertionError( + f"{self.__class__.__name__}'s observer must be a {FixedQParamsObserver.__name__}" + ) self._observer_ctr = observer self.scale = self.activation_post_process.scale self.zero_point = self.activation_post_process.zero_point - assert _is_per_tensor(self.qscheme), ( - "Only per tensor quantization is supported" - + " FixedQParamsFakeQuantize module, got qscheme:" - + str(self.qscheme) - ) + if not _is_per_tensor(self.qscheme): + raise AssertionError( + "Only per tensor quantization is supported" + + " FixedQParamsFakeQuantize module, got qscheme:" + + str(self.qscheme) + ) @torch.jit.export def calculate_qparams(self): # type: ignore[override] @@ -382,12 +394,13 @@ def __init__( **observer_kwargs: Any, ) -> None: super().__init__(observer, quant_min, quant_max, **observer_kwargs) - assert isinstance( + if not isinstance( self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver), - ), ( - "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver" - ) + ): + raise AssertionError( + "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver" + ) self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long)) self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long)) self.is_symmetric_quant = _is_symmetric_quant( diff --git a/torch/ao/quantization/fuser_method_mappings.py b/torch/ao/quantization/fuser_method_mappings.py index f5fd2cad48826..4eef33698d100 100644 --- a/torch/ao/quantization/fuser_method_mappings.py +++ b/torch/ao/quantization/fuser_method_mappings.py @@ -35,9 +35,10 @@ def fuse_conv_bn(is_qat, conv, bn): >>> # xdoctest: +SKIP >>> m2 = fuse_conv_bn(m1, b1) """ - assert conv.training == bn.training, ( - "Conv and BN both must be in the same mode (train or eval)." - ) + if conv.training != bn.training: + raise AssertionError( + "Conv and BN both must be in the same mode (train or eval)." + ) fused_module_class_map = { nn.Conv1d: nni.ConvBn1d, @@ -46,13 +47,18 @@ def fuse_conv_bn(is_qat, conv, bn): } if is_qat: - assert bn.num_features == conv.out_channels, ( - "Output channel of Conv2d must match num_features of BatchNorm2d" - ) - assert bn.affine, "Only support fusing BatchNorm2d with affine set to True" - assert bn.track_running_stats, ( - "Only support fusing BatchNorm2d with tracking_running_stats set to True" - ) + if bn.num_features != conv.out_channels: + raise AssertionError( + "Output channel of Conv2d must match num_features of BatchNorm2d." + ) + if not bn.affine: + raise AssertionError( + "Only support fusing BatchNorm2d with affine set to True" + ) + if not bn.track_running_stats: + raise AssertionError( + "Only support fusing BatchNorm2d with tracking_running_stats set to True" + ) fused_module_class = fused_module_class_map.get((type(conv)), None) if fused_module_class is not None: return fused_module_class(conv, bn) @@ -81,9 +87,10 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu): >>> # xdoctest: +SKIP >>> m2 = fuse_conv_bn_relu(m1, b1, r1) """ - assert conv.training == bn.training == relu.training, ( - "Conv and BN both must be in the same mode (train or eval)." - ) + if not (conv.training == bn.training == relu.training): + raise AssertionError( + "Conv and BN both must be in the same mode (train or eval)." + ) fused_module: Optional[type[nn.Sequential]] = None if is_qat: map_to_fused_module_train = { @@ -91,13 +98,18 @@ def fuse_conv_bn_relu(is_qat, conv, bn, relu): nn.Conv2d: nni.ConvBnReLU2d, nn.Conv3d: nni.ConvBnReLU3d, } - assert bn.num_features == conv.out_channels, ( - "Output channel of Conv must match num_features of BatchNorm" - ) - assert bn.affine, "Only support fusing BatchNorm with affine set to True" - assert bn.track_running_stats, ( - "Only support fusing BatchNorm with tracking_running_stats set to True" - ) + if bn.num_features != conv.out_channels: + raise AssertionError( + "Output channel of Conv2d must match num_features of BatchNorm2d" + ) + if not bn.affine: + raise AssertionError( + "Only support fusing BatchNorm2d with affine set to True" + ) + if not bn.track_running_stats: + raise AssertionError( + "Only support fusing BatchNorm2d with tracking_running_stats set to True" + ) fused_module = map_to_fused_module_train.get(type(conv), None) if fused_module is not None: return fused_module(conv, bn, relu) @@ -134,18 +146,24 @@ def fuse_linear_bn(is_qat, linear, bn): >>> # xdoctest: +SKIP >>> m2 = fuse_linear_bn(m1, b1) """ - assert linear.training == bn.training, ( - "Linear and BN both must be in the same mode (train or eval)." - ) + if linear.training != bn.training: + raise AssertionError( + "Linear and BN both must be in the same mode (train or eval)." + ) if is_qat: - assert bn.num_features == linear.out_features, ( - "Output features of Linear must match num_features of BatchNorm1d" - ) - assert bn.affine, "Only support fusing BatchNorm1d with affine set to True" - assert bn.track_running_stats, ( - "Only support fusing BatchNorm1d with tracking_running_stats set to True" - ) + if bn.num_features != linear.out_features: + raise AssertionError( + "Output features of Linear must match num_features of BatchNorm1d" + ) + if not bn.affine: + raise AssertionError( + "Only support fusing BatchNorm1d with affine set to True" + ) + if not bn.track_running_stats: + raise AssertionError( + "Only support fusing BatchNorm1d with tracking_running_stats set to True" + ) return nni.LinearBn1d(linear, bn) else: return nn.utils.fusion.fuse_linear_bn_eval(linear, bn) @@ -167,9 +185,10 @@ def fuse_convtranspose_bn(is_qat, convt, bn): >>> # xdoctest: +SKIP >>> m2 = fuse_convtranspose_bn(m1, b1) """ - assert convt.training == bn.training, ( - "ConvTranspose and BN both must be in the same mode (train or eval)." - ) + if convt.training != bn.training: + raise AssertionError( + "ConvTranspose and BN both must be in the same mode (train or eval)." + ) if is_qat: raise Exception( # noqa: TRY002 @@ -224,7 +243,8 @@ def get_fuser_method(op_list, additional_fuser_method_mapping=None): _DEFAULT_OP_LIST_TO_FUSER_METHOD, additional_fuser_method_mapping ) fuser_method = all_mappings.get(op_list, None) - assert fuser_method is not None, f"did not find fuser method for: {op_list} " + if fuser_method is None: + raise AssertionError(f"did not find fuser method for: {op_list} ") return fuser_method @@ -289,5 +309,6 @@ def get_fuser_method_new( fuser_method = fuser_method_mapping.get(op_pattern) if fuser_method is not None: break - assert fuser_method is not None, f"did not find fuser method for: {op_pattern} " + if fuser_method is None: + raise AssertionError(f"did not find fuser method for: {op_pattern} ") return fuser_method diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py index ef7d1436f2178..08a95a7095f33 100644 --- a/torch/ao/quantization/fx/_decomposed.py +++ b/torch/ao/quantization/fx/_decomposed.py @@ -29,15 +29,17 @@ def _quant_min_max_bounds_check(quant_min, quant_max, dtype): raise ValueError(f"Unsupported dtype: {dtype}") quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype] - assert quant_min >= quant_min_lower_bound, ( - "quant_min out of bound for dtype, " - f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}" - ) + if quant_min < quant_min_lower_bound: + raise AssertionError( + "quant_min out of bound for dtype, " + f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}" + ) - assert quant_max <= quant_max_upper_bound, ( - "quant_max out of bound for dtype, " - f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}" - ) + if quant_max > quant_max_upper_bound: + raise AssertionError( + "quant_max out of bound for dtype, " + f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}" + ) quantized_decomposed_lib.define( @@ -72,9 +74,10 @@ def quantize_per_tensor( """ if input.dtype in [torch.float16, torch.bfloat16]: input = input.to(torch.float32) - assert input.dtype == torch.float32, ( - f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" - ) + if input.dtype != torch.float32: + raise AssertionError( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) _quant_min_max_bounds_check(quant_min, quant_max, dtype) inv_scale = 1.0 / scale @@ -94,9 +97,10 @@ def quantize_per_tensor_meta( ) -> torch.Tensor: if input.dtype in [torch.float16, torch.bfloat16]: input = input.to(torch.float32) - assert input.dtype == torch.float32, ( - f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" - ) + if input.dtype != torch.float32: + raise AssertionError( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) return torch.empty_like(input, dtype=dtype) @@ -122,12 +126,14 @@ def quantize_per_tensor_tensor( Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of scalar values """ - assert zero_point.numel() == 1, ( - f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" - ) - assert scale.numel() == 1, ( - f"Expecting scale tensor to be one element, but received : {scale.numel()}" - ) + if zero_point.numel() != 1: + raise AssertionError( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + if scale.numel() != 1: + raise AssertionError( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) return quantize_per_tensor( input, scale.item(), @@ -149,15 +155,18 @@ def quantize_per_tensor_tensor_meta( ) -> torch.Tensor: if input.dtype in [torch.float16, torch.bfloat16]: input = input.to(torch.float32) - assert zero_point.numel() == 1, ( - f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" - ) - assert scale.numel() == 1, ( - f"Expecting scale tensor to be one element, but received : {scale.numel()}" - ) - assert input.dtype == torch.float32, ( - f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" - ) + if zero_point.numel() != 1: + raise AssertionError( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + if scale.numel() != 1: + raise AssertionError( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + if input.dtype != torch.float32: + raise AssertionError( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) return torch.empty_like(input, dtype=dtype) @@ -184,12 +193,14 @@ def quantize_per_tensor_tensor2( Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of scalar values """ - assert zero_point.numel() == 1, ( - f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" - ) - assert scale.numel() == 1, ( - f"Expecting scale tensor to be one element, but received : {scale.numel()}" - ) + if zero_point.numel() != 1: + raise AssertionError( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + if scale.numel() != 1: + raise AssertionError( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) return quantize_per_tensor( input, scale.item(), @@ -266,9 +277,10 @@ def dequantize_per_tensor( Returns: dequantized float32 Tensor """ - assert input.dtype == dtype, ( - f"Expecting input to have dtype: {dtype}, but got {input.dtype}" - ) + if input.dtype != dtype: + raise AssertionError( + f"Expecting input to have dtype: {dtype}, but got {input.dtype}" + ) if out_dtype is None: out_dtype = torch.float32 if dtype in _DTYPE_TO_QVALUE_BOUNDS: @@ -322,12 +334,14 @@ def dequantize_per_tensor_tensor( Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of scalar values """ - assert zero_point.numel() == 1, ( - f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" - ) - assert scale.numel() == 1, ( - f"Expecting scale tensor to be one element, but received : {scale.numel()}" - ) + if zero_point.numel() != 1: + raise AssertionError( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + if scale.numel() != 1: + raise AssertionError( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) return dequantize_per_tensor( input, scale.item(), @@ -352,13 +366,18 @@ def dequantize_per_tensor_tensor_meta( ) -> torch.Tensor: if out_dtype is None: out_dtype = torch.float32 - assert zero_point.numel() == 1, ( - f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" - ) - assert scale.numel() == 1, ( - f"Expecting scale tensor to be one element, but received : {scale.numel()}" - ) - assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}" + if zero_point.numel() != 1: + raise AssertionError( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + if scale.numel() != 1: + raise AssertionError( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) + if input.dtype != dtype: + raise AssertionError( + f"Expecting input to have dtype: {dtype}, but got {input.dtype}" + ) if dtype in _DTYPE_TO_QVALUE_BOUNDS: return torch.empty_like(input, dtype=out_dtype) else: @@ -392,12 +411,14 @@ def dequantize_per_tensor_tensor2( Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of scalar values """ - assert zero_point.numel() == 1, ( - f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" - ) - assert scale.numel() == 1, ( - f"Expecting scale tensor to be one element, but received : {scale.numel()}" - ) + if zero_point.numel() != 1: + raise AssertionError( + f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" + ) + if scale.numel() != 1: + raise AssertionError( + f"Expecting scale tensor to be one element, but received : {scale.numel()}" + ) return dequantize_per_tensor( input, scale.item(), @@ -448,16 +469,18 @@ def choose_qparams_tensor( scale (float): quantization parameter for the target quantized Tensor zero_point (int): quantization parameter for the target quantized Tensor """ - assert input.dtype in [ + if input.dtype not in [ torch.float32, torch.float16, torch.bfloat16, - ], ( - f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" - ) - assert dtype in _DTYPE_TO_QVALUE_BOUNDS, ( - f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" - ) + ]: + raise AssertionError( + f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" + ) + if dtype not in _DTYPE_TO_QVALUE_BOUNDS: + raise AssertionError( + f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" + ) validate_qmin_qmax(qmin, qmax) min_val, max_val = torch.aminmax(input) @@ -500,16 +523,18 @@ def choose_qparams_symmetric_tensor( scale (float): quantization parameter for the target quantized Tensor zero_point (int): quantization parameter for the target quantized Tensor """ - assert input.dtype in [ + if input.dtype not in [ torch.float32, torch.float16, torch.bfloat16, - ], ( - f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" - ) - assert dtype in _DTYPE_TO_QVALUE_BOUNDS, ( - f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" - ) + ]: + raise AssertionError( + f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" + ) + if dtype not in _DTYPE_TO_QVALUE_BOUNDS: + raise AssertionError( + f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" + ) validate_qmin_qmax(qmin, qmax) min_val, max_val = torch.aminmax(input) @@ -529,17 +554,18 @@ def choose_qparams_symmetric_tensor( def choose_qparams_tensor_meta( input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - assert input.dtype in [ + if input.dtype not in [ torch.float32, torch.float16, torch.bfloat16, - ], ( - f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" - ) - assert quant_min < quant_max, ( - f"Expecting quant_min to be smaller than quant_max but received min: \ - {quant_min} max: {quant_max}" - ) + ]: + raise AssertionError( + f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" + ) + if quant_min >= quant_max: + raise AssertionError( + f"Expecting quant_min to be smaller than quant_max but received min: {quant_min} max: {quant_max}" + ) return torch.empty(1, dtype=torch.double, device=input.device), torch.empty( 1, dtype=torch.int64, device=input.device ) @@ -598,10 +624,12 @@ def quantize_per_channel( """ if input.dtype in [torch.float16, torch.bfloat16]: input = input.to(torch.float32) - assert input.dtype == torch.float32, ( - f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" - ) - assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" + if input.dtype != torch.float32: + raise AssertionError( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + if axis >= input.dim(): + raise AssertionError(f"Expecting axis to be < {input.dim()}") _quant_min_max_bounds_check(quant_min, quant_max, dtype) input, permute_axis_list = _permute_to_axis_zero(input, axis) @@ -629,10 +657,12 @@ def quantize_per_channel_meta( ) -> torch.Tensor: if input.dtype in [torch.float16, torch.bfloat16]: input = input.to(torch.float32) - assert input.dtype == torch.float32, ( - f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" - ) - assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" + if input.dtype != torch.float32: + raise AssertionError( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + if axis >= input.dim(): + raise AssertionError(f"Expecting axis to be < {input.dim()}") _quant_min_max_bounds_check(quant_min, quant_max, dtype) return torch.empty_like(input, dtype=dtype) @@ -687,12 +717,14 @@ def dequantize_per_channel( Returns: dequantized float32 Tensor """ - assert input.dtype == dtype, ( - f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" - ) + if input.dtype != dtype: + raise AssertionError( + f"Expecting input to have dtype: {dtype}, but got dtype: {input.dtype}" + ) if out_dtype is None: out_dtype = torch.float32 - assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" + if axis >= input.dim(): + raise AssertionError(f"Expecting axis to be < {input.dim()}") _quant_min_max_bounds_check(quant_min, quant_max, dtype) input, permute_axis_list = _permute_to_axis_zero(input, axis) @@ -722,12 +754,14 @@ def dequantize_per_channel_meta( *, out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: - assert input.dtype == dtype, ( - f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" - ) + if input.dtype != dtype: + raise AssertionError( + f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" + ) if out_dtype is None: out_dtype = torch.float32 - assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" + if axis >= input.dim(): + raise AssertionError(f"Expecting axis to be < {input.dim()}") _quant_min_max_bounds_check(quant_min, quant_max, dtype) return torch.empty_like(input, dtype=out_dtype) @@ -879,12 +913,12 @@ def choose_qparams_per_token_asymmetric_meta( def _per_token_quant_qparam_dim_check(input, scales, zero_points): num_tokens = math.prod(list(input.size())[:-1]) - assert num_tokens == scales.numel(), ( - f"num_tokens: {num_tokens} scales: {scales.size()}" - ) - assert num_tokens == zero_points.numel(), ( - f"num_tokens: {num_tokens} zero_points: {zero_points.size()}" - ) + if num_tokens != scales.numel(): + raise AssertionError(f"num_tokens: {num_tokens} scales: {scales.size()}") + if num_tokens != zero_points.numel(): + raise AssertionError( + f"num_tokens: {num_tokens} zero_points: {zero_points.size()}" + ) quantized_decomposed_lib.define( @@ -1019,17 +1053,21 @@ def quantize_per_channel_group( dtype: torch.dtype, group_size=128, ): - assert group_size > 1 + if group_size <= 1: + raise AssertionError("group_size must be > 1") # needed for GPTQ single column quantize if group_size > input.shape[-1] and scales.shape[-1] == 1: group_size = input.shape[-1] - assert input.shape[-1] % group_size == 0 - assert input.dim() == 2 + if input.shape[-1] % group_size != 0: + raise AssertionError("input.shape[-1] must be divisible by group_size") + if input.dim() != 2: + raise AssertionError("input must be 2-dimensional") # TODO: check for dtype, currently we can't express torch.int4 so it's omitted to_quant = input.reshape(-1, group_size) - assert torch.isnan(to_quant).sum() == 0 + if torch.isnan(to_quant).sum() != 0: + raise AssertionError("to_quant must not contain NaNs") scales = scales.reshape(-1, 1) zero_points = zero_points.reshape(-1, 1) @@ -1074,13 +1112,16 @@ def quantize_per_channel_group_meta( Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters are not stored in the Tensor, we are storing them in function arguments instead """ - assert group_size > 1 + if group_size <= 1: + raise AssertionError("group_size must be > 1") # needed for GPTQ single column quantize if group_size > input.shape[-1] and scales.shape[-1] == 1: group_size = input.shape[-1] - assert input.shape[-1] % group_size == 0 - assert input.dim() == 2 + if input.shape[-1] % group_size != 0: + raise AssertionError("input.shape[-1] must be divisible by group_size") + if input.dim() != 2: + raise AssertionError("input must be 2-dimensional") return torch.empty_like(input, dtype=dtype) @@ -1124,12 +1165,15 @@ def dequantize_per_channel_group( dequantized Tensor with dtype `output_dtype` """ - assert group_size > 1 + if group_size <= 1: + raise AssertionError("group_size must be > 1") # needed for GPTQ single column dequantize if group_size > w_int8.shape[-1] and scales.shape[-1] == 1: group_size = w_int8.shape[-1] - assert w_int8.shape[-1] % group_size == 0 - assert w_int8.dim() == 2 + if w_int8.shape[-1] % group_size != 0: + raise AssertionError("w_int8.shape[-1] must be divisible by group_size") + if w_int8.dim() != 2: + raise AssertionError("w_int8 must be 2-dimensional") w_int8_grouped = w_int8.reshape(-1, group_size) scales = scales.reshape(-1, 1) @@ -1155,10 +1199,12 @@ def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max): scales = scales.to(torch.float32) if zero_points.dtype != torch.int32: zero_points = zero_points.to(torch.int32) - assert input.dtype == torch.float32, ( - f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" - ) - assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" + if input.dtype != torch.float32: + raise AssertionError( + f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + ) + if axis >= input.dim(): + raise AssertionError(f"Expecting axis to be < {input.dim()}") broadcast_dims = list(range(axis)) + list(range(axis + 1, input.ndim)) unsqueeze_scales = _unsqueeze_multiple(scales, broadcast_dims) unsqueeze_zero_points = _unsqueeze_multiple(zero_points, broadcast_dims) diff --git a/torch/ao/quantization/fx/_equalize.py b/torch/ao/quantization/fx/_equalize.py index 4d0b098b93abd..b8809c1c60871 100644 --- a/torch/ao/quantization/fx/_equalize.py +++ b/torch/ao/quantization/fx/_equalize.py @@ -90,7 +90,7 @@ def __init__( self.equalization_shape: list[int] = [] def forward(self, x_orig): - if not (x_orig.ndim >= 2 and x_orig.ndim <= 5): + if x_orig.ndim < 2 or x_orig.ndim > 5: raise ValueError( "InputEqualizationObserver only supports Linear and Conv layers" ) @@ -191,7 +191,7 @@ def __init__( self.equalization_scale = torch.tensor(1) def forward(self, w_orig): - if not (w_orig.ndim >= 2 and w_orig.ndim <= 5): + if w_orig.ndim < 2 or w_orig.ndim > 5: raise ValueError( "InputEqualizationObserver only supports Linear and Conv layers" ) @@ -232,7 +232,7 @@ def calculate_equalization_scale( ) return torch.tensor(1) - if not (min_inputs.shape == min_weights.shape): + if min_inputs.shape != min_weights.shape: raise ValueError( "Input and Weight must have the same column dimension. " + f"Found {min_inputs.shape} and {min_weights.shape} shapes instead." @@ -355,30 +355,45 @@ def get_op_node_and_weight_eq_obs( op_node = user break - assert op_node is not None + if op_node is None: + raise AssertionError( + "Expected an operation node after the input equalization observer" + ) if op_node.op == "call_module": # If the op_node is a nn.Linear layer, then it must have a # WeightEqualizationObserver configuration maybe_equalization_node_name_to_config = _get_observed_graph_module_attr( model, "equalization_node_name_to_qconfig" ) - assert maybe_equalization_node_name_to_config is not None + if maybe_equalization_node_name_to_config is None: + raise AssertionError( + "Expected 'equalization_node_name_to_qconfig' attribute in observed graph module" + ) equalization_node_name_to_qconfig: dict[str, Any] = ( maybe_equalization_node_name_to_config # type: ignore[assignment] ) - assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None + if equalization_node_name_to_qconfig.get(op_node.name, None) is None: + raise AssertionError( + f"No equalization qconfig found for op node {op_node.name}" + ) weight_eq_obs = equalization_node_name_to_qconfig.get( # type: ignore[union-attr] op_node.name, None ).weight() - assert isinstance(weight_eq_obs, _WeightEqualizationObserver) + if not isinstance(weight_eq_obs, _WeightEqualizationObserver): + raise AssertionError( + "Expected weight equalization observer to be a _WeightEqualizationObserver" + ) return op_node, weight_eq_obs elif op_node.op == "call_function": weight_node = maybe_get_weight_eq_obs_node(op_node, modules) if weight_node is not None: weight_eq_obs = modules[str(weight_node.target)] - assert isinstance(weight_eq_obs, _WeightEqualizationObserver) + if not isinstance(weight_eq_obs, _WeightEqualizationObserver): + raise AssertionError( + "Expected weight equalization observer to be a _WeightEqualizationObserver" + ) return op_node, weight_eq_obs return None, None @@ -388,17 +403,20 @@ def maybe_get_weight_eq_obs_node( op_node: Node, modules: dict[str, nn.Module] ) -> Optional[Node]: """Gets the weight equalization observer node if it exists.""" - assert op_node.op == "call_function" + if op_node.op != "call_function": + raise AssertionError( + "maybe_get_weight_eq_obs_node expects a call_function op_node" + ) for node_arg in op_node.args: if node_arg_is_weight(op_node, node_arg): - assert ( + if ( isinstance(node_arg, Node) and node_arg.op == "call_module" and isinstance( modules[str(node_arg.target)], _WeightEqualizationObserver ) - ) - return node_arg + ): + return node_arg return None @@ -422,7 +440,8 @@ def maybe_get_next_input_eq_obs( the following equalization observer for linear2. """ - assert node_supports_equalization(node, modules) + if not node_supports_equalization(node, modules): + raise AssertionError("Node does not support equalization") # Locate the following nn.ReLU or F.relu node if it exists maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU) @@ -448,7 +467,10 @@ def maybe_get_next_input_eq_obs( return None maybe_eq_obs = modules[str(maybe_eq_obs_node)] - assert isinstance(maybe_eq_obs, _InputEqualizationObserver) + if not isinstance(maybe_eq_obs, _InputEqualizationObserver): + raise AssertionError( + "Expected the following equalization observer to be an _InputEqualizationObserver" + ) return maybe_eq_obs @@ -480,10 +502,16 @@ def scale_input_observer(node: Node, modules: dict[str, nn.Module]) -> None: equalization observer """ input_eq_obs = modules[str(node.target)] - assert isinstance(input_eq_obs, _InputEqualizationObserver) + if not isinstance(input_eq_obs, _InputEqualizationObserver): + raise AssertionError( + "Expected the module at node.target to be an _InputEqualizationObserver" + ) input_quant_obs_node = node.args[0] - assert isinstance(input_quant_obs_node, Node) + if not isinstance(input_quant_obs_node, Node): + raise AssertionError( + "Expected the input quantization observer node to be a Node" + ) input_quant_obs = modules[str(input_quant_obs_node.target)] if not isinstance(input_quant_obs, ObserverBase): @@ -518,14 +546,19 @@ def scale_weight_node( op_module = modules[str(node.target)][0] # type: ignore[index] else: op_module = modules[str(node.target)] - assert nn_module_supports_equalization( - op_module - ) or custom_module_supports_equalization(op_module) + if not ( + nn_module_supports_equalization(op_module) + or custom_module_supports_equalization(op_module) + ): + raise AssertionError( + "Expected operation module to support equalization (nn or custom)" + ) # Scale the weights for input-weight equalization # If the following layer needs to be equalized then we will multiply its scale weight = op_module.weight - assert isinstance(weight, torch.Tensor) + if not isinstance(weight, torch.Tensor): + raise AssertionError("Expected op_module.weight to be a torch.Tensor") # Scale the weights by the reciprocal of the equalization scale # Reshape the equalization scale so that we can multiply it to the weight along axis=1 @@ -547,7 +580,8 @@ def scale_weight_node( bias = op_module.bias if bias is None: return - assert isinstance(bias, torch.Tensor) + if not isinstance(bias, torch.Tensor): + raise AssertionError("Expected op_module.bias to be a torch.Tensor") # Reshape the equalization scale so that we can multiply it element-wise to the bias next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias) @@ -581,15 +615,20 @@ def scale_weight_functional( weight_quant_obs_node = weight_eq_obs_node.args[0] if weight_quant_obs_node is None: return - assert isinstance(weight_quant_obs_node, Node) and isinstance( - modules[str(weight_quant_obs_node.target)], ObserverBase - ) + if not ( + isinstance(weight_quant_obs_node, Node) + and isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase) + ): + raise AssertionError( + "Expected weight_quant_obs_node to be a Node whose module is an ObserverBase" + ) # Get the get_attr(weight) node weight_node = weight_quant_obs_node.args[0] if weight_node is None: return - assert isinstance(weight_node, Node) and weight_node.op == "get_attr" + if not (isinstance(weight_node, Node) and weight_node.op == "get_attr"): + raise AssertionError("Expected weight node to be a 'get_attr' Node") weight_parent_name, weight_name = _parent_name(weight_node.target) weight = getattr(modules[weight_parent_name], weight_name) @@ -612,7 +651,8 @@ def scale_weight_functional( scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped) setattr(modules[weight_parent_name], weight_name, scaled_weight) - assert torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight) + if not torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight): + raise AssertionError("Model buffer for weight does not match the scaled weight") # Multiply the bias element wise by the next equalization scale bias_node = None @@ -644,10 +684,14 @@ def clear_weight_quant_obs_node(op_node: Node, modules: dict[str, nn.Module]) -> weight_quant_obs_node = weight_eq_obs_node.args[0] if weight_quant_obs_node is None: return - assert isinstance(weight_quant_obs_node, Node) + if not isinstance(weight_quant_obs_node, Node): + raise AssertionError("Expected weight_quant_obs_node to be a Node") weight_quant_obs = modules[str(weight_quant_obs_node.target)] - assert isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase) + if not isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase): + raise AssertionError( + "Expected the module at weight_quant_obs_node to be an ObserverBase" + ) weight_quant_obs.reset_min_max_vals() # type: ignore[operator] @@ -682,7 +726,10 @@ def update_obs_for_equalization( modules[node.target], _InputEqualizationObserver ): input_eq_obs = modules[node.target] - assert isinstance(input_eq_obs, _InputEqualizationObserver) + if not isinstance(input_eq_obs, _InputEqualizationObserver): + raise AssertionError( + "Expected module at node.target to be an _InputEqualizationObserver" + ) op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules) if op_node is None or weight_eq_obs is None: @@ -693,7 +740,10 @@ def update_obs_for_equalization( # been created if fused_module_supports_equalization(modules[str(op_node.target)]): module = modules[str(op_node.target)][0] # type: ignore[index] - assert nn_module_supports_equalization(module) + if not nn_module_supports_equalization(module): + raise AssertionError( + "Expected fused module to support equalization" + ) weight_eq_obs(module.weight) else: weight_eq_obs(modules[str(op_node.target)].weight) @@ -810,7 +860,10 @@ def convert_eq_obs( elif weight_eq_obs_dict.get(node.name, None) is not None: weight_eq_obs = weight_eq_obs_dict.get(node.name) - assert isinstance(weight_eq_obs, _WeightEqualizationObserver) + if not isinstance(weight_eq_obs, _WeightEqualizationObserver): + raise AssertionError( + "Expected weight equalization observer to be a _WeightEqualizationObserver" + ) equalization_scale = weight_eq_obs.equalization_scale if ( @@ -844,9 +897,12 @@ def convert_eq_obs( weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules) if weight_eq_obs_node is None: return - assert isinstance( + if not isinstance( modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver - ) + ): + raise AssertionError( + "Expected weight equalization observer to be a _WeightEqualizationObserver" + ) # Clear the quantization observer's min/max values so that they # can get updated later based on the new scale values diff --git a/torch/ao/quantization/fx/_lower_to_native_backend.py b/torch/ao/quantization/fx/_lower_to_native_backend.py index fa8e7d53e6b02..6ef9c6302d711 100644 --- a/torch/ao/quantization/fx/_lower_to_native_backend.py +++ b/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -227,7 +227,7 @@ def is_dequantize_node(node): def is_getattr_tensor_metadata_node(node): return ( node.op == "call_function" - and node.target == getattr + and node.target is getattr and node.args[1] == "shape" ) @@ -523,7 +523,7 @@ def load_arg(a): del original_weights_lookup[str(lookup_counter)] lookup_counter += 1 elif prepack_node is not None: - # remove the foled node + # remove the fold node continue else: # copy other nodes @@ -585,7 +585,8 @@ def _match_static_pattern( return SKIP_LOWERING_VALUE q_node = node ref_node = q_node.args[0] - assert isinstance(ref_node, Node) + if not isinstance(ref_node, Node): + raise AssertionError("Expected the reference node to be a torch.fx Node") # Handle cases where the node is wrapped in a ReLU if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or ( @@ -593,7 +594,10 @@ def _match_static_pattern( ): relu_node = ref_node ref_node = relu_node.args[0] - assert isinstance(ref_node, Node) + if not isinstance(ref_node, Node): + raise AssertionError( + "Expected the reference node after ReLU to be a torch.fx Node" + ) else: relu_node = None if should_skip_lowering(ref_node, qconfig_map): @@ -616,9 +620,10 @@ def _match_static_pattern( # (2) There must be at least one dequantize node matched_dequantize = False for i in dequantize_node_arg_indices: - assert i < len(ref_node.args), ( - f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}" - ) + if i >= len(ref_node.args): + raise AssertionError( + f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}" + ) arg = ref_node.args[i] if is_dequantize_node(arg): matched_dequantize = True @@ -660,7 +665,8 @@ def _match_static_pattern_with_two_inputs( return SKIP_LOWERING_VALUE q_node = node ref_node = q_node.args[0] - assert isinstance(ref_node, Node) + if not isinstance(ref_node, Node): + raise AssertionError("Expected the reference node to be a torch.fx Node") if should_skip_lowering(ref_node, qconfig_map): return SKIP_LOWERING_VALUE @@ -711,13 +717,21 @@ def _lower_static_weighted_ref_module( ) if q_node is None: continue - assert ref_node is not None + if ref_node is None: + raise AssertionError( + "Expected a reference node when matching static pattern" + ) (_, scale_node, zero_point_node, _) = q_node.args ref_module = _get_module(ref_node, modules) ref_class = type(ref_module) - assert isinstance(scale_node, Node) - assert isinstance(zero_point_node, Node) - assert issubclass(ref_class, nn.Module) + if not isinstance(scale_node, Node): + raise AssertionError("Expected scale_node to be a Node") + if not isinstance(zero_point_node, Node): + raise AssertionError("Expected zero_point_node to be a Node") + if not issubclass(ref_class, nn.Module): + raise AssertionError( + "Expected reference module class to be a subclass of nn.Module" + ) # Step 1: Change this pattern to use the corresponding quantized module # For fused modules, we also check whether the inner module is a reference module @@ -736,9 +750,11 @@ def _lower_static_weighted_ref_module( setattr(modules[parent_name], module_name, q_module) # Step 2: Reroute around dq_node, and remove q_node and its args - assert len(ref_node.args) == 1 + if len(ref_node.args) != 1: + raise AssertionError("Expected reference node to have exactly 1 arg") dq_node = ref_node.args[0] - assert isinstance(dq_node, Node) + if not isinstance(dq_node, Node): + raise AssertionError("Expected dq_node to be a Node") ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type] q_node.replace_all_uses_with(ref_node) model.graph.erase_node(q_node) @@ -771,13 +787,21 @@ def _lower_static_weighted_ref_module_with_two_inputs( ) if q_node is None: continue - assert ref_node is not None + if ref_node is None: + raise AssertionError( + "Expected a reference node when matching static pattern with two inputs" + ) (_, scale_node, zero_point_node, _) = q_node.args ref_module = _get_module(ref_node, modules) ref_class = type(ref_module) - assert isinstance(scale_node, Node) - assert isinstance(zero_point_node, Node) - assert issubclass(ref_class, nn.Module) + if not isinstance(scale_node, Node): + raise AssertionError("Expected scale_node to be a Node") + if not isinstance(zero_point_node, Node): + raise AssertionError("Expected zero_point_node to be a Node") + if not issubclass(ref_class, nn.Module): + raise AssertionError( + "Expected reference module class to be a subclass of nn.Module" + ) # Step 1: Change this pattern to use the corresponding quantized module # For fused modules, we also check whether the inner module is a reference module @@ -798,12 +822,14 @@ def _lower_static_weighted_ref_module_with_two_inputs( setattr(modules[parent_name], module_name, q_module) # Step 2: Reroute around dq_node, and remove q_node and its args - assert len(ref_node.args) == 2 + if len(ref_node.args) != 2: + raise AssertionError("Expected reference node to have exactly 2 args") for arg in ref_node.args: if not is_dequantize_node(arg): continue dq_node = arg - assert isinstance(dq_node, Node) + if not isinstance(dq_node, Node): + raise AssertionError("Expected dq_node to be a Node") ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type] q_node.replace_all_uses_with(ref_node) @@ -900,14 +926,21 @@ def _lower_static_weighted_ref_functional( ) if q_node is None: continue - assert func_node is not None + if func_node is None: + raise AssertionError( + "Expected a function node when matching static functional pattern" + ) (_, output_scale_node, output_zp_node, _) = q_node.args (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args - assert isinstance(output_zp_node, Node) - assert isinstance(input_dq_node, Node) - assert isinstance(weight_dq_node, Node) + if not isinstance(output_zp_node, Node): + raise AssertionError("Expected output_zp_node to be a Node") + if not isinstance(input_dq_node, Node): + raise AssertionError("Expected input_dq_node to be a Node") + if not isinstance(weight_dq_node, Node): + raise AssertionError("Expected weight_dq_node to be a Node") quantized_weight = weight_dq_node.args[0] - assert isinstance(quantized_weight, Node) + if not isinstance(quantized_weight, Node): + raise AssertionError("Expected quantized_weight to be a Node") if quantized_weight.op != "call_function" or quantized_weight.target not in ( torch.quantize_per_tensor, torch.quantize_per_channel, @@ -919,14 +952,14 @@ def _lower_static_weighted_ref_functional( # Linear prepack args: (quantized weights[, bias]) # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups]) prepack_args = [quantized_weight] + remaining_func_args - if func_node.target == F.linear: + if func_node.target is F.linear: weight_dtype = quantized_weight.args[-1] prepack_op = get_linear_prepack_op_for_dtype(weight_dtype) elif func_node.target in CONV_FUNCTIONAL_OPS: prepack_op = get_qconv_prepack_op(func_node.target) # type: ignore[arg-type] # For conv1d, the stride, padding, and dilation args may be ints, # in which case we need to convert them to tuples - if func_node.target == F.conv1d: + if func_node.target is F.conv1d: for i in [2, 3, 4]: if len(prepack_args) > i and isinstance(prepack_args[i], int): prepack_args[i] = (prepack_args[i],) @@ -934,7 +967,7 @@ def _lower_static_weighted_ref_functional( prepack_op = get_qconv_prepack_op(func_node.target) # type: ignore[arg-type] # For conv_transpose1d, the stride, padding, and dilation args may be ints, # in which case we need to convert them to tuples - if func_node.target == F.conv_transpose1d: + if func_node.target is F.conv_transpose1d: # Note prepack_args[5] is groups. for i in [2, 3, 4, 6]: if len(prepack_args) > i and isinstance(prepack_args[i], int): @@ -951,7 +984,7 @@ def _lower_static_weighted_ref_functional( # They are not needed for compute op (i.e., quantized::linear) kwargs = func_node.kwargs # F.linear uses 'bias' key for bias while qlinear_prepack uses 'B' for bias - if func_node.target == F.linear and "bias" in kwargs: + if func_node.target is F.linear and "bias" in kwargs: kwargs = kwargs.copy() kwargs["B"] = kwargs["bias"] del kwargs["bias"] @@ -1006,7 +1039,7 @@ def _lower_dynamic_weighted_ref_functional( # Handle cases where the functional op is wrapped in a ReLU if ( func_node.op == "call_function" - and func_node.target == F.relu + and func_node.target is F.relu or func_node.op == "call_module" and type(modules[str(func_node.target)]) is torch.nn.ReLU ): @@ -1078,7 +1111,7 @@ def _lower_dynamic_weighted_ref_functional( # Conv prepack args: (quantized weights[, bias, stride, padding, dilation, groups]) prepack_args = [quantized_weight] + remaining_func_args prepack_kwargs = {} - if func_node.target == F.linear: + if func_node.target is F.linear: prepack_op = get_linear_prepack_op_for_dtype(weight_dtype) kwargs = func_node.kwargs.copy() if "bias" in kwargs: @@ -1089,7 +1122,7 @@ def _lower_dynamic_weighted_ref_functional( prepack_op = get_qconv_prepack_op(func_node.target) # For conv1d, the stride, padding, and dilation args may be ints, # in which case we need to convert them to tuples - if func_node.target == F.conv1d: + if func_node.target is F.conv1d: for i in [2, 3, 4]: if len(prepack_args) > i and isinstance(prepack_args[i], int): prepack_args[i] = (prepack_args[i],) @@ -1135,7 +1168,10 @@ def _lower_quantized_binary_op(model: GraphModule, qconfig_map: dict[str, QConfi ) if q_node is None: continue - assert bop_node is not None + if bop_node is None: + raise AssertionError( + "Expected a binary op node when matching quantized binary op pattern" + ) (_, scale_node, zero_point_node, _) = q_node.args # Step 1: Remove dequant nodes @@ -1144,14 +1180,21 @@ def _lower_quantized_binary_op(model: GraphModule, qconfig_map: dict[str, QConfi if not is_dequantize_node(arg): continue dq_node = arg - assert isinstance(dq_node, Node) + if not isinstance(dq_node, Node): + raise AssertionError("Expected dq_node to be a Node") dn_input = dq_node.args[0] bop_node.replace_input_with(dq_node, dn_input) # type: ignore[arg-type] num_dq_nodes += 1 - assert num_dq_nodes > 0 + if num_dq_nodes <= 0: + raise AssertionError( + "Expected at least one dequantize node in binary op args" + ) # Step 2: Swap binary op to quantized binary op - assert bop_node.target in QBIN_OP_MAPPING + if bop_node.target not in QBIN_OP_MAPPING: + raise AssertionError( + f"Unsupported binary op {bop_node.target} for lowering" + ) binop_to_qbinop = QBIN_OP_MAPPING if relu_node is None else QBIN_RELU_OP_MAPPING qbin_op = binop_to_qbinop[bop_node.target] # prepare the args for quantized binary op @@ -1181,14 +1224,15 @@ def special_pattern_replacement(model: GraphModule): modules = dict(model.named_modules(remove_duplicate=False)) for n in model.graph.nodes: q_node = n - is_quantize = q_node.target == torch.quantize_per_tensor + is_quantize = q_node.target is torch.quantize_per_tensor is_to_fp16 = ( q_node.op == "call_method" and q_node.target == "to" and len(q_node.args) == 2 and q_node.args[1] == torch.float16 ) - if not (is_quantize or is_to_fp16): + # Only continue when neither quantize nor to_fp16 + if not is_quantize and not is_to_fp16: continue ref_node = q_node.args[0] # get output scale/zero_point/dtype from the quantize node @@ -1217,13 +1261,17 @@ def special_pattern_replacement(model: GraphModule): ) if not (is_call_module or is_call_function or is_call_method): continue - assert len(ref_node.args) > 0 or len(ref_node.kwargs) > 0 + if len(ref_node.args) <= 0 and len(ref_node.kwargs) <= 0: + raise AssertionError("Expected ref_node to have args or kwargs") dq_node_or_nodes = ( ref_node.args[0] if len(ref_node.args) > 0 else next(iter(ref_node.kwargs.values())) ) - assert isinstance(dq_node_or_nodes, (Node, tuple, list)) + if not isinstance(dq_node_or_nodes, (Node, tuple, list)): + raise AssertionError( + "Expected dq_node_or_nodes to be a Node, tuple, or list" + ) is_dequantize = False if isinstance(dq_node_or_nodes, Node): is_dequantize = ( diff --git a/torch/ao/quantization/fx/_model_report/detector.py b/torch/ao/quantization/fx/_model_report/detector.py index b3bc3c3847603..993a6c41f176f 100644 --- a/torch/ao/quantization/fx/_model_report/detector.py +++ b/torch/ao/quantization/fx/_model_report/detector.py @@ -362,11 +362,15 @@ def _detect_per_channel_helper(self, model: nn.Module): # assert statement for MyPy q_config_file = module.qconfig - assert isinstance(q_config_file, QConfig) + if not isinstance(q_config_file, QConfig): + raise AssertionError("module.qconfig must be a QConfig") # this object should either be fake quant or observer q_or_s_obj = module.qconfig.weight.p.func() - assert isinstance(q_or_s_obj, (FakeQuantize, ObserverBase)) + if not isinstance(q_or_s_obj, (FakeQuantize, ObserverBase)): + raise AssertionError( + "module.qconfig.weight must be a FakeQuantize or ObserverBase" + ) per_channel_used = False # will be true if found in qconfig @@ -1160,9 +1164,10 @@ def _generate_comparison_values( input_channels = len(input_ratio) if weight_channels != input_channels: # we try to replicate - assert input_channels % weight_channels == 0, ( - "input channels should be divisible by weight channels." - ) + if input_channels % weight_channels != 0: + raise AssertionError( + "input channels should be divisible by weight channels." + ) # get replication factor rep_factor: int = input_channels // weight_channels @@ -1418,11 +1423,15 @@ def __init__( self.ratio_threshold = ratio_threshold # make sure passed in percentile is valid - assert reference_percentile >= 0 and reference_percentile <= 1 - assert ( + if reference_percentile < 0 or reference_percentile > 1: + raise AssertionError("reference_percentile must be between 0 and 1") + if not ( fraction_batches_used_threshold >= 0 and fraction_batches_used_threshold <= 1 - ) + ): + raise AssertionError( + "fraction_batches_used_threshold must be between 0 and 1" + ) self.reference_percentile = reference_percentile self.fraction_batches_used_threshold = fraction_batches_used_threshold self.ch_axis = ch_axis diff --git a/torch/ao/quantization/fx/_model_report/model_report.py b/torch/ao/quantization/fx/_model_report/model_report.py index ca9c1099298fc..0ffbff88dd2d8 100644 --- a/torch/ao/quantization/fx/_model_report/model_report.py +++ b/torch/ao/quantization/fx/_model_report/model_report.py @@ -261,7 +261,8 @@ def _get_node_from_fqn(self, node_fqn: str) -> torch.fx.node.Node: raise ValueError("The node_fqn is was not found within the module.") # assert for MyPy - assert isinstance(node_to_return, torch.fx.node.Node) + if not isinstance(node_to_return, torch.fx.node.Node): + raise AssertionError("node_to_return must be a torch.fx.node.Node") return node_to_return diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index 853641c6b66ba..08ae102f69f41 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -112,8 +112,12 @@ def _replace_observer_with_quantize_dequantize_node_decomposed( or quantize_per_channel and dequantize_per_channel """ graph = model.graph - assert modules is not None - assert isinstance(node.target, str) + if modules is None: + raise AssertionError("modules must not be None") + if not isinstance(node.target, str): + raise AssertionError( + f"Expected node.target to be a str, but got {type(node.target)}" + ) module_path, prefix = _get_module_path_and_prefix( node, node_name_to_scope, node_name_to_qconfig ) @@ -260,10 +264,10 @@ def add_dequantize_op_kwargs(dequantize_op, input_node): # and that can be done after we remove reduce_range flag # 1. extract qparams from activation_post_process module dtype_ = to_underlying_dtype(dtype) - assert dtype_ in [torch.uint8, torch.int8], ( - "only uint8 and int8 are supported in reference flow for " - "dynamic quantization right now" - ) + if dtype_ not in [torch.uint8, torch.int8]: + raise AssertionError( + "only uint8 and int8 are supported in reference flow for dynamic quantization right now" + ) quant_min = activation_post_process.quant_min # type: ignore[attr-defined] quant_max = activation_post_process.quant_max # type: ignore[attr-defined] qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine) # type: ignore[attr-defined] @@ -379,8 +383,12 @@ def _replace_observer_with_quantize_dequantize_node( After: ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... """ - assert modules is not None - assert isinstance(node.target, str) + if modules is None: + raise AssertionError("modules must not be None") + if not isinstance(node.target, str): + raise AssertionError( + f"Expected node.target to be a str, but got {type(node.target)}" + ) graph = model.graph module_path, prefix = _get_module_path_and_prefix( node, node_name_to_scope, node_name_to_qconfig @@ -521,9 +529,10 @@ def _replace_observer_or_dequant_stub_with_dequantize_node( node: Node, graph: Graph ) -> None: call_custom_module_node = node.args[0] - assert isinstance(call_custom_module_node, Node), ( - f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" - ) + if not isinstance(call_custom_module_node, Node): + raise AssertionError( + f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" + ) node.replace_all_uses_with(call_custom_module_node) graph.erase_node(node) _insert_dequantize_node(call_custom_module_node, graph) @@ -617,9 +626,10 @@ def _get_module_path_and_prefix( # operator (they can be the same) # this flag identifies if the observer is inserted only because the observed node is # the input of the next operator - assert isinstance(observed_node, Node), ( - f"Expecting observed node to be a Node, but got {observed_node}" - ) + if not isinstance(observed_node, Node): + raise AssertionError( + f"Expecting observed node to be a Node, but got {observed_node}" + ) is_input_observer_only = ( node_name_to_qconfig[observed_node.name] is None if observed_node.name in node_name_to_qconfig @@ -633,7 +643,7 @@ def _get_module_path_and_prefix( first_linear_use_or_first_use = users[0] if users else None linear_node = None for n in users: - if n.op == "call_function" and n.target == torch.nn.functional.linear: + if n.op == "call_function" and n.target is torch.nn.functional.linear: linear_node = n break if linear_node: @@ -727,8 +737,10 @@ def convert_standalone_module( "_observed_graph_module_attrs" ].standalone_module_output_quantized_idxs if len(sm_output_quantized_idxs) > 0: - assert sm_output_quantized_idxs[0] == 0, "Currently only quantized" - "output idxs = [0] is supported" + if sm_output_quantized_idxs[0] != 0: + raise AssertionError( + "Currently only quantized output idxs = [0] is supported" + ) # if it's non-empty, then it means the output is kept in quantized form # we'll just add a dequantize node after this node @@ -882,9 +894,10 @@ def convert_weighted_module( ref_qmodule_cls = root_module_to_quantized_reference_module.get( type_before_parametrizations(float_module), None ) - assert ref_qmodule_cls is not None, ( - f"No reference quantized module class configured for {type_before_parametrizations(float_module)}" - ) + if ref_qmodule_cls is None: + raise AssertionError( + f"No reference quantized module class configured for {type_before_parametrizations(float_module)}" + ) ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined] if fused_module is not None: fused_module[0] = ref_qmodule # type: ignore[operator] @@ -904,9 +917,10 @@ def _remove_previous_dequantize_in_custom_module( \\ - dequantize """ # expecting the input node for a custom module node to be a Node - assert isinstance(prev_node, Node), ( - f"Expecting the argument for custom module node to be a Node, but got {prev_node}" - ) + if not isinstance(prev_node, Node): + raise AssertionError( + f"Expecting the argument for custom module node to be a Node, but got {prev_node}" + ) if prev_node.op == "call_method" and prev_node.target == "dequantize": node.replace_input_with(prev_node, prev_node.args[0]) # Remove the dequantize node if it doesn't have other users @@ -952,15 +966,21 @@ def convert_custom_module( if _is_custom_module_lstm(node, modules): # The inputs are tuples in the form (input, (hidden0, hidden1)) # Ensure all three input nodes are quantized - assert ( + if not ( len(node.args) == 2 and isinstance(node.args[1], tuple) and len(node.args[1]) == 2 - ) + ): + raise AssertionError( + "Expected LSTM custom module inputs to be (input, (hidden0, hidden1))" + ) (inputs, (hidden0, hidden1)) = node.args # type: ignore[misc] - assert isinstance(inputs, Node) - assert isinstance(hidden0, Node) - assert isinstance(hidden1, Node) + if not isinstance(inputs, Node): + raise AssertionError("Expected inputs to be a Node") + if not isinstance(hidden0, Node): + raise AssertionError("Expected hidden0 to be a Node") + if not isinstance(hidden1, Node): + raise AssertionError("Expected hidden1 to be a Node") _remove_previous_dequantize_in_custom_module(node, inputs, graph) _remove_previous_dequantize_in_custom_module(node, hidden0, graph) _remove_previous_dequantize_in_custom_module(node, hidden1, graph) @@ -971,22 +991,32 @@ def convert_custom_module( # to the module. # Additional handling is yet to be implemented for the outputs, similar # to LSTM custom module - assert len(node.args) == 3 + if len(node.args) != 3: + raise AssertionError( + "Expected MHA custom module inputs to be (query, key, value)" + ) query, key, value = node.args - assert isinstance(query, Node) - assert isinstance(key, Node) - assert isinstance(value, Node) + if not isinstance(query, Node): + raise AssertionError("Expected query to be a Node") + if not isinstance(key, Node): + raise AssertionError("Expected key to be a Node") + if not isinstance(value, Node): + raise AssertionError("Expected value to be a Node") _remove_previous_dequantize_in_custom_module(node, query, graph) _remove_previous_dequantize_in_custom_module(node, key, graph) _remove_previous_dequantize_in_custom_module(node, value, graph) else: # remove the previous dequant node to ensure the inputs are quantized arg = node.args[0] - assert isinstance(arg, Node) + if not isinstance(arg, Node): + raise AssertionError("Expected arg to be a Node") _remove_previous_dequantize_in_custom_module(node, arg, graph) # absorb the following observer into the module conversion activation_post_process = _maybe_get_observer_for_node(node, modules) - assert activation_post_process is not None + if activation_post_process is None: + raise AssertionError( + "Expected activation_post_process to be present for observed custom module" + ) observed_custom_module.activation_post_process = activation_post_process # swap the observed custom module to quantized custom module @@ -1061,7 +1091,8 @@ def convert( QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None ) qconfig_mapping = copy.deepcopy(qconfig_mapping) - assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping) + if not (qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)): + raise AssertionError("qconfig_mapping must be None or a QConfigMapping") if isinstance(backend_config, dict): warnings.warn( @@ -1075,7 +1106,8 @@ def convert( if backend_config is None: backend_config = get_native_backend_config() - assert _is_observed_module(model), "incoming model must be produced by prepare_fx" + if not _is_observed_module(model): + raise AssertionError("incoming model must be produced by prepare_fx") observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"] node_name_to_scope: dict[str, tuple[str, type]] = ( observed_graph_module_attrs.node_name_to_scope @@ -1121,14 +1153,16 @@ def convert( # all the values either match what was set in prepare node_name_to_qconfig # or are set to None in the convert_node_name_to_qconfig. for k, v in node_name_to_qconfig.items(): - assert k in convert_node_name_to_qconfig, ( - f"Expected key {k} in convert node_name_to_qconfig" - ) - if convert_node_name_to_qconfig[k] is not None: - assert qconfig_equals(v, convert_node_name_to_qconfig[k]), ( - f"Expected k {k} to have the same value in prepare and convert QConfigMappings, " - f"but {v} was updated to {convert_node_name_to_qconfig[k]}" + if k not in convert_node_name_to_qconfig: + raise AssertionError( + f"Expected key {k} in convert node_name_to_qconfig" ) + if convert_node_name_to_qconfig[k] is not None: + if not qconfig_equals(v, convert_node_name_to_qconfig[k]): + raise AssertionError( + f"Expected k {k} to have the same value in prepare and convert QConfigMappings, " + f"but {v} was updated to {convert_node_name_to_qconfig[k]}" + ) node_name_to_qconfig = convert_node_name_to_qconfig custom_module_classes = get_custom_module_class_keys( @@ -1201,7 +1235,10 @@ def convert( ) elif node.op == "call_module": mod = _get_module(node, modules) - assert mod is not None + if mod is None: + raise AssertionError( + "Expected module for call_module node to be present in modules mapping" + ) if _is_activation_post_process(mod): observed_node = node.args[0] if observed_node in statically_quantized_custom_module_nodes: diff --git a/torch/ao/quantization/fx/fuse.py b/torch/ao/quantization/fx/fuse.py index 435085a6b8459..f50f9132cb0e3 100644 --- a/torch/ao/quantization/fx/fuse.py +++ b/torch/ao/quantization/fx/fuse.py @@ -102,7 +102,10 @@ def default_root_node_getter(node_pattern): else: node_subpattern = None if maybe_last_node is node: - assert obj is not None + if obj is None: + raise AssertionError( + "fuse handler object must not be None for matched root node" + ) root_node_getter = fusion_pattern_to_root_node_getter.get( pattern, default_root_node_getter ) diff --git a/torch/ao/quantization/fx/fuse_handler.py b/torch/ao/quantization/fx/fuse_handler.py index 24f3b13381724..b164bd08c344d 100644 --- a/torch/ao/quantization/fx/fuse_handler.py +++ b/torch/ao/quantization/fx/fuse_handler.py @@ -65,9 +65,8 @@ def fuse( fuser_method_mapping: dict[Pattern, Union[torch.nn.Sequential, Callable]], is_qat: bool, ) -> Node: - assert root_node.op == "call_module", ( - "Expecting module node to be a call_module Node" - ) + if root_node.op != "call_module": + raise AssertionError("Expecting module node to be a call_module Node") root_module = named_modules[str(root_node.target)] def get_modules(pattern): @@ -85,7 +84,7 @@ def get_modules(pattern): n = pattern if n.op == "call_module": return named_modules[n.target] - elif n.op == "call_function" and n.target == torch.nn.functional.relu: + elif n.op == "call_function" and n.target is torch.nn.functional.relu: relu = torch.nn.ReLU() relu.training = root_module.training return relu diff --git a/torch/ao/quantization/fx/lstm_utils.py b/torch/ao/quantization/fx/lstm_utils.py index b49f462640f0c..b609cd2b2157d 100644 --- a/torch/ao/quantization/fx/lstm_utils.py +++ b/torch/ao/quantization/fx/lstm_utils.py @@ -109,7 +109,8 @@ def make_qconfig(obs_ctr: _PartialWrapper) -> QConfig: # TODO: maybe make this work for layer_bw as well for layer in quantizable_lstm.layers: cell = layer.layer_fw.cell # type: ignore[union-attr] - assert isinstance(cell, torch.nn.Module), "cell should be a nn.Module" + if not isinstance(cell, torch.nn.Module): + raise AssertionError("cell should be a nn.Module") cell = prepare_fx(cell, cell_qm, example_inputs, backend_config=backend_config) # HACK: Manually replace the activation_post_process following these ops. # This is needed for FloatFunctional ops because there is currently no way @@ -139,10 +140,10 @@ def make_qconfig(obs_ctr: _PartialWrapper) -> QConfig: mul_count = 0 for node in cell.graph.nodes: op_index: Optional[tuple[Callable, int]] = None # e.g. (torch.add, 1) - if node.target == torch.add: + if node.target is torch.add: op_index = (torch.add, add_count) add_count += 1 - elif node.target == torch.mul: + elif node.target is torch.mul: op_index = (torch.mul, mul_count) mul_count += 1 else: @@ -150,7 +151,8 @@ def make_qconfig(obs_ctr: _PartialWrapper) -> QConfig: continue if op_index not in op_index_to_activation_post_process_ctr: continue - assert len(node.users) == 1 + if len(node.users) != 1: + raise AssertionError("expected exactly one user for the node") activation_post_process_name = next(iter(node.users.keys())).name activation_post_process_ctr = op_index_to_activation_post_process_ctr[ op_index @@ -195,7 +197,8 @@ def _get_reference_quantized_lstm_module( for i, layer in enumerate(quantized_lstm.layers): cell = copy.deepcopy(observed_lstm.layers.get_submodule(str(i)).layer_fw.cell) # type: ignore[union-attr] cell = convert_to_reference_fx(cell, backend_config=backend_config) # type: ignore[arg-type] - assert isinstance(cell, torch.fx.GraphModule) + if not isinstance(cell, torch.fx.GraphModule): + raise AssertionError("cell must be converted to a torch.fx.GraphModule") # HACK: Manually remove input quantize nodes and output dequantize nodes, # since custom modules expect quint8 inputs and outputs for now. Note that # this functionality is supposedly handled through PrepareCustomConfig's @@ -205,11 +208,11 @@ def _get_reference_quantized_lstm_module( # on custom module input/output dtypes, and (2) expand support for complex # input/output structures. for node in cell.graph.nodes: - if node.target == torch.quantize_per_tensor: + if node.target is torch.quantize_per_tensor: arg = node.args[0] # Remove quantize(x), quantize(hidden[0]), and quantize(hidden[1]) if arg.target == "x" or ( - arg.target == operator.getitem and arg.args[0].target == "hidden" + arg.target is operator.getitem and arg.args[0].target == "hidden" ): with cell.graph.inserting_before(node): node.replace_all_uses_with(arg) diff --git a/torch/ao/quantization/fx/match_utils.py b/torch/ao/quantization/fx/match_utils.py index 95d2b27f23ca1..86dee6a8965b5 100644 --- a/torch/ao/quantization/fx/match_utils.py +++ b/torch/ao/quantization/fx/match_utils.py @@ -33,7 +33,8 @@ def _is_match(modules, node, pattern, max_uses=sys.maxsize): if isinstance(pattern, tuple): self_match, *arg_matches = pattern if self_match is getattr: - assert len(pattern) == 2, "Expecting getattr pattern to have two elements" + if len(pattern) != 2: + raise AssertionError("Expecting getattr pattern to have two elements") arg_matches = [] else: self_match = pattern @@ -190,7 +191,8 @@ def record_match(pattern, node, last_node, matched_node_pattern, match_map): break # add custom module instances to the match result - assert modules is not None + if modules is None: + raise AssertionError("modules must not be None") for node in graph.nodes: if ( node.op == "call_module" @@ -204,7 +206,8 @@ def record_match(pattern, node, last_node, matched_node_pattern, match_map): ) def is_standalone_module(node_target: str, modules: dict[str, torch.nn.Module]): - assert modules is not None + if modules is None: + raise AssertionError("modules must not be None") return ( node_target in standalone_module_names or type(modules[node_target]) # type: ignore[operator] diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index 4b97311cd93d3..0c05e6499901d 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -149,10 +149,11 @@ def _create_obs_or_fq_from_qspec( return None if isinstance(quantization_spec, SharedQuantizationSpec): edge_or_node = quantization_spec.edge_or_node - assert edge_or_node in obs_or_fq_map, ( - "please make sure only refer to edge or node that has " - f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}" - ) + if edge_or_node not in obs_or_fq_map: + raise AssertionError( + "please make sure only refer to edge or node that has " + f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}" + ) return obs_or_fq_map[edge_or_node] elif isinstance(quantization_spec, DerivedQuantizationSpec): # can't use asdict, so not calling get_observer_kwargs here @@ -177,7 +178,8 @@ def _create_obs_or_fq_from_qspec( else: return observer_ctr() - assert isinstance(quantization_spec, QuantizationSpec) + if not isinstance(quantization_spec, QuantizationSpec): + raise AssertionError("quantization_spec must be a QuantizationSpec") observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr kwargs = _get_observer_kwargs(quantization_spec) kwargs.pop("observer_or_fake_quant_ctr") @@ -214,10 +216,14 @@ def _needs_obs_or_fq( # need to insert placeholder observer for dynamic quantization so that it can # be converted to choose_qparams -> q -> dq in convert step if cur_target_is_dynamic: - assert cur_target_dtype in _OBS_DTYPE_LIST, ( - f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}" - ) - assert prev_output_dtype not in _DO_NOT_OBS_DTYPE_LIST + if cur_target_dtype not in _OBS_DTYPE_LIST: + raise AssertionError( + f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}" + ) + if prev_output_dtype in _DO_NOT_OBS_DTYPE_LIST: + raise AssertionError( + "prev_output_dtype must not be in _DO_NOT_OBS_DTYPE_LIST" + ) return is_zeroth_arg if reuse_input_obs_or_fq: return False @@ -398,7 +404,8 @@ def _is_pattern_dtype_config_and_qconfig_supported_by_backend( """ if backend_config is None or pattern is None: return True - assert matched_node_pattern is not None and len(matched_node_pattern) >= 1 + if matched_node_pattern is None or len(matched_node_pattern) < 1: + raise AssertionError("matched_node_pattern must be non-empty") pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config) dtype_configs: list[DTypeConfig] = pattern_to_dtype_configs.get(pattern, []) pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config) @@ -535,7 +542,8 @@ def _set_target_dtype_info_for_matched_node_pattern( # other types of matched object, e.g. int, float literals, are ignored elif isinstance(matched_node_pattern, Node): # for pyre - assert isinstance(matched_node_pattern, Node) + if not isinstance(matched_node_pattern, Node): + raise AssertionError("matched_node_pattern must be a Node") node = matched_node_pattern if node in processed_nodes: return @@ -674,7 +682,8 @@ def _get_output_act_obs_or_fq( We are assuming that the observers are inserted correctly, and the dtype for argument in quantized graph will match what is specified by the qconfig """ - assert isinstance(arg, Node) + if not isinstance(arg, Node): + raise AssertionError("arg must be a Node") if "quantization_annotation" in arg.meta: return _create_obs_or_fq_from_qspec( arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat @@ -698,9 +707,8 @@ def _get_output_act_obs_or_fq( ) elif _is_activation_post_process_node(arg, named_modules): observed_arg = arg.args[0] - assert isinstance(observed_arg, Node), ( - "Currently we only support observing Node" - ) + if not isinstance(observed_arg, Node): + raise AssertionError("Currently we only support observing Node") if "quantization_annotation" in observed_arg.meta: output_act_obs_or_fq = _create_obs_or_fq_from_qspec( observed_arg.meta["quantization_annotation"].output_qspec, @@ -708,7 +716,10 @@ def _get_output_act_obs_or_fq( is_qat, ) else: - assert "target_dtype_info" in observed_arg.meta + if "target_dtype_info" not in observed_arg.meta: + raise AssertionError( + "expected 'target_dtype_info' in observed_arg.meta" + ) output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"][ "output_act_obs_or_fq_ctr" ] @@ -754,7 +765,8 @@ def _get_arg_as_input_act_obs_or_fq( """Get the observer or fake quant constructor for the Argument `arg`, as input to Node `node` """ - assert isinstance(arg, Node) + if not isinstance(arg, Node): + raise AssertionError("arg must be a Node") # "input_qspec_map" is the more general design we'll use for pt2e path # it is a map from input argument node to observer or fake quant constructor, for example # for the following graph: @@ -838,7 +850,8 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( if not isinstance(arg, Node): return arg - assert isinstance(arg, Node) + if not isinstance(arg, Node): + raise AssertionError("arg must be a Node") # default (no observer) new_arg = arg @@ -854,7 +867,8 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( "quantization_annotation" ]._reuse_input_obs_or_fq else: - assert "target_dtype_info" in node.meta + if "target_dtype_info" not in node.meta: + raise AssertionError("expected 'target_dtype_info' in node.meta") # TODO: we are assuming "target_dtype_info" exists here, maybe # a default value also need to be provided here target_dtype_info = node.meta["target_dtype_info"] @@ -889,7 +903,8 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( ) else: - assert qconfig is not None + if qconfig is None: + raise AssertionError("qconfig must not be None") # custom flow for standalone modules _, _, sm_prepare_custom_config, _ = _get_standalone_module_configs( node, named_modules, prepare_custom_config, qconfig, backend_config @@ -946,7 +961,8 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( existing_obs_node = maybe_obs_node break - assert arg_as_input_act_obs_or_fq is not None + if arg_as_input_act_obs_or_fq is None: + raise AssertionError("arg_as_input_act_obs_or_fq must not be None") obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq if existing_obs_node is None: new_obs_node = _insert_obs_or_fq( @@ -1102,7 +1118,8 @@ def _maybe_insert_output_observer_for_node( Note: inserting dynamic quantization ops for output is not supported in fx graph mode quantization code path right now """ - assert node.op != "output", "observer insertion for outputs is handled elsewhere" + if node.op == "output": + raise AssertionError("observer insertion for outputs is handled elsewhere") is_standalone_module = False if "quantization_annotation" in node.meta: @@ -1110,7 +1127,8 @@ def _maybe_insert_output_observer_for_node( node.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat ) else: - assert "target_dtype_info" in node.meta + if "target_dtype_info" not in node.meta: + raise AssertionError("expected 'target_dtype_info' in node.meta") is_standalone_module = node.meta["target_dtype_info"].get( "_is_standalone_module", False ) @@ -1222,7 +1240,10 @@ def _recursive_maybe_replace_node_with_obs( and arg_as_input_target_dtype != torch.float ) if need_obs: - assert observer_mod is not None + if observer_mod is None: + raise AssertionError( + "observer_mod must not be None when need_obs is True" + ) # insert observer observer_node = _insert_obs_or_fq( maybe_node, observer_mod, model, named_modules, graph @@ -1393,9 +1414,11 @@ def _maybe_make_input_output_share_observers( if iteration_guard > 10000: raise AssertionError("Unable to find observer of previous node") - assert isinstance(first_arg_arg, Node) + if not isinstance(first_arg_arg, Node): + raise AssertionError("first_arg_arg must be a Node") target_to_use = first_arg_arg.target - assert isinstance(target_to_use, str) + if not isinstance(target_to_use, str): + raise AssertionError("target_to_use must be a string") obs_mod_to_use = named_modules[target_to_use] if isinstance(first_arg, (list, tuple)): @@ -1418,7 +1441,10 @@ def _maybe_make_input_output_share_observers( # set the output observer node to use that module for output_obs_node in node.users.keys(): - assert _is_activation_post_process_node(output_obs_node, named_modules) + if not _is_activation_post_process_node(output_obs_node, named_modules): + raise AssertionError( + "output_obs_node must be an activation post process node" + ) parent_name, name = _parent_name(output_obs_node.target) setattr(named_modules[parent_name], name, obs_mod_to_use) @@ -1431,7 +1457,10 @@ def _remove_output_observer( ): items = list(node.users.items()) for output_obs_node, _ in items: - assert _is_activation_post_process_node(output_obs_node, named_modules) + if not _is_activation_post_process_node(output_obs_node, named_modules): + raise AssertionError( + "output_obs_node must be an activation post process node" + ) output_obs_node.replace_all_uses_with(node) model.graph.erase_node(output_obs_node) # type: ignore[union-attr, operator] @@ -1554,7 +1583,8 @@ def insert_observers_for_model( qhandler, qconfig, ) = match_res_with_qconfig - assert qhandler is not None + if qhandler is None: + raise AssertionError("qhandler must not be None") _set_target_dtype_info_for_matched_node_pattern( matched_node_pattern, last_node, @@ -1632,7 +1662,8 @@ def insert_observers_for_model( pattern, matched_node_pattern, qconfig, backend_config ) ) - assert qhandler is not None + if qhandler is None: + raise AssertionError("qhandler must not be None") # get output_act_dtype so that we don't also reset the special typed nodes # TODO: we might want to handle these more uniformly with the default path @@ -1726,7 +1757,8 @@ def insert_observers_for_model( if not skip_inserting_observers and is_supported_by_backend: named_modules = dict(model.named_modules(remove_duplicate=False)) if node.op != "output": - assert matched_node_pattern is not None + if matched_node_pattern is None: + raise AssertionError("matched_node_pattern must not be None") # add matched nodes to the observed node name set _add_matched_node_name_to_set( matched_node_pattern, observed_node_names @@ -2064,8 +2096,10 @@ def prepare( ) backend_config = BackendConfig.from_dict(backend_config) - assert isinstance(qconfig_mapping, QConfigMapping) - assert isinstance(_equalization_config, QConfigMapping) + if not isinstance(qconfig_mapping, QConfigMapping): + raise AssertionError("qconfig_mapping must be a QConfigMapping") + if not isinstance(_equalization_config, QConfigMapping): + raise AssertionError("_equalization_config must be a QConfigMapping") qconfig_mapping = copy.deepcopy(qconfig_mapping) _equalization_config = copy.deepcopy(_equalization_config) @@ -2194,11 +2228,12 @@ def prepare( ) if is_standalone_module: - assert result_node is not None - assert isinstance(result_node.args[0], Node), ( - "standalone module only supports returning simple value currently" - "(not tuple, dict etc.)" - ) + if result_node is None: + raise AssertionError("result_node must not be None for standalone modules") + if not isinstance(result_node.args[0], Node): + raise AssertionError( + "standalone module only supports returning simple value currently (not tuple, dict etc.)" + ) # these inputs are observed in parent # converting List[int] to Tensor since module attribute is # Union[Tensor, Module] diff --git a/torch/ao/quantization/fx/qconfig_mapping_utils.py b/torch/ao/quantization/fx/qconfig_mapping_utils.py index 7e4ebbf75bc3d..74f90505ea2af 100644 --- a/torch/ao/quantization/fx/qconfig_mapping_utils.py +++ b/torch/ao/quantization/fx/qconfig_mapping_utils.py @@ -228,11 +228,12 @@ def _compare_prepare_convert_qconfig_mappings( `prepare_qconfig_mapping`: configuration for prepare quantization step `convert_qconfig_mapping`: configuration for convert quantization step """ - assert qconfig_equals( + if not qconfig_equals( prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig - ), ( - "Expected global qconfigs to be the same in the prepare and convert quantization configs" - ) + ): + raise AssertionError( + "Expected global qconfigs to be the same in the prepare and convert quantization configs" + ) prepare_dicts: list[OrderedDict] = [ prepare_qconfig_mapping.object_type_qconfigs, prepare_qconfig_mapping.module_name_qconfigs, @@ -250,16 +251,17 @@ def _compare_prepare_convert_qconfig_mappings( ] for i in range(len(prepare_dicts)): for name in prepare_dicts[i].keys(): - assert name in convert_dicts[i], ( - f"Missing key {dict_names[i]} {name} in convert QConfigMapping \ - when it was present in prepare" - ) - assert convert_dicts[i][name] is None or qconfig_equals( + if name not in convert_dicts[i]: + raise AssertionError( + f"Missing key {dict_names[i]} {name} in convert QConfigMapping when it was present in prepare" + ) + if convert_dicts[i][name] is not None and not qconfig_equals( prepare_dicts[i][name], convert_dicts[i][name] - ), ( - f"Expected convert QConfigMapping to have the same qconfig as prepare for key {dict_names[i]} {name}; \ - prepare: {prepare_dicts[i][name]}; convert: {convert_dicts[i][name]}" - ) + ): + raise AssertionError( + "Expected convert QConfigMapping to have the same qconfig as prepare for key " + f"{dict_names[i]} {name}; prepare: {prepare_dicts[i][name]}; convert: {convert_dicts[i][name]}" + ) def _is_qconfig_supported_by_dtype_configs( diff --git a/torch/ao/quantization/fx/quantize_handler.py b/torch/ao/quantization/fx/quantize_handler.py index 2acb711943172..6ab33a2283112 100644 --- a/torch/ao/quantization/fx/quantize_handler.py +++ b/torch/ao/quantization/fx/quantize_handler.py @@ -119,10 +119,11 @@ def __init__( ): super().__init__(node_pattern, modules, root_node_getter) if num_tensor_args_to_observation_type: - assert self.num_tensor_args in num_tensor_args_to_observation_type, ( - f"Must provide observation_type config for tensor number {self.num_tensor_args}" - f" in num_tensor_args_to_observation_type for {node_pattern}" - ) + if self.num_tensor_args not in num_tensor_args_to_observation_type: + raise AssertionError( + f"Must provide observation_type config for tensor number {self.num_tensor_args}" + f" in num_tensor_args_to_observation_type for {node_pattern}" + ) self.observation_type = num_tensor_args_to_observation_type[ self.num_tensor_args ] diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index 2044fce538fd9..3e2afaaa1d9f3 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -165,7 +165,8 @@ def get_qconv_prepack_op(conv_op: Callable) -> Callable: torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack, } prepack_op = prepack_ops.get(conv_op) - assert prepack_op, f"Didn't find prepack op for {conv_op}" + if prepack_op is None: + raise AssertionError(f"Didn't find prepack op for {conv_op}") return prepack_op @@ -214,7 +215,7 @@ def forward(self, x): # hit input, can't fold in this case return None nodes.append(arg) - if not (arg.op == "call_function" and arg.target == getattr): + if not (arg.op == "call_function" and arg.target is getattr): frontier.append(arg) return nodes @@ -230,7 +231,8 @@ def graph_module_from_producer_nodes( Return: A graph module constructed from the producer nodes """ - assert len(producer_nodes) > 0, "list of producer nodes can not be empty" + if len(producer_nodes) == 0: + raise AssertionError("list of producer nodes can not be empty") # since we traced back from node to getattr producer_nodes.reverse() graph = Graph() @@ -300,7 +302,8 @@ def all_node_args_have_no_tensors( elif node.op == "placeholder": result = False elif node.op == "call_module": - assert isinstance(node.target, str) + if not isinstance(node.target, str): + raise AssertionError("node.target must be a string for call_module nodes") if _is_activation_post_process(modules[node.target]): result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type] elif node.op == "call_module": @@ -503,9 +506,10 @@ def _is_custom_module_lstm( """ mod = _get_module(node, named_modules) if qconfig is not None and qhandler is not None: - assert isinstance( + if not isinstance( qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler - ) # type: ignore[attr-defined] + ): # type: ignore[attr-defined] + raise AssertionError("qhandler must be a QuantizeHandler when provided") return ( isinstance(mod, torch.nn.LSTM) and activation_is_statically_quantized(qconfig) @@ -527,9 +531,10 @@ def _is_custom_module_mha( """ mod = _get_module(node, named_modules) if qconfig is not None and qhandler is not None: - assert isinstance( + if not isinstance( qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler - ) # type: ignore[attr-defined] + ): # type: ignore[attr-defined] + raise AssertionError("qhandler must be a QuantizeHandler when provided") return ( isinstance(mod, torch.nn.MultiheadAttention) and activation_is_statically_quantized(qconfig) @@ -701,7 +706,7 @@ def match_lstm(a): return _is_custom_module_lstm(a, named_modules) def match_getitem(a): - return a.op == "call_function" and a.target == operator.getitem + return a.op == "call_function" and a.target is operator.getitem def match_tuple(a): return a.op == "call_function" and a.target is tuple @@ -717,7 +722,7 @@ def _match_pattern(match_pattern: list[Callable]) -> Optional[Node]: return None # Match next arg, for tuple the arg is a tuple of a list, e.g. ([dq_1, other_node],) if i < len(match_pattern) - 1: - if match == match_tuple: + if match is match_tuple: a = a.args[0][0] # type: ignore[assignment,index] else: a = a.args[0] # type: ignore[assignment] @@ -805,7 +810,7 @@ def find_patterns( find_patterns( user, index_stack, current_pattern, matched_patterns, seen ) - elif user.op == "call_function" and user.target == operator.getitem: + elif user.op == "call_function" and user.target is operator.getitem: if len(index_stack) > 0: if user.args[1] == index_stack[-1]: index_stack.pop() @@ -826,11 +831,17 @@ def find_patterns( for pattern in matched_patterns: first_tuple = pattern[0] last_getitem = pattern[-1] - assert first_tuple.op == "call_function" and first_tuple.target is tuple - assert ( + if not (first_tuple.op == "call_function" and first_tuple.target is tuple): + raise AssertionError( + "first tuple node must be a call_function with target tuple" + ) + if not ( last_getitem.op == "call_function" - and last_getitem.target == operator.getitem - ) + and last_getitem.target is operator.getitem + ): + raise AssertionError( + "last getitem node must be a call_function with target operator.getitem" + ) last_getitem_index = last_getitem.args[1] new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index] for user in list(last_getitem.users.keys()): @@ -847,7 +858,10 @@ def _get_observer_from_activation_post_process( if isinstance(activation_post_process, ObserverBase): return activation_post_process else: - assert isinstance(activation_post_process, FakeQuantizeBase) + if not isinstance(activation_post_process, FakeQuantizeBase): + raise AssertionError( + "activation_post_process must be an ObserverBase or FakeQuantizeBase" + ) return activation_post_process.activation_post_process # type: ignore[return-value] @@ -966,7 +980,10 @@ def _activation_post_process_satisfies_dtype_config_constraints( satisfies_constraints = True if activation_post_process_ctr is not None: activation_post_process = activation_post_process_ctr() - assert _is_activation_post_process(activation_post_process) + if not _is_activation_post_process(activation_post_process): + raise AssertionError( + "activation_post_process must be an activation post process" + ) # If dtypes don't match, don't check the activation_post_process and return True early if activation_post_process.dtype != dtype_with_constraints.dtype: return True diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index 20b1252f1be80..e7e04795302f2 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -250,17 +250,17 @@ def __init__( ) self.reduce_range = reduce_range self.register_buffer("eps", torch.tensor([eps], **factory_kwargs)) - assert self.qscheme in ( + if self.qscheme not in ( torch.per_tensor_affine, torch.per_tensor_symmetric, torch.per_channel_affine, torch.per_channel_symmetric, torch.per_channel_affine_float_qparams, - ), ( - "Default Observer only works for per_tensor_affine, \ - per_tensor_symmetric, per_channel_affine, \ - per_channel_symmetric and per_channel_float_qparams quantization scheme" - ) + ): + raise AssertionError( + "Default Observer only works for per_tensor_affine, per_tensor_symmetric, " + "per_channel_affine, per_channel_symmetric and per_channel_float_qparams quantization scheme" + ) _ALLOWED_DTYPES = ( torch.qint8, @@ -276,9 +276,10 @@ def __init__( torch.uint16, ) - assert self.dtype in _ALLOWED_DTYPES, ( - f"Default Observer only works for {_ALLOWED_DTYPES} data type" - ) + if self.dtype not in _ALLOWED_DTYPES: + raise AssertionError( + f"Default Observer only works for {_ALLOWED_DTYPES} data type" + ) self.has_customized_qrange = (quant_min is not None) and (quant_max is not None) if self.has_customized_qrange: # pyrefly: ignore [bad-argument-type] @@ -337,12 +338,12 @@ def _validate_qmin_qmax(self, quant_min: int, quant_max: int) -> None: """ # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer. - assert quant_min <= 0 <= quant_max, ( - "Used-specified quantization range must include 0." - ) - assert quant_min < quant_max, ( - "qmin must be strictly less than qmax for user-specified quantization range." - ) + if not quant_min <= 0 <= quant_max: + raise AssertionError("Used-specified quantization range must include 0.") + if quant_min >= quant_max: + raise AssertionError( + "qmin must be strictly less than qmax for user-specified quantization range." + ) @torch.jit.export def _calculate_qparams( @@ -1134,7 +1135,8 @@ def _non_linear_param_search(self) -> tuple[torch.Tensor, torch.Tensor]: This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in caffe2/quantization/server/norm_minimization.cc """ - assert self.histogram.size()[0] == self.bins, "bins mismatch" + if self.histogram.size()[0] != self.bins: + raise AssertionError("bins mismatch") bin_width = (self.max_val - self.min_val) / self.bins # cumulative sum @@ -1213,7 +1215,7 @@ def _upscale_histogram( boundaries_new_histogram = torch.linspace( update_min, update_max, self.bins + 1, device=update_min.device ).to(histogram.device) - # this maps the mid-poits of the histogram to the new histogram's space + # this maps the mid-points of the histogram to the new histogram's space bucket_assignments = ( torch.bucketize(mid_points_histogram, boundaries_new_histogram, right=True) - 1 @@ -1255,8 +1257,10 @@ def _combine_histograms( return transformed_orig_hist + update_hist # We assume the update_hist is already in the target range, we will map the orig_max to it - assert update_min <= orig_min - assert update_max >= orig_max + if update_min > orig_min: + raise AssertionError("update_min must be <= orig_min") + if update_max < orig_max: + raise AssertionError("update_max must be >= orig_max") # Now we need to turn the old_histogram, into the range of the new histogram transformed_orig_hist = self._upscale_histogram( @@ -1276,9 +1280,8 @@ def reset_histogram( self.min_val.copy_(min_val) self.max_val.resize_(max_val.shape) self.max_val.copy_(max_val) - assert min_val.numel() == 1 and max_val.numel() == 1, ( - "histogram min/max values must be scalar." - ) + if min_val.numel() != 1 or max_val.numel() != 1: + raise AssertionError("histogram min/max values must be scalar.") new_histogram = torch.histc(x, self.bins, min=min_val, max=max_val) # type: ignore[arg-type] self.histogram.detach_().resize_(new_histogram.shape) self.histogram.copy_(new_histogram) @@ -1356,10 +1359,11 @@ def calculate_qparams(self): # type: ignore[override] return torch.tensor([1.0], device=self.min_val.device.type), torch.tensor( [0], device=self.min_val.device.type ) - assert self.bins == len(self.histogram), ( - "The number of bins in histogram should be equal to the number of bins " - "supplied while making this observer" - ) + if self.bins != len(self.histogram): + raise AssertionError( + "The number of bins in histogram should be equal to the number of bins " + "supplied while making this observer" + ) new_min, new_max = self._non_linear_param_search() @@ -1792,9 +1796,10 @@ def get_block_size( input_shape: The input tensor shape possibly more than 2 dimensions granularity: The granularity type of the quantization """ - assert isinstance(granularity, Granularity), ( - "Please provide an instance of Granularity, not subclass of it" - ) + if not isinstance(granularity, Granularity): + raise AssertionError( + "Please provide an instance of Granularity, not subclass of it" + ) if isinstance(granularity, PerTensor): return input_shape elif isinstance(granularity, PerAxis): @@ -1804,9 +1809,10 @@ def get_block_size( elif isinstance(granularity, PerRow): return (1,) * (len(input_shape) - 1) + (input_shape[-1],) elif isinstance(granularity, PerGroup): - assert len(input_shape) == 2, ( - f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}" - ) + if len(input_shape) != 2: + raise AssertionError( + f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}" + ) return (1, granularity.group_size) elif isinstance(granularity, PerToken): block_size = [1] * len(input_shape) @@ -1843,8 +1849,8 @@ def __init__( **kwargs, ): super().__init__() - assert granularity is not None, "granularity is None" - + if granularity is None: + raise AssertionError("granularity is None") self.mapping_type = mapping_type self.target_dtype = target_dtype self.granularity = granularity @@ -1882,10 +1888,10 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node): from torch.ao.quantization.fx.utils import create_getattr_from_value with model.graph.inserting_before(observer_node): - assert self.block_size is not None, "Expecting block_size to be populated" - assert self.original_dtype is not None, ( - "Expecting original_dtype to be populated" - ) + if self.block_size is None: + raise AssertionError("Expecting block_size to be populated") + if self.original_dtype is None: + raise AssertionError("Expecting original_dtype to be populated") if hasattr(self, "is_dynamic") and self.is_dynamic: choose_qparams_affine = model.graph.call_function( torch.ops.pt2e_quant.choose_qparams_affine, diff --git a/torch/ao/quantization/pt2e/duplicate_dq_pass.py b/torch/ao/quantization/pt2e/duplicate_dq_pass.py index 34a95eb80fb22..81c03e5141432 100644 --- a/torch/ao/quantization/pt2e/duplicate_dq_pass.py +++ b/torch/ao/quantization/pt2e/duplicate_dq_pass.py @@ -64,7 +64,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: if ( isinstance(getitem_node, torch.fx.node.Node) and getitem_node.op == "call_function" - and getitem_node.target == operator.getitem + and getitem_node.target is operator.getitem ): choose_qparam_node = getitem_node.args[0] if ( diff --git a/torch/ao/quantization/pt2e/lowering.py b/torch/ao/quantization/pt2e/lowering.py index 742549dedcf8d..c306b1745bada 100644 --- a/torch/ao/quantization/pt2e/lowering.py +++ b/torch/ao/quantization/pt2e/lowering.py @@ -12,7 +12,7 @@ def lower_pt2e_quantized_to_x86( model: torch.fx.GraphModule, example_inputs: tuple[torch.Tensor, ...], ) -> torch.fx.GraphModule: - """Lower a PT2E-qantized model to x86 backend. + """Lower a PT2E-quantized model to x86 backend. Args: * `model` (torch.fx.GraphModule): a model quantized by PT2E quantization flow. @@ -38,7 +38,7 @@ def _node_replace(m): # type: ignore[no-untyped-def] aten = torch.ops.aten g = m.graph for node in g.nodes: - if node.target == aten.t.default: + if node.target is aten.t.default: with g.inserting_before(node): x = node.args[0] dims = [1, 0] diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py index aedad07cc8a67..6eac69a96ba42 100644 --- a/torch/ao/quantization/pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -455,9 +455,9 @@ def _maybe_insert_input_observers_for_node( # gelu has a has an approximate kwarg that persist in exported graph. # This is just a work around for these. if not ( - node.target == torch.ops.aten.clone.default - or node.target == torch.ops.aten.zeros_like.default - or node.target == torch.ops.aten.gelu.default + node.target is torch.ops.aten.clone.default + or node.target is torch.ops.aten.zeros_like.default + or node.target is torch.ops.aten.gelu.default or len(node.kwargs) == 0 ): raise AssertionError(" expecting kwargs for aten op IR to be empty") diff --git a/torch/ao/quantization/pt2e/qat_utils.py b/torch/ao/quantization/pt2e/qat_utils.py index b7daca97b18f7..e5a245dc3dadd 100644 --- a/torch/ao/quantization/pt2e/qat_utils.py +++ b/torch/ao/quantization/pt2e/qat_utils.py @@ -396,7 +396,7 @@ def _get_nodes(nodes: list[Node]) -> tuple[Node, Node, Optional[Node]]: f"Found multiple bn nodes in match, previous: {bn_node}, new: {n}" ) bn_node = n - if n.target == operator.getitem: + if n.target is operator.getitem: if getitem_node is not None: raise AssertionError( f"Found multiple getitem nodes in match, previous: {getitem_node}, new: {n}" @@ -939,7 +939,7 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule: # remove in place add from batchnorm tracking training stats for node in m.graph.nodes: if ( - node.target == torch.ops.aten.add_.Tensor + node.target is torch.ops.aten.add_.Tensor and node.args[0].op == "get_attr" and node.args[1] == 1 and ( diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index e43195a38085a..f6e9789e94827 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -97,10 +97,10 @@ def _find_q_dq_node_for_user( def _is_sym_size_node(node: Node): return ( node.op == "call_function" - and node.target == torch.ops.aten.sym_size.default - or node.target == torch.ops.aten.sym_numel.default - or node.target == torch.ops.aten.sym_numel - or node.target == torch.ops.aten.sym_size + and node.target is torch.ops.aten.sym_size.default + or node.target is torch.ops.aten.sym_numel.default + or node.target is torch.ops.aten.sym_numel + or node.target is torch.ops.aten.sym_size ) @@ -204,7 +204,7 @@ def _is_conv_transpose_fn(conv_fn: Callable): def _is_bn_node(n: Node): return ( _is_supported_batch_norm_for_training(n) - or n.target == torch.ops.aten._native_batch_norm_legit_no_training.default + or n.target is torch.ops.aten._native_batch_norm_legit_no_training.default ) @@ -228,7 +228,7 @@ def fold_bn_weights_into_conv_node( bn_b = _get_tensor_constant_from_node(bn_args[2], m) bn_rm = _get_tensor_constant_from_node(bn_args[3], m) bn_rv = _get_tensor_constant_from_node(bn_args[4], m) - if bn_node.target == torch.ops.aten._native_batch_norm_legit_no_training.default: + if bn_node.target is torch.ops.aten._native_batch_norm_legit_no_training.default: eps_arg_index = 6 elif _is_supported_batch_norm_for_training(bn_node): eps_arg_index = 7 @@ -268,7 +268,7 @@ def fold_bn_weights_into_conv_node( # native_batch_norm has 3 outputs, we expect getitem calls on the output # and we want to replace the uses of getitem 0 with the output of conv # - if bn_node.target == torch.ops.aten.batch_norm.default: + if bn_node.target is torch.ops.aten.batch_norm.default: # With the new training ir, instead of batch_norm + getitem, # we only have the batch_norm node. # @@ -377,7 +377,7 @@ def _get_aten_graph_module_for_pattern( for node in aten_pattern.graph.nodes: # type: ignore[union-attr] if ( node.op == "call_function" - and node.target == torch.ops.aten.copy_.default + and node.target is torch.ops.aten.copy_.default and len(node.users) == 0 ): aten_pattern.graph.erase_node(node) # type: ignore[operator, union-attr] diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index 4cf9cf834de1c..89c5bb107c931 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -566,9 +566,10 @@ def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> N torch.ao.quantization.MovingAveragePerChannelMinMaxObserver, ), ) - assert not is_per_channel, ( - "Per channel weight observer is not supported yet for ConvTranspose{n}d." - ) + if is_per_channel: + raise AssertionError( + "Per channel weight observer is not supported yet for ConvTranspose{n}d." + ) if sys.version_info < (3, 12): @@ -600,7 +601,8 @@ def _add_module_to_qconfig_obs_ctr( return qconfig def get_factory_kwargs_based_on_module_device(): - assert isinstance(module, torch.nn.Module) + if not isinstance(module, torch.nn.Module): + raise AssertionError("module must be an instance of torch.nn.Module") devices = {p.device for p in module.parameters()} | { p.device for p in module.buffers() } @@ -672,7 +674,10 @@ def qconfig_equals(q1: QConfigAny, q2: QConfigAny): if q1 is None or q2 is None: return q1 == q2 else: - assert q1 is not None and q2 is not None + if q1 is None or q2 is None: + raise AssertionError( + "Both q1 and q2 must be non-None for qconfig comparison" + ) try: # Qconfig weight and activation can be either a partial wrapper, # or an observer class. Special handling is required (above) for diff --git a/torch/ao/quantization/quantization_mappings.py b/torch/ao/quantization/quantization_mappings.py index ee2c63cc291b8..c9173e6bc6e91 100644 --- a/torch/ao/quantization/quantization_mappings.py +++ b/torch/ao/quantization/quantization_mappings.py @@ -252,10 +252,11 @@ def get_static_quant_module_class( additional_static_quant_mapping, ) static_quant_module_class = all_mappings.get(float_module_class, None) - assert static_quant_module_class is not None, ( - f"Floating point module class {str(float_module_class)}" - + " does not have a corresponding quantized module class" - ) + if static_quant_module_class is None: + raise AssertionError( + f"Floating point module class {str(float_module_class)}" + + " does not have a corresponding quantized module class" + ) return copy.deepcopy(static_quant_module_class) @@ -272,10 +273,11 @@ def get_dynamic_quant_module_class( DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS, additional_dynamic_quant_mapping ) dynamic_quant_module_class = all_mappings.get(float_module_class, None) - assert dynamic_quant_module_class is not None, ( - f"Floating point module class {str(float_module_class)}" - + " does not have a corresponding quantized module class" - ) + if dynamic_quant_module_class is None: + raise AssertionError( + f"Floating point module class {str(float_module_class)}" + + " does not have a corresponding quantized module class" + ) return copy.deepcopy(dynamic_quant_module_class) @@ -344,9 +346,10 @@ def get_default_float_to_quantized_operator_mappings() -> dict[ def get_quantized_operator(float_op: Union[Callable, str]) -> Callable: """Get the quantized operator corresponding to the float operator""" quantized_op = DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS.get(float_op) - assert quantized_op is not None, ( - f"Operator {str(float_op)} does not have corresponding quantized op" - ) + if quantized_op is None: + raise AssertionError( + f"Operator {str(float_op)} does not have corresponding quantized op" + ) return quantized_op diff --git a/torch/ao/quantization/quantize.py b/torch/ao/quantization/quantize.py index 3c53876081e07..e71dd24fda745 100644 --- a/torch/ao/quantization/quantize.py +++ b/torch/ao/quantization/quantize.py @@ -158,9 +158,10 @@ def _observer_forward_pre_hook(self, input): def _register_activation_post_process_hook(module, pre_hook=False): - assert hasattr(module, "activation_post_process"), ( - "Expect activation_post_process attribute already attached to the module" - ) + if not hasattr(module, "activation_post_process"): + raise AssertionError( + "Expect activation_post_process attribute already attached to the module" + ) if pre_hook: module.register_forward_pre_hook(_observer_forward_pre_hook, prepend=True) else: @@ -198,9 +199,10 @@ def _add_observer_( # respect device affinity when adding observers if device is None: devices = _get_unique_devices_(module) - assert len(devices) <= 1, ( - f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}" - ) + if len(devices) > 1: + raise AssertionError( + f"_add_observer_ only works with cpu or single-device CUDA modules, but got devices {devices}" + ) device = next(iter(devices)) if len(devices) > 0 else None def get_activation_post_process(qconfig, device, special_act_post_process=None): @@ -243,9 +245,10 @@ def insert_activation_post_process(m, special_act_post_process=None): type_before_parametrizations(child), (nnq.FloatFunctional, nnq.QFunctional) ): if needs_observation(child): - assert hasattr(child, "activation_post_process"), ( - f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`" - ) + if not hasattr(child, "activation_post_process"): + raise AssertionError( + f"functional class {type_before_parametrizations(child)} has no pre-defined `activation_post_process`" + ) child.activation_post_process = get_activation_post_process( child.qconfig, device ) @@ -585,7 +588,8 @@ def prepare_qat(model, mapping=None, inplace=False): is mutated """ torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat") - assert model.training, "prepare_qat only works on models in training mode" + if not model.training: + raise AssertionError("prepare_qat only works on models in training mode") if mapping is None: mapping = get_default_qat_module_mappings() @@ -761,7 +765,10 @@ def swap_module( elif type_before_parametrizations(mod) in mapping: qmod = mapping[type_before_parametrizations(mod)] if hasattr(qmod, "_IS_REFERENCE") and qmod._IS_REFERENCE: - assert mod.qconfig is not None + if mod.qconfig is None: + raise AssertionError( + "module qconfig must not be None when swapping to reference module" + ) weight_post_process = mod.qconfig.weight() weight_post_process(mod.weight) weight_qparams = get_qparam_dict(weight_post_process) @@ -788,11 +795,13 @@ def swap_module( # respect device affinity when swapping modules devices = _get_unique_devices_(mod) - assert len(devices) <= 1 or ( - len(devices) == 2 and torch.device("meta") in devices - ), ( - f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" - ) + if not ( + len(devices) <= 1 + or (len(devices) == 2 and torch.device("meta") in devices) + ): + raise AssertionError( + f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" + ) device = next(iter(devices)) if len(devices) > 0 else None if device: new_mod.to(device) diff --git a/torch/ao/quantization/quantize_jit.py b/torch/ao/quantization/quantize_jit.py index 38d9cd6b8b765..79f8db1a792fc 100644 --- a/torch/ao/quantization/quantize_jit.py +++ b/torch/ao/quantization/quantize_jit.py @@ -157,12 +157,12 @@ def _convert_ondevice_jit( model, method_name, inplace=False, debug=False, quant_type=QuantType.STATIC ): _check_is_script_module(model) - assert quant_type == QuantType.DYNAMIC, ( - "This API, while should work for static quant, is only tested for dynamic quant." - ) - assert not method_name.startswith("observe_"), ( - "Pass in valid method to be quantized, e.g. forward" - ) + if quant_type != QuantType.DYNAMIC: + raise AssertionError( + "This API, while should work for static quant, is only tested for dynamic quant." + ) + if method_name.startswith("observe_"): + raise AssertionError("Pass in valid method to be quantized, e.g. forward") observe_method_name = "observe_" + method_name quantize_method_name = "quantize_" + method_name model_c = model._c @@ -230,12 +230,14 @@ def _quantize_jit( model = prepare_dynamic_jit(model, qconfig_dict, inplace) model = convert_dynamic_jit(model, True, debug) else: - assert run_fn, ( - "Must provide calibration function for post training static quantization" - ) - assert run_args, ( - "Must provide calibration dataset for post training static quantization" - ) + if not run_fn: + raise AssertionError( + "Must provide calibration function for post training static quantization" + ) + if not run_args: + raise AssertionError( + "Must provide calibration dataset for post training static quantization" + ) model = prepare_jit(model, qconfig_dict, inplace) run_fn(model, *run_args) model = convert_jit(model, True, debug) diff --git a/torch/ao/quantization/quantizer/embedding_quantizer.py b/torch/ao/quantization/quantizer/embedding_quantizer.py index 88bc6f3c8c9ff..b0f1b823b7fdb 100644 --- a/torch/ao/quantization/quantizer/embedding_quantizer.py +++ b/torch/ao/quantization/quantizer/embedding_quantizer.py @@ -77,7 +77,7 @@ def _annotate_embedding_ops(self, graph: torch.fx.Graph) -> None: # just as an example of alternate ways of annotating if ( node.op == "call_function" - and node.target == torch.ops.aten.embedding.default + and node.target is torch.ops.aten.embedding.default ): if embedding_config.config.weight is None: raise ValueError( diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index e2482077b73eb..b10163d4b1e50 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -263,7 +263,10 @@ def _is_quantized_op_pt2e(node: torch.fx.Node): # The node has not been annotated, directly return False return False quantization_annotation = node.meta.get(QUANT_ANNOTATION_KEY, None) - assert isinstance(quantization_annotation, _X86InductorQuantizationAnnotation) + if not isinstance(quantization_annotation, _X86InductorQuantizationAnnotation): + raise AssertionError( + "quantization_annotation must be an _X86InductorQuantizationAnnotation" + ) return quantization_annotation._is_output_of_quantized_pattern @@ -429,20 +432,22 @@ def _get_current_quantization_mode(self) -> _CurrentQuantizationMode: if qat_state is None: qat_state = qconfig.is_qat else: - assert qat_state == qconfig.is_qat, ( - f"All non-None quantization configs should have the same `is_qat`," - f"but got {qat_state} and {qconfig.is_qat}." - ) + if qat_state != qconfig.is_qat: + raise AssertionError( + f"All non-None quantization configs should have the same `is_qat`," + f"but got {qat_state} and {qconfig.is_qat}." + ) # Query the `is_dynamic` state input_activation_spec = qconfig.input_activation if input_activation_spec is not None: if dynamic_state is None: dynamic_state = input_activation_spec.is_dynamic else: - assert dynamic_state == input_activation_spec.is_dynamic, ( - f"All non-None `input_activation_spec` should have the same `is_dynamic`," - f"but got {dynamic_state} and {input_activation_spec.is_dynamic}." - ) + if dynamic_state != input_activation_spec.is_dynamic: + raise AssertionError( + f"All non-None `input_activation_spec` should have the same `is_dynamic`," + f"but got {dynamic_state} and {input_activation_spec.is_dynamic}." + ) return _CurrentQuantizationMode( qat_state=qat_state, dynamic_state=dynamic_state ) @@ -576,10 +581,12 @@ def _annotate_conv_node_helper( return input_qspec_map = {} input_node = conv_node.args[0] - assert isinstance(input_node, Node) + if not isinstance(input_node, Node): + raise AssertionError("input_node must be a FX Node") input_qspec_map[input_node] = get_input_act_qspec(quantization_config) weight_node = conv_node.args[1] - assert isinstance(weight_node, Node) + if not isinstance(weight_node, Node): + raise AssertionError("weight_node must be a FX Node") input_qspec_map[weight_node] = get_weight_qspec(quantization_config) bias_node = None if len(conv_node.args) == 2 else conv_node.args[2] if isinstance(bias_node, Node): @@ -607,18 +614,23 @@ def _annotate_linear_node_helper( _annotate_nodes_not_quantize(linear_node) return input_qspec_map = {} - assert linear_node.target == torch.ops.aten.linear.default + if linear_node.target is not torch.ops.aten.linear.default: + raise AssertionError( + "linear_node.target must be torch.ops.aten.linear.default" + ) has_bias = len(linear_node.args) == 3 input_index = 0 weight_index = 1 bias_index = 2 input_node = linear_node.args[input_index] - assert isinstance(input_node, Node) + if not isinstance(input_node, Node): + raise AssertionError("input_node must be a FX Node") input_qspec_map[input_node] = get_input_act_qspec(quantization_config) weight_node = linear_node.args[weight_index] - assert isinstance(weight_node, Node) + if not isinstance(weight_node, Node): + raise AssertionError("weight_node must be a FX Node") input_qspec_map[weight_node] = get_weight_qspec(quantization_config) bias_node = linear_node.args[bias_index] if has_bias else None @@ -646,7 +658,8 @@ def _get_output_nodes_of_partitions( if len(partition.output_nodes) > 1: raise ValueError("Input partition has more than one output node") output_node = partition.output_nodes[0] - assert isinstance(output_node, Node) + if not isinstance(output_node, Node): + raise AssertionError("output_node must be a FX Node") output_node_list.append(output_node) if len(output_node_list) != len(partition_list): raise ValueError( @@ -675,7 +688,8 @@ def _get_input_idx_for_binary_node( conv_gemm_node_idx = 1 extra_input_node_idx = 0 extra_input_node = binary_node.args[extra_input_node_idx] # type: ignore[index] - assert isinstance(extra_input_node, Node) + if not isinstance(extra_input_node, Node): + raise AssertionError("extra_input_node must be a FX Node") return conv_gemm_node_idx, extra_input_node_idx def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: @@ -1132,7 +1146,8 @@ def _annotate_conv2d_binary( if conv_node != binary_node.args[conv_node_idx]: raise ValueError(f"{conv_node} doesn't match input of binary node") extra_input_node = binary_node.args[extra_input_node_idx] - assert isinstance(conv_node, Node) + if not isinstance(conv_node, Node): + raise AssertionError("conv_node must be a FX Node") if ( conv_node.op != "call_function" or conv_node.target != torch.ops.aten.conv2d.default @@ -1246,7 +1261,8 @@ def _annotate_maxpool2d( return input_node = maxpool_node.args[0] - assert isinstance(input_node, Node) + if not isinstance(input_node, Node): + raise AssertionError("input_node must be a FX Node") input_qspec_map = {} input_qspec_map[input_node] = get_input_act_qspec(quantization_config) maxpool_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( @@ -1263,11 +1279,14 @@ def _annotate_cat( return cat_node = node input_nodes = cat_node.args[0] - assert isinstance(input_nodes, Sequence) + if not isinstance(input_nodes, Sequence): + raise AssertionError("input_nodes must be a Sequence of FX Nodes") first_input_node = input_nodes[0] input_qspec_map = {} - assert isinstance(first_input_node, Node) - assert isinstance(cat_node, Node) + if not isinstance(first_input_node, Node): + raise AssertionError("first_input_node must be a FX Node") + if not isinstance(cat_node, Node): + raise AssertionError("cat_node must be a FX Node") input_qspec_map[first_input_node] = get_input_act_qspec(quantization_config) share_qparams_with_input_act0_qspec = SharedQuantizationSpec( (first_input_node, cat_node) @@ -1276,7 +1295,8 @@ def _annotate_cat( for input_node in input_nodes[1:]: if input_node not in input_qspec_map: # There has the case of cat same nodes: torch.cat([input0, input0], 1) - assert isinstance(input_node, Node) + if not isinstance(input_node, Node): + raise AssertionError("input_node must be a FX Node") input_qspec_map[input_node] = share_qparams_with_input_act0_qspec cat_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( @@ -1396,7 +1416,7 @@ def _annotate_output_for_int8_in_int8_out_pattern( """ # noqa: B950 edge_or_node: tuple[Node, Node] if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])): - if node.target == torch.ops.aten.max_pool2d.default: + if node.target is torch.ops.aten.max_pool2d.default: maxpool_node = node if not _is_all_annotated( [ @@ -1415,8 +1435,10 @@ def _annotate_output_for_int8_in_int8_out_pattern( ): # Annotate the output_qspec of getitem_node input_act = maxpool_node.args[0] - assert isinstance(input_act, Node) - assert isinstance(maxpool_node, Node) + if not isinstance(input_act, Node): + raise AssertionError("input_act must be a FX Node") + if not isinstance(maxpool_node, Node): + raise AssertionError("maxpool_node must be a FX Node") edge_or_node = (input_act, maxpool_node) maxpool_node_quantization_annotation.output_qspec = ( SharedQuantizationSpec(edge_or_node) @@ -1544,7 +1566,8 @@ def _annotate_linear_binary_unary( raise ValueError( f"{linear_node} doesn't match input of binary node" ) - assert isinstance(linear_node, Node) + if not isinstance(linear_node, Node): + raise AssertionError("linear_node must be a FX Node") if ( linear_node.op != "call_function" or linear_node.target != torch.ops.aten.linear.default diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index 177203e8ff47b..792285dc8aead 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -347,9 +347,8 @@ def set_module_name( quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator patterns in the submodule with this module name with the given `quantization_config` """ - assert quantization_config is not None, ( - " quantization_config == None is not supported yet" - ) + if quantization_config is None: + raise AssertionError("quantization_config == None is not supported yet") self.module_name_config[module_name] = quantization_config return self diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py index db790a12430e5..36289b49331aa 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer_utils.py @@ -121,10 +121,13 @@ def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]): if quantization_config.input_activation is None: return None quantization_spec: QuantizationSpec = quantization_config.input_activation - assert quantization_spec.qscheme in [ + if quantization_spec.qscheme not in [ torch.per_tensor_affine, torch.per_tensor_symmetric, - ] + ]: + raise AssertionError( + f"Unsupported activation qscheme: {quantization_spec.qscheme}" + ) return quantization_spec @@ -134,17 +137,21 @@ def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]): if quantization_config.output_activation is None: return None quantization_spec: QuantizationSpec = quantization_config.output_activation - assert quantization_spec.qscheme in [ + if quantization_spec.qscheme not in [ torch.per_tensor_affine, torch.per_tensor_symmetric, - ] + ]: + raise AssertionError( + f"Unsupported activation qscheme: {quantization_spec.qscheme}" + ) return quantization_spec def get_weight_qspec(quantization_config: Optional[QuantizationConfig]): if quantization_config is None: return None - assert quantization_config is not None + if quantization_config is None: + raise AssertionError("quantization_config must not be None") if quantization_config.weight is None: return None quantization_spec: QuantizationSpec = quantization_config.weight @@ -162,13 +169,15 @@ def get_weight_qspec(quantization_config: Optional[QuantizationConfig]): def get_bias_qspec(quantization_config: Optional[QuantizationConfig]): if quantization_config is None: return None - assert quantization_config is not None + if quantization_config is None: + raise AssertionError("quantization_config must not be None") if quantization_config.bias is None: return None quantization_spec: QuantizationSpec = quantization_config.bias - assert quantization_spec.dtype == torch.float, ( - "Only float dtype for bias is supported for bias right now" - ) + if quantization_spec.dtype != torch.float: + raise AssertionError( + "Only float dtype for bias is supported for bias right now" + ) return quantization_spec @@ -253,11 +262,13 @@ def _annotate_linear_relu( input_qspec_map = {} input_act = linear_node.args[0] - assert isinstance(input_act, Node) + if not isinstance(input_act, Node): + raise AssertionError("input activation must be a FX Node") input_qspec_map[input_act] = input_act_qspec weight = linear_node.args[1] - assert isinstance(weight, Node) + if not isinstance(weight, Node): + raise AssertionError("weight must be a FX Node") input_qspec_map[weight] = weight_qspec # adding weight node to the partition as well @@ -303,11 +314,13 @@ def _annotate_conv( input_qspec_map = {} input_act = conv_node.args[0] - assert isinstance(input_act, Node) + if not isinstance(input_act, Node): + raise AssertionError("input activation must be a FX Node") input_qspec_map[input_act] = get_input_act_qspec(quantization_config) weight = conv_node.args[1] - assert isinstance(weight, Node) + if not isinstance(weight, Node): + raise AssertionError("weight must be a FX Node") input_qspec_map[weight] = get_weight_qspec(quantization_config) # adding weight node to the partition as well @@ -362,11 +375,13 @@ def _do_annotate_conv_relu( input_qspec_map = {} input_act = conv_node.args[0] - assert isinstance(input_act, Node) + if not isinstance(input_act, Node): + raise AssertionError("input activation must be a FX Node") input_qspec_map[input_act] = get_input_act_qspec(quantization_config) weight = conv_node.args[1] - assert isinstance(weight, Node) + if not isinstance(weight, Node): + raise AssertionError("weight must be a FX Node") input_qspec_map[weight] = get_weight_qspec(quantization_config) # adding weight node to the partition as well @@ -635,8 +650,10 @@ def _annotate_gru_io_only( # subgraph input_act = input_nodes[0] input_act_user = next(iter(input_act.users.keys())) - assert isinstance(input_act, Node) - assert isinstance(input_act_user, Node) + if not isinstance(input_act, Node): + raise AssertionError("input activation must be a FX Node") + if not isinstance(input_act_user, Node): + raise AssertionError("input activation user must be a FX Node") input_act_user.meta["quantization_annotation"] = QuantizationAnnotation( input_qspec_map={ input_act: get_input_act_qspec(quantization_config), @@ -646,8 +663,10 @@ def _annotate_gru_io_only( hidden_state = input_nodes[1] hidden_state_user = next(iter(hidden_state.users.keys())) - assert isinstance(hidden_state, Node) - assert isinstance(hidden_state_user, Node) + if not isinstance(hidden_state, Node): + raise AssertionError("hidden state must be a FX Node") + if not isinstance(hidden_state_user, Node): + raise AssertionError("hidden state user must be a FX Node") hidden_state_user.meta["quantization_annotation"] = QuantizationAnnotation( input_qspec_map={ hidden_state: get_input_act_qspec(quantization_config), @@ -655,7 +674,8 @@ def _annotate_gru_io_only( _annotated=True, ) - assert len(output_nodes) == 2, "expecting GRU to have two outputs" + if len(output_nodes) != 2: + raise AssertionError("expecting GRU to have two outputs") for output in output_nodes: output.meta["quantization_annotation"] = QuantizationAnnotation( output_qspec=get_output_act_qspec(quantization_config), @@ -691,7 +711,8 @@ def _annotate_adaptive_avg_pool2d( annotated_partitions.append(partition.nodes) input_act = pool_node.args[0] - assert isinstance(input_act, Node) + if not isinstance(input_act, Node): + raise AssertionError("input activation must be a FX Node") # only annotate input output sharing operator # when the output of the input node is annotated diff --git a/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py b/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py index eff97dbcf27da..1888eb57396e9 100644 --- a/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/xpu_inductor_quantizer.py @@ -112,7 +112,7 @@ def _annotate_output_for_int8_in_int8_out_pattern( node: Node, ) -> None: if (node.target in int8_in_int8_out_ops) and (_is_any_annotated([node])): - if node.target == torch.ops.aten.max_pool2d.default: + if node.target is torch.ops.aten.max_pool2d.default: return else: input_node = node.all_input_nodes[0] diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py index 63c635565c4ce..cc21ca2818662 100644 --- a/torch/ao/quantization/utils.py +++ b/torch/ao/quantization/utils.py @@ -214,7 +214,8 @@ def to_underlying_dtype(qdtype): torch.float8_e5m2: torch.float8_e5m2, torch.float8_e4m3fn: torch.float8_e4m3fn, } - assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + str(qdtype) + if qdtype not in DTYPE_MAPPING: + raise AssertionError("Unsupported dtype: " + str(qdtype)) return DTYPE_MAPPING[qdtype] @@ -269,21 +270,24 @@ def get_swapped_custom_module_class( """ quant_type = get_quant_type(qconfig) class_mapping = custom_module_class_mapping.get(quant_type, {}) - assert type(custom_module) in class_mapping, ( - "did not find corresponding observed " - f"module class for {type(custom_module)} in mapping: {class_mapping}" - ) + if type(custom_module) not in class_mapping: + raise AssertionError( + "did not find corresponding observed " + f"module class for {type(custom_module)} in mapping: {class_mapping}" + ) return class_mapping[type(custom_module)] def activation_dtype(qconfig): - assert qconfig is not None + if qconfig is None: + raise AssertionError("qconfig must be provided to determine activation dtype") activation = qconfig.activation() return activation.dtype def weight_dtype(qconfig): - assert qconfig is not None + if qconfig is None: + raise AssertionError("qconfig must be provided to determine weight dtype") weight = qconfig.weight() return weight.dtype @@ -377,7 +381,8 @@ def get_qconfig_dtypes(qconfig): r"""returns the qconfig tuple for qconfig: (activation_dtype, weight_dtype, activation_is_dynamic) """ - assert qconfig is not None + if qconfig is None: + raise AssertionError("qconfig must be provided to extract dtypes") activation = qconfig.activation() weight = qconfig.weight() act_is_dynamic = getattr(activation, "is_dynamic", False) @@ -385,7 +390,8 @@ def get_qconfig_dtypes(qconfig): def get_quant_type(qconfig): - assert qconfig is not None + if qconfig is None: + raise AssertionError("qconfig must be provided to determine quant type") activation = qconfig.activation() weight = qconfig.weight() static_dtypes = [ @@ -442,11 +448,11 @@ def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool: return False - assert min_val <= max_val, f"min {min_val} should be less than max {max_val}" + if min_val > max_val: + raise AssertionError(f"min {min_val} should be less than max {max_val}") else: - assert torch.all(min_val <= max_val), ( - f"min {min_val} should be less than max {max_val}" - ) + if torch.any(min_val > max_val): + raise AssertionError(f"min {min_val} should be less than max {max_val}") return True @@ -481,13 +487,15 @@ def calculate_qmin_qmax( qrange_len = initial_quant_max - initial_quant_min + 1 if dtype in [torch.qint8, torch.int8]: - assert 0 < qrange_len <= 256, ( - "quantization range should be positive and not exceed the maximum bit range (=256)." - ) + if not (0 < qrange_len <= 256): + raise AssertionError( + "quantization range should be positive and not exceed the maximum bit range (=256)." + ) elif dtype in [torch.qint32, torch.int32]: - assert 0 < qrange_len <= 2**32, ( - "quantization range should be positive and not exceed the maximum bit range (=4294967296)." - ) + if not (0 < qrange_len <= 2**32): + raise AssertionError( + "quantization range should be positive and not exceed the maximum bit range (=4294967296)." + ) if reduce_range: quant_min, quant_max = quant_min // 2, quant_max // 2 else: @@ -635,12 +643,12 @@ def validate_qmin_qmax(quant_min: int, quant_max: int) -> None: """ # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer. - assert quant_min <= 0 <= quant_max, ( - "Used-specified quantization range must include 0." - ) - assert quant_min < quant_max, ( - "qmin must be strictly less than qmax for user-specified quantization range." - ) + if not (quant_min <= 0 <= quant_max): + raise AssertionError("Used-specified quantization range must include 0.") + if quant_min >= quant_max: + raise AssertionError( + "qmin must be strictly less than qmax for user-specified quantization range." + ) # Functionally equivalent to '_calculate_qparams' in observer.py. Observers must be torchscriptable however and qscheme @@ -813,10 +821,11 @@ def _assert_and_get_unique_device(module: torch.nn.Module) -> Any: ) devices = {torch.device("cpu")} "" - assert len(devices) <= 1, ( - "prepare only works with cpu or single-device CUDA modules, " - f"but got devices {devices}" - ) + if len(devices) > 1: + raise AssertionError( + "prepare only works with cpu or single-device CUDA modules, " + f"but got devices {devices}" + ) device = next(iter(devices)) if len(devices) > 0 else None return device diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index 53c8c28af9759..053be3450d6d2 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -297,7 +297,7 @@ def _get_numerical_jacobian( inp_indices = [ i for i, a in enumerate(target) if is_tensor_like(a) and a.requires_grad ] - for i, (inp, inp_idx) in enumerate(zip(_iter_tensors(target, True), inp_indices)): + for inp, inp_idx in zip(_iter_tensors(target, True), inp_indices): jacobians += [ get_numerical_jacobian_wrt_specific_input( fn, @@ -549,7 +549,7 @@ def _get_analytical_jacobian_forward_ad( with fwAD.dual_level(): fw_grads = [] dual_inputs = [] - for i, inp in enumerate(inputs): + for inp in inputs: if is_tensor_like(inp) and inp.requires_grad: if inp.layout == torch._mkldnn: # type: ignore[attr-defined] raise ValueError( @@ -1275,7 +1275,7 @@ def _test_undefined_forward_mode(func, outputs, inputs): tensor_indices.add(i) dual_inputs.append(inp) - for i, (fw_grad, u) in enumerate(zip(fw_grads, all_u)): + for fw_grad, u in zip(fw_grads, all_u): fw_grad.copy_(u.view_as(fw_grad)) for idx, inp in enumerate(inputs): diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 2ade6485fff71..f7c7150aa7e9d 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -44,6 +44,7 @@ "GradientEdge", "get_gradient_edge", "increment_version", + "set_warn_on_accumulate_grad_stream_mismatch", ] @@ -438,6 +439,13 @@ def disable_saved_tensors_hooks(error_message: str) -> Generator[None, None, Non torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message) +def set_warn_on_accumulate_grad_stream_mismatch(enabled: bool) -> None: + """Whether to warn when the AccumulateGrad node's stream does not match the stream + of the node that produced the incoming gradient. + """ + return torch._C._set_warn_on_accumulate_grad_stream_mismatch(enabled) + + class _MultiHandle(RemovableHandle): handles: tuple[RemovableHandle, ...] diff --git a/torch/csrc/Event.cpp b/torch/csrc/Event.cpp index 319eee8a41c6a..fd7d72228fcea 100644 --- a/torch/csrc/Event.cpp +++ b/torch/csrc/Event.cpp @@ -49,6 +49,7 @@ static PyObject* THPEvent_pynew( } THPEvent* self = reinterpret_cast(ptr.get()); + self->weakreflist = nullptr; // TODO: blocking and interprocess are not supported yet. To support them, the // flag system of c10::Event needs to be refactored. C10::Event should also @@ -73,6 +74,7 @@ PyObject* THPEvent_new(c10::DeviceType device_type, c10::EventFlag flag) { auto self = THPObjectPtr{type->tp_alloc(type, 0)}; TORCH_CHECK(self, "Failed to allocate memory for Event"); auto self_ = reinterpret_cast(self.get()); + self_->weakreflist = nullptr; new (&self_->event) c10::Event(device_type, flag); return self.release(); } @@ -82,6 +84,7 @@ static void THPEvent_dealloc(THPEvent* self) { pybind11::gil_scoped_release no_gil{}; self->event.~Event(); } + PyObject_ClearWeakRefs((PyObject*)self); Py_TYPE(self)->tp_free(reinterpret_cast(self)); } @@ -282,7 +285,8 @@ static PyMethodDef THPEvent_methods[] = { {"synchronize", THPEvent_synchronize, METH_NOARGS, nullptr}, {"ipc_handle", THPEvent_ipc_handle, METH_NOARGS, nullptr}, {nullptr}}; - +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Winvalid-offsetof" PyTypeObject THPEventType = { PyVarObject_HEAD_INIT(nullptr, 0) "torch.Event", /* tp_name */ @@ -308,7 +312,7 @@ PyTypeObject THPEventType = { nullptr, /* tp_traverse */ nullptr, /* tp_clear */ nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ + offsetof(THPEvent, weakreflist), /* tp_weaklistoffset */ nullptr, /* tp_iter */ nullptr, /* tp_iternext */ THPEvent_methods, /* tp_methods */ @@ -323,6 +327,7 @@ PyTypeObject THPEventType = { nullptr, /* tp_alloc */ THPEvent_pynew, /* tp_new */ }; +#pragma GCC diagnostic pop void THPEvent_init(PyObject* module) { THPEventClass = &THPEventType; diff --git a/torch/csrc/Event.h b/torch/csrc/Event.h index 3bbc7d3793997..7dfc7bb426d32 100644 --- a/torch/csrc/Event.h +++ b/torch/csrc/Event.h @@ -7,6 +7,7 @@ struct TORCH_API THPEvent { PyObject_HEAD c10::Event event; + PyObject* weakreflist; }; TORCH_API extern PyTypeObject* THPEventClass; TORCH_API extern PyTypeObject THPEventType; diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 8c5f8e5918397..ad37abe3b560b 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1605,6 +1605,32 @@ static PyObject* THPModule_are_vmap_fallback_warnings_enabled( END_HANDLE_TH_ERRORS } +static PyObject* THPModule_set_warn_on_accumulate_grad_stream_mismatch( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + PyBool_Check(arg), + "enabled must be a bool, " + "but got ", + THPUtils_typename(arg)); + at::globalContext().setWarnOnAccumulateGradStreamMismatch(arg == Py_True); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* THPModule_warn_on_accumulate_grad_stream_mismatch( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + if (at::globalContext().warnOnAccumulateGradStreamMismatch()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + static PyObject* THCPModule_ensureCUDADeviceGuardSet( PyObject* self, PyObject* noargs) { @@ -1822,6 +1848,14 @@ static std::initializer_list TorchMethods = { THPModule_are_vmap_fallback_warnings_enabled, METH_NOARGS, nullptr}, + {"_set_warn_on_accumulate_grad_stream_mismatch", + THPModule_set_warn_on_accumulate_grad_stream_mismatch, + METH_O, + nullptr}, + {"_warn_on_accumulate_grad_stream_mismatch", + THPModule_warn_on_accumulate_grad_stream_mismatch, + METH_NOARGS, + nullptr}, {"_to_dlpack", castPyCFunctionWithKeywords(THPModule_toDLPack), METH_VARARGS | METH_KEYWORDS, diff --git a/torch/csrc/Stream.cpp b/torch/csrc/Stream.cpp index 534294909a18f..6993f726597cb 100644 --- a/torch/csrc/Stream.cpp +++ b/torch/csrc/Stream.cpp @@ -95,6 +95,7 @@ static PyObject* THPStream_pynew( self->device_index = static_cast(stream_opt->device_index()); self->device_type = static_cast(stream_opt->device_type()); self->context = nullptr; + self->weakreflist = nullptr; return static_cast(ptr.release()); END_HANDLE_TH_ERRORS @@ -114,11 +115,13 @@ PyObject* THPStream_Wrap(const c10::Stream& stream) { self->device_index = static_cast(stream.device_index()); self->device_type = static_cast(stream.device_type()); self->context = nullptr; + self->weakreflist = nullptr; return ptr.release(); END_HANDLE_TH_ERRORS } static void THPStream_dealloc(THPStream* self) { + PyObject_ClearWeakRefs((PyObject*)self); Py_TYPE(self)->tp_free(reinterpret_cast(self)); } @@ -444,7 +447,7 @@ static PyTypeObject THPStreamType = { nullptr, /* tp_traverse */ nullptr, /* tp_clear */ THPStream_richcompare, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ + offsetof(THPStream, weakreflist), /* tp_weaklistoffset */ nullptr, /* tp_iter */ nullptr, /* tp_iternext */ // NOLINTNEXTLINE(*const-cast) diff --git a/torch/csrc/Stream.h b/torch/csrc/Stream.h index 43b2b3ea43ec1..b4378f30a44e3 100644 --- a/torch/csrc/Stream.h +++ b/torch/csrc/Stream.h @@ -13,6 +13,7 @@ struct THPStream { int64_t device_index; // Used to switch stream context management, initialized lazily. PyObject* context; + PyObject* weakreflist; }; extern TORCH_API PyTypeObject* THPStreamClass; diff --git a/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h b/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h index 2ea2b52fa0fb9..198172ab56489 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h +++ b/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h @@ -8,9 +8,9 @@ namespace torch::nn { class ParameterListImpl : public Cloneable { public: - using Iterator = typename std::vector< - OrderedDict::Item>::iterator; - using ConstIterator = typename std::vector< + using Iterator = + std::vector::Item>::iterator; + using ConstIterator = std::vector< OrderedDict::Item>::const_iterator; ParameterListImpl() = default; diff --git a/torch/csrc/api/include/torch/nn/options/conv.h b/torch/csrc/api/include/torch/nn/options/conv.h index bbaecbeb97b61..0648c3d3aa6d5 100644 --- a/torch/csrc/api/include/torch/nn/options/conv.h +++ b/torch/csrc/api/include/torch/nn/options/conv.h @@ -60,7 +60,7 @@ struct ConvNdOptions { TORCH_ARG(padding_t, padding) = 0; public: - decltype(auto) padding(std::initializer_list il) { + auto padding(std::initializer_list il) { return padding(IntArrayRef{il}); } @@ -139,7 +139,7 @@ struct ConvOptions { TORCH_ARG(padding_t, padding) = 0; public: - decltype(auto) padding(std::initializer_list il) { + auto padding(std::initializer_list il) { return padding(IntArrayRef{il}); } @@ -209,7 +209,7 @@ struct ConvFuncOptions { TORCH_ARG(padding_t, padding) = 0; public: - decltype(auto) padding(std::initializer_list il) { + auto padding(std::initializer_list il) { return padding(IntArrayRef{il}); } diff --git a/torch/csrc/api/include/torch/nn/pimpl.h b/torch/csrc/api/include/torch/nn/pimpl.h index 3c1206e4edb82..aef9b590c716e 100644 --- a/torch/csrc/api/include/torch/nn/pimpl.h +++ b/torch/csrc/api/include/torch/nn/pimpl.h @@ -130,7 +130,7 @@ class ModuleHolder : torch::detail::ModuleHolderIndicator { /// NOTE: std::forward is qualified to prevent VS2017 emitting /// error C2872: 'std': ambiguous symbol template - decltype(auto) operator[](Arg&& arg) { + auto operator[](Arg&& arg) { return (*impl_)[::std::forward(arg)]; } diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index ff58cfd18ee39..42d701298b0d1 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -4568,7 +4568,7 @@ std::tuple linalg_solve_triangular_backward( if (!grad.defined() || (!A_requires_grad && !B_requires_grad)) { return std::make_tuple(Tensor{}, Tensor{}); } - // We always need to comput G_B + // We always need to compute G_B const Tensor A_H = A.mH(); const Tensor G_B = at::linalg_solve_triangular(A_H, grad, !upper, left, unitriangular); @@ -6503,7 +6503,7 @@ Tensor rms_norm_jvp( Tensor rstd_t; if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || input_t._is_zerotensor()) { - rstd_t = -rstd_p.pow(3) * (input_t) * (input_p); + rstd_t = -rstd_p.pow(3) * input_t * input_p; } else { rstd_t = input_t * input_p; rstd_t *= -rstd_p.pow(3); @@ -6514,7 +6514,7 @@ Tensor rms_norm_jvp( Tensor result_t; if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || input_t._is_zerotensor()) { - result_t = (input_t)*rstd_p + (input_p)*rstd_t; + result_t = input_t * rstd_p + input_p * rstd_t; } else { result_t = input_t * rstd_p; auto temp = input_p * rstd_t; @@ -6558,7 +6558,7 @@ Tensor rms_norm_rstd_jvp( Tensor rstd_t; if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) || input_t._is_zerotensor()) { - rstd_t = -rstd_p.pow(3) * (input_t) * (input_p); + rstd_t = -rstd_p.pow(3) * input_t * input_p; } else { rstd_t = input_t * input_p; rstd_t *= -rstd_p.pow(3); diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index f92af4994fd5b..0b70aae489e33 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -1199,7 +1199,11 @@ void Engine::evaluate_function( // Accumulates into buffer auto opt_next_stream = next.function->stream(); input_buffer.add( - next.input_nr, std::move(output), opt_parent_stream, opt_next_stream); + next.input_nr, + std::move(output), + opt_parent_stream, + opt_next_stream, + next.function.get()); if (is_ready) { auto queue = ready_queue(cpu_ready_queue, next.function->device()); @@ -1215,7 +1219,11 @@ void Engine::evaluate_function( // Accumulates into buffer auto opt_next_stream = next.function->stream(); input_buffer.add( - next.input_nr, std::move(output), opt_parent_stream, opt_next_stream); + next.input_nr, + std::move(output), + opt_parent_stream, + opt_next_stream, + next.function.get()); if (is_ready) { auto queue = ready_queue(cpu_ready_queue, next.function->device()); queue->push( @@ -1368,7 +1376,8 @@ auto Engine::execute( root_edges.at(0).input_nr, std::move(input), input_stream, - opt_next_stream); + opt_next_stream, + root_edges.at(0).function.get()); execute_with_graph_task( graph_task, std::move(graph_root), std::move(input_buffer)); diff --git a/torch/csrc/autograd/input_buffer.cpp b/torch/csrc/autograd/input_buffer.cpp index 63ca5daedd236..62770ef946592 100644 --- a/torch/csrc/autograd/input_buffer.cpp +++ b/torch/csrc/autograd/input_buffer.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -11,6 +12,7 @@ #include #include #include +#include #include #include @@ -191,7 +193,8 @@ void InputBuffer::add( size_t pos, Variable&& var, const std::optional& opt_producer_stream_, - const std::optional& opt_consumer_stream_) { + const std::optional& opt_consumer_stream_, + Node* fn) { TORCH_INTERNAL_ASSERT(pos < buffer.size()); if (!var.defined()) { @@ -231,6 +234,21 @@ void InputBuffer::add( TORCH_INTERNAL_ASSERT(opt_consumer_stream && opt_producer_stream); + if (*opt_consumer_stream != *opt_producer_stream && + dynamic_cast(fn) && + at::globalContext().warnOnAccumulateGradStreamMismatch()) { + TORCH_WARN_ONCE( + "The AccumulateGrad node's stream does not match the stream of the node that produced " + "the incoming gradient. This may incur unnecessary synchronization and break CUDA graph " + "capture if the AccumulateGrad node's stream is the default stream. This mismatch is " + "caused by an AccumulateGrad node created prior to the current iteration being kept alive. " + "This can happen if the autograd graph is still being kept alive by tensors such as the " + "loss, or if you are using DDP, which will stash a reference to the node. To resolve the " + "mismatch, delete all references to the autograd graph or ensure that DDP initialization is " + "performed under the same stream as subsequent forwards. If the mismatch is intentional, " + "you can use torch.autograd.graph.set_warn_on_accumulate_grad_stream_mismatch(False) to suppress this " + "warning."); + } // See Note: [Autograd Producer-Consumer Stream Syncs] if (!opt_accum_streams[pos].has_value()) { // [ First producer ] diff --git a/torch/csrc/autograd/input_buffer.h b/torch/csrc/autograd/input_buffer.h index 89abd91f49126..791710d295248 100644 --- a/torch/csrc/autograd/input_buffer.h +++ b/torch/csrc/autograd/input_buffer.h @@ -32,7 +32,8 @@ struct InputBuffer { size_t pos, Variable&& var, const std::optional& opt_producer_stream, - const std::optional& opt_consumer_stream); + const std::optional& opt_consumer_stream, + Node* fn); Variable operator[](size_t pos) { return buffer[pos]; diff --git a/torch/csrc/autograd/profiler_legacy.h b/torch/csrc/autograd/profiler_legacy.h index 30a9fb96f258d..7753deec04a63 100644 --- a/torch/csrc/autograd/profiler_legacy.h +++ b/torch/csrc/autograd/profiler_legacy.h @@ -117,7 +117,7 @@ struct TORCH_API LegacyEvent { } double cpuElapsedUs(const LegacyEvent& e) const { - return static_cast(e.cpu_ns_ - cpu_ns_) / (1000.0); + return static_cast(e.cpu_ns_ - cpu_ns_) / 1000.0; } void setCpuUs(int64_t cpu_us) { @@ -125,7 +125,7 @@ struct TORCH_API LegacyEvent { } double cpuUs() const { - return static_cast(cpu_ns_) / (1000.0); + return static_cast(cpu_ns_) / 1000.0; } double cudaElapsedUs(const LegacyEvent& e) const; diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 4d6c618d0faef..946a8d5f1d367 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -29,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -828,19 +830,26 @@ static bool arg_type_tensor_or_tensor_list_like(py::handle arg) { return true; } -#define FOR_EACH_DTENSOR_INTERNED_STRING(_) \ - _(_comparison_key) \ - _(_local_tensor) \ - _(_spec) \ - _(args_schema) \ - _(has_symints) \ - _(kwargs_schema) \ - _(op) \ - _(schema_info) \ - _(shape) \ - _(static_argnum) \ - _(static_kwargkey) \ - _(stride) \ +#if IS_PYTHON_3_11_PLUS +#define MAYBE_FOR_EACH_PYTHON_3_10_MINUS_DTENSOR_INTERNED_STRING(_) +#else +#define MAYBE_FOR_EACH_PYTHON_3_10_MINUS_DTENSOR_INTERNED_STRING(_) _(__name__) +#endif + +#define FOR_EACH_DTENSOR_INTERNED_STRING(_) \ + MAYBE_FOR_EACH_PYTHON_3_10_MINUS_DTENSOR_INTERNED_STRING(_) \ + _(_comparison_key) \ + _(_local_tensor) \ + _(_spec) \ + _(args_schema) \ + _(kwargs_schema) \ + _(op) \ + _(schema_info) \ + _(shape) \ + _(size) \ + _(static_argnum) \ + _(static_kwargkey) \ + _(stride) \ _(tensor_meta) struct DTensorInternedStrings { @@ -1094,41 +1103,133 @@ static PyObject* DTensor_OpSchema_post_init(PyObject* mod, PyObject* self) { return nullptr; } - const auto dtensor_spec_class = get_dtensor_spec_class(); - bool has_symints = false; - for (const auto& a : args_schema) { - if (Py_TYPE(a.ptr()) != (PyTypeObject*)(dtensor_spec_class.ptr()) && - !py::isinstance(a, dtensor_spec_class)) { - continue; - } - const py::handle tensor_meta = a.attr(dtensor_interned_strings.tensor_meta); - if (tensor_meta.is_none()) { - continue; - } - const auto contains_any_symint = [](const py::tuple& sequence) { - for (const auto& s : sequence) { - if (THPUtils_checkLong(s.ptr())) { - continue; - } - if (torch::is_symint(s)) { - return true; + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static py::list symint_array_to_list(SymIntArrayRef arr) { + py::list result(arr.size()); + for (const auto idx : c10::irange(arr.size())) { + result[idx] = py::cast(arr[idx]); + } + return result; +} + +static PyObject* DTensor_compute_global_tensor_info_impl( + const Tensor& tensor, + py::handle mesh, + const py::sequence& placements) { + Py_ssize_t idx = 0; + c10::SymDimVector tensor_shape( + tensor.sym_sizes().begin(), tensor.sym_sizes().end()); + c10::SymDimVector tensor_strides( + tensor.sym_strides().begin(), tensor.sym_strides().end()); + // NOTE: if this is a py::handle then this code stops working; + // apparently we can't rely on the bound method to stick around. + py::object mesh_size; + for (const auto& placement : placements) { + // TODO: C++ify DeviceMesh somehow; profiling seems + // to say that nearly all our remaining time spent is spent + // calling back into Python. + const auto& cpp_placement = placement.cast(); + if (const auto* cpp_shard = + dynamic_cast(&cpp_placement)) { + const auto shard_dim = cpp_shard->dim; + TORCH_CHECK( + shard_dim >= 0, + "Shard placements should have negative dims normalized in the user-facing APIs: ", + py::cast(py::str(placement))); + const auto tensor_ndim = tensor.dim(); + TORCH_CHECK( + shard_dim < tensor_ndim, + "Sharding dim ", + shard_dim, + " greater than tensor ndim ", + tensor_ndim, + " for placement number ", + idx); + + if (!mesh_size) { + mesh_size = mesh.attr(dtensor_interned_strings.size); + } + const auto mesh_dim_size = py::cast(mesh_size(idx)); + tensor_shape[shard_dim] *= mesh_dim_size; + // recover tensor stride by modifying the strides that are + // larger than the current stride on the shard_dim. + for (const auto i : c10::irange(tensor_strides.size())) { + if (static_cast(i) != shard_dim && + tensor_strides[i] >= tensor_strides[shard_dim]) { + tensor_strides[i] *= mesh_dim_size; } } - return false; - }; - // Specifically it's supposed to be torch.Size. - py::object raw_shape = tensor_meta.attr(dtensor_interned_strings.shape); - if (!PyTuple_Check(raw_shape.ptr())) { - PyErr_SetString(PyExc_TypeError, "OpSchema.shape must be a tuple!"); - return nullptr; - } - const auto shape = py::reinterpret_steal(raw_shape.release()); - if (contains_any_symint(shape)) { - has_symints = true; + } else if (!cpp_placement.is_replicate() && !cpp_placement.is_partial()) { +#if IS_PYTHON_3_11_PLUS + const auto placement_type_name = + py::str(py::handle(PyType_GetName(Py_TYPE(placement.ptr())))); +#else + const auto placement_type_name = + py::str(py::handle((PyObject*)Py_TYPE(placement.ptr())) + .attr(dtensor_interned_strings.__name__)); +#endif + return PyErr_Format( + PyExc_RuntimeError, + "placement type %s not supported!", + py::cast(placement_type_name).c_str()); } + idx++; } - self_handle.attr(dtensor_interned_strings.has_symints) = has_symints; - Py_RETURN_NONE; + return py::make_tuple( + symint_array_to_list(tensor_shape), + symint_array_to_list(tensor_strides)) + .release() + .ptr(); +} + +// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) +static constexpr const char compute_global_tensor_info_doc[] = + "Compute the global size and stride of a DTensor from the given local tensor.\n" + "The local size is multiplied by `world_size` per Sharding dim.\n" + "The local stride is multiplied by `world_size` per Sharding dim, as long as the\n" + "dimension is outside sharding dim.\n" + "\n" + "For example, if we have a local tensor with size (4, 8, 2) and stride (16, 1, 8).\n" + "If the DTensor placements are [Shard(2)] and world_size is 2;\n" + "then the global size is (4, 8, 4) and stride is (16 * 2, 1, 8).\n" + "\n" + "Args:\n" + " tensor (:class:`torch.Tensor`):\n" + " Local tensor which DTensor will be constructed from.\n" + " mesh (:class:`DeviceMesh`):\n" + " Object which describes the mesh topology\n" + " of devices for the DTensor.\n" + " placements (Sequence[:class:`Placement`]]):\n" + " The attribute of the DTensor that describes its layout\n" + " on the mesh topology.\n" + "\n" + "Return:\n" + " tensor_shape: A List of int which specifies the size of DTensor which build\n" + " on top of the local tensor.\n" + " tensor_stride: A List of int which specifies the stride of DTensor.\n"; + +static PyObject* DTensor_compute_global_tensor_info( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs) { + HANDLE_TH_ERRORS + TORCH_CHECK_VALUE( + nargs == 3, + "compute_global_tensor_info expects 3 arguments, got ", + nargs); + TORCH_CHECK_TYPE( + THPVariable_Check(args[0]), + "compute_global_tensor_info 1st argument must be Tensor!"); + const auto& tensor = THPVariable_Unpack(args[0]); + const py::handle mesh = args[1]; + TORCH_CHECK_TYPE( + PySequence_Check(args[2]), + "compute_global_tensor_info 3rd argument must be sequence!"); + const py::sequence placements = py::reinterpret_borrow(args[2]); + return DTensor_compute_global_tensor_info_impl(tensor, mesh, placements); END_HANDLE_TH_ERRORS } @@ -2114,6 +2215,10 @@ static PyMethodDef extra_functions[] = { DTensor_OpSchema_recompute_comparison_key, METH_O, nullptr}, + {"_DTensor_compute_global_tensor_info", + castPyCFunctionFast(DTensor_compute_global_tensor_info), + METH_FASTCALL, + compute_global_tensor_info_doc}, {nullptr}}; struct THPVariableMeta { diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index b559ba44bf52f..89135f9aa9a22 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -115,8 +115,8 @@ ViewInfo ViewInfo::chain( // view_func() AND as_strided() isn't supported; there's no obvious way // to chain the two views. auto error_msg = - ("Attempted to chain views when the parent view has no view_func() and " - "does not support as_strided(). This is not supported."); + "Attempted to chain views when the parent view has no view_func() and " + "does not support as_strided(). This is not supported."; view_func = std::make_unique(error_msg); rev_view_func = [=](const at::Tensor& root_view) { TORCH_CHECK(false, error_msg); diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index a9a5a13206f9c..b14323a47bf35 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -457,6 +457,24 @@ PyObject* THCPModule_cudaSleep(PyObject* _unused, PyObject* cycles) { END_HANDLE_TH_ERRORS } +PyObject* THCPModule_cudaBusyWaitForFlag(PyObject* _unused, PyObject* noargs) { + HANDLE_TH_ERRORS { + pybind11::gil_scoped_release no_gil; + at::cuda::busy_wait_for_flag(); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cudaClearFlag(PyObject* _unused, PyObject* noargs) { + HANDLE_TH_ERRORS { + pybind11::gil_scoped_release no_gil; + at::cuda::clear_flag(); + } + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + // We need to ensure that as long as a thread will NEVER loose the GIL as long // as it holds the CUDA mutex. Otherwise another thread might be scheduled and // try to e.g. allocate a new tensor which will cause a deadlock. It's enough to @@ -1017,7 +1035,7 @@ PyObject* THCPModule_cudaGetSyncDebugMode(PyObject* self, PyObject* noargs) { //////////////////////////////////////////////////////////////////////////////// static void registerCudaDeviceProperties(PyObject* module) { - // Add _cudaDevicePropertires class to torch._C + // Add _cudaDeviceProperties class to torch._C auto m = py::handle(module).cast(); // CUuuid is defined in either cuda.h or driver_types.h // hipified to hipUUID which is defined in hip_runtime_api.h @@ -2074,6 +2092,11 @@ static struct PyMethodDef _THCPModule_methods[] = { {"_cuda_synchronize", THCPModule_cudaSynchronize, METH_NOARGS, nullptr}, {"_cuda_ipc_collect", THCPModule_cudaIPCCollect, METH_NOARGS, nullptr}, {"_cuda_sleep", THCPModule_cudaSleep, METH_O, nullptr}, + {"_cuda_busy_wait_for_flag", + THCPModule_cudaBusyWaitForFlag, + METH_NOARGS, + nullptr}, + {"_cuda_clear_flag", THCPModule_cudaClearFlag, METH_NOARGS, nullptr}, {"_cuda_lock_mutex", THCPModule_cudaLockMutex, METH_NOARGS, nullptr}, {"_cuda_unlock_mutex", THCPModule_cudaUnlockMutex, METH_NOARGS, nullptr}, {"_cuda_set_sync_debug_mode", diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index 3743476c7a52f..156c9efd5ca98 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -98,7 +98,12 @@ void DistEngine::globalCpuThread( InputBuffer::variables(std::move(task.inputs_))]() mutable { InputBuffer inputs(variables.size()); for (const auto i : c10::irange(variables.size())) { - inputs.add(i, std::move(variables[i]), std::nullopt, std::nullopt); + inputs.add( + i, + std::move(variables[i]), + std::nullopt, + std::nullopt, + graphRoot.get()); } execute_graph_task_until_ready_queue_empty( /*node_task*/ NodeTask(graphTask, graphRoot, std::move(inputs)), diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 642893cbf41f5..fd7f0b4246517 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -349,8 +349,7 @@ static void cacheAllocatorDeregisterHook( } static void attachAllocatorHooks() { - static c10::once_flag flag; - c10::call_once(flag, [] { + static auto flag [[maybe_unused]] = [] { // Attaching hooks fails if CUDACachingAllocator is not initialized, so // Init for CUDA is called (and is a no-op if CUDA is already // initialized). @@ -359,7 +358,8 @@ static void attachAllocatorHooks() { &cacheAllocatorRegisterHook); c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( &cacheAllocatorDeregisterHook); - }); + return true; + }(); } static std:: diff --git a/torch/csrc/distributed/c10d/error.h b/torch/csrc/distributed/c10d/error.h index fef7a630410f4..7d41e99edba97 100644 --- a/torch/csrc/distributed/c10d/error.h +++ b/torch/csrc/distributed/c10d/error.h @@ -15,13 +15,12 @@ namespace fmt { template <> struct formatter { - constexpr decltype(auto) parse(format_parse_context& ctx) const { + constexpr auto parse(format_parse_context& ctx) const { return ctx.begin(); } template - decltype(auto) format(const std::error_category& cat, FormatContext& ctx) - const { + auto format(const std::error_category& cat, FormatContext& ctx) const { if (std::strcmp(cat.name(), "generic") == 0) { return fmt::format_to(ctx.out(), "errno"); } else { @@ -32,12 +31,12 @@ struct formatter { template <> struct formatter { - constexpr decltype(auto) parse(format_parse_context& ctx) const { + constexpr auto parse(format_parse_context& ctx) const { return ctx.begin(); } template - decltype(auto) format(const std::error_code& err, FormatContext& ctx) const { + auto format(const std::error_code& err, FormatContext& ctx) const { return fmt::format_to( ctx.out(), "({}: {} - {})", err.category(), err.value(), err.message()); } diff --git a/torch/csrc/distributed/c10d/socket.cpp b/torch/csrc/distributed/c10d/socket.cpp index 48b59f41b7a88..c79f5a04010eb 100644 --- a/torch/csrc/distributed/c10d/socket.cpp +++ b/torch/csrc/distributed/c10d/socket.cpp @@ -247,12 +247,12 @@ namespace fmt { template <> struct formatter<::addrinfo> { - constexpr decltype(auto) parse(format_parse_context& ctx) const { + constexpr auto parse(format_parse_context& ctx) const { return ctx.begin(); } template - decltype(auto) format(const ::addrinfo& addr, FormatContext& ctx) const { + auto format(const ::addrinfo& addr, FormatContext& ctx) const { return fmt::format_to( ctx.out(), "{}", @@ -262,14 +262,13 @@ struct formatter<::addrinfo> { template <> struct formatter { - constexpr decltype(auto) parse(format_parse_context& ctx) const { + constexpr auto parse(format_parse_context& ctx) const { return ctx.begin(); } template - decltype(auto) format( - const c10d::detail::SocketImpl& socket, - FormatContext& ctx) const { + auto format(const c10d::detail::SocketImpl& socket, FormatContext& ctx) + const { ::sockaddr_storage addr_s{}; auto addr_ptr = reinterpret_cast<::sockaddr*>(&addr_s); diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.hpp b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.hpp index 77dd36b778aea..efec39e9eb72c 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.hpp @@ -103,7 +103,7 @@ class StoreExchange { size_t seq_id_ = 0; }; -// Teturns a pointer of virtual address that is mapped to the physical memory +// Returns a pointer of virtual address that is mapped to the physical memory // held by the handle. void map_block( void** ptr, diff --git a/torch/csrc/distributed/rpc/request_callback_no_python.cpp b/torch/csrc/distributed/rpc/request_callback_no_python.cpp index ef645675af20a..46fb40801d259 100644 --- a/torch/csrc/distributed/rpc/request_callback_no_python.cpp +++ b/torch/csrc/distributed/rpc/request_callback_no_python.cpp @@ -72,7 +72,7 @@ c10::intrusive_ptr RequestCallbackNoPython::processMessage( auto retFuture = rrefsReadyFuture->thenAsync( [this, - // std::function must be copyable, hence hae to cast the unique_ptr to + // std::function must be copyable, hence has to cast the unique_ptr to // a shared_ptr here. rpc = std::shared_ptr(std::move(rpc)), messageType = request.type(), diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h index e353c54805415..f4b3c58697898 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.h +++ b/torch/csrc/distributed/rpc/rpc_agent.h @@ -240,7 +240,7 @@ class TORCH_API RpcAgent { // should be profiled or not. void enableGILProfiling(bool flag); - // Retrieve wheher we should profile GIL wait times or not. + // Retrieve whether we should profile GIL wait times or not. bool isGILProfilingEnabled(); // Set type resolver that will be passed to JIT pickler to resolver type Ptr diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index bdcaf71c05d5f..ac07dc47c5574 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -3534,7 +3534,7 @@ class RootGuardManager : public GuardManager { void add_no_tensor_aliasing_guard( std::shared_ptr no_tensor_aliasing_guard) { - // stash a pointer to the _no_tensor_alising_guard + // stash a pointer to the _no_tensor_aliasing_guard _no_tensor_aliasing_guard = no_tensor_aliasing_guard; this->add_relational_guard_resetter(std::move(no_tensor_aliasing_guard)); } diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index 1e1783477d2e0..0e70be3e9ffc4 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -980,7 +980,7 @@ static CacheNode* _compiled_autograd_impl( // cache miss, need to capture FX graph TORCH_INTERNAL_ASSERT(!vlogger.has_value() || compile_reason.has_value()); ClosingTHPObjectPtr py_compiler( - check(PyObject_CallNoArgs((the_autograd_compiler)))); + check(PyObject_CallNoArgs(the_autograd_compiler))); PyCompilerGuard py_compiler_guard( std::make_unique()); diff --git a/torch/csrc/fx/node.cpp b/torch/csrc/fx/node.cpp index 1669f79af72aa..11659cc24eb89 100644 --- a/torch/csrc/fx/node.cpp +++ b/torch/csrc/fx/node.cpp @@ -544,19 +544,19 @@ static int NodeBase_set_sort_key( // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) static PyMethodDef NodeBase_methods[] = { {"_update_args_kwargs", - (PyCFunction)(void*)(NodeBase__update_args_kwargs), + (PyCFunction)(void*)NodeBase__update_args_kwargs, METH_FASTCALL, "Internal method: do not call directly."}, {"_remove_from_list", - (PyCFunction)(void*)(NodeBase__remove_from_list), + (PyCFunction)(void*)NodeBase__remove_from_list, METH_NOARGS, "Internal method: do not call directly."}, {"_replace_input_with", - (PyCFunction)(void*)(NodeBase__replace_input_with), + (PyCFunction)(void*)NodeBase__replace_input_with, METH_FASTCALL, "Internal method: replace occurrences of one input Node with another."}, {"_prepend", - (PyCFunction)(void*)(NodeBase__prepend), + (PyCFunction)(void*)NodeBase__prepend, METH_O, "Internal method: do not call directly."}, {"__lt__", @@ -832,11 +832,11 @@ static PyObject* py_map_arg( // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) static PyMethodDef extra_methods[] = { {"_fx_map_aggregate", - (PyCFunction)(void*)(py_map_aggregate), + (PyCFunction)(void*)py_map_aggregate, METH_FASTCALL, "Recursively apply a function to every element in an aggregate object."}, {"_fx_map_arg", - (PyCFunction)(void*)(py_map_arg), + (PyCFunction)(void*)py_map_arg, METH_FASTCALL, "Recursively apply a function to every Node in an aggregate object."}, {nullptr, nullptr, 0, nullptr} // Sentinel diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index 1face0cd6b80b..05d7aa04425f5 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -824,7 +824,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader( runner_ = registered_aoti_runner[device_key]( so_path, num_runners, device.str(), cubin_dir, run_single_threaded); - if (weight_blob_filename != "") { + if (!weight_blob_filename.empty()) { runner_->update_constant_buffer_from_blob(weight_blob_filename); } } diff --git a/torch/csrc/inductor/aoti_runtime/kernel_context_tls.h b/torch/csrc/inductor/aoti_runtime/kernel_context_tls.h new file mode 100644 index 0000000000000..3489494d77e4e --- /dev/null +++ b/torch/csrc/inductor/aoti_runtime/kernel_context_tls.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include + +namespace torch::aot_inductor { + +struct KernelContext { + std::string kernel_name; + std::string python_stack; + + KernelContext(std::string name, std::string stack) + : kernel_name(std::move(name)), python_stack(std::move(stack)) {} +}; + +// Thread-local pointer +extern thread_local KernelContext* tls_kernel_context; + +inline KernelContext* current_kernel_context() { + return tls_kernel_context; +} + +inline void set_kernel_context(KernelContext* ctx) { + tls_kernel_context = ctx; +} + +inline void clear_kernel_context() { + tls_kernel_context = nullptr; +} + +struct KernelContextGuard { + KernelContextGuard(const std::string& name, const std::string& stack) + : owned_context_(name, stack) { + set_kernel_context(&owned_context_); + } + ~KernelContextGuard() { + clear_kernel_context(); + } + + // Delete copy constructor and copy assignment operator + KernelContextGuard(const KernelContextGuard&) = delete; + KernelContextGuard& operator=(const KernelContextGuard&) = delete; + + KernelContextGuard(KernelContextGuard&&) = default; + KernelContextGuard& operator=(KernelContextGuard&&) = delete; + + private: + KernelContext owned_context_; +}; + +} // namespace torch::aot_inductor diff --git a/torch/csrc/inductor/aoti_runtime/model_base.h b/torch/csrc/inductor/aoti_runtime/model_base.h index 19f1dca1b7e27..a23c836a46735 100644 --- a/torch/csrc/inductor/aoti_runtime/model_base.h +++ b/torch/csrc/inductor/aoti_runtime/model_base.h @@ -836,10 +836,9 @@ class AOTInductorModelBase { } void update_constants_array_from_map() { - if (!constants_map_) { - throw std::runtime_error{ - "constants_map_ was not ready when constants_ is trying to be constructed from it!"}; - } + STD_TORCH_CHECK( + constants_map_, + "constants_map_ was not ready when constants_ is trying to be constructed from it!"); if (!constants_) { constants_ = std::make_shared>(constants_info_.size()); @@ -875,9 +874,7 @@ class AOTInductorModelBase { /// Returns true if the model is complete. bool is_finished() { #ifdef USE_CUDA - if (!run_finished_) { - throw std::runtime_error{"Model CUDA event was not initialized"}; - } + STD_TORCH_CHECK(run_finished_, "Model CUDA event was not initialized"); auto event_status = cudaEventQuery(*run_finished_); if (event_status == cudaSuccess) { @@ -886,13 +883,13 @@ class AOTInductorModelBase { return false; } - throw std::runtime_error( - std::string("The model did not finish successfully. Error: ") + + STD_TORCH_CHECK( + false, + "The model did not finish successfully. Error: ", cudaGetErrorString(cudaGetLastError())); #elif defined(USE_XPU) - if (!run_finished_) { - throw std::runtime_error{"Model XPU event was not initialized"}; - } + STD_TORCH_CHECK(run_finished_, "Model XPU event was not initialized"); + using namespace sycl::info; return (*run_finished_)->get_info() == event_command_status::complete; @@ -904,19 +901,14 @@ class AOTInductorModelBase { /// Synchronizes completion event. void wait_for_completion() { + STD_TORCH_CHECK(run_finished_, "Model event was not initialized"); #ifdef USE_CUDA - if (!run_finished_) { - throw std::runtime_error{"Model event was not initialized"}; - } - AOTI_RUNTIME_CUDA_CHECK(cudaEventSynchronize(*run_finished_)); #endif // USE_CUDA + #ifdef USE_XPU - if (!run_finished_) { - throw std::runtime_error{"Model event was not initialized"}; - } (*run_finished_)->wait_and_throw(); -#endif +#endif // USE_XPU } protected: diff --git a/torch/csrc/inductor/aoti_runtime/model_container.h b/torch/csrc/inductor/aoti_runtime/model_container.h index 5cb7daa28a064..61c64760f5328 100644 --- a/torch/csrc/inductor/aoti_runtime/model_container.h +++ b/torch/csrc/inductor/aoti_runtime/model_container.h @@ -123,8 +123,10 @@ class AOTInductorModelContainer { constants_folding_lk.unlock(); model_lk.lock(); } else if (const_folded != ConstantState::FOLDED) { - throw std::runtime_error( - "Unknown constant state: " + toStringConstantState(constant_folded_)); + STD_TORCH_CHECK( + false, + "Unknown constant state: ", + toStringConstantState(constant_folded_)); } try { @@ -167,8 +169,10 @@ class AOTInductorModelContainer { /* validate_full_update = */ false); const_folded = ConstantState::FOLDED; } else if (constant_folded_ != ConstantState::FOLDED) { - throw std::runtime_error( - "Unknown constant state: " + toStringConstantState(constant_folded_)); + STD_TORCH_CHECK( + false, + "Unknown constant state: ", + toStringConstantState(constant_folded_)); } model->run_single_threaded( @@ -202,56 +206,56 @@ class AOTInductorModelContainer { } size_t num_constants() const { - if (this->num_models() == 0) { - throw std::runtime_error("No available models in container!"); - } + STD_TORCH_CHECK( + this->num_models() != 0, "No available models in container!"); + return models_[0]->num_constants(); } // retrieve the constant name of constants_info_[idx] const char* constant_name(size_t idx) const { - if (this->num_models() == 0) { - throw std::runtime_error("No available models in container!"); - } + STD_TORCH_CHECK( + this->num_models() != 0, "No available models in container!"); + return models_[0]->constant_name(static_cast(idx)); } // retrieve original FQN of constants_info_[idx] const char* constant_original_fqn(size_t idx) const { - if (this->num_models() == 0) { - throw std::runtime_error("No available models in container!"); - } + STD_TORCH_CHECK( + this->num_models() != 0, "No available models in container!"); + return models_[0]->constant_original_fqn(static_cast(idx)); } // retrieve whether constant is from folded of constants_info_[idx] bool constant_from_folded(size_t idx) const { - if (this->num_models() == 0) { - throw std::runtime_error("No available models in container!"); - } + STD_TORCH_CHECK( + this->num_models() != 0, "No available models in container!"); + return models_[0]->constant_from_folded(static_cast(idx)); } size_t constant_data_size(size_t idx) const { - if (this->num_models() == 0) { - throw std::runtime_error("No available models in container!"); - } + STD_TORCH_CHECK( + this->num_models() != 0, "No available models in container!"); + return models_[0]->constant_data_size(static_cast(idx)); } // retrieve type of constants_info_[idx] int32_t constant_type(size_t idx) const { - if (this->num_models() == 0) { - throw std::runtime_error("No available models in container!"); - } + STD_TORCH_CHECK( + this->num_models() != 0, "No available models in container!"); + return models_[0]->constant_type(static_cast(idx)); } // retrieve dtype of constants_info_[idx] int32_t constant_dtype(size_t idx) const { - if (this->num_models() == 0) { - throw std::runtime_error("No available models in container!"); - } + STD_TORCH_CHECK( + this->num_models() != 0, "No available models in container!"); + return models_[0]->constant_dtype(static_cast(idx)); } @@ -383,9 +387,12 @@ class AOTInductorModelContainer { << " in model, but not provided by user!\n"; continue; } - throw std::runtime_error( - std::string("Cannot find constants ") + constant_name + - std::string(" in constants_map!")); + + STD_TORCH_CHECK( + false, + "Cannot find constants ", + constant_name, + " in constants_map!"); } } } @@ -395,9 +402,8 @@ class AOTInductorModelContainer { std::unordered_map&& constants_map, bool use_inactive, bool validate_full_update) { - if (this->num_models() == 0) { - throw std::runtime_error("No model available in container!"); - } + STD_TORCH_CHECK( + this->num_models() != 0, "No available models in container!"); if (validate_full_update) { assert_all_constants(constants_map); } @@ -443,9 +449,9 @@ class AOTInductorModelContainer { bool use_inactive, bool validate_full_update, bool user_managed = false) { - if (this->num_models() == 0) { - throw std::runtime_error("No model available in container!"); - } + STD_TORCH_CHECK( + this->num_models() != 0, "No model available in container!"); + if (validate_full_update) { assert_all_constants(constants_map); } diff --git a/torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h b/torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h index 18e0b80589622..24c7b48743265 100644 --- a/torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h +++ b/torch/csrc/inductor/aoti_runtime/scalar_to_tensor.h @@ -7,7 +7,7 @@ namespace torch::aot_inductor { template inline RAIIAtenTensorHandle scalar_to_tensor_handle(T value) { - throw std::runtime_error("Unsupported scalar_to_tensor_handle"); + STD_TORCH_CHECK(false, "Unsupported scalar_to_tensor_handle"); } // Specialize for supported C++ primitive types diff --git a/torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h b/torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h index 3a2e91c37c916..e2b5e04fc455f 100644 --- a/torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h +++ b/torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h @@ -143,7 +143,7 @@ static std::unique_ptr _createKernel( sycl::range<3> localRange(localRangeZ, localRangeY, localRangeX); sycl::nd_range<3> parallelWorkSize(globalRange, localRange); if (sharedMemory) { - // numParams from sycl info = user provided args + sharedMemroyBuffer + // numParams from sycl info = user provided args + sharedMemoryBuffer numParams -= 1; } // Submit the imported kernel. diff --git a/torch/csrc/inductor/aoti_runtime/thread_local.h b/torch/csrc/inductor/aoti_runtime/thread_local.h index fd931c95626e4..cf7ab0c1e6ed5 100644 --- a/torch/csrc/inductor/aoti_runtime/thread_local.h +++ b/torch/csrc/inductor/aoti_runtime/thread_local.h @@ -11,11 +11,11 @@ template <> struct ThreadLocalCachedOutputTensor { explicit ThreadLocalCachedOutputTensor(const RAIIAtenTensorHandle&) {} void copy_data_from(const RAIIAtenTensorHandle& handle) { - throw std::runtime_error("can't happen"); + STD_TORCH_CHECK(false, "can't happen"); } AtenTensorHandle tensor() const { - throw std::runtime_error("can't happen"); + STD_TORCH_CHECK(false, "can't happen"); } }; @@ -23,11 +23,11 @@ template <> struct ThreadLocalCachedOutputTensor { explicit ThreadLocalCachedOutputTensor(const AtenTensorHandle&) {} void copy_data_from(const AtenTensorHandle& handle) { - throw std::runtime_error("can't happen"); + STD_TORCH_CHECK(false, "can't happen"); } AtenTensorHandle tensor() const { - throw std::runtime_error("can't happen"); + STD_TORCH_CHECK(false, "can't happen"); } }; @@ -35,11 +35,11 @@ template <> struct ThreadLocalCachedOutputTensor { explicit ThreadLocalCachedOutputTensor(const ConstantHandle&) {} void copy_data_from(const ConstantHandle& handle) { - throw std::runtime_error("can't happen"); + STD_TORCH_CHECK(false, "can't happen"); } AtenTensorHandle tensor() const { - throw std::runtime_error("can't happen"); + STD_TORCH_CHECK(false, "can't happen"); } }; @@ -92,18 +92,18 @@ struct ThreadLocalCachedOutputArray; template <> struct ThreadLocalCachedOutputArray { explicit ThreadLocalCachedOutputArray(const RAIIAtenTensorHandle&) { - throw std::runtime_error("can't happen"); + STD_TORCH_CHECK(false, "can't happen"); } // Not supported yet! We would need to put contiguous() or // expect_contiguous() into the ABI. void copy_data_from(const RAIIAtenTensorHandle&) { - throw std::runtime_error("can't happen"); + STD_TORCH_CHECK(false, "can't happen"); } template ArrayRefTensor arrayref_tensor() const { - throw std::runtime_error("can't happen"); + STD_TORCH_CHECK(false, "can't happen"); } }; @@ -111,18 +111,18 @@ struct ThreadLocalCachedOutputArray { template <> struct ThreadLocalCachedOutputArray { explicit ThreadLocalCachedOutputArray(const ConstantHandle&) { - throw std::runtime_error("can't happen"); + STD_TORCH_CHECK(false, "can't happen"); } // Not supported yet! We would need to put contiguous() or // expect_contiguous() into the ABI. void copy_data_from(const ConstantHandle&) { - throw std::runtime_error("can't happen"); + STD_TORCH_CHECK(false, "can't happen"); } template ArrayRefTensor arrayref_tensor() const { - throw std::runtime_error("can't happen"); + STD_TORCH_CHECK(false, "can't happen"); } }; diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index 4fb746ea15271..996c6c8de5ea4 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -41,6 +41,7 @@ #include #include #include +#include #ifdef __cplusplus extern "C" { @@ -621,34 +622,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_proxy_executor_call_function( int num_tensors, AtenTensorHandle* flatten_tensor_args); -AOTI_TORCH_EXPORT void aoti_torch_check( - bool cond, - const char* func, - const char* file, - uint32_t line, - const char* msg); - -#ifdef STRIP_ERROR_MESSAGES -#define AOTI_TORCH_CHECK(cond, ...) \ - if (!(cond)) { \ - aoti_torch_check( \ - false, \ - __func__, \ - __FILE__, \ - static_cast(__LINE__), \ - TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \ - } -#else -#define AOTI_TORCH_CHECK(cond, ...) \ - if (!(cond)) { \ - aoti_torch_check( \ - false, \ - __func__, \ - __FILE__, \ - static_cast(__LINE__), \ - TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); \ - } -#endif +// Preserve for BC and will delete it later, using the STD_TORCH_CHECK directly +#define AOTI_TORCH_CHECK(cond, ...) STD_TORCH_CHECK(cond, ##__VA_ARGS__) AOTI_TORCH_EXPORT void aoti_torch_warn( const char* func, diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 04e633771ec27..52e7fd1ae6b90 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -1341,13 +1341,14 @@ AOTITorchError aoti_torch_proxy_executor_call_function( int num_tensors, AtenTensorHandle* flatten_tensor_args) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ - if (!proxy_executor) { - throw std::runtime_error( - "Unable to find a proxy executor to run custom ops. Please check if " - "there is a json file generated in the same directory as the so, or use " - "torch._inductor.aoti_compile_and_package to package everything into a " - "PT2 artifact."); - } + TORCH_CHECK( + proxy_executor != nullptr, + "Unable to find a proxy executor to run custom ops.", + "Please check if there is a json file generated", + "in the same directory as the so,", + "or use torch._inductor.aoti_compile_and_package", + "to package everything into a PT2 artifact."); + ProxyExecutor* executor = reinterpret_cast(proxy_executor); executor->call_function( extern_node_index, @@ -1358,17 +1359,6 @@ AOTITorchError aoti_torch_proxy_executor_call_function( }); } -void aoti_torch_check( - bool cond, - const char* func, - const char* file, - uint32_t line, - const char* msg) { - if (C10_UNLIKELY_OR_CONST(!cond)) { - ::c10::detail::torchCheckFail(func, file, line, msg); - } -} - void aoti_torch_warn( const char* func, const char* file, diff --git a/torch/csrc/inductor/aoti_torch/shim_mps.cpp b/torch/csrc/inductor/aoti_torch/shim_mps.cpp index 568350fa717d8..eb753e82f259b 100644 --- a/torch/csrc/inductor/aoti_torch/shim_mps.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_mps.cpp @@ -10,9 +10,7 @@ AOTITorchError aoti_torch_mps_set_arg_tensor( AtenTensorHandle tensor) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ auto t = tensor_handle_to_tensor_pointer(tensor); - if (t == nullptr) { - throw std::runtime_error("Tensor is null."); - } + TORCH_CHECK(t != nullptr, "Tensor is null."); auto func = reinterpret_cast(handle); func->setArg(idx, *t); }); diff --git a/torch/csrc/inductor/aoti_torch/utils.h b/torch/csrc/inductor/aoti_torch/utils.h index 22018cd70c829..78ab1e8387365 100644 --- a/torch/csrc/inductor/aoti_torch/utils.h +++ b/torch/csrc/inductor/aoti_torch/utils.h @@ -92,13 +92,11 @@ inline void assert_inf_and_nan( const std::string& tensor_name, at::Tensor& check_tensor) { auto isnan_tensor = check_tensor.isnan(); - if (isnan_tensor.any().item()) { - throw std::runtime_error("At least one NaN in " + tensor_name); - } + TORCH_CHECK( + !isnan_tensor.any().item(), "At least one NaN in ", tensor_name); auto isinf_tensor = check_tensor.isinf(); - if (isinf_tensor.any().item()) { - throw std::runtime_error("At least one INF in " + tensor_name); - } + TORCH_CHECK( + !isinf_tensor.any().item(), "At least one INF in ", tensor_name); } // utility functions to convert a pointer to an optional value diff --git a/torch/csrc/inductor/cpp_prefix.h b/torch/csrc/inductor/cpp_prefix.h index 8ae212d3d3db9..decdef52a1daa 100644 --- a/torch/csrc/inductor/cpp_prefix.h +++ b/torch/csrc/inductor/cpp_prefix.h @@ -14,7 +14,7 @@ // Because AOTInductor generated code will copy-paste this cpp_prefix.h for // the CPU backend, we have to make sure the used headers are implemented // in a header-only way, i.e. all the function and class definitions are -// in .h files instead of .cpp files, to avoid ABI backward-compatiblity +// in .h files instead of .cpp files, to avoid ABI backward-compatibility // breakage. #include diff --git a/torch/csrc/jit/OVERVIEW.md b/torch/csrc/jit/OVERVIEW.md index 1ef0522d2175a..2dd563302fdb5 100644 --- a/torch/csrc/jit/OVERVIEW.md +++ b/torch/csrc/jit/OVERVIEW.md @@ -441,7 +441,7 @@ The following sections look into each the stages in the script frontend in detai [frontend/tree.h](frontend/tree.h) -Our frontends produce ASTs in the form of Tree objects. Trees are similar to [s-expressions](https://en.wikipedia.org/wiki/S-expression). Leafs (i.e. Atoms) are always strings. Compound trees have a `kind` (e.g `TK_CONST` or `TK_IDENT` defined in [lexer.h](frontend/lexer.h)) and a list of sub-trees. For instance, the Tree for `z.sigmoid() - (x + y)` is: +Our frontends produce ASTs in the form of Tree objects. Trees are similar to [s-expressions](https://en.wikipedia.org/wiki/S-expression). Leaves (i.e. Atoms) are always strings. Compound trees have a `kind` (e.g `TK_CONST` or `TK_IDENT` defined in [lexer.h](frontend/lexer.h)) and a list of sub-trees. For instance, the Tree for `z.sigmoid() - (x + y)` is: ``` (- diff --git a/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp b/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp index a5a331d15c21c..18c1bc62b8c6d 100644 --- a/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp +++ b/torch/csrc/jit/backends/nnapi/nnapi_backend_lib.cpp @@ -121,7 +121,7 @@ class NnapiBackend : public PyTorchBackendInterface { shape_compute_module.run_method("prepare", ser_model, inputs) .toTensorList(); - // Create and initialize NnapiComilation object + // Create and initialize NnapiCompilation object comp_ = std::make_unique(); auto weights = dict.at("weights").toTensorVector(); comp_->init(ser_model, weights); diff --git a/torch/csrc/jit/frontend/schema_matching.cpp b/torch/csrc/jit/frontend/schema_matching.cpp index cbc22fab84e23..f191c7daf6e26 100644 --- a/torch/csrc/jit/frontend/schema_matching.cpp +++ b/torch/csrc/jit/frontend/schema_matching.cpp @@ -209,7 +209,7 @@ static Value* tryMatchArgument( value = tryConvertToType(loc, graph, concrete_type, value, allow_conversions); std::stringstream ss; if (!value->type()->isSubtypeOfExt( - *concrete_type, /*why_not=*/(failure_messages) ? &ss : nullptr)) { + *concrete_type, /*why_not=*/failure_messages ? &ss : nullptr)) { if (failure_messages) { auto& ostream = err() << arg.formatTypeMismatchMsg(value->type()->repr_str()); diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp index 103fadaf3a57e..0fb50a5d5dd03 100644 --- a/torch/csrc/jit/mobile/flatbuffer_loader.cpp +++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp @@ -379,7 +379,7 @@ std::unique_ptr FlatbufferLoader::parseFunction( function->append_type(getOrCreateTypeAnnotations(i)); } - // 3. If upgrader is needed, change change the OP instrunction to CALL + // 3. If upgrader is needed, change change the OP instruction to CALL // instruction (In next PR, use_upgrader will be parsed to parseInstruction // function and do the actual change) if (use_upgrader) { diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index 6a0ba7e038ea3..ab05e48143e3e 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -391,7 +391,7 @@ void BytecodeDeserializer::parseMethods( debug_handles_m_tuple, function.get()); - // 3. If upgrader is needed, change change the OP instrunction to CALL + // 3. If upgrader is needed, change change the OP instruction to CALL // instruction (In next PR, use_upgrader will be parsed to parseInstruction // function and do the actual change) if (use_upgrader) { diff --git a/torch/csrc/jit/passes/bailout_graph.cpp b/torch/csrc/jit/passes/bailout_graph.cpp index 7f8d7eedbe6bf..5bea5e42c0d28 100644 --- a/torch/csrc/jit/passes/bailout_graph.cpp +++ b/torch/csrc/jit/passes/bailout_graph.cpp @@ -196,7 +196,7 @@ struct BailOutGraphBuilderForNode { std::shared_ptr buildBailOutGraphFrom(Node* n) { // add graph inputs for guard's input // and loop counts for loops `n` is contained in - // to make sure we can line bailout grap's inputs up properly + // to make sure we can line bailout graph's inputs up properly // with arguments to this BailOut node. for (auto bi : n->inputs()) { getOrAddInputForValue(bi); diff --git a/torch/csrc/jit/passes/device_type_analysis.cpp b/torch/csrc/jit/passes/device_type_analysis.cpp index 26d55deb636df..9c88b3a992a80 100644 --- a/torch/csrc/jit/passes/device_type_analysis.cpp +++ b/torch/csrc/jit/passes/device_type_analysis.cpp @@ -252,7 +252,7 @@ std::unique_ptr> // This analysis propagates input device types (if any) throughout the // graph. bool DeviceTypePropagation(std::shared_ptr& graph) { - auto tp = std::make_unique((graph)); + auto tp = std::make_unique(graph); bool changed = tp->run(); if (changed) { GRAPH_DUMP("After TensorPropertyPropagation pass:", graph); diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index ccb6e0bc163a4..686f6e660dba7 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -2414,8 +2414,8 @@ static size_t ONNXAssignOutputShape( } } else { std::string msg = - ("Model output has unsupported type. See " - "https://pytorch.org/docs/stable/onnx.html#types. Got type: "); + "Model output has unsupported type. See " + "https://pytorch.org/docs/stable/onnx.html#types. Got type: "; msg += THPUtils_typename(output_obj); throw std::runtime_error(msg); } diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index 8df57982bc331..a1bcc8d85b70b 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -1230,7 +1230,7 @@ void removeDequantizeFromInputs(const std::unordered_set& inputs) { TORCH_INTERNAL_ASSERT( dequantized_val->uses().size() == 1, "Expect to have one dequantize node for each use"); - // Replace useses of dequantized_val with the input of + // Replace uses of dequantized_val with the input of // dequantize node dequantized_val->replaceAllUsesWith(dequantize_node->inputs()[0]); dequantize_node->removeAllInputs(); diff --git a/torch/csrc/jit/passes/remove_inplace_ops.cpp b/torch/csrc/jit/passes/remove_inplace_ops.cpp index ad5f4ece457e1..87e56d82bb0d0 100644 --- a/torch/csrc/jit/passes/remove_inplace_ops.cpp +++ b/torch/csrc/jit/passes/remove_inplace_ops.cpp @@ -126,7 +126,7 @@ void ImplicitCastForBinaryInplaceOps(Block* b) { originalInputs.at(0)->type()->cast(); TensorTypePtr secondInp_tensor = originalInputs.at(1)->type()->cast(); - if (!(firstInp_tensor) || !(secondInp_tensor) || + if (!firstInp_tensor || !secondInp_tensor || !(firstInp_tensor->scalarType().has_value())) { continue; } diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 18068f2f78cb2..6561dc5bad1d2 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -1906,7 +1906,7 @@ class ShapePropagator : public PropertyPropBase { "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor") || node->matches("aten::div(Tensor self, Scalar other) -> Tensor") || node->matches("aten::mul(Tensor self, Scalar other) -> Tensor")) { - auto first_scalar_type = (tensor_types)[0]->scalarType(); + auto first_scalar_type = tensor_types[0]->scalarType(); auto second_scalar_type = tryScalarTypeFromJitType(*node->inputs()[1]->type()); if (!first_scalar_type || !second_scalar_type) { diff --git a/torch/csrc/jit/python/python_ir.cpp b/torch/csrc/jit/python/python_ir.cpp index 14637d425395f..88794ecbf3d73 100644 --- a/torch/csrc/jit/python/python_ir.cpp +++ b/torch/csrc/jit/python/python_ir.cpp @@ -906,7 +906,7 @@ void initPythonIRBindings(PyObject* module_) { "scalarType", [](Type& t) { auto scalar_type = t.expectRef().scalarType(); - return (scalar_type) ? toString(*scalar_type) : nullptr; + return scalar_type ? toString(*scalar_type) : nullptr; }) .def( "device", diff --git a/torch/csrc/jit/runtime/argument_spec.h b/torch/csrc/jit/runtime/argument_spec.h index 2170d376dd6a5..1b4cf86a1963c 100644 --- a/torch/csrc/jit/runtime/argument_spec.h +++ b/torch/csrc/jit/runtime/argument_spec.h @@ -454,8 +454,8 @@ inline CompleteArgumentInfo CompleteArgumentSpec::at(size_t i) const { inline std::optional convertOptional( std::optional const& from) { - return (from) ? std::optional(static_cast(*from)) - : std::optional{}; + return from ? std::optional(static_cast(*from)) + : std::optional{}; } } // namespace torch::jit diff --git a/torch/csrc/jit/runtime/register_ops_utils.h b/torch/csrc/jit/runtime/register_ops_utils.h index 7578ea6b1f99c..0ca42cfd32316 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.h +++ b/torch/csrc/jit/runtime/register_ops_utils.h @@ -121,7 +121,7 @@ double radians(double x); // Equivalent to list.at(idx) template -decltype(auto) getItem(const c10::List& list, int64_t idx) { +auto getItem(const c10::List& list, int64_t idx) { const int64_t list_size = list.size(); const int64_t normalized_idx = normalizeIndex(idx, list_size); if (normalized_idx < 0 || normalized_idx >= list_size) { diff --git a/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp b/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp index 5c4c65b24aebe..4194e5201ce75 100644 --- a/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp +++ b/torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp @@ -162,7 +162,7 @@ InlinedCallStackPtr InlinedCallStackDeserializer::deserialize( } cached_inlined_callstacks_[tup] = cs_ptr; // Invoking move constructor - // It is not clear if copy-ellision can happen since + // It is not clear if copy-elision can happen since // cs_ptr is copied into map above. // This is to help avoid ref count update return cs_ptr; diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp index 8e4f5fb037be3..fc018a87e9142 100644 --- a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp +++ b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp @@ -28,6 +28,10 @@ namespace flatbuffers = flatbuffers_fbsource; #include // NOLINT #endif +C10_CLANG_DIAGNOSTIC_PUSH() +C10_CLANG_DIAGNOSTIC_IGNORE("-Wswitch-default") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wswitch-enum") + namespace torch::jit { using flatbuffers::FlatBufferBuilder; @@ -858,3 +862,5 @@ bool register_flatbuffer_serializer() { } } // namespace torch::jit + +C10_CLANG_DIAGNOSTIC_POP() diff --git a/torch/csrc/jit/serialization/pickler_helper.cpp b/torch/csrc/jit/serialization/pickler_helper.cpp index 66b51b07f8074..c1d6794ded853 100644 --- a/torch/csrc/jit/serialization/pickler_helper.cpp +++ b/torch/csrc/jit/serialization/pickler_helper.cpp @@ -106,7 +106,7 @@ std::array< GetBackendMetaSerialization() { // The array to save function pointer for BackendMeta serialization. // key is the DeviceType, value is std::pair obj. - // value.first represent get function and value.seconde represent set function + // value.first represent get function and value.second represent set function static std::array< std::optional>, at::COMPILE_TIME_MAX_DEVICE_TYPES> diff --git a/torch/csrc/lazy/core/lazy_graph_executor.cpp b/torch/csrc/lazy/core/lazy_graph_executor.cpp index 754894e6096b1..c440357f9e16e 100644 --- a/torch/csrc/lazy/core/lazy_graph_executor.cpp +++ b/torch/csrc/lazy/core/lazy_graph_executor.cpp @@ -830,7 +830,7 @@ std::shared_ptr LazyGraphExecutor:: const SyncTensorsConfig& config) { SyncTensorCollection coll = CollectSyncTensors(*tensors, config); if (coll.indices.empty()) { - /* Enure previous execution is complete before exiting this + /* Ensure previous execution is complete before exiting this * function */ TensorCollectionBarrier(&coll); return nullptr; diff --git a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp index 1bb720b810f93..f1f69e092591c 100644 --- a/torch/csrc/lazy/ts_backend/ts_native_functions.cpp +++ b/torch/csrc/lazy/ts_backend/ts_native_functions.cpp @@ -466,6 +466,14 @@ at::Tensor LazyNativeFunctions::linalg_pinv( linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian); } +std::tuple LazyNativeFunctions::svd( + const at::Tensor& self, + bool some, + bool compute_uv) { + return at::functionalization::functionalize_aten_op::call( + self, some, compute_uv); +} + // functionalize_aten_op can't handle out= ops directly. // Instead, we can call the composite kernel from core, and copy and mutations // back to the inputs. diff --git a/torch/csrc/mtia/Module.cpp b/torch/csrc/mtia/Module.cpp index 0b8171a372653..468e4828c4122 100644 --- a/torch/csrc/mtia/Module.cpp +++ b/torch/csrc/mtia/Module.cpp @@ -10,6 +10,45 @@ namespace torch::mtia { +struct _MTIAGraph { + // MTIA use accelerator hooks to connect pytorch and outside. + // We need to provide the MTIAGraph class at Python layer, but the hooks only + // support hooking functions, not classes. Thus we store all MTIAGraph C++ + // instances in a map, and use a handle to choose the right instance. + int64_t handle_; + + _MTIAGraph(bool keep_graph = false) + : handle_(at::detail::getMTIAHooks().mtiagraphCreate(keep_graph)) {} + + ~_MTIAGraph() { + at::detail::getMTIAHooks().mtiagraphDestroy(handle_); + } + + void capture_begin(at::MempoolId_t pool) { + at::detail::getMTIAHooks().mtiagraphCaptureBegin(handle_, pool); + } + + void capture_end() { + at::detail::getMTIAHooks().mtiagraphCaptureEnd(handle_); + } + + void instantiate() { + at::detail::getMTIAHooks().mtiagraphInstantiate(handle_); + } + + void replay() { + at::detail::getMTIAHooks().mtiagraphReplay(handle_); + } + + void reset() { + at::detail::getMTIAHooks().mtiagraphReset(handle_); + } + + at::MempoolId_t pool() { + return at::detail::getMTIAHooks().mtiagraphPool(handle_); + } +}; + void initModule(PyObject* module) { auto m = py::handle(module).cast(); @@ -131,6 +170,15 @@ void initModule(PyObject* module) { m.def("_mtia_resetPeakMemoryStats", [](c10::DeviceIndex device_index) { at::detail::getMTIAHooks().resetPeakMemoryStats(device_index); }); + + py::class_<_MTIAGraph>(m, "_MTIAGraph") + .def(py::init(), py::arg("keep_graph") = false) + .def("capture_begin", &_MTIAGraph::capture_begin) + .def("capture_end", &_MTIAGraph::capture_end) + .def("instantiate", &_MTIAGraph::instantiate) + .def("replay", &_MTIAGraph::replay) + .def("reset", &_MTIAGraph::reset) + .def("pool", &_MTIAGraph::pool); } } // namespace torch::mtia diff --git a/torch/csrc/profiler/collection.cpp b/torch/csrc/profiler/collection.cpp index 133951dd817ca..617316617fc67 100644 --- a/torch/csrc/profiler/collection.cpp +++ b/torch/csrc/profiler/collection.cpp @@ -768,7 +768,7 @@ void mark_finished(std::shared_ptr& r) { // Assumption: Total threads number will not exceed 2^16-1, and total ops will // not exceed 2^48 -1. static uint64_t getForwardThreadKey(uint64_t tid, uint64_t seqNr) { - return (((tid) << 48) | ((seqNr) & (((uint64_t)1 << 48) - 1))); + return ((tid << 48) | (seqNr & (((uint64_t)1 << 48) - 1))); } void generateForwardBackwardLink( @@ -915,7 +915,7 @@ void passEventsToKineto( // on-demand Kineto activity handling. Enabling this path // for Profiler API could cause side effects as much has changed since. // Make a surgical fix here until we holistically assess the on-demand - // vs API path framentation, which has been snowballing in complexity + // vs API path fragmentation, which has been snowballing in complexity // and thus flakiness. if (config.global()) { e->kineto_activity_ = activity; diff --git a/torch/csrc/profiler/collection.h b/torch/csrc/profiler/collection.h index b05f4608fb77a..d66eb630a47d9 100644 --- a/torch/csrc/profiler/collection.h +++ b/torch/csrc/profiler/collection.h @@ -380,12 +380,12 @@ struct TORCH_API Result : public std::enable_shared_from_this { } template - decltype(auto) visit(T&& visitor) { + auto visit(T&& visitor) { return std::visit(std::forward(visitor), extra_fields_); } template - decltype(auto) visit(T&& visitor) const { + auto visit(T&& visitor) const { return std::visit(std::forward(visitor), extra_fields_); } diff --git a/torch/csrc/profiler/unwind/unwind.cpp b/torch/csrc/profiler/unwind/unwind.cpp index 0730a4ce58600..2b30df4e2a60e 100644 --- a/torch/csrc/profiler/unwind/unwind.cpp +++ b/torch/csrc/profiler/unwind/unwind.cpp @@ -511,7 +511,7 @@ extern "C" C10_USED void unwind_c( std::shared_lock lock(torch::unwind::cache_mutex_); torch::unwind::UnwindState state{}; // NOLINTNEXTLINE(performance-no-int-to-ptr) - state.rip = *(int64_t*)(rsp); + state.rip = *(int64_t*)rsp; // +8 because we saved rsp after the return address was already pushed // to the stack state.rsp = rsp + 8; diff --git a/torch/csrc/utils/generated_serialization_types.h b/torch/csrc/utils/generated_serialization_types.h index 3090d58f5c094..f7abfece3bc31 100644 --- a/torch/csrc/utils/generated_serialization_types.h +++ b/torch/csrc/utils/generated_serialization_types.h @@ -1905,8 +1905,8 @@ class Graph { std::unordered_map sym_int_values; std::unordered_map sym_bool_values; bool is_single_tensor_return = false; - std::unordered_map custom_obj_values = {}; - std::unordered_map sym_float_values = {}; + std::unordered_map custom_obj_values; + std::unordered_map sym_float_values; public: @@ -3027,8 +3027,8 @@ class GraphModule { Graph graph; GraphSignature signature; std::vector module_call_graph; - std::unordered_map metadata = {}; - std::unordered_map treespec_namedtuple_fields = {}; + std::unordered_map metadata; + std::unordered_map treespec_namedtuple_fields; public: @@ -3109,9 +3109,9 @@ class ExportedProgram { std::unordered_map opset_version; std::unordered_map range_constraints; SchemaVersion schema_version; - std::vector verifiers = {}; + std::vector verifiers; std::string torch_version = "<=2.4"; - std::vector guards_code = {}; + std::vector guards_code; public: diff --git a/torch/csrc/utils/pycfunction_helpers.h b/torch/csrc/utils/pycfunction_helpers.h index 745e1842e682c..151c11a0df42f 100644 --- a/torch/csrc/utils/pycfunction_helpers.h +++ b/torch/csrc/utils/pycfunction_helpers.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include @@ -11,3 +12,15 @@ inline PyCFunction castPyCFunctionWithKeywords(PyCFunctionWithKeywords func) { C10_DIAGNOSTIC_POP() C10_DIAGNOSTIC_POP() } + +#if !IS_PYTHON_3_13_PLUS +using PyCFunctionFast = _PyCFunctionFast; +#endif + +inline PyCFunction castPyCFunctionFast(PyCFunctionFast func) { + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wcast-function-type") + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wcast-function-type-strict") + return reinterpret_cast(func); + C10_DIAGNOSTIC_POP() + C10_DIAGNOSTIC_POP() +} diff --git a/torch/csrc/utils/python_compat.h b/torch/csrc/utils/python_compat.h index 16292e4fd0308..16308dad4421d 100644 --- a/torch/csrc/utils/python_compat.h +++ b/torch/csrc/utils/python_compat.h @@ -9,11 +9,11 @@ extern "C" { // PyTorch-only compat functions -#define IS_PYTHON_3_11_PLUS PY_VERSION_HEX >= 0x030B00C1 -#define IS_PYTHON_3_12_PLUS PY_VERSION_HEX >= 0x030C0000 -#define IS_PYTHON_3_13_PLUS PY_VERSION_HEX >= 0x030D0000 -#define IS_PYTHON_3_14_PLUS PY_VERSION_HEX >= 0x030E0000 -#define IS_PYTHON_3_15_PLUS PY_VERSION_HEX >= 0x030F0000 +#define IS_PYTHON_3_11_PLUS (PY_VERSION_HEX >= 0x030B00C1) +#define IS_PYTHON_3_12_PLUS (PY_VERSION_HEX >= 0x030C0000) +#define IS_PYTHON_3_13_PLUS (PY_VERSION_HEX >= 0x030D0000) +#define IS_PYTHON_3_14_PLUS (PY_VERSION_HEX >= 0x030E0000) +#define IS_PYTHON_3_15_PLUS (PY_VERSION_HEX >= 0x030F0000) static inline int PyCode_GetNCellvars(PyCodeObject* code) { // gh-26364 added co_ncellvars to Python 3.11.0rc1 diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index c422e8af0ecdb..cc2be68402a8e 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -1708,9 +1708,9 @@ bool isValidDLPackCapsule(PyObject* data) { Tensor tensor_fromDLPack(PyObject* data) { const char* bad_capsule = - ("from_dlpack received an invalid capsule. " - "Note that DLTensor capsules can be consumed only once, " - "so you might have already constructed a tensor from it once."); + "from_dlpack received an invalid capsule. " + "Note that DLTensor capsules can be consumed only once, " + "so you might have already constructed a tensor from it once."; if (PyCapsule_IsValid( data, at::DLPackTraits::capsule)) { diff --git a/torch/csrc/xpu/Module.cpp b/torch/csrc/xpu/Module.cpp index 5398700e93274..44d11a5bd9741 100644 --- a/torch/csrc/xpu/Module.cpp +++ b/torch/csrc/xpu/Module.cpp @@ -261,7 +261,7 @@ static PyObject* THXPModule_resetAccumulatedMemoryStats( // XPU module initialization static void registerXpuDeviceProperties(PyObject* module) { - // Add _xpuDevicePropertires class to torch._C + // Add _xpuDeviceProperties class to torch._C using namespace c10::xpu; auto get_device_type = [](const DeviceProp& prop) { std::ostringstream stream; @@ -420,6 +420,9 @@ static void initXpuMethodBindings(PyObject* module) { [](c10::DeviceIndex device, c10::DeviceIndex peer) { return at::xpu::canDeviceAccessPeer(device, peer); }); + m.def("_xpu_getMemoryFraction", [](c10::DeviceIndex device) { + return c10::xpu::XPUCachingAllocator::getMemoryFraction(device); + }); m.def("_xpu_setMemoryFraction", [](double fraction, c10::DeviceIndex device) { c10::xpu::XPUCachingAllocator::setMemoryFraction(fraction, device); }); diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 34723b0e4c2ba..dff869742df56 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -238,6 +238,14 @@ def _sleep(cycles): torch._C._cuda_sleep(cycles) +def _busy_wait_for_flag(): + torch._C._cuda_busy_wait_for_flag() + + +def _clear_flag(): + torch._C._cuda_clear_flag() + + def _extract_arch_version(arch_string: str) -> int: """Extracts the architecture string from a CUDA version""" base = arch_string.split("_", maxsplit=2)[1] diff --git a/torch/custom_class_detail.h b/torch/custom_class_detail.h index 512320081b5d6..82f710ec553b9 100644 --- a/torch/custom_class_detail.h +++ b/torch/custom_class_detail.h @@ -129,7 +129,7 @@ call_torchbind_method_from_stack( Functor& functor, jit::Stack& stack, std::index_sequence /*unused*/) { - (void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would + (void)stack; // when sizeof...(ivalue_arg_indices) == 0, this argument would // be unused and we have to silence the compiler warning. constexpr size_t num_ivalue_args = sizeof...(ivalue_arg_indices); @@ -138,7 +138,7 @@ call_torchbind_method_from_stack( typename c10::guts::infer_function_traits_t::parameter_types; // TODO We shouldn't use c10::impl stuff directly here. We should use the // KernelFunction API instead. - return (functor)(c10::impl::ivalue_to_arg< + return functor(c10::impl::ivalue_to_arg< typename c10::impl::decay_if_not_tensor< c10::guts::typelist:: element_t>::type, diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index f0beed8f4d4c3..6c8912ffa4fa3 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -135,7 +135,7 @@ def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600): # this. # pyrefly: ignore [deprecated] from .distributed_c10d import * # noqa: F403 - from .distributed_c10d import ( # pyrefly: ignore # deprecated + from .distributed_c10d import ( # pyrefly: ignore # deprecated; pyrefly: ignore [deprecated] _all_gather_base, _coalescing_manager, _CoalescingManager, diff --git a/torch/distributed/_composable/replicate_with_fsdp.py b/torch/distributed/_composable/replicate_with_fsdp.py index 368026801172f..8ee84bfe34a3b 100644 --- a/torch/distributed/_composable/replicate_with_fsdp.py +++ b/torch/distributed/_composable/replicate_with_fsdp.py @@ -14,13 +14,12 @@ OffloadPolicy, ) from torch.distributed.fsdp._fully_shard._fsdp_common import ( + DDPMeshInfo, detect_compiled_autograd, - HSDPMeshInfo, ) from torch.distributed.fsdp._fully_shard._fsdp_init import ( _get_device_from_mesh, _get_managed_states, - _get_post_forward_mesh_info, _init_default_fully_shard_mesh, _move_states_to_device, ) @@ -184,23 +183,19 @@ def replicate_impl( ) mesh = mesh or _init_default_fully_shard_mesh() - if mesh.ndim != 2: - raise ValueError(f"replicate expects a 2D DeviceMesh but got {mesh}") + if mesh.ndim != 1: + raise ValueError(f"replicate expects a 1D DeviceMesh but got {mesh}") else: if mesh.mesh_dim_names is None: raise AssertionError( "Please init the 2D mesh for HSDP with mesh_dim_names specified" ) - mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0) + mesh_info = DDPMeshInfo(mesh, replicate_mesh_dim=0) device = _get_device_from_mesh(mesh) auto_reshard_after_forward = reshard_after_forward is None - # If the user does not provide ``reshard_after_forward``, we set it to True. - # During lazy_init, we identify which module is the root and override its value to False - post_forward_mesh_info = _get_post_forward_mesh_info( - reshard_after_forward if not auto_reshard_after_forward else True, # type: ignore[arg-type] - mesh_info, - ) + + post_forward_mesh_info = None arg_module = module modules = ( @@ -217,7 +212,7 @@ def replicate_impl( state._fsdp_param_group = FSDPParamGroup( params, modules, - mesh_info, + mesh_info, # type: ignore[arg-type] post_forward_mesh_info, device, shard_placement_fn, @@ -341,8 +336,8 @@ def replicate_mesh(): device = torch._C._get_accelerator() mesh = init_device_mesh( device.type, - mesh_shape=(default_pg.size(), 1), - mesh_dim_names=("replicate", "shard"), + mesh_shape=(default_pg.size(),), + mesh_dim_names=("replicate",), ) return mesh diff --git a/torch/distributed/_composable_state.py b/torch/distributed/_composable_state.py index 4f2808b545210..d75f1cd71a4c7 100644 --- a/torch/distributed/_composable_state.py +++ b/torch/distributed/_composable_state.py @@ -15,7 +15,8 @@ class _State: def _insert_module_state(module: nn.Module, state: _State) -> None: global _module_state_mapping - assert module not in _module_state_mapping, f"Inserting {module} more than once." + if module in _module_state_mapping: + raise AssertionError(f"Inserting {module} more than once.") _module_state_mapping[module] = weakref.ref(state) diff --git a/torch/distributed/_dist2.py b/torch/distributed/_dist2.py index ce5cb8d7e0cc3..d9ed7003ccfdd 100644 --- a/torch/distributed/_dist2.py +++ b/torch/distributed/_dist2.py @@ -71,7 +71,8 @@ def _gloo_factory( ) -> ProcessGroup: from torch.distributed import ProcessGroupGloo - assert len(kwargs) == 0, "Gloo backend received unexpected kwargs" + if len(kwargs) != 0: + raise AssertionError("Gloo backend received unexpected kwargs") backend_class = ProcessGroupGloo(store, rank, world_size, timeout) backend_class._set_sequence_number_for_group() diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 8574e25833523..9308a63d9e7c2 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -194,7 +194,8 @@ def all_gather_tensor( :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover that information and perform collective algebraic optimization. Use other forms of input for that. """ - assert self.is_contiguous() + if not self.is_contiguous(): + raise AssertionError("Tensor must be contiguous for all_gather_tensor") group_name = _resolve_group_name(group, tag) group_size = c10d._get_group_size_by_name(group_name) tensor = torch.ops._c10d_functional.all_gather_into_tensor( @@ -269,9 +270,10 @@ def reduce_scatter_tensor( group_name = _resolve_group_name(group, tag) group_size = c10d._get_group_size_by_name(group_name) - assert self.size(scatter_dim) % group_size == 0, ( - f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size})" - ) + if self.size(scatter_dim) % group_size != 0: + raise AssertionError( + f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size})" + ) if scatter_dim != 0: tensor_list = torch.chunk(self, group_size, dim=scatter_dim) self = torch.cat(tensor_list) @@ -308,9 +310,10 @@ def reduce_scatter_tensor_autograd( group_name = _resolve_group_name(group, tag) group_size = c10d._get_group_size_by_name(group_name) - assert self.size(scatter_dim) % group_size == 0, ( - f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" - ) + if self.size(scatter_dim) % group_size != 0: + raise AssertionError( + f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" + ) if scatter_dim != 0: tensor_list = torch.chunk(self, group_size, dim=scatter_dim) self = torch.cat(tensor_list) @@ -407,11 +410,15 @@ def reduce_scatter_tensor_coalesced( group_name = _resolve_group_name(group, tag) group_size = c10d._get_group_size_by_name(group_name) - assert len(scatter_dim) == len(inputs) - for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)): - assert tensor.size(dim) % group_size == 0, ( - f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}" + if len(scatter_dim) != len(inputs): + raise AssertionError( + f"Length of scatter_dim ({len(scatter_dim)}) must equal length of inputs ({len(inputs)})" ) + for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)): + if tensor.size(dim) % group_size != 0: + raise AssertionError( + f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}" + ) if dim != 0: tensor_list = torch.chunk(tensor, group_size, dim=dim) inputs[idx] = torch.cat(tensor_list) @@ -429,7 +436,8 @@ def reduce_scatter_tensor_coalesced( # This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias. # Today, this maps 1:1 with "aten ops that are views". def _is_view_op(tgt): - assert isinstance(tgt, torch._ops.OpOverload) + if not isinstance(tgt, torch._ops.OpOverload): + raise AssertionError(f"Expected torch._ops.OpOverload, got {type(tgt)}") # Don't apply the view optimization to any `CompositeImplicitAutograd` ops. # See issue: https://github.com/pytorch/pytorch/issues/133421 if torch._C._dispatch_has_kernel_for_dispatch_key( @@ -466,20 +474,25 @@ def all_to_all_single( that information and perform collective algebraic optimization. Use other forms of input for that. """ if output_split_sizes is not None: - assert all( + if not all( isinstance(size, (int, torch.SymInt)) for size in output_split_sizes - ), output_split_sizes + ): + raise AssertionError( + f"All output_split_sizes must be int or SymInt, got {output_split_sizes}" + ) if input_split_sizes is not None: - assert all( - isinstance(size, (int, torch.SymInt)) for size in input_split_sizes - ), input_split_sizes + if not all(isinstance(size, (int, torch.SymInt)) for size in input_split_sizes): + raise AssertionError( + f"All input_split_sizes must be int or SymInt, got {input_split_sizes}" + ) group_name = _resolve_group_name(group, tag) group_size = c10d._get_group_size_by_name(group_name) if output_split_sizes is None or input_split_sizes is None: - assert output_split_sizes is None and input_split_sizes is None, ( - "output_split_sizes and input_split_sizes must either be " - "specified together or both set to None" - ) + if not (output_split_sizes is None and input_split_sizes is None): + raise AssertionError( + "output_split_sizes and input_split_sizes must either be " + "specified together or both set to None" + ) output_split_sizes = [self.shape[0] // group_size] * group_size input_split_sizes = output_split_sizes tensor = torch.ops._c10d_functional.all_to_all_single( # type: ignore[attr-defined] @@ -502,21 +515,26 @@ def all_to_all_single_autograd( Same as all_to_all_single but supports autograd. """ if output_split_sizes is not None: - assert all( + if not all( isinstance(size, (int, torch.SymInt)) for size in output_split_sizes - ), output_split_sizes + ): + raise AssertionError( + f"All output_split_sizes must be int or SymInt, got {output_split_sizes}" + ) if input_split_sizes is not None: - assert all( - isinstance(size, (int, torch.SymInt)) for size in input_split_sizes - ), input_split_sizes + if not all(isinstance(size, (int, torch.SymInt)) for size in input_split_sizes): + raise AssertionError( + f"All input_split_sizes must be int or SymInt, got {input_split_sizes}" + ) group_name = _resolve_group_name(group, tag) group_size = c10d._get_group_size_by_name(group_name) if output_split_sizes is None or input_split_sizes is None: - assert output_split_sizes is None and input_split_sizes is None, ( - "output_split_sizes and input_split_sizes must either be " - "specified together or both set to None" - ) + if not (output_split_sizes is None and input_split_sizes is None): + raise AssertionError( + "output_split_sizes and input_split_sizes must either be " + "specified together or both set to None" + ) output_split_sizes = [self.shape[0] // group_size] * group_size input_split_sizes = output_split_sizes tensor = torch.ops._c10d_functional_autograd.all_to_all_single( # type: ignore[attr-defined] @@ -599,7 +617,10 @@ def tolist(self): @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): - assert meta is None + if meta is not None: + raise AssertionError( + "meta must be None for AsyncCollectiveTensor unflatten" + ) elem = inner_tensors["elem"] return AsyncCollectiveTensor(elem) @@ -631,7 +652,7 @@ def _get_acs_underlying_tensor(self): @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] - if func == torch.ops.aten.view.default: + if func is torch.ops.aten.view.default: # Fast handle aten.view as a lot of view related op goes to aten.view # eventually, this avoids pytree slowdown # pyrefly: ignore [index-error] @@ -649,7 +670,10 @@ def unwrap(e: AsyncCollectiveTensor): def wrap(e: torch.Tensor): # wait_tensor is idepotent and will do stream sync only once - assert not isinstance(e, AsyncCollectiveTensor) + if isinstance(e, AsyncCollectiveTensor): + raise AssertionError( + "Cannot wrap an AsyncCollectiveTensor inside another AsyncCollectiveTensor" + ) res = AsyncCollectiveTensor(e) return res @@ -723,9 +747,10 @@ def cast_listint(x): group_size = len(rankset) tag = tag or c10d._get_group_tag(group) elif isinstance(group, DeviceMesh): - assert group.ndim == 1, ( - "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" - ) + if group.ndim != 1: + raise AssertionError( + "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" + ) # TODO: it should run collective in the whole mesh instead of dim 0 pg = group.get_group() rankset = dist.get_process_group_ranks(pg) @@ -764,9 +789,10 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str: elif isinstance(group, str): return group elif isinstance(group, DeviceMesh): - assert group.ndim == 1, ( - "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" - ) + if group.ndim != 1: + raise AssertionError( + "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" + ) return group._dim_group_names[0] elif isinstance(group, tuple): if ( @@ -1056,12 +1082,14 @@ def all_gather_tensor_inplace( tag: str = "", gather_dim: int = 0, ): - assert not async_op, ( - "Can't remap async version of inplace op to functional collective" - ) + if async_op: + raise AssertionError( + "Can't remap async version of inplace op to functional collective" + ) group = group or dist.group.WORLD - assert group is not None + if group is None: + raise AssertionError("group cannot be None") return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag)) @@ -1075,12 +1103,14 @@ def reduce_scatter_tensor_inplace( scatter_dim: int = 0, tag: str = "", ): - assert not async_op, ( - "Can't remap async version of inplace op to functional collective" - ) + if async_op: + raise AssertionError( + "Can't remap async version of inplace op to functional collective" + ) group = group or dist.group.WORLD - assert group is not None + if group is None: + raise AssertionError("group cannot be None") return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag)) @@ -1104,12 +1134,14 @@ def all_reduce_inplace( async_op: bool = False, tag: str = "", ): - assert not async_op, ( - "Can't remap async version of inplace op to functional collective" - ) + if async_op: + raise AssertionError( + "Can't remap async version of inplace op to functional collective" + ) group = group or dist.group.WORLD - assert group is not None + if group is None: + raise AssertionError("group cannot be None") return tensor.copy_(all_reduce(tensor, op, group, tag)) @@ -1123,12 +1155,14 @@ def all_to_all_inplace( async_op=False, tag: str = "", ): - assert not async_op, ( - "Can't remap async version of inplace op to functional collective" - ) + if async_op: + raise AssertionError( + "Can't remap async version of inplace op to functional collective" + ) group = group or dist.group.WORLD - assert group is not None + if group is None: + raise AssertionError("group cannot be None") return output.copy_( all_to_all_single( @@ -1148,15 +1182,16 @@ def all_gather_inplace( async_op=False, tag: str = "", ): - assert not async_op, ( - "Can't remap async version of inplace op to functional collective" - ) - assert tensor.dim() == 0 or all(t.size(0) == tensor.size(0) for t in tensor_list), ( - "Remapping variable size all_gather is not yet supported" - ) + if async_op: + raise AssertionError( + "Can't remap async version of inplace op to functional collective" + ) + if tensor.dim() != 0 and not all(t.size(0) == tensor.size(0) for t in tensor_list): + raise AssertionError("Remapping variable size all_gather is not yet supported") group = group or dist.group.WORLD - assert group is not None + if group is None: + raise AssertionError("group cannot be None") output = all_gather_tensor(tensor, 0, group, tag) @@ -1177,7 +1212,7 @@ def all_gather_inplace( return tensor_list -from torch.distributed.distributed_c10d import ( # pyrefly: ignore # deprecated +from torch.distributed.distributed_c10d import ( # pyrefly: ignore # deprecated; pyrefly: ignore [deprecated] _all_gather_base as legacy_all_gather_base, _reduce_scatter_base as legacy_reduce_scatter_base, all_gather as legacy_all_gather, diff --git a/torch/distributed/_functional_collectives_impl.py b/torch/distributed/_functional_collectives_impl.py index 0c1ac0a079dec..e6174c11cd61a 100644 --- a/torch/distributed/_functional_collectives_impl.py +++ b/torch/distributed/_functional_collectives_impl.py @@ -97,10 +97,11 @@ def _all_to_all_single( group_size: int, ): if output_split_sizes is None or input_split_sizes is None: - assert output_split_sizes is None and input_split_sizes is None, ( - "output_split_sizes and input_split_sizes must either be " - "specified together or both set to None" - ) + if not (output_split_sizes is None and input_split_sizes is None): + raise AssertionError( + "output_split_sizes and input_split_sizes must either be " + "specified together or both set to None" + ) output_split_sizes = [input.shape[0] // group_size] * group_size input_split_sizes = output_split_sizes diff --git a/torch/distributed/_local_tensor/__init__.py b/torch/distributed/_local_tensor/__init__.py index b3bc1b5ed8164..ea9707b2e1e85 100644 --- a/torch/distributed/_local_tensor/__init__.py +++ b/torch/distributed/_local_tensor/__init__.py @@ -1,5 +1,7 @@ from ast import Call +from torch._ops import OpOverload + """ A LocalTensor is a tensor subclass which simulates a tensor that is @@ -65,12 +67,14 @@ from torch import Size, SymBool, SymInt, Tensor from torch._C import DispatchKey, DispatchKeySet, ScriptObject from torch._export.wrappers import mark_subclass_constructor_exportable_experimental +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode from torch.distributed import DeviceMesh, ProcessGroup from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.distributed.distributed_c10d import _get_default_group from torch.fx.experimental._constant_symnode import ConstantIntNode from torch.nested._internal.nested_int import NestedIntNode from torch.utils import _pytree as pytree +from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import return_and_correct_aliasing, TorchDispatchMode from torch.utils.checkpoint import get_device_states, set_device_states @@ -81,6 +85,19 @@ from . import _c10d +def _is_inplace_op(op: OpOverload | Callable[..., Any]) -> bool: + return ( + isinstance(op, OpOverload) + # Not precise heuristic to detect inplace operation + and op._schema.name[-1] == "_" + # Strengthen the heuristic to check that the first argument and return value are a write + and len(op._schema.arguments) > 0 + and op._schema.arguments[0].is_write + and len(op._schema.returns) > 0 + and op._schema.returns[0].is_write + ) + + def _int_on_rank(i: "int | LocalIntNode | ConstantIntNode", r: int) -> int: if isinstance(i, LocalIntNode): return i._local_ints[r] @@ -100,7 +117,13 @@ def _check_for_subclass_arg(x: object) -> bool: return ( not isinstance(x, LocalTensor) and isinstance(x, Tensor) - and type(x) not in (Tensor, torch.nn.Parameter, torch.nn.Buffer) + and type(x) + not in ( + Tensor, + FakeTensor, + torch.nn.Parameter, + torch.nn.Buffer, + ) ) @@ -220,7 +243,7 @@ def _zero_sized_like(tensor: torch.Tensor, dim: int) -> torch.Tensor: def _for_each_rank_run_func( - func: Callable[..., Any], + func: OpOverload | Callable[..., Any], ranks: frozenset[int], args: Sequence[Any], kwargs: dict[str, Any], @@ -256,7 +279,15 @@ def _for_each_rank_run_func( split_dim = 0 if len(rank_flat_args) < 3 else rank_flat_args[2] default_value = _zero_sized_like(tensor, split_dim) - ret = _combine_rank_results(flat_rank_rets, default_value) + if _is_inplace_op(func): + alias = False + # For the in-place ops return self + ret = args[0] + if isinstance(func, OpOverload) and torch.Tag.inplace_view in func.tags: + # Ensure that wrapper tensor size is synchronized with its local tensors + ret._sync_meta() + else: + ret = _combine_rank_results(flat_rank_rets, default_value) if alias: return return_and_correct_aliasing(func, args, kwargs, ret) @@ -386,6 +417,11 @@ def ne(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool: r = {self._local_ints[r] != _int_on_rank(other, r) for r in self._local_ints} return torch._C._get_constant_bool_symnode(len(r) > 1 or next(iter(r))) + def ge(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool: + r = {self._local_ints[r] >= _int_on_rank(other, r) for r in self._local_ints} + assert len(r) == 1, (self, other) + return torch._C._get_constant_bool_symnode(next(iter(r))) + def gt(self, other: "int | LocalIntNode | ConstantIntNode") -> bool | SymBool: r = {self._local_ints[r] > _int_on_rank(other, r) for r in self._local_ints} assert len(r) == 1, (self, other) @@ -400,6 +436,93 @@ def wrap_int(self, num: int) -> "LocalIntNode | ConstantIntNode": return ConstantIntNode(num) +_LOCAL_TENSOR_ATTR_PREFIX = "_local_tensor_" + + +def _is_local_tensor_attr(attr: str) -> bool: + return attr.startswith(_LOCAL_TENSOR_ATTR_PREFIX) + + +def _to_local_tensor_attr(rank: int) -> str: + return f"{_LOCAL_TENSOR_ATTR_PREFIX}{rank}" + + +def _from_local_tensor_attr(attr: str) -> int: + if not _is_local_tensor_attr(attr): + raise AssertionError(f"Invalid local tensor attr {attr}") + return int(attr[len(_LOCAL_TENSOR_ATTR_PREFIX) :]) + + +def _all_elements_same(values: list[Any]) -> bool: + if not values: + return True + first_value = values[0] + return all(value == first_value for value in values) + + +def _compute_local_tensor_meta( + local_tensors: dict[int, torch.Tensor], +) -> tuple[ + list[torch.SymInt | int], + list[torch.SymInt | int], + torch.device, + torch.dtype, + torch.layout, + DispatchKeySet, +]: + """ + Computes the meta information for a LocalTensor from its local tensors. + """ + it = iter(local_tensors.values()) + first_local_tensor = next(it) + + first_shape = first_local_tensor.shape + first_stride = first_local_tensor.stride() + dtype = first_local_tensor.dtype + device = first_local_tensor.device + layout = first_local_tensor.layout + + extra_dispatch_keys = _get_extra_dispatch_keys(first_local_tensor) + + # Assert that all tensors have the same dtype, layout and dispatch keys. Due + # to uneven sharding, it is possible that tensors will have different shapes. + for local_tensor in it: + assert dtype == local_tensor.dtype, ( + "Tensors representing LocalTensor shards must have the same dtype" + ) + assert layout == local_tensor.layout, ( + "Tensors representing LocalTensor shards must have the same layout" + ) + assert extra_dispatch_keys == _get_extra_dispatch_keys(local_tensor), ( + "Tensors representing LocalTensor shards must have the same set of extra dispatch keys" + ) + + # Compute shape/stride. We allow for non-SPMD'ness here + local_shapes: dict[int, dict[int, int]] = defaultdict(dict) # dim => rank => size + local_strides: dict[int, dict[int, int]] = defaultdict(dict) # dim => rank => size + for r, local_tensor in local_tensors.items(): + for d, size in enumerate(local_tensor.shape): + local_shapes[d][r] = size + local_strides[d][r] = local_tensor.stride(d) + shape = [ + ( + first_shape[d] + if _all_elements_same(list(local_shapes[d].values())) + else torch.SymInt(LocalIntNode(local_shapes[d])) + ) + for d in range(len(first_shape)) + ] + strides = [ + ( + first_stride[d] + if _all_elements_same(list(local_strides[d].values())) + else torch.SymInt(LocalIntNode(local_strides[d])) + ) + for d in range(len(first_shape)) + ] + return shape, strides, device, dtype, layout, extra_dispatch_keys + + class LocalTensor(torch.Tensor): """ LocalTensor is a Tensor subclass that simulates a tensor distributed across multiple SPMD @@ -418,13 +541,15 @@ class LocalTensor(torch.Tensor): _local_tensors: dict[int, torch.Tensor] # Precomputed for speed set of keys from the local tensor map. _ranks: frozenset[int] - __slots__ = ["_local_tensors", "_ranks"] + _size: list[torch.SymInt | int] + __slots__ = ["_local_tensors", "_ranks", "_size"] @staticmethod @torch._disable_dynamo def __new__( cls, local_tensors: dict[int, torch.Tensor], + requires_grad: bool = False, ) -> "LocalTensor": if any(t.requires_grad for t in local_tensors.values()): raise AssertionError( @@ -432,57 +557,9 @@ def __new__( "Make a custom autograd function and make sure you detach the inner tensors." ) - it = iter(local_tensors.values()) - first_local_tensor = next(it) - - first_shape = first_local_tensor.shape - first_stride = first_local_tensor.stride() - dtype = first_local_tensor.dtype - device = first_local_tensor.device - layout = first_local_tensor.layout - - extra_dispatch_keys = _get_extra_dispatch_keys(first_local_tensor) - - # Assert that all tensors have the same dtype, layout and dispatch keys. Due - # to uneven sharding, it is possible that tensors will have different shapes. - for local_tensor in it: - assert dtype == local_tensor.dtype, ( - "Tensors representing LocalTensor shards must have the same dtype" - ) - assert layout == local_tensor.layout, ( - "Tensors representing LocalTensor shards must have the same layout" - ) - assert extra_dispatch_keys == _get_extra_dispatch_keys(local_tensor), ( - "Tensors representing LocalTensor shards must have the same set of extra dispatch keys" - ) - - # Compute shape/stride. We allow for non-SPMD'ness here - local_shapes: dict[int, dict[int, int]] = defaultdict( - dict - ) # dim => rank => size - local_strides: dict[int, dict[int, int]] = defaultdict( - dict - ) # dim => rank => size - for r, local_tensor in local_tensors.items(): - for d, size in enumerate(local_tensor.shape): - local_shapes[d][r] = size - local_strides[d][r] = local_tensor.stride(d) - shape = [ - ( - first_shape[d] - if len(set(local_shapes[d])) == 1 - else torch.SymInt(LocalIntNode(local_shapes[d])) - ) - for d in range(len(first_shape)) - ] - strides = [ - ( - first_stride[d] - if len(set(local_strides[d])) == 1 - else torch.SymInt(LocalIntNode(local_strides[d])) - ) - for d in range(len(first_shape)) - ] + (shape, strides, device, dtype, layout, extra_dispatch_keys) = ( + _compute_local_tensor_meta(local_tensors) + ) r = torch.Tensor._make_wrapper_subclass( cls, @@ -491,7 +568,13 @@ def __new__( dtype=dtype, device=device, layout=layout, - requires_grad=False, + # In place ops potentially change local tensor sizes (e.g. resize_). While + # executing an in-place op the return value must be the same as "self" input + # otherwise we can introduce errors due to tensor identity changes. Hence we + # need to be able to update wrapper subclass sizes after in-place ops. This + # dispatch policy allows us to do that. + dispatch_sizes_strides_policy="sizes", + requires_grad=requires_grad, _extra_dispatch_keys=extra_dispatch_keys, ) @@ -501,6 +584,7 @@ def __new__( } r._local_tensors = local_tensors r._ranks = frozenset(local_tensors.keys()) + r._size = shape return r @torch._disable_dynamo @@ -512,9 +596,7 @@ def __deepcopy__(self, memo: dict[Any, Any] | None) -> "LocalTensor": local_tensors_copy = { r: copy.deepcopy(t, memo) for r, t in self._local_tensors.items() } - tensor_copy = LocalTensor(local_tensors_copy) - tensor_copy.requires_grad = self.requires_grad - return tensor_copy + return LocalTensor(local_tensors_copy, self.requires_grad) def __repr__(self) -> str: # type: ignore[override] parts = [] @@ -524,12 +606,21 @@ def __repr__(self) -> str: # type: ignore[override] tensors_str = ",\n".join(parts) return f"LocalTensor(\n{tensors_str}\n)" + def __getattr__(self, name: str) -> Any: + if _is_local_tensor_attr(name): + rank = _from_local_tensor_attr(name) + if rank not in self._ranks: + raise AttributeError(f"Local tensor has no knowledge of rank {rank}") + return self._local_tensors[rank] + return object.__getattribute__(self, name) + def __tensor_flatten__(self) -> tuple[list[str], tuple[Any, ...]]: """ protocol to inform how to flatten a DTensor to local tensor for PT2 tracing """ - return ["_local_tensors"], () + local_tensor_attrs = [_to_local_tensor_attr(r) for r in self._ranks] + return local_tensor_attrs, () @staticmethod def __tensor_unflatten__( @@ -541,8 +632,9 @@ def __tensor_unflatten__( assert flatten_spec is not None, ( "Expecting spec to be not None from `__tensor_flatten__` return value!" ) - local_tensors = inner_tensors["_local_tensors"] - # pyrefly: ignore [bad-argument-type, bad-argument-count] + local_tensors = { + _from_local_tensor_attr(a): t for a, t in inner_tensors.items() + } return LocalTensor(local_tensors) @classmethod @@ -591,24 +683,6 @@ def numpy( else: raise RuntimeError("Numpy is not available") - def __lt__( - self, other: torch.Tensor | bool | complex | float | int - ) -> torch.Tensor: - self_rec = self.reconcile() - other_rec = other - if isinstance(other, LocalTensor): - other_rec = other.reconcile() - return self_rec < other_rec - - def __gt__( - self, other: torch.Tensor | bool | complex | float | int - ) -> torch.Tensor: - self_rec = self.reconcile() - other_rec = other - if isinstance(other, LocalTensor): - other_rec = other.reconcile() - return self_rec > other_rec - def contiguous( self, memory_format: torch.memory_format = torch.contiguous_format, @@ -660,6 +734,13 @@ def reconcile(self) -> torch.Tensor: cl.requires_grad_(self.requires_grad) return cl + def _sync_meta(self) -> None: + with no_dispatch(): + (shape, strides, device, dtype, layout, extra_dispatch_keys) = ( + _compute_local_tensor_meta(self._local_tensors) + ) + self._size = shape + _LOCAL_TENSOR_MODE: list["LocalTensorMode"] = [] @@ -753,6 +834,11 @@ def __torch_dispatch__( f"Input LocalTensor {a} and LocalTensorMode must be configured for the same ranks" ) + if func.overloadpacket == torch.ops.aten.dim: + return len(args[0]._size) + if func.overloadpacket == torch.ops.aten.sym_size: + return tuple(args[0]._size) + if func.namespace == "c10d": if func is torch.ops.c10d.allreduce_.default: return _c10d._local_all_reduce_(*args, **kwargs) diff --git a/torch/distributed/_serialization.py b/torch/distributed/_serialization.py index d9c3bfe6b8d5a..c13ba46ba5757 100644 --- a/torch/distributed/_serialization.py +++ b/torch/distributed/_serialization.py @@ -41,7 +41,7 @@ def write_to(self, f: BufferedIOBase) -> None: pickle.dump(entries, f, protocol=DEFAULT_PROTOCOL) - for key, (data, length) in self.records.items(): + for data, _ in self.records.values(): if isinstance(data, bytes): f.write(data) elif isinstance(data, str): diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py index af4f4f890e901..3a8a05fe79d19 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py @@ -128,18 +128,19 @@ def _handle_col_wise_sharding_base( # run the operator's function for all the inputs. results = [] for i, inp in enumerate(gathered_inputs): - if op_func == torch.nn.functional.embedding_bag: + if op_func is torch.nn.functional.embedding_bag: result = op_func( inp, local_shard, offsets=gathered_offsets[i] if gathered_offsets is not None else None, + # pyrefly: ignore [bad-argument-type] mode=mode, per_sample_weights=gathered_per_sample_weights[i] if gathered_per_sample_weights is not None else None, padding_idx=padding_idx, ) - elif op_func == torch.nn.functional.embedding: + elif op_func is torch.nn.functional.embedding: result = op_func( inp, local_shard, diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index bcf5674833439..3d7b91090b125 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -184,12 +184,18 @@ def _iterate_state_dict( if companion_obj is not None: if isinstance(companion_obj, DTensor): - assert isinstance(ret, DTensor) + if not isinstance(ret, DTensor): + raise AssertionError( + "ret must be a DTensor when companion_obj is a DTensor" + ) companion_obj._local_tensor.copy_( ret._local_tensor, non_blocking=non_blocking ) elif isinstance(companion_obj, ShardedTensor): - assert isinstance(ret, ShardedTensor) + if not isinstance(ret, ShardedTensor): + raise AssertionError( + "ret must be a ShardedTensor when companion_obj is a ShardedTensor" + ) for idx, shard in enumerate(companion_obj.local_shards()): shard.tensor.copy_( ret.local_shards()[idx].tensor, non_blocking=non_blocking @@ -548,7 +554,8 @@ def _broadcast_tensors( for key in keys: if dist.get_rank() == 0: full_state = full_state_dict[key] - assert isinstance(full_state, torch.Tensor) + if not isinstance(full_state, torch.Tensor): + raise AssertionError("full_state must be a torch.Tensor") full_tensor = full_state.detach().to(pg_device) else: tensor_info = full_state_dict[key] @@ -707,7 +714,8 @@ def _distribute_state_dict( elif value.dim() == 0: local_state_dict[key] = value.cpu() else: - assert isinstance(value, torch.Tensor) + if not isinstance(value, torch.Tensor): + raise AssertionError("value must be a torch.Tensor") local_state = local_state_dict.get(key) if local_state is None: continue diff --git a/torch/distributed/_tools/fake_collectives.py b/torch/distributed/_tools/fake_collectives.py index 3b201b395334b..18bb1a02a0055 100644 --- a/torch/distributed/_tools/fake_collectives.py +++ b/torch/distributed/_tools/fake_collectives.py @@ -276,14 +276,14 @@ def get_comm_tensor_size(func, res, args, kwargs) -> int: # type: ignore[no-unt return res.untyped_storage().nbytes() if func in CollectiveOp.COMM_TENSOR_SINGLE_UNTYPED_STORAGE: return args[0].untyped_storage().nbytes() - if func == c10d._reduce_scatter_base_.default: + if func is c10d._reduce_scatter_base_.default: return args[1].untyped_storage().nbytes() - if func == c10d.alltoall_.default: + if func is c10d.alltoall_.default: # TODO(@sanketpurandare) - Confirm size computation return max( CollectiveOp.sum_tensors(args[0]), CollectiveOp.sum_tensors(args[1]) ) - if func == c10d.alltoall_base_.default: + if func is c10d.alltoall_base_.default: # TODO(@sanketpurandare) - Confirm size computation return max( args[0].untyped_storage().nbytes(), args[1].untyped_storage().nbytes() diff --git a/torch/distributed/_tools/fsdp2_mem_tracker.py b/torch/distributed/_tools/fsdp2_mem_tracker.py index 60ff77d0d4972..9a749922be939 100644 --- a/torch/distributed/_tools/fsdp2_mem_tracker.py +++ b/torch/distributed/_tools/fsdp2_mem_tracker.py @@ -507,7 +507,7 @@ def __exit__(self, *args: Any) -> None: def __torch_dispatch__(self, func, types, args=..., kwargs=None): # type: ignore[no-untyped-def] if ( - func == torch.ops._c10d_functional.wait_tensor.default + func is torch.ops._c10d_functional.wait_tensor.default and active_fake_mode() ): # N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns @@ -525,7 +525,7 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None): # type: ignor reftype = _FSDPRefType.TEMP else: reftype = _FSDPRefType.ACT - if func == c10d._allgather_base_.default and self._fsdp_state in [ + if func is c10d._allgather_base_.default and self._fsdp_state in [ _FSDPState.PRE_FW, _FSDPState.PRE_BW, ]: @@ -537,7 +537,7 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None): # type: ignor update_existing=True, ) if ( - func == c10d._reduce_scatter_base_.default + func is c10d._reduce_scatter_base_.default and self._fsdp_state == _FSDPState.POST_BW ): # pyrefly: ignore [unsupported-operation] diff --git a/torch/distributed/_tools/mem_tracker.py b/torch/distributed/_tools/mem_tracker.py index 04f5482d7d128..68952c33a6d72 100644 --- a/torch/distributed/_tools/mem_tracker.py +++ b/torch/distributed/_tools/mem_tracker.py @@ -930,7 +930,7 @@ def __exit__(self, *args: Any) -> None: def __torch_dispatch__(self, func, types, args=(), kwargs=None): # type: ignore[no-untyped-def] if ( - func == torch.ops._c10d_functional.wait_tensor.default + func is torch.ops._c10d_functional.wait_tensor.default and active_fake_mode() ): # N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns diff --git a/torch/distributed/_tools/memory_tracker.py b/torch/distributed/_tools/memory_tracker.py index 1dc01f62d94e7..890d2be2794a4 100644 --- a/torch/distributed/_tools/memory_tracker.py +++ b/torch/distributed/_tools/memory_tracker.py @@ -26,7 +26,7 @@ def __init__(self, memory_tracker) -> None: def __torch_dispatch__(self, func, types, args=..., kwargs=None): rs = func(*args, **kwargs) - if func == torch.ops.aten.detach.default: + if func is torch.ops.aten.detach.default: return rs func_name: str = ( self.memory_tracker._cur_module_name diff --git a/torch/distributed/algorithms/_quantization/quantization.py b/torch/distributed/algorithms/_quantization/quantization.py index a1fa1fd64c060..69d8860456135 100644 --- a/torch/distributed/algorithms/_quantization/quantization.py +++ b/torch/distributed/algorithms/_quantization/quantization.py @@ -111,7 +111,7 @@ def wrapper(*args, **kwargs): async_op = kwargs.get("async_op", False) if async_op is True: raise RuntimeError("The async_op=True mode is not supported yet.") - if func == dist.all_gather: + if func is dist.all_gather: tensors = args[0] input_tensors = _quantize_tensor(args[1], qtype) out_tensors = _quantize_tensor_list(tensors, qtype) @@ -121,7 +121,7 @@ def wrapper(*args, **kwargs): ): tensors[i] = t - elif func == dist.all_to_all: + elif func is dist.all_to_all: tensors = args[0] input_tensors = _quantize_tensor_list(args[1], qtype) out_tensors = _quantize_tensor_list(tensors, qtype) @@ -131,7 +131,7 @@ def wrapper(*args, **kwargs): ): tensors[i] = t - elif func == dist.all_to_all_single: + elif func is dist.all_to_all_single: tensors = args[0] out_splits = kwargs.get("out_splits") in_splits = kwargs.get("in_splits") diff --git a/torch/distributed/checkpoint/_async_process_executor.py b/torch/distributed/checkpoint/_async_process_executor.py index f7c045cdd27b4..fd6876f506127 100644 --- a/torch/distributed/checkpoint/_async_process_executor.py +++ b/torch/distributed/checkpoint/_async_process_executor.py @@ -10,6 +10,7 @@ import torch.distributed as dist import torch.multiprocessing as mp +from torch.distributed import PrefixStore, TCPStore from torch.distributed.checkpoint._async_executor import _AsyncCheckpointExecutor from torch.distributed.checkpoint.logger import _dcp_method_logger, _init_logger from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE @@ -55,15 +56,17 @@ class _ProcessGroupInitInfo: world_size: int tcp_store_master_addr: str tcp_store_master_port: int + use_prefix_store: bool def __init__(self, process_group: Optional[dist.ProcessGroup] = None): self.local_rank = dist.get_node_local_rank(fallback_rank=0) self.global_rank = dist.get_rank(process_group) self.world_size = dist.get_world_size(process_group) + self.use_prefix_store = os.environ.get("DCP_USE_PREFIX_STORE", "0") == "1" - # Let coordinator rank find a free port on the localhost. - # Broadcast the (master_addr, free_port) to all ranks; each rank in the - # checkpoint daemon process will use TCPStore (master_addr, master_port) + # Let coordinator rank find a port on the localhost. + # Broadcast the (master_addr, port) to all ranks; each rank in the + # checkpoint daemon process will use TCPStore (master_addr, port) # for collective communication. dist_wrapper: _DistWrapper = _DistWrapper( group=process_group, @@ -72,10 +75,23 @@ def __init__(self, process_group: Optional[dist.ProcessGroup] = None): ) def get_master_addr_and_port() -> tuple[str, int]: - master_addr = os.environ.get("MASTER_ADDR") - if master_addr is None: - master_addr = _get_fq_hostname() - return master_addr, get_free_port() + if self.use_prefix_store: + master_addr = os.environ.get("MASTER_ADDR") + master_port = os.environ.get("MASTER_PORT") + assert master_addr is not None, ( + "DCP needs MASTER_ADDR to use prefix store" + ) + assert master_port is not None, ( + "DCP needs MASTER_PORT to use prefix store" + ) + master_port = int(master_port) + else: + master_addr = os.environ.get("MASTER_ADDR") + if master_addr is None: + master_addr = _get_fq_hostname() + master_port = get_free_port() + + return master_addr, master_port self.tcp_store_master_addr, self.tcp_store_master_port = dist_wrapper.broadcast( step="get_master_addr_and_port", @@ -221,10 +237,29 @@ def _checkpointing_subprocess( os.environ["WORLD_SIZE"] = str(pg_init_info.world_size) logger.info( - "Initializing dist.ProcessGroup in checkpoint background process" + "Initializing dist.ProcessGroup in checkpoint background process on port %s", + pg_init_info.tcp_store_master_port, ) # NOTE: GLOO backend is enforced here. - dist.init_process_group(backend=dist.Backend.GLOO) + if pg_init_info.use_prefix_store: + logger.info( + "Initializing dist.ProcessGroup in checkpoint background process with prefix store" + ) + store = PrefixStore( + "AsyncCheckpointProcess/", + TCPStore( + pg_init_info.tcp_store_master_addr, + pg_init_info.tcp_store_master_port, + ), + ) + dist.init_process_group( + backend=dist.Backend.GLOO, + store=store, + world_size=pg_init_info.world_size, + rank=pg_init_info.global_rank, + ) + else: + dist.init_process_group(backend=dist.Backend.GLOO) dist.barrier() logger.info("Checkpoint background process is running...") @@ -365,7 +400,7 @@ def execute_save( global _CHECKPOINT_PROCESS pg_init_info: Optional[_ProcessGroupInitInfo] = None if _CHECKPOINT_PROCESS is None: - # Find a free port on coordinator rank and broadcast + # Find a port on coordinator rank and broadcast # to all ranks. pg_init_info = _ProcessGroupInitInfo(process_group) diff --git a/torch/distributed/checkpoint/_checkpointer.py b/torch/distributed/checkpoint/_checkpointer.py index d54de9092a93f..13b0d627a36cc 100644 --- a/torch/distributed/checkpoint/_checkpointer.py +++ b/torch/distributed/checkpoint/_checkpointer.py @@ -17,7 +17,7 @@ class _Checkpointer: - """This base class specefies a high level API for saving and loading + """This base class specifies a high level API for saving and loading distributed `state_dict` 's. It provides an abstraction over the low-level APIs provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling :py:meth: `torch.distributed.state_dict_saver.save` and diff --git a/torch/distributed/checkpoint/format_utils.py b/torch/distributed/checkpoint/format_utils.py index 129b7cf570c1d..912f983fe2a7c 100644 --- a/torch/distributed/checkpoint/format_utils.py +++ b/torch/distributed/checkpoint/format_utils.py @@ -80,7 +80,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: planner = cast(DefaultLoadPlanner, planner) # data is read in on the coordinator rank, and broadcast afterwards - # this incurrs a communication cost, but it avoids having to load + # this incurs a communication cost, but it avoids having to load # the entire checkpoint on each rank, hopefully preventing OOM issues # TODO: read on each host, instead of only the coordinator if self.is_coordinator: diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 479027a2ea9a5..16d988a79103e 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -578,7 +578,7 @@ def _load_model_state_dict( assign = False if info.broadcast_from_rank0 or info.full_state_dict: devices = set() - for key, value in local_state_dict.items(): + for value in local_state_dict.values(): if torch.is_tensor(value) and value.dim() > 0: devices.add(value.device) # In lora state_dict, there could be multiple devices, with meta device inside. diff --git a/torch/distributed/collective_utils.py b/torch/distributed/collective_utils.py index b45f0b5cbb4ff..e608e26a3a854 100644 --- a/torch/distributed/collective_utils.py +++ b/torch/distributed/collective_utils.py @@ -104,7 +104,10 @@ def broadcast( if pg is not None: broadcast_list = [sync_obj] dist.broadcast_object_list(broadcast_list, src=rank, group=pg) - assert len(broadcast_list) == 1 + if len(broadcast_list) != 1: + raise AssertionError( + f"Expected broadcast_list to have exactly 1 element, got {len(broadcast_list)}" + ) sync_obj = broadcast_list[0] # failure in any rank will trigger a throw in every rank. @@ -240,8 +243,10 @@ def all_gather_object_enforce_type( def _summarize_ranks(ranks: Iterable[int]) -> str: ranks = sorted(ranks) - assert min(ranks) >= 0, "ranks should all be positive" - assert len(set(ranks)) == len(ranks), "ranks should not contain duplicates" + if min(ranks) < 0: + raise AssertionError("ranks should all be positive") + if len(set(ranks)) != len(ranks): + raise AssertionError("ranks should not contain duplicates") curr: Optional[Union[int, range]] = None ranges = [] while ranks: @@ -255,7 +260,8 @@ def _summarize_ranks(ranks: Iterable[int]) -> str: step = x - curr curr = range(curr, x + step, step) else: - assert isinstance(curr, range) + if not isinstance(curr, range): + raise AssertionError("curr must be an instance of range") if x == curr.stop: curr = range(curr.start, curr.stop + curr.step, curr.step) else: diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index cab3e71d32068..a161a4394a93d 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -237,6 +237,28 @@ def __init__( f"backend_override should have the same length as the number of mesh dimensions, " f"but got {len(backend_override)} and {len(self._layout)}." ) + # Internal bookkeeping for the device mesh. + self._layout = ( + _layout + if _layout + else _MeshLayout(self.mesh.size(), self.mesh.stride()) + ) + if not self._layout.check_non_overlap(): + raise AssertionError( + "Please use a non-overlapping layout when creating a DeviceMesh." + ) + # Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here. + if self._layout.numel() != self.mesh.numel(): + raise AssertionError( + "Please use a valid layout when creating a DeviceMesh." + f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}." + ) + + # private field to pre-generate DeviceMesh's hash + self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) + self._thread_id = None + # Initialize instance-specific flatten mapping + self._flatten_mapping = {} # Skip process group initialization if xla device or init backend is False # TODO(yeounoh) implement DeviceMesh backend and register XLA backend. @@ -263,7 +285,10 @@ def __init__( # calculate the coordinates of the current global rank on the mesh rank_coords = (self.mesh == _rank).nonzero() - assert rank_coords.size(0) in (0, 1) + if rank_coords.size(0) not in (0, 1): + raise AssertionError( + f"rank_coords.size(0) must be 0 or 1, got {rank_coords.size(0)}" + ) self._coordinate_on_dim: Optional[list[int]] = ( rank_coords[0].tolist() if rank_coords.size(0) > 0 else None ) @@ -350,22 +375,33 @@ def _setup_world_group_and_device(self): return _get_default_group() @staticmethod - def _init_process_groups( - layout: _MeshLayout, + def _init_one_process_group( + sub_layout: _MeshLayout, rank_map: torch.Tensor, - mesh_dim_names: Optional[tuple[str, ...]], - backend_override: tuple[BackendConfig, ...], - ) -> list[str]: - # group_name associated with each mesh dimension, each - # mesh dimension should have one sub-group per rank - # - dim_group_names: list[str] = [] + dim_name: str, + backend_override: BackendConfig, + ) -> Optional[str]: + # Generate a 2D global mesh tensor for the current dim for PG creation. + pg_ranks_by_dim = sub_layout.nest().remap_to_tensor(rank_map) + backend, pg_options = backend_override + # We need to explicitly pass in timeout when specified in option, otherwise + # the default timeout will be used to override the timeout set in option. + # TODO: remove this once we have fixed inside c10d level. + timeout = pg_options._timeout if pg_options else None + + # If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description + # of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`. + # If the mesh doesn't have a mesh_dim_names, then the group description of the + # subgroup would be `mesh_dim_0` and `mesh_dim_1`. + group_desc = f"mesh_{dim_name}" + + dim_group = None default_group = _get_default_group() - if ( - len(layout) == 1 - and layout.numel() == get_world_size() - and backend_override[0] == (None, None) + # Early return if there is only one sub_layout in the mesh layout. + if sub_layout.numel() == get_world_size() and backend_override == ( + None, + None, ): # Append the default pg to the first dim groups only if the default pg is compatible with `self._device_type`. # Otherwise, create new pg. @@ -380,90 +416,80 @@ def _init_process_groups( and get_backend(default_group) == "gloo" else default_group ) - dim_group_names.append(dim_group.group_name) - else: - # create sub pgs base on the mesh argument specified - for dim in range(len(layout)): - # swap the current dim to the last dim - # then reshape to flatten out other dims - pg_ranks_by_dim = layout[dim].nest().remap_to_tensor(rank_map) - backend, pg_options = backend_override[dim] - # We need to explicitly pass in timeout when specified in option, otherwise - # the default timeout will be used to override the timeout set in option. - # TODO: remove this once we have fixed inside c10d level. - timeout = pg_options._timeout if pg_options else None - - # If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description - # of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`. - # If the mesh doesn't not have a mesh_dim_names, then the group description of the - # subgroup would be `mesh_dim_0` and `mesh_dim_1`. - group_desc = ( - f"mesh_{mesh_dim_names[dim]}" - if mesh_dim_names - else f"mesh_dim_{dim}" - ) + return dim_group.group_name # type: ignore[union-attr] + + # If bound_device_id exists, it means the nccl communicator has been eagerly initialized + # so that we can use `split_group` to create subgroups through `ncclCommSplit`. + # In this case, we only need to make one API call (`split_group``) for the subgroup creation + # for each mesh dimension. In a 2 * 4 mesh, we only need to make two API calls per ranks to create + # all the subgroups. + # Otherwise, we need to make more than one API call (`new_group`) for subgroup creations. The + # numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4 + # mesh, we need to make two API calls per ranks to create all the subgroups. + if ( + getattr(default_group, "bound_device_id", None) is not None + and torch.cuda.is_available() + and ( + backend is None + or default_group._get_backend(torch.device("cuda")).name() + == backend + ) + ): + dim_group = split_group( + parent_pg=default_group, + timeout=timeout, + pg_options=pg_options, + split_ranks=pg_ranks_by_dim.tolist(), + group_desc=group_desc, + ) + return dim_group.group_name # type: ignore[union-attr] + + # If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim` + # and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup. + # Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim` + # along with appending information to the `dim_group_names` list whenever necessary. + pg_name = None + for dim_mesh in pg_ranks_by_dim: + subgroup_ranks = dim_mesh.tolist() + dim_group = new_group( + ranks=subgroup_ranks, + timeout=timeout, + backend=backend, + pg_options=pg_options, + group_desc=group_desc, + ) - # If bound_device_id exists, it means the nccl communicator has been eagerly initialized - # so that we can use `split_group` to create subgroups through `ncclCommSplit`. - # In this case, we only need to make one API call (`split_group``) for the subgroup creation - # for each mesh dimension. In a 2 * 4 mesh, we only need to make 2 API calls per ranks to create - # all the subgroups. - # Otherwise, we need to make more than one API call (`new_group`) for subgroup creations. The - # numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4 - # mesh, we need to make 2 + 4 = 6 API calls per ranks to create all the subgroups. - dim_group = None - has_split_group = False - if ( - ( - bound_device_id := getattr( - default_group, "bound_device_id", None - ) - ) - is not None - and torch.cuda.is_available() - and ( - backend is None - or default_group._get_backend(torch.device("cuda")).name() - == backend - ) - ): - dim_group = split_group( - parent_pg=default_group, - timeout=timeout, - pg_options=pg_options, - split_ranks=pg_ranks_by_dim.tolist(), - group_desc=group_desc, + # only add to dim_groups if the current rank in the subgroup + if get_rank() in subgroup_ranks: + if pg_name is not None: + raise RuntimeError( + f"Each device mesh dimension should get only one process group, but got {get_rank()} " + f"in {subgroup_ranks}!" ) - has_split_group = True - - # If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim` - # and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup. - # Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim` - # along with appending information to the `dim_group_names` list whenever necessary. - for dim_mesh in pg_ranks_by_dim: - subgroup_ranks = dim_mesh.tolist() - - # We temporarily revert the reuse subgroup, since it breaks two internal tests. - # Temporarily reverting to resolve test timeout while root-causing. - # TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists. - # pyrefly: ignore [unbound-name] - if bound_device_id is None or not has_split_group: - dim_group = new_group( - ranks=subgroup_ranks, - timeout=timeout, - backend=backend, - pg_options=pg_options, - group_desc=group_desc, - ) - - # only add to dim_groups if the current rank in the subgroup - if get_rank() in subgroup_ranks: - if len(dim_group_names) > dim: - raise RuntimeError( - f"Each device mesh dimension should get only one process group, but got {get_rank()} " - f"in {subgroup_ranks}!" - ) - dim_group_names.append(dim_group.group_name) # type: ignore[union-attr] + pg_name = dim_group.group_name + return pg_name + + @staticmethod + def _init_process_groups( + layout: _MeshLayout, + rank_map: torch.Tensor, + mesh_dim_names: Optional[tuple[str, ...]], + backend_override: tuple[BackendConfig, ...], + ) -> list[str]: + # group_name associated with each mesh dimension, each + # mesh dimension should have one sub-group per rank + dim_group_names: list[str] = [] + # create sub pgs base on the mesh argument specified + for dim in range(len(layout)): + dim_name = mesh_dim_names[dim] if mesh_dim_names else f"dim_{dim}" + dim_group_names.append( + DeviceMesh._init_one_process_group( # type: ignore[arg-type] + layout[dim], rank_map, dim_name, backend_override[dim] + ) + ) + if any(n is None for n in dim_group_names): + assert all(n is None for n in dim_group_names) + return [] return dim_group_names def _get_root_mesh(self) -> "DeviceMesh": @@ -629,7 +655,10 @@ def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup: if isinstance(mesh_dim, str) else mesh_dim ) - assert isinstance(mesh_dim, int) + if not isinstance(mesh_dim, int): + raise AssertionError( + f"mesh_dim must be an int, got {type(mesh_dim)}" + ) return not_none(_resolve_process_group(self._dim_group_names[mesh_dim])) def get_all_groups(self) -> list[ProcessGroup]: @@ -736,9 +765,8 @@ def _get_root_mesh_dim(self) -> Optional[int]: root_mesh = self._get_root_mesh() child_mesh_dim_names = self._mesh_dim_names if root_mesh and child_mesh_dim_names: - assert len(child_mesh_dim_names) == 1, ( - "The submesh can only be a 1D mesh." - ) + if len(child_mesh_dim_names) != 1: + raise AssertionError("The submesh can only be a 1D mesh.") child_mesh_dim_name = child_mesh_dim_names[0] return root_mesh._get_mesh_dim_by_name(child_mesh_dim_name) return None @@ -764,12 +792,6 @@ def _get_slice_mesh_layout( """ slice_from_root = True if self != self._get_root_mesh(): - warnings.warn( - "You are attempting to slice a submesh from another submesh. While we support this operation, " - "it is users' responsibility to ensure that the submesh is consistently sliced across all ranks. " - "If not, this may result in some ranks receiving the submesh while others encounter errors.", - stacklevel=2, - ) slice_from_root = False # The slice mesh_dim_names should consist either the current device_mesh's mesh_dim_names @@ -1023,9 +1045,10 @@ def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int: mesh_dim = 0 mesh_dim_group = not_none(self.get_group(mesh_dim)) - assert isinstance(mesh_dim_group, ProcessGroup), ( - "We expect ProcessGroup before calling `get_rank`!" - ) + if not isinstance(mesh_dim_group, ProcessGroup): + raise AssertionError( + "We expect ProcessGroup before calling `get_rank`!" + ) return not_none(get_rank(mesh_dim_group)) def get_coordinate(self) -> Optional[list[int]]: diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 0cebfaff6d63a..bc79408a32ff9 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1550,7 +1550,8 @@ def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> group = _get_default_group() if _rank_not_in_group(group): raise ValueError("Invalid process group specified") - assert isinstance(group, ProcessGroup) + if not isinstance(group, ProcessGroup): + raise AssertionError(f"Expected ProcessGroup, got {type(group)}") devices = group._device_types backends = set() if torch.device("cpu") in devices and is_gloo_available(): @@ -1583,6 +1584,7 @@ def init_process_group( group_name: str = "", pg_options: Optional[Any] = None, device_id: Optional[Union[torch.device, int]] = None, + _ranks: Optional[list[int]] = None, ) -> None: """ Initialize the default distributed process group. @@ -1657,6 +1659,8 @@ def init_process_group( want to know NCCL initialization error early, you can also use this field. If an `int` is provided, the API assumes that the accelerator type at compile time will be used. + _ranks: The ranks in the process group. If provided, the process + group name will be the hash of all the ranks in the group. .. note:: To enable ``backend == Backend.MPI``, PyTorch needs to be built from source on a system that supports MPI. @@ -1691,13 +1695,14 @@ def init_process_group( if "torch._dynamo" in sys.modules: torch._dynamo.trace_rules.clear_lru_cache() - assert (store is None) or (init_method is None), ( - "Cannot specify both init_method and store." - ) + if not ((store is None) or (init_method is None)): + raise AssertionError("Cannot specify both init_method and store.") if store is not None: - assert world_size > 0, "world_size must be positive if using store" - assert rank >= 0, "rank must be non-negative if using store" + if not world_size > 0: + raise AssertionError("world_size must be positive if using store") + if not rank >= 0: + raise AssertionError("rank must be non-negative if using store") elif init_method is None: init_method = "env://" @@ -1761,7 +1766,10 @@ def init_process_group( internals of c10d. This means we can ignore the value they provide as it not exposed in a public way. """ - group_name = _process_group_name([], use_hashed_name=False) + if _ranks is None or len(_ranks) == 0: + group_name = _process_group_name([], use_hashed_name=False) + else: + group_name = _process_group_name(_ranks, use_hashed_name=True) if backend == Backend.MPI: if world_size != -1 or rank != -1: warnings.warn( @@ -1972,7 +1980,8 @@ def _new_process_group_helper( backend_config = BackendConfig(backend) # Set the default backend when single backend is passed in. if "," not in str(backend) and ":" not in str(backend): - assert backend in Backend.backend_type_map, f"Unknown backend type {backend}" + if backend not in Backend.backend_type_map: + raise AssertionError(f"Unknown backend type {backend}") if backend == Backend.UNDEFINED: # Currently when backend is UNDEFINED, only one backend will be initialized # we use nccl (if cuda is available) or gloo as default backend @@ -2042,9 +2051,10 @@ def _new_process_group_helper( if not is_nccl_available(): raise RuntimeError("Distributed package doesn't have NCCL built in") if backend_options is not None: - assert isinstance(backend_options, ProcessGroupNCCL.Options), ( - "Expected backend_options argument to be of type ProcessGroupNCCL.Options" - ) + if not isinstance(backend_options, ProcessGroupNCCL.Options): + raise AssertionError( + "Expected backend_options argument to be of type ProcessGroupNCCL.Options" + ) if backend_options._timeout != timeout: warnings.warn( "backend_options._timeout was specified, " @@ -2095,9 +2105,8 @@ def _new_process_group_helper( ) backend_type = ProcessGroup.BackendType.XCCL else: - assert backend_str.upper() in Backend._plugins, ( - f"Unknown c10d backend type {backend_str.upper()}" - ) + if backend_str.upper() not in Backend._plugins: + raise AssertionError(f"Unknown c10d backend type {backend_str.upper()}") backend_plugin = Backend._plugins[backend_str.upper()] creator_fn = backend_plugin.creator_fn @@ -2122,10 +2131,16 @@ def _new_process_group_helper( # Set sequence numbers for gloo and nccl backends. if backend_str == Backend.GLOO: - assert isinstance(backend_class, ProcessGroupGloo) + if not isinstance(backend_class, ProcessGroupGloo): + raise AssertionError( + f"Expected ProcessGroupGloo, got {type(backend_class)}" + ) backend_class._set_sequence_number_for_group() elif backend_str == Backend.NCCL: - assert isinstance(backend_class, ProcessGroupNCCL) + if not isinstance(backend_class, ProcessGroupNCCL): + raise AssertionError( + f"Expected ProcessGroupNCCL, got {type(backend_class)}" + ) backend_class._set_sequence_number_for_group() # If the type is a subclass of ProcessGroup then return this process group immediately @@ -2172,8 +2187,10 @@ def _new_process_group_helper( pg._register_backend(torch.device(device), backend_type, backend_class) # set group_name and group_dsec to backend - assert group_name is not None - assert group_desc is not None + if group_name is None: + raise AssertionError("group_name must not be None") + if group_desc is None: + raise AssertionError("group_desc must not be None") pg._set_group_name(group_name) pg._set_group_desc(group_desc) @@ -2219,7 +2236,8 @@ def destroy_process_group(group: Optional[ProcessGroup] = None): else: pg = group - assert pg is not None + if pg is None: + raise AssertionError("Process group cannot be None") if _world.pg_map.get(pg, None) is None: raise ValueError("Invalid process group specified") @@ -2310,7 +2328,8 @@ def _abort_process_group(group: Optional[ProcessGroup] = None): pg = group or GroupMember.WORLD - assert pg is not None + if pg is None: + raise AssertionError("Process group cannot be None") if _world.pg_map.get(pg, None) is None: raise ValueError("Invalid process group specified or has been destroyed.") @@ -2666,13 +2685,13 @@ def _coalescing_manager( # - coalesced `all_gather_into_tensor` # - coalesced `reduce_scatter_tensor` op0 = op_list[0].op - if op0 == all_reduce: + if op0 is all_reduce: tensors = [op.tensor for op in op_list] all_reduce_opts = AllreduceCoalescedOptions() all_reduce_opts.reduceOp = not_none(op_list[0].redop) all_reduce_opts.asyncOp = async_ops work = group.allreduce_coalesced(tensors, all_reduce_opts) - elif op0 == all_gather_into_tensor: + elif op0 is all_gather_into_tensor: inputs = [] outputs = [] for op in op_list: @@ -2681,7 +2700,7 @@ def _coalescing_manager( all_gather_opts = AllgatherOptions() all_gather_opts.asyncOp = async_ops work = group.allgather_into_tensor_coalesced(outputs, inputs) - elif op0 == reduce_scatter_tensor: + elif op0 is reduce_scatter_tensor: inputs = [] outputs = [] for op in op_list: @@ -2810,7 +2829,7 @@ def batch_isend_irecv(p2p_op_list: list[P2POp]) -> list[Work]: device = p2p_op_list[0].tensor.device def peer_kwarg(op: P2POp) -> dict[str, int]: - key = "group_dst" if op.op == isend else "group_src" + key = "group_dst" if op.op is isend else "group_src" return {key: op.group_peer} if type(group) is ProcessGroup and group._get_backend(device).supports_coalescing: @@ -3368,8 +3387,9 @@ def gather_object( if my_group_rank != group_dst: return - assert object_gather_list is not None, "Must provide object_gather_list on dst rank" - # pyrefly: ignore [unbound-name] + if object_gather_list is None: + raise AssertionError("Must provide object_gather_list on dst rank") + # pyrefly: ignore # unbound-name for i, tensor in enumerate(output_tensors): tensor = tensor.type(torch.uint8) tensor_size = object_size_list[i] @@ -3624,9 +3644,8 @@ def recv_object_list( rank_objects = get_global_rank(group, group_src) else: rank_objects = recv(object_tensor, group=group, group_src=group_src) - assert rank_sizes == rank_objects, ( - "Mismatch in return ranks for object sizes and objects." - ) + if rank_sizes != rank_objects: + raise AssertionError("Mismatch in return ranks for object sizes and objects.") # Deserialize objects using their stored sizes. offset = 0 for i, obj_size in enumerate(object_sizes_tensor): @@ -5035,7 +5054,8 @@ def _create_process_group_wrapper( world_size: int, timeout: timedelta = default_pg_timeout, ): - assert _GLOO_AVAILABLE, "ProcessGroupWrapper unsupported without GLOO backend." + if not _GLOO_AVAILABLE: + raise AssertionError("ProcessGroupWrapper unsupported without GLOO backend.") # (whc) this appears to be just for the gloo backend? if so, `default_pg_timeout` is appropriate... @@ -5237,9 +5257,10 @@ def split_group( split_pg.bound_device_id = device_id # type: ignore[union-attr] split_backend_class = split_pg._get_backend(torch.device("cuda")) split_backend_class._set_sequence_number_for_group() - assert split_pg.group_name == group_name, ( - f"group name should be set to {group_name} but got {split_pg.group_name}" - ) + if split_pg.group_name != group_name: + raise AssertionError( + f"group name should be set to {group_name} but got {split_pg.group_name}" + ) # update global state _world.pg_map[split_pg] = (backend, split_pg.get_group_store()) @@ -5371,9 +5392,10 @@ def _new_group_with_tag( if device_id is None: device_id = default_pg.bound_device_id elif default_pg.bound_device_id is not None: - assert device_id == default_pg.bound_device_id, ( - "Mismatched bound device between new pg and the default pg." - ) + if device_id != default_pg.bound_device_id: + raise AssertionError( + "Mismatched bound device between new pg and the default pg." + ) default_backend, default_store = _world.pg_map[default_pg] global_rank = default_pg.rank() global_world_size = default_pg.size() @@ -5687,22 +5709,25 @@ def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> Optional[ProcessGro def _find_or_create_pg_by_ranks_and_tag( tag: str, ranks: list[int], stride: int ) -> ProcessGroup: - assert len(ranks) % stride == 0, ( - f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})" - ) + if len(ranks) % stride != 0: + raise ValueError( + f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})" + ) my_rank = get_rank() my_ranks = None if stride == len(ranks): my_ranks = ranks.copy() - assert my_rank in my_ranks, "rankset doesn't include the current node" + if my_rank not in my_ranks: + raise AssertionError("rankset doesn't include the current node") else: for i in range(0, len(ranks), stride): rank_set = ranks[i : i + stride] if my_rank in rank_set: my_ranks = rank_set - assert my_ranks is not None, "rankset doesn't include the current node" + if my_ranks is None: + raise AssertionError("rankset doesn't include the current node") my_ranks = sorted(my_ranks) diff --git a/torch/distributed/fsdp/_flat_param.py b/torch/distributed/fsdp/_flat_param.py index 8adde16de6b91..85e4c23d509f8 100644 --- a/torch/distributed/fsdp/_flat_param.py +++ b/torch/distributed/fsdp/_flat_param.py @@ -2087,14 +2087,14 @@ def _use_unsharded_grad_views(self) -> None: param.grad.data = view else: param.grad = view - for i, ( + for ( param_name, module, module_name, prim_param_name, prim_module, _, - ) in enumerate(self.flat_param._shared_param_infos): + ) in self.flat_param._shared_param_infos: _p_assert( hasattr(module, param_name), f"{module_name + '.' + param_name if module_name else param_name} is missing", @@ -2171,11 +2171,8 @@ def _use_sharded_views(self) -> None: param.data = flat_param[offset : offset + numel_in_shard] if self.flat_param._shared_params is None: raise AssertionError("Expected _shared_params to be not None") - for i, ( - param, - (param_name, module, _, prim_param_name, prim_module, _), - ) in enumerate( - zip(self.flat_param._shared_params, self.flat_param._shared_param_infos) + for param, (param_name, module, _, prim_param_name, prim_module, _) in zip( + self.flat_param._shared_params, self.flat_param._shared_param_infos ): self._setattr_param(module, param_name, param) prim_param = getattr(prim_module, prim_param_name) @@ -2388,14 +2385,14 @@ def _writeback_orig_params(self) -> bool: # TODO: If we want to handle shared parameters, we need to re-generate # the shared parameter data structures in case sharedness changed. - for i, ( + for ( param_name, module, _, prim_param_name, prim_module, _, - ) in enumerate(flat_param._shared_param_infos): + ) in flat_param._shared_param_infos: if getattr(module, param_name) is not getattr(prim_module, prim_param_name): raise NotImplementedError( "Changing shared parameters is not supported yet" diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py index bf3f8eadaaf15..794b755b1f64d 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py @@ -492,7 +492,11 @@ def foreach_reduce( force_sum_reduction_for_comms, ) ) - world_size = reduce_scatter_group.size() + + if reduce_scatter_group is None: + world_size = 1 + else: + world_size = reduce_scatter_group.size() device_handle = _get_device_handle(device.type) current_stream = device_handle.current_stream() @@ -547,7 +551,7 @@ def foreach_reduce( reduce_output.copy_(reduce_scatter_input) reduce_scatter_event = reduce_scatter_stream.record_event() post_reduce_stream = reduce_scatter_stream - if all_reduce_group is not None: # HSDP + if all_reduce_group is not None: # HSDP or DDP/replicate # Accumulations must run in the reduce-scatter stream if not all_reduce_grads: if partial_reduce_output is not None: @@ -690,7 +694,7 @@ def _get_all_gather_input_metadatas( def _get_gradient_divide_factors( - reduce_scatter_group: dist.ProcessGroup, + reduce_scatter_group: Optional[dist.ProcessGroup], all_reduce_group: Optional[dist.ProcessGroup], reduce_dtype: torch.dtype, device_type: str = "", @@ -709,8 +713,11 @@ def _get_gradient_divide_factors( # For fp32/bf16, we do not need to worry about overflow/underflow, so we # use NCCL's built-in division to avoid separate div kernels overflow_risk = reduce_dtype not in (torch.float32, torch.bfloat16) + if reduce_scatter_group is not None: + data_parallel_size = reduce_scatter_group.size() + else: + data_parallel_size = 1 - data_parallel_size = reduce_scatter_group.size() if all_reduce_group is not None: data_parallel_size *= all_reduce_group.size() diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py index d059f697f12ea..476fbd9492894 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py @@ -11,6 +11,7 @@ from torch._prims_common import make_contiguous_strides_for from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp._fully_shard._fsdp_common import DDPMeshInfo from torch.distributed.tensor import DTensor, Replicate, Shard from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta from torch.distributed.tensor.placement_types import _StridedShard, Placement @@ -306,22 +307,29 @@ def _init_sharded_param( f"or 4 (HSDP+EP+TP) but got {self._spmd_mesh.ndim}." ) self._spmd_placements: tuple[Placement, ...] - dp_shard_tp_placement = ( - ( - _StridedShard(shard_dim, split_factor=split_factor) - if split_factor > 1 - else fsdp_placement - ), - *self._tp_spec.placements, - ) - if dp_mesh.ndim == 1: # FSDP - self._spmd_placements = dp_shard_tp_placement - else: # HSDP + if isinstance(self.mesh_info, FSDPMeshInfo): # FSDP or HSDP + dp_shard_tp_placement = ( + ( + _StridedShard(shard_dim, split_factor=split_factor) + if split_factor > 1 + else fsdp_placement + ), + *self._tp_spec.placements, + ) + else: # DDP + dp_shard_tp_placement = ( + (Replicate()), + *self._tp_spec.placements, + ) + if isinstance(self.mesh_info, HSDPMeshInfo): # HSDP if self.mesh_info.replicate_mesh_dim != 0: raise AssertionError( f"Expected replicate_mesh_dim to be 0, got {self.mesh_info.replicate_mesh_dim}" ) self._spmd_placements = (Replicate(),) + dp_shard_tp_placement + else: # FSDP or DDP + self._spmd_placements = dp_shard_tp_placement + self._sharding_spec = DTensorSpec( self._spmd_mesh, self._spmd_placements, @@ -330,10 +338,12 @@ def _init_sharded_param( param_data = cast(DTensor, param)._local_tensor else: self._spmd_mesh = self.mesh_info.mesh - if isinstance(self.mesh_info, HSDPMeshInfo): + if isinstance(self.mesh_info, HSDPMeshInfo): # HSDP self._spmd_placements = (Replicate(), fsdp_placement) - else: + elif isinstance(self.mesh_info, FSDPMeshInfo): # FSDP self._spmd_placements = (fsdp_placement,) + elif isinstance(self.mesh_info, DDPMeshInfo): # DDP + self._spmd_placements = (Replicate(),) self._sharding_spec = DTensorSpec( self._spmd_mesh, self._spmd_placements, @@ -351,8 +361,13 @@ def _init_sharded_param( ) self._orig_size = param_data.size() self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size) - shard_rank = self.mesh_info.shard_mesh_rank - shard_world_size = self.mesh_info.shard_mesh_size + if isinstance(self.mesh_info, FSDPMeshInfo): # FSDP or HSDP + shard_rank = self.mesh_info.shard_mesh_rank + shard_world_size = self.mesh_info.shard_mesh_size + else: # DDP + shard_rank = 0 + shard_world_size = 1 + if shard_dim > 0 and param_data.size(shard_dim) % shard_world_size != 0: # If sharding on nonzero dim, require even sharding for now because # the uneven sharding (1) requires extra copies before/after FSDP @@ -401,12 +416,20 @@ def _init_sharded_post_forward_param_metadata(self, param: torch.Tensor) -> None if mesh_info is None: raise AssertionError("Expected post_forward_mesh_info to not be None") param_data = param._local_tensor if isinstance(param, DTensor) else param - chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0) - self.sharded_post_forward_size = _get_dim_chunked_size( - chunks[mesh_info.shard_mesh_rank], - param_data.size(), - dim=self.fsdp_placement.dim, - ) + if isinstance(mesh_info, FSDPMeshInfo): + chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0) + self.sharded_post_forward_size = _get_dim_chunked_size( + chunks[mesh_info.shard_mesh_rank], + param_data.size(), + dim=self.fsdp_placement.dim, + ) + else: # DDP + chunks = _chunk_with_empty(param_data, 1, dim=0) + self.sharded_post_forward_size = _get_dim_chunked_size( + chunks[0], + param_data.size(), + dim=self.fsdp_placement.dim, + ) self.contiguous_sharded_post_forward_stride = make_contiguous_strides_for( self.sharded_post_forward_size ) @@ -832,9 +855,7 @@ def shard_mesh_from_root(self): if mesh.mesh_dim_names is None: raise AssertionError("Expected mesh_dim_names to not be None") shard_dim_name = mesh.mesh_dim_names[-1] - - root_mesh = mesh._get_root_mesh() - return root_mesh[shard_dim_name] + return mesh[shard_dim_name] def _assert_in_states(self, *states: ShardedState) -> None: if self.sharded_state not in states: diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py index f2eac802bb672..b70a5f06f4ae9 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py @@ -29,6 +29,7 @@ ) from ._fsdp_common import ( compiled_autograd_enabled, + DDPMeshInfo, FSDPMeshInfo, HSDPMeshInfo, is_bw, @@ -315,7 +316,10 @@ def unshard(self, async_op: bool = False): self._wait_all_gather_streams_on_event(self._reshard_after_forward_event) self._reshard_after_forward_event = None - world_size = self._all_gather_process_group.size() + if isinstance(self.mesh_info, FSDPMeshInfo): + world_size = self._all_gather_process_group.size() + else: + world_size = 1 if world_size == 1: # can't skip due to early return in wait_for_unshard if # no self._all_gather_result @@ -356,7 +360,10 @@ def wait_for_unshard(self): if prev_all_gather_state := self.comm_ctx.all_gather_state: self._wait_all_gather_streams_on_event(prev_all_gather_state.event) self.comm_ctx.all_gather_state = None # free the all-gather result - world_size = self._all_gather_process_group.size() + if isinstance(self.mesh_info, FSDPMeshInfo): + world_size = self._all_gather_process_group.size() + else: + world_size = 1 if world_size == 1: # directly initialize unsharded parameters from sharded parameters @@ -531,7 +538,11 @@ def post_backward(self, *unused: Any): self.comm_ctx.reduce_scatter_state.event ) self.comm_ctx.reduce_scatter_state = None - all_reduce_pg = self._all_reduce_process_group if self._is_hsdp else None + all_reduce_pg = ( + self._all_reduce_process_group + if isinstance(self.mesh_info, DDPMeshInfo) + else None + ) all_reduce_stream: torch.cuda.Stream if all_reduce_pg is None and self._all_reduce_hook_stream is not None: # this means the native HSDP is not enabled, @@ -555,14 +566,22 @@ def post_backward(self, *unused: Any): ) = foreach_reduce( fsdp_params_with_grad, unsharded_grads, - self._reduce_scatter_process_group, + ( + self._reduce_scatter_process_group + if isinstance(self.mesh_info, FSDPMeshInfo) + else None # pyre-fixme[6] + ), self.comm_ctx.reduce_scatter_stream, self._reduce_scatter_comm, self._orig_dtype, self._reduce_dtype, self.device, self.gradient_divide_factor, - self._all_reduce_process_group if self._is_hsdp else None, + ( + self._all_reduce_process_group + if isinstance(self.mesh_info, DDPMeshInfo) + else None + ), all_reduce_stream, self.all_reduce_grads, self._partial_reduce_output, @@ -776,9 +795,9 @@ def _reduce_scatter_process_group(self) -> dist.ProcessGroup: @property def _all_reduce_process_group(self) -> dist.ProcessGroup: - if not isinstance(self.mesh_info, HSDPMeshInfo): + if not isinstance(self.mesh_info, DDPMeshInfo): raise AssertionError( - f"Expected mesh_info to be HSDPMeshInfo, got {type(self.mesh_info)}" + f"Expected mesh_info to be DDPMeshInfo or HSDPMeshInfo, got {type(self.mesh_info)}" ) return self.mesh_info.replicate_process_group diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py index b6f7cc4085b16..b75db1b11abbc 100644 --- a/torch/distributed/launcher/api.py +++ b/torch/distributed/launcher/api.py @@ -233,8 +233,9 @@ def launch_agent( " log_dir : %(log_dir)s\n" " metrics_cfg : %(metrics_cfg)s\n" " event_log_handler : %(event_log_handler)s\n" - " numa_options : %(numa_options)s\n", - " duplicate_stdout_filters : %(duplicate_stdout_filters)s\n", + " numa_options : %(numa_options)s\n" + " signals_to_handle : %(signals_to_handle)s\n" + " duplicate_stdout_filters : %(duplicate_stdout_filters)s\n" " duplicate_stderr_filters : %(duplicate_stderr_filters)s\n", { "entrypoint": entrypoint_name, diff --git a/torch/distributed/nn/jit/instantiator.py b/torch/distributed/nn/jit/instantiator.py index 9465eb036daab..a6dee7e61ef57 100644 --- a/torch/distributed/nn/jit/instantiator.py +++ b/torch/distributed/nn/jit/instantiator.py @@ -1,12 +1,8 @@ #!/usr/bin/python3 # mypy: allow-untyped-defs -import atexit -import importlib -import logging -import os +import importlib.abc +import importlib.util import sys -import tempfile -from typing import Optional import torch from torch.distributed.nn.jit.templates.remote_module_template import ( @@ -14,15 +10,7 @@ ) -logger = logging.getLogger(__name__) - - _FILE_PREFIX = "_remote_module_" -_TEMP_DIR = tempfile.TemporaryDirectory() -INSTANTIATED_TEMPLATE_DIR_PATH = _TEMP_DIR.name -atexit.register(_TEMP_DIR.cleanup) -logger.info("Created a temporary directory at %s", INSTANTIATED_TEMPLATE_DIR_PATH) -sys.path.append(INSTANTIATED_TEMPLATE_DIR_PATH) def get_arg_return_types_from_interface(module_interface): @@ -63,40 +51,37 @@ def get_arg_return_types_from_interface(module_interface): return args_str, arg_types_str, return_type_str -def _write(out_path, text): - old_text: Optional[str] - try: - with open(out_path) as f: - old_text = f.read() - except OSError: - old_text = None - if old_text != text: - with open(out_path, "w") as f: - logger.info("Writing %s", out_path) - f.write(text) - else: - logger.info("Skipped writing %s", out_path) +class _StringLoader(importlib.abc.SourceLoader): + def __init__(self, data): + self.data = data + + def get_source(self, fullname): + return self.data + + def get_data(self, path): + return self.data.encode("utf-8") + + def get_filename(self, fullname): + return fullname def _do_instantiate_remote_module_template( generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda ): - generated_code_text = get_remote_module_template( - enable_moving_cpu_tensors_to_cuda - ).format(**str_dict) - out_path = os.path.join( - INSTANTIATED_TEMPLATE_DIR_PATH, f"{generated_module_name}.py" + if generated_module_name in sys.modules: + return sys.modules[generated_module_name] + + loader = _StringLoader( + get_remote_module_template(enable_moving_cpu_tensors_to_cuda).format(**str_dict) + ) + spec = importlib.util.spec_from_loader( + generated_module_name, loader, origin="torch-git" ) - _write(out_path, generated_code_text) - - # From importlib doc, - # > If you are dynamically importing a module that was created since - # the interpreter began execution (e.g., created a Python source file), - # you may need to call invalidate_caches() in order for the new module - # to be noticed by the import system. - importlib.invalidate_caches() - generated_module = importlib.import_module(f"{generated_module_name}") - return generated_module + assert spec is not None + module = importlib.util.module_from_spec(spec) + sys.modules[generated_module_name] = module + loader.exec_module(module) + return module def instantiate_scriptable_remote_module_template( diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 52e56dd3f95ba..62e3764abe055 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -133,7 +133,7 @@ def _insert_stage_symbolic_backward( # In the forward pass, only emit placeholder, module calls, and # getitem calls. If we have a target other than getitem in this # (forward-only) code, there is a bug. - assert node.target == operator.getitem, ( + assert node.target is operator.getitem, ( "Found non-getitem call in forward pass. Please report a bug to PiPPy" ) assert len(node.args) == 2, ( @@ -407,7 +407,7 @@ def dont_traverse_size(a): def call_function(self, target, args, kwargs): # HACK to reroute saved input tensors to point to the detach()ed version - if target == stage_backward: + if target is stage_backward: kwargs = dict(kwargs) kwargs["input_values"] = [ self.value_remap.get(v, v) for v in kwargs["input_values"] @@ -924,7 +924,7 @@ def move_param_to_callee( pass # This is done by (1) `_sink_params` at each submodule; - for name, submod in split.named_children(): + for submod in split.children(): if isinstance(submod, fx.GraphModule): _sink_params(submod, inputs_to_state, []) submod.graph.lint() diff --git a/torch/distributed/pipelining/microbatch.py b/torch/distributed/pipelining/microbatch.py index 9a576d2a829a3..251d53a22bf27 100644 --- a/torch/distributed/pipelining/microbatch.py +++ b/torch/distributed/pipelining/microbatch.py @@ -509,7 +509,9 @@ def merge_chunks( values_to_cat = [] chunk_start_idx = 0 assert len(partial_values) == len(meta_chunks) - for partial_value, meta_chunk in zip(partial_values, meta_chunks): + for partial_value, meta_chunk in zip( + partial_values, meta_chunks, strict=True + ): chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim) slice_indices = [slice(None, None, None)] * partial_value.ndim diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index c18c4d6f67854..6274689945109 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -252,7 +252,7 @@ def _configure_outputs_meta(self, outputs_meta: tuple[torch.Tensor, ...]): self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment] def get_outputs_meta(self) -> tuple[torch.Tensor, ...]: - """Get the output metadata (meta tensors) reprensenting the outputs of this stage""" + """Get the output metadata (meta tensors) representing the outputs of this stage""" assert self._outputs_meta is not None, ( "Attempted to get_outputs_meta() without configuring output meta" ) @@ -723,7 +723,7 @@ def forward_one_chunk( ) self._validate_fwd_outputs(output_tuple) - # We return the original user-provied output, not normalized to tuple. + # We return the original user-provided output, not normalized to tuple. # See [Note: pipeline model output type] return output @@ -1188,7 +1188,7 @@ def find_dst_rank( # No need to send back to rank 0 # - If user.target is stage_backward: # No need to send assuming submod output is stored locally or - # should be re-calucated in case of activation checkpointing + # should be re-calculated in case of activation checkpointing return None def _create_act_send_info(self): diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index 0c9ebc468be16..a65bfa783efc3 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -83,9 +83,10 @@ def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwa world_size = world_size_opt if rank != -1 or world_size != -1 or world_size_opt is None: query_dict = _query_to_dict(result.query) - assert "rank" not in query_dict and "world_size" not in query_dict, ( - f"The url: {url} has node-specific arguments(rank, world_size) already." - ) + if "rank" in query_dict or "world_size" in query_dict: + raise AssertionError( + f"The url: {url} has node-specific arguments(rank, world_size) already." + ) if rank != -1: query_dict["rank"] = str(rank) if world_size != -1 or world_size_opt is None: @@ -227,7 +228,8 @@ def _error(msg): world_size = int(query_dict["world_size"]) use_libuv = _get_use_libuv_from_query_dict(query_dict) - assert result.hostname is not None + if result.hostname is None: + raise AssertionError("hostname cannot be None") store = _create_c10d_store( result.hostname, result.port, rank, world_size, timeout, use_libuv diff --git a/torch/distributed/rpc/rref_proxy.py b/torch/distributed/rpc/rref_proxy.py index 71c111b2f2e65..46eecf19e22c9 100644 --- a/torch/distributed/rpc/rref_proxy.py +++ b/torch/distributed/rpc/rref_proxy.py @@ -40,7 +40,7 @@ def _rref_type_cont(rref_fut): rref_fut = rref._get_type(timeout=timeout, blocking=False) - if rpc_api != rpc_async: + if rpc_api is not rpc_async: rref_fut.wait() return _rref_type_cont(rref_fut) else: diff --git a/torch/distributed/run.py b/torch/distributed/run.py index a076c8d5798a3..cd9820e0e10ea 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -821,8 +821,12 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> type[LogsSpecs]: def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str]]: # If ``args`` not passed, defaults to ``sys.argv[:1]`` min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes) - assert 0 < min_nodes <= max_nodes - assert args.max_restarts >= 0 + if not (0 < min_nodes <= max_nodes): + raise AssertionError( + f"min_nodes must be > 0 and <= max_nodes, got min_nodes={min_nodes}, max_nodes={max_nodes}" + ) + if args.max_restarts < 0: + raise AssertionError("max_restarts must be >= 0") if ( hasattr(args, "master_addr") @@ -862,7 +866,8 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str if args.local_ranks_filter: try: ranks = set(map(int, args.local_ranks_filter.split(","))) - assert ranks + if not ranks: + raise AssertionError("ranks set cannot be empty") except Exception as e: raise ValueError( "--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2" diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index 865de11daccb2..de86d7923ae65 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -967,7 +967,7 @@ def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: if partition_fn is None: # if partition_fn not specified, we by default replicate # all module params/buffers - for name, submod in module.named_modules(): + for submod in module.modules(): replicate_module_params_buffers(submod, device_mesh) else: # apply partition_fun to submodules @@ -1060,10 +1060,10 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def] ) # initialize the local tensor - if init_op == torch.full: + if init_op is torch.full: fill_value = kwargs.pop("fill_value", 0) local_tensor = init_op(local_shape, fill_value, **kwargs) - elif init_op == torch.rand or init_op == torch.randn: + elif init_op is torch.rand or init_op is torch.randn: # this tensor meta is not used except `shape` dtype = kwargs.get("dtype", torch.get_default_dtype()) diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index 8d3a89a1a647b..6fc3cc1d4e670 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -24,7 +24,7 @@ """ from collections.abc import Sequence -from dataclasses import dataclass, field +from dataclasses import dataclass from functools import cached_property from typing import Any, Optional, Union from typing_extensions import deprecated @@ -335,8 +335,6 @@ class OpSchema: _comparison_key: Optional[tuple[object, ...]] = None - has_symints: bool = field(init=False) - @property def args_spec(self) -> tuple[DTensorSpec, ...]: """ diff --git a/torch/distributed/tensor/_ops/_einsum_strategy.py b/torch/distributed/tensor/_ops/_einsum_strategy.py index 506103d70a599..9d46ede21f97b 100644 --- a/torch/distributed/tensor/_ops/_einsum_strategy.py +++ b/torch/distributed/tensor/_ops/_einsum_strategy.py @@ -45,7 +45,7 @@ def parse_dims(cls, input_dims: list[str], output_dim: str) -> "EinsumDims": for input_dim in input_dims: dim_char_set.update(input_dim) - # get a determinisitc order of all dim chars + # get a deterministic order of all dim chars all_dim_chars = sorted(dim_char_set) # parse input and output dimensions @@ -170,7 +170,7 @@ def gen_einsum_strategies( # linearity strategy if linearity: linearity_placement_list: list[Placement] = [Partial()] - for input_dim in input_dims: + for _ in input_dims: linearity_placement_list.append(Partial()) strategies_over_one_mesh_dim.append(linearity_placement_list) diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 93030c7142b3e..45a786b9058e2 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -335,8 +335,8 @@ def common_reduction_strategy( LINEAR_REDUCTION_OP_MAP = { - aten.all.default: "sum", - aten.all.dim: "sum", + aten.all.default: "product", + aten.all.dim: "product", aten.sum.default: "sum", aten.sum.dim_IntList: "sum", aten.any.default: "sum", diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index e3134c26a9158..43722c11c2a99 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -484,7 +484,7 @@ def replicate_tensor_dim( def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType: # 1. number of dimensions in input and src need to match. # 2. number of elements on all non-dim need to match between input and src. - # 3. numer of elements in src in dim need to match the slice size. + # 3. number of elements in src in dim need to match the slice size. # Given the above: # - We suggest for src to follow the sharding of input, except on the scatter dimension, # where our best bet for now is to make them replicated as a fall-back. diff --git a/torch/distributed/tensor/_redistribute.py b/torch/distributed/tensor/_redistribute.py index 463c34c8fb436..a407ba6ca91df 100644 --- a/torch/distributed/tensor/_redistribute.py +++ b/torch/distributed/tensor/_redistribute.py @@ -592,7 +592,7 @@ def generate_greedy_transform_infos( current = current_placements[mesh_dim] target = target_placements[mesh_dim] # If target is not Shard, we can directly redistribute since we - # are traversing from innner to outer placements here + # are traversing from inner to outer placements here if isinstance(target, Shard): # If target is Shard, check for nested sharding on the # tensor dim BEFORE the current mesh_dim diff --git a/torch/distributed/tensor/_utils.py b/torch/distributed/tensor/_utils.py index 7325fc2daf095..d192ddf7c35b3 100644 --- a/torch/distributed/tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -220,63 +220,7 @@ def _compute_local_shape_and_global_offset( return tuple(local_shape), tuple(global_offset) -def compute_global_tensor_info( - tensor: torch.Tensor, mesh: DeviceMesh, placements: Sequence[Placement] -) -> tuple[list[int], list[int]]: - """ - Compute the global size and stride of a DTensor from the given local tensor. - The local size is multiplited by `world_size` per Sharding dim. - The local stride is multiplited by `world_size` per Sharding dim, as long as the - dimension is outside sharding dim. - - For example, if we have a local tensor with size (4, 8, 2) and stride (16, 1, 8). - If the DTensor placements are [Shard(2)] and world_size is 2; - then the global size is (4, 8, 4) and stride is (16 * 2, 1, 8). - - Args: - tensor (:class:`torch.Tensor`): - Local tensor which DTensor will be constructed from. - mesh (:class:`DeviceMesh`): - Object which describes the mesh topology - of devices for the DTensor. - placements (Sequence[:class:`Placement`]]): - The attribute of the DTensor that describes its layout - on the mesh topology. - - Return: - tensor_shape: A List of int which specifies the size of DTensor which build - on top of the local tensor. - tensor_stride: A List of int which specifies the stride of DTensor. - """ - tensor_shape = list(tensor.size()) - tensor_stride = list(tensor.stride()) - for idx, placement in enumerate(placements): - mesh_dim_size = mesh.size(idx) - if placement.is_shard(): - shard_placement = cast(Shard, placement) - if shard_placement.dim < 0: - raise AssertionError( - "Shard placements should have negative dims normalized in " - f"the user-facing APIs: {shard_placement}" - ) - shard_dim = shard_placement.dim - - assert shard_dim < tensor.ndim, ( - f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}." - ) - - local_dim_size = tensor_shape[shard_dim] - tensor_shape[shard_dim] = local_dim_size * mesh_dim_size - - # recover tensor stride by modifying the stride that larger than - # the current stride on the shard_dim - for i in range(len(tensor_stride)): - if i != shard_dim and tensor_stride[i] >= tensor_stride[shard_dim]: - # rescale the stride by the shard size - tensor_stride[i] = tensor_stride[i] * mesh_dim_size - elif not isinstance(placement, (Replicate, Partial)): - raise RuntimeError(f"placement type {type(placement)} not supported!") - return tensor_shape, tensor_stride +compute_global_tensor_info = torch._C._DTensor_compute_global_tensor_info def compute_local_tensor_info( diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index 566390b8a039a..2444467a3595f 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -1,1650 +1,41 @@ -import contextlib -import itertools -import logging -import types -from abc import ABC, abstractmethod -from collections.abc import Callable, Generator, Mapping, Sequence -from dataclasses import dataclass -from enum import auto, Enum -from functools import partial -from typing import Any, cast, Optional, Protocol, TypeAlias - -import torch -import torch.distributed as dist -import torch.distributed._functional_collectives as ft_c -import torch.distributed.distributed_c10d as c10d -import torch.nn as nn -import torch.nn.functional as F -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import distribute_tensor, DTensor, Shard -from torch.distributed.tensor.experimental._load_balancer import ( - _create_default_load_balancer, - _LoadBalancer, +# Copyright (c) Meta Platforms, Inc. and affiliates +# Backward compatibility stub - this module has been moved to _context_parallel/_attention.py + +from ._context_parallel._attention import ( + _CausalBehavior, + _context_parallel_shard, + _ContextParallel, + _cp_options, + _disable_context_parallel_dispatcher, + _enable_context_parallel_dispatcher, + _is_causal_behavior, + _RotateMethod, + context_parallel, + context_parallel_unshard, + set_rotate_method, ) -from torch.distributed.tensor.parallel import ParallelStyle -from torch.nn.attention.flex_attention import ( - _mask_mod_signature, - BlockMask, - create_block_mask, -) -from torch.utils._pytree import tree_flatten, tree_unflatten - -from ._cp_custom_ops import flex_cp_allgather - - -__all__ = ["context_parallel", "set_rotate_method"] - - -class _CausalBehavior(Enum): - SKIP = None - NOT_IS_CAUSAL = False - IS_CAUSAL = True - - -class _RotateMethod(Enum): - ALL_TO_ALL = auto() - ALL_GATHER = auto() - - -aten = torch.ops.aten -logger = logging.getLogger(__name__) - - -class _DispatchMode(Enum): - MONKEY_PATCH = auto() - MODULE_WRAPPER = auto() - - -_dispatch_mode: _DispatchMode = _DispatchMode.MONKEY_PATCH - - -@dataclass -class _ContextParallelOptions: - # Whether to upcast parameters and gradients to float32 to avoid accumulation - # errors. It is likely this is always True, but we currently keep this variable - # for experimental purposes. - convert_to_f32: bool = True - enable_load_balance: bool = True - rotate_method: _RotateMethod = _RotateMethod.ALL_GATHER - - -_cp_options = _ContextParallelOptions() - - -def _is_causal_behavior( - rank: int, world_size: int, i: int, is_causal: bool -) -> _CausalBehavior: - """ - Calculate is_causal behavior for each KV block. The attention can either be - calculated in full, not at all or with the causal mask applied. - """ - if not is_causal: - return _CausalBehavior.NOT_IS_CAUSAL - - if i == 0: - return _CausalBehavior.IS_CAUSAL - - source_rank = (rank - i) % world_size - if source_rank < rank or _cp_options.enable_load_balance: - return _CausalBehavior.NOT_IS_CAUSAL - else: - return _CausalBehavior.SKIP - - -def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor: - """ - When tracing the code, the result tensor is not an AsyncCollectiveTensor, - so we cannot call ``wait()``. - """ - if isinstance(tensor, ft_c.AsyncCollectiveTensor): - return tensor.wait() - return tensor - - -def _partial_update( - original: torch.Tensor, - new: torch.Tensor, - dim: int, - n_chunks: int, - idx: int, - add: bool, -) -> torch.Tensor: - """ - This API partially updates a chunk of ``original`` tensor. The ``original`` - tensor will be first chunked along ``dim`` dimension, then the ``idx`` chunk - will be updated with ``new``. If ``add`` is True, the chunk will be added - with ``new``, otherwise the chunk will be replaced by ``new``. - - The result is a tensor that is the same size as ``original``. - """ - chunks = list(original.chunk(n_chunks, dim=dim)) - assert chunks[idx].shape == new.shape, (original.shape, new.shape, idx) - if add: - chunks[idx] += new - else: - chunks[idx] = new - return torch.cat(chunks, dim=dim) - - -class _SDPAMerger: - """A class to help merge the local SDPA result.""" - - def __init__(self, convert_to_f32: bool, seq_dim: int): - self._seq_dim = seq_dim - self._out: Optional[torch.Tensor] = None - self._lse: Optional[torch.Tensor] = None - self._should_lse_squeeze = False - self._convert_to_f32 = convert_to_f32 - self._out_dtype = torch.float32 - self._lse_dtype = torch.float32 - - def _merge_one( - self, block_out: torch.Tensor, block_lse: torch.Tensor, partial: bool - ) -> None: - # The cuDNN backend preserves the last dimension for LSE. - # Apply unsqueeze only if the input does not already have - # the required dimensionality. - if len(block_lse.shape) < len(block_out.shape): - block_lse = block_lse.unsqueeze(dim=-1) - self._should_lse_squeeze = True - assert len(block_lse.shape) == len(block_out.shape) - - if self._lse is None: - self._lse = block_lse - self._out = block_out - else: - ROUND_ROBIN_CYCLE = 2 - assert self._lse is not None - assert self._out is not None - lse = ( - self._lse.chunk(ROUND_ROBIN_CYCLE, dim=self._seq_dim)[1] - if partial - else self._lse - ) - out = ( - self._out.chunk(ROUND_ROBIN_CYCLE, dim=self._seq_dim)[1] - if partial - else self._out - ) - - # The algorithm from - # github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 - # gives a relatively stable result. - out = out - F.sigmoid(block_lse - lse) * (out - block_out) - lse = lse - F.logsigmoid(lse - block_lse) - if partial: - self._lse = _partial_update( - self._lse, - lse, - dim=self._seq_dim, - n_chunks=ROUND_ROBIN_CYCLE, - idx=1, - add=False, - ) - self._out = _partial_update( - self._out, - out, - dim=self._seq_dim, - n_chunks=ROUND_ROBIN_CYCLE, - idx=1, - add=False, - ) - else: - self._lse = lse - self._out = out - - def step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool) -> None: - self._out_dtype = out.dtype - self._lse_dtype = lse.dtype - - if self._convert_to_f32: - out = out.to(torch.float32) - lse = lse.to(torch.float32) - - self._merge_one(out, lse, partial) - - def results(self) -> tuple[torch.Tensor, torch.Tensor]: - assert self._out is not None - assert self._lse is not None - out = self._out.to(self._out_dtype) - if self._should_lse_squeeze: - lse = self._lse.squeeze(-1).to(self._lse_dtype) - else: - lse = self._lse.to(self._lse_dtype) - return out, lse - - -class _AttentionOp(Protocol): - def __call__( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - **kwargs: object, - ) -> tuple[torch.Tensor, ...]: ... - - -class _RingRotater(ABC): - @abstractmethod - def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: ... - - @abstractmethod - def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: ... - - @abstractmethod - def next_buffer(self) -> torch.Tensor: ... - - -class _AllToAllRotater(_RingRotater): - """Use all_to_all to send the kv to the next rank.""" - - def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: - self._pg = pg - self._seq_dim = seq_dim - self._buffer: Optional[torch.Tensor] = None - - def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: - curr_buffer = curr_buffer.contiguous() - size = dist.get_world_size(self._pg) - dsts = list(range(1, size)) + [0] - self._buffer = ft_c.permute_tensor(curr_buffer, dsts, self._pg) - - def next_buffer(self) -> torch.Tensor: - assert self._buffer is not None - return _maybe_wait(self._buffer) - - -class _AllGatherRotater(_RingRotater): - """ - Allgather the kv and return only the required kv. - Only one communication will be done. - """ - - def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: - self._pg = pg - self._seq_dim = seq_dim - self._aggregated_buffer: Optional[torch.Tensor] = None - self._idx = 0 - - def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: - # We only need to perform allgather once. - self._idx += 1 - if self._aggregated_buffer is None: - self._aggregated_buffer = ft_c.all_gather_tensor( - curr_buffer.contiguous(), gather_dim=0, group=self._pg - ) - - def next_buffer(self) -> torch.Tensor: - rank = dist.get_rank(self._pg) - idx = rank - self._idx - - assert self._aggregated_buffer is not None - self._aggregated_buffer = _maybe_wait(self._aggregated_buffer) - return self._aggregated_buffer.chunk(dist.get_world_size(self._pg))[idx] - - -def _create_rotater( - pg: dist.ProcessGroup, seq_dim: int, method: Optional[_RotateMethod] = None -) -> _RingRotater: - if method is None: - method = _cp_options.rotate_method - - if method == _RotateMethod.ALL_TO_ALL: - return _AllToAllRotater(pg, seq_dim) - elif method == _RotateMethod.ALL_GATHER: - return _AllGatherRotater(pg, seq_dim) - else: - raise NotImplementedError(f"Unknown method {method}") - - -def _templated_ring_attention( - group: dist.ProcessGroup, - seq_dim: int, - op: _AttentionOp, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - is_causal: bool = False, - **kwargs: object, -) -> tuple[torch.Tensor, ...]: - """ - A generalized ring attention implementation that can support multiple attention ops. - - Note [Context parallelism load balance algorithm for causal masking] - ===================== - This explanation uses an example to illustrate the CP algorithm with causal - masking. - - Consider a scenario where the sequence length of q, k, and v is 4 (e.g., - q = (q0, q1, q2, q3)), and there are two ranks. For simplicity, we will discuss - only q and k, as v follows the same pattern as k. - - The diagram below represents a complete QK^T operation without parallelism. - The `****` entries indicate that the result is not required due to causal - masking (e.g., q0k1 is marked as `****`). - - +----+------------------------+ - | | k0 k1 k2 k3 | - +----+------------------------+ - | q0 | q0k0, ****, ****, **** | - | q1 | q1k0, q1k1, ****, **** | - | q2 | q2k0, q2k1, q2k2, **** | - | q3 | q3k0, q3k1, q3k2, q3k3 | - +----+------------------------+ - - ### No Load Balance: - - In this scenario, each rank owns a local chunk of q, k, and v, with each chunk - containing two elements. Rank0 is responsible for managing (q0, q1) and (k0, k1), - while rank1 manages (q2, q3) and (k2, k3). - - First Iteration: Both rank0 and rank1 perform SDPA with their local qkv pairs. - Causal masking is enabled as some results are not required (e.g., q0k1). - - Second Iteration: Local queries remain the same, but local kv pairs are exchanged. - Rank0 now has (q0, q1) and (k2, k3); rank1 has (q2, q3) and (k0, k1). Rank0 performs - no computation, while rank1 computes locally without causal masking since all results - (q2k0, q2k1, q3k0, q3k1) are needed. - - ### Round-robin Load Balance: - - In this setup, each rank owns two local chunks of q, k, and v, with each chunk - containing one element. Rank0 manages (q0, q3) and (k0, k3); Rank1 manages (q1, q2) - and (k1, k2). Although the local chunks are not consecutive, they are concatenated to - enable SDPA to be performed in a single call for each step. Consequently, the chunk() - function may be required to prepare the correct q, k, and v configurations. - - First Iteration: Both ranks perform SDPA with their local qkv pairs, similar to the - no-load-balance case. This iteration corresponds to the `if` of the - (`if, `elif`, `else`) in the implementation. - - Second Iteration: Rank0 now has (q0, q3) and (k1, k2); rank1 has (q1, q2) and - (k0, k3). For rank0, no computation is needed for q0. However, computations for - q3k1 and q3k2 are required, so only q3 is used for SDPA. This corresponds to the - `else` of the (`if`, `elif`, `else`) in the implementation. - For rank1, k3 is not needed for q1 and q2, so only k0 is used for SDPA. This - corresponds to the `elif` of (`if`, `elif`, `else`) in the implementation. - - Parameters - ---------- - op: - The attention op to use - *args: - additional args are passed to the op - **kwargs: - additional kwargs are passed to the op - - Returns - ------- - out: - The merged attention output - softmax_lse: - The logsumexp of the merged attention output - """ - if is_causal and (query.size(2) != key.size(2)): - raise NotImplementedError( - "is_causal requires the same query and context sequence lengths" - ) - if not is_causal and _cp_options.enable_load_balance: - raise RuntimeError("Load balancing requires `is_causal=True`.") - - assert isinstance(group, dist.ProcessGroup), ( - "process group must be single dimension" - ) - rank = dist.get_rank(group) - size = dist.get_world_size(group) - - next_kv = None - - # Without making key and value contiguous(), the loss curve is bad. - # TODO(fegin): figure out why this is a requirement since SDPA does not have - # this requirement. - key = key.contiguous() - value = value.contiguous() - - sdpa_merger = _SDPAMerger(_cp_options.convert_to_f32, seq_dim=seq_dim) - - rest: list[Any] - out: torch.Tensor - logsumexp: torch.Tensor - - rotater = _create_rotater(group, 2) - - for i in range(size): - if i > 0: - # Wait for the kv from the (cp_rank - 1) rank. - next_kv = rotater.next_buffer() - key = next_kv[: key.numel()].reshape(key.shape) - value = next_kv[key.numel() :].reshape(value.shape) - - if i < (size - 1): - # Send the k, v to the next rank - next_kv = torch.cat([key.flatten(), value.flatten()]) - next_kv = rotater.exchange_buffers(next_kv) - - is_causal_behavior = _is_causal_behavior( - rank=rank, world_size=size, i=i, is_causal=is_causal - ) - - # For a detailed understanding of the load balancing algorithm, see - # Note [Context parallelism load balance algorithm for causal masking] - if is_causal_behavior == _CausalBehavior.SKIP: - # If i > rank and load balancing is not turned on. - continue - - if i == 0 or (not _cp_options.enable_load_balance or not is_causal): - # When local balance is enabled, we still need to do SDPA with - # the both local chunks of q, k, v for the first iteration. - q, k, v, partial = (query, key, value, False) - elif i <= rank: - # Round-robin load balancing case, and i <= rank. - # We need to do SDPA with only the first local chunk of k, v. - # Note that q, k, v each contains two local chunks. - ROUND_ROBIN_CYCLE = 2 - q, k, v, partial = ( - query, - key.chunk(ROUND_ROBIN_CYCLE, dim=2)[0], - value.chunk(ROUND_ROBIN_CYCLE, dim=2)[0], - False, - ) - else: - # Round-robin load balancing case, and i > rank. - # We need to do SDPA with only the second half of q, and update - # only the second part of logsumexp. So partial is True. - # Note that q, k, v each contains two chunks. - q, k, v, partial = query.chunk(2, dim=2)[1], key, value, True - - # See https://github.com/pytorch/pytorch/blob/release/2.4/aten/src/ATen/native/native_functions.yaml#L14695 - # for the SDPA kernel definitions. - out, logsumexp, *rest = op( - q, - k, - v, - is_causal=is_causal_behavior.value, - **kwargs, - ) - sdpa_merger.step(out, logsumexp, partial) - - # pyrefly: ignore [unbound-name] - return *sdpa_merger.results(), *rest - - -def _templated_ring_attention_backward( - group: dist.ProcessGroup, - seq_dim: int, - op: _AttentionOp, - grad_out: torch.Tensor, - grad_out_name: str, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - out: torch.Tensor, - logsumexp: torch.Tensor, - is_causal: bool, - **kwargs: Any, -) -> tuple[torch.Tensor, ...]: - """This API implements the backward pass of the ring attention.""" - if not is_causal and _cp_options.enable_load_balance: - raise RuntimeError("Load balancing requires `is_causal=True`.") - rank = dist.get_rank(group) - size = dist.get_world_size(group) - next_kv = None - next_grad_kv = None - rest: list[Any] - grad_query_, grad_key_, grad_value_ = None, None, None - - accum_dtype = torch.float32 if _cp_options.convert_to_f32 else query.dtype - grad_query = torch.zeros_like(query, dtype=accum_dtype) - grad_key = torch.zeros_like(key, dtype=accum_dtype) - grad_value = torch.zeros_like(value, dtype=accum_dtype) - - key = key.contiguous() - value = value.contiguous() - kv_rotater = _create_rotater(group, 2) - dkv_rotater = _create_rotater(group, 2, method=_RotateMethod.ALL_TO_ALL) - for i in range(size): - if i > 0: - # Wait for the kv from the (cp_rank - 1) rank. - buffer = kv_rotater.next_buffer() - pointer = 0 - key = buffer[pointer : pointer + key.numel()].reshape(key.shape) - pointer += key.numel() - value = buffer[pointer : pointer + value.numel()].reshape(value.shape) - pointer += value.numel() - - if i != size - 1: - # Send the kv to the next rank. - next_kv = torch.cat([key.flatten(), value.flatten()]) - kv_rotater.exchange_buffers(next_kv) - - is_causal_behavior = _is_causal_behavior( - rank=rank, world_size=size, i=i, is_causal=is_causal - ) - - if is_causal_behavior != _CausalBehavior.SKIP: - if i == 0 or (not _cp_options.enable_load_balance or not is_causal): - # We need to do SDPA with the full local q, k, v. - q, k, v, out_, dout, lse = (query, key, value, out, grad_out, logsumexp) - elif i <= rank: - # Round-robin load balancing case, and i <= rank. - # We need to do SDPA with only the first half of k, v. - # Note that q, k, v each contains two chunks. - q, k, v, out_, dout, lse = ( - query, - key.chunk(2, dim=seq_dim)[0], - value.chunk(2, dim=seq_dim)[0], - out, - grad_out, - logsumexp, - ) - else: - # Round-robin load balancing case, and i > rank. - # We need to do SDPA with only the second half of q. - # Note that q, k, v each contains two chunks. - q, k, v, out_, dout, lse = ( - query.chunk(2, dim=seq_dim)[1], - key, - value, - out.chunk(2, dim=seq_dim)[1], - grad_out.chunk(2, dim=seq_dim)[1], - # Need to make logsumexp contiguous, otherwise there will - # be numerical error. - logsumexp.chunk(2, dim=seq_dim)[1].contiguous(), - ) - - kwargs[grad_out_name] = dout - # See https://github.com/pytorch/pytorch/blob/release/2.4/aten/src/ATen/native/native_functions.yaml#L14695 - # for the SDPA kernel definitions. - grad_query_, grad_key_, grad_value_, *rest = op( - query=q, - key=k, - value=v, - out=out_, - logsumexp=lse, - is_causal=is_causal_behavior.value, - **kwargs, - ) - else: - grad_query_ = torch.zeros_like(query, dtype=accum_dtype) - grad_key_ = torch.zeros_like(key, dtype=accum_dtype) - grad_value_ = torch.zeros_like(value, dtype=accum_dtype) - - ROUND_ROBIN_CYCLE = 2 - if i == 0: - grad_key += grad_key_ - grad_value += grad_value_ - else: - pointer = 0 - # Wait for the kv gradient from (cp_rank - 1) rank. - next_grad_kv = dkv_rotater.next_buffer() - grad_key = next_grad_kv[pointer : pointer + grad_key.numel()].reshape( - grad_key.shape - ) - pointer += grad_key.numel() - grad_value = next_grad_kv[pointer : pointer + grad_value.numel()].reshape( - grad_value.shape - ) - - if i <= rank and _cp_options.enable_load_balance: - grad_key = _partial_update( - grad_key, - grad_key_, - dim=seq_dim, - n_chunks=ROUND_ROBIN_CYCLE, - idx=0, - add=True, - ) - grad_value = _partial_update( - grad_value, - grad_value_, - dim=seq_dim, - n_chunks=ROUND_ROBIN_CYCLE, - idx=0, - add=True, - ) - else: - grad_key += grad_key_ - grad_value += grad_value_ - - next_grad_kv = torch.cat([grad_key.flatten(), grad_value.flatten()]) - # Send the grad key and grad value to the next rank. - dkv_rotater.exchange_buffers(next_grad_kv) - - if i <= rank or not _cp_options.enable_load_balance: - grad_query += grad_query_ - else: - grad_query = _partial_update( - grad_query, - grad_query_, - dim=seq_dim, - n_chunks=ROUND_ROBIN_CYCLE, - idx=1, - add=True, - ) - - assert grad_key_ is not None - assert grad_value_ is not None - grad_query = grad_query.to(query.dtype) - next_grad_kv = dkv_rotater.next_buffer().to(key.dtype) - grad_key = next_grad_kv[: grad_key.numel()].reshape(grad_key.shape) - grad_value = next_grad_kv[grad_key.numel() :].reshape(grad_value.shape) - return ( - grad_query, - grad_key, - grad_value, - # pyrefly: ignore [unbound-name] - *rest, - ) - - -def _scaled_dot_product_ring_flash_attention( - mesh: DeviceMesh, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - dropout_p: float = 0.0, - is_causal: bool = False, - return_debug_mask: bool = False, - *, - scale: Optional[float] = None, -) -> tuple[torch.Tensor, ...]: - if return_debug_mask: - raise NotImplementedError("return_debug_mask is not supported yet") - - # TODO: remove this hardcoding - seq_dim = 2 - group = mesh.get_group() - return _templated_ring_attention( - group, - seq_dim, - aten._scaled_dot_product_flash_attention, - query=query, - key=key, - value=value, - is_causal=is_causal, - dropout_p=dropout_p, - scale=scale, - ) - - -def _scaled_dot_product_ring_efficient_attention( - mesh: DeviceMesh, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_bias: Optional[torch.Tensor] = None, - compute_log_sumexp: bool = True, - dropout_p: float = 0.0, - is_causal: bool = False, - *, - scale: Optional[float] = None, -) -> tuple[torch.Tensor, ...]: - if attn_bias is not None: - raise NotImplementedError("attn_bias is not supported yet") - - if not compute_log_sumexp: - # CP requires compute_log_sumexp to be True because it always merges LSE - compute_log_sumexp = True - - # TODO: remove this hardcoding - seq_dim = 2 - group = mesh.get_group() - return _templated_ring_attention( - group, - seq_dim, - aten._scaled_dot_product_efficient_attention, - query=query, - key=key, - value=value, - is_causal=is_causal, - attn_bias=attn_bias, - dropout_p=dropout_p, - scale=scale, - compute_log_sumexp=compute_log_sumexp, - ) - - -def _scaled_dot_product_ring_cudnn_attention( - mesh: DeviceMesh, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_bias: Optional[torch.Tensor] = None, - compute_log_sumexp: bool = True, - dropout_p: float = 0.0, - is_causal: bool = False, - return_debug_mask: bool = False, - *, - scale: Optional[float] = None, -) -> tuple[torch.Tensor, ...]: - if attn_bias is not None: - raise NotImplementedError("attn_bias is not supported yet") - - if not compute_log_sumexp: - # CP requires compute_log_sumexp to be True because it always merges LSE - compute_log_sumexp = True - - # TODO: remove this hardcoding - seq_dim = 2 - group = mesh.get_group() - return _templated_ring_attention( - group, - seq_dim, - aten._scaled_dot_product_cudnn_attention, - query=query, - key=key, - value=value, - attn_bias=attn_bias, - compute_log_sumexp=compute_log_sumexp, - dropout_p=dropout_p, - is_causal=is_causal, - return_debug_mask=return_debug_mask, - scale=scale, - ) - - -def _scaled_dot_product_ring_flash_attention_backward( - mesh: DeviceMesh, - grad_out: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - out: torch.Tensor, - logsumexp: torch.Tensor, - cum_seq_q: torch.Tensor, - cum_seq_k: torch.Tensor, - max_q: int, - max_k: int, - dropout_p: float, - is_causal: bool, - philox_seed: torch.Tensor, - philox_offset: torch.Tensor, - *, - scale: Optional[float] = None, -) -> tuple[torch.Tensor, ...]: - # TODO: remove this hardcoding - seq_dim = 2 - group = mesh.get_group() - return _templated_ring_attention_backward( - group, - seq_dim, - aten._scaled_dot_product_flash_attention_backward.default, - grad_out=grad_out, - grad_out_name="grad_out", - query=query, - key=key, - value=value, - out=out, - logsumexp=logsumexp, - is_causal=is_causal, - cum_seq_q=cum_seq_q, - cum_seq_k=cum_seq_k, - max_q=max_q, - max_k=max_k, - dropout_p=dropout_p, - philox_seed=philox_seed, - philox_offset=philox_offset, - scale=scale, - ) - - -def _scaled_dot_product_ring_efficient_attention_backward( - mesh: DeviceMesh, - grad_out: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - bias: torch.Tensor, - out: torch.Tensor, - logsumexp: torch.Tensor, - philox_seed: torch.Tensor, - philox_offset: torch.Tensor, - dropout_p: float, - grad_input_mask: tuple[bool, ...], - is_causal: bool = False, - *, - scale: Optional[float] = None, -) -> tuple[torch.Tensor, ...]: - # TODO: remove this hardcoding - seq_dim = 2 - group = mesh.get_group() - return _templated_ring_attention_backward( - group, - seq_dim, - aten._scaled_dot_product_efficient_attention_backward.default, - grad_out=grad_out, - grad_out_name="grad_out_", - query=query, - key=key, - value=value, - attn_bias=bias, - out=out, - logsumexp=logsumexp, - philox_seed=philox_seed, - philox_offset=philox_offset, - dropout_p=dropout_p, - grad_input_mask=grad_input_mask, - is_causal=is_causal, - scale=scale, - ) - - -def _scaled_dot_product_ring_cudnn_attention_backward( - mesh: DeviceMesh, - grad_out: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - out: torch.Tensor, - logsumexp: torch.Tensor, - philox_seed: torch.Tensor, - philox_offset: torch.Tensor, - attn_bias: torch.Tensor, - cum_seq_q: torch.Tensor, - cum_seq_k: torch.Tensor, - max_q: int, - max_k: int, - dropout_p: float, - is_causal: bool, - *, - scale: Optional[float] = None, -) -> tuple[torch.Tensor, ...]: - # TODO: remove this hardcoding - seq_dim = 2 - group = mesh.get_group() - return _templated_ring_attention_backward( - group, - seq_dim, - aten._scaled_dot_product_cudnn_attention_backward.default, - grad_out=grad_out, - grad_out_name="grad_out", - query=query, - key=key, - value=value, - out=out, - logsumexp=logsumexp, - philox_seed=philox_seed, - philox_offset=philox_offset, - attn_bias=attn_bias, - cum_seq_q=cum_seq_q, - cum_seq_k=cum_seq_k, - max_q=max_q, - max_k=max_k, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - ) - - -def _sdpa_handler( - op_call: torch._ops.OpOverload, - args: tuple[object, ...], - kwargs: dict[str, object], -) -> object: - # extract local tensor and sharding infos to a OpInfo - op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) - logger.debug("Dispatching op_call: %s", op_info.schema) - - # sharding propagation - # TODO: remove the context parallel strategy from the default propagation - # rule. Either figure out how to dynamically enable it or just don't call - # propagate. - DTensor._op_dispatcher.sharding_propagator.propagate(op_info) - output_sharding = op_info.output_sharding - assert output_sharding is not None, "output sharding should not be None" - assert not output_sharding.needs_redistribute, "inputs need to be redistributed" - - call_maps: dict[torch._ops.OpOverload, Callable] = { - aten._scaled_dot_product_flash_attention.default: _scaled_dot_product_ring_flash_attention, - aten._scaled_dot_product_efficient_attention.default: _scaled_dot_product_ring_efficient_attention, - aten._scaled_dot_product_cudnn_attention.default: _scaled_dot_product_ring_cudnn_attention, - aten._scaled_dot_product_flash_attention_backward.default: _scaled_dot_product_ring_flash_attention_backward, - aten._scaled_dot_product_efficient_attention_backward.default: _scaled_dot_product_ring_efficient_attention_backward, - aten._scaled_dot_product_cudnn_attention_backward.default: _scaled_dot_product_ring_cudnn_attention_backward, - } - if op_call in call_maps: - local_results = call_maps[op_call]( - op_info.compute_mesh, - *op_info.local_args, # type: ignore[arg-type] - **op_info.local_kwargs, # type: ignore[arg-type] - ) - else: - raise NotImplementedError( - "CP only supports flash attention and memory efficient attention now." - ) - - return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec) - - -custom_ops = { - aten._scaled_dot_product_flash_attention.default: _sdpa_handler, - aten._scaled_dot_product_flash_attention_backward.default: _sdpa_handler, - aten._scaled_dot_product_efficient_attention.default: _sdpa_handler, - aten._scaled_dot_product_efficient_attention_backward.default: _sdpa_handler, - aten._scaled_dot_product_cudnn_attention.default: _sdpa_handler, - aten._scaled_dot_product_cudnn_attention_backward.default: _sdpa_handler, -} -exitsing_custom_ops = DTensor._op_dispatcher._custom_op_handlers - - -ArgsType = tuple[Any, ...] -KwargsType = dict[str, Any] -InputFnType = Callable[[Optional[nn.Module], ArgsType, KwargsType, DeviceMesh], Any] -OutputFnType = Callable[[Optional[nn.Module], Any, Any, DeviceMesh], Any] - -_replaced_functions: dict[Callable, tuple[str, Callable]] = {} - - -def _distribute_function( - fn: Callable, - fn_module: types.ModuleType, - device_mesh: DeviceMesh, - input_fn: InputFnType, - output_fn: OutputFnType, -) -> None: - """ - A helper function to replace a function with a distributed version by - using the monkey patching approach. - - This function is for the CP internal usage only. - """ - - def wrapper( - target_fn: Callable, input_fn: InputFnType, output_fn: OutputFnType - ) -> Callable: - def inner_fn(*args: ArgsType, **kwargs: KwargsType) -> Any: - args, kwargs = input_fn(None, args, kwargs, device_mesh) - outputs = target_fn(*args, **kwargs) - return output_fn(None, (args, kwargs), outputs, device_mesh) - - return inner_fn - - global _replaced_functions - - if fn in _replaced_functions: - return - - wrapper_fn = wrapper(fn, input_fn, output_fn) - setattr(fn_module, fn.__name__, wrapper_fn) - _replaced_functions[wrapper_fn] = (fn.__name__, fn) - - -def _restore_function(fn: Callable, fn_module: types.ModuleType) -> None: - """Restore the function that is replaced by _distribute_function.""" - if fn not in _replaced_functions: - return - - original_name, original_fn = _replaced_functions[fn] - setattr(fn_module, original_name, original_fn) - - -def _enable_cp_dtensor_dispatcher() -> None: - """Enables DTensor dispatcher to dispatch SDPA to CP.""" - DTensor._op_dispatcher._custom_op_handlers = { - **exitsing_custom_ops, - **custom_ops, - } - - -def _disable_cp_dtensor_dispatcher() -> None: - """Disables DTensor dispatcher to dispatch SDPA to CP.""" - DTensor._op_dispatcher._custom_op_handlers = exitsing_custom_ops - - -def _enable_context_parallel_dispatcher_impl(seq_dim: int, mesh: DeviceMesh) -> None: - sdpa_cp = _ContextParallel( - seq_dim=seq_dim, - attention_type=_ContextParallel.AttentionType.SDPA, - ) - - if _dispatch_mode == _DispatchMode.MONKEY_PATCH: - _distribute_function( - F.scaled_dot_product_attention, - F, - mesh, - sdpa_cp.sdpa_input_fn, - sdpa_cp.sdpa_output_fn, - ) - _enable_cp_dtensor_dispatcher() - elif _dispatch_mode == _DispatchMode.MODULE_WRAPPER: - _enable_cp_dtensor_dispatcher() - else: - raise ValueError(f"Unknown dispatch mode: {_dispatch_mode}") - - -def _disable_context_parallel_dispatcher_impl() -> None: - if _dispatch_mode == _DispatchMode.MONKEY_PATCH: - _restore_function(F.scaled_dot_product_attention, F) - elif _dispatch_mode == _DispatchMode.MODULE_WRAPPER: - pass - else: - raise NotImplementedError(f"Unknown dispatch mode: {_dispatch_mode}") - - _disable_cp_dtensor_dispatcher() - - -_compiled_create_block_mask = torch.compile( - create_block_mask, dynamic=False, fullgraph=True +from ._context_parallel._load_balancer import ( + _HeadTailLoadBalancer, + _LoadBalancer, + _PerDocumentHeadTailLoadBalancer, + _PTRRLoadBalancer, ) -def _context_parallel_buffers( - mesh: DeviceMesh, - buffers: list[torch.Tensor | BlockMask], - buffer_seq_dims: list[int], - load_balancer: Optional[_LoadBalancer] = None, -) -> list[torch.Tensor | BlockMask]: - """ - Shard the buffers along the sequence dimensions according to CP rules. - Args: - mesh (:class:`DeviceMesh`): the device mesh for the context parallelism. - buffers (List[torch.Tensor]): the buffers to be sharded. - seq_dims (List[int]): the sequence dimensions of ``buffers``. This list - must have the same length as ``buffers``. - load_balancer (Optional[:class:`_LoadBalancer`]): an optional `_LoadBalancer` - object. If this argument is `None`, it means the `buffers` need no - rearrangement before being sharded. If this argument is a `_LoadBalancer` - object, call its `_generate_indices(restore=False)` to generate the - rearrangement indices such that each shard of `buffer[rearrange_idx]` is - well-balanced (i.e., having close sparsities). - - Returns: - List[torch.Tensor]: the sharded buffers. - - Note: - For `_context_parallel_shard` we require a non-None `load_balancer` object to be - explicitly passed if load-balancing is needed. - """ - # generate the index tensor for rearranging the buffer if a load-balance - # is available - load_balance_indices = load_balancer._generate_indices() if load_balancer else None - assert load_balance_indices is None or load_balance_indices.ndim == 2, ( - "load balance index expects shape (1, seq_len) or (B, seq_len) " - f"but got {load_balance_indices.shape}." - ) - - new_buffers = [] - sharded_buffer: torch.Tensor | BlockMask - for buffer, seq_dim in zip(buffers, buffer_seq_dims): - if isinstance(buffer, torch.Tensor): - # TODO: the load balance doesn't perform error handling. - - # NOTE: assuming batch dim is 0 - - if load_balance_indices is not None: - # TODO: we should expclitly ask users to unsqueeze the batch dim. - # But this is a BC breaking ask. - # However, what we have done today is also not very safe. - idx_batch_size = load_balance_indices.size(0) - data_batch_size = buffer.size(0) if seq_dim > 0 else 1 - - if idx_batch_size != 1 and idx_batch_size != data_batch_size: - raise ValueError( - "Cannot rearrange buffer: " - f"load_balance_indices has shape {load_balance_indices.shape}, " - f"but buffer has shape {buffer.shape}." - ) - - if seq_dim == 0: - buffer = torch.index_select( - buffer, dim=0, index=load_balance_indices[0] - ) - else: - indices = load_balance_indices - if idx_batch_size == 1: - size = [data_batch_size] + list(indices.size())[1:] - indices = indices.expand(*size) - - for i in range(data_batch_size): - buffer[i] = torch.index_select( - buffer[i], dim=seq_dim - 1, index=indices[i] - ) - - # use DTensor to shard the buffer on sequence dimension, retain the local tensor - sharded_buffer = distribute_tensor( - buffer, mesh, [Shard(seq_dim)], src_data_rank=None - ).to_local() - elif isinstance(buffer, BlockMask): - sharded_buffer = _create_cp_block_mask( - mask_mod=buffer.mask_mod, - B=buffer.kv_num_blocks.shape[0], - H=buffer.kv_num_blocks.shape[1], - Q_LEN=buffer.seq_lengths[0], - KV_LEN=buffer.seq_lengths[1], - device_mesh=mesh, - load_balancer=load_balancer, - ) - else: - raise ValueError(f"Unknown buffer type: {type(buffer)}") - - new_buffers.append(sharded_buffer) - - return new_buffers - - -def _create_cp_block_mask( - mask_mod: _mask_mod_signature, - B: int, - H: int, - Q_LEN: int, - KV_LEN: int, - device_mesh: DeviceMesh, - load_balancer: Optional[_LoadBalancer] = None, -) -> BlockMask: - """ - Creates a specialized BlockMask for Context Parallel FlexAttention. - - This function creates a BlockMask that enables computation of attention results - for sharded Q attending to global KV. The mask appropriately handles the query - index offset required when each rank operates on a shard of the query sequence - while accessing the full key-value sequence. - - The function internally rewrites the provided mask_mod function to translate local - query indices to global query indices, ensuring that the masking logic is applied - correctly across the distributed computation. - - Args: - mask_mod (Callable): Mask function that operates on global attention indices. - B (int): Batch size. - H (int): Number of query heads. - Q_LEN (int): Global sequence length of the query. - KV_LEN (int): Global sequence length of the key/value. - device_mesh (DeviceMesh): Device mesh used for context parallelism. - load_balancer (Optional[:class:`_LoadBalancer`]): The load-balancer used to rearrange - QKV before sharding. This will be used to modify the block_mask generated. - - Returns: - BlockMask: A block mask configured for the local query shard that can be used - with flex_attention() for the given cp_mesh. - - Raises: - NotImplementedError: If Q_LEN is not divisible by (CP world size * BLOCK_SIZE). - - Warning: - Currently requires Q_LEN to be divisible by CP mesh world size * BLOCK_SIZE - (BLOCK_SIZE defaults to 128). This constraint exists because the BlockMask - must handle both padding and offsets correctly. For example, if Q_LEN is 384, - CP world size is 2, and BLOCK_SIZE is 128, the local Q_LEN would be 192. In - such cases, both rank0 and rank1 would have paddings in their local BlockMasks. - Support for padding in this scenario is planned for future work. - - """ - - from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE - - if Q_LEN % (device_mesh.size() * _DEFAULT_SPARSE_BLOCK_SIZE) != 0: - raise NotImplementedError( - f"Q_LEN {Q_LEN} is not divisible by CP mesh world size {device_mesh.size()} * " - f"BLOCK_SIZE {_DEFAULT_SPARSE_BLOCK_SIZE}. This is not supported yet. " - ) - - compiled_create_block_mask = torch.compile( - create_block_mask, dynamic=False, fullgraph=True - ) - - def _rewrite_mask_mod( - mask_mod: _mask_mod_signature, - rank: int, - block_size: int, - local_q_size: int, - qkv_rearrange_indices: Optional[torch.Tensor] = None, - ) -> _mask_mod_signature: - assert qkv_rearrange_indices is None or qkv_rearrange_indices.ndim == 2, ( - "load balance index expects shape (1, seq_len) or (B, seq_len) " - f"but got {qkv_rearrange_indices.shape}." - ) - - def qkv_idx_restore( - b: torch.Tensor, idx_post_rearrange: torch.Tensor - ) -> torch.Tensor: - if qkv_rearrange_indices is not None: - if ( - qkv_rearrange_indices.size(0) == 1 - ): # identical load-balance in batch - idx_pre_rearrange = qkv_rearrange_indices[0][idx_post_rearrange] - else: - idx_pre_rearrange = qkv_rearrange_indices[b][idx_post_rearrange] - else: - idx_pre_rearrange = idx_post_rearrange - - return idx_pre_rearrange - - def local_q_idx_to_q_idx(local_q_idx: torch.Tensor) -> torch.Tensor: - # calculate local block_idx and block_offset - local_blk_idx, local_blk_offset = ( - local_q_idx // block_size, - local_q_idx % block_size, - ) - # NOTE: load balancing is not used - local_num_blocks = local_q_size // block_size - blk_idx = local_num_blocks * rank + local_blk_idx - return blk_idx * block_size + local_blk_offset - - return lambda b, h, q_idx, kv_idx: mask_mod( - b, - h, - qkv_idx_restore(b, local_q_idx_to_q_idx(q_idx)), - qkv_idx_restore(b, kv_idx), - ) - - cp_rank = device_mesh.get_local_rank() - cp_group_size = device_mesh.size() - load_balancer = load_balancer or _create_default_load_balancer( - Q_LEN, cp_group_size, device_mesh.device_type - ) - Q_SHARD_LEN = Q_LEN // cp_group_size - block_size = _DEFAULT_SPARSE_BLOCK_SIZE - - rearrange_indices = ( - load_balancer._generate_indices(restore=False) if load_balancer else None - ) - block_mask = compiled_create_block_mask( - _rewrite_mask_mod( - mask_mod, - cp_rank, - block_size, - Q_SHARD_LEN, - qkv_rearrange_indices=rearrange_indices, - ), - B, - H, - Q_SHARD_LEN, - KV_LEN, - device=device_mesh.device_type, - BLOCK_SIZE=(block_size, block_size), - ) - return block_mask - - -##################### -# Experimental APIs -##################### - - -class _ContextParallel(ParallelStyle): - class AttentionType(Enum): - FLEX = "flex_attention" - SDPA = "scaled_dot_product_attention" - - def __init__( - self, - seq_dim: int, - attention_type: AttentionType, - ) -> None: - super().__init__() - self.seq_dim = seq_dim - self.attention_type = attention_type - - def _apply(self, module: nn.Module, mesh: DeviceMesh) -> nn.Module: - if self.attention_type == self.AttentionType.FLEX: - module.register_forward_pre_hook( - partial(self.flex_input_fn, mesh=mesh), with_kwargs=True - ) - return module - elif self.attention_type == self.AttentionType.SDPA: - module.register_forward_pre_hook( - partial(self.sdpa_input_fn, mesh=mesh), with_kwargs=True - ) - module.register_forward_hook(partial(self.sdpa_output_fn, mesh=mesh)) - return module - else: - raise ValueError(f"Unknown attention type: {self.attention_type}") - - def flex_input_fn( - self, module: Optional[nn.Module], args: Any, kwargs: Any, mesh: DeviceMesh - ) -> Any: - args_list = list(args) - for idx, name in enumerate( - ("query", "key", "value", "score_mod", "block_mask") - ): - if idx >= len(args): - args_list.append(kwargs.pop(name, None)) - - query, key, value, score_mod, block_mask = args_list[:5] - assert isinstance(query, torch.Tensor) - assert isinstance(key, torch.Tensor) - assert isinstance(value, torch.Tensor) - assert isinstance(block_mask, BlockMask | tuple) - - key = key.contiguous() - value = value.contiguous() - - global_key, global_value = flex_cp_allgather( - key, value, self.seq_dim, c10d._get_process_group_name(mesh.get_group()) - ) - args_list[1] = global_key - args_list[2] = global_value - - return tuple(args_list), kwargs - - def sdpa_input_fn( - self, - module: Optional[nn.Module], - args: tuple[Any, ...], - kwargs: dict[str, Any], - mesh: DeviceMesh, - ) -> tuple[tuple[Any, ...], dict[str, Any]]: - placement = [Shard(self.seq_dim)] - all_args = [] - - for arg in itertools.chain(args, kwargs.values()): - if isinstance(arg, torch.Tensor): - if isinstance(arg, DTensor): - assert arg._spec.placements == placement - else: - arg = DTensor.from_local(arg, mesh, placement, run_check=False) - - all_args.append(arg) - - new_args = tuple(all_args[0 : len(args)]) - new_kwargs = dict(zip(kwargs.keys(), all_args[len(args) :])) - return new_args, new_kwargs - - def sdpa_output_fn( - self, module: Optional[nn.Module], inputs: Any, outputs: Any, mesh: DeviceMesh - ) -> Any: - new_outputs = [] - for output in [outputs] if isinstance(outputs, torch.Tensor) else outputs: - output = output.to_local() if isinstance(output, DTensor) else output - new_outputs.append(output) - - if isinstance(outputs, torch.Tensor): - return new_outputs[0] - - return tuple(new_outputs) - - -CPBuffer: TypeAlias = torch.Tensor | BlockMask -CPBufferContainer: TypeAlias = Sequence[CPBuffer] | Mapping[str, CPBuffer] -CPBufferSeqDims: TypeAlias = Sequence[int] | Mapping[str, int] - - -def _context_parallel_shard( - mesh: DeviceMesh, - buffers: CPBufferContainer, - seq_dims: CPBufferSeqDims, - load_balancer: Optional[_LoadBalancer] = None, -) -> list[torch.Tensor | BlockMask]: - """ - Shard the buffers along the specified sequence dimensions (`seq_dims`), so that each - rank retains only its corresponding shard according to the provided `mesh`. If a - `load_balancer` is provided, the buffers will be rearranged by the load balancer - before sharding to improve load balance. Buffers can be either tensors or `BlockMask` - objects. If a buffer is a `BlockMask`, its sharding dimension is determined by the - `BlockMask` implementation, and the corresponding `seq_dim` is ignored. - - Note: - For `_context_parallel_shard`, a non-None `load_balancer` must be explicitly passed - if load balancing is required. - - Args: - mesh (DeviceMesh): The device mesh used for context parallelism. - buffers (List[torch.Tensor | BlockMask]): Buffers whose usage depends on the sequence - dimension. Examples include input batches, labels, and positional embedding buffers. - These buffers must be sharded along the sequence dimension to ensure correctness. - seq_dims (List[int]): The sequence dimensions for each buffer in `buffers`. Must have - the same length as `buffers`. - load_balancer (Optional[_LoadBalancer]): An optional load balancer object. If provided, - it rearranges the buffers before sharding to achieve better load balance. If not - provided, no rearrangement is performed. - - Returns: - List[torch.Tensor | BlockMask]: The sharded buffers, each corresponding to the local - shard for the current rank. - """ - # TODO: these global variables are going to bite us someday. - # We will have to remove them soon. - # For the new API, we only support the module wrapper mode. - global _dispatch_mode - _dispatch_mode = _DispatchMode.MODULE_WRAPPER - global _cp_options - if load_balancer is not None: - _cp_options.enable_load_balance = True - else: - _cp_options.enable_load_balance = False - - if len(buffers) != len(seq_dims): - raise ValueError( - "`seq_dims` must have the same number of elements as `buffers`." - ) - - flat_buffers, spec = tree_flatten(buffers) - flat_seq_dims, _ = tree_flatten(seq_dims) - if len(flat_buffers) != len(flat_seq_dims): - raise ValueError("`seq_dims` must have the pytree structure as `buffers`.") - - if isinstance(flat_buffers[0], torch.Tensor): - device = flat_buffers[0].device - else: - device = flat_buffers[0].kv_num_blocks.device - for buffer in flat_buffers: - if isinstance(buffer, torch.Tensor): - assert device == buffer.device, "All buffers must be on the same device" - else: - assert device == buffer.kv_num_blocks.device, ( - "All buffers must be on the same device" - ) - - flat_sharded_buffers = _context_parallel_buffers( - mesh, flat_buffers, flat_seq_dims, load_balancer - ) - - return tree_unflatten(flat_sharded_buffers, spec) - - -def _enable_context_parallel_dispatcher() -> None: - """ - Enable the context parallel dispatcher. This API is experimental and subject to change. - """ - _enable_cp_dtensor_dispatcher() - - -def _disable_context_parallel_dispatcher() -> None: - """ - Disable the context parallel dispatcher. This API is experimental and subject to change. - """ - _disable_cp_dtensor_dispatcher() - - -##################################################### -# Current public APIs, but are also subject to change -##################################################### -@contextlib.contextmanager -@torch.no_grad() -def context_parallel( - mesh: DeviceMesh, - *, - buffers: Optional[list[torch.Tensor]] = None, - buffer_seq_dims: Optional[list[int]] = None, - no_restore_buffers: Optional[set[torch.Tensor]] = None, -) -> Generator[None, None, None]: - """ - - ``context_parallel`` is an experimental API to enable context - parallelism (CP). This API performs two actions: 1) patch the SDPA - (``torch.nn.functional.scaled_dot_product_attention``) with the CP-enabled - one, 2) shard ``buffers`` along the sequence dimension and each rank will - preserve the corresponding shard according ``mesh``. - - Args: - mesh (:class:`DeviceMesh`): the device mesh for the context parallelism. - buffers (Optional[List[torch.Tensor]]): buffers that the usage depend - on the sequence dimension. Examples are input batch, labels and - positional embedding buffers. These buffers must be sharded along - the sequence dimension to ensure the accuracy. The sharding will - happen in-place, the buffer's shape will change within the context. - The buffers will be restored after the context finishes. - ``no_restore_buffers`` can be used to specify which buffers don't - need to be restored. Note that ``buffers`` should not contain any - nn.Parameter. - buffer_seq_dims (Optional[List[int]]): the sequence dimensions of ``buffers``. - no_restore_buffers (Optional[Set[torch.Tensor]]): buffers in these set - won't be restored after the context exits. This set must be a subset - of ``buffers``. If the buffers won't be used after the context exits, - these buffers can be put in this list to avoid extra restore time. - - .. warning:: - `torch.distributed.tensor.experimental.context_parallel` is a - prototype feature in PyTorch. The API is subject to change. - """ - # For the legacy API, we only support the monkey-patch mode. - # We will deprecate this API once the new API is widely used. - global _dispatch_mode - _dispatch_mode = _DispatchMode.MONKEY_PATCH - - buffers = [] if buffers is None else buffers - buffer_seq_dims = [] if buffer_seq_dims is None else buffer_seq_dims - no_restore_buffers = set() if no_restore_buffers is None else no_restore_buffers - - if len(buffers) != len(buffer_seq_dims): - raise ValueError( - "`seq_dims` must have the same number of elements as `buffers`." - ) - - for buffer in no_restore_buffers: - # Cannot use `if not buffer in buffers` which will incur tensor comparison. - if not any(b is buffer for b in buffers): - raise ValueError("`no_restore_buffers` must be a subset of `buffers`.") - - original_buffers = [None if b in no_restore_buffers else b.clone() for b in buffers] - - device = buffers[0].device - seq_length = buffers[0].shape[buffer_seq_dims[0]] - cp_world_size = mesh.size() - - # If `enable_load_balance` is True, the default Head-tail load balancer - # (:class:`_HeadTailLoadBalancer`) is used to rearrange the buffers before - # sharding. Otherwise, we don't do any load-balance rearrange by passing - # `None` to `_context_parallel_shard()`. - load_balancer = _create_default_load_balancer(seq_length, cp_world_size, device) - shards = _context_parallel_buffers( - mesh, - cast(list[torch.Tensor | BlockMask], buffers), - buffer_seq_dims, - load_balancer, - ) - for buffer, shard in zip(buffers, shards): - assert isinstance(shard, torch.Tensor), "ContextParallel only supports Tensor" - shard = shard.clone() - buffer.resize_(shard.shape) - buffer.copy_(shard) - - _enable_context_parallel_dispatcher_impl(seq_dim=2, mesh=mesh) - yield - _disable_context_parallel_dispatcher_impl() - - for buffer, original_buffer in zip(buffers, original_buffers): - if original_buffer is not None: - buffer.resize_(original_buffer.shape) - buffer.copy_(original_buffer) - - -@torch.no_grad() -def context_parallel_unshard( - mesh: DeviceMesh, - buffers: list[torch.Tensor], - seq_dims: list[int], - load_balancer: Optional[_LoadBalancer] = None, -) -> list[torch.Tensor]: - """ - Unshard the tensors (e.g., output) that are sharded due to context parallelism. - - Args: - mesh (:class:`DeviceMesh`): the device mesh for the context parallelism. - buffers (List[torch.Tensor]): the buffers to be unsharded. - seq_dims (List[int]): the sequence dimensions of ``buffers``. This list - must have the same length as ``buffers``. - load_balancer (Optional[:class:`_Loadbalancer`]): an optional `_LoadBalancer` - object. If this argument is `None`, it means the `buffers` were not - rearranged when being sharded and there's no need to put it back to order - after unsharding. If this argument is a `_LoadBalancer` object, call - its `_generate_indices(restore=True)` to generate the restore indices such - that `unsharded[restore_idx]` is the original buffer. - - Returns: - List[torch.Tensor]: the unsharded buffers. - - Note: - For `context_parallel_unshard` we require not-None `load_balancer` object be - explicitly passed if flex_attention() is to be used and load-balancing is needed. - This is different from the case of SDPA though we strongly suggest users follow - the same convention. - """ - device = buffers[0].device - cp_world_size = mesh.size() - seq_length = buffers[0].shape[seq_dims[0]] * cp_world_size - - # If users don't pass in a `load_balancer`: - # - if `enable_load_balance` is True, we use the default round-robin - # load balancer. - # - if `enable_load_balance` is False, we don't do any load balancing - # by passing in `None` as `restore_indices`. - load_balancer = load_balancer or _create_default_load_balancer( - seq_length, cp_world_size, device - ) - restore_indices = ( - load_balancer._generate_indices(restore=True) if load_balancer else None - ) - - assert restore_indices is None or restore_indices.ndim == 2, ( - "load balance restore index expects shape (1, seq_len) or (B, seq_len) " - f"but got {restore_indices.shape}." - ) - unsharded_buffers = [] - for b, dim in zip(buffers, seq_dims): - b = b.contiguous() - unsharded_b = _maybe_wait(ft_c.all_gather_tensor(b, dim, mesh)) - - if restore_indices is not None: - # NOTE: assuming batch dim is 0 - idx_batch_size = restore_indices.size(0) - data_batch_size = unsharded_b.size(0) - if idx_batch_size != 1 and idx_batch_size != data_batch_size: - raise ValueError( - "Cannot restore buffer: " - f"restore_indices has shape {restore_indices.shape}, " - f"but unsharded_b has shape {unsharded_b.shape}." - ) - - for i in range(data_batch_size): - index = ( - restore_indices[0] # identical load-balance in batch - if idx_batch_size == 1 - else restore_indices[i] - ) - unsharded_b_batch_i = torch.index_select( - unsharded_b[i], dim=dim - 1, index=index - ) - unsharded_b[i] = unsharded_b_batch_i - - unsharded_buffers.append(unsharded_b) - - return unsharded_buffers - - -def set_rotate_method(rotate_method: str) -> None: - """ - Context Parallel SDPA requires the rotation of kv shards. Users can call this - API to specify which rotation method to use. "alltoall" shuffles the kv shards - using all-to-all collective. While "allgather" gathers the kv shards using - all-gather collective after the first sub-SDPA computation. If this API has not - been called, the default rotate method is "allgather". - - Args: - rotate_method (str): the rotate method to use. Currently only supports - "allgather" and "alltoall". If a different string other than these two - is passed in, the function will raise an error. - - Returns: - None - """ - logger.info("Note that FlexAttention CP doesn't support alltoall yet.") - if rotate_method == "allgather": - _cp_options.rotate_method = _RotateMethod.ALL_GATHER - elif rotate_method == "alltoall": - _cp_options.rotate_method = _RotateMethod.ALL_TO_ALL - else: - raise NotImplementedError( - "Context Parallel does not support " - f"using {rotate_method} for kv shards rotation" - ) +__all__ = [ + "_CausalBehavior", + "_context_parallel_shard", + "_ContextParallel", + "_cp_options", + "_disable_context_parallel_dispatcher", + "_enable_context_parallel_dispatcher", + "_is_causal_behavior", + "_RotateMethod", + "context_parallel", + "context_parallel_unshard", + "set_rotate_method", + "_HeadTailLoadBalancer", + "_LoadBalancer", + "_PerDocumentHeadTailLoadBalancer", + "_PTRRLoadBalancer", +] diff --git a/torch/distributed/tensor/experimental/_context_parallel/__init__.py b/torch/distributed/tensor/experimental/_context_parallel/__init__.py new file mode 100644 index 0000000000000..009255631796f --- /dev/null +++ b/torch/distributed/tensor/experimental/_context_parallel/__init__.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Context Parallel components + +from ._attention import ( + _CausalBehavior, + _context_parallel_shard, + _ContextParallel, + _cp_options, + _disable_context_parallel_dispatcher, + _enable_context_parallel_dispatcher, + _is_causal_behavior, + _RotateMethod, + context_parallel, + context_parallel_unshard, + set_rotate_method, +) +from ._cp_custom_ops import flex_cp_allgather +from ._load_balancer import ( + _HeadTailLoadBalancer, + _LoadBalancer, + _PerDocumentHeadTailLoadBalancer, + _PTRRLoadBalancer, +) + + +__all__ = [ + # From _attention + "_CausalBehavior", + "_context_parallel_shard", + "_ContextParallel", + "_cp_options", + "_disable_context_parallel_dispatcher", + "_enable_context_parallel_dispatcher", + "_is_causal_behavior", + "_RotateMethod", + "context_parallel", + "context_parallel_unshard", + "set_rotate_method", + # From _cp_custom_ops + "flex_cp_allgather", + # From _load_balancer + "_HeadTailLoadBalancer", + "_LoadBalancer", + "_PerDocumentHeadTailLoadBalancer", + "_PTRRLoadBalancer", +] diff --git a/torch/distributed/tensor/experimental/_context_parallel/_attention.py b/torch/distributed/tensor/experimental/_context_parallel/_attention.py new file mode 100644 index 0000000000000..09a86081df522 --- /dev/null +++ b/torch/distributed/tensor/experimental/_context_parallel/_attention.py @@ -0,0 +1,1659 @@ +import contextlib +import itertools +import logging +import types +from abc import ABC, abstractmethod +from collections.abc import Callable, Generator, Mapping, Sequence +from dataclasses import dataclass +from enum import auto, Enum +from functools import partial +from typing import Any, cast, Optional, Protocol, TypeAlias + +import torch +import torch.distributed as dist +import torch.distributed._functional_collectives as ft_c +import torch.distributed.distributed_c10d as c10d +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import distribute_tensor, DTensor, Shard +from torch.distributed.tensor.parallel import ParallelStyle +from torch.nn.attention.flex_attention import ( + _mask_mod_signature, + BlockMask, + create_block_mask, +) +from torch.utils._pytree import tree_flatten, tree_unflatten + +from ._cp_custom_ops import flex_cp_allgather +from ._load_balancer import _create_default_load_balancer, _LoadBalancer + + +__all__ = [ + "_CausalBehavior", + "_context_parallel_shard", + "_ContextParallel", + "_cp_options", + "_disable_context_parallel_dispatcher", + "_enable_context_parallel_dispatcher", + "_is_causal_behavior", + "_RotateMethod", + "context_parallel", + "context_parallel_unshard", + "set_rotate_method", +] + + +class _CausalBehavior(Enum): + SKIP = None + NOT_IS_CAUSAL = False + IS_CAUSAL = True + + +class _RotateMethod(Enum): + ALL_TO_ALL = auto() + ALL_GATHER = auto() + + +aten = torch.ops.aten +logger = logging.getLogger(__name__) + + +class _DispatchMode(Enum): + MONKEY_PATCH = auto() + MODULE_WRAPPER = auto() + + +_dispatch_mode: _DispatchMode = _DispatchMode.MONKEY_PATCH + + +@dataclass +class _ContextParallelOptions: + # Whether to upcast parameters and gradients to float32 to avoid accumulation + # errors. It is likely this is always True, but we currently keep this variable + # for experimental purposes. + convert_to_f32: bool = True + enable_load_balance: bool = True + rotate_method: _RotateMethod = _RotateMethod.ALL_GATHER + + +_cp_options = _ContextParallelOptions() + + +def _is_causal_behavior( + rank: int, world_size: int, i: int, is_causal: bool +) -> _CausalBehavior: + """ + Calculate is_causal behavior for each KV block. The attention can either be + calculated in full, not at all or with the causal mask applied. + """ + if not is_causal: + return _CausalBehavior.NOT_IS_CAUSAL + + if i == 0: + return _CausalBehavior.IS_CAUSAL + + source_rank = (rank - i) % world_size + if source_rank < rank or _cp_options.enable_load_balance: + return _CausalBehavior.NOT_IS_CAUSAL + else: + return _CausalBehavior.SKIP + + +def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor: + """ + When tracing the code, the result tensor is not an AsyncCollectiveTensor, + so we cannot call ``wait()``. + """ + if isinstance(tensor, ft_c.AsyncCollectiveTensor): + return tensor.wait() + return tensor + + +def _partial_update( + original: torch.Tensor, + new: torch.Tensor, + dim: int, + n_chunks: int, + idx: int, + add: bool, +) -> torch.Tensor: + """ + This API partially updates a chunk of ``original`` tensor. The ``original`` + tensor will be first chunked along ``dim`` dimension, then the ``idx`` chunk + will be updated with ``new``. If ``add`` is True, the chunk will be added + with ``new``, otherwise the chunk will be replaced by ``new``. + + The result is a tensor that is the same size as ``original``. + """ + chunks = list(original.chunk(n_chunks, dim=dim)) + assert chunks[idx].shape == new.shape, (original.shape, new.shape, idx) + if add: + chunks[idx] += new + else: + chunks[idx] = new + return torch.cat(chunks, dim=dim) + + +class _SDPAMerger: + """A class to help merge the local SDPA result.""" + + def __init__(self, convert_to_f32: bool, seq_dim: int): + self._seq_dim = seq_dim + self._out: Optional[torch.Tensor] = None + self._lse: Optional[torch.Tensor] = None + self._should_lse_squeeze = False + self._convert_to_f32 = convert_to_f32 + self._out_dtype = torch.float32 + self._lse_dtype = torch.float32 + + def _merge_one( + self, block_out: torch.Tensor, block_lse: torch.Tensor, partial: bool + ) -> None: + # The cuDNN backend preserves the last dimension for LSE. + # Apply unsqueeze only if the input does not already have + # the required dimensionality. + if len(block_lse.shape) < len(block_out.shape): + block_lse = block_lse.unsqueeze(dim=-1) + self._should_lse_squeeze = True + assert len(block_lse.shape) == len(block_out.shape) + + if self._lse is None: + self._lse = block_lse + self._out = block_out + else: + ROUND_ROBIN_CYCLE = 2 + assert self._lse is not None + assert self._out is not None + lse = ( + self._lse.chunk(ROUND_ROBIN_CYCLE, dim=self._seq_dim)[1] + if partial + else self._lse + ) + out = ( + self._out.chunk(ROUND_ROBIN_CYCLE, dim=self._seq_dim)[1] + if partial + else self._out + ) + + # The algorithm from + # github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + # gives a relatively stable result. + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + if partial: + self._lse = _partial_update( + self._lse, + lse, + dim=self._seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=1, + add=False, + ) + self._out = _partial_update( + self._out, + out, + dim=self._seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=1, + add=False, + ) + else: + self._lse = lse + self._out = out + + def step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool) -> None: + self._out_dtype = out.dtype + self._lse_dtype = lse.dtype + + if self._convert_to_f32: + out = out.to(torch.float32) + lse = lse.to(torch.float32) + + self._merge_one(out, lse, partial) + + def results(self) -> tuple[torch.Tensor, torch.Tensor]: + assert self._out is not None + assert self._lse is not None + out = self._out.to(self._out_dtype) + if self._should_lse_squeeze: + lse = self._lse.squeeze(-1).to(self._lse_dtype) + else: + lse = self._lse.to(self._lse_dtype) + return out, lse + + +class _AttentionOp(Protocol): + def __call__( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + **kwargs: object, + ) -> tuple[torch.Tensor, ...]: ... + + +class _RingRotater(ABC): + @abstractmethod + def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: ... + + @abstractmethod + def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: ... + + @abstractmethod + def next_buffer(self) -> torch.Tensor: ... + + +class _AllToAllRotater(_RingRotater): + """Use all_to_all to send the kv to the next rank.""" + + def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: + self._pg = pg + self._seq_dim = seq_dim + self._buffer: Optional[torch.Tensor] = None + + def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: + curr_buffer = curr_buffer.contiguous() + size = dist.get_world_size(self._pg) + dsts = list(range(1, size)) + [0] + self._buffer = ft_c.permute_tensor(curr_buffer, dsts, self._pg) + + def next_buffer(self) -> torch.Tensor: + assert self._buffer is not None + return _maybe_wait(self._buffer) + + +class _AllGatherRotater(_RingRotater): + """ + Allgather the kv and return only the required kv. + Only one communication will be done. + """ + + def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: + self._pg = pg + self._seq_dim = seq_dim + self._aggregated_buffer: Optional[torch.Tensor] = None + self._idx = 0 + + def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: + # We only need to perform allgather once. + self._idx += 1 + if self._aggregated_buffer is None: + self._aggregated_buffer = ft_c.all_gather_tensor( + curr_buffer.contiguous(), gather_dim=0, group=self._pg + ) + + def next_buffer(self) -> torch.Tensor: + rank = dist.get_rank(self._pg) + idx = rank - self._idx + + assert self._aggregated_buffer is not None + self._aggregated_buffer = _maybe_wait(self._aggregated_buffer) + return self._aggregated_buffer.chunk(dist.get_world_size(self._pg))[idx] + + +def _create_rotater( + pg: dist.ProcessGroup, seq_dim: int, method: Optional[_RotateMethod] = None +) -> _RingRotater: + if method is None: + method = _cp_options.rotate_method + + if method == _RotateMethod.ALL_TO_ALL: + return _AllToAllRotater(pg, seq_dim) + elif method == _RotateMethod.ALL_GATHER: + return _AllGatherRotater(pg, seq_dim) + else: + raise NotImplementedError(f"Unknown method {method}") + + +def _templated_ring_attention( + group: dist.ProcessGroup, + seq_dim: int, + op: _AttentionOp, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + **kwargs: object, +) -> tuple[torch.Tensor, ...]: + """ + A generalized ring attention implementation that can support multiple attention ops. + + Note [Context parallelism load balance algorithm for causal masking] + ===================== + This explanation uses an example to illustrate the CP algorithm with causal + masking. + + Consider a scenario where the sequence length of q, k, and v is 4 (e.g., + q = (q0, q1, q2, q3)), and there are two ranks. For simplicity, we will discuss + only q and k, as v follows the same pattern as k. + + The diagram below represents a complete QK^T operation without parallelism. + The `****` entries indicate that the result is not required due to causal + masking (e.g., q0k1 is marked as `****`). + + +----+------------------------+ + | | k0 k1 k2 k3 | + +----+------------------------+ + | q0 | q0k0, ****, ****, **** | + | q1 | q1k0, q1k1, ****, **** | + | q2 | q2k0, q2k1, q2k2, **** | + | q3 | q3k0, q3k1, q3k2, q3k3 | + +----+------------------------+ + + ### No Load Balance: + + In this scenario, each rank owns a local chunk of q, k, and v, with each chunk + containing two elements. Rank0 is responsible for managing (q0, q1) and (k0, k1), + while rank1 manages (q2, q3) and (k2, k3). + + First Iteration: Both rank0 and rank1 perform SDPA with their local qkv pairs. + Causal masking is enabled as some results are not required (e.g., q0k1). + + Second Iteration: Local queries remain the same, but local kv pairs are exchanged. + Rank0 now has (q0, q1) and (k2, k3); rank1 has (q2, q3) and (k0, k1). Rank0 performs + no computation, while rank1 computes locally without causal masking since all results + (q2k0, q2k1, q3k0, q3k1) are needed. + + ### Round-robin Load Balance: + + In this setup, each rank owns two local chunks of q, k, and v, with each chunk + containing one element. Rank0 manages (q0, q3) and (k0, k3); Rank1 manages (q1, q2) + and (k1, k2). Although the local chunks are not consecutive, they are concatenated to + enable SDPA to be performed in a single call for each step. Consequently, the chunk() + function may be required to prepare the correct q, k, and v configurations. + + First Iteration: Both ranks perform SDPA with their local qkv pairs, similar to the + no-load-balance case. This iteration corresponds to the `if` of the + (`if, `elif`, `else`) in the implementation. + + Second Iteration: Rank0 now has (q0, q3) and (k1, k2); rank1 has (q1, q2) and + (k0, k3). For rank0, no computation is needed for q0. However, computations for + q3k1 and q3k2 are required, so only q3 is used for SDPA. This corresponds to the + `else` of the (`if`, `elif`, `else`) in the implementation. + For rank1, k3 is not needed for q1 and q2, so only k0 is used for SDPA. This + corresponds to the `elif` of (`if`, `elif`, `else`) in the implementation. + + Parameters + ---------- + op: + The attention op to use + *args: + additional args are passed to the op + **kwargs: + additional kwargs are passed to the op + + Returns + ------- + out: + The merged attention output + softmax_lse: + The logsumexp of the merged attention output + """ + if is_causal and (query.size(2) != key.size(2)): + raise NotImplementedError( + "is_causal requires the same query and context sequence lengths" + ) + if not is_causal and _cp_options.enable_load_balance: + raise RuntimeError("Load balancing requires `is_causal=True`.") + + assert isinstance(group, dist.ProcessGroup), ( + "process group must be single dimension" + ) + rank = dist.get_rank(group) + size = dist.get_world_size(group) + + next_kv = None + + # Without making key and value contiguous(), the loss curve is bad. + # TODO(fegin): figure out why this is a requirement since SDPA does not have + # this requirement. + key = key.contiguous() + value = value.contiguous() + + sdpa_merger = _SDPAMerger(_cp_options.convert_to_f32, seq_dim=seq_dim) + + rest: list[Any] + out: torch.Tensor + logsumexp: torch.Tensor + + rotater = _create_rotater(group, 2) + + for i in range(size): + if i > 0: + # Wait for the kv from the (cp_rank - 1) rank. + next_kv = rotater.next_buffer() + key = next_kv[: key.numel()].reshape(key.shape) + value = next_kv[key.numel() :].reshape(value.shape) + + if i < (size - 1): + # Send the k, v to the next rank + next_kv = torch.cat([key.flatten(), value.flatten()]) + next_kv = rotater.exchange_buffers(next_kv) + + is_causal_behavior = _is_causal_behavior( + rank=rank, world_size=size, i=i, is_causal=is_causal + ) + + # For a detailed understanding of the load balancing algorithm, see + # Note [Context parallelism load balance algorithm for causal masking] + if is_causal_behavior == _CausalBehavior.SKIP: + # If i > rank and load balancing is not turned on. + continue + + if i == 0 or (not _cp_options.enable_load_balance or not is_causal): + # When local balance is enabled, we still need to do SDPA with + # the both local chunks of q, k, v for the first iteration. + q, k, v, partial = (query, key, value, False) + elif i <= rank: + # Round-robin load balancing case, and i <= rank. + # We need to do SDPA with only the first local chunk of k, v. + # Note that q, k, v each contains two local chunks. + ROUND_ROBIN_CYCLE = 2 + q, k, v, partial = ( + query, + key.chunk(ROUND_ROBIN_CYCLE, dim=2)[0], + value.chunk(ROUND_ROBIN_CYCLE, dim=2)[0], + False, + ) + else: + # Round-robin load balancing case, and i > rank. + # We need to do SDPA with only the second half of q, and update + # only the second part of logsumexp. So partial is True. + # Note that q, k, v each contains two chunks. + q, k, v, partial = query.chunk(2, dim=2)[1], key, value, True + + # See https://github.com/pytorch/pytorch/blob/release/2.4/aten/src/ATen/native/native_functions.yaml#L14695 + # for the SDPA kernel definitions. + out, logsumexp, *rest = op( + q, + k, + v, + is_causal=is_causal_behavior.value, + **kwargs, + ) + sdpa_merger.step(out, logsumexp, partial) + + # pyrefly: ignore [unbound-name] + return *sdpa_merger.results(), *rest + + +def _templated_ring_attention_backward( + group: dist.ProcessGroup, + seq_dim: int, + op: _AttentionOp, + grad_out: torch.Tensor, + grad_out_name: str, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + is_causal: bool, + **kwargs: Any, +) -> tuple[torch.Tensor, ...]: + """This API implements the backward pass of the ring attention.""" + if not is_causal and _cp_options.enable_load_balance: + raise RuntimeError("Load balancing requires `is_causal=True`.") + rank = dist.get_rank(group) + size = dist.get_world_size(group) + next_kv = None + next_grad_kv = None + rest: list[Any] + grad_query_, grad_key_, grad_value_ = None, None, None + + accum_dtype = torch.float32 if _cp_options.convert_to_f32 else query.dtype + grad_query = torch.zeros_like(query, dtype=accum_dtype) + grad_key = torch.zeros_like(key, dtype=accum_dtype) + grad_value = torch.zeros_like(value, dtype=accum_dtype) + + key = key.contiguous() + value = value.contiguous() + kv_rotater = _create_rotater(group, 2) + dkv_rotater = _create_rotater(group, 2, method=_RotateMethod.ALL_TO_ALL) + for i in range(size): + if i > 0: + # Wait for the kv from the (cp_rank - 1) rank. + buffer = kv_rotater.next_buffer() + pointer = 0 + key = buffer[pointer : pointer + key.numel()].reshape(key.shape) + pointer += key.numel() + value = buffer[pointer : pointer + value.numel()].reshape(value.shape) + pointer += value.numel() + + if i != size - 1: + # Send the kv to the next rank. + next_kv = torch.cat([key.flatten(), value.flatten()]) + kv_rotater.exchange_buffers(next_kv) + + is_causal_behavior = _is_causal_behavior( + rank=rank, world_size=size, i=i, is_causal=is_causal + ) + + if is_causal_behavior != _CausalBehavior.SKIP: + if i == 0 or (not _cp_options.enable_load_balance or not is_causal): + # We need to do SDPA with the full local q, k, v. + q, k, v, out_, dout, lse = (query, key, value, out, grad_out, logsumexp) + elif i <= rank: + # Round-robin load balancing case, and i <= rank. + # We need to do SDPA with only the first half of k, v. + # Note that q, k, v each contains two chunks. + q, k, v, out_, dout, lse = ( + query, + key.chunk(2, dim=seq_dim)[0], + value.chunk(2, dim=seq_dim)[0], + out, + grad_out, + logsumexp, + ) + else: + # Round-robin load balancing case, and i > rank. + # We need to do SDPA with only the second half of q. + # Note that q, k, v each contains two chunks. + q, k, v, out_, dout, lse = ( + query.chunk(2, dim=seq_dim)[1], + key, + value, + out.chunk(2, dim=seq_dim)[1], + grad_out.chunk(2, dim=seq_dim)[1], + # Need to make logsumexp contiguous, otherwise there will + # be numerical error. + logsumexp.chunk(2, dim=seq_dim)[1].contiguous(), + ) + + kwargs[grad_out_name] = dout + # See https://github.com/pytorch/pytorch/blob/release/2.4/aten/src/ATen/native/native_functions.yaml#L14695 + # for the SDPA kernel definitions. + grad_query_, grad_key_, grad_value_, *rest = op( + query=q, + key=k, + value=v, + out=out_, + logsumexp=lse, + is_causal=is_causal_behavior.value, + **kwargs, + ) + else: + grad_query_ = torch.zeros_like(query, dtype=accum_dtype) + grad_key_ = torch.zeros_like(key, dtype=accum_dtype) + grad_value_ = torch.zeros_like(value, dtype=accum_dtype) + + ROUND_ROBIN_CYCLE = 2 + if i == 0: + grad_key += grad_key_ + grad_value += grad_value_ + else: + pointer = 0 + # Wait for the kv gradient from (cp_rank - 1) rank. + next_grad_kv = dkv_rotater.next_buffer() + grad_key = next_grad_kv[pointer : pointer + grad_key.numel()].reshape( + grad_key.shape + ) + pointer += grad_key.numel() + grad_value = next_grad_kv[pointer : pointer + grad_value.numel()].reshape( + grad_value.shape + ) + + if i <= rank and _cp_options.enable_load_balance: + grad_key = _partial_update( + grad_key, + grad_key_, + dim=seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=0, + add=True, + ) + grad_value = _partial_update( + grad_value, + grad_value_, + dim=seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=0, + add=True, + ) + else: + grad_key += grad_key_ + grad_value += grad_value_ + + next_grad_kv = torch.cat([grad_key.flatten(), grad_value.flatten()]) + # Send the grad key and grad value to the next rank. + dkv_rotater.exchange_buffers(next_grad_kv) + + if i <= rank or not _cp_options.enable_load_balance: + grad_query += grad_query_ + else: + grad_query = _partial_update( + grad_query, + grad_query_, + dim=seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=1, + add=True, + ) + + assert grad_key_ is not None + assert grad_value_ is not None + grad_query = grad_query.to(query.dtype) + next_grad_kv = dkv_rotater.next_buffer().to(key.dtype) + grad_key = next_grad_kv[: grad_key.numel()].reshape(grad_key.shape) + grad_value = next_grad_kv[grad_key.numel() :].reshape(grad_value.shape) + return ( + grad_query, + grad_key, + grad_value, + # pyrefly: ignore [unbound-name] + *rest, + ) + + +def _scaled_dot_product_ring_flash_attention( + mesh: DeviceMesh, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + *, + scale: Optional[float] = None, +) -> tuple[torch.Tensor, ...]: + if return_debug_mask: + raise NotImplementedError("return_debug_mask is not supported yet") + + # TODO: remove this hardcoding + seq_dim = 2 + group = mesh.get_group() + return _templated_ring_attention( + group, + seq_dim, + aten._scaled_dot_product_flash_attention, + query=query, + key=key, + value=value, + is_causal=is_causal, + dropout_p=dropout_p, + scale=scale, + ) + + +def _scaled_dot_product_ring_efficient_attention( + mesh: DeviceMesh, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, + compute_log_sumexp: bool = True, + dropout_p: float = 0.0, + is_causal: bool = False, + *, + scale: Optional[float] = None, +) -> tuple[torch.Tensor, ...]: + if attn_bias is not None: + raise NotImplementedError("attn_bias is not supported yet") + + if not compute_log_sumexp: + # CP requires compute_log_sumexp to be True because it always merges LSE + compute_log_sumexp = True + + # TODO: remove this hardcoding + seq_dim = 2 + group = mesh.get_group() + return _templated_ring_attention( + group, + seq_dim, + aten._scaled_dot_product_efficient_attention, + query=query, + key=key, + value=value, + is_causal=is_causal, + attn_bias=attn_bias, + dropout_p=dropout_p, + scale=scale, + compute_log_sumexp=compute_log_sumexp, + ) + + +def _scaled_dot_product_ring_cudnn_attention( + mesh: DeviceMesh, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, + compute_log_sumexp: bool = True, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + *, + scale: Optional[float] = None, +) -> tuple[torch.Tensor, ...]: + if attn_bias is not None: + raise NotImplementedError("attn_bias is not supported yet") + + if not compute_log_sumexp: + # CP requires compute_log_sumexp to be True because it always merges LSE + compute_log_sumexp = True + + # TODO: remove this hardcoding + seq_dim = 2 + group = mesh.get_group() + return _templated_ring_attention( + group, + seq_dim, + aten._scaled_dot_product_cudnn_attention, + query=query, + key=key, + value=value, + attn_bias=attn_bias, + compute_log_sumexp=compute_log_sumexp, + dropout_p=dropout_p, + is_causal=is_causal, + return_debug_mask=return_debug_mask, + scale=scale, + ) + + +def _scaled_dot_product_ring_flash_attention_backward( + mesh: DeviceMesh, + grad_out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + cum_seq_q: torch.Tensor, + cum_seq_k: torch.Tensor, + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + philox_seed: torch.Tensor, + philox_offset: torch.Tensor, + *, + scale: Optional[float] = None, +) -> tuple[torch.Tensor, ...]: + # TODO: remove this hardcoding + seq_dim = 2 + group = mesh.get_group() + return _templated_ring_attention_backward( + group, + seq_dim, + aten._scaled_dot_product_flash_attention_backward.default, + grad_out=grad_out, + grad_out_name="grad_out", + query=query, + key=key, + value=value, + out=out, + logsumexp=logsumexp, + is_causal=is_causal, + cum_seq_q=cum_seq_q, + cum_seq_k=cum_seq_k, + max_q=max_q, + max_k=max_k, + dropout_p=dropout_p, + philox_seed=philox_seed, + philox_offset=philox_offset, + scale=scale, + ) + + +def _scaled_dot_product_ring_efficient_attention_backward( + mesh: DeviceMesh, + grad_out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + bias: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + philox_seed: torch.Tensor, + philox_offset: torch.Tensor, + dropout_p: float, + grad_input_mask: tuple[bool, ...], + is_causal: bool = False, + *, + scale: Optional[float] = None, +) -> tuple[torch.Tensor, ...]: + # TODO: remove this hardcoding + seq_dim = 2 + group = mesh.get_group() + return _templated_ring_attention_backward( + group, + seq_dim, + aten._scaled_dot_product_efficient_attention_backward.default, + grad_out=grad_out, + grad_out_name="grad_out_", + query=query, + key=key, + value=value, + attn_bias=bias, + out=out, + logsumexp=logsumexp, + philox_seed=philox_seed, + philox_offset=philox_offset, + dropout_p=dropout_p, + grad_input_mask=grad_input_mask, + is_causal=is_causal, + scale=scale, + ) + + +def _scaled_dot_product_ring_cudnn_attention_backward( + mesh: DeviceMesh, + grad_out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + philox_seed: torch.Tensor, + philox_offset: torch.Tensor, + attn_bias: torch.Tensor, + cum_seq_q: torch.Tensor, + cum_seq_k: torch.Tensor, + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + *, + scale: Optional[float] = None, +) -> tuple[torch.Tensor, ...]: + # TODO: remove this hardcoding + seq_dim = 2 + group = mesh.get_group() + return _templated_ring_attention_backward( + group, + seq_dim, + aten._scaled_dot_product_cudnn_attention_backward.default, + grad_out=grad_out, + grad_out_name="grad_out", + query=query, + key=key, + value=value, + out=out, + logsumexp=logsumexp, + philox_seed=philox_seed, + philox_offset=philox_offset, + attn_bias=attn_bias, + cum_seq_q=cum_seq_q, + cum_seq_k=cum_seq_k, + max_q=max_q, + max_k=max_k, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + ) + + +def _sdpa_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + # extract local tensor and sharding infos to a OpInfo + op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + logger.debug("Dispatching op_call: %s", op_info.schema) + + # sharding propagation + # TODO: remove the context parallel strategy from the default propagation + # rule. Either figure out how to dynamically enable it or just don't call + # propagate. + DTensor._op_dispatcher.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + assert not output_sharding.needs_redistribute, "inputs need to be redistributed" + + call_maps: dict[torch._ops.OpOverload, Callable] = { + aten._scaled_dot_product_flash_attention.default: _scaled_dot_product_ring_flash_attention, + aten._scaled_dot_product_efficient_attention.default: _scaled_dot_product_ring_efficient_attention, + aten._scaled_dot_product_cudnn_attention.default: _scaled_dot_product_ring_cudnn_attention, + aten._scaled_dot_product_flash_attention_backward.default: _scaled_dot_product_ring_flash_attention_backward, + aten._scaled_dot_product_efficient_attention_backward.default: _scaled_dot_product_ring_efficient_attention_backward, + aten._scaled_dot_product_cudnn_attention_backward.default: _scaled_dot_product_ring_cudnn_attention_backward, + } + if op_call in call_maps: + local_results = call_maps[op_call]( + op_info.compute_mesh, + *op_info.local_args, # type: ignore[arg-type] + **op_info.local_kwargs, # type: ignore[arg-type] + ) + else: + raise NotImplementedError( + "CP only supports flash attention and memory efficient attention now." + ) + + return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec) + + +custom_ops = { + aten._scaled_dot_product_flash_attention.default: _sdpa_handler, + aten._scaled_dot_product_flash_attention_backward.default: _sdpa_handler, + aten._scaled_dot_product_efficient_attention.default: _sdpa_handler, + aten._scaled_dot_product_efficient_attention_backward.default: _sdpa_handler, + aten._scaled_dot_product_cudnn_attention.default: _sdpa_handler, + aten._scaled_dot_product_cudnn_attention_backward.default: _sdpa_handler, +} +exitsing_custom_ops = DTensor._op_dispatcher._custom_op_handlers + + +ArgsType = tuple[Any, ...] +KwargsType = dict[str, Any] +InputFnType = Callable[[Optional[nn.Module], ArgsType, KwargsType, DeviceMesh], Any] +OutputFnType = Callable[[Optional[nn.Module], Any, Any, DeviceMesh], Any] + +_replaced_functions: dict[Callable, tuple[str, Callable]] = {} + + +def _distribute_function( + fn: Callable, + fn_module: types.ModuleType, + device_mesh: DeviceMesh, + input_fn: InputFnType, + output_fn: OutputFnType, +) -> None: + """ + A helper function to replace a function with a distributed version by + using the monkey patching approach. + + This function is for the CP internal usage only. + """ + + def wrapper( + target_fn: Callable, input_fn: InputFnType, output_fn: OutputFnType + ) -> Callable: + def inner_fn(*args: ArgsType, **kwargs: KwargsType) -> Any: + args, kwargs = input_fn(None, args, kwargs, device_mesh) + outputs = target_fn(*args, **kwargs) + return output_fn(None, (args, kwargs), outputs, device_mesh) + + return inner_fn + + global _replaced_functions + + if fn in _replaced_functions: + return + + wrapper_fn = wrapper(fn, input_fn, output_fn) + setattr(fn_module, fn.__name__, wrapper_fn) + _replaced_functions[wrapper_fn] = (fn.__name__, fn) + + +def _restore_function(fn: Callable, fn_module: types.ModuleType) -> None: + """Restore the function that is replaced by _distribute_function.""" + if fn not in _replaced_functions: + return + + original_name, original_fn = _replaced_functions[fn] + setattr(fn_module, original_name, original_fn) + + +def _enable_cp_dtensor_dispatcher() -> None: + """Enables DTensor dispatcher to dispatch SDPA to CP.""" + DTensor._op_dispatcher._custom_op_handlers = { + **exitsing_custom_ops, + **custom_ops, + } + + +def _disable_cp_dtensor_dispatcher() -> None: + """Disables DTensor dispatcher to dispatch SDPA to CP.""" + DTensor._op_dispatcher._custom_op_handlers = exitsing_custom_ops + + +def _enable_context_parallel_dispatcher_impl(seq_dim: int, mesh: DeviceMesh) -> None: + sdpa_cp = _ContextParallel( + seq_dim=seq_dim, + attention_type=_ContextParallel.AttentionType.SDPA, + ) + + if _dispatch_mode == _DispatchMode.MONKEY_PATCH: + _distribute_function( + F.scaled_dot_product_attention, + F, + mesh, + sdpa_cp.sdpa_input_fn, + sdpa_cp.sdpa_output_fn, + ) + _enable_cp_dtensor_dispatcher() + elif _dispatch_mode == _DispatchMode.MODULE_WRAPPER: + _enable_cp_dtensor_dispatcher() + else: + raise ValueError(f"Unknown dispatch mode: {_dispatch_mode}") + + +def _disable_context_parallel_dispatcher_impl() -> None: + if _dispatch_mode == _DispatchMode.MONKEY_PATCH: + _restore_function(F.scaled_dot_product_attention, F) + elif _dispatch_mode == _DispatchMode.MODULE_WRAPPER: + pass + else: + raise NotImplementedError(f"Unknown dispatch mode: {_dispatch_mode}") + + _disable_cp_dtensor_dispatcher() + + +_compiled_create_block_mask = torch.compile( + create_block_mask, dynamic=False, fullgraph=True +) + + +def _context_parallel_buffers( + mesh: DeviceMesh, + buffers: list[torch.Tensor | BlockMask], + buffer_seq_dims: list[int], + load_balancer: Optional[_LoadBalancer] = None, +) -> list[torch.Tensor | BlockMask]: + """ + Shard the buffers along the sequence dimensions according to CP rules. + Args: + mesh (:class:`DeviceMesh`): the device mesh for the context parallelism. + buffers (List[torch.Tensor]): the buffers to be sharded. + seq_dims (List[int]): the sequence dimensions of ``buffers``. This list + must have the same length as ``buffers``. + load_balancer (Optional[:class:`_LoadBalancer`]): an optional `_LoadBalancer` + object. If this argument is `None`, it means the `buffers` need no + rearrangement before being sharded. If this argument is a `_LoadBalancer` + object, call its `_generate_indices(restore=False)` to generate the + rearrangement indices such that each shard of `buffer[rearrange_idx]` is + well-balanced (i.e., having close sparsities). + + Returns: + List[torch.Tensor]: the sharded buffers. + + Note: + For `_context_parallel_shard` we require a non-None `load_balancer` object to be + explicitly passed if load-balancing is needed. + """ + # generate the index tensor for rearranging the buffer if a load-balance + # is available + load_balance_indices = load_balancer._generate_indices() if load_balancer else None + assert load_balance_indices is None or load_balance_indices.ndim == 2, ( + "load balance index expects shape (1, seq_len) or (B, seq_len) " + f"but got {load_balance_indices.shape}." + ) + + new_buffers = [] + sharded_buffer: torch.Tensor | BlockMask + for buffer, seq_dim in zip(buffers, buffer_seq_dims): + if isinstance(buffer, torch.Tensor): + # TODO: the load balance doesn't perform error handling. + + # NOTE: assuming batch dim is 0 + + if load_balance_indices is not None: + # TODO: we should expclitly ask users to unsqueeze the batch dim. + # But this is a BC breaking ask. + # However, what we have done today is also not very safe. + idx_batch_size = load_balance_indices.size(0) + data_batch_size = buffer.size(0) if seq_dim > 0 else 1 + + if idx_batch_size != 1 and idx_batch_size != data_batch_size: + raise ValueError( + "Cannot rearrange buffer: " + f"load_balance_indices has shape {load_balance_indices.shape}, " + f"but buffer has shape {buffer.shape}." + ) + + if seq_dim == 0: + buffer = torch.index_select( + buffer, dim=0, index=load_balance_indices[0] + ) + else: + indices = load_balance_indices + if idx_batch_size == 1: + size = [data_batch_size] + list(indices.size())[1:] + indices = indices.expand(*size) + + for i in range(data_batch_size): + buffer[i] = torch.index_select( + buffer[i], dim=seq_dim - 1, index=indices[i] + ) + + # use DTensor to shard the buffer on sequence dimension, retain the local tensor + sharded_buffer = distribute_tensor( + buffer, mesh, [Shard(seq_dim)], src_data_rank=None + ).to_local() + elif isinstance(buffer, BlockMask): + sharded_buffer = _create_cp_block_mask( + mask_mod=buffer.mask_mod, + B=buffer.kv_num_blocks.shape[0], + H=buffer.kv_num_blocks.shape[1], + Q_LEN=buffer.seq_lengths[0], + KV_LEN=buffer.seq_lengths[1], + device_mesh=mesh, + load_balancer=load_balancer, + ) + else: + raise ValueError(f"Unknown buffer type: {type(buffer)}") + + new_buffers.append(sharded_buffer) + + return new_buffers + + +def _create_cp_block_mask( + mask_mod: _mask_mod_signature, + B: int, + H: int, + Q_LEN: int, + KV_LEN: int, + device_mesh: DeviceMesh, + load_balancer: Optional[_LoadBalancer] = None, +) -> BlockMask: + """ + Creates a specialized BlockMask for Context Parallel FlexAttention. + + This function creates a BlockMask that enables computation of attention results + for sharded Q attending to global KV. The mask appropriately handles the query + index offset required when each rank operates on a shard of the query sequence + while accessing the full key-value sequence. + + The function internally rewrites the provided mask_mod function to translate local + query indices to global query indices, ensuring that the masking logic is applied + correctly across the distributed computation. + + Args: + mask_mod (Callable): Mask function that operates on global attention indices. + B (int): Batch size. + H (int): Number of query heads. + Q_LEN (int): Global sequence length of the query. + KV_LEN (int): Global sequence length of the key/value. + device_mesh (DeviceMesh): Device mesh used for context parallelism. + load_balancer (Optional[:class:`_LoadBalancer`]): The load-balancer used to rearrange + QKV before sharding. This will be used to modify the block_mask generated. + + Returns: + BlockMask: A block mask configured for the local query shard that can be used + with flex_attention() for the given cp_mesh. + + Raises: + NotImplementedError: If Q_LEN is not divisible by (CP world size * BLOCK_SIZE). + + Warning: + Currently requires Q_LEN to be divisible by CP mesh world size * BLOCK_SIZE + (BLOCK_SIZE defaults to 128). This constraint exists because the BlockMask + must handle both padding and offsets correctly. For example, if Q_LEN is 384, + CP world size is 2, and BLOCK_SIZE is 128, the local Q_LEN would be 192. In + such cases, both rank0 and rank1 would have paddings in their local BlockMasks. + Support for padding in this scenario is planned for future work. + + """ + + from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE + + if Q_LEN % (device_mesh.size() * _DEFAULT_SPARSE_BLOCK_SIZE) != 0: + raise NotImplementedError( + f"Q_LEN {Q_LEN} is not divisible by CP mesh world size {device_mesh.size()} * " + f"BLOCK_SIZE {_DEFAULT_SPARSE_BLOCK_SIZE}. This is not supported yet. " + ) + + compiled_create_block_mask = torch.compile( + create_block_mask, dynamic=False, fullgraph=True + ) + + def _rewrite_mask_mod( + mask_mod: _mask_mod_signature, + rank: int, + block_size: int, + local_q_size: int, + qkv_rearrange_indices: Optional[torch.Tensor] = None, + ) -> _mask_mod_signature: + assert qkv_rearrange_indices is None or qkv_rearrange_indices.ndim == 2, ( + "load balance index expects shape (1, seq_len) or (B, seq_len) " + f"but got {qkv_rearrange_indices.shape}." + ) + + def qkv_idx_restore( + b: torch.Tensor, idx_post_rearrange: torch.Tensor + ) -> torch.Tensor: + if qkv_rearrange_indices is not None: + if ( + qkv_rearrange_indices.size(0) == 1 + ): # identical load-balance in batch + idx_pre_rearrange = qkv_rearrange_indices[0][idx_post_rearrange] + else: + idx_pre_rearrange = qkv_rearrange_indices[b][idx_post_rearrange] + else: + idx_pre_rearrange = idx_post_rearrange + + return idx_pre_rearrange + + def local_q_idx_to_q_idx(local_q_idx: torch.Tensor) -> torch.Tensor: + # calculate local block_idx and block_offset + local_blk_idx, local_blk_offset = ( + local_q_idx // block_size, + local_q_idx % block_size, + ) + # NOTE: load balancing is not used + local_num_blocks = local_q_size // block_size + blk_idx = local_num_blocks * rank + local_blk_idx + return blk_idx * block_size + local_blk_offset + + return lambda b, h, q_idx, kv_idx: mask_mod( + b, + h, + qkv_idx_restore(b, local_q_idx_to_q_idx(q_idx)), + qkv_idx_restore(b, kv_idx), + ) + + cp_rank = device_mesh.get_local_rank() + cp_group_size = device_mesh.size() + load_balancer = load_balancer or _create_default_load_balancer( + Q_LEN, cp_group_size, device_mesh.device_type + ) + Q_SHARD_LEN = Q_LEN // cp_group_size + block_size = _DEFAULT_SPARSE_BLOCK_SIZE + + rearrange_indices = ( + load_balancer._generate_indices(restore=False) if load_balancer else None + ) + block_mask = compiled_create_block_mask( + _rewrite_mask_mod( + mask_mod, + cp_rank, + block_size, + Q_SHARD_LEN, + qkv_rearrange_indices=rearrange_indices, + ), + B, + H, + Q_SHARD_LEN, + KV_LEN, + device=device_mesh.device_type, + BLOCK_SIZE=(block_size, block_size), + ) + return block_mask + + +##################### +# Experimental APIs +##################### + + +class _ContextParallel(ParallelStyle): + class AttentionType(Enum): + FLEX = "flex_attention" + SDPA = "scaled_dot_product_attention" + + def __init__( + self, + seq_dim: int, + attention_type: AttentionType, + ) -> None: + super().__init__() + self.seq_dim = seq_dim + self.attention_type = attention_type + + def _apply(self, module: nn.Module, mesh: DeviceMesh) -> nn.Module: + if self.attention_type == self.AttentionType.FLEX: + module.register_forward_pre_hook( + partial(self.flex_input_fn, mesh=mesh), with_kwargs=True + ) + return module + elif self.attention_type == self.AttentionType.SDPA: + module.register_forward_pre_hook( + partial(self.sdpa_input_fn, mesh=mesh), with_kwargs=True + ) + module.register_forward_hook(partial(self.sdpa_output_fn, mesh=mesh)) + return module + else: + raise ValueError(f"Unknown attention type: {self.attention_type}") + + def flex_input_fn( + self, module: Optional[nn.Module], args: Any, kwargs: Any, mesh: DeviceMesh + ) -> Any: + args_list = list(args) + for idx, name in enumerate( + ("query", "key", "value", "score_mod", "block_mask") + ): + if idx >= len(args): + args_list.append(kwargs.pop(name, None)) + + query, key, value, score_mod, block_mask = args_list[:5] + assert isinstance(query, torch.Tensor) + assert isinstance(key, torch.Tensor) + assert isinstance(value, torch.Tensor) + assert isinstance(block_mask, BlockMask | tuple) + + key = key.contiguous() + value = value.contiguous() + + global_key, global_value = flex_cp_allgather( + key, value, self.seq_dim, c10d._get_process_group_name(mesh.get_group()) + ) + args_list[1] = global_key + args_list[2] = global_value + + return tuple(args_list), kwargs + + def sdpa_input_fn( + self, + module: Optional[nn.Module], + args: tuple[Any, ...], + kwargs: dict[str, Any], + mesh: DeviceMesh, + ) -> tuple[tuple[Any, ...], dict[str, Any]]: + placement = [Shard(self.seq_dim)] + all_args = [] + + for arg in itertools.chain(args, kwargs.values()): + if isinstance(arg, torch.Tensor): + if isinstance(arg, DTensor): + assert arg._spec.placements == placement + else: + arg = DTensor.from_local(arg, mesh, placement, run_check=False) + + all_args.append(arg) + + new_args = tuple(all_args[0 : len(args)]) + new_kwargs = dict(zip(kwargs.keys(), all_args[len(args) :])) + return new_args, new_kwargs + + def sdpa_output_fn( + self, module: Optional[nn.Module], inputs: Any, outputs: Any, mesh: DeviceMesh + ) -> Any: + new_outputs = [] + for output in [outputs] if isinstance(outputs, torch.Tensor) else outputs: + output = output.to_local() if isinstance(output, DTensor) else output + new_outputs.append(output) + + if isinstance(outputs, torch.Tensor): + return new_outputs[0] + + return tuple(new_outputs) + + +CPBuffer: TypeAlias = torch.Tensor | BlockMask +CPBufferContainer: TypeAlias = Sequence[CPBuffer] | Mapping[str, CPBuffer] +CPBufferSeqDims: TypeAlias = Sequence[int] | Mapping[str, int] + + +def _context_parallel_shard( + mesh: DeviceMesh, + buffers: CPBufferContainer, + seq_dims: CPBufferSeqDims, + load_balancer: Optional[_LoadBalancer] = None, +) -> list[torch.Tensor | BlockMask]: + """ + Shard the buffers along the specified sequence dimensions (`seq_dims`), so that each + rank retains only its corresponding shard according to the provided `mesh`. If a + `load_balancer` is provided, the buffers will be rearranged by the load balancer + before sharding to improve load balance. Buffers can be either tensors or `BlockMask` + objects. If a buffer is a `BlockMask`, its sharding dimension is determined by the + `BlockMask` implementation, and the corresponding `seq_dim` is ignored. + + Note: + For `_context_parallel_shard`, a non-None `load_balancer` must be explicitly passed + if load balancing is required. + + Args: + mesh (DeviceMesh): The device mesh used for context parallelism. + buffers (List[torch.Tensor | BlockMask]): Buffers whose usage depends on the sequence + dimension. Examples include input batches, labels, and positional embedding buffers. + These buffers must be sharded along the sequence dimension to ensure correctness. + seq_dims (List[int]): The sequence dimensions for each buffer in `buffers`. Must have + the same length as `buffers`. + load_balancer (Optional[_LoadBalancer]): An optional load balancer object. If provided, + it rearranges the buffers before sharding to achieve better load balance. If not + provided, no rearrangement is performed. + + Returns: + List[torch.Tensor | BlockMask]: The sharded buffers, each corresponding to the local + shard for the current rank. + """ + # TODO: these global variables are going to bite us someday. + # We will have to remove them soon. + # For the new API, we only support the module wrapper mode. + global _dispatch_mode + _dispatch_mode = _DispatchMode.MODULE_WRAPPER + global _cp_options + if load_balancer is not None: + _cp_options.enable_load_balance = True + else: + _cp_options.enable_load_balance = False + + if len(buffers) != len(seq_dims): + raise ValueError( + "`seq_dims` must have the same number of elements as `buffers`." + ) + + flat_buffers, spec = tree_flatten(buffers) + flat_seq_dims, _ = tree_flatten(seq_dims) + if len(flat_buffers) != len(flat_seq_dims): + raise ValueError("`seq_dims` must have the pytree structure as `buffers`.") + + if isinstance(flat_buffers[0], torch.Tensor): + device = flat_buffers[0].device + else: + device = flat_buffers[0].kv_num_blocks.device + for buffer in flat_buffers: + if isinstance(buffer, torch.Tensor): + assert device == buffer.device, "All buffers must be on the same device" + else: + assert device == buffer.kv_num_blocks.device, ( + "All buffers must be on the same device" + ) + + flat_sharded_buffers = _context_parallel_buffers( + mesh, flat_buffers, flat_seq_dims, load_balancer + ) + + return tree_unflatten(flat_sharded_buffers, spec) + + +def _enable_context_parallel_dispatcher() -> None: + """ + Enable the context parallel dispatcher. This API is experimental and subject to change. + """ + _enable_cp_dtensor_dispatcher() + + +def _disable_context_parallel_dispatcher() -> None: + """ + Disable the context parallel dispatcher. This API is experimental and subject to change. + """ + _disable_cp_dtensor_dispatcher() + + +##################################################### +# Current public APIs, but are also subject to change +##################################################### +@contextlib.contextmanager +@torch.no_grad() +def context_parallel( + mesh: DeviceMesh, + *, + buffers: Optional[list[torch.Tensor]] = None, + buffer_seq_dims: Optional[list[int]] = None, + no_restore_buffers: Optional[set[torch.Tensor]] = None, +) -> Generator[None, None, None]: + """ + + ``context_parallel`` is an experimental API to enable context + parallelism (CP). This API performs two actions: 1) patch the SDPA + (``torch.nn.functional.scaled_dot_product_attention``) with the CP-enabled + one, 2) shard ``buffers`` along the sequence dimension and each rank will + preserve the corresponding shard according ``mesh``. + + Args: + mesh (:class:`DeviceMesh`): the device mesh for the context parallelism. + buffers (Optional[List[torch.Tensor]]): buffers that the usage depend + on the sequence dimension. Examples are input batch, labels and + positional embedding buffers. These buffers must be sharded along + the sequence dimension to ensure the accuracy. The sharding will + happen in-place, the buffer's shape will change within the context. + The buffers will be restored after the context finishes. + ``no_restore_buffers`` can be used to specify which buffers don't + need to be restored. Note that ``buffers`` should not contain any + nn.Parameter. + buffer_seq_dims (Optional[List[int]]): the sequence dimensions of ``buffers``. + no_restore_buffers (Optional[Set[torch.Tensor]]): buffers in these set + won't be restored after the context exits. This set must be a subset + of ``buffers``. If the buffers won't be used after the context exits, + these buffers can be put in this list to avoid extra restore time. + + .. warning:: + `torch.distributed.tensor.experimental.context_parallel` is a + prototype feature in PyTorch. The API is subject to change. + """ + # For the legacy API, we only support the monkey-patch mode. + # We will deprecate this API once the new API is widely used. + global _dispatch_mode + _dispatch_mode = _DispatchMode.MONKEY_PATCH + + buffers = [] if buffers is None else buffers + buffer_seq_dims = [] if buffer_seq_dims is None else buffer_seq_dims + no_restore_buffers = set() if no_restore_buffers is None else no_restore_buffers + + if len(buffers) != len(buffer_seq_dims): + raise ValueError( + "`seq_dims` must have the same number of elements as `buffers`." + ) + + for buffer in no_restore_buffers: + # Cannot use `if not buffer in buffers` which will incur tensor comparison. + if not any(b is buffer for b in buffers): + raise ValueError("`no_restore_buffers` must be a subset of `buffers`.") + + original_buffers = [None if b in no_restore_buffers else b.clone() for b in buffers] + + device = buffers[0].device + seq_length = buffers[0].shape[buffer_seq_dims[0]] + cp_world_size = mesh.size() + + # If `enable_load_balance` is True, the default Head-tail load balancer + # (:class:`_HeadTailLoadBalancer`) is used to rearrange the buffers before + # sharding. Otherwise, we don't do any load-balance rearrange by passing + # `None` to `_context_parallel_shard()`. + load_balancer = _create_default_load_balancer(seq_length, cp_world_size, device) + shards = _context_parallel_buffers( + mesh, + cast(list[torch.Tensor | BlockMask], buffers), + buffer_seq_dims, + load_balancer, + ) + for buffer, shard in zip(buffers, shards): + assert isinstance(shard, torch.Tensor), "ContextParallel only supports Tensor" + shard = shard.clone() + buffer.resize_(shard.shape) + buffer.copy_(shard) + + _enable_context_parallel_dispatcher_impl(seq_dim=2, mesh=mesh) + yield + _disable_context_parallel_dispatcher_impl() + + for buffer, original_buffer in zip(buffers, original_buffers): + if original_buffer is not None: + buffer.resize_(original_buffer.shape) + buffer.copy_(original_buffer) + + +@torch.no_grad() +def context_parallel_unshard( + mesh: DeviceMesh, + buffers: list[torch.Tensor], + seq_dims: list[int], + load_balancer: Optional[_LoadBalancer] = None, +) -> list[torch.Tensor]: + """ + Unshard the tensors (e.g., output) that are sharded due to context parallelism. + + Args: + mesh (:class:`DeviceMesh`): the device mesh for the context parallelism. + buffers (List[torch.Tensor]): the buffers to be unsharded. + seq_dims (List[int]): the sequence dimensions of ``buffers``. This list + must have the same length as ``buffers``. + load_balancer (Optional[:class:`_Loadbalancer`]): an optional `_LoadBalancer` + object. If this argument is `None`, it means the `buffers` were not + rearranged when being sharded and there's no need to put it back to order + after unsharding. If this argument is a `_LoadBalancer` object, call + its `_generate_indices(restore=True)` to generate the restore indices such + that `unsharded[restore_idx]` is the original buffer. + + Returns: + List[torch.Tensor]: the unsharded buffers. + + Note: + For `context_parallel_unshard` we require not-None `load_balancer` object be + explicitly passed if flex_attention() is to be used and load-balancing is needed. + This is different from the case of SDPA though we strongly suggest users follow + the same convention. + """ + device = buffers[0].device + cp_world_size = mesh.size() + seq_length = buffers[0].shape[seq_dims[0]] * cp_world_size + + # If users don't pass in a `load_balancer`: + # - if `enable_load_balance` is True, we use the default round-robin + # load balancer. + # - if `enable_load_balance` is False, we don't do any load balancing + # by passing in `None` as `restore_indices`. + load_balancer = load_balancer or _create_default_load_balancer( + seq_length, cp_world_size, device + ) + restore_indices = ( + load_balancer._generate_indices(restore=True) if load_balancer else None + ) + + assert restore_indices is None or restore_indices.ndim == 2, ( + "load balance restore index expects shape (1, seq_len) or (B, seq_len) " + f"but got {restore_indices.shape}." + ) + unsharded_buffers = [] + for b, dim in zip(buffers, seq_dims): + b = b.contiguous() + unsharded_b = _maybe_wait(ft_c.all_gather_tensor(b, dim, mesh)) + + if restore_indices is not None: + # NOTE: assuming batch dim is 0 + idx_batch_size = restore_indices.size(0) + data_batch_size = unsharded_b.size(0) + if idx_batch_size != 1 and idx_batch_size != data_batch_size: + raise ValueError( + "Cannot restore buffer: " + f"restore_indices has shape {restore_indices.shape}, " + f"but unsharded_b has shape {unsharded_b.shape}." + ) + + for i in range(data_batch_size): + index = ( + restore_indices[0] # identical load-balance in batch + if idx_batch_size == 1 + else restore_indices[i] + ) + unsharded_b_batch_i = torch.index_select( + unsharded_b[i], dim=dim - 1, index=index + ) + unsharded_b[i] = unsharded_b_batch_i + + unsharded_buffers.append(unsharded_b) + + return unsharded_buffers + + +def set_rotate_method(rotate_method: str) -> None: + """ + Context Parallel SDPA requires the rotation of kv shards. Users can call this + API to specify which rotation method to use. "alltoall" shuffles the kv shards + using all-to-all collective. While "allgather" gathers the kv shards using + all-gather collective after the first sub-SDPA computation. If this API has not + been called, the default rotate method is "allgather". + + Args: + rotate_method (str): the rotate method to use. Currently only supports + "allgather" and "alltoall". If a different string other than these two + is passed in, the function will raise an error. + + Returns: + None + """ + logger.info("Note that FlexAttention CP doesn't support alltoall yet.") + if rotate_method == "allgather": + _cp_options.rotate_method = _RotateMethod.ALL_GATHER + elif rotate_method == "alltoall": + _cp_options.rotate_method = _RotateMethod.ALL_TO_ALL + else: + raise NotImplementedError( + "Context Parallel does not support " + f"using {rotate_method} for kv shards rotation" + ) diff --git a/torch/distributed/tensor/experimental/_cp_custom_ops.py b/torch/distributed/tensor/experimental/_context_parallel/_cp_custom_ops.py similarity index 100% rename from torch/distributed/tensor/experimental/_cp_custom_ops.py rename to torch/distributed/tensor/experimental/_context_parallel/_cp_custom_ops.py diff --git a/torch/distributed/tensor/experimental/_load_balancer.py b/torch/distributed/tensor/experimental/_context_parallel/_load_balancer.py similarity index 99% rename from torch/distributed/tensor/experimental/_load_balancer.py rename to torch/distributed/tensor/experimental/_context_parallel/_load_balancer.py index befda2c736ed5..e5230092b41d7 100644 --- a/torch/distributed/tensor/experimental/_load_balancer.py +++ b/torch/distributed/tensor/experimental/_context_parallel/_load_balancer.py @@ -479,7 +479,7 @@ def _generate_indices(self, restore: bool = False) -> Tensor: def _create_default_load_balancer( seq_length: int, world_size: int, device: str | torch.device ) -> Optional[_LoadBalancer]: - from torch.distributed.tensor.experimental._attention import _cp_options + from ._attention import _cp_options if _cp_options.enable_load_balance: return _HeadTailLoadBalancer(seq_length, world_size, device) diff --git a/torch/distributed/tensor/experimental/_tp_transform.py b/torch/distributed/tensor/experimental/_tp_transform.py index f66ab2b2e39d2..426eb2ac83b38 100644 --- a/torch/distributed/tensor/experimental/_tp_transform.py +++ b/torch/distributed/tensor/experimental/_tp_transform.py @@ -203,7 +203,7 @@ def _mark_sharding( ) node.meta["sharding"] = placement_strategies[node] elif node.op == "call_function": - if node.target == operator.getitem: + if node.target is operator.getitem: input_nodes = node.all_input_nodes assert len(input_nodes) == 1, ( f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}" diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index 2b1a88f1a126f..275814693354f 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -69,9 +69,8 @@ def _unpack_kwargs( flat_args: tuple[Any, ...], kwarg_keys: tuple[str, ...] ) -> tuple[tuple[Any, ...], dict[str, Any]]: """See _pack_kwargs.""" - assert len(kwarg_keys) <= len(flat_args), ( - f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" - ) + if len(kwarg_keys) > len(flat_args): + raise AssertionError(f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}") if len(kwarg_keys) == 0: return flat_args, {} args = flat_args[: -len(kwarg_keys)] @@ -127,7 +126,8 @@ def to_map(obj): if isinstance(obj, PackedSequence): output.data.record_stream(current_stream) # type: ignore[arg-type] else: - assert isinstance(output, torch.Tensor) + if not isinstance(output, torch.Tensor): + raise AssertionError("output must be a torch.Tensor") output.record_stream(current_stream) # type: ignore[arg-type] return (output,) diff --git a/torch/export/_remove_auto_functionalized_pass.py b/torch/export/_remove_auto_functionalized_pass.py index 67f84e49af643..1b48339276567 100644 --- a/torch/export/_remove_auto_functionalized_pass.py +++ b/torch/export/_remove_auto_functionalized_pass.py @@ -17,7 +17,7 @@ def remove_self_clone(graph: Graph) -> None: for node in graph.nodes: - if node.target == torch.ops.aten.copy_.default and node.args[0] == node.args[1]: + if node.target is torch.ops.aten.copy_.default and node.args[0] == node.args[1]: node.replace_all_uses_with(node.args[0]) graph.erase_node(node) diff --git a/torch/export/_remove_effect_tokens_pass.py b/torch/export/_remove_effect_tokens_pass.py index bde7eb6042245..21930d81fe092 100644 --- a/torch/export/_remove_effect_tokens_pass.py +++ b/torch/export/_remove_effect_tokens_pass.py @@ -52,7 +52,7 @@ def _remove_effect_tokens_from_graph_helper( func = node.args[1] assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)) - if func == torch.ops.higher_order.call_torchbind: + if func is torch.ops.higher_order.call_torchbind: custom_obj_meta = node.args[2].meta["val"] # type: ignore[union-attr] assert isinstance(custom_obj_meta, CustomObjArgument) if custom_obj_meta.fake_val: @@ -83,7 +83,7 @@ def _remove_effect_tokens_from_graph_helper( # Update user getitem nodes for user in list(new_node.users.keys()): - assert user.target == operator.getitem + assert user.target is operator.getitem # getitem(with_effects, 0) == token if user.args[1] == 0: ep.graph.erase_node(user) diff --git a/torch/export/_swap.py b/torch/export/_swap.py index 333d70c2b64d5..6c93bb8c33a74 100644 --- a/torch/export/_swap.py +++ b/torch/export/_swap.py @@ -26,7 +26,7 @@ def _get_getitem_users(node: torch.fx.Node) -> set[torch.fx.Node]: if user.op == "output": continue - assert user.op == "call_function" and user.target == operator.getitem, ( + assert user.op == "call_function" and user.target is operator.getitem, ( f"Expected getitem node as user for {node}, instead got {user}" ) getitem_users.update(list(user.users.keys())) @@ -69,7 +69,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None: flatten_node = curr_module_users[0] assert ( flatten_node.op == "call_function" - and flatten_node.target == fx_pytree.tree_flatten_spec + and flatten_node.target is fx_pytree.tree_flatten_spec ) flatten_getitem_users = _get_getitem_users(flatten_node) @@ -85,7 +85,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None: unflatten_node = next(iter(flatten_getitem_users)) if not ( unflatten_node.op == "call_function" - and unflatten_node.target == pytree.tree_unflatten + and unflatten_node.target is pytree.tree_unflatten ): log.debug( "Flatten node %s's user is not a pytree.tree_unflatten. " @@ -110,7 +110,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None: # pyrefly: ignore [missing-attribute] arg.op == "call_function" # pyrefly: ignore [missing-attribute] - and arg.target == operator.getitem + and arg.target is operator.getitem # pyrefly: ignore [missing-attribute] and arg.args[1] == i ): diff --git a/torch/export/_trace.py b/torch/export/_trace.py index b1926abebaa8b..934ee44882052 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -222,7 +222,7 @@ def _rewrite_tracepoint_node(gm: torch.fx.GraphModule): that has the same target and args, but with the _export_root stripped from path. """ for node in gm.graph.nodes: - if node.target == torch.ops.higher_order._export_tracepoint: + if node.target is torch.ops.higher_order._export_tracepoint: if "path" in node.kwargs: path = _strip_root(node.kwargs["path"]) with gm.graph.inserting_before(node): @@ -922,7 +922,7 @@ def _export_to_aten_ir( if decompose_custom_triton_ops else _disable_custom_triton_op_functional_decomposition ) - # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode, + # This _reparameterize_module makes sure inputs and module.params/buffers have the same fake_mode, # otherwise aot_export_module will error out because it sees a mix of fake_modes. # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. with ExitStack() as stack: @@ -1843,7 +1843,7 @@ def _is_impure(node): ) return gm, sig - # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode, + # This _reparameterize_module makes sure inputs and module.params/buffers have the same fake_mode, # otherwise aot_export_module will error out because it sees a mix of fake_modes. # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. with ExitStack() as stack: diff --git a/torch/export/decomp_utils.py b/torch/export/decomp_utils.py index a261ce3c8b2c8..d3097734c8a35 100644 --- a/torch/export/decomp_utils.py +++ b/torch/export/decomp_utils.py @@ -21,6 +21,10 @@ PRESERVED_ATEN_CIA_OPS = { torch.ops.aten.upsample_bilinear2d.vec, torch.ops.aten.upsample_nearest2d.vec, + # NB: don't use the C++ decomp, because it is not functional! + torch.ops.aten.silu_backward.default, + torch.ops.aten.mish_backward.default, + torch.ops.aten._fused_rms_norm.default, } diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index fd26d684b2b38..1e1f1f409857b 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -1333,7 +1333,7 @@ def refine_dynamic_shapes_from_suggested_fixes( roots.add(c.root.__name__) # type: ignore[attr-defined] # check keys are existing dims or new roots - for k, c in shape_fixes.items(): + for k in shape_fixes.keys(): assert k in name_to_dim or k in roots # cache so we don't produce multiple derived dim objects diff --git a/torch/export/experimental/__init__.py b/torch/export/experimental/__init__.py index 2705b59a9075a..ec5e73cad85d4 100644 --- a/torch/export/experimental/__init__.py +++ b/torch/export/experimental/__init__.py @@ -50,9 +50,9 @@ def _remove_detach_pass( if node.op != "call_function": continue if ( - node.target == torch.ops.aten.detach.default + node.target is torch.ops.aten.detach.default and len(node.users) == 1 - and next(iter(node.users)).target == torch.ops.aten.detach.default + and next(iter(node.users)).target is torch.ops.aten.detach.default ): next(iter(node.users)).replace_all_uses_with(node) diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index f7bc5531677f9..58bd4b9087d21 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -281,7 +281,7 @@ def _split_decomp_table_to_cia_and_python_decomp( for op in list(decomp_table.keys()): # TODO we are silently allowing non-safe(non-functional) ops through a crack # due to core aten decomp table having non-functional entries. Once we have - # a tigher check around core aten decomp, we should warn users about them. + # a tighter check around core aten decomp, we should warn users about them. # Tracking issue: (https://github.com/pytorch/pytorch/issues/135759) # if it is a valid CIA op we can mess with in export, we check if it is: @@ -798,7 +798,7 @@ def _remove_unnecessary_copy_op_pass( ): if ( out.op == "call_function" - and out.target == torch.ops.aten.copy.default + and out.target is torch.ops.aten.copy.default ): out.replace_all_uses_with(out.args[1]) # type: ignore[arg-type] gm.graph.erase_node(out) @@ -817,7 +817,7 @@ def _common_getitem_elimination_pass( node_id: dict[torch.fx.Node, str] = {} getitems: dict[str, torch.fx.Node] = {} for node in list(module.graph.nodes): - if node.op == "call_function" and node.target == operator.getitem: + if node.op == "call_function" and node.target is operator.getitem: source, idx = node.args new_id = f"{node_id[source]}.{idx}" if new_id in getitems: diff --git a/torch/export/passes/__init__.py b/torch/export/passes/__init__.py index d36985180f5fd..90430608cab21 100644 --- a/torch/export/passes/__init__.py +++ b/torch/export/passes/__init__.py @@ -78,7 +78,7 @@ def _get_new_device( if ( node.op == "call_function" - and node.target == torch.ops.aten.to.device + and node.target is torch.ops.aten.to.device ): args = list(node.args) # pyrefly: ignore [unsupported-operation] diff --git a/torch/functional.py b/torch/functional.py index 3054f54b7cd40..013832d59cfb3 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -1829,7 +1829,7 @@ def norm( # noqa: F811 return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore[attr-defined] # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed - # remove the overloads where dim is an int and replace with BraodcastingList1 + # remove the overloads where dim is an int and replace with BroadcastingList1 # and remove next four lines, replace _dim with dim if dim is not None: if isinstance(dim, (int, torch.SymInt)): diff --git a/torch/fx/experimental/accelerator_partitioner.py b/torch/fx/experimental/accelerator_partitioner.py index 5aa976a2a1218..7cfd41b039e9e 100644 --- a/torch/fx/experimental/accelerator_partitioner.py +++ b/torch/fx/experimental/accelerator_partitioner.py @@ -581,7 +581,7 @@ def dump_dag(self, module_with_submodules: GraphModule) -> DAG: break if node.op in {"placeholder", "get_attr"}: continue - if node.target == operator.__getitem__: + if node.target is operator.__getitem__: continue input_nodes: dict[Node, None] = {} map_arg(node.args, input_nodes.setdefault) diff --git a/torch/fx/experimental/graph_gradual_typechecker.py b/torch/fx/experimental/graph_gradual_typechecker.py index d1ca9bc0c8805..58a62aee31460 100644 --- a/torch/fx/experimental/graph_gradual_typechecker.py +++ b/torch/fx/experimental/graph_gradual_typechecker.py @@ -101,11 +101,11 @@ def broadcast_types(t1, t2): # We make the types the same length which is the first requirement # for consistency if s1 > s2: - for i in range(s1 - s2): + for _ in range(s1 - s2): new_t2.insert(0, 1) elif s2 > s1: - for i in range(s2 - s1): + for _ in range(s2 - s1): new_t1.insert(0, 1) # we replace occurrences of "1" with each tensor with @@ -250,7 +250,7 @@ def transpose_inference_rule(n: Node): We check that dimensions for the transpose operations are within range of the tensor type of the node """ - if n.target == torch.transpose: + if n.target is torch.transpose: assert isinstance(n.args[0], Node) t = n.args[0].type @@ -674,7 +674,7 @@ def type_check_node(self, n: Node): return n.type elif n.op == "call_function": - if n.target == getattr: + if n.target is getattr: assert getattr in _INFERENCE_RULES return _INFERENCE_RULES[n.target](n, self.traced) diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py index 6ff260f227e9e..28e5c7c215e64 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py @@ -1050,9 +1050,9 @@ def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var): @register_inference_rule(operator.add) def broadcasting_inference_rule(n: Node, symbols, constraints, counter): op_code = None - if n.target == operator.add or n.target == torch.add: + if n.target is operator.add or n.target is torch.add: op_code = op_add - elif n.target == operator.mul: + elif n.target is operator.mul: op_code = op_mul if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index f4152621a5dd7..a2bb9a7549c5e 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -40,6 +40,7 @@ from torch import SymBool, SymInt, Tensor from torch._dispatch.python import enable_python_dispatcher from torch._library.fake_class_registry import FakeScriptObject +from torch._library.opaque_object import is_opaque_type from torch._logging import trace_structured from torch._subclasses.fake_impls import fast_detach from torch._subclasses.fake_tensor import ( @@ -63,6 +64,7 @@ _disable_infra_mode, _push_mode, _unset_infra_mode, + autograd_would_have_decomposed, TorchDispatchMode, ) from torch.utils._stats import count @@ -408,6 +410,7 @@ def get_proxy_slot( tracker = tracer.symnode_tracker # pyrefly: ignore [index-error] + # pyrefly: ignore [no-matching-overload, bad-argument-type] value = tracker.get(obj) if value is None and isinstance(obj, py_sym_types): @@ -1030,11 +1033,16 @@ def can_handle_tensor(x: Tensor) -> bool: return r # For pre-autograd tracing, we do not want to run CompositeImplicit decomps. - if not pre_dispatch and func not in [ - torch.ops.aten.size.default, - torch.ops.aten.stride.default, - torch.ops.aten.storage_offset.default, - ]: + if ( + not pre_dispatch + and func + not in [ + torch.ops.aten.size.default, + torch.ops.aten.stride.default, + torch.ops.aten.storage_offset.default, + ] + and autograd_would_have_decomposed(func, flat_args_kwargs) + ): with proxy_mode: r = func.decompose(*args, **kwargs) if r is not NotImplemented: @@ -1587,11 +1595,11 @@ def __torch_function__( # TODO(tmanlaibaatar): we should systematically couple it with export verifier, # instead of hardcoding it here. # T203648563 - if func == torch.amp.autocast_mode._exit_autocast: + if func is torch.amp.autocast_mode._exit_autocast: enter_node = self.enter_autocast_nodes.pop() args = (enter_node,) node = self.tracer.create_node("call_function", func, args, {}) # type: ignore[arg-type] - if func == torch.amp.autocast_mode._enter_autocast: + if func is torch.amp.autocast_mode._enter_autocast: self.enter_autocast_nodes.append(node) if func in [ torch._C._set_grad_enabled, @@ -1717,7 +1725,7 @@ def __sym_dispatch__( ) -> object: # Peephole optimize multiply by one # NB: be careful not to trigger guards here! - if func == operator.mul: + if func is operator.mul: if isinstance(args[1], int) and args[1] == 1: return args[0] elif isinstance(args[0], int) and args[0] == 1: @@ -1968,6 +1976,7 @@ def __init__(self, base: Union[Module, _AttrProxy], path: str) -> None: # Warning: We blow away our own attributes here to mimic the base class # - so don't expect `self.x` to do anything useful. # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-override] self.__class__ = type( base.__class__.__name__, (self.__class__, base.__class__), @@ -2427,7 +2436,7 @@ def inner_wrap_fake(x: object) -> object: hint=x, source=source, ) - elif isinstance(x, torch.ScriptObject): + elif isinstance(x, torch.ScriptObject) or is_opaque_type(type(x)): return torch._library.fake_class_registry.maybe_to_fake_obj( self.fake_tensor_mode, x ) diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index a617d4fe558cd..d07d235e51321 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -1871,7 +1871,7 @@ def round_magic_impl(self, ndigits=None): setattrs(user_type, f"__r{method_name}__", rbinary_magic_impl) -for method, func in magic_methods.items(): # type: ignore[assignment] +for method in magic_methods.keys(): # type: ignore[assignment] if method in only_bool_magic_methods: _make_user_magic(method, SymBool) continue diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 5eb5f6688e212..aeccdfbe000db 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -841,7 +841,7 @@ def div_by_factor(x: sympy.Expr, factor: int) -> sympy.Expr: factor = functools.reduce(math.gcd, map(integer_coefficient, atoms)) if factor == 1: return expr - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] atoms = [div_by_factor(x, factor) for x in atoms] return _sympy_from_args( sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative @@ -2207,7 +2207,7 @@ class SubclassSymbolicContext(StatefulSymbolicContext): def __post_init__(self) -> None: super().__post_init__() if self.inner_contexts is None: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.inner_contexts = {} @@ -2296,12 +2296,12 @@ def _fast_expand(expr: _SympyT) -> _SympyT: # only re-create the objects if any of the args changed to avoid expensive # checks when re-creating objects. new_args = [_fast_expand(arg) for arg in expr.args] # type: ignore[arg-type] - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return _fast_expand(expr.func(*new_args)) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] if expr.is_Pow: base: sympy.Expr exp: sympy.Expr @@ -2311,11 +2311,11 @@ def _fast_expand(expr: _SympyT) -> _SympyT: return sympy.expand_multinomial(expr, deep=False) elif exp < 0: return S.One / sympy.expand_multinomial(S.One / expr, deep=False) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] elif expr.is_Mul: num: list[sympy.Expr] = [] den: list[sympy.Expr] = [] - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] for arg in expr.args: if arg.is_Pow and arg.args[1] == -1: den.append(S.One / arg) # type: ignore[operator, arg-type] @@ -2437,7 +2437,7 @@ def _maybe_evaluate_static_worker( # TODO: remove this try catch (esp for unbacked_only) try: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] new_expr = expr.xreplace(new_shape_env) except RecursionError: log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env) @@ -2658,7 +2658,9 @@ def _print_Symbol(self, expr: sympy.Symbol) -> str: Convert a sympy Symbol to its source representation. This method looks up the symbol in symbol_to_source mapping and returns - the string representation of its first source. + the string representation of its first source. If the symbol is not in + symbol_to_source (which can happen when symbols appear in guard expressions + through simplification or substitution), it falls back to var_to_sources. Args: expr: The sympy Symbol to convert @@ -2667,24 +2669,30 @@ def _print_Symbol(self, expr: sympy.Symbol) -> str: String representation of the symbol's source Raises: - AssertionError: If the symbol is not found in symbol_to_source + AssertionError: If the symbol is not found in either mapping """ assert isinstance(expr, sympy.Symbol), str(type(expr)) - def repr_symbol_to_source() -> str: - return repr( - { - symbol: [s.name() for s in sources] - for symbol, sources in self.symbol_to_source.items() - } - ) + # Try symbol_to_source first, fall back to var_to_sources if not found + if source := self.symbol_to_source.get(expr): + return self.print_source(source[0]) + elif source := self.var_to_sources.get(expr): + return self.print_source(source[0]) + else: - assert self.symbol_to_source.get(expr), ( - f"{expr} (could be from {[s.name() for s in self.var_to_sources[expr]]}) " - f"not in {repr_symbol_to_source()}. If this assert is failing, it could be " - "due to the issue described in https://github.com/pytorch/pytorch/pull/90665" - ) - return self.print_source(self.symbol_to_source[expr][0]) + def repr_sources(src: Mapping[sympy.Symbol, list[Source]]) -> str: + return repr( + { + symbol: [s.name() for s in sources] + for symbol, sources in src.items() + } + ) + + raise RuntimeError( + f"{expr} not in {repr_sources(self.symbol_to_source)} or " + f"{repr_sources(self.var_to_sources)}. This could be due to " + "the issue described in https://github.com/pytorch/pytorch/pull/90665" + ) @abc.abstractmethod def print_source(self, source: Source) -> str: @@ -2975,19 +2983,19 @@ def floor_div_handler(*args: sympy.Expr) -> sympy.Expr: # is_integer tests though haha return (base - mod_reduced) / divisor - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] if expr.has(Mod): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] expr = expr.replace(Mod, mod_handler) # 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative # arguments should be OK. - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] if expr.has(PythonMod): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] expr = expr.replace(PythonMod, mod_handler) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] if expr.has(FloorDiv): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] expr = expr.replace(FloorDiv, floor_div_handler) return expr @@ -3342,7 +3350,7 @@ def _check_same_range(c: Mapping[str, int], dim: object) -> bool: # alter derivations that depend on old root, to unify to new root # e.g. dx=3*_dx+1, dy=dx+1 -> dy=3*_dx+2 for old_root in introduced_roots.values(): - for k, c in list(results.items()): + for c in results.values(): if ( "eq" in c and isinstance(c["eq"], sympy.Expr) @@ -4522,7 +4530,7 @@ def create_symbolic_sizes_strides_storage_offset( # The order of checking the guards matters. In this specific example: # If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True, - # we may have an unnecessary shape speciliazation for y. + # we may have an unnecessary shape specialization for y. def _maybe_specialize_sym_int_with_hint( self, maybe_sym: IntLikeType ) -> IntLikeType: @@ -5106,7 +5114,7 @@ def create_symbol( if duck: # Make sure to reuse this symbol for subsequent duck shaping - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] self.val_to_var[val] = sympy_expr if isinstance(val, int): @@ -5338,9 +5346,9 @@ def _create_no_constraints_context(t: Tensor) -> StatelessSymbolicContext: # Expand optional inputs, or verify invariants are upheld if input_contexts is None: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] input_contexts = [ - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] _create_no_constraints_context(t) if isinstance(t, Tensorlike) else None for t in placeholders ] @@ -5350,7 +5358,7 @@ def _create_no_constraints_context(t: Tensor) -> StatelessSymbolicContext: for i, (t, context) in enumerate(zip(placeholders, input_contexts)): if isinstance(t, Tensorlike): if context is None: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] input_contexts[i] = _create_no_constraints_context(t) else: assert isinstance(t, (SymInt, int, SymFloat, float)) @@ -5636,7 +5644,7 @@ def track_symfloat(source: Source, val: FloatLikeType) -> None: s = sympy.Float(val) input_guards.append((source, s)) - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] for t, source, context in zip(placeholders, sources, input_contexts): if isinstance(source, str): from torch._dynamo.source import LocalSource @@ -5830,7 +5838,7 @@ def track_symfloat(source: Source, val: FloatLikeType) -> None: def issue_guard(guard: ShapeGuard) -> None: expr = self.simplify(guard.expr) - # Avoid re-issueing the same guard. + # Avoid re-issuing the same guard. if expr in issued: return @@ -5999,7 +6007,7 @@ def issue_guard(guard: ShapeGuard) -> None: else: str_msg = f" - {msg_cb()}" error_msgs.append(str_msg) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] debug_names.add(debug_name) if len(error_msgs) > 0: debug_names_str = ", ".join(sorted(debug_names)) @@ -6133,7 +6141,7 @@ def get_pruned_guards(self, symints: Sequence[torch.SymInt]) -> list[ShapeGuard] Get a list of guards, but pruned so it only provides guards that reference symints from the passed in input """ - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] symints = { s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol) } @@ -6396,7 +6404,7 @@ def replace(self, expr: _SympyT) -> _SympyT: Apply symbol replacements to any symbols in the given expression. """ replacements = {} - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] for s in expr.free_symbols: r = self._find(s) @@ -6406,7 +6414,7 @@ def replace(self, expr: _SympyT) -> _SympyT: if not r.is_Symbol or r != s: replacements[s] = r if replacements: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return safe_expand(expr.xreplace(replacements)) else: return expr @@ -6899,9 +6907,6 @@ def _maybe_guard_rel(self, expr: sympy.Expr) -> None: self._maybe_guard_rel(arg) return elif not isinstance(expr, sympy.Rel): - log.warning( - "_maybe_guard_rel() was called on non-relation expression %s", expr - ) return # A good example of what goes wrong if you don't do this is @@ -7181,7 +7186,7 @@ def _find_frame_locals(self) -> _FrameLocalResult: instructions = list(dis.Bytecode(frame.f_code)) co_lines, offset = inspect.getsourcelines(frame.f_code) start, end, cur = None, None, None - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] for i, instr in enumerate(instructions): if instr.starts_line is not None: cur = instr.starts_line diff --git a/torch/fx/experimental/unification/multipledispatch/conflict.py b/torch/fx/experimental/unification/multipledispatch/conflict.py index 44a893ad56a40..181e0e8dd167a 100644 --- a/torch/fx/experimental/unification/multipledispatch/conflict.py +++ b/torch/fx/experimental/unification/multipledispatch/conflict.py @@ -118,7 +118,7 @@ def edge(a, b, tie_breaker=hash): """A should be checked before B Tie broken by tie_breaker, defaults to ``hash`` """ - # A either supercedes B and B does not supercede A or if B does then call + # A either supersedes B and B does not supersede A or if B does then call # tie_breaker return supercedes(a, b) and ( not supercedes(b, a) or tie_breaker(a) > tie_breaker(b) diff --git a/torch/fx/experimental/unification/multipledispatch/dispatcher.py b/torch/fx/experimental/unification/multipledispatch/dispatcher.py index 1410bbc5239c3..e2459b82247bc 100644 --- a/torch/fx/experimental/unification/multipledispatch/dispatcher.py +++ b/torch/fx/experimental/unification/multipledispatch/dispatcher.py @@ -238,7 +238,7 @@ def add(self, signature, func): "To use a variadic union type place the desired types " "inside of a tuple, e.g., [(int, str)]" ) - # pyrefly: ignore # bad-specialization + # pyrefly: ignore [bad-specialization] new_signature.append(Variadic[typ[0]]) else: new_signature.append(typ) @@ -407,7 +407,7 @@ class MethodDispatcher(Dispatcher): Dispatcher """ - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] __slots__ = ("obj", "cls") @classmethod diff --git a/torch/fx/experimental/unification/unification_tools.py b/torch/fx/experimental/unification/unification_tools.py index f29bc8b525500..8b4216a79ad03 100644 --- a/torch/fx/experimental/unification/unification_tools.py +++ b/torch/fx/experimental/unification/unification_tools.py @@ -298,7 +298,7 @@ def update_in(d, keys, func, default=None, factory=dict): rv = inner = factory() rv.update(d) - # pyrefly: ignore # not-iterable + # pyrefly: ignore [not-iterable] for key in ks: if k in d: d = d[k] diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index eb55b6c2050ca..96d815750c206 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -181,6 +181,7 @@ def to_int(x: z3.ArithRef) -> z3.ArithRef: return x if x.is_int() else z3.ToInt(x) def sym_sum(self, args: z3.ArithRef) -> z3.ArithRef: + # pyrefly: ignore return sum(args) # Implements Python division semantics. @@ -357,7 +358,7 @@ def placeholder( def call_function( self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] ) -> Any: - if target != torch._assert: + if target is not torch._assert: # Lift and runs the node target function return super().call_function(z3op(target, self.validator), args, kwargs) # type: ignore[arg-type] # Adds the Z3 expression corresponding to the first argument @@ -814,7 +815,7 @@ def check_node_fails(node: torch.fx.Node) -> Optional[ValidationException]: # Bisection happens on the assertion nodes of the recorded FX graph for # dynamic shapes. assert_nodes = [ - node for node in shape_env.graph.nodes if node.target == torch._assert + node for node in shape_env.graph.nodes if node.target is torch._assert ] # Preparing the indices for binary search. diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 4467d252ad499..fc6f4c5b27021 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -1380,7 +1380,7 @@ def erase_node(self, to_erase: Node) -> None: f(to_erase) self._find_nodes_lookup_table.remove(to_erase) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] to_erase._remove_from_list() to_erase._erased = True # iterators may retain handles to erased nodes self._len -= 1 @@ -1941,7 +1941,7 @@ def check_arg(arg: Node, n: Optional[Node] = None) -> None: "a str is expected" ) if node.op in ["get_attr", "call_module"]: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] target_atoms = node.target.split(".") m_itr = self.owning_module for i, atom in enumerate(target_atoms): diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index d0604d96b012a..159926bc8ba49 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -535,7 +535,7 @@ def __init__( self.graph._tracer_cls and "" not in self.graph._tracer_cls.__qualname__ ): - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self._tracer_cls = self.graph._tracer_cls self._tracer_extras = {} diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index a6cbe1cfe2c82..a3114a14a657e 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -55,7 +55,7 @@ class NegSigmSwapInterpreter(Interpreter): def call_function( self, target: Target, args: Tuple, kwargs: Dict ) -> Any: - if target == torch.sigmoid: + if target is torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) @@ -489,7 +489,7 @@ def call_function( args: Tuple[Argument, ...], kwargs: Dict[str, Any], ) -> Any: - if target == torch.sigmoid: + if target is torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) diff --git a/torch/fx/node.py b/torch/fx/node.py index 48f57d588631c..1d72a75a6ccf4 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -754,6 +754,26 @@ def is_impure(self, impure_random: bool = True) -> bool: return self.target in _side_effectful_functions + def subgraph_has_impure_ops(module: torch.fx.GraphModule) -> bool: + """ + Return True if a GraphModule type subgraph contains any impure op, else False. + """ + assert isinstance(module, torch.fx.GraphModule), ( + "caller should only pass GraphModule to subgraph_has_impure_ops check" + ) + for node in module.graph.nodes: + if node.op == "call_function" and node.is_impure(impure_random): + return True + if ( + # pyrefly: ignore [invalid-argument] + node.op == "call_module" + # pyrefly: ignore [not-callable] + and (submodule := module.get_submodule(node.target)) + and isinstance(submodule, torch.fx.GraphModule) + ): + return subgraph_has_impure_ops(submodule) + return False + # Check if an impure module. if self.op == "call_module": assert self.graph.owning_module is not None, ( @@ -763,7 +783,10 @@ def is_impure(self, impure_random: bool = True) -> bool: assert target_mod is not None, ( f"Did not find expected submodule target {self.target}" ) - return getattr(target_mod, "_is_impure", False) + if isinstance(target_mod, torch.fx.GraphModule): + return subgraph_has_impure_ops(target_mod) + else: + return getattr(target_mod, "_is_impure", False) return False diff --git a/torch/fx/operator_schemas.py b/torch/fx/operator_schemas.py index 1234d13b3b11f..397d4c5996ee9 100644 --- a/torch/fx/operator_schemas.py +++ b/torch/fx/operator_schemas.py @@ -80,6 +80,7 @@ def __getattr__(self, name): "NoneType": type(None), "Storage": torch.UntypedStorage, "t": typing.TypeVar("t"), + "PyObject": Any, } for k in dir(typing): _type_eval_globals[k] = getattr(typing, k) diff --git a/torch/fx/passes/_tensorify_python_scalars.py b/torch/fx/passes/_tensorify_python_scalars.py index bf6a2d99c40d3..7d7a4c04cff2f 100644 --- a/torch/fx/passes/_tensorify_python_scalars.py +++ b/torch/fx/passes/_tensorify_python_scalars.py @@ -165,12 +165,12 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: node = graph.call_function( torch.ops.aten.scalar_tensor.default, - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] (c,), {"dtype": dtype}, ) with fake_mode: - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] node.meta["val"] = torch.ops.aten.scalar_tensor.default(c, dtype=dtype) expr_to_tensor_proxy[expr] = MetaProxy( node, @@ -223,13 +223,13 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: expr_to_sym_proxy[s] = MetaProxy( node, tracer=tracer, fake_mode=fake_mode ) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] elif (sym_expr := _get_sym_val(node)) is not None: if sym_expr not in expr_to_sym_proxy and not isinstance( sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom) ): expr_to_sym_proxy[sym_expr] = MetaProxy( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] node, tracer=tracer, fake_mode=fake_mode, @@ -238,7 +238,7 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: # Specialize all dimensions that contain symfloats. Here's # an example test that requires this: # PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoCUDA.test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 # noqa: B950 - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] val = node.meta.get("val") if isinstance(val, FakeTensor): for dim in val.shape: @@ -257,17 +257,17 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: should_restart = True # Look for functions to convert - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] if node.op == "call_function" and ( - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] replacement_op := SUPPORTED_OPS.get(node.target) ): args: list[Any] = [] transform = False - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] compute_dtype = get_computation_dtype(node.meta["val"].dtype) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] for a in node.args: if ( isinstance(a, fx.Node) @@ -304,7 +304,7 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: if transform: replacement_proxy = replacement_op(*args) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] if compute_dtype != node.meta["val"].dtype: replacement_proxy = ( torch.ops.prims.convert_element_type.default( @@ -313,9 +313,9 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: ) ) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] node.replace_all_uses_with(replacement_proxy.node) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] graph.erase_node(node) metrics_context = get_metrics_context() @@ -324,16 +324,16 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: "tensorify_float_success", True, overwrite=True ) else: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] for a in node.args: if ( isinstance(a, fx.Node) and "val" in a.meta and isinstance(zf := a.meta["val"], torch.SymFloat) ): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] failed_tensorify_ops.update(str(node.target)) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] log.info("Failed to tensorify %s", str(node.target)) # Now do one more pass that specializes all symfloats we didn't manage diff --git a/torch/fx/passes/annotate_getitem_nodes.py b/torch/fx/passes/annotate_getitem_nodes.py index 0a31a76420b34..17b77f6396206 100644 --- a/torch/fx/passes/annotate_getitem_nodes.py +++ b/torch/fx/passes/annotate_getitem_nodes.py @@ -16,7 +16,7 @@ def annotate_getitem_nodes(graph: torch.fx.Graph) -> None: graph (Graph): The graph to be annotated """ for node in graph.nodes: - if node.target == operator.getitem: + if node.target is operator.getitem: sequence_node, index_node = node.args if not sequence_node.type: continue diff --git a/torch/fx/passes/backends/cudagraphs.py b/torch/fx/passes/backends/cudagraphs.py index 657c7578f5fa5..97496fbc9b2a2 100644 --- a/torch/fx/passes/backends/cudagraphs.py +++ b/torch/fx/passes/backends/cudagraphs.py @@ -15,10 +15,10 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: if node.op not in CALLABLE_NODE_OPS: return False - if node.target == torch.ops.aten.embedding_dense_backward.default: + if node.target is torch.ops.aten.embedding_dense_backward.default: return False - if node.target == operator.getitem: + if node.target is operator.getitem: return True found_not_cuda = False diff --git a/torch/fx/passes/graph_drawer.py b/torch/fx/passes/graph_drawer.py index 313766d51028e..92ce645df8fa9 100644 --- a/torch/fx/passes/graph_drawer.py +++ b/torch/fx/passes/graph_drawer.py @@ -437,13 +437,13 @@ def _to_dot( ) current_graph = buf_name_to_subgraph.get(buf_name) # type: ignore[assignment] - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] current_graph.add_node(dot_node) def get_module_params_or_buffers(): for pname, ptensor in chain( leaf_module.named_parameters(), - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] leaf_module.named_buffers(), ): pname1 = node.name + "." + pname diff --git a/torch/fx/passes/infra/pass_base.py b/torch/fx/passes/infra/pass_base.py index ef8e79e578696..32c641031b31f 100644 --- a/torch/fx/passes/infra/pass_base.py +++ b/torch/fx/passes/infra/pass_base.py @@ -11,7 +11,7 @@ @compatibility(is_backward_compatible=False) -# pyrefly: ignore # invalid-inheritance +# pyrefly: ignore [invalid-inheritance] class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): """ Result of a pass: diff --git a/torch/fx/passes/infra/pass_manager.py b/torch/fx/passes/infra/pass_manager.py index e13ca72fd2408..87fb6e70037f9 100644 --- a/torch/fx/passes/infra/pass_manager.py +++ b/torch/fx/passes/infra/pass_manager.py @@ -31,7 +31,7 @@ def pass_result_wrapper(fn: Callable) -> Callable: wrapped_fn (Callable[Module, PassResult]) """ if fn is None: - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return None @wraps(fn) diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py index b4a82f10177dc..e98bad06e5a55 100644 --- a/torch/fx/passes/net_min_base.py +++ b/torch/fx/passes/net_min_base.py @@ -396,25 +396,25 @@ def _run_and_compare( report.append(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined] if self.module_exporter: if isinstance(result_key, tuple): # type: ignore[possibly-undefined] - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] result_key = result_key[-1] # If the result is still a tuple (happens in non-sequential mode), # we only use the first element as name. if isinstance(result_key, tuple): # type: ignore[possibly-undefined] - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] result_key = str(result_key[0]) # pyre-ignore[29]: not a function self.module_exporter( a_input, submodule, - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] result_key + "_cpu", ) # pyre-ignore[29]: not a function self.module_exporter( b_input, submodule, - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] result_key + "_acc", ) raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined] diff --git a/torch/fx/passes/reinplace.py b/torch/fx/passes/reinplace.py index 30f1549389610..2dba9f0ca12f0 100644 --- a/torch/fx/passes/reinplace.py +++ b/torch/fx/passes/reinplace.py @@ -219,7 +219,7 @@ def _add_if_tensor(x, set_): if n in tensor_aliases: if ( isinstance(n.target, torch._ops.OpOverload) - or n.target == _operator.getitem + or n.target is _operator.getitem ): continue nodes_used_after.add(n) diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 46298304adbde..58aa801062824 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -360,7 +360,7 @@ def match_symbol(symint, cb): ): # this guards against deleting calls like item() that produce new untracked symbols def has_new_untracked_symbols(): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] for symbol in sym_expr.free_symbols: if symbol not in expr_to_proxy: return True @@ -376,7 +376,7 @@ def has_new_untracked_symbols(): assert resolved_unbacked_bindings is not None def has_new_unbacked_bindings(): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] for key in resolved_unbacked_bindings.keys(): if key not in expr_to_proxy: return True @@ -606,7 +606,7 @@ def convert(s): if ( expr_to_proxy[i0].node.target - != cast_symbool_to_symint_guardless + is not cast_symbool_to_symint_guardless ): # TODO(pianpwk): calling sym_constrain_range_for_size or adding bound asserts # raises AOTAutograd errors on cast_symbool_to_symint_guardless diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index fdbec419041da..a4b244750f33d 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -326,18 +326,18 @@ def instantiate_node_partition_mapping(node): instantiate_node_partition_mapping(node) if node.op == "call_function" and node.target in GLOBAL_STATE_NODES: - if node.target == torch._C._set_grad_enabled: + if node.target is torch._C._set_grad_enabled: assert len(node.args) == 1 assert isinstance(node.args[0], bool) active_grad = node grad_regions[active_grad] = set({split_callback(node)}) - elif node.target == torch.amp._enter_autocast: + elif node.target is torch.amp._enter_autocast: # Should all be python constants assert all(not isinstance(arg, Node) for arg in node.args) active_autocasts.add(node) autocast_regions[node] = set({split_callback(node)}) autocast_exits[node] = None - elif node.target == torch.amp._exit_autocast: + elif node.target is torch.amp._exit_autocast: assert len(node.args) == 1 autocast_regions[node.args[0]].add(split_callback(node)) active_autocasts.remove(node.args[0]) @@ -351,9 +351,9 @@ def instantiate_node_partition_mapping(node): assert all(v is not None for v in autocast_exits.values()), "autocast must exit" - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()} - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] grad_regions = {k: sorted(v) for k, v in grad_regions.items()} if _LOGGER.isEnabledFor(logging.DEBUG): @@ -418,9 +418,9 @@ def instantiate_node_partition_mapping(node): for regions_mapping in [autocast_regions, grad_regions]: for node, regions in regions_mapping.items(): assert len(regions) > 0 - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] partitions[str(regions[0])].environment[node] = node - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] for r in regions[1:]: partition = partitions[str(r)] new_node = partition.graph.create_node( @@ -520,7 +520,7 @@ def add_placeholder(): for node in reversed(regions_mapping): regions = regions_mapping[node] assert len(regions) > 0 - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] for r in regions[:-1]: partition = partitions[str(r)] exit_node = autocast_exits[node] diff --git a/torch/fx/passes/utils/common.py b/torch/fx/passes/utils/common.py index 3924a93d22cff..4c97aa4093571 100644 --- a/torch/fx/passes/utils/common.py +++ b/torch/fx/passes/utils/common.py @@ -64,7 +64,7 @@ def lift_subgraph_as_module( for name in target_name_parts[:-1]: if not hasattr(curr, name): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] curr.add_module(name, HolderModule({})) curr = getattr(curr, name) diff --git a/torch/headeronly/README.md b/torch/headeronly/README.md new file mode 100644 index 0000000000000..a24734dd1a07c --- /dev/null +++ b/torch/headeronly/README.md @@ -0,0 +1,7 @@ +## torch/headeronly + +The inlined C++ headers in the `torch::headeronly` namespace living this subdirectory are completely decoupled from LibTorch. These APIs are also globally listed in [torch/header_only_apis.txt](https://github.com/pytorch/pytorch/blob/main/torch/header_only_apis.txt). + +There are two types of LibTorch independent header-only headers: +1. OG header-only. Originally header-only APIs, such as `ScalarType`, `Half`, `BFloat16`, have always been implemented in headers only. For them to move into torch/headeronly only required a code migration, a copy-pasta, if you will. +2. Made to be header-only. There are also APIs that were NOT header-only that we made to be header-only. One example of such an API is `STD_TORCH_CHECK`, which was derived from `TORCH_CHECK`. `STD_TORCH_CHECK` calls into `std::runtime_error` instead of relying on `c10::Error`, which relies on libtorch.so. As a result, `STD_TORCH_CHECK` does not have the full `TORCH_CHECK` functionality that displays a fanciful traceback when the check is not met. We intentionally maintain the design that functions that do different things should be explicitly named differently. diff --git a/torch/headeronly/core/ScalarType.h b/torch/headeronly/core/ScalarType.h index 9b403df12f3a4..19262e51de529 100644 --- a/torch/headeronly/core/ScalarType.h +++ b/torch/headeronly/core/ScalarType.h @@ -19,6 +19,8 @@ #include +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") + namespace c10 { // dummy struct for uint1 to uint7, actual functionality @@ -351,3 +353,5 @@ using c10::impl::ScalarTypeToCPPTypeT; } // namespace impl HIDDEN_NAMESPACE_END(torch, headeronly) + +C10_DIAGNOSTIC_POP() diff --git a/torch/headeronly/macros/Macros.h b/torch/headeronly/macros/Macros.h index 5d52ad326726e..63aa0d20d8e54 100644 --- a/torch/headeronly/macros/Macros.h +++ b/torch/headeronly/macros/Macros.h @@ -1,6 +1,11 @@ #ifndef C10_MACROS_MACROS_H_ #define C10_MACROS_MACROS_H_ + +#ifdef __cplusplus #include +#else +#include +#endif /* Main entry for torch/headeronly/macros (used to be c10/macros). * @@ -139,6 +144,8 @@ #define C10_RESTRICT __restrict +#ifdef __cplusplus + // Simply define the namespace, in case a dependent library want to refer to // the c10 namespace but not any nontrivial files. namespace c10 {} @@ -176,6 +183,8 @@ namespace at::xpu { using namespace c10::xpu; } // namespace at::xpu +#endif // __cplusplus + // C10_LIKELY/C10_UNLIKELY // // These macros provide parentheses, so you can use these macros as: @@ -236,7 +245,11 @@ using namespace c10::xpu; #define C10_ERASE C10_ALWAYS_INLINE C10_ATTR_VISIBILITY_HIDDEN +#ifdef __cplusplus #include +#else +#include +#endif #ifdef __HIPCC__ // Unlike CUDA, HIP requires a HIP header to be included for __host__ to work. @@ -467,7 +480,7 @@ __host__ __device__ // a non-negligible performance impact even if the assert condition is // never triggered. We choose to use abort() instead which will still // terminate the application but without a more useful error message. -#if !defined(C10_USE_ROCM_KERNEL_ASSERT) and defined(USE_ROCM) +#if !defined(C10_USE_ROCM_KERNEL_ASSERT) && defined(USE_ROCM) #define CUDA_KERNEL_ASSERT(cond) \ if C10_UNLIKELY (!(cond)) { \ abort(); \ @@ -517,9 +530,21 @@ __host__ __device__ __assert_fail( \ #cond, __FILE__, static_cast(__LINE__), __func__); \ } -#endif // C10_USE_ROCM_KERNEL_ASSERT and USE_ROCM +#endif // C10_USE_ROCM_KERNEL_ASSERT && USE_ROCM #endif // __APPLE__ +// Compile-time switch to control how assertions are logged inside CUDA kernels. +// If C10_CUDA_VERBOSE_ASSERT is defined, CUDA_KERNEL_ASSERT_VERBOSE will +// take addition information passed to the macro and forward them to +// CUDA_KERNEL_ASSERT_PRINTF If C10_CUDA_VERBOSE_ASSERT is not defined, +// CUDA_KERNEL_ASSERT_VERBOSE will behave the same as CUDA_KERNEL_ASSERT. +#ifdef C10_ENABLE_VERBOSE_ASSERT +#define CUDA_KERNEL_ASSERT_VERBOSE(cond, ...) \ + CUDA_KERNEL_ASSERT_PRINTF(cond, __VA_ARGS__) +#else +#define CUDA_KERNEL_ASSERT_VERBOSE(cond, ...) CUDA_KERNEL_ASSERT(cond) +#endif + #ifdef __APPLE__ #include #endif diff --git a/torch/headeronly/util/TypeSafeSignMath.h b/torch/headeronly/util/TypeSafeSignMath.h index f41269082d9b4..c33a286bc5b55 100644 --- a/torch/headeronly/util/TypeSafeSignMath.h +++ b/torch/headeronly/util/TypeSafeSignMath.h @@ -79,7 +79,7 @@ template inline constexpr bool greater_than_max(const T& x) { constexpr bool can_overflow = std::numeric_limits::digits > std::numeric_limits::digits; - return can_overflow && x > (std::numeric_limits::max)(); + return can_overflow && x > std::numeric_limits::max(); } #ifdef __GNUC__ diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 8b2ecf566a351..a8bb3ba9bd8f5 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -909,7 +909,7 @@ def __dir__(self): self_method = self.__dir__ if ( self_method.__func__ # type: ignore[attr-defined] - == _get_function_from_type(RecursiveScriptModule, "__dir__") + is _get_function_from_type(RecursiveScriptModule, "__dir__") ): return super().__dir__() return self_method() @@ -921,7 +921,7 @@ def __bool__(self): self_method = self.__bool__ if ( self_method.__func__ # type: ignore[attr-defined] - == _get_function_from_type(RecursiveScriptModule, "__bool__") + is _get_function_from_type(RecursiveScriptModule, "__bool__") ): return True return self_method() @@ -1066,7 +1066,7 @@ def call_prepare_scriptable_func_impl(obj, memo): else: new_obj_dict[name] = sub_module - for k, v in new_obj_dict.items(): + for v in new_obj_dict.values(): obj.__dict__[name] = v return obj diff --git a/torch/lib/libshm/core.cpp b/torch/lib/libshm/core.cpp index 72edb235888c4..4b8aabac54ad2 100644 --- a/torch/lib/libshm/core.cpp +++ b/torch/lib/libshm/core.cpp @@ -56,7 +56,7 @@ static void start_manager() { } SYSCHECK_ERR_RETURN_NEG1(close(pipe_ends[0])); - TORCH_CHECK(handle.length() != 0, "no response from torch_shm_manager at \"", manager_executable_path, "\""); + TORCH_CHECK(!handle.empty(), "no response from torch_shm_manager at \"", manager_executable_path, "\""); handle.pop_back(); // remove \n TORCH_CHECK( diff --git a/torch/lib/libshm/manager.cpp b/torch/lib/libshm/manager.cpp index ec0519d83b752..5647f5a350c8e 100644 --- a/torch/lib/libshm/manager.cpp +++ b/torch/lib/libshm/manager.cpp @@ -105,12 +105,12 @@ int main(int argc, char* argv[]) { srv_socket = std::make_unique(tempfile); register_fd(srv_socket->socket_fd); - print_init_message(tempfile.c_str()); + print_init_message(tempfile); DEBUG("opened socket %s", tempfile.c_str()); } catch (const std::exception& e) { std::string message("ERROR: "); message += e.what(); - print_init_message(message.c_str()); + print_init_message(message); return 1; } catch (...) { print_init_message("ERROR: unhandled exception"); diff --git a/torch/library.h b/torch/library.h index 816f88b13f30d..b244654916b91 100644 --- a/torch/library.h +++ b/torch/library.h @@ -353,6 +353,7 @@ inline CppFunction dispatch(c10::DispatchKey k, Func&& raw_f) { template inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) { auto deviceTypeToDispatchKey = [](c10::DeviceType t) { + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") switch (t) { // This list is synchronized with the k-constants in c10/core/DeviceType.h case c10::DeviceType::CPU: @@ -389,6 +390,7 @@ inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) { " cannot be overloaded at dispatch time, " "please file a bug report explaining what you were trying to do."); } + C10_DIAGNOSTIC_POP() }; return dispatch(deviceTypeToDispatchKey(type), std::forward(raw_f)); } diff --git a/torch/library.py b/torch/library.py index 0490e68b5d1e8..b9b56f6aa7c46 100644 --- a/torch/library.py +++ b/torch/library.py @@ -242,7 +242,7 @@ def _impl_with_aoti_compile(self, op_name, dispatch_key=""): if dispatch_key == "": dispatch_key = self.dispatch_key - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense) if isinstance(op_name, str): diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py index e330b59f47a2d..4bae914f0292b 100644 --- a/torch/masked/_ops.py +++ b/torch/masked/_ops.py @@ -484,7 +484,7 @@ def _canonical_dim(dim: DimOrDims, ndim: int) -> tuple[int, ...]: raise IndexError( f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})" ) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] dims.append(d % ndim) return tuple(sorted(dims)) @@ -1017,7 +1017,7 @@ def helper(input, mask): class Combine(torch.autograd.Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, input, mask): """Return input with masked-out elements eliminated for the given operations.""" ctx.save_for_backward(mask) @@ -1028,7 +1028,7 @@ def forward(ctx, input, mask): return helper(input, mask) @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): (mask,) = ctx.saved_tensors grad_data = ( @@ -1403,18 +1403,18 @@ def mean( if input.layout == torch.strided: if mask is None: # TODO: compute count analytically - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] count = sum( torch.ones(input.shape, dtype=torch.int64, device=input.device), dim, keepdim=keepdim, ) - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] total = sum(input, dim, keepdim=keepdim, dtype=dtype) else: inmask = _input_mask(input, mask=mask) count = inmask.sum(dim=dim, keepdim=bool(keepdim)) - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask) return total / count elif input.layout == torch.sparse_csr: @@ -1625,18 +1625,18 @@ def _std_var( if input.layout == torch.strided: if mask is None: # TODO: compute count analytically - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] count = sum( torch.ones(input.shape, dtype=torch.int64, device=input.device), dim, keepdim=True, ) - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] sample_total = sum(input, dim, keepdim=True, dtype=dtype) else: inmask = _input_mask(input, mask=mask) count = inmask.sum(dim=dim, keepdim=True) - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask) # TODO: replace torch.subtract/divide/square/maximum with # masked subtract/divide/square/maximum when these will be @@ -1644,7 +1644,7 @@ def _std_var( sample_mean = torch.divide(sample_total, count) x = torch.subtract(input, sample_mean) if mask is None: - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype) else: total = sum( diff --git a/torch/masked/maskedtensor/_ops_refs.py b/torch/masked/maskedtensor/_ops_refs.py index a2bb25c9f6cf8..d9add0a1dfbae 100644 --- a/torch/masked/maskedtensor/_ops_refs.py +++ b/torch/masked/maskedtensor/_ops_refs.py @@ -47,7 +47,7 @@ def _check_args_kwargs_length( class _MaskedContiguous(torch.autograd.Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, input): if not is_masked_tensor(input): raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.") @@ -61,14 +61,14 @@ def forward(ctx, input): return MaskedTensor(data.contiguous(), mask.contiguous()) @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): return grad_output class _MaskedToDense(torch.autograd.Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, input): if not is_masked_tensor(input): raise ValueError("MaskedToDense forward: input must be a MaskedTensor.") @@ -83,7 +83,7 @@ def forward(ctx, input): return MaskedTensor(data.to_dense(), mask.to_dense()) @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): layout = ctx.layout @@ -98,7 +98,7 @@ def backward(ctx, grad_output): class _MaskedToSparse(torch.autograd.Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, input): if not is_masked_tensor(input): raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.") @@ -115,14 +115,14 @@ def forward(ctx, input): return MaskedTensor(sparse_data, sparse_mask) @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): return grad_output.to_dense() class _MaskedToSparseCsr(torch.autograd.Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, input): if not is_masked_tensor(input): raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.") @@ -143,21 +143,21 @@ def forward(ctx, input): return MaskedTensor(sparse_data, sparse_mask) @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): return grad_output.to_dense() class _MaskedWhere(torch.autograd.Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, cond, self, other): ctx.mark_non_differentiable(cond) ctx.save_for_backward(cond) return torch.ops.aten.where(cond, self, other) @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): (cond,) = ctx.saved_tensors diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py index 75a41e705b180..111680c1f019e 100644 --- a/torch/masked/maskedtensor/core.py +++ b/torch/masked/maskedtensor/core.py @@ -174,7 +174,7 @@ def __new__(cls, data, mask, requires_grad=False): UserWarning, stacklevel=2, ) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) def _preprocess_data(self, data, mask): @@ -244,12 +244,12 @@ def _from_values(data, mask): class Constructor(torch.autograd.Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, data, mask): return MaskedTensor(data, mask) @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): return grad_output, None @@ -336,12 +336,12 @@ def to_tensor(self, value): def get_data(self): class GetData(torch.autograd.Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, self): return self._masked_data.detach() @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): if is_masked_tensor(grad_output): return grad_output diff --git a/torch/multiprocessing/spawn.py b/torch/multiprocessing/spawn.py index 8506bf7986852..f553f7cacd753 100644 --- a/torch/multiprocessing/spawn.py +++ b/torch/multiprocessing/spawn.py @@ -114,7 +114,7 @@ def _join_procs_with_timeout(self, timeout: float): """Attempt to join all processes with a shared timeout.""" end = time.monotonic() + timeout for process in self.processes: - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] time_to_wait = max(0, end - time.monotonic()) process.join(time_to_wait) diff --git a/torch/nativert/OVERVIEW.md b/torch/nativert/OVERVIEW.md index bfe97c9aefc75..d8a7d255d921a 100644 --- a/torch/nativert/OVERVIEW.md +++ b/torch/nativert/OVERVIEW.md @@ -282,7 +282,7 @@ RuntimeConfigs { Constant folding is the process of finding all of the constant-evaluable subgraphs, evaluating them at startup, and then storing their results as -constants as opposed to re-evaluting them every time. +constants as opposed to re-evaluating them every time. To enable constant folding, you can set the following configurations. diff --git a/torch/nativert/executor/DelegateExecutor.cpp b/torch/nativert/executor/DelegateExecutor.cpp index 6585ac34ddd6c..eb738e5b38daf 100644 --- a/torch/nativert/executor/DelegateExecutor.cpp +++ b/torch/nativert/executor/DelegateExecutor.cpp @@ -54,8 +54,8 @@ std::string extractToTemporaryFolder( << " from archive path: " << path << " size: " << dataSize; File extracted(extractedFilename, O_CREAT | O_WRONLY, 0640); - const auto bytesWritten = writeFull( - extracted.fd(), const_cast(dataPointer.get()), dataSize); + const auto bytesWritten = + writeFull(extracted.fd(), dataPointer.get(), dataSize); TORCH_CHECK( bytesWritten != -1, "failure copying from archive path ", diff --git a/torch/nativert/executor/Weights.h b/torch/nativert/executor/Weights.h index 7791a329ec498..39653d0bed561 100644 --- a/torch/nativert/executor/Weights.h +++ b/torch/nativert/executor/Weights.h @@ -137,8 +137,8 @@ class Weights { // every instance of Weight has a unique version number static WeightVersion globalVersion_; - std::function skipSizeCheck_ = {}; - std::function skipDtypeCheck_ = {}; + std::function skipSizeCheck_; + std::function skipDtypeCheck_; // save the names of unused weights std::unordered_set unusedWeights_; diff --git a/torch/nativert/executor/memory/LayoutManager.h b/torch/nativert/executor/memory/LayoutManager.h index d98700e7f0215..c6aabce2db675 100644 --- a/torch/nativert/executor/memory/LayoutManager.h +++ b/torch/nativert/executor/memory/LayoutManager.h @@ -70,7 +70,7 @@ struct ContiguousLayoutBuffer { size_t size_{0}; // the dataptr returned by the allocator - at::DataPtr data_ptr_{}; + at::DataPtr data_ptr_; }; struct ContiguousStorageImplBuffer { @@ -198,7 +198,7 @@ class LayoutManager { #else auto alignment = c10::gAlignment; #endif - return ((nbytes) + alignment - 1) & (~(alignment - 1)); + return (nbytes + alignment - 1) & (~(alignment - 1)); } void allocate_plan(const LayoutPlan& plan); diff --git a/torch/nativert/executor/memory/LayoutPlannerAlgorithm.h b/torch/nativert/executor/memory/LayoutPlannerAlgorithm.h index eda8e57c64d19..90d04ab25d5ef 100644 --- a/torch/nativert/executor/memory/LayoutPlannerAlgorithm.h +++ b/torch/nativert/executor/memory/LayoutPlannerAlgorithm.h @@ -59,7 +59,7 @@ struct AllocationLifetime { }; struct AllocationSpec { - AllocationLifetime lifetime{}; + AllocationLifetime lifetime; size_t size{0}; bool not_overlapping_with(const AllocationSpec& other) const { diff --git a/torch/nativert/executor/triton/CudaTritonKernelManager.cpp b/torch/nativert/executor/triton/CudaTritonKernelManager.cpp index d18efcc178f46..20324e1e2f5d6 100644 --- a/torch/nativert/executor/triton/CudaTritonKernelManager.cpp +++ b/torch/nativert/executor/triton/CudaTritonKernelManager.cpp @@ -29,7 +29,11 @@ namespace torch::nativert { class CudaKernelInputs final : public KernelInputs { public: CudaKernelInputs(size_t num_args, size_t num_attrs) - : KernelInputs(num_args, num_attrs), arg_ptrs_(num_args) {} + : KernelInputs(num_args, num_attrs), + arg_ptrs_(num_args), + global_scratch_(0) { + inputs_.push_back(&global_scratch_); + } ~CudaKernelInputs() final = default; void add_arg(void* arg) override { @@ -41,6 +45,7 @@ class CudaKernelInputs final : public KernelInputs { private: std::vector arg_ptrs_; + CUdeviceptr global_scratch_; }; class CudaTritonKernelManager final : public TritonKernelManager { diff --git a/torch/nativert/graph/GraphPasses.cpp b/torch/nativert/graph/GraphPasses.cpp index 6cb378af80dbd..46cdd38cca54c 100644 --- a/torch/nativert/graph/GraphPasses.cpp +++ b/torch/nativert/graph/GraphPasses.cpp @@ -124,7 +124,7 @@ std::string selectScalarOverloadName(const Node& node) { for (const auto& variant : {"Scalar_mode", "Scalar", "Scalar_Tensor", "Tensor_Scalar"}) { if (auto schema = c10::Dispatcher::singleton().findSchema( - {fmt::format("{}::{}", ns, opName.c_str()).c_str(), variant})) { + {fmt::format("{}::{}", ns, opName.c_str()), variant})) { if (schemaTypeMatch(schema->schema(), node)) { return variant; } diff --git a/torch/nativert/graph/passes/SubgraphRewriter.cpp b/torch/nativert/graph/passes/SubgraphRewriter.cpp index f4aa743d0214f..ef385a15e33ae 100644 --- a/torch/nativert/graph/passes/SubgraphRewriter.cpp +++ b/torch/nativert/graph/passes/SubgraphRewriter.cpp @@ -265,8 +265,8 @@ bool SubgraphRewriter::run( for (const auto& [pattern, replacement] : patterns_) { const auto& pattern_graph = stringToGraph(pattern); const auto& replacement_graph = stringToGraph(replacement); - mutated |= runForPattern( - graph, *pattern_graph.get(), *replacement_graph.get(), filters); + mutated |= + runForPattern(graph, *pattern_graph, *replacement_graph, filters); } return mutated; } diff --git a/torch/nativert/kernels/PrimKernelRegistry.cpp b/torch/nativert/kernels/PrimKernelRegistry.cpp index b00a7b7715109..e3bf60d466adb 100644 --- a/torch/nativert/kernels/PrimKernelRegistry.cpp +++ b/torch/nativert/kernels/PrimKernelRegistry.cpp @@ -77,7 +77,7 @@ class OpKernel_variadic_concat : public OpKernel { public: explicit OpKernel_variadic_concat(const Node* node) : OpKernel(node, OpKernelKind::kPrimKernel) { - dim_ = node_->attributes().size() > 0 + dim_ = !node_->attributes().empty() ? constantToIValue(node_->getAttribute("dim").value).toInt() : 0; } @@ -122,7 +122,7 @@ class OpKernel_variadic_stack : public OpKernel { public: explicit OpKernel_variadic_stack(const Node* node) : OpKernel(node, OpKernelKind::kPrimKernel) { - dim_ = node_->attributes().size() > 0 + dim_ = !node_->attributes().empty() ? constantToIValue(node_->getAttribute("dim").value).toInt() : 0; } diff --git a/torch/nativert/kernels/TritonKernel.cpp b/torch/nativert/kernels/TritonKernel.cpp index 81606729e645a..081c81f7c646b 100644 --- a/torch/nativert/kernels/TritonKernel.cpp +++ b/torch/nativert/kernels/TritonKernel.cpp @@ -37,20 +37,40 @@ TritonKernel::TritonKernel( TORCH_CHECK(reader != nullptr, "reader is null"); std::string kernel_name{}; + std::string symbol_name{}; bool found_grid = false; + + // To prevent vector reallocation and dangling pointers + size_t num_double_attrs = 0; + for (const auto& attr : node_->attributes()) { + if (attr.name.empty() && std::holds_alternative(attr.value)) { + ++num_double_attrs; + } + } + float_attrs_.reserve(num_double_attrs); + for (const auto& attr : node_->attributes()) { if (attr.name.empty()) { attr_ptrs_.emplace_back(std::visit( - [](auto&& arg) -> void* { + [this](auto&& arg) -> void* { using T = std::decay_t; if constexpr (std::is_same_v) { return nullptr; + } else if constexpr (std::is_same_v) { + // Triton always uses fp32 for floats. See create_specialize_impl + // in jit.py. However, due to the Thrift schema, floats are + // serialized as doubles here. But, Triton kernels read them as + // floats. So, we need to downcast double to float here. + float_attrs_.push_back(static_cast(arg)); + return static_cast(&float_attrs_.back()); } return static_cast(const_cast(&arg)); }, attr.value)); } else if (attr.name == "name") { kernel_name = std::get(attr.value); + size_t last_underscore = kernel_name.find_last_of('_'); + symbol_name = kernel_name.substr(0, last_underscore); } else if (attr.name == "grid") { found_grid = true; auto grid = std::get>(attr.value); @@ -82,6 +102,7 @@ TritonKernel::TritonKernel( } TORCH_CHECK(!kernel_name.empty(), "kernel name not found"); + TORCH_CHECK(!symbol_name.empty(), "symbol_name not found"); TORCH_CHECK(found_grid, "grid attribute not found"); TORCH_CHECK(!output_indices_.empty(), "output_indices attribute not found"); @@ -91,20 +112,20 @@ TritonKernel::TritonKernel( if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".cubin")) { loader_ = TritonKernelManagerRegistry()->Create( - at::kCUDA, kernel_name, tmp_dir + kernel_name + ".cubin", ""); + at::kCUDA, symbol_name, tmp_dir + kernel_name + ".cubin", ""); TORCH_CHECK( loader_ != nullptr, "couldn't find cuda loader -- is this a gpu build?"); } else if (reader->hasRecord(kernel_prefix + "/" + kernel_name + ".hsaco")) { loader_ = TritonKernelManagerRegistry()->Create( - at::kHIP, kernel_name, tmp_dir + kernel_name + ".hsaco", ""); + at::kHIP, symbol_name, tmp_dir + kernel_name + ".hsaco", ""); TORCH_CHECK( loader_ != nullptr, "couldn't find cuda loader -- is this a gpu build?"); } else { loader_ = TritonKernelManagerRegistry()->Create( at::kCPU, - kernel_name, + symbol_name, tmp_dir + kernel_name + ".so", tmp_dir + kernel_name + ".launcher.so"); } diff --git a/torch/nativert/kernels/TritonKernel.h b/torch/nativert/kernels/TritonKernel.h index 4f9f0e47b00cd..29453ca190449 100644 --- a/torch/nativert/kernels/TritonKernel.h +++ b/torch/nativert/kernels/TritonKernel.h @@ -24,6 +24,8 @@ class TritonKernel : public OpKernel { // unnamed node attributes will be passed as arguments to the kernel std::vector attr_ptrs_; + // Storage for float attributes that were serialized as doubles + std::vector float_attrs_; std::vector output_indices_; LaunchParams launch_params_; }; diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index 756dc643baf63..a84a5b681d638 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -51,7 +51,7 @@ def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False): # Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1. # For other dims, subtract 1 to convert to inner space. return ( - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] ragged_dim - 1 if dim == 0 else dim - 1 ) @@ -502,13 +502,13 @@ def _rms_norm_sig(input, normalized_shape, weight=None, eps=None): "self: jt_all", ) def tensor_attr_supported_getter(func, *args, **kwargs): - if func == torch.ops.aten.is_non_overlapping_and_dense.default: + if func is torch.ops.aten.is_non_overlapping_and_dense.default: return False - if func == torch.ops.aten.sym_size.default: + if func is torch.ops.aten.sym_size.default: return args[0]._size - if func == torch.ops.aten.dim.default: + if func is torch.ops.aten.dim.default: return len(args[0]._size) if func in (torch.ops.aten.sym_numel.default, torch.ops.aten.numel.default): @@ -516,10 +516,10 @@ def tensor_attr_supported_getter(func, *args, **kwargs): return int(sum(args[0]._lengths) * math.prod(args[0]._size[2:])) return args[0]._values.numel() - if func == torch.ops.aten.sym_stride.default: + if func is torch.ops.aten.sym_stride.default: return args[0]._strides - if func == torch.ops.aten.sym_storage_offset.default: + if func is torch.ops.aten.sym_storage_offset.default: return args[0]._values.storage_offset() @@ -533,7 +533,7 @@ def prim_layout_default(func, *args, **kwargs): "self: jt_all", ) def tensor_attr_unsupported_getter(func, *args, **kwargs): - if func == torch.ops.aten.size.default: + if func is torch.ops.aten.size.default: raise RuntimeError( "NestedTensor does not support directly calling torch.ops.aten.size; " "please use `nested_tensor.size()` instead." @@ -1995,7 +1995,7 @@ def index_put_(func, *args, **kwargs): max_seqlen=max_seqlen, ) - if func == torch.ops.aten.index_put_.default: + if func is torch.ops.aten.index_put_.default: inp._values.copy_(new_njt.values()) return inp return new_njt @@ -2008,7 +2008,7 @@ def index_put_(func, *args, **kwargs): else: lengths = inp.lengths() torch._assert_async( - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] torch.all(indices[inp._ragged_idx] < lengths), "Some indices in the ragged dimension are out of bounds!", ) @@ -2024,7 +2024,7 @@ def index_put_(func, *args, **kwargs): + indices[inp._ragged_idx + 1 :] ) - if func == torch.ops.aten.index_put_.default: + if func is torch.ops.aten.index_put_.default: inp._values = func(inp._values, func_indices, **new_kwargs) return inp diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py index 4e8d430e845b5..fe385dc5c766f 100644 --- a/torch/nested/_internal/sdpa.py +++ b/torch/nested/_internal/sdpa.py @@ -438,7 +438,7 @@ def _view_as_dense( # # this is because needs_broadcast indicates that the batch_size is 1 # # and hence there is only 1 value for seq_len # # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1), -# # ..., outut_batch_size * {*}_t.size(1)] +# # ..., output_batch_size * {*}_t.size(1)] # # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1) # if q_batch_size_needs_broadcast or not q_t.is_nested: diff --git a/torch/nn/attention/__init__.py b/torch/nn/attention/__init__.py index 9113fd7e37912..5e6e0fa5fae3b 100644 --- a/torch/nn/attention/__init__.py +++ b/torch/nn/attention/__init__.py @@ -14,14 +14,11 @@ SDPAParams, ) -from .varlen import varlen_attn - __all__: list[str] = [ "SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS", - "varlen_attn", ] # Note: [SDPA warnings] diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index cbf1efdd7571d..b79b86a29afb6 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -738,7 +738,7 @@ def causal_mask(b, h, q_idx, kv_idx): (slice(i + n, i + n + 1) if -n <= i < 0 else slice(i, i + 1)) if isinstance(i, int) else i - for i, n in zip(padded, sizes) + for i, n in zip(padded, sizes, strict=True) ) new_kv_num_blocks = self.kv_num_blocks[index] new_kv_indices = self.kv_indices[index] diff --git a/torch/nn/attention/varlen.py b/torch/nn/attention/varlen.py index 7234dd5e7912d..3a81a3b0cee39 100644 --- a/torch/nn/attention/varlen.py +++ b/torch/nn/attention/varlen.py @@ -7,7 +7,7 @@ import logging from functools import lru_cache -from typing import NamedTuple, Optional, Union +from typing import Any, NamedTuple, Optional, Union import torch @@ -33,8 +33,7 @@ class AuxRequest(NamedTuple): lse: bool = False -# import failures when I try to register as custom op -# @torch.library.custom_op("torch_nn_attention::_varlen_attn", mutates_args={}) +@torch.library.custom_op("torch_attn::_varlen_attn", mutates_args={}) def _varlen_attn( query: torch.Tensor, key: torch.Tensor, @@ -44,7 +43,7 @@ def _varlen_attn( max_q: int, max_k: int, is_causal: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Private custom op for variable-length attention. @@ -70,7 +69,7 @@ def _varlen_attn( False, # return_debug_mask ) # cuDNN returns: (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask) - output, softmax_lse = result[0], result[1] + output, softmax_lse, rng_state = result[0], result[1], result[6] else: log.info("Using Flash Attention backend for varlen_attn") output, softmax_lse, rng_state, _, _ = torch.ops.aten._flash_attention_forward( @@ -86,10 +85,13 @@ def _varlen_attn( return_debug_mask=False, ) - return output, softmax_lse + rng_state_ = torch.zeros( + (2,), dtype=torch.uint64, device=query.device + ) # hardcoded since dropout is hardcoded to 0 + return output, softmax_lse, rng_state_ -# @_varlen_attn.register_fake +@_varlen_attn.register_fake def _varlen_attn_fake( query: torch.Tensor, key: torch.Tensor, @@ -99,7 +101,7 @@ def _varlen_attn_fake( max_q: int, max_k: int, is_causal: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Fake implementation for meta tensor computation and tracing. @@ -117,7 +119,9 @@ def _varlen_attn_fake( (num_heads, total_q), dtype=torch.float, device=query.device ) - return output, logsumexp + rng_state = torch.empty((2,), dtype=torch.uint64, device=query.device) + + return output, logsumexp, rng_state def varlen_attn( @@ -191,9 +195,132 @@ def varlen_attn( ... query, key, value, cu_seq, cu_seq, max_len, max_len, is_causal=False ... ) """ - out, lse = _varlen_attn( + out, lse, _ = torch.ops.torch_attn._varlen_attn( query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal ) if return_aux is not None and return_aux.lse: return out, lse return out + + +def _setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None: + query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal = inputs + out, lse, rng_state = output + + ctx.save_for_backward(query, key, value, cu_seq_q, cu_seq_k, out, lse, rng_state) + + ctx.max_q = max_q + ctx.max_k = max_k + ctx.is_causal = is_causal + + +@torch.library.custom_op("torch_attn::_varlen_attn_backward", mutates_args={}) +def _varlen_attn_backward( + grad_out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seq_q: torch.Tensor, + cu_seq_k: torch.Tensor, + max_q: int, + max_k: int, + is_causal: bool, + rng_state: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + unused = torch.empty(0, device=query.device) + + use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index) + if use_cudnn: + log.info("Using cuDNN backend for varlen_attn") + dq, dk, dv = torch.ops.aten._cudnn_attention_backward( + grad_out, + query, + key, + value, + out, + lse, + cu_seq_q, + cu_seq_k, + max_q, + max_k, + 0.0, + is_causal, + rng_state, + unused, + ) + else: + log.info("Using Flash Attention backend for varlen_attn") + dq, dk, dv = torch.ops.aten._flash_attention_backward( + grad_out, + query, + key, + value, + out, + lse, + cu_seq_q, + cu_seq_k, + max_q, + max_k, + 0.0, + is_causal, + rng_state, + unused, + ) + return dq, dk, dv + + +@_varlen_attn_backward.register_fake +def _varlen_attn_backward_fake( + grad_out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seq_q: torch.Tensor, + cu_seq_k: torch.Tensor, + max_q: int, + max_k: int, + is_causal: bool, + rng_state: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Fake implementation for meta tensor computation and tracing. + """ + + grad_query = torch.empty_like(query) + grad_key = torch.empty_like(key) + grad_value = torch.empty_like(value) + + return grad_query, grad_key, grad_value + + +def _backward( + ctx: Any, grad_out: torch.Tensor, grad_lse: torch.Tensor, grad_rng: torch.Tensor +) -> tuple[Optional[torch.Tensor], ...]: + query, key, value, cu_seq_q, cu_seq_k, out, lse, rng_state = ctx.saved_tensors + + max_q = ctx.max_q + max_k = ctx.max_k + is_causal = ctx.is_causal + + dq, dk, dv = torch.ops.torch_attn._varlen_attn_backward( + grad_out, + query, + key, + value, + out, + lse, + cu_seq_q, + cu_seq_k, + max_q, + max_k, + is_causal, + rng_state, + ) + return dq, dk, dv, None, None, None, None, None, None + + +_varlen_attn.register_autograd(_backward, setup_context=_setup_context) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 360d687094d9b..bc1e873c428fb 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2229,7 +2229,7 @@ def gumbel_softmax( ).scatter_(dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft else: - # Reparametrization trick. + # Reparameterization trick. ret = y_soft return ret @@ -3304,7 +3304,8 @@ def gaussian_nll_loss( # or input.size = (4, 3, 32, 32), var.size = (4, 1, 32, 32) elif ( input.ndim == var.ndim - and sum(y for x, y in zip(input.size(), var.size()) if x != y) == 1 + and sum(y for x, y in zip(input.size(), var.size(), strict=True) if x != y) + == 1 ): # Heteroscedastic case pass diff --git a/torch/nn/modules/adaptive.py b/torch/nn/modules/adaptive.py index 26103c4f2a7b9..4267ed9993bff 100644 --- a/torch/nn/modules/adaptive.py +++ b/torch/nn/modules/adaptive.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs +import itertools from collections import namedtuple from collections.abc import Sequence @@ -273,7 +274,7 @@ def _get_full_log_prob(self, input, head_output): out[:, : self.shortlist_size] = head_logprob[:, : self.shortlist_size] - for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])): + for i, (start_idx, stop_idx) in enumerate(itertools.pairwise(self.cutoffs)): cluster_output = self.tail[i](input) cluster_logprob = F.log_softmax(cluster_output, dim=1) output_logprob = cluster_logprob + head_logprob[ diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index a93843f859a7b..1132dc2bb0d4d 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -150,7 +150,9 @@ def __delitem__(self, idx: Union[slice, int]) -> None: delattr(self, key) # To preserve numbering str_indices = [str(i) for i in range(len(self._modules))] - self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) + self._modules = OrderedDict( + zip(str_indices, self._modules.values(), strict=True) + ) @_copy_to_script_wrapper def __len__(self) -> int: @@ -395,7 +397,9 @@ def __delitem__(self, idx: Union[int, slice]) -> None: delattr(self, self._get_abs_string_index(idx)) # To preserve numbering, self._modules is being reconstructed with modules after deletion str_indices = [str(i) for i in range(len(self._modules))] - self._modules = OrderedDict(list(zip(str_indices, self._modules.values()))) + self._modules = OrderedDict( + zip(str_indices, self._modules.values(), strict=True) + ) @_copy_to_script_wrapper def __len__(self) -> int: @@ -432,7 +436,9 @@ def __repr__(self) -> str: lines = [] main_str = self._get_name() + "(" - for (start_id, end_id), b in zip(start_end_indices, repeated_blocks): + for (start_id, end_id), b in zip( + start_end_indices, repeated_blocks, strict=True + ): local_repr = f"({start_id}): {b}" # default repr if start_id != end_id: diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index f06e38c2abae2..e0923fb786493 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -142,7 +142,10 @@ def __init__( self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size) if padding == "same": for d, k, i in zip( - dilation, kernel_size, range(len(kernel_size) - 1, -1, -1) + dilation, + kernel_size, + range(len(kernel_size) - 1, -1, -1), + strict=False, ): total_padding = d * (k - 1) left_pad = total_padding // 2 @@ -1468,7 +1471,7 @@ def _get_num_spatial_dims(self) -> int: raise NotImplementedError -# LazyConv1d defines weight as a Tensor but derived class defines it as UnitializeParameter +# LazyConv1d defines weight as a Tensor but derived class defines it as UninitializeParameter class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc] r"""A :class:`torch.nn.Conv1d` module with lazy initialization of the ``in_channels`` argument. @@ -1540,7 +1543,7 @@ def _get_num_spatial_dims(self) -> int: return 1 -# LazyConv2d defines weight as a Tensor but derived class defines it as UnitializeParameter +# LazyConv2d defines weight as a Tensor but derived class defines it as UninitializeParameter class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc] r"""A :class:`torch.nn.Conv2d` module with lazy initialization of the ``in_channels`` argument. @@ -1612,7 +1615,7 @@ def _get_num_spatial_dims(self) -> int: return 2 -# LazyConv3d defines weight as a Tensor but derived class defines it as UnitializeParameter +# LazyConv3d defines weight as a Tensor but derived class defines it as UninitializeParameter class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc] r"""A :class:`torch.nn.Conv3d` module with lazy initialization of the ``in_channels`` argument. @@ -1685,7 +1688,7 @@ def _get_num_spatial_dims(self) -> int: return 3 -# LazyConvTranspose1d defines weight as a Tensor but derived class defines it as UnitializeParameter +# LazyConvTranspose1d defines weight as a Tensor but derived class defines it as UninitializeParameter class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[misc] r"""A :class:`torch.nn.ConvTranspose1d` module with lazy initialization of the ``in_channels`` argument. @@ -1757,7 +1760,7 @@ def _get_num_spatial_dims(self) -> int: return 1 -# LazyConvTranspose2d defines weight as a Tensor but derived class defines it as UnitializeParameter +# LazyConvTranspose2d defines weight as a Tensor but derived class defines it as UninitializeParameter class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[misc] r"""A :class:`torch.nn.ConvTranspose2d` module with lazy initialization of the ``in_channels`` argument. @@ -1829,7 +1832,7 @@ def _get_num_spatial_dims(self) -> int: return 2 -# LazyConvTranspose3d defines weight as a Tensor but derived class defines it as UnitializeParameter +# LazyConvTranspose3d defines weight as a Tensor but derived class defines it as UninitializeParameter class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[misc] r"""A :class:`torch.nn.ConvTranspose3d` module with lazy initialization of the ``in_channels`` argument. diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index c7b44b61354a6..13cd9ec08cb55 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -196,7 +196,7 @@ def __init__( param_names += ["weight_hr_l{}{}"] param_names = [x.format(layer, suffix) for x in param_names] - for name, param in zip(param_names, layer_params): + for name, param in zip(param_names, layer_params, strict=True): setattr(self, name, param) self._flat_weights_names.extend(param_names) self._all_weights.append(param_names) @@ -352,7 +352,9 @@ def _weights_have_changed(self): # Returns True if the weight tensors have changed since the last forward pass. # This is the case when used with torch.func.functional_call(), for example. weights_changed = False - for ref, name in zip(self._flat_weight_refs, self._flat_weights_names): + for ref, name in zip( + self._flat_weight_refs, self._flat_weights_names, strict=True + ): weight = getattr(self, name) if hasattr(self, name) else None if weight is not None and ref is not None and ref() is not weight: weights_changed = True diff --git a/torch/nn/modules/utils.py b/torch/nn/modules/utils.py index cfe621983dc21..5dffadefe152d 100644 --- a/torch/nn/modules/utils.py +++ b/torch/nn/modules/utils.py @@ -41,7 +41,8 @@ def _list_with_default(out_size: list[int], defaults: list[int]) -> list[int]: if len(defaults) <= len(out_size): raise ValueError(f"Input dimension should be at least {len(out_size) + 1}") return [ - v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size) :]) + v if v is not None else d + for v, d in zip(out_size, defaults[-len(out_size) :], strict=False) ] diff --git a/torch/nn/parallel/comm.py b/torch/nn/parallel/comm.py index 3df1b4b4eadcf..255c0c4b33271 100644 --- a/torch/nn/parallel/comm.py +++ b/torch/nn/parallel/comm.py @@ -141,18 +141,18 @@ def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760): output = [] ref_order = [] # process sparse ones first since they may have different sizes on different gpus - for tensor_at_gpus in zip(*inputs): + for tensor_at_gpus in zip(*inputs, strict=True): if all(t.is_sparse for t in tensor_at_gpus): result = reduce_add(tensor_at_gpus, destination) # this will be sparse too output.append(result) ref_order.append(tensor_at_gpus[0]) else: - for coll, t in zip(dense_tensors, tensor_at_gpus): + for coll, t in zip(dense_tensors, tensor_at_gpus, strict=True): coll.append(t.to_dense() if t.is_sparse else t) ref_order.append(dense_tensors[0][-1]) itrs = [_take_tensors(tensors, buffer_size) for tensors in dense_tensors] # now the dense ones, which have consistent sizes - for chunks in zip(*itrs): + for chunks in zip(*itrs, strict=True): flat_tensors = [ _flatten_dense_tensors(chunk) for chunk in chunks ] # (num_gpus,) diff --git a/torch/nn/parallel/parallel_apply.py b/torch/nn/parallel/parallel_apply.py index aa8db823e1185..d0b50bbc20208 100644 --- a/torch/nn/parallel/parallel_apply.py +++ b/torch/nn/parallel/parallel_apply.py @@ -115,7 +115,7 @@ def _worker( target=_worker, args=(i, module, input, kwargs, device, stream) ) for i, (module, input, kwargs, device, stream) in enumerate( - zip(modules, inputs, kwargs_tup, devices, streams) + zip(modules, inputs, kwargs_tup, devices, streams, strict=True) ) ] diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index a2917bddd0327..96a1237275252 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -57,16 +57,26 @@ def scatter_map(obj): return Scatter.apply(target_gpus, None, dim, obj) if _is_namedtuple(obj): # pyrefly: ignore [no-matching-overload] - return [type(obj)(*args) for args in zip(*map(scatter_map, obj))] + return [ + # pyrefly: ignore [no-matching-overload] + type(obj)(*args) + # pyrefly: ignore # no-matching-overload + for args in zip(*map(scatter_map, obj), strict=False) + ] if isinstance(obj, tuple) and len(obj) > 0: # pyrefly: ignore [no-matching-overload] - return list(zip(*map(scatter_map, obj))) + return list(zip(*map(scatter_map, obj), strict=False)) if isinstance(obj, list) and len(obj) > 0: # pyrefly: ignore [no-matching-overload] - return [list(i) for i in zip(*map(scatter_map, obj))] + return [list(i) for i in zip(*map(scatter_map, obj), strict=False)] if isinstance(obj, dict) and len(obj) > 0: # pyrefly: ignore [no-matching-overload] - return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))] + return [ + # pyrefly: ignore [no-matching-overload] + type(obj)(i) + # pyrefly: ignore # no-matching-overload + for i in zip(*map(scatter_map, obj.items()), strict=False) + ] return [obj for _ in target_gpus] # After scatter_map is called, a scatter_map cell will exist. This cell @@ -131,9 +141,9 @@ def gather_map(outputs): return type(out)((k, gather_map([d[k] for d in outputs])) for k in out) if _is_namedtuple(out): # pyrefly: ignore [no-matching-overload] - return type(out)._make(map(gather_map, zip(*outputs))) + return type(out)._make(map(gather_map, zip(*outputs, strict=True))) # pyrefly: ignore [no-matching-overload] - return type(out)(map(gather_map, zip(*outputs))) + return type(out)(map(gather_map, zip(*outputs, strict=True))) # Recursive function calls like this create reference cycles. # Setting the function to None clears the refcycle. diff --git a/torch/nn/utils/__init__.py b/torch/nn/utils/__init__.py index ed9a83b133896..e9253264d1e0e 100644 --- a/torch/nn/utils/__init__.py +++ b/torch/nn/utils/__init__.py @@ -1,5 +1,5 @@ from . import parametrizations, parametrize, rnn, stateless -from .clip_grad import ( # pyrefly: ignore # deprecated +from .clip_grad import ( # pyrefly: ignore # deprecated; pyrefly: ignore [deprecated] _clip_grads_with_norm_ as clip_grads_with_norm_, _get_total_norm as get_total_norm, clip_grad_norm, diff --git a/torch/nn/utils/_expanded_weights/conv_utils.py b/torch/nn/utils/_expanded_weights/conv_utils.py index e755362a4f201..d68a82b71268b 100644 --- a/torch/nn/utils/_expanded_weights/conv_utils.py +++ b/torch/nn/utils/_expanded_weights/conv_utils.py @@ -14,12 +14,12 @@ def conv_picker(func, conv1dOpt, conv2dOpt, conv3dOpt): - if func == F.conv1d: + if func is F.conv1d: return conv1dOpt - if func == F.conv2d: + if func is F.conv2d: return conv2dOpt else: - assert func == F.conv3d + assert func is F.conv3d return conv3dOpt @@ -28,7 +28,7 @@ def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs): kwargs = expanded_args_and_kwargs[ len(expanded_args_and_kwargs) - len(kwarg_names) : ] - kwargs = dict(zip(kwarg_names, kwargs)) + kwargs = dict(zip(kwarg_names, kwargs, strict=True)) return conv_normalizer(*args, **kwargs) diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py index 36e3ee7a58909..cfb1d99ac30ec 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py @@ -140,7 +140,7 @@ def __torch_function__(cls, func, _, args=(), kwargs=None): if decomp is not None: with setup_rnn(use_input_variant, args, kwargs): return decomp(*args, **kwargs) - if func == torch._cudnn_rnn_flatten_weight: + if func is torch._cudnn_rnn_flatten_weight: # since we aren't using the fused cuda kernels for RNNs, don't do this return if func in cls.handled_functions: diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py index b3f674d3233d8..ec6d55305fb46 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py @@ -32,7 +32,7 @@ def standard_kwargs(kwarg_names, expanded_args): expanded_args_without_kwargs = expanded_args[ : len(expanded_args) - len(kwarg_names) ] - expanded_kwargs = dict(zip(kwarg_names, kwarg_values)) + expanded_kwargs = dict(zip(kwarg_names, kwarg_values, strict=True)) return expanded_args_without_kwargs, expanded_kwargs diff --git a/torch/nn/utils/_named_member_accessor.py b/torch/nn/utils/_named_member_accessor.py index e815265fec633..0935490856aeb 100644 --- a/torch/nn/utils/_named_member_accessor.py +++ b/torch/nn/utils/_named_member_accessor.py @@ -250,7 +250,7 @@ def set_tensors(self, names: Iterable[str], values: Iterable[torch.Tensor]) -> N values = list(values) assert len(names) == len(values), "names and values must have the same length" - for name, value in zip(names, values): + for name, value in zip(names, values, strict=True): self.set_tensor(name, value) def set_tensors_dict(self, named_tensors: dict[str, torch.Tensor]) -> None: @@ -298,7 +298,7 @@ def swap_tensors( return [ self.swap_tensor(name, value, allow_missing=allow_missing) - for name, value in zip(names, values) + for name, value in zip(names, values, strict=True) ] def swap_tensors_dict( diff --git a/torch/nn/utils/prune.py b/torch/nn/utils/prune.py index 99a1439ec5c8f..3c1a800085951 100644 --- a/torch/nn/utils/prune.py +++ b/torch/nn/utils/prune.py @@ -144,7 +144,7 @@ def _get_composite_method(cls, module, name, *args, **kwargs): method = _get_composite_method(cls, module, name, *args, **kwargs) # at this point we have no forward_pre_hooks but we could have an - # active reparametrization of the tensor if another pruning method + # active reparameterization of the tensor if another pruning method # had been applied (in which case `method` would be a PruningContainer # and not a simple pruning method). diff --git a/torch/nn/utils/rnn.py b/torch/nn/utils/rnn.py index 5529cbc83ef0a..47bd937a32ae0 100644 --- a/torch/nn/utils/rnn.py +++ b/torch/nn/utils/rnn.py @@ -528,7 +528,7 @@ def unpad_sequence( max_length = padded_sequences.shape[1] idx = torch.arange(max_length, device=lengths.device) - for seq, length in zip(padded_sequences, lengths): + for seq, length in zip(padded_sequences, lengths, strict=True): mask = idx < length unpacked_seq = seq[mask] unpadded_sequences.append(unpacked_seq) diff --git a/torch/numa/binding.py b/torch/numa/binding.py index d256c2af055d6..89602e2136ad8 100644 --- a/torch/numa/binding.py +++ b/torch/numa/binding.py @@ -102,6 +102,7 @@ def maybe_wrap_command_args_with_numa_binding( ) return wrapped_command_args except Exception: + # pyrefly: ignore [bad-argument-type] _handle_exception(numa_options=numa_options, logger_kwargs=kwargs) return command_args @@ -140,6 +141,7 @@ def maybe_wrap_with_numa_binding( def wrapped(*args: _TParams.args, **kwargs: _TParams.kwargs) -> _TReturn: _maybe_apply_numa_binding_to_current_process( gpu_index=gpu_index, + # pyrefly: ignore [bad-argument-type] numa_options=numa_options, ) return func(*args, **kwargs) @@ -174,6 +176,7 @@ def _maybe_apply_numa_binding_to_current_process( }, ) except Exception: + # pyrefly: ignore [bad-argument-type] _handle_exception(numa_options=numa_options, logger_kwargs=kwargs) @@ -237,24 +240,24 @@ def _bind_all_threads_in_current_process_to_logical_cpus( *, logical_cpu_indices: set[int] ) -> None: # Save the original affinity of the main thread before changing it - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] original_main_thread_affinity = os.sched_getaffinity(0) # type: ignore[attr-defined] # 0 represents the current thread. # This is outside the try/except because the main thread should always bind successfully. - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] os.sched_setaffinity(0, logical_cpu_indices) # type: ignore[attr-defined] for tid_str in os.listdir("/proc/self/task"): try: tid = int(tid_str) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] tid_affinity = os.sched_getaffinity(tid) # type: ignore[attr-defined] # Defensive check to ensure we do not overwrite affinity on any threads # that have already had their affinity set elsewhere. if tid_affinity == original_main_thread_affinity: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] os.sched_setaffinity(tid, logical_cpu_indices) # type: ignore[attr-defined] except Exception: # Thread may have exited or otherwise become invalid @@ -668,5 +671,5 @@ def _get_numa_node_indices_for_socket_index(*, socket_index: int) -> set[int]: def _get_allowed_cpu_indices_for_current_thread() -> set[int]: # 0 denotes current thread - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return os.sched_getaffinity(0) # type:ignore[attr-defined] diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index df0bf2cd1a225..c19827c7de156 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -324,12 +324,10 @@ def forward(self, x): warnings.warn( "You are using the legacy TorchScript-based ONNX export. Starting in PyTorch 2.9, " - "the new torch.export-based ONNX exporter will be the default. To switch now, set " - "dynamo=True in torch.onnx.export. This new exporter supports features like exporting " - "LLMs with DynamicCache. We encourage you to try it and share feedback to help improve " - "the experience. Learn more about the new export logic: " - "https://pytorch.org/docs/stable/onnx_dynamo.html. For exporting control flow: " - "https://pytorch.org/tutorials/beginner/onnx/export_control_flow_model_to_onnx_tutorial.html.", + "the new torch.export-based ONNX exporter has become the default. " + "Learn more about the new export logic: https://docs.pytorch.org/docs/stable/onnx_export.html. " + "For exporting control flow: " + "https://pytorch.org/tutorials/beginner/onnx/export_control_flow_model_to_onnx_tutorial.html", category=DeprecationWarning, stacklevel=2, ) diff --git a/torch/onnx/_internal/exporter/_building.py b/torch/onnx/_internal/exporter/_building.py index bb07bf97d7bdd..2dbcf8f083877 100644 --- a/torch/onnx/_internal/exporter/_building.py +++ b/torch/onnx/_internal/exporter/_building.py @@ -430,6 +430,7 @@ def _process_python_sequences( # when the expected input type is INT64 # We assume this only happens for 0D cases if all(isinstance(val, ir.Value) for val in arg): + # pyrefly: ignore expanded_args = [_reshape_to_1d_tensor(opset, val) for val in arg] named_inputs[name] = opset.Concat(*expanded_args, axis=0) continue diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py index 71803852690e3..7be2dee45668e 100644 --- a/torch/onnx/_internal/exporter/_compat.py +++ b/torch/onnx/_internal/exporter/_compat.py @@ -82,7 +82,7 @@ def export_compat( "different behavior during inference. " "Calling model.eval() before export is recommended.", UserWarning, - stacklevel=2, + stacklevel=3, ) if isinstance(model, torch.export.ExportedProgram): diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index 4458e00d7679a..b618943c3f21b 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -392,7 +392,7 @@ def _handle_call_function_node( node: The FX node to translate. node_name_to_values: A mapping of FX node names to their produced ir.Value. """ - if node.target == operator.getitem: + if node.target is operator.getitem: _handle_getitem_node(node, node_name_to_values) # Add op to the graph op = str(node.target) @@ -402,7 +402,7 @@ def _handle_call_function_node( if input_ is None: inputs.append(None) elif hasattr(input_, "name"): - if isinstance(input_, torch.fx.Node) and input_.target == operator.getitem: + if isinstance(input_, torch.fx.Node) and input_.target is operator.getitem: actual_input = _handle_getitem_node(input_, node_name_to_values) inputs.append(actual_input) else: @@ -456,7 +456,7 @@ def _convert_fx_arg_to_onnx_arg( # The actual dropping of a None attribute value is done by OpRecorder return None if hasattr(arg, "name"): - if isinstance(arg, torch.fx.Node) and arg.target == operator.getitem: + if isinstance(arg, torch.fx.Node) and arg.target is operator.getitem: source = arg.all_input_nodes[0] source_outputs = node_name_to_values[source.name] if isinstance(source_outputs, Sequence): @@ -527,7 +527,7 @@ def _handle_call_function_node_with_lowering( opset: The ONNX Script opset object for constructing ONNX nodes. node_name_to_local_functions: A mapping of subgraph names to the corresponding ONNX functions. """ - if node.target == operator.getitem: + if node.target is operator.getitem: source = node.all_input_nodes[0] source_outputs = node_name_to_values[source.name] if isinstance(source_outputs, Sequence): diff --git a/torch/onnx/_internal/exporter/_ir_passes.py b/torch/onnx/_internal/exporter/_ir_passes.py index 9391b642b009b..fe10ccd15862d 100644 --- a/torch/onnx/_internal/exporter/_ir_passes.py +++ b/torch/onnx/_internal/exporter/_ir_passes.py @@ -76,6 +76,7 @@ def rename_axis(model: ir.Model, rename_mapping: dict[str, str]) -> None: continue dim_name = dim.value if dim_name in sorted_rename_mapping: + # pyrefly: ignore new_shape.append(sorted_rename_mapping[dim_name]) changed = True elif dim_name is not None: diff --git a/torch/onnx/_internal/exporter/_torchlib/_tensor_typing.py b/torch/onnx/_internal/exporter/_torchlib/_tensor_typing.py index dd2ee6c81792b..c0faf24f6f269 100644 --- a/torch/onnx/_internal/exporter/_torchlib/_tensor_typing.py +++ b/torch/onnx/_internal/exporter/_torchlib/_tensor_typing.py @@ -1,8 +1,3 @@ -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - """Typings for function definitions.""" from __future__ import annotations diff --git a/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py b/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py index a2f86a6ccf266..6db344123519e 100644 --- a/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py +++ b/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py @@ -1,5 +1,3 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. """Registry for aten functions.""" from __future__ import annotations diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/core.py b/torch/onnx/_internal/exporter/_torchlib/ops/core.py index 3b3b5691bd741..36d53b113edc2 100644 --- a/torch/onnx/_internal/exporter/_torchlib/ops/core.py +++ b/torch/onnx/_internal/exporter/_torchlib/ops/core.py @@ -1,5 +1,6 @@ """torch.ops.aten operators under the `core` module.""" # mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index" +# pyrefly: ignore-errors # ruff: noqa: TCH001,TCH002 from __future__ import annotations diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py index 1623028e1f553..31f87046315b6 100644 --- a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py +++ b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py @@ -1,11 +1,12 @@ """torch.ops.aten operators under the `core` module.""" # mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index" +# pyrefly: ignore-errors # ruff: noqa: TCH001,TCH002 # flake8: noqa: B950 from __future__ import annotations -from typing import Optional, TYPE_CHECKING +from typing import Optional, Sequence, TYPE_CHECKING from onnxscript.onnx_opset import ( # type: ignore[attr-defined] opset20 as op20, @@ -24,9 +25,6 @@ aten = torch.ops.aten -_INT64_MAX = 9223372036854775807 -_INT64_MIN = -9223372036854775808 - @onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20) def aten_gelu_opset20( @@ -50,11 +48,9 @@ def aten_group_norm( c = op21.Shape(input, start=1, end=2) if weight is None: - # pyrefly: ignore [missing-attribute] - weight = op21.ConstantOfShape(c, value=ir.tensor(1.0, dtype=input.dtype)) + weight = op21.ConstantOfShape(c, value=ir.tensor([1.0], dtype=input.dtype)) if bias is None: - # pyrefly: ignore [missing-attribute] - bias = op21.ConstantOfShape(c, value=ir.tensor(0.0, dtype=input.dtype)) + bias = op21.ConstantOfShape(c, value=ir.tensor([0.0], dtype=input.dtype)) return op21.GroupNormalization( input, weight, bias, epsilon=eps, num_groups=num_groups ) @@ -63,7 +59,7 @@ def aten_group_norm( @onnx_impl(aten.rms_norm.default, trace_only=True, opset_introduced=23) def aten_rms_norm( input: TFloat, - normalized_shape: list[int], + normalized_shape: Sequence[int], weight: Optional[TFloat] = None, eps: Optional[float] = None, ) -> TFloat: @@ -82,8 +78,9 @@ def aten_rms_norm( # Create weight tensor if not provided if weight is None: - # pyrefly: ignore [missing-attribute] - weight = op23.Constant(value=ir.tensor(1.0, dtype=input.dtype)) + weight = op23.ConstantOfShape( + op23.Shape(input), value=ir.tensor([1], dtype=input.dtype) + ) return op23.RMSNormalization(input, weight, axis=axis, epsilon=eps) @@ -131,7 +128,6 @@ def aten_scaled_dot_product_attention_23( assert (not is_causal) or (is_causal and attn_mask is None), ( "is_causal and attn_mask cannot be set at the same time" ) - # pyrefly: ignore [missing-attribute] assert len(query.shape) == 4 and len(key.shape) == 4 and len(value.shape) == 4, ( "only 4D query, key, and value are supported" ) @@ -140,15 +136,12 @@ def aten_scaled_dot_product_attention_23( if dropout_p == 0: if enable_gqa: assert ( - # pyrefly: ignore [index-error] query.shape[1] > key.shape[1] == value.shape[1] - # pyrefly: ignore [index-error] and query.shape[1] % key.shape[1] == 0 ), ( "SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0" ) else: - # pyrefly: ignore [index-error] assert query.shape[1] == key.shape[1] == value.shape[1], ( "SDPA (MHA) requires q_num_heads = kv_num_heads" ) @@ -209,9 +202,7 @@ def _attention_repeat_kv_for_group_query( """ assert ( - # pyrefly: ignore [missing-attribute] query.shape[1] > key.shape[1] == value.shape[1] - # pyrefly: ignore [missing-attribute] and query.shape[1] % key.shape[1] == 0 ), ( "SDPA (GQA or MQA) requires q_num_heads > kv_num_heads & q_num_heads % kv_num_heads == 0" diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/symops.py b/torch/onnx/_internal/exporter/_torchlib/ops/symops.py index bba780fed535e..2a21dc4ce8e13 100644 --- a/torch/onnx/_internal/exporter/_torchlib/ops/symops.py +++ b/torch/onnx/_internal/exporter/_torchlib/ops/symops.py @@ -1,6 +1,7 @@ """Implementation for torch.sym* ops.""" # mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index" +# pyrefly: ignore-errors # ruff: noqa: TCH001,TCH002 from __future__ import annotations diff --git a/torch/onnx/_internal/fx/type_utils.py b/torch/onnx/_internal/fx/type_utils.py index 968f69328011d..072f9f10e2646 100644 --- a/torch/onnx/_internal/fx/type_utils.py +++ b/torch/onnx/_internal/fx/type_utils.py @@ -164,7 +164,7 @@ def from_scalar_type_to_torch_dtype(scalar_type: type) -> torch.dtype | None: SYM_VALUE_TYPE = Union[torch.SymInt, torch.SymFloat, torch.SymBool] META_VALUE_TYPE = Union[fake_tensor.FakeTensor, SYM_VALUE_TYPE, int, float, bool] -# NOTE: Belows are from torch/fx/node.py +# NOTE: Below are from torch/fx/node.py BaseArgumentTypes = Union[ str, int, diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py index 72c5074eb3856..e1b34469fbf20 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset9.py @@ -6099,7 +6099,7 @@ def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None): if other_dim_rank != self_dim_rank: delta = self_dim_rank - other_dim_rank - for i in range(delta): + for _ in range(delta): other = symbolic_helper._unsqueeze_helper( g, other, [symbolic_helper._get_tensor_rank(other)] ) @@ -6126,10 +6126,10 @@ def index_add(g: jit_utils.GraphContext, self, dim, index, other, alpha=None): ) other = expand_as(g, other, new_shape) - for i in range(dim): + for _ in range(dim): index = symbolic_helper._unsqueeze_helper(g, index, [0]) - for i in range(self_dim_rank - dim - 1): + for _ in range(self_dim_rank - dim - 1): index = symbolic_helper._unsqueeze_helper( g, index, [symbolic_helper._get_tensor_rank(index)] ) diff --git a/torch/onnx/_internal/torchscript_exporter/verification.py b/torch/onnx/_internal/torchscript_exporter/verification.py index a79540f7155c9..32885d1f63774 100644 --- a/torch/onnx/_internal/torchscript_exporter/verification.py +++ b/torch/onnx/_internal/torchscript_exporter/verification.py @@ -251,7 +251,7 @@ def _compare_onnx_pytorch_outputs_in_np( # pyrefly: ignore [missing-attribute] if ort_out.dtype == np.uint8 or ort_out.dtype == np.int8: warnings.warn("ONNX output is quantized", stacklevel=2) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] if pt_out.dtype == np.uint8 or pt_out.dtype == np.int8: warnings.warn("PyTorch output is quantized", stacklevel=2) raise diff --git a/torch/onnx/ops/_symbolic_impl.py b/torch/onnx/ops/_symbolic_impl.py index 4876612ad978b..aafe9c00828cc 100644 --- a/torch/onnx/ops/_symbolic_impl.py +++ b/torch/onnx/ops/_symbolic_impl.py @@ -78,7 +78,7 @@ def from_dict( attr_floats=[], attr_strs=[], ) - for i, (k, v) in enumerate(attrs.items()): + for k, v in attrs.items(): encoded.attr_keys.append(k) if isinstance(v, int): start_pos = len(encoded.attr_ints) diff --git a/torch/optim/_adafactor.py b/torch/optim/_adafactor.py index 3b529201d7ff2..4def193daf190 100644 --- a/torch/optim/_adafactor.py +++ b/torch/optim/_adafactor.py @@ -532,7 +532,7 @@ def _multi_tensor_adafactor( alphas = [ max(eps2, p.norm(2).item() / (p.numel() ** 0.5)) * r - for p, r in zip(device_params, rho_ts) + for p, r in zip(device_params, rho_ts, strict=True) ] # Perform stepweight decay @@ -566,7 +566,9 @@ def _multi_tensor_adafactor( var_estimates = [ row_var @ col_var - for row_var, col_var in zip(device_row_vars, device_col_vars) + for row_var, col_var in zip( + device_row_vars, device_col_vars, strict=True + ) ] row_var_means = [ row_var.mean(dim=-2, keepdim=True) for row_var in device_row_vars @@ -594,7 +596,7 @@ def _multi_tensor_adafactor( alphas = [ -a / (max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * d))) - for a, update in zip(alphas, updates) + for a, update in zip(alphas, updates, strict=True) ] torch._foreach_mul_(updates, alphas) torch._foreach_add_(device_params, updates) diff --git a/torch/optim/_muon.py b/torch/optim/_muon.py index 7c2d5465c63cd..7b7167a40fc1c 100644 --- a/torch/optim/_muon.py +++ b/torch/optim/_muon.py @@ -78,7 +78,7 @@ def _adjust_lr( A, B = param_shape[:2] if adjust_lr_fn is None or adjust_lr_fn == "original": - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] adjusted_ratio = math.sqrt(max(1, A / B)) elif adjust_lr_fn == "match_rms_adamw": adjusted_ratio = 0.2 * math.sqrt(max(A, B)) diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index 2261faa3908da..4a893026451ae 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -266,7 +266,7 @@ def _single_tensor_adadelta( if not all( p.device.type == step.device.type and p.device.type in capturable_supported_devices - for p, step in zip(params, state_steps) + for p, step in zip(params, state_steps, strict=True) ): raise AssertionError( f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." @@ -276,7 +276,7 @@ def _single_tensor_adadelta( lr = _to_scalar(lr) for param, grad, square_avg, acc_delta, step in zip( - params, grads, square_avgs, acc_deltas, state_steps + params, grads, square_avgs, acc_deltas, state_steps, strict=True ): step += 1 grad = grad if not maximize else -grad @@ -329,7 +329,7 @@ def _multi_tensor_adadelta( if not all( p.device.type == step.device.type and p.device.type in capturable_supported_devices - for p, step in zip(params, state_steps) + for p, step in zip(params, state_steps, strict=True) ): raise AssertionError( f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index 4afb0b60d4951..4d2523b2a16af 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -343,7 +343,9 @@ def _single_tensor_adagrad( if not torch.jit.is_scripting(): lr = _to_scalar(lr) - for param, grad, state_sum, step_t in zip(params, grads, state_sums, state_steps): + for param, grad, state_sum, step_t in zip( + params, grads, state_sums, state_steps, strict=True + ): # update step step_t += 1 step = _get_value(step_t) diff --git a/torch/optim/adam.py b/torch/optim/adam.py index b10dadd9e5098..5ceadccce86a5 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -423,7 +423,7 @@ def _single_tensor_adam( if weight_decay.requires_grad: grad = grad.addcmul_(param.clone(), weight_decay) else: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] grad = grad.add(param, alpha=weight_decay) else: grad = grad.add(param, alpha=weight_decay) @@ -608,7 +608,7 @@ def _multi_tensor_adam( if not all( p.device.type == step.device.type and p.device.type in capturable_supported_devices - for p, step in zip(params, state_steps) + for p, step in zip(params, state_steps, strict=True) ): raise AssertionError( f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index 42440024f1249..76d784d6ea764 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -334,7 +334,7 @@ def _multi_tensor_adamax( if not all( p.device.type == step.device.type and p.device.type in capturable_supported_devices - for p, step in zip(params, state_steps) + for p, step in zip(params, state_steps, strict=True) ): raise AssertionError( f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index 831ce069f5a15..0008694bda18b 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -264,7 +264,7 @@ def _single_tensor_asgd( ax.copy_(param) if capturable: - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha)) mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t))) else: @@ -307,7 +307,7 @@ def _multi_tensor_asgd( if not all( p.device.type == mu.device.type == eta.device.type == step.device.type and p.device.type in capturable_supported_devices - for p, mu, eta, step in zip(params, mus, etas, state_steps) + for p, mu, eta, step in zip(params, mus, etas, state_steps, strict=True) ): raise AssertionError( f"If capturable=True, params, mus, etas, and state_steps must be on " diff --git a/torch/optim/lbfgs.py b/torch/optim/lbfgs.py index 8c191f71f04e8..ae4b286ffa225 100644 --- a/torch/optim/lbfgs.py +++ b/torch/optim/lbfgs.py @@ -113,16 +113,18 @@ def _strong_wolfe( # compute new trial value t = _cubic_interpolate( - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] + # pyrefly: ignore [unbound-name] bracket[0], - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] bracket_f[0], bracket_gtd[0], # type: ignore[possibly-undefined] - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] + # pyrefly: ignore [unbound-name] bracket[1], - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] bracket_f[1], - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] bracket_gtd[1], ) @@ -133,20 +135,20 @@ def _strong_wolfe( # + `t` is at one of the boundary, # we will move `t` to a position which is `0.1 * len(bracket)` # away from the nearest boundary point. - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] eps = 0.1 * (max(bracket) - min(bracket)) - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] if min(max(bracket) - t, t - min(bracket)) < eps: # interpolation close to boundary - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] if insuf_progress or t >= max(bracket) or t <= min(bracket): # evaluate at 0.1 away from boundary - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] if abs(t - max(bracket)) < abs(t - min(bracket)): - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] t = max(bracket) - eps else: - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] t = min(bracket) + eps insuf_progress = False else: @@ -160,45 +162,49 @@ def _strong_wolfe( gtd_new = g_new.dot(d) ls_iter += 1 - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]: # Armijo condition not satisfied or not lower than lowest point - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] + # pyrefly: ignore [unbound-name] bracket[high_pos] = t - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] bracket_f[high_pos] = f_new bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined] - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] bracket_gtd[high_pos] = gtd_new - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0) else: if abs(gtd_new) <= -c2 * gtd: # Wolfe conditions satisfied done = True - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] + # pyrefly: ignore [unbound-name] elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0: # old high becomes new low - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] + # pyrefly: ignore [unbound-name] bracket[high_pos] = bracket[low_pos] - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] bracket_f[high_pos] = bracket_f[low_pos] bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined] - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] bracket_gtd[high_pos] = bracket_gtd[low_pos] # new point becomes new low - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] + # pyrefly: ignore [unbound-name] bracket[low_pos] = t - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] bracket_f[low_pos] = f_new bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined] - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] bracket_gtd[low_pos] = gtd_new # return stuff t = bracket[low_pos] # type: ignore[possibly-undefined] - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] f_new = bracket_f[low_pos] g_new = bracket_g[low_pos] # type: ignore[possibly-undefined] return f_new, g_new, t, ls_func_evals @@ -276,7 +282,7 @@ def __init__( def _numel(self): if self._numel_cache is None: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self._numel_cache = sum( 2 * p.numel() if torch.is_complex(p) else p.numel() for p in self._params @@ -314,7 +320,7 @@ def _clone_param(self): return [p.clone(memory_format=torch.contiguous_format) for p in self._params] def _set_param(self, params_data): - for p, pdata in zip(self._params, params_data): + for p, pdata in zip(self._params, params_data, strict=True): p.copy_(pdata) def _directional_evaluate(self, closure, x, t, d): diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 3cc6649e0d80c..71dcb6129a8ec 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -302,7 +302,7 @@ def _update_lr(self, epoch: Optional[int] = None): else: values = self.get_lr() - for param_group, lr in zip(self.optimizer.param_groups, values): + for param_group, lr in zip(self.optimizer.param_groups, values, strict=True): _update_param_group_val(param_group, "lr", lr) self._last_lr: list[float | Tensor] = _param_groups_val_list( @@ -422,7 +422,7 @@ def state_dict(self) -> dict[str, Any]: for idx, fn in enumerate(self.lr_lambdas): if not isinstance(fn, types.FunctionType): - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] state_dict["lr_lambdas"][idx] = fn.__dict__.copy() return state_dict @@ -472,7 +472,7 @@ def get_lr(self) -> list[float | Tensor]: return [ base_lr * lmbda(self.last_epoch) - for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs) + for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs, strict=True) ] @@ -542,7 +542,7 @@ def state_dict(self) -> dict[str, Any]: for idx, fn in enumerate(self.lr_lambdas): if not isinstance(fn, types.FunctionType): - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] state_dict["lr_lambdas"][idx] = fn.__dict__.copy() return state_dict @@ -592,7 +592,9 @@ def get_lr(self) -> list[float | Tensor]: if not self._is_initial: return [ group["lr"] * lmbda(self.last_epoch) - for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups) + for lmbda, group in zip( + self.lr_lambdas, self.optimizer.param_groups, strict=True + ) ] else: return _param_groups_val_list(self.optimizer, "lr") @@ -1219,7 +1221,7 @@ def state_dict(self) -> dict[str, Any]: state_dict["_schedulers"] = [None] * len(self._schedulers) for idx, s in enumerate(self._schedulers): - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] state_dict["_schedulers"][idx] = s.state_dict() return state_dict @@ -1441,13 +1443,17 @@ def get_lr(self) -> list[float | Tensor]: + (base_lr - self.eta_min) * (1 + math.cos((self.last_epoch) * math.pi / self.T_max)) / 2 - for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + for base_lr, group in zip( + self.base_lrs, self.optimizer.param_groups, strict=True + ) ] elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: return [ group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 - for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + for base_lr, group in zip( + self.base_lrs, self.optimizer.param_groups, strict=True + ) ] return [ (1 + math.cos(math.pi * self.last_epoch / self.T_max)) @@ -1562,7 +1568,7 @@ def state_dict(self) -> dict[str, Any]: state_dict["_schedulers"] = [None] * len(self._schedulers) for idx, s in enumerate(self._schedulers): - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] state_dict["_schedulers"][idx] = s.state_dict() return state_dict @@ -1671,7 +1677,7 @@ def __init__( self.default_min_lr = None self.min_lrs = list(min_lr) else: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.default_min_lr = min_lr self.min_lrs = [min_lr] * len(optimizer.param_groups) @@ -1731,7 +1737,7 @@ def _reduce_lr(self, epoch): "of the `optimizer` param groups." ) else: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.min_lrs = [self.default_min_lr] * len(self.optimizer.param_groups) for i, param_group in enumerate(self.optimizer.param_groups): @@ -1906,18 +1912,18 @@ def __init__( base_lrs = _format_param("base_lr", optimizer, base_lr) if last_epoch == -1: - for lr, group in zip(base_lrs, optimizer.param_groups): + for lr, group in zip(base_lrs, optimizer.param_groups, strict=True): _update_param_group_val(group, "lr", lr) self.max_lrs = _format_param("max_lr", optimizer, max_lr) - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] step_size_up = float(step_size_up) step_size_down = ( - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] float(step_size_down) if step_size_down is not None else step_size_up ) - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] self.total_size = step_size_up + step_size_down self.step_ratio = step_size_up / self.total_size @@ -1949,7 +1955,10 @@ def __init__( self.max_momentums = _format_param("max_momentum", optimizer, max_momentum) if last_epoch == -1: for m_momentum, b_momentum, group in zip( - self.max_momentums, self.base_momentums, optimizer.param_groups + self.max_momentums, + self.base_momentums, + optimizer.param_groups, + strict=True, ): if self.use_beta1: group["betas"] = (m_momentum, *group["betas"][1:]) @@ -2033,7 +2042,7 @@ def get_lr(self) -> list[float | Tensor]: scale_factor = (x - 1) / (self.step_ratio - 1) lrs = [] - for base_lr, max_lr in zip(self.base_lrs, self.max_lrs): + for base_lr, max_lr in zip(self.base_lrs, self.max_lrs, strict=True): base_height = (max_lr - base_lr) * scale_factor if self.scale_mode == "cycle": lr = base_lr + base_height * self.scale_fn(cycle) @@ -2044,7 +2053,7 @@ def get_lr(self) -> list[float | Tensor]: if self.cycle_momentum: momentums = [] for base_momentum, max_momentum in zip( - self.base_momentums, self.max_momentums + self.base_momentums, self.max_momentums, strict=True ): base_height = (max_momentum - base_momentum) * scale_factor if self.scale_mode == "cycle": @@ -2054,7 +2063,9 @@ def get_lr(self) -> list[float | Tensor]: self.last_epoch ) momentums.append(momentum) - for param_group, momentum in zip(self.optimizer.param_groups, momentums): + for param_group, momentum in zip( + self.optimizer.param_groups, momentums, strict=True + ): if self.use_beta1: param_group["betas"] = (momentum, *param_group["betas"][1:]) else: @@ -2260,7 +2271,9 @@ def step(self, epoch=None) -> None: self.last_epoch = math.floor(epoch) with _enable_get_lr_call(self): - for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): + for param_group, lr in zip( + self.optimizer.param_groups, self.get_lr(), strict=True + ): _update_param_group_val(param_group, "lr", lr) self._last_lr = _param_groups_val_list(self.optimizer, "lr") @@ -2500,7 +2513,7 @@ def __init__( base_momentums = _format_param("base_momentum", optimizer, base_momentum) if last_epoch == -1: for m_momentum, b_momentum, group in zip( - max_momentums, base_momentums, optimizer.param_groups + max_momentums, base_momentums, optimizer.param_groups, strict=True ): if self.use_beta1: group["betas"] = (m_momentum, *group["betas"][1:]) diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index 59ed0e5d54bf9..508648a65c14a 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -412,7 +412,7 @@ def _multi_tensor_nadam( if not all( p.device.type == mp.device.type == step.device.type and p.device.type in capturable_supported_devices - for p, mp, step in zip(params, mu_products, state_steps) + for p, mp, step in zip(params, mu_products, state_steps, strict=True) ): raise AssertionError( "If capturable=True, " @@ -570,7 +570,7 @@ def _multi_tensor_nadam( step_size_grads = _stack_if_compiling( [ (_get_value(lr) * (1.0 - mu) / (1.0 - _get_value(mu_product))) * -1 - for mu_product, mu in zip(grouped_mu_products, mus) + for mu_product, mu in zip(grouped_mu_products, mus, strict=True) ] ) step_size_expavg = _stack_if_compiling( @@ -581,7 +581,9 @@ def _multi_tensor_nadam( / (1.0 - _get_value(mu_product) * mu_next) ) * -1 - for mu_product, mu_next in zip(grouped_mu_products, mu_nexts) + for mu_product, mu_next in zip( + grouped_mu_products, mu_nexts, strict=True + ) ] ) diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 5475b2755d4be..6a336fa5bab70 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -62,7 +62,7 @@ def _use_grad_for_differentiable(func: Callable[_P, _T]) -> Callable[_P, _T]: def _use_grad(*args: _P.args, **kwargs: _P.kwargs) -> _T: import torch._dynamo - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] self = cast(Optimizer, args[0]) # assume first positional arg is `self` prev_grad = torch.is_grad_enabled() try: @@ -136,13 +136,13 @@ def maybe_fallback(*args: _P.args, **kwargs: _P.kwargs): if torch.compiler.is_compiling() and ( not kwargs.get("capturable", False) and has_state_steps - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] and (arg := args[state_steps_ind]) and isinstance(arg, Sequence) and arg[0].is_cuda or ( "state_steps" in kwargs - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] and (kwarg := kwargs["state_steps"]) and isinstance(kwarg, Sequence) and kwarg[0].is_cuda @@ -362,18 +362,18 @@ class Optimizer: _optimizer_step_pre_hooks: dict[int, OptimizerPreHook] _optimizer_step_post_hooks: dict[int, OptimizerPostHook] - # pyrefly: ignore # not-a-type + # pyrefly: ignore [not-a-type] _optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' _optimizer_state_dict_post_hooks: ( - # pyrefly: ignore # not-a-type + # pyrefly: ignore [not-a-type] 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' ) _optimizer_load_state_dict_pre_hooks: ( - # pyrefly: ignore # not-a-type + # pyrefly: ignore [not-a-type] 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' ) _optimizer_load_state_dict_post_hooks: ( - # pyrefly: ignore # not-a-type + # pyrefly: ignore [not-a-type] 'OrderedDict[int, Callable[["Optimizer"], None]]' ) @@ -522,7 +522,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> R: f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}." ) - # pyrefly: ignore # invalid-param-spec + # pyrefly: ignore [invalid-param-spec] out = func(*args, **kwargs) self._optimizer_step_code() @@ -941,7 +941,9 @@ def load_state_dict(self, state_dict: StateDict) -> None: ) param_lens = (len(g["params"]) for g in groups) saved_lens = (len(g["params"]) for g in saved_groups) - if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + if any( + p_len != s_len for p_len, s_len in zip(param_lens, saved_lens, strict=True) + ): raise ValueError( "loaded state dict contains a parameter group " "that doesn't match the size of optimizer's group" @@ -952,6 +954,7 @@ def load_state_dict(self, state_dict: StateDict) -> None: zip( chain.from_iterable(g["params"] for g in saved_groups), chain.from_iterable(g["params"] for g in groups), + strict=True, ) ) @@ -961,9 +964,9 @@ def _cast(param, value, param_id=None, param_groups=None, key=None): return Optimizer._process_value_according_to_param_policy( param, value, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] param_id, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] param_groups, key, ) @@ -976,7 +979,7 @@ def _cast(param, value, param_id=None, param_groups=None, key=None): } elif isinstance(value, Iterable): return type(value)( - # pyrefly: ignore # bad-argument-count + # pyrefly: ignore [bad-argument-count] _cast(param, v, param_id=param_id, param_groups=param_groups) for v in value ) # type: ignore[call-arg] @@ -1005,7 +1008,9 @@ def update_group( new_group["param_names"] = group["param_names"] return new_group - param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + param_groups = [ + update_group(g, ng) for g, ng in zip(groups, saved_groups, strict=True) + ] self.__setstate__({"state": state, "param_groups": param_groups}) for post_hook in self._optimizer_load_state_dict_post_hooks.values(): diff --git a/torch/optim/radam.py b/torch/optim/radam.py index 22c00c3d0766d..e13e6806e43a7 100644 --- a/torch/optim/radam.py +++ b/torch/optim/radam.py @@ -323,7 +323,7 @@ def _single_tensor_radam( rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2 def _compute_rect(): - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] return ( (rho_t - 4) * (rho_t - 2) @@ -338,7 +338,7 @@ def _compute_adaptive_lr(): else: exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps) - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] return (bias_correction2**0.5) / exp_avg_sq_sqrt # Compute the variance rectification term and update parameters accordingly @@ -392,7 +392,7 @@ def _multi_tensor_radam( if not all( p.device.type == step.device.type and p.device.type in capturable_supported_devices - for p, step in zip(params, state_steps) + for p, step in zip(params, state_steps, strict=True) ): raise AssertionError( f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." @@ -501,7 +501,8 @@ def _multi_tensor_radam( # TODO(mlazos): we should try and get a foreach_where op https://github.com/pytorch/pytorch/issues/117884 rect = [ - torch.where(rho_t > 5.0, n, 0.0) for n, rho_t in zip(num, rho_t_list) + torch.where(rho_t > 5.0, n, 0.0) + for n, rho_t in zip(num, rho_t_list, strict=True) ] del num del rho_t_list @@ -544,11 +545,14 @@ def _multi_tensor_radam( 1 - beta1 ** _get_value(step) for step in grouped_state_steps ] unrect_step_size = [ - (lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1) + (lr * rect / bc) * -1 + for rect, bc in zip(unrectified, bias_correction1, strict=True) ] bias_correction2 = [ ((1 - beta2 ** _get_value(step)) ** 0.5) * (lr * rect / bc) * -1 - for step, rect, bc in zip(grouped_state_steps, rect, bias_correction1) + for step, rect, bc in zip( + grouped_state_steps, rect, bias_correction1, strict=True + ) ] buffer = torch._foreach_sqrt(grouped_exp_avg_sqs) diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index 10e7ce74509b0..04981d517d1ef 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -370,7 +370,7 @@ def _multi_tensor_rmsprop( if not all( p.device.type == step.device.type and p.device.type in capturable_supported_devices - for p, step in zip(params, state_steps) + for p, step in zip(params, state_steps, strict=True) ): raise AssertionError( f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index 1ca4aefae3456..8ad7faf130e39 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -319,7 +319,7 @@ def _multi_tensor_rprop( if not all( p.device.type == step.device.type and p.device.type in capturable_supported_devices - for p, step in zip(params, state_steps) + for p, step in zip(params, state_steps, strict=True) ): raise AssertionError( f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index b630abea96770..9c2c5a0eab3d0 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -143,7 +143,9 @@ def step(self, closure=None): if group["momentum"] != 0: # update momentum_buffers in state - for p, momentum_buffer in zip(params, momentum_buffer_list): + for p, momentum_buffer in zip( + params, momentum_buffer_list, strict=True + ): state = self.state[p] state["momentum_buffer"] = momentum_buffer @@ -348,7 +350,7 @@ def _single_tensor_sgd( # usually this is the differentiable path, which is why the param.clone() is needed grad = grad.addcmul_(param.clone(), weight_decay) else: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] grad = grad.add(param, alpha=weight_decay) else: grad = grad.add(param, alpha=weight_decay) @@ -372,7 +374,7 @@ def _single_tensor_sgd( if lr.requires_grad: param.addcmul_(grad, lr, value=-1) else: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] param.add_(grad, alpha=-lr) else: param.add_(grad, alpha=-lr) diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index 08cd0b504dc87..1ab915d27cd66 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -50,7 +50,7 @@ def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _): ): torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay) else: - for p_ema, p_model in zip(ema_param_list, current_param_list): + for p_ema, p_model in zip(ema_param_list, current_param_list, strict=True): p_ema.copy_(p_ema * decay + p_model * (1 - decay)) return ema_update @@ -250,13 +250,13 @@ def forward(self, *args, **kwargs): def update_parameters(self, model: Module): """Update model parameters.""" self_param = ( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] itertools.chain(self.module.parameters(), self.module.buffers()) if self.use_buffers else self.parameters() ) model_param = ( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] itertools.chain(model.parameters(), model.buffers()) if self.use_buffers else model.parameters() @@ -264,7 +264,7 @@ def update_parameters(self, model: Module): self_param_detached: list[Optional[Tensor]] = [] model_param_detached: list[Optional[Tensor]] = [] copy_param = bool(self.n_averaged == 0) - for p_averaged, p_model in zip(self_param, model_param): + for p_averaged, p_model in zip(self_param, model_param, strict=False): p_model_ = p_model.detach().to(p_averaged.device) self_param_detached.append(p_averaged.detach()) model_param_detached.append(p_model_) @@ -297,25 +297,29 @@ def update_parameters(self, model: Module): else: avg_fn = get_swa_avg_fn() n_averaged = self.n_averaged.to(device) - for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment] - # pyrefly: ignore # missing-attribute + for p_averaged, p_model in zip( # type: ignore[assignment] + self_params, model_params, strict=True + ): + # pyrefly: ignore [missing-attribute] p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged)) else: for p_averaged, p_model in zip( # type: ignore[assignment] - self_param_detached, model_param_detached + self_param_detached, model_param_detached, strict=True ): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] n_averaged = self.n_averaged.to(p_averaged.device) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] p_averaged.detach().copy_( - # pyrefly: ignore # missing-attribute, bad-argument-type + # pyrefly: ignore [missing-attribute, bad-argument-type] self.avg_fn(p_averaged.detach(), p_model, n_averaged) ) if not self.use_buffers: # If not apply running averages to the buffers, # keep the buffers in sync with the source model. - for b_swa, b_model in zip(self.module.buffers(), model.buffers()): + for b_swa, b_model in zip( + self.module.buffers(), model.buffers(), strict=True + ): b_swa.detach().copy_(b_model.detach().to(b_swa.device)) self.n_averaged += 1 @@ -432,7 +436,7 @@ def __init__( last_epoch=-1, ): # noqa: D107 swa_lrs = _format_param("swa_lr", optimizer, swa_lr) - for swa_lr, group in zip(swa_lrs, optimizer.param_groups): + for swa_lr, group in zip(swa_lrs, optimizer.param_groups, strict=True): group["swa_lr"] = swa_lr if anneal_strategy not in ["cos", "linear"]: raise ValueError( @@ -497,19 +501,19 @@ def get_lr(self): step = self._step_count - 1 if self.anneal_epochs == 0: step = max(1, step) - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs))) prev_alpha = self.anneal_func(prev_t) prev_lrs = [ self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha) for group in self.optimizer.param_groups ] - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] t = max(0, min(1, step / max(1, self.anneal_epochs))) alpha = self.anneal_func(t) return [ group["swa_lr"] * alpha + lr * (1 - alpha) - for group, lr in zip(self.optimizer.param_groups, prev_lrs) + for group, lr in zip(self.optimizer.param_groups, prev_lrs, strict=True) ] def _set_anneal_func(self, anneal_strategy: Literal["cos", "linear"]): diff --git a/torch/overrides.py b/torch/overrides.py index db4a7535a36fd..dea75f69ea49b 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -221,8 +221,11 @@ def get_ignored_functions() -> set[Callable]: torch.ones, torch.promote_types, torch.rand, + torch.rand_like, torch.randn, + torch.randn_like, torch.randint, + torch.randint_like, torch.randperm, torch.range, torch.result_type, @@ -1075,9 +1078,6 @@ def get_testing_overrides() -> dict[Callable, Callable]: lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950 ), torch.rad2deg: lambda input, out=None: -1, - torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, - torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, - torch.randn_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, torch.ravel: lambda input: -1, torch.real: lambda input, out=None: -1, torch.vdot: lambda input, other, out=None: -1, @@ -1657,7 +1657,8 @@ def _get_overloaded_args( if ( arg_type not in overloaded_types and hasattr(arg_type, "__torch_function__") - and arg_type.__torch_function__ != torch._C._disabled_torch_function_impl + and arg_type.__torch_function__ + is not torch._C._disabled_torch_function_impl ): # Create lists explicitly for the first type (usually the only one # done) to avoid setting up the iterator for overloaded_args. diff --git a/torch/package/_package_pickler.py b/torch/package/_package_pickler.py index 31898c96f1b08..a66c14adfe86f 100644 --- a/torch/package/_package_pickler.py +++ b/torch/package/_package_pickler.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -# pyrefly: ignore # missing-module-attribute +# pyrefly: ignore [missing-module-attribute] from pickle import ( # type: ignore[attr-defined] _compat_pickle, _extension_registry, diff --git a/torch/package/importer.py b/torch/package/importer.py index 3984ddfc40fbc..fc0e735890634 100644 --- a/torch/package/importer.py +++ b/torch/package/importer.py @@ -1,9 +1,10 @@ # mypy: allow-untyped-defs import importlib import logging +import sys from abc import ABC, abstractmethod -# pyrefly: ignore # missing-module-attribute +# pyrefly: ignore [missing-module-attribute] from pickle import ( # type: ignore[attr-defined] _getattribute, _Pickler, @@ -102,7 +103,12 @@ def get_name(self, obj: Any, name: Optional[str] = None) -> tuple[str, str]: # Check that this name will indeed return the correct object try: module = self.import_module(module_name) - obj2, _ = _getattribute(module, name) + if sys.version_info >= (3, 14): + # pickle._getatribute signature changes in 3.14 + # to take iterable and return just one object + obj2 = _getattribute(module, name.split(".")) + else: + obj2, _ = _getattribute(module, name) except (ImportError, KeyError, AttributeError): raise ObjNotFoundError( f"{obj} was not found as {module_name}.{name}" diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index 50e9cbe92fb08..b25ebca23095f 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -652,7 +652,7 @@ def _check_mocked_error(module: Optional[str], field: Optional[str]): memo: defaultdict[int, str] = defaultdict(None) memo_count = 0 # pickletools.dis(data_value) - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] for opcode, arg, _pos in pickletools.genops(data_value): if pickle_protocol == 4: if ( diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index 1fa07c90fde15..3f21ce81171d7 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -230,7 +230,7 @@ def inputs_are_mutable(cls, t: _ExtraFields_TorchOp) -> tuple[Optional[bool], .. for schema in cls.match_schemas(t): mutable = mutable or [False for _ in schema.arguments] for i, arg in enumerate(schema.arguments): - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] mutable[i] |= getattr(arg.alias_info, "is_write", False) return tuple(mutable or (None for _ in t.inputs)) @@ -254,7 +254,9 @@ def match_schemas(cls, t: _ExtraFields_TorchOp) -> tuple[FunctionSchema, ...]: def matches(schema) -> bool: return len(schema.arguments) == len(signature) and all( cls._types_match(observed, schema_arg.type) - for observed, schema_arg in zip(signature, schema.arguments) + for observed, schema_arg in zip( + signature, schema.arguments, strict=True + ) ) return tuple(s for s in cls.lookup_schemas(t.name) or () if matches(s)) @@ -377,7 +379,9 @@ def _update_values(self, t: Optional[_TensorMetadata]) -> None: key = TensorKey.from_tensor(t) if key is not None and t is not None and t.layout == torch.strided: # Scalars are represented as zero dim Tensors - n = max(i[0] * i[1] for i in zip(t.sizes or [1], t.strides or [1])) + n = max( + i[0] * i[1] for i in zip(t.sizes or [1], t.strides or [1], strict=True) + ) num_bytes = n * _element_size(t.dtype) assert num_bytes >= 0, f"{num_bytes}" @@ -430,7 +434,7 @@ def _determine_edges(self) -> dict[TensorKey, DataFlowEdge]: mutable_by_key: dict[Optional[TensorKey], set[Optional[bool]]] = {} for op in (i.typed[1] for i in subtree if i.typed[0] == _EventType.TorchOp): for op_input, mutable in zip( - op.inputs, SchemaMatcher.inputs_are_mutable(op) + op.inputs, SchemaMatcher.inputs_are_mutable(op), strict=True ): # Tensor if isinstance(op_input, _TensorMetadata): @@ -1084,7 +1088,7 @@ def get_category_index(key, version): if action in (Action.PREEXISTING, Action.CREATE): raw_events.append( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] ( t, _ACTION_TO_INDEX[action], @@ -1095,7 +1099,7 @@ def get_category_index(key, version): elif action == Action.INCREMENT_VERSION: raw_events.append( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] ( t, _ACTION_TO_INDEX[action], @@ -1104,7 +1108,7 @@ def get_category_index(key, version): ) ) raw_events.append( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] ( t, _ACTION_TO_INDEX[action], @@ -1115,7 +1119,7 @@ def get_category_index(key, version): elif action == Action.DESTROY: raw_events.append( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] ( t, _ACTION_TO_INDEX[action], diff --git a/torch/profiler/_pattern_matcher.py b/torch/profiler/_pattern_matcher.py index cee47f28eb04a..97fe8b5edc22a 100644 --- a/torch/profiler/_pattern_matcher.py +++ b/torch/profiler/_pattern_matcher.py @@ -276,7 +276,7 @@ def match(self, event: _ProfilerEvent): def same_ops(list1, list2) -> bool: if len(list1) != len(list2): return False - for op1, op2 in zip(list1, list2): + for op1, op2 in zip(list1, list2, strict=True): if op1.name != op2.name: return False return True diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index 04b0fc62189e8..2c6e06b2cb3c9 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -211,7 +211,7 @@ def new_old_event_comparator(event): # Find latest cuda kernel event if hasattr(event, "start_us"): start_time = event.start_us() * 1000 - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] end_time = (event.start_us() + event.duration_us()) * 1000 # Find current spawned cuda kernel event if event in kernel_mapping and kernel_mapping[event] is not None: @@ -336,7 +336,7 @@ def rank_events(self, length): event_list = [ event for _, event in sorted( - zip(heuristic_score_list, event_list), + zip(heuristic_score_list, event_list, strict=True), key=operator.itemgetter(0), reverse=True, ) diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index aa046db445494..ee0ea85e1694b 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -161,19 +161,19 @@ def __init__( self.mem_tl: Optional[MemoryProfileTimeline] = None self.use_device = None if ProfilerActivity.CUDA in self.activities: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.use_device = "cuda" elif ProfilerActivity.XPU in self.activities: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.use_device = "xpu" elif ProfilerActivity.MTIA in self.activities: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.use_device = "mtia" elif ProfilerActivity.HPU in self.activities: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.use_device = "hpu" elif ProfilerActivity.PrivateUse1 in self.activities: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.use_device = _get_privateuse1_backend_name() # user-defined metadata to be amended to the trace @@ -385,7 +385,7 @@ def _get_distributed_info(self): } if backend == "nccl": nccl_version = torch.cuda.nccl.version() - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] dist_info["nccl_version"] = ".".join(str(v) for v in nccl_version) return dist_info diff --git a/torch/quantization/_quantized_conversions.py b/torch/quantization/_quantized_conversions.py index 54f40dcf7b25e..0fcb1004f7047 100644 --- a/torch/quantization/_quantized_conversions.py +++ b/torch/quantization/_quantized_conversions.py @@ -71,7 +71,7 @@ def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( nrows // 16, 16 ) ).view(-1) - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] outp = outp.index_copy(1, cols_permuted, outp) # interleave_column_major_tensor diff --git a/torch/serialization.py b/torch/serialization.py index 268540752343d..bc209350708cf 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -790,7 +790,7 @@ def __init__(self, name: str) -> None: # PyTorchFileWriter only supports ascii filename. # For filenames with non-ascii characters, we rely on Python # for writing out the file. - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.file_stream = io.FileIO(self.name, mode="w") super().__init__( torch._C.PyTorchFileWriter( # pyrefly: ignore # no-matching-overload diff --git a/torch/signal/windows/windows.py b/torch/signal/windows/windows.py index 96fe6932de87f..8478d0df574dc 100644 --- a/torch/signal/windows/windows.py +++ b/torch/signal/windows/windows.py @@ -397,15 +397,15 @@ def kaiser( ) # Avoid NaNs by casting `beta` to the appropriate dtype. - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] beta = torch.tensor(beta, dtype=dtype, device=device) start = -beta constant = 2.0 * beta / (M if not sym else M - 1) end = torch.minimum( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] beta, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] start + (M - 1) * constant, ) @@ -420,7 +420,7 @@ def kaiser( ) return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] beta ) diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index 448b50eda0200..2ddd930cd8521 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -623,20 +623,20 @@ def convert_to_strided_representation(args): ) obj = obj.to_dense().sparse_mask(full_mask) if obj.layout is torch.sparse_coo: - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] d.update( indices=obj._indices(), is_coalesced=obj.is_coalesced() ) values = obj._values() elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}: - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] d.update( compressed_indices=obj.crow_indices(), plain_indices=obj.col_indices(), ) values = obj.values() else: - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] d.update( compressed_indices=obj.ccol_indices(), plain_indices=obj.row_indices(), diff --git a/torch/sparse/_semi_structured_conversions.py b/torch/sparse/_semi_structured_conversions.py index c98205f567070..354acdee16a26 100644 --- a/torch/sparse/_semi_structured_conversions.py +++ b/torch/sparse/_semi_structured_conversions.py @@ -140,7 +140,7 @@ def sparse_semi_structured_from_dense_cutlass(dense): if dense.dtype != torch.float: sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) else: @@ -173,7 +173,7 @@ def sparse_semi_structured_from_dense_cutlass(dense): meta_offsets = _calculate_meta_reordering_scatter_offsets( m, meta_ncols, meta_dtype, device ) - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) return (sparse, meta_reordered.view(m, meta_ncols)) diff --git a/torch/sparse/_semi_structured_ops.py b/torch/sparse/_semi_structured_ops.py index 55cb0a8c113ef..8870dce504190 100644 --- a/torch/sparse/_semi_structured_ops.py +++ b/torch/sparse/_semi_structured_ops.py @@ -67,7 +67,7 @@ def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor: # Because we cannot go from the compressed representation back to the dense representation currently, # we just keep track of how many times we have been transposed. Depending on whether the sparse matrix # is the first or second argument, we expect an even / odd number of calls to transpose respectively. - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] return self.__class__( torch.Size([self.shape[-1], self.shape[0]]), packed=self.packed_t, diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py index 942e5e8dca3f1..2011930d78fbf 100644 --- a/torch/sparse/_triton_ops.py +++ b/torch/sparse/_triton_ops.py @@ -121,7 +121,7 @@ def slicer(dim, slice_range, *tensors): def multidim_slicer(dims, slices, *tensors): for t in tensors: s = [slice(None)] * t.dim() - for d, d_slice in zip(dims, slices): + for d, d_slice in zip(dims, slices, strict=False): if d is not None: s[d] = d_slice yield t[tuple(s)] @@ -140,7 +140,7 @@ def grid_partitioner(full_grid, grid_blocks, tensor_dims_map): import itertools def generate_grid_points(): - for fg, mg in zip(full_grid, grid_blocks): + for fg, mg in zip(full_grid, grid_blocks, strict=False): yield range(0, fg, mg) def generate_sliced_tensors(slices): @@ -149,9 +149,10 @@ def generate_sliced_tensors(slices): for grid_point in itertools.product(*generate_grid_points()): grid = [ - min(fg - gp, mg) for fg, gp, mg in zip(full_grid, grid_point, grid_blocks) + min(fg - gp, mg) + for fg, gp, mg in zip(full_grid, grid_point, grid_blocks, strict=False) ] - slices = [slice(gp, gp + g) for gp, g in zip(grid_point, grid)] + slices = [slice(gp, gp + g) for gp, g in zip(grid_point, grid, strict=False)] # grid_points are iterated in a "contiguous" order, i.e. # left dimensions traversed slower than right dimensions. # This order is reversed for CUDA grids. @@ -173,7 +174,8 @@ def valid_grid_dim(g, mg): return max(1, min(g, mg)) grid_blocks = tuple( - valid_grid_dim(g, mg) for g, mg in zip(grid_blocks, cuda_max_grid) + valid_grid_dim(g, mg) + for g, mg in zip(grid_blocks, cuda_max_grid, strict=False) ) # type: ignore[assignment] for grid, *sliced_tensors in grid_partitioner( @@ -1297,20 +1299,31 @@ def bsr_dense_addmm( assert alpha != 0 def kernel(grid, *sliced_tensors): - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] _bsr_strided_addmm_kernel[grid]( *ptr_stride_extractor(*sliced_tensors), + # pyrefly: ignore # bad-argument-count beta, alpha, + # pyrefly: ignore # bad-keyword-argument, bad-argument-type beta_is_one=beta == 1, + # pyrefly: ignore # bad-keyword-argument, bad-argument-type beta_is_nonzero=beta != 0, + # pyrefly: ignore # bad-keyword-argument, bad-argument-type alpha_is_one=alpha == 1, + # pyrefly: ignore # bad-keyword-argument, bad-argument-type left_alpha_is_one=left_alpha_is_one, + # pyrefly: ignore # bad-keyword-argument, bad-argument-type right_alpha_is_one=right_alpha_is_one, + # pyrefly: ignore # bad-keyword-argument, bad-argument-type BLOCKSIZE_ROW=BM, + # pyrefly: ignore # bad-keyword-argument, bad-argument-type BLOCKSIZE_INNER=BK, + # pyrefly: ignore # bad-keyword-argument BLOCKSIZE_COL=BN, + # pyrefly: ignore # bad-keyword-argument allow_tf32=dot_out_dtype == tl.float32, + # pyrefly: ignore # bad-keyword-argument, bad-argument-type acc_dtype=dot_out_dtype, **meta, ) @@ -1427,7 +1440,7 @@ def _sampled_addmm_kernel( mat1_block = tl.load( mat1_block_ptrs + mat1_col_block_stride * k_offsets[None, :], - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] mask=mask_k[None, :], other=0.0, ) @@ -1436,7 +1449,7 @@ def _sampled_addmm_kernel( mat2_block_ptrs + mat2_tiled_col_stride * col_block + mat2_row_block_stride * k_offsets[:, None], - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] mask=mask_k[:, None], other=0.0, ) @@ -1631,12 +1644,17 @@ def kernel(grid, *sliced_tensors): beta, is_beta_zero, *blocksize, + # pyrefly: ignore # bad-argument-count k, tile_k, *ptr_stride_extractor(*sliced_tensors), + # pyrefly: ignore # bad-keyword-argument, bad-argument-type acc_dtype=acc_dtype, + # pyrefly: ignore # bad-keyword-argument, bad-argument-type allow_tf32=allow_tf32, + # pyrefly: ignore # unexpected-keyword num_stages=1, + # pyrefly: ignore # unexpected-keyword num_warps=4, ) @@ -1921,6 +1939,7 @@ def bsr_softmax(input, max_row_nnz=None): def kernel(grid, *sliced_tensors): _bsr_softmax_kernel[grid]( *ptr_stride_extractor(*sliced_tensors), + # pyrefly: ignore # bad-argument-count row_block, col_block, max_row_nnz, @@ -1974,7 +1993,7 @@ def _scaled_dot_product_attention( if attn_mask.dtype is not torch.bool: check_dtype(f_name, attn_mask, query.dtype) - # pyrefly: ignore # not-callable + # pyrefly: ignore [not-callable] sdpa = sampled_addmm( attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False ) @@ -1986,10 +2005,10 @@ def _scaled_dot_product_attention( ) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale sdpa.values().mul_(scale_factor) - # pyrefly: ignore # not-callable + # pyrefly: ignore [not-callable] sdpa = bsr_softmax(sdpa) torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True) - # pyrefly: ignore # not-callable + # pyrefly: ignore [not-callable] sdpa = bsr_dense_mm(sdpa, value) return sdpa @@ -2094,8 +2113,11 @@ def grid(META): if "allow_tf32" not in meta: meta.update(allow_tf32=dot_out_dtype == tl.float32) _scatter_mm2_kernel[grid]( + # pyrefly: ignore # bad-argument-type M, + # pyrefly: ignore # bad-argument-type K, + # pyrefly: ignore # bad-argument-type N, blocks, blocks.stride(0), @@ -2114,7 +2136,9 @@ def grid(META): pq_indices, pq_indices.stride(0), pq_indices.stride(1), + # pyrefly: ignore # bad-argument-type dot_out_dtype=dot_out_dtype, + # pyrefly: ignore # bad-argument-type **meta, ) @@ -2297,6 +2321,7 @@ def grid(META): _scatter_mm6_kernel[grid]( B, Ms, + # pyrefly: ignore # bad-argument-type Ks, N, blocks, @@ -2315,6 +2340,7 @@ def grid(META): r_offsets, p_offsets, q_offsets, + # pyrefly: ignore # bad-argument-type dot_out_dtype=dot_out_dtype, **meta, ) diff --git a/torch/sparse/_triton_ops_meta.py b/torch/sparse/_triton_ops_meta.py index 78bdbf07b2b3c..38749d00f0eb4 100644 --- a/torch/sparse/_triton_ops_meta.py +++ b/torch/sparse/_triton_ops_meta.py @@ -155,7 +155,11 @@ def get_meta(op, key, device_name=None, version=(0, torch.float16, 0.5), exact=F matching_data = {} if "*" in key: for op_key in op_data: - if [None for k1, k2 in zip(op_key, key) if k2 != "*" and k1 != k2]: + if [ + None + for k1, k2 in zip(op_key, key, strict=True) + if k2 != "*" and k1 != k2 + ]: continue matching_data[op_key] = op_data[op_key] else: @@ -173,10 +177,14 @@ def get_meta(op, key, device_name=None, version=(0, torch.float16, 0.5), exact=F "num_stages", "num_warps", ) - meta = dict(zip(names, values)) + meta = dict(zip(names, values, strict=True)) elif op in {"bsr_dense_addmm", "_int_bsr_dense_addmm"}: meta = dict( - zip(("GROUP_SIZE_ROW", "SPLIT_N", "num_stages", "num_warps"), values) + zip( + ("GROUP_SIZE_ROW", "SPLIT_N", "num_stages", "num_warps"), + values, + strict=True, + ) ) else: raise NotImplementedError(f"names for {op=}") @@ -234,10 +242,10 @@ def sort_key(key): part2 = current_content[end_data_index:] data_part = [] for op_key in sorted(_operation_device_version_data, key=sort_key): - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] data_part.append(" " + repr(op_key).replace("'", '"') + ": {") op_data = _operation_device_version_data[op_key] - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] data_part.extend(f" {key}: {op_data[key]}," for key in sorted(op_data)) data_part.append(" },") new_content = part1 + "\n".join(data_part) + "\n" + part2 @@ -289,7 +297,7 @@ def to_key(parameters): return tuple(parameters[k] for k in sorted(parameters)) def from_key(key, parameters): - return dict(zip(sorted(parameters), key)) + return dict(zip(sorted(parameters), key, strict=True)) if all_values is None: all_values = {} @@ -347,7 +355,7 @@ def from_key(key, parameters): for i, (_, d_tuple) in enumerate(all_directions): pbar.update(1) next_parameters = parameters.copy() - for name, direction in zip(names, d_tuple): + for name, direction in zip(names, d_tuple, strict=True): value = next_parameters[name] if direction == 0: continue @@ -371,7 +379,7 @@ def from_key(key, parameters): if next_target < minimal_target: minimal_target = next_target parameters = next_parameters - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] pbar.total += i + 1 break else: diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index da5b8b4798a9c..df5e3508e5256 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -185,7 +185,7 @@ def __tensor_unflatten__( outer_stride, ) -> torch.Tensor: shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] return cls( shape=shape, packed=inner_tensors.get("packed", None), @@ -415,7 +415,7 @@ def from_dense( sparse_tensor_cutlass, meta_tensor_cutlass, ) = sparse_semi_structured_from_dense_cutlass(original_tensor) - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] return cls( original_tensor.shape, packed=sparse_tensor_cutlass, @@ -502,7 +502,7 @@ def prune_dense_static_sort( original_tensor, algorithm=algorithm, use_cutlass=True ) - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] return cls( original_tensor.shape, packed=packed, @@ -564,7 +564,7 @@ def from_dense( cls, original_tensor: torch.Tensor ) -> "SparseSemiStructuredTensorCUSPARSELT": cls._validate_device_dim_dtype_shape(original_tensor) - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] return cls( shape=original_tensor.shape, packed=torch._cslt_compress(original_tensor), @@ -631,7 +631,7 @@ def prune_dense_static_sort( packed = packed.view(original_tensor.shape[0], -1) packed_t = packed_t.view(original_tensor.shape[1], -1) - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] return cls( original_tensor.shape, packed=packed, diff --git a/torch/testing/__init__.py b/torch/testing/__init__.py index 9e8f41008dcfb..6724bd3d523b0 100644 --- a/torch/testing/__init__.py +++ b/torch/testing/__init__.py @@ -2,6 +2,6 @@ from . import _utils -# pyrefly: ignore # deprecated +# pyrefly: ignore [deprecated] from ._comparison import assert_allclose, assert_close as assert_close from ._creation import make_tensor as make_tensor diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py index 1d4a050b80472..45622ec7f15e9 100644 --- a/torch/testing/_comparison.py +++ b/torch/testing/_comparison.py @@ -243,7 +243,7 @@ def make_scalar_mismatch_msg( Defaults to "Scalars". """ abs_diff = abs(actual - expected) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] rel_diff = float("inf") if expected == 0 else abs_diff / abs(expected) return _make_mismatch_msg( default_identifier="Scalars", @@ -487,7 +487,7 @@ def __init__( def _supported_types(self) -> tuple[type, ...]: cls: list[type] = [bool] if HAS_NUMPY: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] cls.append(np.bool_) return tuple(cls) @@ -503,7 +503,7 @@ def _process_inputs( def _to_bool(self, bool_like: Any, *, id: tuple[Any, ...]) -> bool: if isinstance(bool_like, bool): return bool_like - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] elif isinstance(bool_like, np.bool_): return bool_like.item() else: @@ -583,7 +583,7 @@ def __init__( def _supported_types(self) -> tuple[type, ...]: cls = list(self._NUMBER_TYPES) if HAS_NUMPY: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] cls.append(np.number) return tuple(cls) @@ -599,7 +599,7 @@ def _process_inputs( def _to_number( self, number_like: Any, *, id: tuple[Any, ...] ) -> Union[int, float, complex]: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] if HAS_NUMPY and isinstance(number_like, np.number): return number_like.item() elif isinstance(number_like, self._NUMBER_TYPES): @@ -1122,7 +1122,7 @@ def originate_pairs( mapping_types: tuple[type, ...] = (collections.abc.Mapping,), id: tuple[Any, ...] = (), **options: Any, - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> list[Pair]: """Originates pairs from the individual inputs. @@ -1221,7 +1221,7 @@ def originate_pairs( else: for pair_type in pair_types: try: - # pyrefly: ignore # bad-instantiation + # pyrefly: ignore [bad-instantiation] return [pair_type(actual, expected, id=id, **options)] # Raising an `UnsupportedInputs` during origination indicates that the pair type is not able to handle the # inputs. Thus, we try the next pair type. @@ -1319,9 +1319,9 @@ def not_close_error_metas( # would not get freed until cycle collection, leaking cuda memory in tests. # We break the cycle by removing the reference to the error_meta objects # from this frame as it returns. - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] error_metas = [error_metas] - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return error_metas.pop() diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 3f475bd6823b5..5e4b65f68cd72 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -60,9 +60,9 @@ def CDNA2OrLater(): def evaluate_platform_supports_flash_attention(): if TEST_WITH_ROCM: - arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"] + arch_list = ["gfx90a", "gfx942", "gfx1201", "gfx950"] if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0": - arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"] + arch_list += ["gfx1100", "gfx1101", "gfx1102", "gfx1150", "gfx1151", "gfx1200"] return evaluate_gfx_arch_within(arch_list) if TEST_CUDA: return not IS_WINDOWS and SM80OrLater @@ -70,9 +70,9 @@ def evaluate_platform_supports_flash_attention(): def evaluate_platform_supports_efficient_attention(): if TEST_WITH_ROCM: - arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"] + arch_list = ["gfx90a", "gfx942", "gfx1201", "gfx950"] if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0": - arch_list += ["gfx1101", "gfx1150", "gfx1151", "gfx1200"] + arch_list += ["gfx1100", "gfx1101", "gfx1102", "gfx1150", "gfx1151", "gfx1200"] return evaluate_gfx_arch_within(arch_list) if TEST_CUDA: return True @@ -384,6 +384,20 @@ def xfailIfSM120OrLater(func): def xfailIfDistributedNotSupported(func): return func if not (IS_MACOS or IS_JETSON) else unittest.expectedFailure(func) +def _check_has_working_nvml() -> bool: + try: + if not torch.cuda.is_available(): + return False + import pynvml + torch.cuda.device_memory_used() + return True + except ModuleNotFoundError: + return False + except pynvml.NVMLError_NotSupported: + return False + +HAS_WORKING_NVML = _check_has_working_nvml() + # Importing this module should NOT eagerly initialize CUDA if not CUDA_ALREADY_INITIALIZED_ON_IMPORT: assert not torch.cuda.is_initialized() diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 0cecc762bce4a..98e17404d1ad2 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -445,11 +445,9 @@ def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs): ) # Checking for permutations of weights and biases as `None` - weights = [channels, None, None] - biases = [None, channels, None] is_training = [True, False, False] - for weight, bias, training in zip(weights, biases, is_training, strict=True): + for training in is_training: yield SampleInput( make_arg(input_shape), args=( @@ -9765,6 +9763,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): supports_autograd=True, supports_inplace_autograd=True, supports_forward_ad=True, + supports_sparse=True, decorators=( DecorateInfo( unittest.expectedFailure, diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index 45ede2d5e433f..b32898531926d 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -92,6 +92,8 @@ def mps_ops_modifier( "log1p", "log2", "log", + "logaddexp", + "logaddexp2", "mH", "mT", "masked_fill", diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index 3153359326dca..7a276144e53bd 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -120,7 +120,9 @@ def get_weight(m): desc='no_bias', reference_fn=lambda i, p, _: torch.mm(i, p[0].t()), with_tf32=True, - tf32_precision=0.05 if TEST_WITH_ROCM else 0.005, + tf32_precision=0.005, + # ROCM: skipping tf32 test on gfx94 archs due to tolerance issue. + test_cuda=not (TEST_WITH_ROCM and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName), default_dtype=torch.double, ), dict( diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index fde4f396b2b91..c88f7ad45c7ea 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -810,7 +810,7 @@ def check_eager_serialization(self, ref_model, loaded_model, x): b = io.BytesIO() torch.save(model_dict, b) b.seek(0) - # weights_only=False as we sometimes get a ScriptObect here (weird) + # weights_only=False as we sometimes get a ScriptObject here (weird) loaded_dict = torch.load(b, weights_only=False) loaded_model.load_state_dict(loaded_dict) ref_out = ref_model(*x) diff --git a/torch/testing/_internal/common_quantized.py b/torch/testing/_internal/common_quantized.py index 10670123d630f..6bd57fa976ebc 100644 --- a/torch/testing/_internal/common_quantized.py +++ b/torch/testing/_internal/common_quantized.py @@ -530,10 +530,37 @@ def to_blocked(input_matrix) -> torch.Tensor: return rearranged.flatten() + +def down_size(size): + assert size[-1] % 2 == 0, f"{size} last dim not divisible by two" + return (*size[:-1], size[-1] // 2) + + +def pack_uint4(uint8_data) -> torch.Tensor: + # converting to uint8 for operations + shape = uint8_data.shape + assert shape[-1] % 2 == 0 + uint8_data = uint8_data.contiguous().view(-1) + return (uint8_data[1::2] << 4 | uint8_data[::2]).view(down_size(shape)) + + +# exponent and mantissa bits of `torch.float4_e2m1fn_x2` +FP4_EBITS, FP4_MBITS = 2, 1 + + +def _bfloat16_to_float4_e2m1fn_x2(x): + assert x.dtype == torch.bfloat16 + x = _f32_to_floatx_unpacked(x.float(), FP4_EBITS, FP4_MBITS) + x = pack_uint4(x) + x = x.view(torch.float4_e2m1fn_x2) + return x + + # This function is extracted from https://github.com/pytorch/ao/blob/v0.12.0/torchao/prototype/mx_formats/mx_tensor.py#L142 -def to_mxfp8( +def to_mxfp( data_hp: torch.Tensor, block_size: int = 32, + format: str = "mxfp8", ): assert data_hp.dtype in ( torch.bfloat16, @@ -554,8 +581,12 @@ def to_mxfp8( data_hp = data_hp.to(torch.float32) max_abs = max_abs.to(torch.float32) - F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0 - max_pos = F8E4M3_MAX + if format == "mxfp8": + F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0 + max_pos = F8E4M3_MAX + elif format == "mxfp4": + F4E2M1_MAX = 6. + max_pos = F4E2M1_MAX # RCEIL def _to_mx_rceil( @@ -592,9 +623,15 @@ def _to_mx_rceil( scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos) # cast to target dtype - data_lp = data_lp.to(torch.float8_e4m3fn) - # need to reshape at the end to help inductor fuse things - data_lp = data_lp.reshape(orig_shape) + if format == "mxfp8": + data_lp = data_lp.to(torch.float8_e4m3fn) + # need to reshape at the end to help inductor fuse things + data_lp = data_lp.reshape(orig_shape) + elif format == "mxfp4": + data_lp = _bfloat16_to_float4_e2m1fn_x2(data_lp.to(torch.bfloat16)) + final_shape = list(orig_shape) + final_shape[-1] //= 2 + data_lp = data_lp.reshape(final_shape) scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu) scale_e8m0_biased = scale_e8m0_biased.squeeze(-1) diff --git a/torch/testing/_internal/composite_compliance.py b/torch/testing/_internal/composite_compliance.py index 527fc8a5826e8..773bea63eef82 100644 --- a/torch/testing/_internal/composite_compliance.py +++ b/torch/testing/_internal/composite_compliance.py @@ -186,7 +186,7 @@ def unwrap(e): def wrap(e): return CompositeCompliantTensor(e, self) if isinstance(e, torch.Tensor) else e - if func == torch.ops.aten._local_scalar_dense.default: + if func is torch.ops.aten._local_scalar_dense.default: raise RuntimeError( ".item() is not allowed to be called inside of composite " "functions in the PyTorch library because not all backends " diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 117f6ec8c1b25..17140f40684dd 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -17,7 +17,6 @@ import torch.nn as nn import torch.nn.functional as F from torch.distributed._local_tensor import ( - local_tensor_mode, LocalIntNode, LocalTensor, LocalTensorMode, @@ -715,9 +714,6 @@ def _handle_test_skip(self, msg: str) -> None: self.skipTest(msg) def _get_local_tensor_mode(self): - lm = local_tensor_mode() - if lm is not None: - breakpoint() return LocalTensorMode(frozenset(range(self.world_size))) def setUp(self) -> None: diff --git a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py index ca9bc297010ac..32498f6d14917 100644 --- a/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py +++ b/torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py @@ -465,7 +465,7 @@ def do_test_on_master( ) # Destroy process groups - for idx, trainer_rref in enumerate(trainer_rrefs): + for trainer_rref in trainer_rrefs: _remote_method_async(Trainer.destroy_pg, trainer_rref).wait() # Send shutdown signals. diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 499341b079518..a14f670d788be 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -6094,7 +6094,7 @@ def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value( dim=1, ).cuda(rank) - for i in range(100): + for _ in range(100): y = model(input_var[rank].cuda(rank)) y.mean().backward() @@ -6467,7 +6467,7 @@ def test_SyncBatchNorm_process_group(self): def _run_reduction_test( self, tensor, expected_tensor, op, reduction_fn=dist.all_reduce, dst=None ): - if reduction_fn != dist.all_reduce and dst is None: + if reduction_fn is not dist.all_reduce and dst is None: raise ValueError(f"Reduction fn {reduction_fn} must specify dst!") if dst is not None: reduction_fn(tensor, dst, op) diff --git a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py index 3c5c9101e43c4..1b371d3ee6ea0 100644 --- a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py +++ b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py @@ -1988,7 +1988,7 @@ def test_clean_context_during_backward(self): self.assertEqual(self.world_size - 1, len(known_context_ids)) t1 = torch.rand((3, 3), requires_grad=True) - for i in range(100): + for _ in range(100): dst = self._next_rank() t1 = rpc.rpc_sync(worker_name(dst), torch.add, args=(t1, t1)) diff --git a/torch/testing/_internal/opinfo/definitions/special.py b/torch/testing/_internal/opinfo/definitions/special.py index d6dce75437d17..f9dc471ca98aa 100644 --- a/torch/testing/_internal/opinfo/definitions/special.py +++ b/torch/testing/_internal/opinfo/definitions/special.py @@ -42,7 +42,7 @@ # supports `exclude` argument. # For more context: https://github.com/pytorch/pytorch/pull/56352#discussion_r633277617 def sample_inputs_i0_i1(op_info, device, dtype, requires_grad, **kwargs): - exclude_zero = requires_grad and op_info.op == torch.special.i0e + exclude_zero = requires_grad and op_info.op is torch.special.i0e make_arg = partial( make_tensor, dtype=dtype, diff --git a/torch/testing/_internal/triton_utils.py b/torch/testing/_internal/triton_utils.py index 0964c68ebb20b..a0fcbaee30f52 100644 --- a/torch/testing/_internal/triton_utils.py +++ b/torch/testing/_internal/triton_utils.py @@ -823,7 +823,7 @@ def add_4_times_kernel( mask = offsets < n_elements x = tl.load(in_ptr0 + offsets, mask=mask) y = tl.load(in_ptr1 + offsets, mask=mask) - for i in range(2): + for _ in range(2): output = x + y tl.store(out_ptr + offsets, output, mask=mask) i = 2 diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 51f1704e4a22c..603625ed97c12 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -53,6 +53,7 @@ "DumpableContext", "ToDumpableContextFn", "FromDumpableContextFn", + "PyTreeSpec", "TreeSpec", "LeafSpec", "keystr", diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 7c9d1a850a46f..21f27243914f9 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -1,6 +1,8 @@ # mypy: allow-untyped-defs import contextlib -from typing import Optional, TYPE_CHECKING +import functools +import traceback +from typing import Any, Callable, Optional, TYPE_CHECKING import torch from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode @@ -10,7 +12,8 @@ _get_current_dispatch_mode_stack, TorchDispatchMode, ) -from torch.utils._pytree import tree_map +from torch.utils._pytree import tree_all, tree_map +from torch.utils._traceback import CapturedTraceback if TYPE_CHECKING: @@ -19,7 +22,10 @@ __all__ = ["DebugMode", "get_active_debug_mode"] + REDISTRIBUTE_FUNC = "redistribute_input" +_DISPATCH_RECORD_HOOKS: list[Callable] = [] +_DISPATCH_LOG_HOOKS: list[Callable] = [] def _stringify_shape(shape) -> str: @@ -81,11 +87,62 @@ def to_str(x): return str(arg) +def default_hash_fn(t: torch.Tensor, use_scalar: bool = False) -> torch.Tensor: + """ + from Observer. Computes a hash for a tensor by converting it to float (if needed), making it contiguous, + replacing NaN/inf values with fixed numbers, and then computing the L1 norm in float64 or complex128. + This is used to generate a deterministic summary value for tensor comparison. + """ + if not (t.is_floating_point() or t.is_complex()): + t = t.float() + t = t.contiguous() + # Clean the tensor to handle NaN/inf values, then compute norm + t_clean = torch.nan_to_num(t, nan=0.0, posinf=1.0, neginf=-1.0) + + dtype = torch.complex128 if t.is_complex() else torch.float64 + out = t_clean.norm(p=1, dtype=dtype) + if use_scalar: + return out.item() + return out + + +def _get_stack_trace() -> str: + from torch.fx.experimental.symbolic_shapes import uninteresting_files + + summary = CapturedTraceback.extract().summary() + summary = summary[:-4] # filter out DebugMode frames + summary = [ + frame for frame in summary if frame.filename not in uninteresting_files() + ] + summary = traceback.StackSummary.from_list(summary) + return "".join(summary.format()) + + class _DebugCall: """Base class for tracking operator calls in DebugMode""" - def __init__(self, call_depth: int): + def __init__( + self, + call_depth: int, + record: Optional[dict[str, Any]] = None, + log: Optional[dict[str, Any]] = None, + stack: bool = False, + ): self.call_depth = call_depth + if stack: + self.stack_trace = _get_stack_trace() + + # results from dispatch hooks + self.record = record + self.log = log + + def stringify_args(self, attributes: list[str]) -> None: + """ + To reduce memory consumption, this method stringifies args/kwargs, stores the result, and deletes original args/kwargs. + """ + raise NotImplementedError( + "Subclasses must implement stringify_args(), even if no-op" + ) def render(self, attributes: list[str]) -> str: raise NotImplementedError("Subclasses must implement string render()") @@ -97,21 +154,48 @@ def __repr__(self) -> str: class _OpCall(_DebugCall): """Normal operator call""" - def __init__(self, op, args: tuple, kwargs: dict, call_depth: int): - super().__init__(call_depth) + def __init__( + self, + op, + args: tuple, + kwargs: dict, + call_depth: int, + stack: bool = False, + ): + super().__init__(call_depth, stack=stack) self.op = op self.args = args self.kwargs = kwargs - def render(self, attributes: list[str]) -> str: - args_str = ", ".join(_arg_to_str(arg, attributes) for arg in self.args) + self.args_str: Optional[str] = None + self.kwargs_str: Optional[str] = None + def stringify_args(self, attributes: list[str]) -> None: + self.args_str = ", ".join(_arg_to_str(arg, attributes) for arg in self.args) if self.kwargs: - kwargs_str = ", " + ", ".join( + self.kwargs_str = ", " + ", ".join( f"{k}={_arg_to_str(v, attributes)}" for k, v in self.kwargs.items() ) else: - kwargs_str = "" + self.kwargs_str = "" + del self.args + del self.kwargs + + def render(self, attributes: list[str]) -> str: + if self.args_str is not None: + args_str = self.args_str + else: + args_str = ", ".join(_arg_to_str(arg, attributes) for arg in self.args) + + if self.kwargs_str is not None: + kwargs_str = self.kwargs_str + else: + if self.kwargs: + kwargs_str = ", " + ", ".join( + f"{k}={_arg_to_str(v, attributes)}" for k, v in self.kwargs.items() + ) + else: + kwargs_str = "" if isinstance(self.op, torch._ops.OpOverload): op_name = self.op.__qualname__ @@ -120,27 +204,50 @@ def render(self, attributes: list[str]) -> str: else: op_name = str(self.op) - return f"{op_name}({args_str}{kwargs_str})" + base_str = f"{op_name}({args_str}{kwargs_str})" + + if self.log: + base_str += f" # {self.log}" + return base_str def __iter__(self): # for BC; tuple(self) returns (op, args, kwargs, call_depth) - yield from [self.op, self.args, self.kwargs, self.call_depth] + if self.args_str is not None: + yield from [self.op, self.args_str, self.kwargs_str, self.call_depth] + else: + yield from [self.op, self.args, self.kwargs, self.call_depth] class _RedistributeCall(_DebugCall): """Redistribute call from DTensor dispatch""" def __init__( - self, arg, src_placement, dst_placement, transform_info_str, call_depth + self, + arg, + src_placement, + dst_placement, + transform_info_str, + call_depth, + stack=False, ): - super().__init__(call_depth) + super().__init__(call_depth, stack=stack) self.arg = arg self.src_placement = src_placement self.dst_placement = dst_placement self.transform_info_str = transform_info_str + self.arg_str: Optional[str] = None + + def stringify_args(self, attributes: list[str]) -> None: + self.arg_str = f"{_arg_to_str(self.arg, attributes)}" + del self.arg + def render(self, attributes: list[str]) -> str: - arg_str = f"{_arg_to_str(self.arg, attributes)}" + if self.arg_str is not None: + arg_str = self.arg_str + else: + arg_str = f"{_arg_to_str(self.arg, attributes)}" + if self.transform_info_str is not None: # prioritize over src/dst placements placement_str = f"trace: {self.transform_info_str}" else: @@ -151,11 +258,16 @@ def render(self, attributes: list[str]) -> str: def __iter__(self): # for BC; tuple(self) returns (op, placement info, kwargs, call_depth) + if self.arg_str is not None: + arg = self.arg_str + else: + arg = self.arg + yield REDISTRIBUTE_FUNC if self.transform_info_str: - yield [self.arg, self.transform_info_str] + yield [arg, self.transform_info_str] else: - yield [self.arg, self.src_placement, self.dst_placement] + yield [arg, self.src_placement, self.dst_placement] yield {} yield self.call_depth @@ -163,10 +275,13 @@ def __iter__(self): class _NNModuleCall(_DebugCall): """Designates entering an nn.Module's forward method""" - def __init__(self, module_name: str, call_depth: int): - super().__init__(call_depth) + def __init__(self, module_name: str, call_depth: int, stack: bool = False): + super().__init__(call_depth, stack=stack) self.module_name = module_name + def stringify_args(self, attributes: list[str]) -> None: + pass # nothing to stringify + def render(self, attributes: list[str]) -> str: return f"[nn.Mod] {self.module_name}" @@ -179,6 +294,33 @@ def __iter__(self): ] +def _run_hook(hook, *args): + out = hook(*args) + assert out is None or isinstance(out, dict) + return out + + +def _run_dispatch_hooks(call: _DebugCall, func, types, args, kwargs, result) -> None: + global _DISPATCH_RECORD_HOOKS, _DISPATCH_LOG_HOOKS + if _DISPATCH_RECORD_HOOKS: + record = {} + for hook in _DISPATCH_RECORD_HOOKS: + hook_out = _run_hook(hook, func, types, args, kwargs, result) + if hook_out is not None: + record.update(hook_out) + if record: + call.record = record + + if _DISPATCH_LOG_HOOKS: + log = {} + for hook in _DISPATCH_LOG_HOOKS: + hook_out = _run_hook(hook, func, types, args, kwargs, result) + if hook_out is not None: + log.update(hook_out) + if log: + call.log = log + + class DebugMode(TorchDispatchMode): def __init__( self, @@ -188,22 +330,42 @@ def __init__( record_realtensor=True, record_tensor_attributes=None, record_nn_module=False, + store_original_args=False, + record_stack_trace=False, ): super().__init__() import torch.distributed.tensor # noqa: F401 self.supports_higher_order_operators = True + + # Pushes DebugMode onto the torchfunction stack, and records __torch_function__ calls as well. + # WARNING: currently incompatible with torch.compile due to dynamo guard failures. self.record_torchfunction = record_torchfunction + + # Records __torch_dispatch__ calls on FakeTensors. self.record_faketensor = record_faketensor + + # Records __torch_dispatch__ calls on real tensors. self.record_realtensor = record_realtensor + + # Optional list[str] of tensor attributes, to be annotated in the string dump. self.record_tensor_attributes = record_tensor_attributes or [] + # Uses ModTracker to record nn.Module entrances, as _NNModuleCall entries. + # This flag currently has no effect on torch.compiled-regions. self.record_nn_module = record_nn_module self.module_tracker: Optional[ModTracker] = None if self.record_nn_module: self.module_tracker_setup() + # If True, stores call args/kwargs in logs, without immediately stringifying. + # Defaults to False for memory concerns. + self.store_original_args = store_original_args + + # For stack trace recording, stores log call stack traces in .stack_trace. + self.record_stack_trace = record_stack_trace + self.operators = [] self.call_depth = 0 @@ -214,11 +376,18 @@ def __init__( def ignore_compile_internals(cls): return True + def _record_call(self, call): + if not self.store_original_args: + call.stringify_args(self.record_tensor_attributes) + self.operators.append(call) + def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - self.operators.append(_OpCall(func, args, kwargs, self.call_depth)) + self._record_call( + _OpCall(func, args, kwargs, self.call_depth, stack=self.record_stack_trace) + ) try: self.call_depth += 1 @@ -231,22 +400,40 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = {} # Record the operation with its call depth + call = None if torch.distributed.tensor.DTensor in types: - self.operators.append(_OpCall(func, args, kwargs, self.call_depth)) + call = _OpCall( + func, args, kwargs, self.call_depth, stack=self.record_stack_trace + ) + self._record_call(call) return NotImplemented elif FakeTensor in types or isinstance( _get_current_dispatch_mode(), FakeTensorMode ): if self.record_faketensor: if func != torch.ops.prim.device.default: - self.operators.append( - _OpCall(func, args, kwargs, self.call_depth + 1) + call = _OpCall( + func, + args, + kwargs, + self.call_depth + 1, + stack=self.record_stack_trace, ) + self._record_call(call) elif len(types) == 0: if self.record_realtensor: - self.operators.append(_OpCall(func, args, kwargs, self.call_depth + 1)) + call = _OpCall( + func, + args, + kwargs, + self.call_depth + 1, + stack=self.record_stack_trace, + ) + self._record_call(call) result = func(*args, **kwargs) + if call: + _run_dispatch_hooks(call, func, types, args, kwargs, result) return result @@ -296,13 +483,14 @@ def record_redistribute_calls( transform_info_str: Optional[str] = None, ): try: - self.operators.append( + self._record_call( _RedistributeCall( arg, src_placement=src_placement, dst_placement=dst_placement, transform_info_str=transform_info_str, call_depth=self.call_depth + 1, + stack=self.record_stack_trace, ) ) self.call_depth += 1 @@ -319,6 +507,89 @@ def debug_string(self) -> str: ) return result + @staticmethod + @contextlib.contextmanager + def dispatch_hooks( + record_hook: Optional[Callable] = None, + log_hook: Optional[Callable] = None, + ): + """ + Allows installing post-hooks on arguments to intercepted __torch_dispatch__ calls; + hook signatures are expected as (func, types, args, kwargs, result), + i.e. __torch_dispatch__ args + return value. + + Logging hook outputs are stored in call.log and annotate calls in debug_string(), + while recording hook outputs are just stored in call.record. + For now hooks are expected to return dictionaries. + """ + global _DISPATCH_RECORD_HOOKS, _DISPATCH_LOG_HOOKS + + if record_hook: + _DISPATCH_RECORD_HOOKS.append(record_hook) + if log_hook: + _DISPATCH_LOG_HOOKS.append(log_hook) + try: + yield + finally: + if record_hook: + _DISPATCH_RECORD_HOOKS.pop() + if log_hook: + _DISPATCH_LOG_HOOKS.pop() + + @staticmethod + @contextlib.contextmanager + def record_outputs(): + """ + Hook for storing cloned output tensors in .record["output"]. + """ + + def dispatch_hook(func, types, args, kwargs, result): + with torch._C._DisablePythonDispatcher(): + out = tree_map( + lambda x: x.clone() if isinstance(x, torch.Tensor) else x, result + ) + return {"output": out} + + with DebugMode.dispatch_hooks(record_hook=dispatch_hook): + yield + + @staticmethod + @contextlib.contextmanager + def log_tensor_hashes( + hash_fn: Optional[Callable] = None, hash_inputs: bool = False + ): + """ + Installs hook for tensor hash logging. + + hash_fn: optional function for custom hashing + hash_inputs: if True, also hashes tensors in (args, kwargs), storing them in "input_hash". + NOTE: this is currently a post-hook, so e.g. inplace ops will log the "output" hashes. + """ + if hash_fn is None: + hash_fn = functools.partial(default_hash_fn, use_scalar=True) + + def _tree_hash(obj): + with torch._C._DisablePythonDispatcher(): + return tree_map( + lambda x: hash_fn(x) if isinstance(x, torch.Tensor) else None, obj + ) + + def _dispatch_hash_hook(func, types, args, kwargs, result): + if "empty" in str(func) or "profiler" in str(func): + return None + + out = {} + out["hash"] = _tree_hash(result) + if hash_inputs: + out["input_hash"] = _tree_hash((args, kwargs)) + + if tree_all(lambda x: x is None, out.values()): + return None + return out + + with DebugMode.dispatch_hooks(log_hook=_dispatch_hash_hook): + yield + def get_active_debug_mode() -> Optional[DebugMode]: debug_mode = None diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 4ab48bc41ba53..52be3280c9c39 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -1,11 +1,12 @@ # mypy: allow-untyped-defs +from __future__ import annotations + import contextlib import functools import warnings from collections import deque -from collections.abc import Sequence from dataclasses import dataclass -from typing import Optional, overload, Protocol, Union +from typing import cast, Optional, overload, Protocol, TYPE_CHECKING, Union from typing_extensions import TypeIs import torch @@ -20,6 +21,10 @@ ) +if TYPE_CHECKING: + from collections.abc import Sequence + + # TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it: # - We need a better user-facing api for _DisableTorchDispatch that # is able to selectively disable __torch_dispatch__ of a particular class. @@ -430,7 +435,7 @@ def to( @overload def to( self, - device: Optional["torch._prims_common.DeviceLikeType"] = None, + device: Optional[torch._prims_common.DeviceLikeType] = None, dtype: Optional[torch.types._dtype] = None, non_blocking: bool = False, copy: bool = False, @@ -601,13 +606,20 @@ def alias_non_inplace_storage(arg, ret): raise AssertionError(f"expected torch.Tensor, got {type(ret)}") torch._functionalize_unsafe_set(ret, arg) - for arg_idx, schema_arg in enumerate(schema_info.args): - for return_idx, schema_out in enumerate(schema_info.outs): - is_read_only_alias_match = ( - schema_arg.alias_set & schema_out.alias_set - ) and not schema_arg.is_write - if is_read_only_alias_match: - alias_non_inplace_storage(args[arg_idx], outs[return_idx]) + for arg_idx, return_idx in schema_info.read_only_alias_match_indexes: + alias_non_inplace_storage(args[arg_idx], outs[return_idx]) + + +def _get_write_alias(x) -> Optional[str]: + alias_set = x.alias_set + if not alias_set or not x.is_write: + return None + # torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing + if len(alias_set) != 1: + raise AssertionError("Expected alias_set to contain exactly one element") + # timeit says next(iter(alias_set)) is faster than list(alias_set)[0] even for + # set of size 1 on Python 3.13. + return next(iter(alias_set)) # This abstracts over the fact that in return_and_correct_aliasing, @@ -625,13 +637,16 @@ class SchemaInfo: args: list[AliasInfo] outs: list[AliasInfo] - # NOTE[SchemaInfo int_tags]: This has nothing to do with aliasing, but we take - # advantage of our existing caching of data for each OpOverload to paper over an - # efficiency problem with pybind11::enum_ (which currently is used to implement - # torch.Tag): a scan over a list of pybind enums using `in` is inefficient because - # each element must be converted to int with the __int__ method, which incurs a lot - # of overhead. Converting to int once and caching removes this per-op overhead. - int_tags: list[int] + is_inplace_view_op: bool + + # [_get_write_alias(x) for x in outs]. Guaranteed to contain no Nones; we coerce + # all-Nones result to empty list instead, and we don't support + # some-but-not-all-Nones. + outs_write_aliases: Optional[list[str]] + + # List of (arg_idx, return_idx) where args[arg_idx].alias_set & + # outs[out_idx].alias_set is not empty, and not args[arg_idx].is_write. + read_only_alias_match_indexes: list[tuple[int, int]] # Given an OpOverload, returns schema information on it. @@ -702,14 +717,92 @@ def get_alias_info(func) -> SchemaInfo: ) for a in func._schema.returns ] + read_only_alias_match_indexes = [] + for arg_idx, schema_arg in enumerate(arg_schemas): + for return_idx, schema_out in enumerate(out_schemas): + is_read_only_alias_match = ( + schema_arg.alias_set & schema_out.alias_set + ) and not schema_arg.is_write + if is_read_only_alias_match: + read_only_alias_match_indexes.append((arg_idx, return_idx)) + + outs_write_aliases_list: list[Optional[str]] = [ + _get_write_alias(r) for r in out_schemas + ] + non_nones = sum(x is not None for x in outs_write_aliases_list) + if non_nones == 0: + outs_write_aliases: Optional[list[str]] = None + elif non_nones != len(outs_write_aliases_list): + # simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)" + raise RuntimeError("Unsupported schema: " + str(func._schema)) + else: + outs_write_aliases = cast(list[str], outs_write_aliases_list) + schema_info = SchemaInfo( - args=arg_schemas, outs=out_schemas, int_tags=[int(x) for x in func.tags] + args=arg_schemas, + outs=out_schemas, + # This check is surprisingly expensive because pybind11 enum_s are + # inefficient. Just cache it. + is_inplace_view_op=torch.Tag.inplace_view in func.tags, + outs_write_aliases=outs_write_aliases, + read_only_alias_match_indexes=read_only_alias_match_indexes, ) return schema_info -# See NOTE[SchemaInfo int_tags] above. -_TORCH_TAG_INPLACE_VIEW_INT = int(torch.Tag.inplace_view) # type: ignore[call-overload] +def autograd_would_have_decomposed( + func: torch._ops.OpOverload, flat_args: Sequence[Union[torch.Tensor, object]] +) -> bool: + """ + Suppose that an operator has CompositeImplicitAutograd decomp registered. + Would autograd have used this decomposition? It will only use it if there + isn't an explicit backend registration for the device as well. This function + will tell if this would have occurred. + + Why do we need to apply these decompositions later? When inference mode is + on, the autograd key is bypassed entirely, so a lower level mode cannot rely + on the decomposition have been applied. It's easy to accidentally never apply + the decomposition, resulting in an operator showing up in a graph that + is unexpected. + + Why do we need to AVOID applying the decomposition when autograd wouldn't + have decomposed? If autograd doesn't decompose, this means in eager mode + we would have run the fused kernel. It must be possible to trace this + fused kernel directly into the graph for fidelity with eager (NB: a user + has the option of then further decomposing at proxy tensor mode via + decomposition table, but we must preserve it to proxy mode to have the + choice.) + + Why does functionalization need to also perform the test here? This is + because some CompositeImplicitAutograd decompositions are not functional. + If we are eventually going to decompose, we need to do this while we can + still turn functionalization back on, so those decompositions get functionalized. + So an early decomposition in functionalization may still be necessary. Note that + if proxy tensor decomposition process could turn functionalization back on, this + wouldn't be necessary, and maybe that is a useful thing to do anyway because + the decomposition table is user specified and a user could violate the functional + decomp requirement with a bad decomp. If this happened, then you could always + pass through functionalization. + """ + has_backend_registration = False + for a in flat_args: + if isinstance(a, torch.Tensor): + backend_key = torch._C._parse_dispatch_key( + torch._C._dispatch_key_for_device(a.device.type) + ) + assert backend_key is not None + # TODO: use func.has_kernel_for_dispatch_key(backend_key) + # but this one checks py_impl and CompositeImplicitAutograd + # incorrectly shows up as has backend reg here + has_backend_registration = torch._C._dispatch_has_kernel_for_dispatch_key( + func.name(), backend_key + ) + + # in theory we should take all backend keys and take the highest priority one + # to properly mimic the dispatcher, + # this just grabs the first tensor and takes its device key + break + return not has_backend_registration def return_and_correct_aliasing(func, args, kwargs, out): @@ -732,17 +825,6 @@ def return_and_correct_aliasing(func, args, kwargs, out): # once for every op in the graph during functionalization. schema_info = get_alias_info(func) - def get_write_alias(x): - alias_set = x.alias_set - if not alias_set or not x.is_write: - return None - # torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing - if len(alias_set) != 1: - raise AssertionError("Expected alias_set to contain exactly one element") - # timeit says next(iter(alias_set)) is faster than list(alias_set)[0] even for - # set of size 1 on Python 3.13. - return next(iter(alias_set)) - def get_arg_from_alias(output_alias, schema_info, args, kwargs): new_args, new_kwargs = torch.fx.operator_schemas.normalize_function( # type: ignore[misc] func, args=args, kwargs=kwargs @@ -770,14 +852,13 @@ def get_arg_from_alias(output_alias, schema_info, args, kwargs): # For inplace_view ops in particular, we'll try hard to make sure that the wrapper subclass's # metadata is set correctly. - # See NOTE[SchemaInfo int_tags] above. - if _TORCH_TAG_INPLACE_VIEW_INT in schema_info.int_tags: + if schema_info.is_inplace_view_op: # no_dispatch() to make sure that we secretly change the metadata on the wrapper, # but don't end up dispatching the op anywhere else. mutated_args = [ x for i, x in enumerate(args) - if get_write_alias(schema_info.args[i]) is not None + if _get_write_alias(schema_info.args[i]) is not None ] # Assumption: we have a very small number of inplace_view ops that follow a strict schema: # there is only a single argument that gets its metadata mutated. @@ -803,16 +884,11 @@ def get_arg_from_alias(output_alias, schema_info, args, kwargs): # Next: we need to make sure to return inputs directly, if the output is a mutable alias (e.g. add_()). - # Compute write aliases once instead of repeatedly. - schema_info_outs_write_aliases = [get_write_alias(r) for r in schema_info.outs] + schema_info_outs_write_aliases = schema_info.outs_write_aliases # simple case: none of our outputs have mutable aliases, so we can return the output as-is - if not any(x is not None for x in schema_info_outs_write_aliases): + if schema_info_outs_write_aliases is None: return out - # simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)" - if not all(x is not None for x in schema_info_outs_write_aliases): - raise RuntimeError("Unsupported schema: " + str(func._schema)) - if len(schema_info_outs_write_aliases) == 1: return get_arg_from_alias( schema_info_outs_write_aliases[0], schema_info, args, kwargs diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index b2c7c9985bf52..56704bb3f8024 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -39,7 +39,7 @@ TypeVar, Union, ) -from typing_extensions import deprecated, NamedTuple, Self +from typing_extensions import deprecated, NamedTuple, Self, TypeAlias from torch.torch_version import TorchVersion as _TorchVersion @@ -52,6 +52,7 @@ "DumpableContext", "ToDumpableContextFn", "FromDumpableContextFn", + "PyTreeSpec", "TreeSpec", "LeafSpec", "keystr", @@ -364,11 +365,15 @@ def _flatten_fn(obj: Any) -> tuple[list[Any], Context]: def _unflatten_fn(values: Iterable[Any], context: Context) -> Any: flat_names, none_names = context - return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names)) + return cls( + **dict(zip(flat_names, values, strict=True)), **dict.fromkeys(none_names) + ) def _flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]: flattened, (flat_names, _none_names) = _flatten_fn(obj) # type: ignore[misc] - return [(GetAttrKey(k), v) for k, v in zip(flat_names, flattened)], flat_names + return [ + (GetAttrKey(k), v) for k, v in zip(flat_names, flattened, strict=True) + ], flat_names _private_register_pytree_node( cls, @@ -788,11 +793,11 @@ def _dict_flatten_with_keys( ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _dict_flatten(d) # pyrefly: ignore [bad-return] - return [(MappingKey(k), v) for k, v in zip(context, values)], context + return [(MappingKey(k), v) for k, v in zip(context, values, strict=True)], context def _dict_unflatten(values: Iterable[T], context: Context) -> dict[Any, T]: - return dict(zip(context, values)) + return dict(zip(context, values, strict=True)) def _namedtuple_flatten(d: NamedTuple) -> tuple[list[Any], Context]: @@ -805,7 +810,10 @@ def _namedtuple_flatten_with_keys( values, context = _namedtuple_flatten(d) # pyrefly: ignore [bad-return] return ( - [(GetAttrKey(field), v) for field, v in zip(context._fields, values)], + [ + (GetAttrKey(field), v) + for field, v in zip(context._fields, values, strict=True) + ], context, ) @@ -854,14 +862,14 @@ def _ordereddict_flatten_with_keys( ) -> tuple[list[tuple[KeyEntry, T]], Context]: values, context = _ordereddict_flatten(d) # pyrefly: ignore [bad-return] - return [(MappingKey(k), v) for k, v in zip(context, values)], context + return [(MappingKey(k), v) for k, v in zip(context, values, strict=True)], context def _ordereddict_unflatten( values: Iterable[T], context: Context, ) -> OrderedDict[Any, T]: - return OrderedDict((key, value) for key, value in zip(context, values)) + return OrderedDict((key, value) for key, value in zip(context, values, strict=True)) _odict_flatten = _ordereddict_flatten @@ -879,7 +887,9 @@ def _defaultdict_flatten_with_keys( values, context = _defaultdict_flatten(d) _, dict_context = context # pyrefly: ignore [bad-return] - return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context + return [ + (MappingKey(k), v) for k, v in zip(dict_context, values, strict=True) + ], context def _defaultdict_unflatten( @@ -1067,11 +1077,14 @@ def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) - # A TreeSpec represents the structure of a pytree. It holds: -# "type": the type of root Node of the pytree -# context: some context that is useful in unflattening the pytree -# children_specs: specs for each child of the root Node -# num_leaves: the number of leaves -@dataclasses.dataclass(init=True, frozen=True, eq=True, repr=False) +# "type": the type of root Node of the pytree +# context: some context that is useful in unflattening the pytree +# children(): specs for each child of the root Node +# num_nodes: the total number of nodes +# num_leaves: the number of leaves +# num_children: the number of children of the root Node (i.e., len(children())) +# is_leaf(): whether the root Node is a leaf +@dataclasses.dataclass(init=False, frozen=True, eq=True, repr=False) class TreeSpec: type: Any _context: Context @@ -1081,6 +1094,17 @@ class TreeSpec: num_leaves: int = dataclasses.field(init=False) num_children: int = dataclasses.field(init=False) + def __init__( + self, + type: Any, + context: Context, # keep for backward compatibility + children_specs: list[Self], # keep for backward compatibility + ) -> None: + object.__setattr__(self, "type", type) + object.__setattr__(self, "_context", context) + object.__setattr__(self, "_children", children_specs) + self.__post_init__() + def __post_init__(self) -> None: if self.type is None: assert self._context is None @@ -1224,7 +1248,7 @@ def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None: f"expected {treespec._context!r}, but got {context!r}.", # namedtuple type mismatch ) - for subtree, subspec in zip(children, treespec._children): + for subtree, subspec in zip(children, treespec._children, strict=True): helper(subspec, subtree, subtrees) subtrees: list[PyTree] = [] @@ -1277,6 +1301,9 @@ def __hash__(self) -> int: return hash((node_type, hashable_context, tuple(self._children))) +PyTreeSpec: TypeAlias = TreeSpec + + # NOTE: subclassing a dataclass is subtle. In order to enable reasoning about # this class with `dataclasses.fields`, etc., while having a simplified # constructor that takes no argument, we wrap with `dataclass(init=True, ...)` @@ -1820,7 +1847,7 @@ def _broadcast_to_and_flatten( # Recursively flatten the children result: list[Any] = [] - for child, child_spec in zip(child_pytrees, treespec._children): + for child, child_spec in zip(child_pytrees, treespec._children, strict=True): flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf) if flat is not None: result += flat @@ -2122,9 +2149,9 @@ def tree_map_with_path( ``xs`` is the tuple of values at corresponding nodes in ``rests``. """ keypath_leaves, treespec = tree_flatten_with_path(tree, is_leaf) - keypath_leaves = list(zip(*keypath_leaves)) + keypath_leaves = list(zip(*keypath_leaves, strict=True)) all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests] - return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves)) + return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves, strict=True)) def keystr(kp: KeyPath) -> str: diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index d152b719bcde5..297d7f4eec9a8 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1223,7 +1223,8 @@ def eval(cls, *args): # When all strides are integral, we can sort, and the size for the # largest stride doesn't matter and can be arbitrarily symbolic s_sizes, s_strides = zip( - *sorted(zip(sizes, strides), key=operator.itemgetter(1)) + *sorted(zip(sizes, strides, strict=True), key=operator.itemgetter(1)), + strict=True, ) # Put something arbitrary in the max size spot, it'll be ignored if all(isinstance(a, sympy.Integer) for a in s_sizes[:-1]): diff --git a/torch/utils/benchmark/examples/op_benchmark.py b/torch/utils/benchmark/examples/op_benchmark.py index 55f25e5c896d5..8a76331d3404f 100644 --- a/torch/utils/benchmark/examples/op_benchmark.py +++ b/torch/utils/benchmark/examples/op_benchmark.py @@ -32,7 +32,7 @@ def run(n, stmt, fuzzer_cls): float_iter = fuzzer_cls(seed=0, dtype=torch.float32).take(n) int_iter = fuzzer_cls(seed=0, dtype=torch.int32).take(n) raw_results = [] - for i, (float_values, int_values) in enumerate(zip(float_iter, int_iter)): + for i, (float_values, int_values) in enumerate(zip(float_iter, int_iter, strict=True)): float_tensors, float_tensor_params, float_params = float_values int_tensors, int_tensor_params, int_params = int_values @@ -89,7 +89,7 @@ def run(n, stmt, fuzzer_cls): for t_float, t_int, rel_diff, descriptions in results: time_str = [f"{rel_diff * 100:>4.1f}% {'int' if t_int < t_float else 'float':<20}"] time_str.extend(["".ljust(len(time_str[0])) for _ in descriptions[:-1]]) - for t_str, (name, shape, order, steps) in zip(time_str, descriptions): + for t_str, (name, shape, order, steps) in zip(time_str, descriptions, strict=True): name = f"{name}:".ljust(name_len + 1) shape = shape.ljust(shape_len + 10) order = order.ljust(order_len) diff --git a/torch/utils/benchmark/examples/sparse/op_benchmark.py b/torch/utils/benchmark/examples/sparse/op_benchmark.py index bd52084fbc0cc..b574b0223d489 100644 --- a/torch/utils/benchmark/examples/sparse/op_benchmark.py +++ b/torch/utils/benchmark/examples/sparse/op_benchmark.py @@ -29,7 +29,7 @@ def run(n, stmt, fuzzer_cls): float_iter = fuzzer_cls(seed=0, dtype=torch.float32).take(n) double_iter = fuzzer_cls(seed=0, dtype=torch.float64).take(n) raw_results = [] - for i, (float_values, int_values) in enumerate(zip(float_iter, double_iter)): + for i, (float_values, int_values) in enumerate(zip(float_iter, double_iter, strict=True)): float_tensors, float_tensor_params, float_params = float_values int_tensors, int_tensor_params, int_params = int_values @@ -84,7 +84,7 @@ def run(n, stmt, fuzzer_cls): for t_float, t_int, rel_diff, descriptions in results: time_str = [f"{rel_diff * 100:>4.1f}% {'int' if t_int < t_float else 'float':<20}"] time_str.extend(["".ljust(len(time_str[0])) for _ in descriptions[:-1]]) - for t_str, (name, shape, sparse_dim, is_coalesced) in zip(time_str, descriptions): + for t_str, (name, shape, sparse_dim, is_coalesced) in zip(time_str, descriptions, strict=True): name = f"{name}:".ljust(name_len + 1) shape = shape.ljust(shape_len + 10) sparse_dim = sparse_dim.ljust(sparse_dim_len) diff --git a/torch/utils/benchmark/utils/compare.py b/torch/utils/benchmark/utils/compare.py index 0b8a2163b3c4c..21a83926a2e82 100644 --- a/torch/utils/benchmark/utils/compare.py +++ b/torch/utils/benchmark/utils/compare.py @@ -51,7 +51,7 @@ def __init__( unit_digits = max(d for d in leading_digits if d is not None) decimal_digits = min( max(m.significant_figures - digits, 0) - for digits, m in zip(leading_digits, self._flat_results) + for digits, m in zip(leading_digits, self._flat_results, strict=True) if (m is not None) and (digits is not None) ) if self._trim_significant_figures else 1 length = unit_digits + decimal_digits + (1 if decimal_digits else 0) @@ -99,7 +99,7 @@ def as_column_strings(self): env = f"({concrete_results[0].env})" if self._render_env else "" env = env.ljust(self._env_str_len + 4) output = [" " + env + concrete_results[0].as_row_name] - for m, col in zip(self._results, self._columns or ()): + for m, col in zip(self._results, self._columns or (), strict=False): if m is None: output.append(col.num_to_str(None, 1, None)) else: @@ -141,7 +141,7 @@ def finalize_column_strings(self, column_strings, col_widths): ] row_contents = [column_strings[0].ljust(col_widths[0])] - for col_str, width, result, best_value in zip(column_strings[1:], col_widths[1:], self._results, best_values): + for col_str, width, result, best_value in zip(column_strings[1:], col_widths[1:], self._results, best_values, strict=False): col_str = col_str.center(width) if self._colorize != Colorize.NONE and result is not None and best_value is not None: col_str = self.color_segment(col_str, result.median, best_value) @@ -206,7 +206,7 @@ def populate_rows_and_columns(self) -> tuple[tuple[_Row, ...], tuple[_Column, .. prior_env = "" row_group = -1 rows_by_group: list[list[list[Optional[common.Measurement]]]] = [] - for (num_threads, env, _), row in zip(self.row_keys, ordered_results): + for (num_threads, env, _), row in zip(self.row_keys, ordered_results, strict=True): thread_transition = (num_threads != prior_num_threads) if thread_transition: prior_num_threads = num_threads @@ -250,10 +250,10 @@ def render(self) -> str: for sr in string_rows: sr.extend(["" for _ in range(num_cols - len(sr))]) - col_widths = [max(len(j) for j in i) for i in zip(*string_rows)] - finalized_columns = [" | ".join(i.center(w) for i, w in zip(string_rows[0], col_widths))] + col_widths = [max(len(j) for j in i) for i in zip(*string_rows, strict=True)] + finalized_columns = [" | ".join(i.center(w) for i, w in zip(string_rows[0], col_widths, strict=True))] overall_width = len(finalized_columns[0]) - for string_row, row in zip(string_rows[1:], self.rows): + for string_row, row in zip(string_rows[1:], self.rows, strict=True): finalized_columns.extend(row.row_separator(overall_width)) finalized_columns.append(" | ".join(row.finalize_column_strings(string_row, col_widths))) diff --git a/torch/utils/benchmark/utils/fuzzer.py b/torch/utils/benchmark/utils/fuzzer.py index f7fc21ceaf88b..f343722ef686d 100644 --- a/torch/utils/benchmark/utils/fuzzer.py +++ b/torch/utils/benchmark/utils/fuzzer.py @@ -295,7 +295,7 @@ def _make_tensor(self, params, state): raw_tensor = raw_tensor.permute(tuple(order)).contiguous() raw_tensor = raw_tensor.permute(tuple(np.argsort(order))) - slices = [slice(0, size * step, step) for size, step in zip(size, steps)] + slices = [slice(0, size * step, step) for size, step in zip(size, steps, strict=True)] tensor = raw_tensor[tuple(slices)] properties = { @@ -326,7 +326,7 @@ def resolve(values, dim): size = resolve(self._size, dim) steps = resolve(self._steps or (), dim) - allocation_size = tuple(size_i * step_i for size_i, step_i in zip(size, steps)) + allocation_size = tuple(size_i * step_i for size_i, step_i in zip(size, steps, strict=True)) return size, steps, allocation_size def satisfies_constraints(self, params): diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index 1b16da9f242f3..d9802c06e9444 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -196,7 +196,7 @@ def set_device_states(devices, states, *, device_type=None) -> None: if device_type == "meta": return device_module = _get_device_module(device_type) - for device, state in zip(devices, states): + for device, state in zip(devices, states, strict=False): with device_module.device(device): device_module.set_rng_state(state) @@ -794,7 +794,7 @@ def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: # Only tensors can be saved with ctx.save_for_backward, everything else # is captured by get_args, which is saved directly on ctx tensor_indices, tensors = zip( - *[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)] + *[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)], strict=False ) idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)} # args but with tensors replaced with None as placeholders @@ -1020,7 +1020,7 @@ def unpack_error_cb(e: CheckpointError): def get_str_tb(label, capture_logs): out = "" total_len = len(capture_logs.logs) - for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs)): + for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs, strict=False)): out += f"{log} ({i + 1} of {total_len} in {label})\n\n" found_torch_dispatch = False for line in tb: @@ -1501,7 +1501,7 @@ def _checkpoint_without_reentrant_generator( unpack_error_cb = None if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug: - if context_fn != noop_context_fn: + if context_fn is not noop_context_fn: raise ValueError( "debug=True is incompatible with non-default context_fn" ) @@ -1518,7 +1518,7 @@ def _checkpoint_without_reentrant_generator( device_type = _infer_device_type(*args) device_module = _get_device_module(device_type) forward_context, recompute_context = context_fn() - if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn: + if _is_compiling(fn, args, kwargs) and context_fn is not noop_context_fn: if ( not isinstance(forward_context, TorchDispatchMode) or not isinstance(recompute_context, TorchDispatchMode) diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index aa5ca8c4b5d3a..235b7e104c702 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -2947,7 +2947,7 @@ def sanitize_flags(flags): # Emit one build rule per source to enable incremental build. build = [] - for source_file, object_file in zip(sources, objects): + for source_file, object_file in zip(sources, objects, strict=True): is_cuda_source = _is_cuda_file(source_file) and with_cuda is_sycl_source = _is_sycl_file(source_file) and with_sycl if is_cuda_source: diff --git a/torch/utils/data/_utils/collate.py b/torch/utils/data/_utils/collate.py index efe50ba22e8e6..cb051f6642dcf 100644 --- a/torch/utils/data/_utils/collate.py +++ b/torch/utils/data/_utils/collate.py @@ -197,7 +197,7 @@ def collate( return elem_type( *( collate(samples, collate_fn_map=collate_fn_map) - for samples in zip(*batch) + for samples in zip(*batch, strict=False) ) ) elif isinstance(elem, collections.abc.Sequence): @@ -207,7 +207,9 @@ def collate( # pyrefly: ignore [not-iterable] if not all(len(elem) == elem_size for elem in it): raise RuntimeError("each element in list of batch should be of equal size") - transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. + transposed = list( + zip(*batch, strict=False) + ) # It may be accessed twice, so we use a list. if isinstance(elem, tuple): return [ diff --git a/torch/utils/data/dataframes_pipes.ipynb b/torch/utils/data/dataframes_pipes.ipynb index 2f995aab05abd..bc4abeba15b33 100644 --- a/torch/utils/data/dataframes_pipes.ipynb +++ b/torch/utils/data/dataframes_pipes.ipynb @@ -355,7 +355,7 @@ "dp = dp.shuffle()\n", "dp = dp.batch(2)\n", "print(\"Iterate over DataFrame batches\")\n", - "for i,v in enumerate(dp):\n", + "for v in dp:\n", " print(v)\n", "\n", "# this is similar to batching of regular DataPipe\n", diff --git a/torch/utils/data/datapipes/_typing.py b/torch/utils/data/datapipes/_typing.py index 32777cfd01d34..5392d71bce804 100644 --- a/torch/utils/data/datapipes/_typing.py +++ b/torch/utils/data/datapipes/_typing.py @@ -184,7 +184,7 @@ def _issubtype_with_constraints(variant, constraints, recursive=True): and len(v_args) == len(c_args) and all( issubtype(v_arg, c_arg) - for v_arg, c_arg in zip(v_args, c_args) + for v_arg, c_arg in zip(v_args, c_args, strict=True) ) ): return True @@ -207,7 +207,7 @@ def issubinstance(data, data_type): return True if len(dt_args) != len(data): return False - return all(issubinstance(d, t) for d, t in zip(data, dt_args)) + return all(issubinstance(d, t) for d, t in zip(data, dt_args, strict=True)) elif isinstance(data, (list, set)): if dt_args is None or len(dt_args) == 0: return True diff --git a/torch/utils/data/datapipes/dataframe/datapipes.py b/torch/utils/data/datapipes/dataframe/datapipes.py index edb08d77a81d9..0526b472ad194 100644 --- a/torch/utils/data/datapipes/dataframe/datapipes.py +++ b/torch/utils/data/datapipes/dataframe/datapipes.py @@ -101,7 +101,7 @@ def __iter__(self): filter_res.append(self.filter_fn(df.iloc[i])) buffer = [] - for df, res in zip(all_buffer, filter_res): + for df, res in zip(all_buffer, filter_res, strict=True): if res: buffer.append(df) if len(buffer) == size: diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index 2e3d371244253..6efaa8c3d8be9 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -705,7 +705,7 @@ def __init__(self, *datapipes: IterDataPipe): def __iter__(self) -> Iterator[tuple[_T_co]]: iterators = [iter(datapipe) for datapipe in self.datapipes] - yield from zip(*iterators) + yield from zip(*iterators, strict=False) def __len__(self) -> int: if all(isinstance(dp, Sized) for dp in self.datapipes): diff --git a/torch/utils/data/dataset.py b/torch/utils/data/dataset.py index f4e61963cd01e..b77ff892e6662 100644 --- a/torch/utils/data/dataset.py +++ b/torch/utils/data/dataset.py @@ -267,10 +267,10 @@ def __getitems__(self, indices: list): "Nested dataset's output size mismatch." f" Expected {len(indices)}, got {len(items)}" ) - for data, d_sample in zip(items, dict_batch): + for data, d_sample in zip(items, dict_batch, strict=True): d_sample[k] = data else: - for idx, d_sample in zip(indices, dict_batch): + for idx, d_sample in zip(indices, dict_batch, strict=True): d_sample[k] = dataset[idx] return dict_batch @@ -284,10 +284,10 @@ def __getitems__(self, indices: list): "Nested dataset's output size mismatch." f" Expected {len(indices)}, got {len(items)}" ) - for data, t_sample in zip(items, list_batch): + for data, t_sample in zip(items, list_batch, strict=True): t_sample.append(data) else: - for idx, t_sample in zip(indices, list_batch): + for idx, t_sample in zip(indices, list_batch, strict=True): t_sample.append(dataset[idx]) tuple_batch: list[_T_tuple] = [tuple(sample) for sample in list_batch] return tuple_batch @@ -477,5 +477,5 @@ def random_split( lengths = cast(Sequence[int], lengths) return [ Subset(dataset, indices[offset - length : offset]) - for offset, length in zip(itertools.accumulate(lengths), lengths) + for offset, length in zip(itertools.accumulate(lengths), lengths, strict=True) ] diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index 81f05a936df8f..f36f15ee09589 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -335,7 +335,7 @@ def __iter__(self) -> Iterator[list[int]]: if self.drop_last: # Create multiple references to the same iterator args = [sampler_iter] * self.batch_size - for batch_droplast in zip(*args): + for batch_droplast in zip(*args, strict=False): yield [*batch_droplast] else: batch = [*itertools.islice(sampler_iter, self.batch_size)] diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index e368d52de0c53..28754b1ae277c 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -341,7 +341,7 @@ def _unpack_flash_attention_nested_shapes( raise AssertionError("sdpa_flop_count: cum_seq_q and cum_seq_k must have the same shape") seq_q_lengths = _offsets_to_lengths(cum_seq_q, max_q) seq_k_lengths = _offsets_to_lengths(cum_seq_k, max_k) - for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths): + for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths, strict=True): new_query_shape = (1, h_q, seq_q_len, d_q) new_key_shape = (1, h_k, seq_k_len, d_k) new_value_shape = (1, h_v, seq_k_len, d_v) @@ -396,7 +396,7 @@ def _unpack_efficient_attention_nested_shapes( "cu_seqlens_q and cu_seqlens_k must have the same shape") seqlens_q = _offsets_to_lengths(cu_seqlens_q, max_seqlen_q) seqlens_k = _offsets_to_lengths(cu_seqlens_k, max_seqlen_k) - for len_q, len_k in zip(seqlens_q, seqlens_k): + for len_q, len_k in zip(seqlens_q, seqlens_k, strict=True): new_query_shape = (1, h_q, len_q, d_q) new_key_shape = (1, h_k, len_k, d_k) new_value_shape = (1, h_v, len_k, d_v) diff --git a/torch/utils/hooks.py b/torch/utils/hooks.py index e52a57d709951..3c022a4e85508 100644 --- a/torch/utils/hooks.py +++ b/torch/utils/hooks.py @@ -114,7 +114,7 @@ def __init__(self, module, user_hooks, user_pre_hooks): def _pack_with_none(self, indices, values, size): res = [None] * size - for idx, val in zip(indices, values): + for idx, val in zip(indices, values, strict=True): res[idx] = val return tuple(res) @@ -180,7 +180,7 @@ def _apply_on_tensors(self, fn, args): fn(grad_fns[0]) arg_list = list(args) - for idx, val in zip(tensors_idx, new_tensors): + for idx, val in zip(tensors_idx, new_tensors, strict=True): arg_list[idx] = val if type(args) is tuple: diff --git a/torch/utils/tensorboard/_onnx_graph.py b/torch/utils/tensorboard/_onnx_graph.py index 3b7381737b3e7..abadb7c9fdb42 100644 --- a/torch/utils/tensorboard/_onnx_graph.py +++ b/torch/utils/tensorboard/_onnx_graph.py @@ -24,6 +24,7 @@ def parse(graph): print(node.name) shapeproto = TensorShapeProto( dim=[ + # pyrefly: ignore [missing-attribute] TensorShapeProto.Dim(size=d.dim_value) for d in node.type.tensor_type.shape.dim ] diff --git a/torch/utils/tensorboard/_proto_graph.py b/torch/utils/tensorboard/_proto_graph.py index c4e234dff6ba0..c32be5b2cae36 100644 --- a/torch/utils/tensorboard/_proto_graph.py +++ b/torch/utils/tensorboard/_proto_graph.py @@ -7,6 +7,7 @@ from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto +# pyrefly: ignore [not-a-type] def attr_value_proto(dtype: object, shape: Optional[Sequence[int]], s: Optional[str]) -> dict[str, AttrValue]: """Create a dict of objects matching a NodeDef's attr field. @@ -19,15 +20,18 @@ def attr_value_proto(dtype: object, shape: Optional[Sequence[int]], s: Optional[ attr["attr"] = AttrValue(s=s.encode(encoding="utf_8")) if shape is not None: shapeproto = tensor_shape_proto(shape) + # pyrefly: ignore [missing-attribute] attr["_output_shapes"] = AttrValue(list=AttrValue.ListValue(shape=[shapeproto])) return attr +# pyrefly: ignore [not-a-type] def tensor_shape_proto(outputsize: Sequence[int]) -> TensorShapeProto: """Create an object matching a tensor_shape field. Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto . """ + # pyrefly: ignore [missing-attribute] return TensorShapeProto(dim=[TensorShapeProto.Dim(size=d) for d in outputsize]) @@ -39,7 +43,7 @@ def node_proto( shape: Optional[tuple[int, ...]] = None, outputsize: Optional[Sequence[int]] = None, attributes: str = "", -) -> NodeDef: +) -> NodeDef: # pyrefly: ignore [not-a-type] """Create an object matching a NodeDef. Follows https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto . diff --git a/torch/utils/tensorboard/_pytorch_graph.py b/torch/utils/tensorboard/_pytorch_graph.py index b3ef6a468dca5..859f80e691ce5 100644 --- a/torch/utils/tensorboard/_pytorch_graph.py +++ b/torch/utils/tensorboard/_pytorch_graph.py @@ -167,7 +167,7 @@ def find_common_root(self): def populate_namespace_from_OP_to_IO(self): for node in self.nodes_op: - for node_output, outputSize in zip(node.outputs, node.outputstensor_size): + for node_output, outputSize in zip(node.outputs, node.outputstensor_size, strict=True): self.scope_name_appeared.append(node.scopeName) self.nodes_io[node_output] = NodeBase( node_output, diff --git a/torch/utils/tensorboard/summary.py b/torch/utils/tensorboard/summary.py index ae3b6a7a19a51..f36382cb42e16 100644 --- a/torch/utils/tensorboard/summary.py +++ b/torch/utils/tensorboard/summary.py @@ -261,6 +261,7 @@ def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None): hps.append( HParamInfo( name=k, + # pyrefly: ignore [missing-attribute] type=DataType.Value("DATA_TYPE_FLOAT64"), domain_discrete=domain_discrete, ) @@ -283,6 +284,7 @@ def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None): hps.append( HParamInfo( name=k, + # pyrefly: ignore [missing-attribute] type=DataType.Value("DATA_TYPE_STRING"), domain_discrete=domain_discrete, ) @@ -305,6 +307,7 @@ def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None): hps.append( HParamInfo( name=k, + # pyrefly: ignore [missing-attribute] type=DataType.Value("DATA_TYPE_BOOL"), domain_discrete=domain_discrete, ) @@ -314,6 +317,7 @@ def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None): if isinstance(v, torch.Tensor): v = make_np(v)[0] ssi.hparams[k].number_value = v + # pyrefly: ignore [missing-attribute] hps.append(HParamInfo(name=k, type=DataType.Value("DATA_TYPE_FLOAT64"))) continue raise ValueError( @@ -322,10 +326,12 @@ def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None): content = HParamsPluginData(session_start_info=ssi, version=PLUGIN_DATA_VERSION) smd = SummaryMetadata( + # pyrefly: ignore [missing-attribute] plugin_data=SummaryMetadata.PluginData( plugin_name=PLUGIN_NAME, content=content.SerializeToString() ) ) + # pyrefly: ignore [missing-attribute] ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)]) mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()] @@ -334,19 +340,24 @@ def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None): content = HParamsPluginData(experiment=exp, version=PLUGIN_DATA_VERSION) smd = SummaryMetadata( + # pyrefly: ignore [missing-attribute] plugin_data=SummaryMetadata.PluginData( plugin_name=PLUGIN_NAME, content=content.SerializeToString() ) ) + # pyrefly: ignore [missing-attribute] exp = Summary(value=[Summary.Value(tag=EXPERIMENT_TAG, metadata=smd)]) + # pyrefly: ignore [missing-attribute] sei = SessionEndInfo(status=Status.Value("STATUS_SUCCESS")) content = HParamsPluginData(session_end_info=sei, version=PLUGIN_DATA_VERSION) smd = SummaryMetadata( + # pyrefly: ignore [missing-attribute] plugin_data=SummaryMetadata.PluginData( plugin_name=PLUGIN_NAME, content=content.SerializeToString() ) ) + # pyrefly: ignore [missing-attribute] sei = Summary(value=[Summary.Value(tag=SESSION_END_INFO_TAG, metadata=smd)]) return exp, ssi, sei @@ -380,10 +391,12 @@ def scalar(name, tensor, collections=None, new_style=False, double_precision=Fal if double_precision: tensor_proto = TensorProto(double_val=[scalar], dtype="DT_DOUBLE") + # pyrefly: ignore [missing-attribute] plugin_data = SummaryMetadata.PluginData(plugin_name="scalars") smd = SummaryMetadata(plugin_data=plugin_data) return Summary( value=[ + # pyrefly: ignore [missing-attribute] Summary.Value( tag=name, tensor=tensor_proto, @@ -392,6 +405,7 @@ def scalar(name, tensor, collections=None, new_style=False, double_precision=Fal ] ) else: + # pyrefly: ignore [missing-attribute] return Summary(value=[Summary.Value(tag=name, simple_value=scalar)]) @@ -419,6 +433,7 @@ def tensor_proto(tag, tensor): **{ "dtype": dtype, "tensor_shape": TensorShapeProto( + # pyrefly: ignore [missing-attribute] dim=[TensorShapeProto.Dim(size=x) for x in tensor.shape] ), field_name: conversion_fn(tensor), @@ -427,8 +442,10 @@ def tensor_proto(tag, tensor): else: raise ValueError(f"{tag} has unsupported tensor dtype {tensor.dtype}") + # pyrefly: ignore [missing-attribute] plugin_data = SummaryMetadata.PluginData(plugin_name="tensor") smd = SummaryMetadata(plugin_data=plugin_data) + # pyrefly: ignore [missing-attribute] return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor_proto)]) @@ -462,6 +479,7 @@ def histogram_raw(name, min, max, num, sum, sum_squares, bucket_limits, bucket_c bucket_limit=bucket_limits, bucket=bucket_counts, ) + # pyrefly: ignore [missing-attribute] return Summary(value=[Summary.Value(tag=name, histo=hist)]) @@ -484,6 +502,7 @@ def histogram(name, values, bins, max_bins=None): """ values = make_np(values) hist = make_histogram(values.astype(float), bins, max_bins) + # pyrefly: ignore [missing-attribute] return Summary(value=[Summary.Value(tag=name, histo=hist)]) @@ -577,6 +596,7 @@ def image(tag, tensor, rescale=1, dataformats="NCHW"): tensor = tensor.astype(np.float32) tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8) image = make_image(tensor, rescale=rescale) + # pyrefly: ignore [missing-attribute] return Summary(value=[Summary.Value(tag=tag, image=image)]) @@ -594,6 +614,7 @@ def image_boxes( rois=tensor_boxes, labels=labels, ) + # pyrefly: ignore [missing-attribute] return Summary(value=[Summary.Value(tag=tag, image=image)]) @@ -632,6 +653,7 @@ def make_image(tensor, rescale=1, rois=None, labels=None): image.save(output, format="PNG") image_string = output.getvalue() output.close() + # pyrefly: ignore [missing-attribute] return Summary.Image( height=height, width=width, @@ -648,6 +670,7 @@ def video(tag, tensor, fps=4): tensor = tensor.astype(np.float32) tensor = (tensor * scale_factor).clip(0, 255).astype(np.uint8) video = make_video(tensor, fps) + # pyrefly: ignore [missing-attribute] return Summary(value=[Summary.Value(tag=tag, image=video)]) @@ -689,6 +712,7 @@ def make_video(tensor, fps): except OSError: logger.warning("The temporary file used by moviepy cannot be deleted.") + # pyrefly: ignore [missing-attribute] return Summary.Image( height=h, width=w, colorspace=c, encoded_image_string=tensor_string ) @@ -715,6 +739,7 @@ def audio(tag, tensor, sample_rate=44100): wave_write.writeframes(array.data) audio_string = fio.getvalue() fio.close() + # pyrefly: ignore [missing-attribute] audio = Summary.Audio( sample_rate=sample_rate, num_channels=1, @@ -722,6 +747,7 @@ def audio(tag, tensor, sample_rate=44100): encoded_audio_string=audio_string, content_type="audio/wav", ) + # pyrefly: ignore [missing-attribute] return Summary(value=[Summary.Value(tag=tag, audio=audio)]) @@ -736,6 +762,7 @@ def custom_scalars(layout): raise AssertionError("len(tags) != 3") mgcc = layout_pb2.MarginChartContent( series=[ + # pyrefly: ignore [missing-attribute] layout_pb2.MarginChartContent.Series( value=tags[0], lower=tags[1], upper=tags[2] ) @@ -749,6 +776,7 @@ def custom_scalars(layout): categories.append(layout_pb2.Category(title=k, chart=charts)) layout = layout_pb2.Layout(category=categories) + # pyrefly: ignore [missing-attribute] plugin_data = SummaryMetadata.PluginData(plugin_name="custom_scalars") smd = SummaryMetadata(plugin_data=plugin_data) tensor = TensorProto( @@ -758,12 +786,14 @@ def custom_scalars(layout): ) return Summary( value=[ + # pyrefly: ignore [missing-attribute] Summary.Value(tag="custom_scalars__config__", tensor=tensor, metadata=smd) ] ) def text(tag, text): + # pyrefly: ignore [missing-attribute] plugin_data = SummaryMetadata.PluginData( plugin_name="text", content=TextPluginData(version=0).SerializeToString() ) @@ -771,9 +801,11 @@ def text(tag, text): tensor = TensorProto( dtype="DT_STRING", string_val=[text.encode(encoding="utf_8")], + # pyrefly: ignore [missing-attribute] tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=1)]), ) return Summary( + # pyrefly: ignore [missing-attribute] value=[Summary.Value(tag=tag + "/text_summary", metadata=smd, tensor=tensor)] ) @@ -787,6 +819,7 @@ def pr_curve_raw( pr_curve_plugin_data = PrCurvePluginData( version=0, num_thresholds=num_thresholds ).SerializeToString() + # pyrefly: ignore [missing-attribute] plugin_data = SummaryMetadata.PluginData( plugin_name="pr_curves", content=pr_curve_plugin_data ) @@ -796,11 +829,14 @@ def pr_curve_raw( float_val=data.reshape(-1).tolist(), tensor_shape=TensorShapeProto( dim=[ + # pyrefly: ignore [missing-attribute] TensorShapeProto.Dim(size=data.shape[0]), + # pyrefly: ignore [missing-attribute] TensorShapeProto.Dim(size=data.shape[1]), ] ), ) + # pyrefly: ignore [missing-attribute] return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)]) @@ -813,6 +849,7 @@ def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None): pr_curve_plugin_data = PrCurvePluginData( version=0, num_thresholds=num_thresholds ).SerializeToString() + # pyrefly: ignore [missing-attribute] plugin_data = SummaryMetadata.PluginData( plugin_name="pr_curves", content=pr_curve_plugin_data ) @@ -822,11 +859,14 @@ def pr_curve(tag, labels, predictions, num_thresholds=127, weights=None): float_val=data.reshape(-1).tolist(), tensor_shape=TensorShapeProto( dim=[ + # pyrefly: ignore [missing-attribute] TensorShapeProto.Dim(size=data.shape[0]), + # pyrefly: ignore [missing-attribute] TensorShapeProto.Dim(size=data.shape[1]), ] ), ) + # pyrefly: ignore [missing-attribute] return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)]) @@ -911,13 +951,17 @@ def _get_tensor_summary( float_val=tensor.reshape(-1).tolist(), tensor_shape=TensorShapeProto( dim=[ + # pyrefly: ignore [missing-attribute] TensorShapeProto.Dim(size=tensor.shape[0]), + # pyrefly: ignore [missing-attribute] TensorShapeProto.Dim(size=tensor.shape[1]), + # pyrefly: ignore [missing-attribute] TensorShapeProto.Dim(size=tensor.shape[2]), ] ), ) + # pyrefly: ignore [missing-attribute] tensor_summary = Summary.Value( tag=metadata.get_instance_name(name, content_type), tensor=tensor, @@ -965,8 +1009,11 @@ def mesh( summaries = [] tensors = [ + # pyrefly: ignore [missing-attribute] (vertices, MeshPluginData.VERTEX), + # pyrefly: ignore [missing-attribute] (faces, MeshPluginData.FACE), + # pyrefly: ignore [missing-attribute] (colors, MeshPluginData.COLOR), ] tensors = [tensor for tensor in tensors if tensor[0] is not None] diff --git a/torch/utils/tensorboard/writer.py b/torch/utils/tensorboard/writer.py index e100ddb179f62..4fab33dc7ff09 100644 --- a/torch/utils/tensorboard/writer.py +++ b/torch/utils/tensorboard/writer.py @@ -166,6 +166,7 @@ def reopen(self): The events will go into a new events file. Does nothing if the EventFileWriter was not closed. """ + # pyrefly: ignore [missing-attribute] self.event_writer.reopen() @@ -280,6 +281,7 @@ def _get_file_writer(self): self.file_writer.add_event( Event( step=most_recent_step, + # pyrefly: ignore [missing-attribute] session_log=SessionLog(status=SessionLog.START), ) ) diff --git a/torch/utils/viz/_cycles.py b/torch/utils/viz/_cycles.py index 72ceba903aad2..9587a8d682e5b 100644 --- a/torch/utils/viz/_cycles.py +++ b/torch/utils/viz/_cycles.py @@ -212,7 +212,7 @@ def object_annotation(obj): """ def format_sequence(obj): - body = ','.join(repr(x) if isinstance(x, BASE_TYPES) else type(x).__name__ for i, x in zip(range(8), obj)) + body = ','.join(repr(x) if isinstance(x, BASE_TYPES) else type(x).__name__ for x in obj[:8]) if len(obj) > 8: body = f'{body}, ...{len(obj) - 8}' return body diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 5fec24c74de5a..6f1671e4e7a43 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -245,7 +245,7 @@ def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]: (default). Returns: - Dict[str, Any]: the xpu capability dictionary of the device + dict[str, Any]: the xpu capability dictionary of the device """ props = get_device_properties(device) # Only keep attributes that are safe for dictionary serialization. @@ -521,6 +521,7 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int: # import here to avoid circular import from .memory import ( empty_cache, + get_per_process_memory_fraction, max_memory_allocated, max_memory_reserved, mem_get_info, @@ -562,6 +563,7 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int: "get_device_name", "get_device_properties", "get_gencode_flags", + "get_per_process_memory_fraction", "get_rng_state", "get_rng_state_all", "get_stream_from_external", diff --git a/torch/xpu/memory.py b/torch/xpu/memory.py index 9086b1258fc8c..069d93cefa9b6 100644 --- a/torch/xpu/memory.py +++ b/torch/xpu/memory.py @@ -194,6 +194,26 @@ def mem_get_info(device: _device_t = None) -> tuple[int, int]: return torch._C._xpu_getMemoryInfo(device) +def get_per_process_memory_fraction(device: _device_t = None) -> float: + r""" + Retrieve the memory fraction currently set for a process on a given XPU device. + This fraction represents the portion of the total device memory that + the caching allocator is allowed to use. The allowed memory is calculated as: + + .. math:: \text{allowed\_memory} = \text{total\_memory} \times \text{fraction} + + Args: + device (torch.device or int or str, optional): selected device. It uses the current device, + given by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None`` (default). + + Returns: + float: The memory fraction in the range 0.0 to 1.0. + """ + _lazy_init() + device = _get_device_index(device, optional=True) + return torch._C._xpu_getMemoryFraction(device) + + def set_per_process_memory_fraction(fraction: float, device: _device_t = None) -> None: r""" Set the memory fraction for a single process on XPU device. @@ -216,11 +236,13 @@ def set_per_process_memory_fraction(fraction: float, device: _device_t = None) - device = _get_device_index(device, optional=True) if not isinstance(fraction, float): raise TypeError("Invalid type for fraction argument, must be `float`") + # pyrefly: ignore [missing-attribute] torch._C._xpu_setMemoryFraction(fraction, device) __all__ = [ "empty_cache", + "get_per_process_memory_fraction", "max_memory_allocated", "max_memory_reserved", "mem_get_info", diff --git a/torchgen/_autoheuristic/README.md b/torchgen/_autoheuristic/README.md index 2241785c2983b..091011d3f47a1 100644 --- a/torchgen/_autoheuristic/README.md +++ b/torchgen/_autoheuristic/README.md @@ -3,7 +3,7 @@ AutoHeuristic is a framework that allows one to use results from autotuning to l ## How to use AutoHeuristic In general, the following steps have to performed: -- The AutoHeursitic constructor has to be called. +- The AutoHeuristic constructor has to be called. - A script that runs benchmarks in order to collect training data has to be implemented. - The train_decision.py (if you want to learn a decision tree) or train_regression.py (if you want to learn a regression tree) script has to be run in order to learn the heuristic and generate it to code. diff --git a/torchgen/gen_aoti_c_shim.py b/torchgen/gen_aoti_c_shim.py index 65161200256e5..ead2a2a1cf4cc 100644 --- a/torchgen/gen_aoti_c_shim.py +++ b/torchgen/gen_aoti_c_shim.py @@ -678,7 +678,7 @@ def gen_aoti_c_shim_files( # Use "aten" as the device name when dispatch_key is Generic device_name = "aten" if dispatch_key is None else dispatch_key.lower() - # header files were checked in for ABI-compatiblilty checking + # header files were checked in for ABI-compatibility checking header_file_name = f"c_shim_{device_name}.h" new_header = gen_aoti_c_shim( fallback_native_functions, diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index c396941cf913d..1cb681ba19d34 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -1024,8 +1024,22 @@ def gen_functionalization_registration( ) -> list[str]: @with_native_function def emit_registration_helper(f: NativeFunction) -> str: - assert not f.has_composite_implicit_autograd_kernel - registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})" + if f.has_composite_implicit_autograd_kernel: + metadata = composite_implicit_autograd_index.get_kernel(f) + assert metadata is not None + native_api_name = metadata.kernel + sig = NativeSignature(f.func, symint=metadata.supports_symint()) + # Note [Composite view ops in the functionalization pass] + # We don't need to worry about implemententing functionalization kernels for views with + # CompositeImplicitAutograd kernels, because we can just decompose them into their base operators. + # We can't just opt the entire Functionalization dispatch key into the composite keyset though, + # because we don't want to decompose non-view ops that are composite, like `at::ones`. + registration_str = ( + f"static_cast<{sig.ptr_type()}>(at::native::{native_api_name})" + ) + else: + # non-composite view ops (and inplace ops) get a normal registration. + registration_str = f"TORCH_FN(functionalization::{wrapper_name(f.func)})" return f'm.impl("{f.func.name}", {registration_str});' # Don't generate kernels in mobile build @@ -1038,12 +1052,8 @@ def emit_registration_helper(f: NativeFunction) -> str: if str(g.view.func.name) == "lift_fresh": return [] view_str = [] - if not g.view.has_composite_implicit_autograd_kernel: - view_str.append(emit_registration_helper(g.view)) - if ( - g.view_inplace is not None - and not g.view_inplace.has_composite_implicit_autograd_kernel - ): + view_str.append(emit_registration_helper(g.view)) + if g.view_inplace is not None: assert g.view_inplace.is_view_op view_str.append(emit_registration_helper(g.view_inplace)) return view_str